from __future__ import annotations

import copy
from typing import Callable, Dict, Optional, Type, Tuple

import jax
import jax.numpy as jnp
import optax
from flax import jax_utils, linen as nn
from phijax.optimizers.soap import soap



from .base import *
from .state import *


def deep_merge(base: dict, override: Optional[dict]) -> dict:
    if not override:
        return copy.deepcopy(base)
    out = copy.deepcopy(base)
    for k, v in override.items():
        if isinstance(v, dict) and isinstance(out.get(k), dict):
            out[k] = deep_merge(out[k], v)
        else:
            out[k] = v
    return out





def create_lr_schedule(cfg) -> optax.Schedule:
    if getattr(cfg, "scheduler", None) is None or str(cfg.scheduler).lower() == "exponential":
        base = optax.exponential_decay(
            cfg.learning_rate,
            cfg.decay_steps,
            cfg.decay_rate,
            staircase=cfg.staircase,
        )
    elif str(cfg.scheduler).lower() == "cosine_decay":
       
            base = optax.cosine_decay_schedule(
                cfg.learning_rate,
                cfg.decay_steps,
                alpha=0.0,
            )
    elif str(cfg.scheduler).lower() == "none":
        base = optax.constant_schedule(cfg.learning_rate)
    else:
        raise NotImplementedError(cfg.scheduler)

    warmup = int(getattr(cfg, "warmup_steps", 0))
    if warmup > 0:
        w = optax.linear_schedule(0.0, cfg.learning_rate, warmup)
        return optax.join_schedules([w, base], [warmup])
    return base


def create_optimizer(cfg) -> Tuple[optax.Schedule, optax.GradientTransformation]:
    lr = create_lr_schedule(cfg)
    name = str(cfg.optimizer).lower()

    if name == "adam":
        tx = optax.adam(lr, b1=cfg.beta1, b2=cfg.beta2, eps=cfg.eps)
    elif name == "adamw":
        weight_decay = float(getattr(cfg, "weight_decay", 0.0))
        tx = optax.adamw(lr, b1=cfg.beta1, b2=cfg.beta2, eps=cfg.eps, weight_decay=weight_decay)

    elif name == "sgd":
        tx = optax.sgd(lr, momentum=getattr(cfg, "momentum", 0.0))
    elif name == "lamb":
        tx = optax.lamb(lr, b1=cfg.beta1, b2=cfg.beta2, eps=cfg.eps)
    elif name == "adagrad":
        tx = optax.adagrad(lr, eps=cfg.eps)
    elif name == "rmsprop":
        tx = optax.rmsprop(lr)
    elif name == "muon":
        tx = optax.contrib.muon(
            learning_rate=lr,
            ns_coeffs=(2, -1.5, 0.5),
            ns_steps=10,
            beta=0.99,
            adam_b1=0.99,
        )
    elif name == "soap":
        tx = soap(lr, b1=cfg.beta1, b2=cfg.beta2, weight_decay=0.0, precondition_frequency=2)
    elif name == "kron":
        tx = kron(lr, b1=cfg.beta1)
    else:
        raise NotImplementedError(name)

    clip = float(getattr(cfg, "clip_norm", 0.0))
    if bool(getattr(cfg, "schedule_free", False)):
        tx = optax.chain(
            optax.clip_by_global_norm(clip if clip > 0 else 1.0),
            optax.contrib.schedule_free(tx, lr, b1=cfg.beta1),
        )
    elif clip > 0:
        tx = optax.chain(optax.clip_by_global_norm(clip), tx)

    accum = int(getattr(cfg, "grad_accum_steps", 1))
    if accum > 1:
        tx = optax.MultiSteps(tx, every_k_schedule=accum)
    return lr, tx


class _Entry:
    __slots__ = ("cls", "defaults", "pre_hook", "aliases")
    def __init__(
        self,
        cls: Type[nn.Module],
        defaults: Optional[dict],
        pre_hook: Optional[Callable[[dict, dict], None]],
        aliases: Optional[list[str]] = None,
    ):
        self.cls = cls
        self.defaults = defaults or {}
        self.pre_hook = pre_hook
        self.aliases = [a.lower() for a in aliases or []]


_ARCH_REGISTRY: Dict[str, _Entry] = {}


def register_arch(
    name: str,
    cls: Optional[Type[nn.Module]] = None,
    *,
    defaults: Optional[dict] = None,
    pre_hook: Optional[Callable[[dict, dict], None]] = None,
    aliases: Optional[list[str]] = None,
):
    def _do_register(c: Type[nn.Module]):
        entry = _Entry(c, defaults, pre_hook, aliases)
        key = name.lower()
        _ARCH_REGISTRY[key] = entry
        if aliases:
            for a in aliases:
                al = a.lower()
                if al in _ARCH_REGISTRY:
                    raise ValueError(al)
                _ARCH_REGISTRY[al] = entry
        return c

    if cls is not None:
        return _do_register(cls)
    return _do_register


