Source code for agents.rules.behaviour_templates

  1import random
  2from functools import partial, update_wrapper
  3from typing import List, Optional
  4
  5import carla
  6from omegaconf._impl import select_node  # noqa: PLC2701
  7
  8from classes.constants import READTHEDOCS, Phase
  9from classes.rule import ConditionFunction, Context, Rule, always_execute
 10
 11_use_debug_rules = True  # TODO: Turn off again #XXX
 12DEBUG_RULES: bool = not READTHEDOCS and _use_debug_rules
 13
 14#TODO: maybe create some omega conf dict creator that allows to create settings more easily
 15# e.g. CreateOverwriteDict.speed.max_speed = 60, yields such a subdict.
 16# QUESTION: How to merge more than one entry?
 17
 18# ------ Rule Helpers ------
 19
 20
 21def _if_config_checker(ctx: "Context", config_path: str, value) -> bool:
 22    """
 23    Check if a value in the config is set to a certain value.
 24    """
 25    return select_node(ctx.config, config_path, absolute_key=True) == value
 26
 27
[docs] 28def if_config(config_path, value): 29 """ 30 Returns a partial function that checks if a value in the config is set to a certain value. 31 """ 32 func = partial(_if_config_checker, config_path=config_path, value=value) 33 func = update_wrapper(func, _if_config_checker) 34 return ConditionFunction(func, 35 name=f"Checks if {config_path} is {value}", 36 use_self=False # NOTE: Has to be used as _if_config_checker has > 1 argument and no self usage. 37 )
38 39# --- 40 41# Make random based on probabilistic config 42 43 44# ------ Speed Rules ------ 45
[docs] 46def set_default_intersection_speed(ctx: "Context"): 47 """ 48 Slow down the car when turning at a junction. 49 """ 50 target_speed = min([ 51 ctx.config.speed.max_speed, 52 ctx.config.live_info.current_speed_limit - ctx.config.speed.intersection_speed_decrease] 53 ) 54 # NOTE: could interpolate this in omega conf 55 ctx.agent.config.speed.target_speed = target_speed
56 57
[docs] 58class SlowDownAtIntersectionRule(Rule): 59 """ 60 Slow down the car when turning at a junction. 61 """ 62 phase = Phase.TURNING_AT_JUNCTION | Phase.BEGIN 63 condition = always_execute 64 action = set_default_intersection_speed 65 overwrite_settings = {"speed": {"intersection_speed_decrease": 10}} 66 description = "Set speed to intersection speed"
67 68 69# ----- 70
[docs] 71def set_default_speed(ctx: "Context"): 72 """ 73 Speed to apply when the car drives under normal circumstances, 74 i.e. no junctions, no obstacles, etc. detected. 75 """ 76 # Read from config 77 target_speed = min([ 78 ctx.config.speed.max_speed, 79 ctx.config.live_info.current_speed_limit - ctx.config.speed.speed_lim_dist]) 80 # Set on Agent 81 ctx.agent.config.speed.target_speed = target_speed
82 83
[docs] 84class NormalSpeedRule(Rule): 85 """ 86 Speed to apply when the car drives under normal circumstances, 87 i.e. no junctions, no obstacles, etc. detected. 88 """ 89 phases = Phase.TAKE_NORMAL_STEP | Phase.BEGIN # type: ignore[assignment] 90 condition = always_execute 91 action = set_default_speed 92 description = "Set speed to normal speed"
93 94# ----------- Plan next waypoint ----------- 95 96
[docs] 97def random_spawnpoint_destination(ctx: "Context", waypoints: Optional[List[carla.Waypoint]] = None): 98 """ 99 Set a random waypoint as the next target. 100 """ 101 print("The target has been reached, searching for another target") 102 ctx.agent._world_model.hud.notification("Target reached", seconds=4.0) 103 if waypoints is None: 104 transforms = ctx.get_map().get_spawn_points() 105 loc = random.choice(transforms).location 106 else: 107 loc = random.choice(waypoints).transform.location 108 ctx.agent.set_destination(loc)
109 110
[docs] 111@ConditionFunction 112def is_agent_done(ctx: Context) -> bool: 113 """ 114 Agent has reached its destination. 115 """ 116 return ctx.agent.done()
117 118
[docs] 119class TargetRandomSpawnpointWhenDone(Rule): 120 """ 121 Sets random waypoint when done 122 """ 123 phases = Phase.DONE | Phase.BEGIN # type: ignore[assignment] 124 condition = is_agent_done 125 action = random_spawnpoint_destination 126 description = "Sets random waypoint when done"
127 128# --- 129 130
[docs] 131def set_next_waypoint_nearby(ctx: "Context"): 132 ctx.agent._world_model.hud.notification("Target reached", seconds=4.0) 133 wp = ctx.agent._current_waypoint.next(150)[-1] 134 next_wp = random.choice((wp, wp.get_left_lane(), wp.get_right_lane())) 135 if next_wp is None: 136 next_wp = wp 137 #destination = random.choice(spawn_points).location 138 destination = next_wp.transform.location 139 ctx.agent.set_destination(destination)
140 141
[docs] 142class SetNextWaypointNearby(Rule): 143 "Sets random waypoint when done to a nearby point ahead" 144 phases = Phase.DONE | Phase.BEGIN # type: ignore[assignment] 145 condition = is_agent_done 146 action = set_next_waypoint_nearby
147 148# ----------- RSS Rules ----------- 149 150
[docs] 151def accept_rss_updates(ctx: Context): 152 """ 153 Accept RSS updates from the RSS manager. 154 """ 155 if ctx.prior_result is None: 156 return None 157 assert isinstance(ctx.prior_result, carla.VehicleControl) 158 ctx.control = ctx.prior_result
159 160 161assert isinstance(if_config("rss.enabled", True), ConditionFunction) 162 163
[docs] 164class AlwaysAcceptRSSUpdates(Rule): 165 """ 166 Always accept RSS updates if rss is enabled in the config. 167 168 """ 169 phases = Phase.RSS_EVALUATION | Phase.END # type: ignore[assignment] 170 condition = if_config("rss.enabled", True) 171 action = accept_rss_updates 172 description = "Always accepts the updates calculated by the RSS System."
173 174
[docs] 175class ConfigBasedRSSUpdates(Rule): 176 """Always accept RSS updates if :any:`rss.always_accept_update <LunaticAgentSettings.rss>` is set to True in the config.""" 177 phases = Phase.RSS_EVALUATION | Phase.END # type: ignore[assignment] 178 condition = if_config("rss.always_accept_update", True) 179 action = accept_rss_updates
180 #description = "Accepts RSS updates depending on the value of `config.rss.always_accept_update`" 181 182 183# ----------- Tests ----------- 184 185if __name__ == "__main__" or DEBUG_RULES: 186 from ._debug_rules import debug_rules as debug_rules # mark as reexport # noqa: PLC0414