Source code for agents.tools._config_tools

  1"""
  2Helper Tools for :py:mod:`.config_tools`.
  3"""
  4
  5# pyright: reportPrivateUsage=false, reportUnknownLambdaType=false, reportUnusedClass=false
  6# pyright: reportPossiblyUnboundVariable=information,reportAttributeAccessIssue=warning
  7# pyright: reportUnknownVariableType=information, reportUnknownMemberType=information
  8from __future__ import annotations
  9
 10import ast
 11import inspect
 12import io
 13import logging
 14import os
 15from pathlib import Path
 16import re
 17import yaml
 18
 19from dataclasses import is_dataclass
 20from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast, get_type_hints
 21
 22import carla
 23import omegaconf.errors
 24from hydra.core.config_store import ConfigStore
 25from omegaconf import MISSING, DictConfig, ListConfig, MissingMandatoryValue, OmegaConf, SCMode, open_dict
 26from omegaconf._utils import is_structured_config, get_omega_conf_dumper  # noqa: PLC2701
 27from typing_extensions import Protocol, TypeAlias, TypeAliasType, TypeVar
 28
 29from classes.constants import AD_RSS_AVAILABLE, Phase, RssLogLevelStub, RssRoadBoundariesModeStub, READTHEDOCS
 30from launch_tools import ast_parse
 31
 32if TYPE_CHECKING:
 33    from typing import type_check_only
 34    from classes.type_protocols import AgentConfigT
 35    from agents.tools.config_creation import (
 36        AgentConfig,
 37        CreateRuleFromConfig,
 38        LunaticAgentSettings,
 39        RuleConfig,
 40        RuleCreatingParameters,
 41    )
 42    from classes.rule import Rule
 43    from ruamel.yaml.comments import CommentedMap
 44
 45# ----------- Resolvers -------------
 46
 47
 48def look_ahead_time(speed: float, time_to_collision: float, plus: float = 0) -> float:
 49    """
 50    Convert the current speed in km /h and a time to collision in seconds to a distance in meters
 51    and adds a slight buffer on top.
 52
 53    Use as :python:`"${look_ahead_time: ${live_info.current_speed}, time, }"`
 54    """
 55    return speed / 3.6 * time_to_collision + plus  # km / h * s = m
 56
 57
 58# need this check for readthedocs
 59if not READTHEDOCS and os.environ.get("_OMEGACONF_RESOLVERS_REGISTERED", "0") == "0":
 60    import random
 61    import operator
 62
 63    OmegaConf.register_new_resolver("add", operator.add)  # type: ignore[arg-type]
 64    OmegaConf.register_new_resolver("sub", operator.sub)  # type: ignore[arg-type]
 65    OmegaConf.register_new_resolver("mul", operator.mul)  # type: ignore[arg-type]
 66    OmegaConf.register_new_resolver("divide", operator.truediv)  # type: ignore[arg-type]
 67    OmegaConf.register_new_resolver("min", lambda *els: min(els))
 68    OmegaConf.register_new_resolver("max", lambda *els: max(els))
 69    OmegaConf.register_new_resolver("randint", random.randint)
 70    OmegaConf.register_new_resolver("randuniform", random.uniform)
 71    OmegaConf.register_new_resolver("look_ahead_time", look_ahead_time)
 72    os.environ["_OMEGACONF_RESOLVERS_REGISTERED"] = "1"
 73
 74
 75CONFIG_SCHEMA_NAME = "launch_config_schema.yaml"
 76"""Name to use for the launch_config as it cannot be launch_config itself."""
 77
 78
 79config_store = ConfigStore.instance()
 80"""Hydra_ 's ConfigStore instance to access config schemas."""
 81
 82# POSTPONE_REGISTER = sys.version_info < (3, 10)
 83# postpone_register: dict[str, type[Any]] = {}
 84
 85
 86def register_hydra_schema(obj: "type[Any]", name: Optional[str] = None):
 87    """
 88    Uses Hydra's ConfigStore to register the schema of the current class in the
 89    :py:obj:`ConfigStore <config_store>`.
 90
 91    See Also:
 92        :py:func:`config_path`
 93    """
 94    if name is None:
 95        name = cast("str", getattr(obj, "_config_path", obj.__name__))
 96    # if not POSTPOND_REGISTER:
 97    #    pass
 98    # else:
 99    #    postpond_register[name] = obj
