# adamo.py
from typing import List, Optional, Dict, Tuple, Any

import torch
from torch import Tensor

from .optimizer import Optimizer, _use_grad_for_differentiable
from .adam import adam

__all__ = ["AdamO", "adamo"]


def _flatten_to_2d(w: Tensor) -> Tuple[Tensor, Tuple[int, ...]]:
    """Flatten weight to 2D as (rows, cols) while remembering original shape."""
    orig_shape = tuple(w.shape)
    w2 = w.reshape(w.shape[0], -1)
    return w2, orig_shape


def _get_cached_eye(state: Dict[str, Any], n: int, like: Tensor, key: str) -> Tensor:
    """
    Get (and cache) an identity matrix of size n on the same device/dtype as `like`.

    Notes:
      - Cache eye to avoid repeated allocation.
      - Storing in optimizer state (state[p]) is most convenient.
    """
    eye = state.get(key, None)
    if eye is None or eye.shape != (n, n) or eye.dtype != like.dtype or eye.device != like.device:
        eye = torch.eye(n, dtype=like.dtype, device=like.device)
        state[key] = eye
    return eye


def _orthogonality_grad_from_weight(w: Tensor, state: Dict[str, Any]) -> Tensor:
    """
    Compute gradient of the orthogonality penalty R(W):

    R(W) = 1/4 || W W^T - I ||_F^2   if rows < cols  (wide: orthonormal rows)
         = 1/4 || W^T W - I ||_F^2   if rows >= cols (tall: orthonormal cols)

    grad:
      wide: (W W^T - I) W
      tall: W (W^T W - I)

    Supports ndim >= 2 by flattening to (w.shape[0], -1).
    """
    w2, orig_shape = _flatten_to_2d(w)
    r, c = w2.shape
    if r < c:
        I = _get_cached_eye(state, r, w2, key="orth_I_row")
        wwT = w2 @ w2.transpose(0, 1)
        g2 = (wwT - I) @ w2
    else:
        I = _get_cached_eye(state, c, w2, key="orth_I_col")
        wTw = w2.transpose(0, 1) @ w2
        g2 = w2 @ (wTw - I)
    return g2.reshape(orig_shape)


def _fro_inner(a: Tensor, b: Tensor) -> Tensor:
    """Frobenius inner product <a,b> = sum(a*b)."""
    return torch.sum(a * b)


def _normalized_orth_step(r: Tensor, u: Tensor, kappa: float, eps_r: float) -> Tensor:
    """
    Build a scale-stable orth step:
        delta0 = kappa * ||u|| * r / (||r|| + eps_r)

    This merges the role of "orth_lambda" and "orth_rho" into one knob kappa,
    making the step magnitude roughly proportional to Adam step magnitude.
    """
    if kappa <= 0.0:
        return torch.zeros_like(r)

    # Norms may be fp32 even if r/u are fp16; keep delta0 dtype consistent with r.
    u_norm = torch.norm(u)
    r_norm = torch.norm(r)

    # Ensure eps is not flushed-to-zero for low-precision dtypes.
    if torch.is_floating_point(r):
        dtype_eps = torch.finfo(r.dtype).eps
        eps_val = max(float(eps_r), float(dtype_eps))
    else:
        eps_val = float(eps_r)

    denom = (r_norm + eps_val)
    factor = (kappa * u_norm / denom)

    # Cast factor to r.dtype to avoid dtype promotion that breaks in-place add on fp16 params.
    factor = factor.to(dtype=r.dtype)
    return r * factor


def _budget_scale(delta0: Tensor, g: Tensor, u: Tensor, tau: float, eps_g: float) -> Tensor:
    """
    Scale-only budget projection.

    We want: <g, delta> >= -tau * <g, u>
    but constrain delta = s * delta0, with s in [0, 1].

    Let target = -tau <g, u>, dot = <g, delta0>.
      - If dot >= target -> s=1
      - Else if target <= 0 and dot < 0 -> s = clip(target/(dot+eps), 0, 1)
      - Else -> s=0 (scaling can't help)

    Notes:
      - tau=0 => target=0 => if dot<0 then s=0 (strict cautious: disable orth when conflicting)
      - This avoids changing direction (unlike adding alpha*g).
    """
    # Compute inner products in fp32 for robustness; return delta in original dtype.
    dot = _fro_inner(g.float(), delta0.float())
    gu = _fro_inner(g.float(), u.float())

    if tau <= 0.0:
        target = torch.zeros_like(dot)
    else:
        target = -tau * gu

    # eps handling (dot is fp32 here)
    dtype_eps = torch.finfo(dot.dtype).eps
    eps_val = max(float(eps_g), float(dtype_eps))

    one = torch.ones_like(dot)
    zero = torch.zeros_like(dot)

    cond_ok = dot >= target
    cond_scale = (~cond_ok) & (target <= 0) & (dot < 0)

    s_raw = (target / (dot + eps_val)).clamp(min=0.0, max=1.0)
    s = torch.where(cond_ok, one, torch.where(cond_scale, s_raw, zero))

    # Cast scale back to delta dtype (avoid dtype promotion issues).
    s = s.to(dtype=delta0.dtype)
    return delta0 * s


