1"""
2Note:
3 All rule should be imported into this module for hydra.instantiate to work
4"""
5# pylint: disable=unused-import
6# pyright: reportUnusedImport=false
7# ruff: noqa: F401, F403
8
9from typing import TYPE_CHECKING, Iterable, Optional, Union
10
11import hydra.errors
12import omegaconf
13from hydra.utils import call, instantiate
14from omegaconf import DictConfig, OmegaConf
15from typing_extensions import overload
16
17from agents.rules.behaviour_templates import (
18 DEBUG_RULES,
19 ConfigBasedRSSUpdates,
20 NormalSpeedRule,
21 SetNextWaypointNearby,
22 SlowDownAtIntersectionRule,
23)
24from agents.rules.lane_changes import * # allow to import all new rules
25from agents.rules.lane_changes import AvoidTailgatorRule, RandomLaneChangeRule, SimpleOvertakeRule
26from agents.rules.obstacles import *
27from agents.rules.obstacles import DriveSlowTowardsTrafficLight, PassYellowTrafficLightRule
28from agents.rules.stopped_long_trigger import StoppedTooLongTrigger
29from agents.tools.logs import logger
30from agents.tools.config_creation import LunaticAgentSettings
31from classes.constants import Phase
32from classes.rule import Rule, BlockingRule
33
34if TYPE_CHECKING:
35 from agents.tools.config_creation import CallFunctionFromConfig, CreateRuleFromConfig, RuleCreatingParameters
36 from classes.worldmodel import GameFramework
37
38if DEBUG_RULES:
39 from agents.rules._debug_rules import SimpleRule1, SimpleRule1B, debug_rules
40
41
[docs]
42def create_default_rules(gameframework: Optional["GameFramework"] = None, random_lane_change: bool = False) -> "Iterable[Rule]":
43
44 avoid_tailgator_rule = AvoidTailgatorRule()
45 simple_overtake_rule = SimpleOvertakeRule()
46
47 set_close_waypoint_when_done = SetNextWaypointNearby()
48 normal_intersection_speed_rule = SlowDownAtIntersectionRule()
49 normal_speed_rule = NormalSpeedRule()
50 config_based_rss_updates = ConfigBasedRSSUpdates()
51
52 #slow_towards_traffic_light = DriveSlowTowardsTrafficLight(gameframework=gameframework) # Blocking Rule
53
54 default_rules: list[Rule] = [normal_intersection_speed_rule, normal_speed_rule, avoid_tailgator_rule,
55 simple_overtake_rule, set_close_waypoint_when_done, config_based_rss_updates,]
56 if random_lane_change:
57 default_rules.append(RandomLaneChangeRule())
58
59 if DEBUG_RULES:
60 default_rules.append(StoppedTooLongTrigger())
61 default_rules.extend([SimpleRule1, SimpleRule1B])
62 default_rules.extend(debug_rules)
63 if not gameframework and any(isinstance(rule, BlockingRule) for rule in default_rules):
64 logger.warning("A BlockingRule is in the default rules but no GameFramework instance is provided. Be sure to initialize a GameFramework later!")
65
66 return default_rules
67
68
69@overload
70def rule_from_config(cfg: "CreateRuleFromConfig") -> Rule:
71 ...
72
73
74@overload
75def rule_from_config(cfg: "CallFunctionFromConfig | DictConfig") -> Union[Rule, Iterable[Rule]]:
76 ...
77
78
[docs]
79def rule_from_config(cfg: "CallFunctionFromConfig | DictConfig | CreateRuleFromConfig") -> Union[Rule, Iterable[Rule]]:
80 """
81 Instantiates Rules through Hydra's instantiate function.
82
83 Note:
84 The _target_ interface also allows to call functions, e.g. :py:func:`create_default_rules`,
85 hence you need to check if the return value is a Rule or an Iterable[Rule]
86
87 Returns:
88 Rule or Iterable[Rule]
89
90 See Also:
91 - :py:class:`agents.tools.config_creation.CreateRuleFromConfig`
92 - :py:class:`agents.tools.config_creation.CallFunctionFromConfig`
93 """
94 if isinstance(cfg, dict):
95 cfg = OmegaConf.create(cfg, flags={"allow_objects": True})
96
97 # Lazy dotpath from globals
98 # Allow to write NormalSpeedRule instead of agents.rules.behaviour_templates.NormalSpeedRule
99 if cfg._target_ in globals(): # pyright: ignore[reportPrivateUsage]
100 cfg._target_ = globals()[cfg._target_].__module__ + "." + cfg._target_ # pyright: ignore[reportPrivateUsage]
101 # NOTE: Could use rule_class for a more direct way compared to the block with
102 # except (omegaconf.MissingMandatoryValue, omegaconf.errors.InterpolationKeyError)
103 #rule_class = globals()[cfg._target_]
104 else:
105 rule_class = None # noqa: F841 # pyright: ignore[reportUnusedVariable]
106 # Else user needs to provide the full path
107
108 # Fix phase as string from yaml
109 if "phases" not in cfg or OmegaConf.is_missing(cfg, "phases") or isinstance(cfg.phases, Phase): # type: ignore[attr-defined]
110 # Target refers to a function or this will throw an error when applied to a Rule
111 pass
112 elif isinstance(cfg.phases, str): # pyright: ignore[reportAttributeAccessIssue]
113 cfg.phases = Phase.from_string(cfg.phases) # pyright: ignore[reportAttributeAccessIssue]
114 elif isinstance(cfg.phases, Iterable): # pyright: ignore[reportAttributeAccessIssue]
115 cfg.phases = [Phase.from_string(phase) if isinstance(phase, str) else phase for phase in cfg.phases] # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType, reportUnknownVariableType]
116
117 # Throw out all keys that are not valid for the target, i.e. MISSING
118 valid_keys = list({k for k in cfg if not OmegaConf.is_missing(cfg, k)}) # _target_ is kept for instantiate
119 clean_cfg: RuleCreatingParameters = OmegaConf.masked_copy(cfg, valid_keys) # pyright: ignore[reportArgumentType]
120
121 if "_args_" in clean_cfg and clean_cfg._args_ is None:
122 logger.error("_args_ argument for %s should be a list, not None", cfg._target_) # pyright: ignore[reportPrivateUsage]
123 clean_cfg._args_ = []
124
125 # Note: call is an alias for instantiate
126 try:
127 if "_args_" not in clean_cfg:
128 if "self_config" in clean_cfg:
129 try:
130 # Test. If this fails, then the instantiation will also fail -> fix it
131 OmegaConf.to_container(clean_cfg, resolve=True, throw_on_missing=True)
132 except (omegaconf.MissingMandatoryValue, omegaconf.errors.InterpolationKeyError):
133 logger.debug("Could not resolve all values for %s, will set up a dummy parent", cfg._target_)
134 # HACK:
135 # If this fails the instantiation, OmegaConf wants to resolve the values, but as there is no parent
136 # Set a dummy config as parent
137 # TODO: should also get rid of ALL missing values
138 # Alternatively could escape all interpolations as strings and recreate the interpolations afterwards,
139 # however, need to assume that all interpolation like stings are meant as interpolations.
140
141 parent: LunaticAgentSettings = OmegaConf.structured(LunaticAgentSettings(rules=[]),
142 flags={"allow_objects": True})
143 for key in parent.live_info.keys(): # noqa: SIM118,RUF100
144 if key in ("executed_direction", "incoming_direction"):
145 parent.live_info[key] = "VOID"
146 else:
147 try:
148 parent.live_info[key] = 0
149 except Exception:
150 logger.debug("Could not set %s to 0", key)
151 continue
152 clean_cfg.self_config._set_parent(parent)
153 clean_cfg._set_parent(parent) # pyright: ignore[reportPrivateUsage]
154 parent["self"] = clean_cfg.self_config
155 # NOTE: If this still fails, can go over the rule_class directly if found; which might be better/easier than this hack
156
157 rule: Rule = instantiate(clean_cfg, _convert_="none")
158 if "self_config" in clean_cfg:
159 rule.self_config.merge_with(clean_cfg.self_config) # Interpolations are resolved, adding them back as strings
160 return rule
161 else:
162 rule_or_rules: Union[Rule, Iterable[Rule]] = call(clean_cfg, _convert_="none")
163 return rule_or_rules
164 except hydra.errors.InstantiationException:
165 logger.error("Could not instantiate rule. The _target_ must exist in %s or you need to provide a global _target_.module.submodule... path ", __file__)
166 raise
167
168
169# Add rules to extracted schema
170import agents.tools.config_creation as __config_creation # noqa
171if not __config_creation.READTHEDOCS:
172 try:
173 __config_creation.export_schemas(detailed_rules=True)
174 except Exception:
175 logger.exception("Error exporting schemas with rules")
176del __config_creation