from __future__ import annotations
import copy
from typing import Callable, Dict, Optional, Type
from phijax.utils import Collection
from phijax.equations.base import PINN

def deep_merge(base: dict, override: Optional[dict]) -> dict:
    if not override:
        return copy.deepcopy(base)
    out = copy.deepcopy(base)
    for k, v in override.items():
        if isinstance(v, dict) and isinstance(out.get(k), dict):
            out[k] = deep_merge(out[k], v)
        else:
            out[k] = v
    return out


class _PDEEntry:
    __slots__ = ("cls", "defaults", "pre_hook", "aliases")
    def __init__(
        self,
        cls: Type["PINN"],
        defaults: Optional[dict],
        pre_hook: Optional[Callable[[dict, dict], None]],
        aliases: Optional[list[str]] = None,
    ):
        self.cls = cls
        self.defaults = defaults or {}
        self.pre_hook = pre_hook
        self.aliases = [a.lower() for a in aliases or []]


_PDE_REGISTRY: Dict[str, _PDEEntry] = {}


def register_pde(
    name: str,
    cls: Optional[Type["PINN"]] = None,
    *,
    defaults: Optional[dict] = None,
    pre_hook: Optional[Callable[[dict, dict], None]] = None,
    aliases: Optional[list[str]] = None,
):
    def _do_register(c: Type["PINN"]):
        entry = _PDEEntry(c, defaults, pre_hook, aliases)
        key = name.lower()
        _PDE_REGISTRY[key] = entry
        if aliases:
            for a in aliases:
                al = a.lower()
                if al in _PDE_REGISTRY:
                    raise ValueError(f"Alias '{a}' already registered.")
                _PDE_REGISTRY[al] = entry
        return c

    if cls is not None:
        return _do_register(cls)
    return _do_register


def get_pde(config):
    cfg_dict = dict(config) if isinstance(config, dict) else config

    name = None
    if isinstance(cfg_dict, dict):
        name = cfg_dict.get("pde")
        if name is None and "exp_name" in cfg_dict:
            name = cfg_dict["exp_name"].split("-")[0]
    else:
        name = getattr(config, "pde", None)
        if name is None and hasattr(config, "exp_name"):
            name = config.exp_name.split("-")[0]

    if name is None:
        raise ValueError("Provide config.pde (preferred) or config.exp_name.")

    key = str(name).lower()
    if key not in _PDE_REGISTRY:
        raise NotImplementedError(f"Unknown PDE '{name}'. Registered: {sorted(_PDE_REGISTRY.keys())}")

    entry = _PDEEntry  # only for type checkers

    entry = _PDE_REGISTRY[key]

    user_cfg = None
    if isinstance(cfg_dict, dict):
        user_cfg = cfg_dict.get("pde_config", {}) or {}
        merged = deep_merge(entry.defaults, user_cfg)
    else:
        user_cfg = getattr(config, "pde_config", {}) or {}
        merged = deep_merge(entry.defaults, dict(user_cfg))

    if entry.pre_hook is not None:
        entry.pre_hook(cfg_dict, merged)

    if isinstance(cfg_dict, dict):
        cfg2 = dict(cfg_dict)
        cfg2["pde_config"] = merged
        config2 = Collection.from_dict(cfg2) if "Collection" in globals() else cfg2
    else:
        config2 = config
        config2.pde_config = Collection.from_dict(merged) if "Collection" in globals() else merged

    return entry.cls(config2)


__all__ = ["register_pde", "get_pde", "_PDE_REGISTRY"]