def adamo(params: List[Tensor],
           grads: List[Tensor],
           exp_avgs: List[Tensor],
           exp_avg_sqs: List[Tensor],
           max_exp_avg_sqs: List[Tensor],
           state_steps: List[Tensor],
           *,
           amsgrad: bool,
           beta1: float,
           beta2: float,
           lr: float,
           weight_decay: float,
           eps: float,
           maximize: bool,
           foreach: Optional[bool] = None,
           capturable: bool = False,
           differentiable: bool = False,
           fused: Optional[bool] = None,
           grad_scale: Optional[Tensor] = None,
           found_inf: Optional[Tensor] = None,
           # AdamO-specific
           orth_kappa: float = 0.0,
           orth_tau: float = 0.0,
           orth_eps_g: float = 1e-4,
           orth_eps_r: float = 1e-4) -> None:
    """
    Functional AdamO (route B):

      1) run standard Adam update (exactly identical to `adam`)
      2) apply orthogonality correction on matrix-like params (ndim>=2)

         r      = grad_R(W_old)
         u_hat  = (W_old - W_after_adam) / lr
         delta0 = kappa * ||u_hat|| * r / (||r|| + eps_r)
         delta  = scale_budget(delta0; <g,delta> >= -tau<g,u_hat>)
         W <- W_after_adam - lr * delta

    Notes:
      - Functional API cannot conveniently reuse optimizer state for caching, so no eye caching here (eye will be created each step).
      - Recommended to use the AdamO class below (which caches eye in state[p]).
    """
    if orth_kappa == 0.0:
        adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps,
             foreach=foreach, capturable=capturable, differentiable=differentiable, fused=fused,
             grad_scale=grad_scale, found_inf=found_inf,
             amsgrad=amsgrad, beta1=beta1, beta2=beta2, lr=lr, weight_decay=weight_decay,
             eps=eps, maximize=maximize)
        return

    if not (0.0 <= orth_tau < 1.0):
        raise ValueError(f"Invalid orth_tau={orth_tau}. Expected in [0,1).")
    if orth_kappa < 0.0:
        raise ValueError(f"Invalid orth_kappa={orth_kappa}. Expected >= 0.")

    if capturable:
        raise RuntimeError("AdamO does not support capturable=True when orth_kappa != 0.")
    if differentiable:
        raise RuntimeError("AdamO does not support differentiable=True when orth_kappa != 0.")

    # AMP overflow -> skip orth module (but still do Adam baseline)
    if found_inf is not None and float(found_inf) != 0.0:
        adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps,
             foreach=foreach, capturable=capturable, differentiable=differentiable, fused=fused,
             grad_scale=grad_scale, found_inf=found_inf,
             amsgrad=amsgrad, beta1=beta1, beta2=beta2, lr=lr, weight_decay=weight_decay,
             eps=eps, maximize=maximize)
        return

    if lr == 0.0:
        # No update; avoid division by zero in u_hat.
        return

    # Save old weights for orth params to recover u_hat.
    orth_infos: List[Tuple[int, Tensor]] = []
    for i, p in enumerate(params):
        if p.grad is None:
            continue
        if p.ndim < 2:
            continue
        if (not torch.is_floating_point(p)) or torch.is_complex(p):
            continue
        orth_infos.append((i, p.detach().clone()))

    # Baseline Adam
    adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps,
         foreach=foreach, capturable=capturable, differentiable=differentiable, fused=fused,
         grad_scale=grad_scale, found_inf=found_inf,
         amsgrad=amsgrad, beta1=beta1, beta2=beta2, lr=lr, weight_decay=weight_decay,
         eps=eps, maximize=maximize)

    # Orth correction (route B)
    with torch.no_grad():
        for (i, p_old) in orth_infos:
            p = params[i]
            if p.grad is None:
                continue

            u_hat = (p_old - p) / lr  # p_after_adam = p_old - lr*u_hat

            # gradient consistent with Adam update
            g = grads[i] if not maximize else -grads[i]
            if weight_decay != 0.0:
                g = g.add(p_old, alpha=weight_decay)

            if g.shape != p.shape:
                continue

            # no cache in functional API
            r = _orthogonality_grad_from_weight(p_old, state={})

            delta0 = _normalized_orth_step(r, u_hat, kappa=orth_kappa, eps_r=orth_eps_r)
            if torch.count_nonzero(delta0).item() == 0:
                continue

            delta = _budget_scale(delta0, g, u_hat, tau=orth_tau, eps_g=orth_eps_g)

            # apply orth correction: p <- p - lr*delta
            p.add_(delta, alpha=-lr)


