from __future__ import annotations

import copy
from typing import Callable, Dict, Optional, Type, Any

import torch

from phijax.torch.utils import Collection
from phijax.torch.equations.base import PINN
from phijax.torch.models import get_model


def deep_merge(base: dict, override: Optional[dict]) -> dict:
    if not override:
        return copy.deepcopy(base)
    out = copy.deepcopy(base)
    for k, v in override.items():
        if isinstance(v, dict) and isinstance(out.get(k), dict):
            out[k] = deep_merge(out[k], v)
        else:
            out[k] = v
    return out


class _PDEEntry:
    __slots__ = ("cls", "defaults", "pre_hook", "aliases")

    def __init__(
        self,
        cls: Type[PINN],
        defaults: Optional[dict],
        pre_hook: Optional[Callable[[dict, dict], None]],
        aliases: Optional[list[str]] = None,
    ):
        self.cls = cls
        self.defaults = defaults or {}
        self.pre_hook = pre_hook
        self.aliases = [a.lower() for a in (aliases or [])]


_PDE_REGISTRY: Dict[str, _PDEEntry] = {}


def register_pde(
    name: str,
    cls: Optional[Type[PINN]] = None,
    *,
    defaults: Optional[dict] = None,
    pre_hook: Optional[Callable[[dict, dict], None]] = None,
    aliases: Optional[list[str]] = None,
):
    def _do_register(c: Type[PINN]):
        entry = _PDEEntry(c, defaults, pre_hook, aliases)
        key = name.lower()
        _PDE_REGISTRY[key] = entry
        if aliases:
            for a in aliases:
                al = a.lower()
                if al in _PDE_REGISTRY:
                    raise ValueError(f"Alias '{a}' already registered.")
                _PDE_REGISTRY[al] = entry
        return c

    return _do_register(cls) if cls is not None else _do_register


def _infer_pde_name(cfg_dict: dict, cfg_obj: Any) -> str:
    name = None
    if isinstance(cfg_dict, dict):
        name = cfg_dict.get("pde")
        if name is None and "exp_name" in cfg_dict and cfg_dict["exp_name"]:
            parts = str(cfg_dict["exp_name"]).split("-")
            if len(parts) >= 2:
                name = parts[1]
    else:
        name = getattr(cfg_obj, "pde", None)
        if name is None and hasattr(cfg_obj, "exp_name") and cfg_obj.exp_name:
            parts = str(cfg_obj.exp_name).split("-")
            if len(parts) >= 2:
                name = parts[1]
    if name is None:
        raise ValueError("Provide config.pde (preferred) or config.exp_name including PDE name.")
    return str(name)


def get_pde(config: Any, *, device: Optional[torch.device] = None) -> PINN:
    cfg_dict = dict(config) if isinstance(config, dict) else config
    name = _infer_pde_name(cfg_dict if isinstance(cfg_dict, dict) else {}, config)

    key = name.lower()
    print(f"Getting PDE '{key}'")
    if key not in _PDE_REGISTRY:
        raise NotImplementedError(f"Unknown PDE '{name}'. Registered: {sorted(_PDE_REGISTRY.keys())}")

    entry = _PDE_REGISTRY[key]

    if isinstance(cfg_dict, dict):
        user_cfg = cfg_dict.get("pde_config", {}) or {}
        merged = deep_merge(entry.defaults, user_cfg)
    else:
        user_cfg = getattr(config, "pde_config", {}) or {}
        merged = deep_merge(entry.defaults, dict(user_cfg))

    if entry.pre_hook is not None:
        entry.pre_hook(cfg_dict if isinstance(cfg_dict, dict) else dict(config), merged)

    if isinstance(cfg_dict, dict):
        cfg2 = dict(cfg_dict)
        cfg2["pde_config"] = merged
        config2 = Collection.from_dict(cfg2)
    else:
        config2 = config
        config2.pde_config = Collection.from_dict(merged)

    model = get_model(dict(config2) if isinstance(config2, dict) else dict(config2), device=device)
    print(model)

    return entry.cls(config2, model=model, device=device)


__all__ = ["register_pde", "get_pde", "_PDE_REGISTRY"]
