Source code for pytext.config.utils

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Dict, List, Union


[docs]def is_component_class(obj): first = obj.__class__.__name__[0] return first.upper() == first
[docs]def find_param(root, suffix, parent=""): """ Recursively look at all fields in config to find where `suffix` would fit. This is used to change configs so that they don't use default values. Return the list of field paths matching. """ ret = [] for k in getattr(root.__class__, "__annotations__", []): here = parent + k if here.endswith(suffix): ret += [here] v = getattr(root, k) if v is not None and is_component_class(type(v)): ret += find_param(v, suffix, parent=here + ".") return ret
[docs]def resolve_optional(type_v): """Deal with Optional implemented as Union[type, None]""" if getattr(type_v, "__origin__", None) == Union and len(type_v.__args__) == 2: if type_v.__args__[0] != type(None): return type_v.__args__[0] return type_v.__args__[1] return type_v
[docs]def cast_str(to_type, value): if type(value) != str: return value if to_type == int: return int(value) elif to_type == float: return float(value) elif to_type == str: return value elif to_type == bool: if value.lower() in ("yes", "true", "t", "1"): return True elif value.lower() in ("no", "false", "f", "0", ""): return False else: raise Exception(f'Not a boolean value: "{value}"') elif getattr(to_type, "__origin__", None) in (list, List): return [cast_str(to_type.__args__[0], v.strip()) for v in value.split(",")] elif getattr(to_type, "__origin__", None) in (dict, Dict): key_type, value_type = to_type.__args__ ret = {} for entry in value.split(","): k, v = entry.split(":") typed_k = cast_str(key_type, k) typed_v = cast_str(value_type, v) ret[typed_k] = typed_v return ret else: raise Exception(f"Unsupported type: {to_type}")
[docs]def replace_param(root, path_list, value): for here in path_list[:-1]: root = getattr(root, here) param_name = path_list[-1] annotation = root.__class__.__annotations__[param_name] type_root = resolve_optional(annotation) setattr(root, param_name, cast_str(type_root, value))