# file: prism/utils/config.py
import argparse
from collections.abc import MutableMapping
import yaml


class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        for key, value in self.items():
            if isinstance(value, dict):
                self[key] = AttrDict(value)

    def __getattr__(self, item):
        try:
            return self[item]
        except KeyError as e:
            raise AttributeError(f"'AttrDict' object has no attribute '{item}'") from e

    def __setattr__(self, key, value):
        self[key] = value

    def to_dict(self):
        d = {}
        for key, value in self.items():
            if isinstance(value, AttrDict):
                d[key] = value.to_dict()
            else:
                d[key] = value
        return d


def deep_merge(base_dict, new_dict):
    for key, value in new_dict.items():
        if isinstance(value, MutableMapping) and key in base_dict and isinstance(base_dict[key], MutableMapping):
            base_dict[key] = deep_merge(base_dict[key], value)
        else:
            base_dict[key] = value
    return base_dict


def apply_cli_overrides(config, overrides):
    parser = argparse.ArgumentParser(description="CLI Overrides", add_help=False)

    unique_keys = set()
    if overrides:
        for i in range(0, len(overrides), 2):
            key = overrides[i]
            if key.startswith('--'):
                unique_keys.add(key)

    for key in unique_keys:
        parser.add_argument(key)

    known_args, _ = parser.parse_known_args(overrides)

    for key, value in vars(known_args).items():
        if value is None:
            continue
        try:
            parsed_value = yaml.safe_load(value)
        except (yaml.YAMLError, TypeError):
            parsed_value = value

        keys = key.split('.')
        d = config
        for k in keys[:-1]:
            d = d.setdefault(k, {})
        d[keys[-1]] = parsed_value


def apply_dict_overrides(config, overrides_dict):
    for key, value in overrides_dict.items():
        keys = key.split('.')
        d = config
        for k in keys[:-1]:
            d = d.setdefault(k, {})
        d[keys[-1]] = value


def process_derived_config(config):
    if isinstance(config, AttrDict):
        config = config.to_dict()

    dataset_name = config['data']['name']
    if dataset_name in config['data']['properties']:
        dataset_props = config['data']['properties'][dataset_name]
        config['data'] = deep_merge(dict(config['data']), dataset_props)
    else:
        raise ValueError(f"Properties for dataset '{dataset_name}' not found in config.")

    model_cfg = config['model']
    if model_cfg['type'] == 'fcn':
        enc_cfg = model_cfg['architecture']['conv']['encoder']
        downsample_factor = enc_cfg['downsampling_factor'] ** len(enc_cfg['h_dims'])
        latent_h = config['data']['image_shape'][1] // downsample_factor
        latent_w = config['data']['image_shape'][2] // downsample_factor

        latent_channels = model_cfg['fcn_params']['latent_channels']
        target_channels = model_cfg['fcn_params']['target_channels']
        nontarget_channels = latent_channels - target_channels

        flat_target_dim = target_channels * latent_h * latent_w
        flat_nontarget_dim = nontarget_channels * latent_h * latent_w

        model_cfg['latent_space']['target_dim'] = flat_target_dim
        model_cfg['latent_space']['nontarget_dim'] = flat_nontarget_dim
        model_cfg['latent_space']['latent_dim'] = flat_target_dim + flat_nontarget_dim

        model_cfg['fcn_params']['latent_h'] = latent_h
        model_cfg['fcn_params']['latent_w'] = latent_w
        model_cfg['fcn_params']['nontarget_channels'] = nontarget_channels
    else:
        latent_dim = model_cfg['latent_space']['latent_dim']
        target_dim = model_cfg['latent_space']['target_dim']
        model_cfg['latent_space']['nontarget_dim'] = latent_dim - target_dim

    target_dim = model_cfg['latent_space']['target_dim']
    latent_dim = model_cfg['latent_space']['latent_dim']
    model_cfg['latent_space']['target_slice_start'] = 0
    model_cfg['latent_space']['target_slice_stop'] = target_dim
    model_cfg['latent_space']['nontarget_slice_start'] = target_dim
    model_cfg['latent_space']['nontarget_slice_stop'] = latent_dim

    return AttrDict(config)


def load_config(config_path=None, cli_overrides=None):
    config = {}
    if config_path:
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)

    if cli_overrides:
        apply_cli_overrides(config, cli_overrides)

    return AttrDict(config)