import os
from typing import Any, Dict, Tuple

import yaml

from egxc.utils.typing import NnParams
from egxc.xc_energy.functionals.base import BaseEnergyFunctional

from .checkpointing import CheckpointManager


def _deep_set(d: Dict[str, Any], keys: list[str], value: Any) -> None:
    cur: Dict[str, Any] = d
    for k in keys[:-1]:
        nxt = cur.get(k)
        if not isinstance(nxt, dict):
            nxt = {}
            cur[k] = nxt
        cur = nxt
    leaf = keys[-1]
    if isinstance(cur.get(leaf), dict) and isinstance(value, dict):
        # Deep-merge dictionaries rather than overwriting.
        _deep_update(cur[leaf], value)  # type: ignore[arg-type]
    else:
        cur[leaf] = value


def _deep_update(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]:
    for k, v in src.items():
        if isinstance(v, dict) and isinstance(dst.get(k), dict):
            _deep_update(dst[k], v)  # type: ignore[arg-type]
        else:
            dst[k] = v
    return dst


def _undot_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
    """
    SEML/Sacred configs may contain dotted keys like "model.name".
    Convert those into nested dictionaries while preserving existing nesting.
    """
    out: Dict[str, Any] = {}
    for k, v in cfg.items():
        v2 = _undot_config(v) if isinstance(v, dict) else v
        if isinstance(k, str) and '.' in k:
            _deep_set(out, k.split('.'), v2)
        else:
            if isinstance(out.get(k), dict) and isinstance(v2, dict):
                _deep_update(out[k], v2)  # type: ignore[arg-type]
            else:
                out[k] = v2
    return out


def _load_checkpoint_config(checkpoint_dir: str) -> Dict[str, Any]:
    if not os.path.isdir(checkpoint_dir):
        raise FileNotFoundError(f'Checkpoint directory not found: {checkpoint_dir}')

    yaml_files = sorted(
        [
            os.path.join(checkpoint_dir, f)
            for f in os.listdir(checkpoint_dir)
            if f.endswith('.yaml') or f.endswith('.yml')
        ]
    )
    if not yaml_files:
        raise FileNotFoundError(
            f'No checkpoint config YAML found in {checkpoint_dir}. '
            'Expected a *.yaml saved alongside the checkpoint.'
        )

    # Usually there is exactly one. If multiple exist, prefer the newest.
    config_path = max(yaml_files, key=lambda p: os.path.getmtime(p))
    with open(config_path, 'r') as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)
    if not isinstance(cfg, dict):
        raise ValueError(f'Invalid checkpoint config in {config_path}: expected a dict')
    return _undot_config(cfg)


def _unwrap_params_for_functional(params: NnParams) -> NnParams:
    """
    Checkpoints are often saved from a larger Flax module (e.g. SCF solver) where
    the functional lives under a nested scope like:
        params['params']['xc_module']['functional'][...]

    This helper peels off common wrappers so the returned params can be used as:
        functional.apply(params, ...)
    """
    if not isinstance(params, dict):
        return params

    # Most commonly we have a Flax variables dict with a top-level "params" collection.
    if 'params' in params and isinstance(params['params'], dict):
        p = params['params']
        # solver -> xc_module -> functional
        if isinstance(p.get('xc_module'), dict):
            xc = p['xc_module']
            if isinstance(xc.get('functional'), dict):
                return {'params': xc['functional']}
        # xc_module only -> functional
        if isinstance(p.get('functional'), dict):
            return {'params': p['functional']}

    # Already looks like functional params
    return params


def load_model(
    checkpoint_path: str, prefix: str = ''
) -> Tuple[NnParams, BaseEnergyFunctional]:
    """
    Load a model from a checkpoint path.
    Automatically infers the model parameters from the checkpoint path.

    Args:
        checkpoint_path: Path to the checkpoint directory.
        prefix: Prefix of the checkpoint file. If not specified, loads the best parameters.

    Returns:
        NnParams: The loaded model parameters.
        BaseEnergyFunctional: The loaded model functional.
    """
    params = CheckpointManager.load_params(checkpoint_path, prefix=prefix)
    cfg = _load_checkpoint_config(checkpoint_path)

    if 'model' not in cfg or not isinstance(cfg['model'], dict):
        raise KeyError(
            'Checkpoint config does not contain a "model" dict. '
            f'Available top-level keys: {sorted(cfg.keys())}'
        )

    # Some functionals (hybrids) require these flags if not included in model kwargs.
    base_cfg = cfg.get('base', {}) if isinstance(cfg.get('base', {}), dict) else {}
    spin_restricted = base_cfg.get('spin_restricted')
    use_density_fitting = base_cfg.get('use_density_fitting')

    model_cfg: Dict[str, Any] = dict(cfg['model'])
    name = str(model_cfg.get('name', '')).lower()
    if name in {'b3lyp', 'pbe0', 'wb97m-v'}:
        model_cfg.setdefault('spin_restricted', spin_restricted)
        model_cfg.setdefault('use_density_fitting', use_density_fitting)

    from egxc.xc_energy import functionals

    functional = functionals.get_functional(**model_cfg)
    return _unwrap_params_for_functional(params), functional
