from __future__ import annotations

import copy
from dataclasses import dataclass
from typing import Callable, Dict, Optional, Type, Tuple, Any

import torch
from torch import nn

from phijax.torch.models.base import *


def deep_merge(base: dict, override: Optional[dict]) -> dict:
    if not override:
        return copy.deepcopy(base)
    out = copy.deepcopy(base)
    for k, v in override.items():
        if isinstance(v, dict) and isinstance(out.get(k), dict):
            out[k] = deep_merge(out[k], v)
        else:
            out[k] = v
    return out


@dataclass(frozen=True)
class OptimArtifacts:
    optimizer: torch.optim.Optimizer
    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler]
    clip_norm: float
    grad_accum_steps: int


def _make_warmup_then(base_sched, warmup_steps: int):
    class WarmupThen(torch.optim.lr_scheduler._LRScheduler):
        def __init__(self, optimizer, warmup_steps: int, base_sched):
            self.warmup_steps = int(warmup_steps)
            self.base_sched = base_sched
            super().__init__(optimizer)

        def get_lr(self):
            step = self.last_epoch
            if step < self.warmup_steps:
                frac = float(step + 1) / float(self.warmup_steps)
                return [base * frac for base in self.base_lrs]
            return self.base_sched.get_last_lr()

        def step(self, epoch: Optional[int] = None):
            super().step(epoch)
            if self.last_epoch >= self.warmup_steps:
                self.base_sched.step(epoch=None)

    return WarmupThen


def create_optimizer_and_scheduler(cfg: Any, params) -> OptimArtifacts:
    optim_cfg = cfg.optim if hasattr(cfg, "optim") else cfg["optim"]
    name = str(getattr(optim_cfg, "optimizer", optim_cfg.get("optimizer", "adam"))).lower()

    lr = float(getattr(optim_cfg, "learning_rate", getattr(optim_cfg, "lr", optim_cfg.get("learning_rate", optim_cfg.get("lr", 1e-3)))))
    beta1 = float(getattr(optim_cfg, "beta1", optim_cfg.get("beta1", 0.9)))
    beta2 = float(getattr(optim_cfg, "beta2", optim_cfg.get("beta2", 0.999)))
    eps = float(getattr(optim_cfg, "eps", optim_cfg.get("eps", 1e-8)))
    weight_decay = float(getattr(optim_cfg, "weight_decay", optim_cfg.get("weight_decay", 0.0)))
    momentum = float(getattr(optim_cfg, "momentum", optim_cfg.get("momentum", 0.0)))

    clip_norm = float(getattr(optim_cfg, "clip_norm", optim_cfg.get("clip_norm", 0.0)))
    grad_accum_steps = int(getattr(optim_cfg, "grad_accum_steps", optim_cfg.get("grad_accum_steps", 1)))

    if name == "adam":
        optimizer = torch.optim.Adam(params, lr=lr, betas=(beta1, beta2), eps=eps, weight_decay=0.0)
    elif name == "adamw":
        optimizer = torch.optim.AdamW(params, lr=lr, betas=(beta1, beta2), eps=eps, weight_decay=weight_decay)
    elif name == "sgd":
        optimizer = torch.optim.SGD(params, lr=lr, momentum=momentum, weight_decay=weight_decay)
    elif name in {"lamb", "adagrad", "rmsprop", "muon", "soap", "kron"}:
        raise NotImplementedError(f"Optimizer '{name}' not implemented in torch translation yet.")
    else:
        raise NotImplementedError(name)

    scheduler_name = str(getattr(optim_cfg, "scheduler", optim_cfg.get("scheduler", "exponential"))).lower()
    warmup_steps = int(getattr(optim_cfg, "warmup_steps", optim_cfg.get("warmup_steps", 0)))

    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None
    if scheduler_name in {"exponential", ""}:
        decay_steps = int(getattr(optim_cfg, "decay_steps", optim_cfg.get("decay_steps", 10000)))
        decay_rate = float(getattr(optim_cfg, "decay_rate", optim_cfg.get("decay_rate", 0.9)))
        staircase = bool(getattr(optim_cfg, "staircase", optim_cfg.get("staircase", False)))

        if staircase:
            gamma = decay_rate
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=decay_steps, gamma=gamma)
        else:
            gamma = decay_rate ** (1.0 / max(decay_steps, 1))
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

    elif scheduler_name == "cosine_decay":
        decay_steps = int(getattr(optim_cfg, "decay_steps", optim_cfg.get("decay_steps", 10000)))
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=decay_steps, eta_min=0.0)

    elif scheduler_name == "none":
        scheduler = None
    else:
        raise NotImplementedError(scheduler_name)

    if scheduler is not None and warmup_steps > 0:
        WarmupThen = _make_warmup_then(scheduler, warmup_steps)
        scheduler = WarmupThen(optimizer, warmup_steps=warmup_steps, base_sched=scheduler)

    return OptimArtifacts(optimizer=optimizer, scheduler=scheduler, clip_norm=clip_norm, grad_accum_steps=grad_accum_steps)


class _Entry:
    __slots__ = ("cls", "defaults", "pre_hook", "aliases")

    def __init__(
        self,
        cls: Type[nn.Module],
        defaults: Optional[dict],
        pre_hook: Optional[Callable[[dict, dict], None]],
        aliases: Optional[list[str]] = None,
    ):
        self.cls = cls
        self.defaults = defaults or {}
        self.pre_hook = pre_hook
        self.aliases = [a.lower() for a in (aliases or [])]


_ARCH_REGISTRY: Dict[str, _Entry] = {}


def register_arch(
    name: str,
    cls: Optional[Type[nn.Module]] = None,
    *,
    defaults: Optional[dict] = None,
    pre_hook: Optional[Callable[[dict, dict], None]] = None,
    aliases: Optional[list[str]] = None,
):
    def _do_register(c: Type[nn.Module]):
        entry = _Entry(c, defaults, pre_hook, aliases)
        key = name.lower()
        _ARCH_REGISTRY[key] = entry
        if aliases:
            for a in aliases:
                al = a.lower()
                if al in _ARCH_REGISTRY:
                    raise ValueError(al)
                _ARCH_REGISTRY[al] = entry
        return c

    return _do_register(cls) if cls is not None else _do_register


def get_model(config: dict, *, device: Optional[torch.device] = None) -> nn.Module:
    exp_name = config.get("exp_name", "")
    model_name = exp_name.split("-")[0].lower() if exp_name else str(config.get("model", "")).lower()
    if model_name not in _ARCH_REGISTRY:
        raise NotImplementedError(model_name)

    entry = _ARCH_REGISTRY[model_name]
    merged = deep_merge(entry.defaults, config.get("model_config", {}))

    if entry.pre_hook is not None:
        entry.pre_hook(config, merged)

    model = entry.cls(**merged)

    if device is not None:
        model = model.to(device)

    return model



register_arch(
    "mlp",
    cls=MLP,
    defaults={
        #"in_dim": 2,
        #"out_dim": 1,
       # "hidden_dim": [50, 50, 50, 50],
        #"activation": "tanh",
        #"use_bias": True,
        #"fourier_features": None,
    },
)


