"""Config helper."""

import ast
import collections
import copy
from typing import Any, Dict, Text, List, Callable, Optional
import tensorflow as tf
import yaml

REQUIRED = '__required__'


def eval_str_fn(val: str) -> Any:
  if '|' in val:
    return [eval_str_fn(v) for v in val.split('|')]
  if val in {'true', 'false'}:
    return val == 'true'
  try:
    return ast.literal_eval(val)
  except (ValueError, SyntaxError):
    return val


# pylint: disable=protected-access
class Config(dict):
  """A config utility class."""

  def __init__(self, *args, **kwargs):
    super().__init__()
    input_config_dict = dict(*args, **kwargs)
    self.update(input_config_dict)

  def __len__(self):
    return len(self.__dict__)

  def __setattr__(self, k: str, v: Any) -> None:
    if isinstance(v, dict) and not isinstance(v, Config):
      self.__dict__[k] = Config(v)
    else:
      self.__dict__[k] = copy.deepcopy(v)

  def __getattr__(self, k: str) -> Any:
    return self.__dict__[k]

  def __setitem__(self, k: str, v: Any) -> None:
    self.__setattr__(k, v)

  def __getitem__(self, k: str) -> Any:
    return self.__dict__[k]

  def __contains__(self, k: str) -> bool:
    return self.__dict__.__contains__(k)

  def __iter__(self):
    for key in self.__dict__:
      yield key

  def items(self):
    for key, value in self.__dict__.items():
      yield key, value

  def replace(self, **kwargs) -> 'Config':
    """Deep copies and replaces some values."""
    cfg = copy.deepcopy(self)
    cfg.update(dict(**kwargs))
    return cfg

  def __repr__(self):
    return repr(self.as_dict())

  def __getstate__(self):
    return self.__dict__

  def __copy__(self):
    cls = self.__class__
    result = cls.__new__(cls)
    result.__dict__.update(self.__dict__)
    return result

  def __deepcopy__(self, memo: Any) -> 'Config':
    cls = self.__class__
    result = cls.__new__(cls)
    for k, v in self.__dict__.items():
      result[k] = v
    return result

  def __str__(self):
    try:
      return yaml.dump(self.as_dict(), indent=4)
    except TypeError:
      return str(self.as_dict())

  def _update(self, config_dict, allow_new_keys=True, skip_new_keys=False):
    """Recursively updates internal members."""
    if not config_dict:
      return

    for k, v in config_dict.items():
      if k not in self.__dict__:
        if allow_new_keys:
          self.__setattr__(k, v)
        elif skip_new_keys:
          pass
        else:
          raise KeyError('Key `{}` does not exist for overriding. '.format(k))
      else:
        if isinstance(self.__dict__[k], Config) and isinstance(v, dict):
          self.__dict__[k]._update(v, allow_new_keys)
        elif isinstance(self.__dict__[k], Config) and isinstance(v, Config):
          self.__dict__[k]._update(v.as_dict(), allow_new_keys)
        else:
          self.__setattr__(k, v)

  def get(self, k: str, default_value: Any = None) -> Any:
    return self.__dict__.get(k, default_value)

  def update(self, config_dict: Optional[Dict[Any, Any]] = None) -> None:
    """Updates members while allowing new keys."""
    if config_dict:
      self._update(config_dict, allow_new_keys=True)

  def update_dict(self, **kwargs):
    """Updates members while allowing new keys."""
    self._update(kwargs, allow_new_keys=True)

  def keys(self) -> Any:
    return self.__dict__.keys()

  def override_dict(self, skip_new_keys: bool = True, **kwargs) -> None:
    """Overrides members and skips new keys."""
    self._update(kwargs, allow_new_keys=False, skip_new_keys=skip_new_keys)

  def override(self, config_dict_or_str: ...,
               allow_new_keys: bool = False) -> None:
    """Updates members while disallowing new keys."""
    if not config_dict_or_str:
      return
    if isinstance(config_dict_or_str, str):
      if '=' in config_dict_or_str:
        config_dict = self.parse_from_str(config_dict_or_str)
      elif config_dict_or_str.endswith('.yaml'):
        config_dict = self.parse_from_yaml(config_dict_or_str)
      else:
        raise ValueError(
            'Invalid string {}, must end with .yaml or contains "=".'.format(
                config_dict_or_str))
    elif isinstance(config_dict_or_str, dict):
      config_dict = config_dict_or_str
    else:
      raise ValueError('Unknown value type: {}'.format(config_dict_or_str))

    self._update(config_dict, allow_new_keys)

  def parse_from_yaml(self, yaml_file_path: Text) -> Dict[Any, Any]:
    """Parses a yaml file and returns a dictionary."""
    with tf.io.gfile.GFile(yaml_file_path, 'r') as f:
      config_dict = yaml.load(f, Loader=yaml.FullLoader)
      return config_dict

  def save_to_yaml(self, yaml_file_path: str) -> None:
    """Writes a dictionary into a yaml file."""
    with tf.io.gfile.GFile(yaml_file_path, 'w') as f:
      yaml.dump(self.as_dict(), f, default_flow_style=False)

  def parse_from_str(self, config_str: Text) -> Dict[Any, Any]:
    """Parses a string like 'x.y=1,x.z=2' to nested dict {x: {y: 1, z: 2}}."""
    if not config_str:
      return {}
    config_dict = {}
    try:
      for kv_pair in config_str.split(','):
        # We skip the empty string here.
        if not kv_pair:
          continue
        key_str, value_str = kv_pair.split('=')
        key_str = key_str.strip()

        def add_kv_recursive(k, v):
          """Recursively parses x.y.z=tt to {x: {y: {z: tt}}}."""
          if '.' not in k:
            return {k: eval_str_fn(v)}
          pos = k.index('.')
          return {k[:pos]: add_kv_recursive(k[pos + 1:], v)}  # pylint: disable=cell-var-from-loop

        def merge_dict_recursive(target, src):
          """Recursively merges two nested dictionary."""
          for k in src.keys():
            if ((k in target and isinstance(target[k], dict) and
                 isinstance(src[k], collections.abc.Mapping))):
              merge_dict_recursive(target[k], src[k])  # pylint: disable=cell-var-from-loop
            else:
              target[k] = src[k]

        merge_dict_recursive(config_dict, add_kv_recursive(key_str, value_str))
      return config_dict
    except ValueError:
      raise ValueError(f'Invalid config_str: {config_str}') from None

  def as_dict(self) -> Dict[Any, Any]:
    """Returns a dict representation."""
    config_dict = {}
    for k, v in self.__dict__.items():
      if isinstance(v, Config):
        config_dict[k] = v.as_dict()
      elif isinstance(v, (list, tuple)):
        config_dict[k] = [
            i.as_dict() if isinstance(i, Config) else copy.deepcopy(i)
            for i in v
        ]
      else:
        config_dict[k] = copy.deepcopy(v)
    return config_dict

  def validate(self):
    required_keys = []
    for k, v in self.__dict__.items():
      if v == REQUIRED:
        required_keys.append(k)
    if required_keys:
      raise ValueError(f'Values are required for keys: {required_keys}')


