Source code for agents.rules

  1"""
  2Note:
  3    All rule should be imported into this module for hydra.instantiate to work
  4"""
  5# pylint: disable=unused-import
  6# pyright: reportUnusedImport=false
  7# ruff: noqa: F401, F403
  8
  9from typing import TYPE_CHECKING, Iterable, Optional, Union
 10
 11import hydra.errors
 12import omegaconf
 13from hydra.utils import call, instantiate
 14from omegaconf import DictConfig, OmegaConf
 15from typing_extensions import overload
 16
 17from agents.rules.behaviour_templates import (
 18    DEBUG_RULES,
 19    ConfigBasedRSSUpdates,
 20    NormalSpeedRule,
 21    SetNextWaypointNearby,
 22    SlowDownAtIntersectionRule,
 23)
 24from agents.rules.lane_changes import *  # allow to import all new rules
 25from agents.rules.lane_changes import AvoidTailgatorRule, RandomLaneChangeRule, SimpleOvertakeRule
 26from agents.rules.obstacles import *
 27from agents.rules.obstacles import DriveSlowTowardsTrafficLight, PassYellowTrafficLightRule
 28from agents.rules.stopped_long_trigger import StoppedTooLongTrigger
 29from agents.tools.logs import logger
 30from agents.tools.config_creation import LunaticAgentSettings
 31from classes.constants import Phase
 32from classes.rule import Rule, BlockingRule
 33
 34if TYPE_CHECKING:
 35    from agents.tools.config_creation import CallFunctionFromConfig, CreateRuleFromConfig, RuleCreatingParameters
 36    from classes.worldmodel import GameFramework
 37    
 38if DEBUG_RULES:
 39    from agents.rules._debug_rules import SimpleRule1, SimpleRule1B, debug_rules
 40    
 41