def get_model(config: dict, *, replicate: bool = True):
    model_name = config["exp_name"].split("-")[0].lower()
    if model_name not in _ARCH_REGISTRY:
        raise NotImplementedError(model_name)

    entry = _ARCH_REGISTRY[model_name]
    merged = deep_merge(entry.defaults, config.get("model_config", {}))
    #print("Model config: ", config.get("model_config", {}))

    if entry.pre_hook is not None:
        entry.pre_hook(config, merged)

    arch = entry.cls(**merged)

    # model meta
    mmeta = config.get("model_meta", {})


    input_dim = mmeta.get("input_dim", 2)
    x = jnp.ones((input_dim,))

    key = jax.random.PRNGKey(int(config.get("seed", 0)))
    variables = arch.init(key, x)
    params = variables["params"]
    #rot_state = variables.get("rot_state", {})

    _, tx = create_optimizer(config["optim"])


    wcfg = config["weighting"]

    weight_keys = tuple(dict(wcfg.init_weights).keys())
    loss_params = {"log_sigma": {k: jnp.array(0.0, jnp.float32) for k in weight_keys}}

    use_rot = bool(config.get("use_rot", False)) or bool(merged.get("use_rot", False))

    use_hetero = bool(config.get("hetero", False)) or bool(merged.get("hetero", False))

    #check if the model has use_rot attribute
    if hasattr(arch, "use_rot"):
        use_rot = use_rot or  arch.use_rot

    opt_name = str(config["optim"].optimizer).lower() if hasattr(config["optim"], "optimizer") else str(config["optim"].get("optimizer", "")).lower()
    if opt_name == "soap":
        use_rot = False

    if use_rot:

        if "rot_state" not in variables:
            raise ValueError(
                "use_rot=True but model.init(...) did not create a 'rot_state' collection. "
                "Did you update RotDense to store rot_state variable 'rot' (RotLayer) ?"
            )
        
       
        rot_state = variables["rot_state"]
        state = RotTrainState.create(
            apply_fn=arch.apply,
            params=params,
            tx=tx,
            weights=dict(wcfg.init_weights),
            momentum=float(wcfg.momentum),
            rot_state=rot_state,
            use_rot=True,
        )
    elif use_hetero:
        weight_keys = tuple(dict(wcfg.init_weights).keys())
        loss_params = {"log_sigma": {k: jnp.array(0.0, jnp.float32) for k in weight_keys}}
        #_, loss_tx = create_optimizer(config["optim"])
        loss_opt_cfg = copy.deepcopy(config["optim"])
        loss_opt_cfg.lr = float(getattr(loss_opt_cfg, "loss_lr", 5e-5))
        #_, loss_tx = create_optimizer(loss_opt_cfg)
        loss_tx = optax.adam(
            learning_rate=1e-5,
            b1=loss_opt_cfg.beta1,
            b2=loss_opt_cfg.beta2,
            eps=loss_opt_cfg.eps,
        )

        state = LossTrainState.create(
            apply_fn=arch.apply,
            params=params,
            tx=tx,
            weights=dict(wcfg.init_weights),
            momentum=float(wcfg.momentum),
            loss_params=loss_params,
            loss_tx=loss_tx,
        )
 
    else:
        flag = config.get("flag", None)
        if flag == "state_fail":
            state = AllTrainState.create(
            apply_fn=arch.apply,
            params=params,
            tx=tx,
            weights=dict(wcfg.init_weights),
            momentum=float(wcfg.momentum),
            loss_ema= dict(wcfg.init_weights)
            
            )
        else:
            state = TrainState.create(
                apply_fn=arch.apply,
                params=params,
                tx=tx,
                weights=dict(wcfg.init_weights),
                momentum=float(wcfg.momentum),
                #loss_ema= dict(wcfg.init_weights)
            )

    return jax_utils.replicate(state) if replicate else state



register_arch(
    "mlp",
    Mlp, 
    defaults={
        "num_layers": 4,
        "hidden_dim": 64,
        "out_dim": 1,
        "activation": "tanh"
    },
    aliases=["pinn"]
)

register_arch(
    "piratenet",
    PirateNet, 
    defaults= {
        "num_layers": 3,
        "hidden_dim": 256,
        "out_dim": 1,
        "activation": "tanh"
    }, 
    aliases=["res"]
)

register_arch(
    "mmlp",
    ModifiedMlp, 
    defaults= {
        "num_layers": 4,
        "hidden_dim": 256,
        "out_dim": 1,
        "activation": "tanh"
    }, 
    aliases=["modified_mlp"]
)


register_arch(
    "rot_mlp",
    RotMlp, 
    defaults={
        "num_layers": 4,
        "hidden_dim": 64,
        "out_dim": 1,
        "activation": "tanh"
    },
    aliases=["rmlp", "rpinn"]
)

register_arch(
    "fast_mlp",
    FastMlp, 
    defaults={
        "num_layers": 4,
        "hidden_dim": 64,
        "out_dim": 1,
        "activation": "tanh"
    },
    aliases=["frmpl", "frpinn", "fmlp"]
)

