1"""
2Helper Tools for :py:mod:`.config_tools`.
3"""
4
5# pyright: reportPrivateUsage=false, reportUnknownLambdaType=false, reportUnusedClass=false
6# pyright: reportPossiblyUnboundVariable=information,reportAttributeAccessIssue=warning
7# pyright: reportUnknownVariableType=information, reportUnknownMemberType=information
8from __future__ import annotations
9
10import ast
11import inspect
12import io
13import logging
14import os
15from pathlib import Path
16import re
17import yaml
18
19from dataclasses import is_dataclass
20from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast, get_type_hints
21
22import carla
23import omegaconf.errors
24from hydra.core.config_store import ConfigStore
25from omegaconf import MISSING, DictConfig, ListConfig, MissingMandatoryValue, OmegaConf, SCMode, open_dict
26from omegaconf._utils import is_structured_config, get_omega_conf_dumper # noqa: PLC2701
27from typing_extensions import Protocol, TypeAlias, TypeAliasType, TypeVar
28
29from classes.constants import AD_RSS_AVAILABLE, Phase, RssLogLevelStub, RssRoadBoundariesModeStub, READTHEDOCS
30from launch_tools import ast_parse
31
32if TYPE_CHECKING:
33 from typing import type_check_only
34 from classes.type_protocols import AgentConfigT
35 from agents.tools.config_creation import (
36 AgentConfig,
37 CreateRuleFromConfig,
38 LunaticAgentSettings,
39 RuleConfig,
40 RuleCreatingParameters,
41 )
42 from classes.rule import Rule
43 from ruamel.yaml.comments import CommentedMap
44
45# ----------- Resolvers -------------
46
47
48def look_ahead_time(speed: float, time_to_collision: float, plus: float = 0) -> float:
49 """
50 Convert the current speed in km /h and a time to collision in seconds to a distance in meters
51 and adds a slight buffer on top.
52
53 Use as :python:`"${look_ahead_time: ${live_info.current_speed}, time, }"`
54 """
55 return speed / 3.6 * time_to_collision + plus # km / h * s = m
56
57
58# need this check for readthedocs
59if not READTHEDOCS and os.environ.get("_OMEGACONF_RESOLVERS_REGISTERED", "0") == "0":
60 import random
61 import operator
62
63 OmegaConf.register_new_resolver("add", operator.add) # type: ignore[arg-type]
64 OmegaConf.register_new_resolver("sub", operator.sub) # type: ignore[arg-type]
65 OmegaConf.register_new_resolver("mul", operator.mul) # type: ignore[arg-type]
66 OmegaConf.register_new_resolver("divide", operator.truediv) # type: ignore[arg-type]
67 OmegaConf.register_new_resolver("min", lambda *els: min(els))
68 OmegaConf.register_new_resolver("max", lambda *els: max(els))
69 OmegaConf.register_new_resolver("randint", random.randint)
70 OmegaConf.register_new_resolver("randuniform", random.uniform)
71 OmegaConf.register_new_resolver("look_ahead_time", look_ahead_time)
72 os.environ["_OMEGACONF_RESOLVERS_REGISTERED"] = "1"
73
74
75CONFIG_SCHEMA_NAME = "launch_config_schema.yaml"
76"""Name to use for the launch_config as it cannot be launch_config itself."""
77
78
79config_store = ConfigStore.instance()
80"""Hydra_ 's ConfigStore instance to access config schemas."""
81
82# POSTPONE_REGISTER = sys.version_info < (3, 10)
83# postpone_register: dict[str, type[Any]] = {}
84
85
86def register_hydra_schema(obj: "type[Any]", name: Optional[str] = None):
87 """
88 Uses Hydra's ConfigStore to register the schema of the current class in the
89 :py:obj:`ConfigStore <config_store>`.
90
91 See Also:
92 :py:func:`config_path`
93 """
94 if name is None:
95 name = cast("str", getattr(obj, "_config_path", obj.__name__))
96 # if not POSTPOND_REGISTER:
97 # pass
98 # else:
99 # postpond_register[name] = obj
100 config_store.store(
101 name,
102 OmegaConf.structured(obj, flags={"allow_objects": True}),
103 provider="agents.tools.config_creation",
104 group=None,
105 package=obj.__module__,
106 )
107
108
[docs]
109def config_path(path: Optional[str] = None):
110 """
111 Decorator to register the schema of the current class with Hydra's :py:obj:`ConfigStore<hydra>`..
112 Use the path relative to the `launch_config.yaml`, where the config is stored to use.
113
114 Create subclasses in the following way:
115
116 .. code-block:: python
117
118 @config_path("agent/speed")
119 @dataclass
120 class AgentSpeedSettings(AgentConfig):
121
122 Attention:
123 - Use "/" as separator and not dots.
124 - This is used for the Hydra schema registration and repeated paths will overwrite each other.
125 - This value is inherited (if != :code:`NOT_GIVEN`), and the value of the parent is taken
126 as default. Do not type-hint this value it must be a ClassVar to not conflict with dataclasses.
127
128 Returns:
129 (Callable[[type[AgentConfig]], type[AgentConfig]]) Wrapper function to register the schema.
130 """
131
132 if not READTHEDOCS:
133
134 def _register(obj: "type[_AnAgentConfig]") -> "type[_AnAgentConfig]": # pyright: ignore[reportRedeclaration]
135 name = obj._config_path if path is None else path
136 if name is None or name == "NOT_GIVEN":
137 msg = f"Path is not given for {obj.__name__}. Use @register('path/to/config.yaml') to set the path."
138 raise ValueError(msg)
139 dots = name.count(".")
140 if dots > 0 and not (dots == 1 and name.endswith(".yaml")):
141 msg = f"Use '/' as separator and not dots. E.g. {name.replace('.', '/')} and not '{name}'"
142 raise ValueError(msg)
143 obj._config_path = name
144 if not is_dataclass(obj):
145 msg = f"Only dataclasses can be registered. {obj.__name__} is not a dataclass."
146 raise ValueError(msg)
147 register_hydra_schema(obj, name) # type error will be fixed in pyright: 1.1.381+
148 return obj
149
150 else:
151 # dummy, to avoid errors
152 def _register(obj: "type[_AnAgentConfig]") -> "type[_AnAgentConfig]":
153 return obj
154
155 return _register
156
157
158def load_config_schema(name: str) -> Any:
159 return config_store.load(name).node
160
161
162# ---------------------
163# Helper methods
164# ---------------------
165
166
167def set_readonly_keys(conf: Union[DictConfig, ListConfig], keys: List[str]):
168 """
169 Sets nodes to readonly.
170
171 See: https://github.com/omry/omegaconf/issues/1161
172 """
173 if isinstance(keys, str):
174 keys = [keys]
175 for key in keys:
176 OmegaConf.set_readonly(conf._get_node(key), True) # pyright: ignore[reportArgumentType]
177
178
179def set_readonly_interpolations(conf: Union[DictConfig, ListConfig]):
180 """
181 Sets all interpolations to readonly.
182
183 See: https://github.com/omry/omegaconf/issues/1161
184 """
185 if conf._is_interpolation():
186 OmegaConf.set_readonly(conf, True)
187 elif isinstance(conf, DictConfig):
188 for key in conf:
189 set_readonly_interpolations(conf._get_node(key)) # pyright: ignore[reportArgumentType]
190 elif isinstance(conf, ListConfig): # type: ignore
191 for key in range(len(conf)):
192 set_readonly_interpolations(conf._get_node(key)) # pyright: ignore[reportArgumentType]
193 else:
194 print("WARNING: Could not set readonly for", type(conf))
195
196
197_NOTSET = object()
198"""Sentinel value for not set default values."""
199
200# ----- Type Annotations -----
201
202ConfigType = TypeVar("ConfigType", DictConfig, "AgentConfig")
203"""
204Generic of an object that is a :py:class:`omegaconf.DictConfig`
205or a subclass of :py:class:`AgentConfig`.
206"""
207
208_T = TypeVar("_T")
209MT = TypeVar("MT", Dict[str, Any], DictConfig)
210"""A generic type variable for a mapping type."""
211
212_AnAgentConfig = TypeVar("_AnAgentConfig", bound="AgentConfig")
213"""A generic :py:class:`AgentConfig` type variable"""
214
215AsDictConfig = TypeAliasType("AsDictConfig", _AnAgentConfig, type_params=(_AnAgentConfig,))
216"""
217This annotation hints that object is a duck-typed :py:class:`omegaconf.DictConfig`
218and not a subclass of :py:class:`AgentConfig`.
219"""
220
221# Problem in Sphinx is entered twice and creates large signatures
222
223"""
224Allowed types for nested config
225"""
226
227# Special annotations
228if READTHEDOCS and not TYPE_CHECKING:
229 from typing_extensions import TypeAliasType
230
231 # annotate MISSING instead of ???
232 MISSING = TypeAliasType("MISSING", Any)
233 """
234 Alias for :py:obj:`omegaconf.MISSING`, is literally :python:`"???"` but has type :python:`Any`.
235
236 If an attribute with this value is accessed from a :py:class:`DictConfig`,
237 it will raise a :py:exc:`MissingMandatoryValue` error.
238
239 :meta hide-value:
240 :meta public:
241 """
242
243 # prevent unpack of nested types
244 NestedConfigDict = TypeAliasType(
245 "NestedConfigDict", dict[str, "AgentConfig | DictConfig | Any | NestedConfigDict"]
246 ) # type: ignore
247 """
248 Type alias for nested configurations: :python:`Dict[str, NestedConfigDict | AgentConfig | DictConfig | Any]`
249
250 :meta hide-value:
251 """
252else:
253 NestedConfigDict: TypeAlias = Dict[str, "AgentConfig | DictConfig | Any | NestedConfigDict"]
254
255_NestedStrDict = Dict[str, "str | _NestedStrDict"]
256"""Nested dict with str as leaves"""
257
258if TYPE_CHECKING:
259 # AgentConfig parent should include DictConfig interface; without being a DictConfig
260 # BaseContainer adds the methods, however is ABC with more methods
261 from omegaconf.basecontainer import BaseContainer # noqa: F401
262
263 # More informative types when type checking; need primitive types at runtime
264 DictConfigAlias: TypeAlias = DictConfig | NestedConfigDict
265 OverwriteDictTypes: TypeAlias = dict[str, dict[str, NestedConfigDict] | "AgentConfig"]
266
267 class DictConfigLike(DictConfig):
268 """
269 Duck-typed DictConfig still appears like a DictConfig.
270
271 Note:
272 At runtime this is just :py:class:`object`.
273 """
274
275 keys = DictConfig.keys
276 values = DictConfig.values
277
278else:
279 # primitive type at runtime
280 DictConfigAlias: TypeAlias = Dict[str, Any]
281 OverwriteDictTypes: TypeAlias = Dict[str, Dict[str, Any]]
282 DictConfigLike = object
283
284# --------------- YAML Export -----------------
285
286PATH_FIELD_NAME = "config_path"
287
288
289def extract_annotations(
290 parent: "ast.Module", docs: Dict[str, _NestedStrDict], global_annotations: Dict[str, _NestedStrDict]
291):
292 """Extracts comments from the source code"""
293 for main_body in parent.body:
294 # Skip non-classes
295 if not isinstance(main_body, ast.ClassDef):
296 continue
297 if main_body.name in ("AgentConfig", "SimpleConfig", "class_or_instance_method", "_from_config_default_rules"):
298 continue
299 docs[main_body.name] = {}
300 for base in reversed(main_body.bases):
301 # Fill in parent information
302 try:
303 if (
304 isinstance(base, ast.IfExp) and base.test.id == "TYPE_CHECKING"
305 ): # (DictConfig if TYPE_CHECKING else object):
306 continue
307 docs[main_body.name].update(docs.get(base.id, {})) # pyright: ignore[reportUnknownArgumentType]
308 except Exception:
309 logging.exception("Error in %s", main_body.name)
310
311 target: str
312 for i, body in enumerate(main_body.body):
313 if isinstance(body, ast.ClassDef):
314 # Nested classes, extract recursive
315 extract_annotations(ast.Module([body], type_ignores=[]), docs[main_body.name], global_annotations) # type: ignore[arg-type]
316 continue
317 if isinstance(body, ast.AnnAssign):
318 target = body.target.id
319 continue
320 if isinstance(body, ast.Assign):
321 target = body.targets[0].id
322 continue
323 if isinstance(body, ast.Expr):
324 try:
325 # NOTE: This is different for <Python3.8; this is ast.Str
326 doc: str = body.value.value # type: ignore
327 except AttributeError:
328 # Try < 3.8 code
329 doc = body.value.s # type: ignore
330 assert isinstance(doc, str)
331 if i == 0: # Docstring of class
332 target = "__doc__"
333 # else: use last found target
334 else:
335 continue
336
337 if doc.startswith(".. <take doc|") and doc.endswith(">"):
338 key = doc[len(".. <take doc|") : -1]
339 try:
340 docs[main_body.name][target] = docs[main_body.name][key]
341 except KeyError as e:
342 try:
343 # Do global look up, docs is here _class_annotations
344 # Move fitting sub-class to key
345 docs[main_body.name][target] = global_annotations[key] # pyright: ignore[reportOptionalSubscript]
346 continue
347 except Exception:
348 pass
349 msg = f"{key} needs to be defined before {target} or globally"
350 raise NameError(msg) from e
351 continue
352 doc = inspect.cleandoc(doc)
353 if target == "__doc__":
354 header = (
355 "-" * len(main_body.name)
356 ) + "\n" # + main_body.name + "\n" + ("-" * len(main_body.name)) + "\n" + doc
357 footer = "\n" + ("-" * len(main_body.name))
358
359 if doc.startswith(".. @package"):
360 start = doc.find("\n") + 1
361 if start == 0:
362 # no linebreak found
363 doc += "\n"
364 start = doc.find("\n") + 1
365 header += main_body.name + "\n" + ("-" * len(main_body.name)) + "\n"
366 if doc[start:].lstrip():
367 doc = doc[3:start] + header + doc[start:].lstrip() + footer + "\n\n"
368 else:
369 # no content beside package
370 doc = doc[3:start] + header + "\n"
371 else:
372 doc = header + doc + footer
373 # remove rst
374 doc = re.sub(r":py:\w+:\\?`[~.!]*(.+?)\\?`", r"`\1`", doc)
375 doc = re.sub(r":(?::|\w|-)+?:`+(.+?)`+", r"`\1`", doc)
376 docs[main_body.name][target] = doc
377 del target # delete to get better errors
378 del doc
379
380
381class_annotations: Optional[Dict[str, _NestedStrDict]] = None
382"""Nested documentation strings for classes; used for YAML comments."""
383
384
385def get_commented_yaml(
386 cls_or_self: Union[type[AgentConfig], AgentConfig],
387 string: str,
388 container: "DictConfig | NestedConfigDict",
389 *,
390 include_private: bool = False,
391) -> str:
392 cls = cls_or_self if inspect.isclass(cls_or_self) else cls_or_self.__class__
393 cls_file = inspect.getfile(cls)
394 # Get documentations and store globally
395 global class_annotations # noqa: PLW0603
396 if class_annotations is None:
397 tree = ast_parse(Path(cls_file).read_text())
398 class_annotations = {}
399 extract_annotations(tree, docs=class_annotations, global_annotations=class_annotations)
400
401 from ruamel.yaml import YAML # optional # noqa: PLC0415
402
403 yaml2 = YAML(typ="rt")
404 # container = OmegaConf.to_container(options, resolve=False, enum_to_str=True, structured_config_mode=SCMode.DICT)
405 data: CommentedMap = yaml2.load(string)
406
407 cls_doc = class_annotations[cls.__name__]
408
409 # First line
410 data.yaml_set_start_comment(cls_doc.get("__doc__", cls.__name__))
411
412 nested_data: list[CommentedMap] = []
413
414 # add comments to all other attributes
415 def add_comments(
416 container: "DictConfig | NestedConfigDict",
417 data: CommentedMap,
418 lookup: Union[AgentConfig, _NestedStrDict],
419 indent: int = 0,
420 ):
421 """
422 Recursively adds comments to the YAML output.
423
424 Args:
425 container: The current dict to be commented
426 lookup: The lookup dictionary for docstrings
427 """
428 nested_data.append(data)
429 if isinstance(container, DictConfig):
430 containeritems = container.items_ex(resolve=False)
431 else:
432 containeritems = container.items()
433 for key, value in containeritems:
434 if TYPE_CHECKING:
435 assert isinstance(key, str)
436 if isinstance(value, dict) and isinstance(cls_doc.get(key, None), dict):
437 # Add nested comments
438 add_comments(value, data[key], cls_doc[key], indent=indent + 2) # type: ignore[arg-type]
439 comment_txt = "\n" + cls_doc[key].get("__doc__", "") # type: ignore
440 assert isinstance(comment_txt, str)
441 # no @package in subfields
442 if comment_txt.startswith("\n@package "): # already striped here
443 comment_txt = "\n".join(comment_txt.split("\n")[2:]).strip()
444 else:
445 comment_txt = lookup.get(key, None)
446 if comment_txt is None:
447 continue
448 if isinstance(comment_txt, dict):
449 # Add comments for nested dataclasses
450 try:
451 add_comments(comment_txt, data[key], comment_txt, indent=indent + 2)
452 except KeyError:
453 # double nested will throw a KeyError here as key not in data; will only be the
454 # variable name of the nested dataclass; seems to be okay.
455 # NOTE: logging level might only be on WARNING here!
456 logging.debug(
457 "KeyError for %s in %s when adding comments. "
458 "This should be okay, report if descriptions are missing.",
459 key,
460 cls.__name__,
461 )
462 continue
463 if (":meta exclude:" in comment_txt) or (not include_private and ":meta private:" in comment_txt):
464 data.pop(key)
465 continue # Skip private fields; TODO does not skip "_named" fields, is that a problem?
466 comment_txt = comment_txt.replace("\n\n", "\n \n")
467 if comment_txt.count("\n") > 0:
468 comment_txt = "\n" + comment_txt
469 data.yaml_set_comment_before_after_key(key, comment_txt, indent=indent)
470
471 # top_container = container # for debugging
472 add_comments(container, data, cls_doc) # pyright: ignore[reportArgumentType]
473 # data.yaml_add_eol_comment(comment_txt, key = key)
474
475 has_null_entry = re.findall(r"^\s*\w+: null$", string, re.MULTILINE)
476
477 stream = io.StringIO()
478 yaml2.dump(data, stream)
479 stream.seek(0)
480 string = stream.read()
481 # Fixes:
482 if "rss" in data:
483 start = string.find("use_stay_on_road_feature: ")
484 end = string.find("\n", start)
485 # quote On/Off; to not be interpreted as boolean
486 string = (
487 string[: start + len("use_stay_on_road_feature: ")]
488 + "'"
489 + string[start + len("use_stay_on_road_feature: ") : end]
490 + "'"
491 + string[end:]
492 )
493 # entry: null has been replaced by entry:\n
494 if has_null_entry:
495 entry: str
496 for entry in has_null_entry:
497 parts = entry.partition(":")
498 if parts[2] != " null":
499 logging.debug(
500 "Warning: %s for entry %s. Entry is not ' null'. This should not happen", cls.__name__, entry
501 )
502 continue
503 entry = parts[0] + ":" # noqa: PLW2901 # entry should be the same
504 string = re.sub(rf"^{entry}$", entry + " null", string, flags=re.MULTILINE)
505 return string
506
507
508def to_yaml(
509 cls_or_self: Union[type[AgentConfig], AgentConfig],
510 resolve: bool = False,
511 yaml_commented: bool = True,
512 detailed_rules: bool = False,
513 *,
514 include_private: bool = False,
515) -> str:
516 """
517 Convert the options to a YAML string representation.
518
519 Args:
520 resolve : Whether to resolve interpolations. Defaults to :code:`False`.
521 yaml_commented : Whether to include comments in the YAML output. Defaults to :python:`True`.
522 detailed_rules : Whether to include detailed rules in the YAML output. Defaults to :code:`False`.
523 include_private : Whether to include fields that are marked as private. Defaults to :code:`False`.
524
525 Returns:
526 The YAML string representation of the options.
527 """
528 cfg: DictConfig = OmegaConf.structured(cls_or_self, flags={"allow_objects": True})
529
530 if (inspect.isclass(cls_or_self) and cls_or_self.__name__ == "LunaticAgentSettings") or (
531 isinstance(cls_or_self, object) and cls_or_self.__class__.__name__ == "LunaticAgentSettings"
532 ):
533 with open_dict(cfg):
534 del cfg["self"]
535 del cfg["current_rule"]
536 if "rules" in cfg:
537 # Validate and remove missing keys for the yaml export
538 if TYPE_CHECKING:
539 assert isinstance(cfg, LunaticAgentSettings)
540 rules: List[RuleCreatingParameters] = cfg.rules
541 masked_rules: list[DictConfig] = []
542 for rule_cfg in rules:
543 if "phases" in rule_cfg.keys():
544 if TYPE_CHECKING:
545 assert isinstance(rule_cfg, CreateRuleFromConfig)
546 # > CreateRuleFromConfig
547 if detailed_rules:
548 # Circular import can only call this after agents.rules
549 try:
550 from agents.rules import rule_from_config # noqa: PLC0415
551 except ImportError:
552 print(
553 "Could not import agents.rules.rule_from_config. Set detailed_rules=False to avoid this error. Call this function somewhere else."
554 )
555 raise
556 rule: Rule = rule_from_config(rule_cfg)
557 self_config: RuleConfig = rule.self_config
558 if OmegaConf.is_missing(rule_cfg, "phases"):
559 rule_cfg.phases = next(iter(self_config.instance.phases)) # only support one atm
560 with open_dict(self_config):
561 del self_config["instance"]
562 if OmegaConf.is_missing(rule_cfg, "self_config"):
563 rule_cfg.self_config = OmegaConf.to_container(self_config, enum_to_str=True) # type: ignore
564 else:
565 try:
566 rule_cfg.self_config.update(self_config)
567 except Exception:
568 with open_dict(rule_cfg):
569 rule_cfg.self_config = OmegaConf.to_container(
570 OmegaConf.merge(self_config, rule_cfg.self_config), enum_to_str=True
571 ) # type: ignore
572
573 if "phases" in rule_cfg and not isinstance(rule_cfg.phases, str):
574 assert isinstance(rule_cfg.phases, Phase), (
575 "Currently only supports a Phase as string or Phase object."
576 )
577 rule_cfg.phases = str(rule_cfg.phases)
578
579 if detailed_rules:
580 assert not OmegaConf.is_missing(rule_cfg, "phases")
581
582 # NOTE: For some reason "_args_" in rule does NOT WORK
583 elif "_args_" in rule_cfg.keys() and OmegaConf.is_missing(rule_cfg, key="_args_"):
584 # check > CallFunctionFromConfig
585 msg = (
586 f"{rule_cfg} has no phase or (positional) `_args_` key. Did you forget to add a phase?"
587 "If the _target_ is a function, still prove an empty `_args_ = []` key."
588 )
589 raise ValueError(msg)
590 missing_keys = {k for k in rule_cfg.keys() if OmegaConf.is_missing(rule_cfg, k)}
591 clean_rule = OmegaConf.masked_copy(rule_cfg, set(rule_cfg.keys()) - missing_keys) # pyright: ignore[reportArgumentType]
592 masked_rules.append(clean_rule)
593 cfg.rules = masked_rules # type: ignore[attr-defined]
594
595 container: Dict[str, Any] = OmegaConf.to_container(cfg, resolve=resolve, enum_to_str=True) # pyright: ignore[reportAssignmentType]
596 if AD_RSS_AVAILABLE:
597
598 def replace_carla_enum(content: _T) -> _T:
599 # retrieve name from the stubs
600 if isinstance(content, carla.RssLogLevel):
601 return RssLogLevelStub(content).name
602 if isinstance(content, carla.RssRoadBoundariesMode):
603 return RssRoadBoundariesModeStub(content).name
604 return content
605
606 def recursive_replace(content: _T) -> _T:
607 if isinstance(content, dict):
608 return {k: recursive_replace(v) for k, v in content.items()} # type: ignore
609 if isinstance(content, list):
610 return [recursive_replace(v) for v in content] # type: ignore
611 return replace_carla_enum(content)
612
613 container = recursive_replace(container)
614 string = yaml.dump(
615 container,
616 default_flow_style=False,
617 allow_unicode=True,
618 sort_keys=False,
619 Dumper=get_omega_conf_dumper(),
620 )
621 if not yaml_commented:
622 return string
623 # Extend
624 return get_commented_yaml(cls_or_self, string, container, include_private=include_private) # type: ignore[arg-type]
625
626
627def export_options(
628 cls_or_self: Union[type[AgentConfig], AgentConfig],
629 path: Union[str, "os.PathLike[str]"],
630 *,
631 resolve: bool = False,
632 with_comments: bool = False,
633 detailed_rules: bool = False,
634 include_private: bool = False,
635) -> None:
636 """
637 Exports the options to a YAML file. With the :py:meth:`to_yaml` method.
638
639 Args:
640 path : The path for the exported YAML file.
641 resolve : Whether to resolve the options before exporting. Defaults to False.
642 with_comments : Whether to include comments in the exported YAML file. Defaults to False.
643 detailed_rules : Whether to include detailed rules in the exported YAML file. Defaults to False.
644 include_private : Whether to include private fields in the exported YAML file. Defaults to False.
645
646 Returns:
647 None
648 """
649 options = cls_or_self() if inspect.isclass(cls_or_self) else cls_or_self # type: ignore[call-arg]
650 if with_comments:
651 string = cls_or_self.to_yaml(
652 resolve=resolve, yaml_commented=True, detailed_rules=detailed_rules, include_private=include_private
653 )
654 Path(os.path.split(path)[0]).mkdir(parents=True, exist_ok=True)
655 Path(path).write_text(string)
656 return
657 if not isinstance(options, DictConfig):
658 # TODO: look how we can do this directly from dataclass
659 options = OmegaConf.create(options, flags={"allow_objects": True}) # type: ignore
660 OmegaConf.save(
661 options, path, resolve=resolve
662 ) # NOTE: This might raise if options is structured, for export structured this is actually not necessary. # type: ignore[argument-type]
663
664
665# --------------- Other Tools -----------------
666
667
668def set_container_type(base: "type[AgentConfig]", container: Union[NestedConfigDict, "AgentConfig"]) -> None:
669 """
670 Sets the object_type for sub configs if the config has been initialized with
671 a :py:class:`omegaconf.DictConfig` and not the respective AgentConfig subclass.
672
673 Args:
674 base : The base / duck type the container should have.
675 container : The passed value.
676 """
677 try:
678 annotations = get_type_hints(base)
679 except TypeError:
680 logging.debug("Error getting type hints for %s with container %s", base.__name__, type(container))
681 return
682 keys: "list[str]" = container.__dataclass_fields__.keys() if is_dataclass(container) else container.keys() # type: ignore
683 for key in keys:
684 if key == "overwrites" or key not in annotations:
685 continue
686 if isinstance(container, (DictConfig, ListConfig)) and (
687 OmegaConf.is_interpolation(container, key) or key not in container
688 ):
689 continue
690 try:
691 value = getattr(container, key, MISSING)
692 except MissingMandatoryValue:
693 continue
694 if value == MISSING:
695 continue
696 typ = annotations[key]
697 if is_structured_config(typ): # is structured dataclass or attrs
698 if OmegaConf.get_type(value) is dict: # but is not
699 if isinstance(value, DictConfig):
700 # value._metadata.object_type = typ
701 if hasattr(typ, "create"):
702 setattr(container, key, typ.create(value, as_dictconfig=True))
703 else:
704 setattr(container, key, OmegaConf.structured(typ(**value), flags={"allow_objects": True}))
705 # Below might rise type-errors if the schema is not correct
706 elif hasattr(typ, "uses_overwrite_interface") and typ.uses_overwrite_interface(): # type: ignore[attr-defined]
707 setattr(container, key, typ(overwrites=value)) # type: ignore[arg-type]
708 else:
709 setattr(container, key, typ(**value))
710 if isinstance(value, (DictConfig, dict)) or is_dataclass(value):
711 set_container_type(typ, value) # type: ignore[arg-type]
712
713
714def _flatten_dict(source: NestedConfigDict, target: NestedConfigDict, resolve: bool = False) -> None:
715 if isinstance(source, DictConfig):
716 items = source.items_ex(resolve=resolve)
717 else:
718 items = source.items() # normal case after to_container
719 for k, v in items:
720 if isinstance(v, dict):
721 _flatten_dict(v, target)
722 else:
723 if k in target:
724 print(f"Warning: Key '{k}'={target[k]} already exists in target. Overwriting with {v}.")
725 target[k] = v # type: ignore[arg-type]
726
727
728def flatten_config(config: "type[AgentConfig] | AgentConfig", *, resolve: bool = True) -> Dict[str, Any]:
729 """
730 Returns the data as a flat hierarchy.
731
732 Note:
733 Interpolations are replaced by default.
734 For example :py:attr:`target_speed` and :py:attr:`max_speed` are two *different* references.
735 Also with **resolve=False** the interpolation will be just a string value, as the return
736 type is a normal dictionary.
737 """
738 try:
739 resolved = cast(
740 "NestedConfigDict",
741 OmegaConf.to_container(
742 OmegaConf.structured(config, flags={"allow_objects": True}),
743 resolve=resolve,
744 throw_on_missing=False,
745 structured_config_mode=SCMode.DICT,
746 ),
747 )
748 except omegaconf.errors.InterpolationToMissingValueError:
749 print(
750 "Resolving has failed because a missing value has been accessed. "
751 "Fill all missing values before calling this function or use `resolve=False`."
752 )
753 # NOTE: alternatively call again with resolve=False
754 raise
755 options: Dict[str, Any] = {}
756 _flatten_dict(resolved, options, resolve=resolve)
757 return options
758
759
760if TYPE_CHECKING:
761
762 @type_check_only
763 class ConfigWithOverwrites(Protocol["AgentConfigT"], AgentConfig): # pyright: ignore[reportGeneralTypeIssues]
764 """
765 :meta private:
766 """
767
768 def __new__(
769 cls, overwrites: Optional["OverwriteDictTypes | NestedConfigDict | AgentConfigT"] = None, *args, **kwargs
770 ) -> "AgentConfigT":
771 """
772 :meta public:
773 """
774 ...
775
776 @type_check_only
777 class ConfigWithoutOverwrites(Protocol["AgentConfigT"], AgentConfig): # pyright: ignore[reportGeneralTypeIssues]
778 """
779 :meta private:
780 """
781
782 def __new__(cls, *args, **kwargs) -> type[AgentConfigT]:
783 """
784 :meta public:
785 """
786 ...