1"""
2Attention:
3 All rules should be imported into this module for :py:attr:`.LunaticAgentSettings.rules`
4 to work. As :python:`hydra.instantiate` needs the variables in this file or an absolute path.
5"""
6# pylint: disable=unused-import
7# pyright: reportUnusedImport=false
8# ruff: noqa: F401, F403
9
10from typing import TYPE_CHECKING, Iterable, Optional, Union
11
12import hydra.errors
13import omegaconf
14from hydra.utils import call, instantiate
15from omegaconf import DictConfig, OmegaConf
16from typing_extensions import overload
17
18from agents.rules.behaviour_templates import (
19 DEBUG_RULES,
20 ConfigBasedRSSUpdates,
21 NormalSpeedRule,
22 SetNextWaypointNearby,
23 SlowDownAtIntersectionRule,
24)
25from agents.rules.lane_changes import * # allow to import all new rules
26from agents.rules.lane_changes import AvoidTailgatorRule, RandomLaneChangeRule, SimpleOvertakeRule
27from agents.rules.obstacles import *
28from agents.rules.obstacles import DriveSlowTowardsTrafficLight, PassYellowTrafficLightRule
29from agents.rules.stopped_long_trigger import StoppedTooLongTrigger
30from agents.tools.logs import logger
31from agents.tools.config_creation import LunaticAgentSettings
32from classes.constants import Phase
33from classes.rule import Rule, BlockingRule
34
35if TYPE_CHECKING:
36 from agents.tools.config_creation import CallFunctionFromConfig, CreateRuleFromConfig, RuleCreatingParameters
37 from classes.worldmodel import GameFramework
38
39if DEBUG_RULES:
40 from agents.rules._debug_rules import SimpleRule1, SimpleRule1B, debug_rules
41
42
[docs]
43def create_default_rules(
44 gameframework: Optional["GameFramework"] = None, random_lane_change: bool = False
45) -> "Iterable[Rule]":
46 avoid_tailgator_rule = AvoidTailgatorRule()
47 simple_overtake_rule = SimpleOvertakeRule()
48
49 set_close_waypoint_when_done = SetNextWaypointNearby()
50 normal_intersection_speed_rule = SlowDownAtIntersectionRule()
51 normal_speed_rule = NormalSpeedRule()
52 config_based_rss_updates = ConfigBasedRSSUpdates()
53
54 # slow_towards_traffic_light = DriveSlowTowardsTrafficLight(gameframework=gameframework) # Blocking Rule
55
56 default_rules: list[Rule] = [
57 normal_intersection_speed_rule,
58 normal_speed_rule,
59 avoid_tailgator_rule,
60 simple_overtake_rule,
61 set_close_waypoint_when_done,
62 config_based_rss_updates,
63 ]
64 if random_lane_change:
65 default_rules.append(RandomLaneChangeRule())
66
67 if DEBUG_RULES:
68 default_rules.append(StoppedTooLongTrigger())
69 default_rules.extend([SimpleRule1, SimpleRule1B])
70 default_rules.extend(debug_rules)
71 if not gameframework and any(isinstance(rule, BlockingRule) for rule in default_rules):
72 logger.warning(
73 "A BlockingRule is in the default rules but no GameFramework instance is provided. Be sure to initialize a GameFramework later!"
74 )
75
76 return default_rules
77
78
79@overload
80def rule_from_config(cfg: "CreateRuleFromConfig") -> Rule: ...
81
82
83@overload
84def rule_from_config(cfg: "CallFunctionFromConfig | DictConfig") -> Union[Rule, Iterable[Rule]]: ...
85
86
[docs]
87def rule_from_config(cfg: "CallFunctionFromConfig | DictConfig | CreateRuleFromConfig") -> Union[Rule, Iterable[Rule]]:
88 """
89 Instantiates Rules through Hydra's instantiate function.
90
91 Note:
92 The _target_ interface also allows to call functions, e.g. :py:func:`create_default_rules`,
93 hence you need to check if the return value is a Rule or an Iterable[Rule]
94
95 Returns:
96 Rule or Iterable[Rule]
97
98 See Also:
99 - :py:class:`agents.tools.config_creation.CreateRuleFromConfig`
100 - :py:class:`agents.tools.config_creation.CallFunctionFromConfig`
101 """
102 if isinstance(cfg, dict):
103 cfg = OmegaConf.create(cfg, flags={"allow_objects": True})
104
105 # Lazy dotpath from globals
106 # Allow to write NormalSpeedRule instead of agents.rules.behaviour_templates.NormalSpeedRule
107 if cfg._target_ in globals(): # pyright: ignore[reportPrivateUsage]
108 cfg._target_ = globals()[cfg._target_].__module__ + "." + cfg._target_ # pyright: ignore[reportPrivateUsage]
109 # NOTE: Could use rule_class for a more direct way compared to the block with
110 # except (omegaconf.MissingMandatoryValue, omegaconf.errors.InterpolationKeyError)
111 # rule_class = globals()[cfg._target_]
112 else:
113 rule_class = None # noqa: F841 # pyright: ignore[reportUnusedVariable]
114 # Else user needs to provide the full path
115
116 # Fix phase as string from yaml
117 if "phases" not in cfg or OmegaConf.is_missing(cfg, "phases") or isinstance(cfg.phases, Phase): # type: ignore[attr-defined]
118 # Target refers to a function or this will throw an error when applied to a Rule
119 pass
120 elif isinstance(cfg.phases, str): # pyright: ignore[reportAttributeAccessIssue]
121 cfg.phases = Phase.from_string(cfg.phases) # pyright: ignore[reportAttributeAccessIssue]
122 elif isinstance(cfg.phases, Iterable): # pyright: ignore[reportAttributeAccessIssue]
123 cfg.phases = [Phase.from_string(phase) if isinstance(phase, str) else phase for phase in cfg.phases] # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType, reportUnknownVariableType]
124
125 # Throw out all keys that are not valid for the target, i.e. MISSING
126 valid_keys = list({k for k in cfg if not OmegaConf.is_missing(cfg, k)}) # _target_ is kept for instantiate
127 clean_cfg: RuleCreatingParameters = OmegaConf.masked_copy(cfg, valid_keys) # pyright: ignore[reportArgumentType]
128
129 if "_args_" in clean_cfg and clean_cfg._args_ is None:
130 logger.error("_args_ argument for %s should be a list, not None", cfg._target_) # pyright: ignore[reportPrivateUsage]
131 clean_cfg._args_ = []
132
133 # Note: call is an alias for instantiate
134 try:
135 if "_args_" not in clean_cfg:
136 if "self_config" in clean_cfg:
137 try:
138 # Test. If this fails, then the instantiation will also fail -> fix it
139 OmegaConf.to_container(clean_cfg, resolve=True, throw_on_missing=True)
140 except (omegaconf.MissingMandatoryValue, omegaconf.errors.InterpolationKeyError):
141 logger.debug("Could not resolve all values for %s, will set up a dummy parent", cfg._target_)
142 # HACK:
143 # If this fails the instantiation, OmegaConf wants to resolve the values, but as there is no parent
144 # Set a dummy config as parent
145 # TODO: should also get rid of ALL missing values
146 # Alternatively could escape all interpolations as strings and recreate the interpolations afterwards,
147 # however, need to assume that all interpolation like stings are meant as interpolations.
148
149 parent: LunaticAgentSettings = OmegaConf.structured(
150 LunaticAgentSettings(rules=[]), flags={"allow_objects": True}
151 )
152 for key in parent.live_info.keys(): # noqa: SIM118,RUF100
153 if key in ("executed_direction", "incoming_direction"):
154 parent.live_info[key] = "VOID"
155 else:
156 try:
157 parent.live_info[key] = 0
158 except Exception:
159 logger.debug("Could not set %s to 0", key)
160 continue
161 clean_cfg.self_config._set_parent(parent)
162 clean_cfg._set_parent(parent) # pyright: ignore[reportPrivateUsage]
163 parent["self"] = clean_cfg.self_config
164 # NOTE: If this still fails, can go over the rule_class directly if found; which might be better/easier than this hack
165
166 rule: Rule = instantiate(clean_cfg, _convert_="none")
167 if "self_config" in clean_cfg:
168 rule.self_config.merge_with(
169 clean_cfg.self_config
170 ) # Interpolations are resolved, adding them back as strings
171 return rule
172 else:
173 rule_or_rules: Union[Rule, Iterable[Rule]] = call(clean_cfg, _convert_="none")
174 return rule_or_rules
175 except hydra.errors.InstantiationException:
176 logger.error(
177 "Could not instantiate rule. The _target_ must exist in %s or you need to provide a global _target_.module.submodule... path ",
178 __file__,
179 )
180 raise
181
182
183# Add rules to extracted schema
184import agents.tools.config_creation as __config_creation # noqa
185
186if not __config_creation.READTHEDOCS:
187 try:
188 __config_creation.export_schemas(detailed_rules=True)
189 except Exception:
190 logger.exception("Error exporting schemas with rules")
191del __config_creation