# -*- coding: utf-8 -*-
import yaml


class ModelDict(dict):
  """
    layer:
         model_name
    'oL0': {
        'qO1': qA1,
        'qO1_old': deepcopy(qA1),
        'qO1_optim': qA1_optim,
        'qO2': qA2,
        'qO2_old': deepcopy(qA2),
        'qO2_optim': qA2_optim
    },
    'oL1': {
        'qO1': qO1,
        'qO1_old': deepcopy(qO1),
        'qO1_optim': qO1_optim,
        'qO2': qO2,
        'qO2_old': deepcopy(qO2),
        'qO2_optim': qO2_optim
    }
  """

  def __init__(self, init_dict=None):
    super().__init__()
    if init_dict is not None:
      self._recursive_update(init_dict)

  def _recursive_update(self, init_dict):
    for key, value in init_dict.items():
      if isinstance(value, dict):
        self[key] = ModelDict(value)
      else:
        self[key] = value

  def __getattr__(self, attr):
    if attr in self:
      return self[attr]
    else:
      raise AttributeError(f"'ParamsDict' object has no attribute '{attr}'")

  def set_qold_eval(self):
    for layer in self:
      for model in self[layer]:
        if model.endswith('_old'):
          self[layer][model].eval()

  def set_train(self, mode: bool):
    not_suffix = ['_old', '_optim']
    for layer in self:
      for model in self[layer]:
        if all([True if suf not in model else False for suf in not_suffix]):
          self[layer][model].train(mode)


class ConfigDict(dict):

  def __init__(self, config_dict):
    flat_config = self.flat(config_dict)
    super(ConfigDict, self).__init__(flat_config)
    self.__dict__ = self

  def __setattr__(self, __name: str, __value) -> None:
    return super().__setattr__(__name, __value)

  def flat(self, config):
    flat_config = dict()
    for k in config:
      if type(config[k]) == dict:
        # drop first level keys, promote second level keys to first level
        for sk, v in config[k].items():
          flat_config[sk] = v
      else:  # keep first level kv-pair unchanged
        flat_config[k] = config[k]

    return flat_config


def load_config(path="ts_ppo_config.yaml"):
  with open(path, 'r') as f:
    em_config = yaml.safe_load(f)
    debug_config = em_config.pop('debug_config')
    em_config['debug_flag'] = debug_config['debug_flag']

  config = ConfigDict(em_config)
  config = merge_debug(debug_config, config)
  return config


def merge_debug(debug_config, config):
  if debug_config['debug_flag']:
    config.update(debug_config)
  return config