class RegistryFactor():
  """A template for registry factory."""

  def __init__(self, prefix: str) -> None:
    self.registry_map = {}
    self.prefix = prefix

  def register(self, name: Optional[str] = None) -> Callable[[Any], Any]:
    """Registers a function, mainly for config here."""

    def decorator(cls):
      key = self.prefix + (name or cls.__name__.lower())
      if key in self.registry_map:
        raise ValueError(f'{key} is already registered')
      self.registry_map[key] = cls
      return cls

    return decorator

  def lookup(self, name: str) -> Config:
    """Looks up a class based on class name."""
    key = self.prefix + name.lower()
    if key not in self.registry_map:
      raise KeyError(f'{key} is not in {self.registry_map.keys()}')
    return self.registry_map[key]

  def keys(self, prefix: str = '') -> List[Config]:
    return [
        k[len(prefix):]
        for k in self.registry_map.keys()
        if k.startswith(prefix)
    ]


def create_config_from_dict(config_dict: Dict[str, Any],
                            required_keys: List[str],
                            optional_keys: Dict[str, Any]) -> Config:
  """Creates hparam config from dictionary."""
  config = Config()
  for key in required_keys:
    if key not in config_dict:
      raise ValueError('Required key %s missed in config dict.' % key)
    config[key] = config_dict[key]
  for key in optional_keys:
    if key not in config_dict:
      config[key] = optional_keys[key]
    else:
      config[key] = config_dict[key]
  return config
