Source code for agents.rules.behaviour_templates

  1import random
  2from functools import partial, update_wrapper
  3from typing import List, Optional
  4
  5import carla
  6from omegaconf._impl import select_node  # noqa: PLC2701
  7
  8from classes.constants import READTHEDOCS, Phase
  9from classes.rule import ConditionFunction, Context, Rule, always_execute
 10
 11_use_debug_rules = True  # TODO: Turn off again #XXX
 12DEBUG_RULES: bool = not READTHEDOCS and _use_debug_rules
 13
 14# TODO: maybe create some omega conf dict creator that allows to create settings more easily
 15# e.g. CreateOverwriteDict.speed.max_speed = 60, yields such a subdict.
 16# QUESTION: How to merge more than one entry?
 17
 18# ------ Rule Helpers ------
 19
 20
 21def _if_config_checker(ctx: "Context", config_path: str, value) -> bool:
 22    """
 23    Check if a value in the config is set to a certain value.
 24    """
 25    return select_node(ctx.config, config_path, absolute_key=True) == value
 26
 27
[docs] 28def if_config(config_path, value): 29 """ 30 Returns a partial function that checks if a value in the config is set to a certain value. 31 """ 32 func = partial(_if_config_checker, config_path=config_path, value=value) 33 func = update_wrapper(func, _if_config_checker) 34 return ConditionFunction( 35 func, 36 name=f"Checks if {config_path} is {value}", 37 use_self=False, # NOTE: Has to be used as _if_config_checker has > 1 argument and no self usage. 38 )
39 40 41# --- 42 43# Make random based on probabilistic config 44 45 46# ------ Speed Rules ------ 47 48
[docs] 49def set_default_intersection_speed(ctx: "Context"): 50 """ 51 Slow down the car when turning at a junction. 52 """ 53 target_speed = min([ 54 ctx.config.speed.max_speed, 55 ctx.config.live_info.current_speed_limit - ctx.config.speed.intersection_speed_decrease, 56 ]) 57 # NOTE: could interpolate this in omega conf 58 ctx.agent.config.speed.target_speed = target_speed
59 60
[docs] 61class SlowDownAtIntersectionRule(Rule): 62 """ 63 Slow down the car when turning at a junction. 64 """ 65 66 phase = Phase.TURNING_AT_JUNCTION | Phase.BEGIN 67 condition = always_execute 68 action = set_default_intersection_speed 69 overwrite_settings = {"speed": {"intersection_speed_decrease": 10}} 70 description = "Set speed to intersection speed"
71 72 73# ----- 74 75
[docs] 76def set_default_speed(ctx: "Context"): 77 """ 78 Speed to apply when the car drives under normal circumstances, 79 i.e. no junctions, no obstacles, etc. detected. 80 """ 81 # Read from config 82 target_speed = min([ 83 ctx.config.speed.max_speed, 84 ctx.config.live_info.current_speed_limit - ctx.config.speed.speed_lim_dist, 85 ]) 86 # Set on Agent 87 ctx.agent.config.speed.target_speed = target_speed
88 89
[docs] 90class NormalSpeedRule(Rule): 91 """ 92 Speed to apply when the car drives under normal circumstances, 93 i.e. no junctions, no obstacles, etc. detected. 94 """ 95 96 phases = Phase.TAKE_NORMAL_STEP | Phase.BEGIN # type: ignore[assignment] 97 condition = always_execute 98 action = set_default_speed 99 description = "Set speed to normal speed"
100 101 102# ----------- Plan next waypoint ----------- 103 104
[docs] 105def random_spawnpoint_destination(ctx: "Context", waypoints: Optional[List[carla.Waypoint]] = None): 106 """ 107 Set a random waypoint as the next target. 108 """ 109 print("The target has been reached, searching for another target") 110 ctx.agent._world_model.hud.notification("Target reached", seconds=4.0) 111 if waypoints is None: 112 transforms = ctx.get_map().get_spawn_points() 113 loc = random.choice(transforms).location 114 else: 115 loc = random.choice(waypoints).transform.location 116 ctx.agent.set_destination(loc)
117 118
[docs] 119@ConditionFunction 120def is_agent_done(ctx: Context) -> bool: 121 """ 122 Agent has reached its destination. 123 """ 124 return ctx.agent.done()
125 126
[docs] 127class TargetRandomSpawnpointWhenDone(Rule): 128 """ 129 Sets random waypoint when done 130 """ 131 132 phases = Phase.DONE | Phase.BEGIN # type: ignore[assignment] 133 condition = is_agent_done 134 action = random_spawnpoint_destination 135 description = "Sets random waypoint when done"
136 137 138# --- 139 140
[docs] 141def set_next_waypoint_nearby(ctx: "Context"): 142 ctx.agent._world_model.hud.notification("Target reached", seconds=4.0) 143 wp = ctx.agent._current_waypoint.next(150)[-1] 144 next_wp = random.choice((wp, wp.get_left_lane(), wp.get_right_lane())) 145 if next_wp is None: 146 next_wp = wp 147 # destination = random.choice(spawn_points).location 148 destination = next_wp.transform.location 149 ctx.agent.set_destination(destination)
150 151
[docs] 152class SetNextWaypointNearby(Rule): 153 "Sets random waypoint when done to a nearby point ahead" 154 155 phases = Phase.DONE | Phase.BEGIN # type: ignore[assignment] 156 condition = is_agent_done 157 action = set_next_waypoint_nearby
158 159 160# ----------- RSS Rules ----------- 161 162
[docs] 163def accept_rss_updates(ctx: Context): 164 """ 165 Accept RSS updates from the RSS manager. 166 """ 167 if ctx.prior_result is None: 168 return None 169 assert isinstance(ctx.prior_result, carla.VehicleControl) 170 ctx.control = ctx.prior_result
171 172 173assert isinstance(if_config("rss.enabled", True), ConditionFunction) 174 175
[docs] 176class AlwaysAcceptRSSUpdates(Rule): 177 """ 178 Always accept RSS updates if rss is enabled in the config. 179 180 """ 181 182 phases = Phase.RSS_EVALUATION | Phase.END # type: ignore[assignment] 183 condition = if_config("rss.enabled", True) 184 action = accept_rss_updates 185 description = "Always accepts the updates calculated by the RSS System."
186 187
[docs] 188class ConfigBasedRSSUpdates(Rule): 189 """Always accept RSS updates if :any:`rss.always_accept_update <LunaticAgentSettings.rss>` is set to True in the config.""" 190 191 phases = Phase.RSS_EVALUATION | Phase.END # type: ignore[assignment] 192 condition = if_config("rss.always_accept_update", True) 193 action = accept_rss_updates
194 # description = "Accepts RSS updates depending on the value of `config.rss.always_accept_update`" 195 196 197# ----------- Tests ----------- 198 199if __name__ == "__main__" or DEBUG_RULES: 200 from ._debug_rules import debug_rules as debug_rules # mark as reexport # noqa: PLC0414