[docs] 42def create_default_rules(gameframework: Optional["GameFramework"] = None, random_lane_change: bool = False) -> "Iterable[Rule]": 43 44 avoid_tailgator_rule = AvoidTailgatorRule() 45 simple_overtake_rule = SimpleOvertakeRule() 46 47 set_close_waypoint_when_done = SetNextWaypointNearby() 48 normal_intersection_speed_rule = SlowDownAtIntersectionRule() 49 normal_speed_rule = NormalSpeedRule() 50 config_based_rss_updates = ConfigBasedRSSUpdates() 51 52 #slow_towards_traffic_light = DriveSlowTowardsTrafficLight(gameframework=gameframework) # Blocking Rule 53 54 default_rules: list[Rule] = [normal_intersection_speed_rule, normal_speed_rule, avoid_tailgator_rule, 55 simple_overtake_rule, set_close_waypoint_when_done, config_based_rss_updates,] 56 if random_lane_change: 57 default_rules.append(RandomLaneChangeRule()) 58 59 if DEBUG_RULES: 60 default_rules.append(StoppedTooLongTrigger()) 61 default_rules.extend([SimpleRule1, SimpleRule1B]) 62 default_rules.extend(debug_rules) 63 if not gameframework and any(isinstance(rule, BlockingRule) for rule in default_rules): 64 logger.warning("A BlockingRule is in the default rules but no GameFramework instance is provided. Be sure to initialize a GameFramework later!") 65 66 return default_rules
67 68 69@overload 70def rule_from_config(cfg: "CreateRuleFromConfig") -> Rule: 71 ... 72 73 74@overload 75def rule_from_config(cfg: "CallFunctionFromConfig | DictConfig") -> Union[Rule, Iterable[Rule]]: 76 ... 77 78
[docs] 79def rule_from_config(cfg: "CallFunctionFromConfig | DictConfig | CreateRuleFromConfig") -> Union[Rule, Iterable[Rule]]: 80 """ 81 Instantiates Rules through Hydra's instantiate function. 82 83 Note: 84 The _target_ interface also allows to call functions, e.g. :py:func:`create_default_rules`, 85 hence you need to check if the return value is a Rule or an Iterable[Rule] 86 87 Returns: 88 Rule or Iterable[Rule] 89 90 See Also: 91 - :py:class:`agents.tools.config_creation.CreateRuleFromConfig` 92 - :py:class:`agents.tools.config_creation.CallFunctionFromConfig` 93 """ 94 if isinstance(cfg, dict): 95 cfg = OmegaConf.create(cfg, flags={"allow_objects": True}) 96 97 # Lazy dotpath from globals 98 # Allow to write NormalSpeedRule instead of agents.rules.behaviour_templates.NormalSpeedRule 99 if cfg._target_ in globals(): # pyright: ignore[reportPrivateUsage] 100 cfg._target_ = globals()[cfg._target_].__module__ + "." + cfg._target_ # pyright: ignore[reportPrivateUsage] 101 # NOTE: Could use rule_class for a more direct way compared to the block with 102 # except (omegaconf.MissingMandatoryValue, omegaconf.errors.InterpolationKeyError) 103 #rule_class = globals()[cfg._target_] 104 else: 105 rule_class = None # noqa: F841 # pyright: ignore[reportUnusedVariable] 106 # Else user needs to provide the full path 107 108 # Fix phase as string from yaml 109 if "phases" not in cfg or OmegaConf.is_missing(cfg, "phases") or isinstance(cfg.phases, Phase): # type: ignore[attr-defined] 110 # Target refers to a function or this will throw an error when applied to a Rule 111 pass 112 elif isinstance(cfg.phases, str): # pyright: ignore[reportAttributeAccessIssue] 113 cfg.phases = Phase.from_string(cfg.phases) # pyright: ignore[reportAttributeAccessIssue] 114 elif isinstance(cfg.phases, Iterable): # pyright: ignore[reportAttributeAccessIssue] 115 cfg.phases = [Phase.from_string(phase) if isinstance(phase, str) else phase for phase in cfg.phases] # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType, reportUnknownVariableType] 116 117 # Throw out all keys that are not valid for the target, i.e. MISSING 118 valid_keys = list({k for k in cfg if not OmegaConf.is_missing(cfg, k)}) # _target_ is kept for instantiate 119 clean_cfg: RuleCreatingParameters = OmegaConf.masked_copy(cfg, valid_keys) # pyright: ignore[reportArgumentType] 120 121 if "_args_" in clean_cfg and clean_cfg._args_ is None: 122 logger.error("_args_ argument for %s should be a list, not None", cfg._target_) # pyright: ignore[reportPrivateUsage] 123 clean_cfg._args_ = [] 124 125 # Note: call is an alias for instantiate 126 try: 127 if "_args_" not in clean_cfg: 128 if "self_config" in clean_cfg: 129 try: 130 # Test. If this fails, then the instantiation will also fail -> fix it 131 OmegaConf.to_container(clean_cfg, resolve=True, throw_on_missing=True) 132 except (omegaconf.MissingMandatoryValue, omegaconf.errors.InterpolationKeyError): 133 logger.debug("Could not resolve all values for %s, will set up a dummy parent", cfg._target_) 134 # HACK: 135 # If this fails the instantiation, OmegaConf wants to resolve the values, but as there is no parent 136 # Set a dummy config as parent 137 # TODO: should also get rid of ALL missing values 138 # Alternatively could escape all interpolations as strings and recreate the interpolations afterwards, 139 # however, need to assume that all interpolation like stings are meant as interpolations. 140 141 parent: LunaticAgentSettings = OmegaConf.structured(LunaticAgentSettings(rules=[]), 142 flags={"allow_objects": True}) 143 for key in parent.live_info.keys(): # noqa: SIM118,RUF100 144 if key in ("executed_direction", "incoming_direction"): 145 parent.live_info[key] = "VOID" 146 else: 147 try: 148 parent.live_info[key] = 0 149 except Exception: 150 logger.debug("Could not set %s to 0", key) 151 continue 152 clean_cfg.self_config._set_parent(parent) 153 clean_cfg._set_parent(parent) # pyright: ignore[reportPrivateUsage] 154 parent["self"] = clean_cfg.self_config 155 # NOTE: If this still fails, can go over the rule_class directly if found; which might be better/easier than this hack 156 157 rule: Rule = instantiate(clean_cfg, _convert_="none") 158 if "self_config" in clean_cfg: 159 rule.self_config.merge_with(clean_cfg.self_config) # Interpolations are resolved, adding them back as strings 160 return rule 161 else: 162 rule_or_rules: Union[Rule, Iterable[Rule]] = call(clean_cfg, _convert_="none") 163 return rule_or_rules 164 except hydra.errors.InstantiationException: 165 logger.error("Could not instantiate rule. The _target_ must exist in %s or you need to provide a global _target_.module.submodule... path ", __file__) 166 raise
167 168 169# Add rules to extracted schema 170import agents.tools.config_creation as __config_creation # noqa 171if not __config_creation.READTHEDOCS: 172 try: 173 __config_creation.export_schemas(detailed_rules=True) 174 except Exception: 175 logger.exception("Error exporting schemas with rules") 176del __config_creation