import inspect
import torch
from torch.optim import Optimizer

from project_simplex import project_simplex
from waterfill_clip import waterfill_clip


class TrailMix(Optimizer):

    def __init__(
        self,
        params,
        base_optimizers,                 # list[torch.optim.Optimizer] - (initialized with model parameters)
        lr_meta: float,
        initial_weighting: list = None,     
        base_optimizer_names=None,                  # names for logging

        ### Stability hyperparameters ########################
        base_opt_grad_clipping=None,                # Per-base gradient clipping (applied to shadow grads before each base step).
        meta_update_every: int = 1,                 # update λ every K steps (>=1)
        lambda_tv_cap: float = 1e9,                 # max L1 change per λ update (large = effectively disabled)
        entropy_tau0: float = 0.0,                  # entropy bonus coefficient (annealed to 0)
        entropy_tau_steps: int = 400,               # steps to anneal entropy to 0
        meta_temperature_early: float = 10.0,       # >1 early to soften LR; 1.0 means off
        meta_temperature_switch_step: int = 60,     # when to switch temperature to 1
        lambda_eps_floor: float = 1e-3,             # ε-exploration floor on λ (e.g., 1e-3), 0 disables

        ### Curvature Reward ########################
        curvature_bonus: float = 0.8,        # weight to add curvature reward to meta scores (0 disables)
        curvature_clip_min: float = 1e-6,    # clamp diagonal curvature to avoid division blowups
        curvature_clip_max: float = 1e6,
        curvature_eps: float = 1e-12,
        curvature_alpha_local: float = 0.6,  # 0.1: instant curvature vs historical consistency
        curvature_ma_beta: float = 0.8,      # EMA beta for curvature match (higher = slower)
        curvature_kappa_lcb: float = 0.2,    # curvature variance penalty
    ):

        self.base_optimizers = list(base_optimizers)
        self.num_optimizers = len(self.base_optimizers)
        self.base_opt_names = (
            list(base_optimizer_names) if base_optimizer_names and len(base_optimizer_names) == self.num_optimizers
            else [f"opt{i}" for i in range(self.num_optimizers)]
        )

        # Per-base grad clipping and mixed-step clipping
        self.base_opt_grad_clipping = base_opt_grad_clipping

        # Fairness + meta knobs
        self.meta_update_every = int(max(1, meta_update_every))
        self.lambda_tv_cap = float(lambda_tv_cap)
        self.entropy_tau0 = float(entropy_tau0)
        self.entropy_tau_steps = int(max(1, entropy_tau_steps))
        self.meta_temperature_early = float(meta_temperature_early)
        self.meta_temperature_switch_step = int(max(0, meta_temperature_switch_step))
        self.lambda_eps_floor = float(lambda_eps_floor)

        # Curvature knobs
        self.curvature_bonus = float(curvature_bonus)
        self.curv_clip_min = float(curvature_clip_min)
        self.curv_clip_max = float(curvature_clip_max)
        self.curv_eps = float(curvature_eps)
        self._curv_alpha_local = float(curvature_alpha_local)
        self._curv_ma_beta = float(curvature_ma_beta)
        self._curv_kappa_lcb = float(curvature_kappa_lcb)

        # Optimizer shell + meta group
        super().__init__(params, defaults={'lr': float(lr_meta)})
        if len(self.param_groups) == 0 or len(self.param_groups[0]['params']) == 0:
            raise ValueError("TrailMix received an empty parameter list.")

        param_device = self.param_groups[0]['params'][0].device

        # λ weights (exposed for logging)
        if initial_weighting:
            if len(initial_weighting) != self.num_optimizers:
                raise ValueError("Bad initial weightings")
            else:
                lambda_init = torch.tensor(initial_weighting, device=param_device)
                lambda_init = lambda_init / torch.sum(lambda_init)
        else:
            lambda_init = torch.ones(self.num_optimizers, device=param_device) / self.num_optimizers
        self.lambdas = torch.nn.Parameter(lambda_init, requires_grad=False)

        # Meta group (scheduler targets this group's lr)
        self.add_param_group({"params": [torch.tensor(0.0, device=param_device)],
                              "lr": float(lr_meta), "group_type": "meta"})
        self.meta_group_idx = len(self.param_groups) - 1

        # Meta step counter
        self.meta_step = 0

        # Buffers for Δ^t
        self._prev_updates = None
        self._prev_updates_normed = None

        # Persistent shadows
        self._shadow_params = None
        self._shadow_opts = None

        # Logging
        self.last_base_updates = {}
        self.last_base_updates_normed = {}

        # Cache for curvature (g_{t-1}, final_update_{t-1})
        self._prev_grad_flat = None
        self._prev_final_update = None

        # Per-base curvature reputation buffers (EMA mean & mean-square)
        self._curv_ma = torch.zeros(self.num_optimizers, device=param_device)
        self._curv_msq = torch.zeros_like(self._curv_ma)

        # Build persistent shadows now
        self._ensure_shadows_ready()

        # score EMA buffer
        self._h_ema = torch.zeros_like(self.lambdas.data)

    

    def _current_lr_meta(self) -> float:
        return float(self.param_groups[self.meta_group_idx]["lr"])

    @torch.no_grad()
    def _collect_model_params(self):
        return [p for g in self.param_groups if g.get("group_type") != "meta" for p in g['params']]

    @torch.no_grad()
    def _flat_all_grads_with_zeros(self, params):
        flats = []
        for p in params:
            flats.append(torch.zeros_like(p).reshape(-1) if p.grad is None else p.grad.detach().reshape(-1))
        return torch.cat(flats)

    @torch.no_grad()
    def _flat_params(self, params):
        return torch.cat([p.detach().reshape(-1) for p in params]) if params else torch.tensor([])

    def _filtered_cfg_from(self, opt):
        g = dict(opt.param_groups[0])
        g.pop('params', None)
        for k in list(g.keys()):
            if k in {'initial_lr', 'params'}:
                g.pop(k, None)
        sig = inspect.signature(type(opt).__init__)
        accepted = {name for name, pp in sig.parameters.items()
                    if name != 'self' and pp.kind in (pp.POSITIONAL_OR_KEYWORD, pp.KEYWORD_ONLY)}
        return {k: v for k, v in g.items() if k in accepted}

    @torch.no_grad()
    def _ensure_shadows_ready(self):
        """Create or refresh shadow params/opts to match current model layout/device."""
        model_params = self._collect_model_params()
        if self._shadow_params is None:
            self._shadow_params = []
            self._shadow_opts = []
            for base in self.base_optimizers:
                sp = [p.detach().clone().requires_grad_(True) for p in model_params]
                cfg = self._filtered_cfg_from(base)
                try:
                    sopt = type(base)(sp, **cfg)
                except TypeError:
                    sopt = type(base)(sp)
                # initialize shadow state from real
                sopt.load_state_dict(base.state_dict())
                self._shadow_params.append(sp)
                self._shadow_opts.append(sopt)
        else:
            need_rebuild = False
            if len(self._shadow_params[0]) != len(model_params):
                need_rebuild = True
            else:
                for p, sp in zip(model_params, self._shadow_params[0]):
                    if p.shape != sp.shape or p.device != sp.device or p.dtype != sp.dtype:
                        need_rebuild = True
                        break
            if need_rebuild:
                self._shadow_params = None
                self._shadow_opts = None
                self._ensure_shadows_ready()
                return

        # Sync shadow values/hyperparams (state is synced per-step before proposing)
        for base, sp_list, sopt in zip(self.base_optimizers, self._shadow_params, self._shadow_opts):
            for p, sp in zip(model_params, sp_list):
                sp.data.copy_(p.data)
            for g_base, g_shadow in zip(base.param_groups, sopt.param_groups):
                for k, v in g_base.items():
                    if k != "params" and k in g_shadow:
                        g_shadow[k] = v

    @torch.no_grad()
    def _state_only_step_with_scaled_grads(self, base_opt, lam_i: float, model_params):
        """
        Advance optimizer state as if it saw gradient lam_i * g_t, without moving params:
          - Scale current grads by lam_i
          - Temporarily set lr=0
          - base_opt.step() to update moments/step counters
          - Restore lrs and grads
        """
        # Snapshot grads and lrs
        saved_grads = [None if p.grad is None else p.grad.detach().clone() for p in model_params]
        saved_lrs = [g.get('lr', 0.0) for g in base_opt.param_groups]

        # Scale grads by λ_i
        for p in model_params:
            if p.grad is not None:
                p.grad.mul_(float(lam_i))

        # State-only step via lr=0 trick
        for g in base_opt.param_groups:
            g['lr'] = 0.0
        base_opt.step()

        # Restore lrs/grads
        for g, lr in zip(base_opt.param_groups, saved_lrs):
            g['lr'] = lr
        for p, g in zip(model_params, saved_grads):
            if p.grad is not None and g is not None:
                p.grad.copy_(g)

    @torch.no_grad()
    def _compute_updates_matrix(self, grad_flat_t, model_params):
        """
        Return:
          updates_raw:  [K, D] raw proposed steps (pre-fairness)
          updates_norm: [K, D] normalized steps for mixing (post-fairness)
          preds_raw:    [K]    raw ⟨g, Δ_i⟩ for logging
          norms_raw:    [K]    raw ‖Δ_i‖ for logging
        """
        self._ensure_shadows_ready()

        # grads aligned with params
        grads_list = [None if p.grad is None else p.grad.detach().clone() for p in model_params]

        K = self.num_optimizers
        D = grad_flat_t.numel()
        device = grad_flat_t.device
        dtype = grad_flat_t.dtype

        updates_raw = torch.zeros((K, D), device=device, dtype=dtype)
        updates_norm = torch.zeros_like(updates_raw)
        preds_raw = torch.zeros(K, device=device, dtype=dtype)
        norms_raw = torch.zeros_like(preds_raw)

        old_flat = self._flat_params(model_params)

        # Proposals per base using persistent shadows
        for i, (name, sp_list, sopt, base) in enumerate(
            zip(self.base_opt_names, self._shadow_params, self._shadow_opts, self.base_optimizers)
        ):
            # keep shadow *state* aligned with the REAL base before proposing
            sopt.load_state_dict(base.state_dict())

            # assign grads (respect None)
            for sp, gp in zip(sp_list, grads_list):
                sp.grad = (None if gp is None else gp.detach().clone())

            # sanitize grads before clipping
            plist = [p for g in sopt.param_groups for p in g['params'] if p.grad is not None]
            for tp in plist:
                if not torch.isfinite(tp.grad).all():
                    tp.grad.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)

            # optional per-base grad clip
            if self.base_opt_grad_clipping is not None:
                torch.nn.utils.clip_grad_norm_(plist, max_norm=float(self.base_opt_grad_clipping),
                                               norm_type=2.0, error_if_nonfinite=False)

            # one shadow step -> proposal
            sopt.step()

            new_flat = self._flat_params(sp_list)
            delta_flat = (old_flat - new_flat)
            updates_raw[i] = delta_flat
            preds_raw[i] = torch.dot(grad_flat_t, delta_flat)
            norms_raw[i] = delta_flat.norm()
            self.last_base_updates[name] = delta_flat.detach().clone()

            updates_norm = updates_raw

        return updates_raw, updates_norm, preds_raw, norms_raw

    # Interfaces for logging update histories.
    def get_last_base_updates(self):
        return dict(self.last_base_updates)
    def get_last_base_updates_normed(self):
        return dict(self.last_base_updates_normed)

    # Curvature calculation
    @torch.no_grad()
    def _curvature_reward_instant(self, updates_norm: torch.Tensor, grad_flat_t: torch.Tensor) -> torch.Tensor:
        if self._prev_grad_flat is None or self._prev_final_update is None:
            return torch.zeros(self.num_optimizers, device=grad_flat_t.device, dtype=grad_flat_t.dtype)

        s = -self._prev_final_update.to(grad_flat_t)
        y = (grad_flat_t - self._prev_grad_flat.to(grad_flat_t))

        ys = torch.dot(y, s)
        if float(ys) <= 0.0:
            return torch.zeros(self.num_optimizers, device=grad_flat_t.device, dtype=grad_flat_t.dtype)

        rel = (ys / ((y.norm() * s.norm()) + self.curv_eps)).clamp_min(0.0)

        denom = s.clone().abs_().clamp_min_(self.curv_eps)
        diagH = (y.abs() / denom).clamp_(self.curv_clip_min, self.curv_clip_max)
        d_newton = - grad_flat_t / diagH

        dN_norm = d_newton.norm() + self.curv_eps
        if float(dN_norm) == 0.0:
            return torch.zeros(self.num_optimizers, device=grad_flat_t.device, dtype=grad_flat_t.dtype)

        U = updates_norm
        U_norms = U.norm(dim=1) + self.curv_eps
        cos_sim = torch.mv(U, d_newton) / (U_norms * dN_norm)
        return rel * cos_sim

    # Main step
    @torch.no_grad()
    def step(self):
        model_params = self._collect_model_params()
        if not model_params:
            return

        # g_t at x_t
        grad_flat_t = self._flat_all_grads_with_zeros(model_params)
        # if (self.meta_step%100==0): print(grad_flat_t)

        # Build Δ^{t} (raw) + normalized updates for mixing
        updates_raw, updates_norm, preds_raw, norms_raw = self._compute_updates_matrix(
            grad_flat_t, model_params
        )

        # Apply mixed update with current feasible λ_t  (normalized)
        lam = self.lambdas.data.to(updates_norm.device, dtype=updates_norm.dtype)
        final_update = torch.matmul(lam, updates_norm)

        # Apply to real params
        offset = 0
        for p in model_params:
            n = p.numel()
            p.add_(final_update[offset:offset + n].view_as(p), alpha=-1.0)
            offset += n

        # keeps moments/step counters aligned with the mixed trajectory.
        for lam_i, base in zip(lam, self.base_optimizers):
            self._state_only_step_with_scaled_grads(base, min(.25, float(lam_i)), model_params)

        # Advantage meta-scores on normalized steps (pred decrease)
        pred = torch.mv(updates_norm, grad_flat_t)
        pred_mix = torch.dot(lam, pred)
        h = pred - pred_mix

        # Mild norm tempering
        norms = updates_norm.norm(dim=1) + 1e-12
        median_norm = torch.median(norms)
        temper = (norms / (median_norm + 1e-12)).clamp_(0.5, 2.0)
        h = h * temper

        # Curvature reward
        if self.curvature_bonus != 0.0:
            curv_local = self._curvature_reward_instant(updates_norm, grad_flat_t)
            beta = self._curv_ma_beta
            one_m_beta = 1.0 - beta
            self._curv_ma.mul_(beta).add_(curv_local, alpha=one_m_beta)
            self._curv_msq.mul_(beta).addcmul_(curv_local, curv_local, value=one_m_beta)
            t = max(self.meta_step, 1)
            debias = 1.0 - (beta ** t)
            ma_hat = self._curv_ma / (debias + 1e-12)
            msq_hat = self._curv_msq / (debias + 1e-12)
            var_hat = (msq_hat - ma_hat * ma_hat).clamp_min(0.0)
            std_hat = var_hat.sqrt()
            consistency = ma_hat - self._curv_kappa_lcb * std_hat
            curv_combined = self._curv_alpha_local * curv_local + (1.0 - self._curv_alpha_local) * consistency
            h = h + torch.norm(h, 2) * self.curvature_bonus * curv_combined

        # EMA smoothing on h
        self._h_ema.mul_(0.9).add_(h, alpha=0.1)
        h = self._h_ema.clamp_(-3.0, 3.0)

        # Meta-update λ (cadence/regularizers)
        self.meta_step += 1
        do_meta = (self.meta_step % self.meta_update_every) == 0
        if do_meta:
            # Early meta update temperature
            lr_t = self._current_lr_meta()
            T = self.meta_temperature_early if self.meta_step < self.meta_temperature_switch_step else 1.0

            # Entropy reward
            tau = self.entropy_tau0 * max(0.0, 1.0 - self.meta_step / float(self.entropy_tau_steps))
            h_eff = h
            if tau > 0.0:
                h_eff = h + tau * (torch.log(self.lambdas.data.clamp_min(1e-12)) + 1.0)

            # Do a mirror descent step
            delta = ((lr_t / T) * h_eff).clamp_(min=-6.0, max=6.0)
            unconstrained = self.lambdas.data * torch.exp(delta)
            unconstrained = torch.nan_to_num(unconstrained, nan=0.0, posinf=1e6, neginf=0.0)
            with torch.no_grad(): new_l = project_simplex(unconstrained)

            dlam = new_l - self.lambdas.data
            delta_l1 = float(dlam.abs().sum().item())
            if delta_l1 > self.lambda_tv_cap:
                scale = self.lambda_tv_cap / (delta_l1 + 1e-12)
                new_l = self.lambdas.data + dlam * scale

            if self.lambda_eps_floor > 0.0:
                m = float(self.num_optimizers)
                new_l = (1.0 - self.lambda_eps_floor) * new_l + self.lambda_eps_floor / m

            with torch.no_grad(): self.lambdas.data = project_simplex(new_l)

        # Store for logging/inspection parity
        self._prev_updates = updates_raw.detach()
        self._prev_updates_normed = updates_norm.detach()

        # Cache for next-step curvature estimate
        self._prev_grad_flat = grad_flat_t.detach().clone()
        self._prev_final_update = final_update.detach().clone()

        return None