100    config_store.store(
101        name,
102        OmegaConf.structured(obj, flags={"allow_objects": True}),
103        provider="agents.tools.config_creation",
104        group=None,
105        package=obj.__module__,
106    )
107
108
[docs] 109def config_path(path: Optional[str] = None): 110 """ 111 Decorator to register the schema of the current class with Hydra's :py:obj:`ConfigStore<hydra>`.. 112 Use the path relative to the `launch_config.yaml`, where the config is stored to use. 113 114 Create subclasses in the following way: 115 116 .. code-block:: python 117 118 @config_path("agent/speed") 119 @dataclass 120 class AgentSpeedSettings(AgentConfig): 121 122 Attention: 123 - Use "/" as separator and not dots. 124 - This is used for the Hydra schema registration and repeated paths will overwrite each other. 125 - This value is inherited (if != :code:`NOT_GIVEN`), and the value of the parent is taken 126 as default. Do not type-hint this value it must be a ClassVar to not conflict with dataclasses. 127 128 Returns: 129 (Callable[[type[AgentConfig]], type[AgentConfig]]) Wrapper function to register the schema. 130 """ 131 132 if not READTHEDOCS: 133 134 def _register(obj: "type[_AnAgentConfig]") -> "type[_AnAgentConfig]": # pyright: ignore[reportRedeclaration] 135 name = obj._config_path if path is None else path 136 if name is None or name == "NOT_GIVEN": 137 msg = f"Path is not given for {obj.__name__}. Use @register('path/to/config.yaml') to set the path." 138 raise ValueError(msg) 139 dots = name.count(".") 140 if dots > 0 and not (dots == 1 and name.endswith(".yaml")): 141 msg = f"Use '/' as separator and not dots. E.g. {name.replace('.', '/')} and not '{name}'" 142 raise ValueError(msg) 143 obj._config_path = name 144 if not is_dataclass(obj): 145 msg = f"Only dataclasses can be registered. {obj.__name__} is not a dataclass." 146 raise ValueError(msg) 147 register_hydra_schema(obj, name) # type error will be fixed in pyright: 1.1.381+ 148 return obj 149 150 else: 151 # dummy, to avoid errors 152 def _register(obj: "type[_AnAgentConfig]") -> "type[_AnAgentConfig]": 153 return obj 154 155 return _register
156 157 158def load_config_schema(name: str) -> Any: 159 return config_store.load(name).node 160 161 162# --------------------- 163# Helper methods 164# --------------------- 165 166 167def set_readonly_keys(conf: Union[DictConfig, ListConfig], keys: List[str]): 168 """ 169 Sets nodes to readonly. 170 171 See: https://github.com/omry/omegaconf/issues/1161 172 """ 173 if isinstance(keys, str): 174 keys = [keys] 175 for key in keys: 176 OmegaConf.set_readonly(conf._get_node(key), True) # pyright: ignore[reportArgumentType] 177 178 179def set_readonly_interpolations(conf: Union[DictConfig, ListConfig]): 180 """ 181 Sets all interpolations to readonly. 182 183 See: https://github.com/omry/omegaconf/issues/1161 184 """ 185 if conf._is_interpolation(): 186 OmegaConf.set_readonly(conf, True) 187 elif isinstance(conf, DictConfig): 188 for key in conf: 189 set_readonly_interpolations(conf._get_node(key)) # pyright: ignore[reportArgumentType] 190 elif isinstance(conf, ListConfig): # type: ignore 191 for key in range(len(conf)): 192 set_readonly_interpolations(conf._get_node(key)) # pyright: ignore[reportArgumentType] 193 else: 194 print("WARNING: Could not set readonly for", type(conf)) 195 196 197_NOTSET = object() 198"""Sentinel value for not set default values.""" 199 200# ----- Type Annotations ----- 201 202ConfigType = TypeVar("ConfigType", DictConfig, "AgentConfig") 203""" 204Generic of an object that is a :py:class:`omegaconf.DictConfig` 205or a subclass of :py:class:`AgentConfig`. 206""" 207 208_T = TypeVar("_T") 209MT = TypeVar("MT", Dict[str, Any], DictConfig) 210"""A generic type variable for a mapping type.""" 211 212_AnAgentConfig = TypeVar("_AnAgentConfig", bound="AgentConfig") 213"""A generic :py:class:`AgentConfig` type variable""" 214 215AsDictConfig = TypeAliasType("AsDictConfig", _AnAgentConfig, type_params=(_AnAgentConfig,)) 216""" 217This annotation hints that object is a duck-typed :py:class:`omegaconf.DictConfig` 218and not a subclass of :py:class:`AgentConfig`. 219""" 220 221# Problem in Sphinx is entered twice and creates large signatures 222 223""" 224Allowed types for nested config 225""" 226 227# Special annotations 228if READTHEDOCS and not TYPE_CHECKING: 229 from typing_extensions import TypeAliasType 230 231 # annotate MISSING instead of ??? 232 MISSING = TypeAliasType("MISSING", Any) 233 """ 234 Alias for :py:obj:`omegaconf.MISSING`, is literally :python:`"???"` but has type :python:`Any`. 235 236 If an attribute with this value is accessed from a :py:class:`DictConfig`, 237 it will raise a :py:exc:`MissingMandatoryValue` error. 238 239 :meta hide-value: 240 :meta public: 241 """ 242 243 # prevent unpack of nested types 244 NestedConfigDict = TypeAliasType( 245 "NestedConfigDict", dict[str, "AgentConfig | DictConfig | Any | NestedConfigDict"] 246 ) # type: ignore 247 """ 248 Type alias for nested configurations: :python:`Dict[str, NestedConfigDict | AgentConfig | DictConfig | Any]` 249 250 :meta hide-value: 251 """ 252else: 253 NestedConfigDict: TypeAlias = Dict[str, "AgentConfig | DictConfig | Any | NestedConfigDict"] 254 255_NestedStrDict = Dict[str, "str | _NestedStrDict"] 256"""Nested dict with str as leaves""" 257 258if TYPE_CHECKING: 259 # AgentConfig parent should include DictConfig interface; without being a DictConfig 260 # BaseContainer adds the methods, however is ABC with more methods 261 from omegaconf.basecontainer import BaseContainer # noqa: F401 262 263 # More informative types when type checking; need primitive types at runtime 264 DictConfigAlias: TypeAlias = DictConfig | NestedConfigDict 265 OverwriteDictTypes: TypeAlias = dict[str, dict[str, NestedConfigDict] | "AgentConfig"] 266 267 class DictConfigLike(DictConfig): 268 """ 269 Duck-typed DictConfig still appears like a DictConfig. 270 271 Note: 272 At runtime this is just :py:class:`object`. 273 """ 274 275 keys = DictConfig.keys 276 values = DictConfig.values 277 278else: 279 # primitive type at runtime 280 DictConfigAlias: TypeAlias = Dict[str, Any] 281 OverwriteDictTypes: TypeAlias = Dict[str, Dict[str, Any]] 282 DictConfigLike = object 283 284# --------------- YAML Export ----------------- 285 286PATH_FIELD_NAME = "config_path" 287 288 289def extract_annotations( 290 parent: "ast.Module", docs: Dict[str, _NestedStrDict], global_annotations: Dict[str, _NestedStrDict] 291): 292 """Extracts comments from the source code""" 293 for main_body in parent.body: 294 # Skip non-classes 295 if not isinstance(main_body, ast.ClassDef): 296 continue 297 if main_body.name in ("AgentConfig", "SimpleConfig", "class_or_instance_method", "_from_config_default_rules"): 298 continue 299 docs[main_body.name] = {} 300 for base in reversed(main_body.bases): 301 # Fill in parent information 302 try: 303 if ( 304 isinstance(base, ast.IfExp) and base.test.id == "TYPE_CHECKING" 305 ): # (DictConfig if TYPE_CHECKING else object): 306 continue 307 docs[main_body.name].update(docs.get(base.id, {})) # pyright: ignore[reportUnknownArgumentType] 308 except Exception: 309 logging.exception("Error in %s", main_body.name) 310 311 target: str 312 for i, body in enumerate(main_body.body): 313 if isinstance(body, ast.ClassDef): 314 # Nested classes, extract recursive 315 extract_annotations(ast.Module([body], type_ignores=[]), docs[main_body.name], global_annotations) # type: ignore[arg-type] 316 continue 317 if isinstance(body, ast.AnnAssign): 318 target = body.target.id 319 continue 320 if isinstance(body, ast.Assign): 321 target = body.targets[0].id 322 continue 323 if isinstance(body, ast.Expr): 324 try: 325 # NOTE: This is different for <Python3.8; this is ast.Str 326 doc: str = body.value.value # type: ignore 327 except AttributeError: 328 # Try < 3.8 code 329 doc = body.value.s # type: ignore 330 assert isinstance(doc, str) 331 if i == 0: # Docstring of class 332 target = "__doc__" 333 # else: use last found target 334 else: 335 continue 336 337 if doc.startswith(".. <take doc|") and doc.endswith(">"): 338 key = doc[len(".. <take doc|") : -1] 339 try: 340 docs[main_body.name][target] = docs[main_body.name][key] 341 except KeyError as e: 342 try: 343 # Do global look up, docs is here _class_annotations 344 # Move fitting sub-class to key 345 docs[main_body.name][target] = global_annotations[key] # pyright: ignore[reportOptionalSubscript] 346 continue 347 except Exception: 348 pass 349 msg = f"{key} needs to be defined before {target} or globally" 350 raise NameError(msg) from e 351 continue 352 doc = inspect.cleandoc(doc) 353 if target == "__doc__": 354 header = ( 355 "-" * len(main_body.name) 356 ) + "\n" # + main_body.name + "\n" + ("-" * len(main_body.name)) + "\n" + doc 357 footer = "\n" + ("-" * len(main_body.name)) 358 359 if doc.startswith(".. @package"): 360 start = doc.find("\n") + 1 361 if start == 0: 362 # no linebreak found 363 doc += "\n" 364 start = doc.find("\n") + 1 365 header += main_body.name + "\n" + ("-" * len(main_body.name)) + "\n" 366 if doc[start:].lstrip(): 367 doc = doc[3:start] + header + doc[start:].lstrip() + footer + "\n\n" 368 else: 369 # no content beside package 370 doc = doc[3:start] + header + "\n" 371 else: 372 doc = header + doc + footer 373 # remove rst 374 doc = re.sub(r":py:\w+:\\?`[~.!]*(.+?)\\?`", r"`\1`", doc) 375 doc = re.sub(r":(?::|\w|-)+?:`+(.+?)`+", r"`\1`", doc) 376 docs[main_body.name][target] = doc 377 del target # delete to get better errors 378 del doc 379 380 381class_annotations: Optional[Dict[str, _NestedStrDict]] = None 382"""Nested documentation strings for classes; used for YAML comments.""" 383 384 385def get_commented_yaml( 386 cls_or_self: Union[type[AgentConfig], AgentConfig], 387 string: str, 388 container: "DictConfig | NestedConfigDict", 389 *, 390 include_private: bool = False, 391) -> str: 392 cls = cls_or_self if inspect.isclass(cls_or_self) else cls_or_self.__class__ 393 cls_file = inspect.getfile(cls) 394 # Get documentations and store globally 395 global class_annotations # noqa: PLW0603 396 if class_annotations is None: 397 tree = ast_parse(Path(cls_file).read_text()) 398 class_annotations = {} 399 extract_annotations(tree, docs=class_annotations, global_annotations=class_annotations) 400 401 from ruamel.yaml import YAML # optional # noqa: PLC0415 402 403 yaml2 = YAML(typ="rt") 404 # container = OmegaConf.to_container(options, resolve=False, enum_to_str=True, structured_config_mode=SCMode.DICT) 405 data: CommentedMap = yaml2.load(string) 406 407 cls_doc = class_annotations[cls.__name__] 408 409 # First line 410 data.yaml_set_start_comment(cls_doc.get("__doc__", cls.__name__)) 411 412 nested_data: list[CommentedMap] = [] 413 414 # add comments to all other attributes 415 def add_comments( 416 container: "DictConfig | NestedConfigDict", 417 data: CommentedMap, 418 lookup: Union[AgentConfig, _NestedStrDict], 419 indent: int = 0, 420 ): 421 """ 422 Recursively adds comments to the YAML output. 423 424 Args: 425 container: The current dict to be commented 426 lookup: The lookup dictionary for docstrings 427 """ 428 nested_data.append(data) 429 if isinstance(container, DictConfig): 430 containeritems = container.items_ex(resolve=False) 431 else: 432 containeritems = container.items() 433 for key, value in containeritems: 434 if TYPE_CHECKING: 435 assert isinstance(key, str) 436 if isinstance(value, dict) and isinstance(cls_doc.get(key, None), dict): 437 # Add nested comments 438 add_comments(value, data[key], cls_doc[key], indent=indent + 2) # type: ignore[arg-type] 439 comment_txt = "\n" + cls_doc[key].get("__doc__", "") # type: ignore 440 assert isinstance(comment_txt, str) 441 # no @package in subfields 442 if comment_txt.startswith("\n@package "): # already striped here 443 comment_txt = "\n".join(comment_txt.split("\n")[2:]).strip() 444 else: 445 comment_txt = lookup.get(key, None) 446 if comment_txt is None: 447 continue 448 if isinstance(comment_txt, dict): 449 # Add comments for nested dataclasses 450 try: 451 add_comments(comment_txt, data[key], comment_txt, indent=indent + 2) 452 except KeyError: 453 # double nested will throw a KeyError here as key not in data; will only be the 454 # variable name of the nested dataclass; seems to be okay. 455 # NOTE: logging level might only be on WARNING here! 456 logging.debug( 457 "KeyError for %s in %s when adding comments. " 458 "This should be okay, report if descriptions are missing.", 459 key, 460 cls.__name__, 461 ) 462 continue 463 if (":meta exclude:" in comment_txt) or (not include_private and ":meta private:" in comment_txt): 464 data.pop(key) 465 continue # Skip private fields; TODO does not skip "_named" fields, is that a problem? 466 comment_txt = comment_txt.replace("\n\n", "\n \n") 467 if comment_txt.count("\n") > 0: 468 comment_txt = "\n" + comment_txt 469 data.yaml_set_comment_before_after_key(key, comment_txt, indent=indent) 470 471 # top_container = container # for debugging 472 add_comments(container, data, cls_doc) # pyright: ignore[reportArgumentType] 473 # data.yaml_add_eol_comment(comment_txt, key = key) 474 475 has_null_entry = re.findall(r"^\s*\w+: null$", string, re.MULTILINE) 476 477 stream = io.StringIO() 478 yaml2.dump(data, stream) 479 stream.seek(0) 480 string = stream.read() 481 # Fixes: 482 if "rss" in data: 483 start = string.find("use_stay_on_road_feature: ") 484 end = string.find("\n", start) 485 # quote On/Off; to not be interpreted as boolean 486 string = ( 487 string[: start + len("use_stay_on_road_feature: ")] 488 + "'" 489 + string[start + len("use_stay_on_road_feature: ") : end] 490 + "'" 491 + string[end:] 492 ) 493 # entry: null has been replaced by entry:\n 494 if has_null_entry: 495 entry: str 496 for entry in has_null_entry: 497 parts = entry.partition(":") 498 if parts[2] != " null": 499 logging.debug( 500 "Warning: %s for entry %s. Entry is not ' null'. This should not happen", cls.__name__, entry 501 ) 502 continue 503 entry = parts[0] + ":" # noqa: PLW2901 # entry should be the same 504 string = re.sub(rf"^{entry}$", entry + " null", string, flags=re.MULTILINE) 505 return string 506 507 508def to_yaml( 509 cls_or_self: Union[type[AgentConfig], AgentConfig], 510 resolve: bool = False, 511 yaml_commented: bool = True, 512 detailed_rules: bool = False, 513 *, 514 include_private: bool = False, 515) -> str: 516 """ 517 Convert the options to a YAML string representation. 518 519 Args: 520 resolve : Whether to resolve interpolations. Defaults to :code:`False`. 521 yaml_commented : Whether to include comments in the YAML output. Defaults to :python:`True`. 522 detailed_rules : Whether to include detailed rules in the YAML output. Defaults to :code:`False`. 523 include_private : Whether to include fields that are marked as private. Defaults to :code:`False`. 524 525 Returns: 526 The YAML string representation of the options. 527 """ 528 cfg: DictConfig = OmegaConf.structured(cls_or_self, flags={"allow_objects": True}) 529 530 if (inspect.isclass(cls_or_self) and cls_or_self.__name__ == "LunaticAgentSettings") or ( 531 isinstance(cls_or_self, object) and cls_or_self.__class__.__name__ == "LunaticAgentSettings" 532 ): 533 with open_dict(cfg): 534 del cfg["self"] 535 del cfg["current_rule"] 536 if "rules" in cfg: 537 # Validate and remove missing keys for the yaml export 538 if TYPE_CHECKING: 539 assert isinstance(cfg, LunaticAgentSettings) 540 rules: List[RuleCreatingParameters] = cfg.rules 541 masked_rules: list[DictConfig] = [] 542 for rule_cfg in rules: 543 if "phases" in rule_cfg.keys(): 544 if TYPE_CHECKING: 545 assert isinstance(rule_cfg, CreateRuleFromConfig) 546 # > CreateRuleFromConfig 547 if detailed_rules: 548 # Circular import can only call this after agents.rules 549 try: 550 from agents.rules import rule_from_config # noqa: PLC0415 551 except ImportError: 552 print( 553 "Could not import agents.rules.rule_from_config. Set detailed_rules=False to avoid this error. Call this function somewhere else." 554 ) 555 raise 556 rule: Rule = rule_from_config(rule_cfg) 557 self_config: RuleConfig = rule.self_config 558 if OmegaConf.is_missing(rule_cfg, "phases"): 559 rule_cfg.phases = next(iter(self_config.instance.phases)) # only support one atm 560 with open_dict(self_config): 561 del self_config["instance"] 562 if OmegaConf.is_missing(rule_cfg, "self_config"): 563 rule_cfg.self_config = OmegaConf.to_container(self_config, enum_to_str=True) # type: ignore 564 else: 565 try: 566 rule_cfg.self_config.update(self_config) 567 except Exception: 568 with open_dict(rule_cfg): 569 rule_cfg.self_config = OmegaConf.to_container( 570 OmegaConf.merge(self_config, rule_cfg.self_config), enum_to_str=True 571 ) # type: ignore 572 573 if "phases" in rule_cfg and not isinstance(rule_cfg.phases, str): 574 assert isinstance(rule_cfg.phases, Phase), ( 575 "Currently only supports a Phase as string or Phase object." 576 ) 577 rule_cfg.phases = str(rule_cfg.phases) 578 579 if detailed_rules: 580 assert not OmegaConf.is_missing(rule_cfg, "phases") 581 582 # NOTE: For some reason "_args_" in rule does NOT WORK 583 elif "_args_" in rule_cfg.keys() and OmegaConf.is_missing(rule_cfg, key="_args_"): 584 # check > CallFunctionFromConfig 585 msg = ( 586 f"{rule_cfg} has no phase or (positional) `_args_` key. Did you forget to add a phase?" 587 "If the _target_ is a function, still prove an empty `_args_ = []` key." 588 ) 589 raise ValueError(msg) 590 missing_keys = {k for k in rule_cfg.keys() if OmegaConf.is_missing(rule_cfg, k)} 591 clean_rule = OmegaConf.masked_copy(rule_cfg, set(rule_cfg.keys()) - missing_keys) # pyright: ignore[reportArgumentType] 592 masked_rules.append(clean_rule) 593 cfg.rules = masked_rules # type: ignore[attr-defined] 594 595 container: Dict[str, Any] = OmegaConf.to_container(cfg, resolve=resolve, enum_to_str=True) # pyright: ignore[reportAssignmentType] 596 if AD_RSS_AVAILABLE: 597 598 def replace_carla_enum(content: _T) -> _T: 599 # retrieve name from the stubs 600 if isinstance(content, carla.RssLogLevel): 601 return RssLogLevelStub(content).name 602 if isinstance(content, carla.RssRoadBoundariesMode): 603 return RssRoadBoundariesModeStub(content).name 604 return content 605 606 def recursive_replace(content: _T) -> _T: 607 if isinstance(content, dict): 608 return {k: recursive_replace(v) for k, v in content.items()} # type: ignore 609 if isinstance(content, list): 610 return [recursive_replace(v) for v in content] # type: ignore 611 return replace_carla_enum(content) 612 613 container = recursive_replace(container) 614 string = yaml.dump( 615 container, 616 default_flow_style=False, 617 allow_unicode=True, 618 sort_keys=False, 619 Dumper=get_omega_conf_dumper(), 620 ) 621 if not yaml_commented: 622 return string 623 # Extend 624 return get_commented_yaml(cls_or_self, string, container, include_private=include_private) # type: ignore[arg-type] 625 626 627def export_options( 628 cls_or_self: Union[type[AgentConfig], AgentConfig], 629 path: Union[str, "os.PathLike[str]"], 630 *, 631 resolve: bool = False, 632 with_comments: bool = False, 633 detailed_rules: bool = False, 634 include_private: bool = False, 635) -> None: 636 """ 637 Exports the options to a YAML file. With the :py:meth:`to_yaml` method. 638 639 Args: 640 path : The path for the exported YAML file. 641 resolve : Whether to resolve the options before exporting. Defaults to False. 642 with_comments : Whether to include comments in the exported YAML file. Defaults to False. 643 detailed_rules : Whether to include detailed rules in the exported YAML file. Defaults to False. 644 include_private : Whether to include private fields in the exported YAML file. Defaults to False. 645 646 Returns: 647 None 648 """ 649 options = cls_or_self() if inspect.isclass(cls_or_self) else cls_or_self # type: ignore[call-arg] 650 if with_comments: 651 string = cls_or_self.to_yaml( 652 resolve=resolve, yaml_commented=True, detailed_rules=detailed_rules, include_private=include_private 653 ) 654 Path(os.path.split(path)[0]).mkdir(parents=True, exist_ok=True) 655 Path(path).write_text(string) 656 return 657 if not isinstance(options, DictConfig): 658 # TODO: look how we can do this directly from dataclass 659 options = OmegaConf.create(options, flags={"allow_objects": True}) # type: ignore 660 OmegaConf.save( 661 options, path, resolve=resolve 662 ) # NOTE: This might raise if options is structured, for export structured this is actually not necessary. # type: ignore[argument-type] 663 664 665# --------------- Other Tools ----------------- 666 667 668def set_container_type(base: "type[AgentConfig]", container: Union[NestedConfigDict, "AgentConfig"]) -> None: 669 """ 670 Sets the object_type for sub configs if the config has been initialized with 671 a :py:class:`omegaconf.DictConfig` and not the respective AgentConfig subclass. 672 673 Args: 674 base : The base / duck type the container should have. 675 container : The passed value. 676 """ 677 try: 678 annotations = get_type_hints(base) 679 except TypeError: 680 logging.debug("Error getting type hints for %s with container %s", base.__name__, type(container)) 681 return 682 keys: "list[str]" = container.__dataclass_fields__.keys() if is_dataclass(container) else container.keys() # type: ignore 683 for key in keys: 684 if key == "overwrites" or key not in annotations: 685 continue 686 if isinstance(container, (DictConfig, ListConfig)) and ( 687 OmegaConf.is_interpolation(container, key) or key not in container 688 ): 689 continue 690 try: 691 value = getattr(container, key, MISSING) 692 except MissingMandatoryValue: 693 continue 694 if value == MISSING: 695 continue 696 typ = annotations[key] 697 if is_structured_config(typ): # is structured dataclass or attrs 698 if OmegaConf.get_type(value) is dict: # but is not 699 if isinstance(value, DictConfig): 700 # value._metadata.object_type = typ 701 if hasattr(typ, "create"): 702 setattr(container, key, typ.create(value, as_dictconfig=True)) 703 else: 704 setattr(container, key, OmegaConf.structured(typ(**value), flags={"allow_objects": True})) 705 # Below might rise type-errors if the schema is not correct 706 elif hasattr(typ, "uses_overwrite_interface") and typ.uses_overwrite_interface(): # type: ignore[attr-defined] 707 setattr(container, key, typ(overwrites=value)) # type: ignore[arg-type] 708 else: 709 setattr(container, key, typ(**value)) 710 if isinstance(value, (DictConfig, dict)) or is_dataclass(value): 711 set_container_type(typ, value) # type: ignore[arg-type] 712 713 714def _flatten_dict(source: NestedConfigDict, target: NestedConfigDict, resolve: bool = False) -> None: 715 if isinstance(source, DictConfig): 716 items = source.items_ex(resolve=resolve) 717 else: 718 items = source.items() # normal case after to_container 719 for k, v in items: 720 if isinstance(v, dict): 721 _flatten_dict(v, target) 722 else: 723 if k in target: 724 print(f"Warning: Key '{k}'={target[k]} already exists in target. Overwriting with {v}.") 725 target[k] = v # type: ignore[arg-type] 726 727 728def flatten_config(config: "type[AgentConfig] | AgentConfig", *, resolve: bool = True) -> Dict[str, Any]: 729 """ 730 Returns the data as a flat hierarchy. 731 732 Note: 733 Interpolations are replaced by default. 734 For example :py:attr:`target_speed` and :py:attr:`max_speed` are two *different* references. 735 Also with **resolve=False** the interpolation will be just a string value, as the return 736 type is a normal dictionary. 737 """ 738 try: 739 resolved = cast( 740 "NestedConfigDict", 741 OmegaConf.to_container( 742 OmegaConf.structured(config, flags={"allow_objects": True}), 743 resolve=resolve, 744 throw_on_missing=False, 745 structured_config_mode=SCMode.DICT, 746 ), 747 ) 748 except omegaconf.errors.InterpolationToMissingValueError: 749 print( 750 "Resolving has failed because a missing value has been accessed. " 751 "Fill all missing values before calling this function or use `resolve=False`." 752 ) 753 # NOTE: alternatively call again with resolve=False 754 raise 755 options: Dict[str, Any] = {} 756 _flatten_dict(resolved, options, resolve=resolve) 757 return options 758 759 760if TYPE_CHECKING: 761 762 @type_check_only 763 class ConfigWithOverwrites(Protocol["AgentConfigT"], AgentConfig): # pyright: ignore[reportGeneralTypeIssues] 764 """ 765 :meta private: 766 """ 767 768 def __new__( 769 cls, overwrites: Optional["OverwriteDictTypes | NestedConfigDict | AgentConfigT"] = None, *args, **kwargs 770 ) -> "AgentConfigT": 771 """ 772 :meta public: 773 """ 774 ... 775 776 @type_check_only 777 class ConfigWithoutOverwrites(Protocol["AgentConfigT"], AgentConfig): # pyright: ignore[reportGeneralTypeIssues] 778 """ 779 :meta private: 780 """ 781 782 def __new__(cls, *args, **kwargs) -> type[AgentConfigT]: 783 """ 784 :meta public: 785 """ 786 ...