class AdamO(Optimizer):
    """
    AdamO = Adam + orthogonality correction (scale-only budget projection, route B).

    Key knobs (recommended minimal set):
      orth_kappa : orth step scaling factor relative to Adam step (dimensionless, usually more stable)
      orth_tau   : budget coefficient tau in [0,1)
                   constraint: <g,delta> >= -tau <g,u_hat>
                   tau=0 => strict cautious: shrink/disable orth when conflicting
    """
    def __init__(self,
                 params,
                 lr: float = 1e-3,
                 betas=(0.9, 0.999),
                 eps: float = 1e-4,
                 weight_decay: float = 0.0,
                 amsgrad: bool = False,
                 *,
                 foreach: Optional[bool] = None,
                 maximize: bool = False,
                 capturable: bool = False,
                 differentiable: bool = False,
                 fused: Optional[bool] = None,
                 # AdamO-specific
                 orth_kappa: float = 0.0,
                 orth_tau: float = 0.0,
                 orth_eps_g: float = 1e-4,
                 orth_eps_r: float = 1e-4):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        if not 0.0 <= weight_decay:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")

        if orth_kappa < 0.0:
            raise ValueError(f"Invalid orth_kappa={orth_kappa}. Expected >= 0.")
        if not (0.0 <= orth_tau < 10000.0):
            raise ValueError(f"Invalid orth_tau={orth_tau}. Expected in [0,1).")
        if orth_eps_g < 0.0:
            raise ValueError(f"Invalid orth_eps_g={orth_eps_g}. Expected >= 0.")
        if orth_eps_r < 0.0:
            raise ValueError(f"Invalid orth_eps_r={orth_eps_r}. Expected >= 0.")

        defaults = dict(
            lr=lr, betas=betas, eps=eps,
            weight_decay=weight_decay, amsgrad=amsgrad,
            maximize=maximize, foreach=foreach,
            capturable=capturable, differentiable=differentiable, fused=fused,
            # AdamO-specific
            orth_kappa=orth_kappa,
            orth_tau=orth_tau,
            orth_eps_g=orth_eps_g,
            orth_eps_r=orth_eps_r,
        )
        super().__init__(params, defaults)

        if fused:
            if differentiable:
                raise RuntimeError("`fused` does not support `differentiable`")
            self._step_supports_amp_scaling = True
            if not all(
                p.is_cuda and torch.is_floating_point(p)
                for pg in self.param_groups for p in pg["params"]
            ):
                raise RuntimeError("`fused=True` requires all the params to be CUDA, floating point Tensor")
            if foreach:
                raise RuntimeError("`fused` and `foreach` cannot be `True` together.")

    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault("amsgrad", False)
            group.setdefault("maximize", False)
            group.setdefault("foreach", None)
            group.setdefault("capturable", False)
            group.setdefault("differentiable", False)
            group.setdefault("fused", None)
            # AdamO-specific
            group.setdefault("orth_kappa", 0.0)
            group.setdefault("orth_tau", 0.0)
            group.setdefault("orth_eps_g", 1e-4)
            group.setdefault("orth_eps_r", 1e-4)

        state_values = list(self.state.values())
        step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0].get("step", None))
        if not step_is_tensor:
            for s in state_values:
                if "step" in s:
                    s["step"] = torch.tensor(float(s["step"]))

    def _init_group(self,
                    group,
                    params_with_grad,
                    grads,
                    exp_avgs,
                    exp_avg_sqs,
                    max_exp_avg_sqs,
                    state_steps):
        for p in group["params"]:
            if p.grad is None:
                continue
            params_with_grad.append(p)
            if p.grad.is_sparse:
                raise RuntimeError("AdamO does not support sparse gradients, please consider SparseAdam instead")
            grads.append(p.grad)

            state = self.state[p]
            if len(state) == 0:
                state["step"] = (
                    torch.zeros((1,), dtype=torch.float, device=p.device)
                    if group["capturable"] or group["fused"]
                    else torch.tensor(0.0)
                )
                state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                if group["amsgrad"]:
                    state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)

            exp_avgs.append(state["exp_avg"])
            exp_avg_sqs.append(state["exp_avg_sq"])
            if group["amsgrad"]:
                max_exp_avg_sqs.append(state["max_exp_avg_sq"])

            if group["differentiable"] and state["step"].requires_grad:
                raise RuntimeError("`requires_grad` is not supported for `step` in differentiable mode")
            state_steps.append(state["step"])

    @_use_grad_for_differentiable
    def step(self, closure=None):
        """Performs a single optimization step."""
        self._cuda_graph_capture_health_check()

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad: List[Tensor] = []
            grads: List[Tensor] = []
            exp_avgs: List[Tensor] = []
            exp_avg_sqs: List[Tensor] = []
            max_exp_avg_sqs: List[Tensor] = []
            state_steps: List[Tensor] = []
            beta1, beta2 = group["betas"]

            self._init_group(group, params_with_grad, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps)

            orth_kappa = float(group.get("orth_kappa", 0.0))
            orth_tau = float(group.get("orth_tau", 0.0))
            orth_eps_g = float(group.get("orth_eps_g", 1e-4))
            orth_eps_r = float(group.get("orth_eps_r", 1e-4))

            if orth_kappa != 0.0:
                if not (0.0 <= orth_tau < 10000.0):
                    raise ValueError(f"Invalid orth_tau={orth_tau}. Expected in [0,1).")
                if group.get("capturable", False):
                    raise RuntimeError("AdamO does not support capturable=True when orth_kappa != 0.")
                if group.get("differentiable", False):
                    raise RuntimeError("AdamO does not support differentiable=True when orth_kappa != 0.")

            # Save old weights only for orth-eligible params
            orth_infos: List[Tuple[int, Tensor]] = []
            if orth_kappa != 0.0:
                for idx, p in enumerate(params_with_grad):
                    if p.grad is None:
                        continue
                    if p.ndim < 2:
                        continue
                    if (not torch.is_floating_point(p)) or torch.is_complex(p):
                        continue
                    orth_infos.append((idx, p.detach().clone()))

            # Baseline Adam step (fully reusing the provided adam.py)
            adam(
                params_with_grad,
                grads,
                exp_avgs,
                exp_avg_sqs,
                max_exp_avg_sqs,
                state_steps,
                amsgrad=group["amsgrad"],
                beta1=beta1,
                beta2=beta2,
                lr=group["lr"],
                weight_decay=group["weight_decay"],
                eps=group["eps"],
                maximize=group["maximize"],
                foreach=group["foreach"],
                capturable=group["capturable"],
                differentiable=group["differentiable"],
                fused=group["fused"],
                grad_scale=getattr(self, "grad_scale", None),
                found_inf=getattr(self, "found_inf", None),
            )

            if orth_kappa == 0.0:
                continue

            # AMP overflow -> skip orth module
            found_inf = getattr(self, "found_inf", None)
            if found_inf is not None and float(found_inf) != 0.0:
                continue

            lr = float(group["lr"])
            if lr == 0.0:
                continue
            wd = float(group["weight_decay"])
            maximize = bool(group["maximize"])

            with torch.no_grad():
                for (idx, p_old) in orth_infos:
                    p = params_with_grad[idx]
                    if p.grad is None:
                        continue

                    # Adam direction u_hat (so p_after_adam = p_old - lr*u_hat)
                    u_hat = (p_old - p) / lr

                    # gradient consistent with Adam update
                    g = grads[idx] if not maximize else -grads[idx]
                    if wd != 0.0:
                        g = g.add(p_old, alpha=wd)

                    # orth gradient at pre-Adam weight
                    state = self.state[p]
                    r = _orthogonality_grad_from_weight(p_old, state)

                    delta0 = _normalized_orth_step(r, u_hat, kappa=orth_kappa, eps_r=orth_eps_r)
                    # Fast skip
                    if torch.count_nonzero(delta0).item() == 0:
                        continue

                    delta = _budget_scale(delta0, g, u_hat, tau=orth_tau, eps_g=orth_eps_g)

                    # apply orth correction: p <- p - lr*delta
                    p.add_(delta, alpha=-lr)

        return loss
