from __future__ import annotations

from typing import Dict, Any, Optional, Tuple

import jax
import jax.numpy as jnp

from flax import struct, linen as nn
from flax.training import train_state
from jax.flatten_util import ravel_pytree
import optax




@struct.dataclass
class RotLayer:
    QL: jnp.ndarray  # (out, out)
    QR: jnp.ndarray  # (in, in)
    Lb: jnp.ndarray  # (out, out)  EMA of grad_S grad_S^T in S-basis
    Rb: jnp.ndarray  # (in, in)    EMA of grad_S^T grad_S in S-basis


def _dict_to_vec(d: Dict[str, Any], keys: tuple[str, ...]) -> jnp.ndarray:
    return jnp.stack([d[k] for k in keys])


def _vec_to_dict(v: jnp.ndarray, keys: tuple[str, ...]) -> Dict[str, Any]:
    return {k: v[i] for i, k in enumerate(keys)}


@struct.dataclass
class TrainState(train_state.TrainState):
    weights: Dict[str, Any]
    momentum: float = 0.9
    st_params: Dict[str, Any] = struct.field(default_factory=dict)
    
    def variables(self) -> Dict[str, Any]:
        return {"params": self.params}

    def update_bases(self, grads, **kwargs) -> "TrainState":
        return self

    def ema_update_weights(self, weights: Dict[str, Any], stop_grad: bool = True) -> "TrainState":
        m = self.momentum
        w = jax.tree_util.tree_map(lambda o, n: o * m + (1.0 - m) * n, self.weights, weights)
        if stop_grad:
            w = jax.lax.stop_gradient(w)
        return self.replace(weights=w)

@struct.dataclass
class AllTrainState(TrainState):
    loss_ema: Dict[str, Any] = struct.field(default_factory=dict)
    loss_ema_beta: float = 0.99
    align_alpha = None
    align_freq: int = 10
    align_eps: float = 1e-12
    align_clip_min: float = 1e-20 

class DROTrainState(TrainState):
    loss_ema: Dict[str, Any] = struct.field(default_factory=dict)
    loss_ema_beta: float = 0.99
    def ema_update_losses(self, grouped_losses: Dict[str, Any], beta: float = None, stop_grad: bool = True) -> "TrainState":
        if beta is None:
            beta = self.loss_ema_beta

        def _init():
            return {k: jnp.array(1.0, dtype=grouped_losses[k].dtype) for k in grouped_losses.keys()}

        loss_ema = jax.lax.cond(
            (self.loss_ema is None) | (len(self.loss_ema) == 0),
            _init,
            lambda: self.loss_ema,
        )

        new_ema = jax.tree_util.tree_map(
            lambda e, x: beta * e + (1.0 - beta) * x,
            loss_ema,
            grouped_losses,
        )
        if stop_grad:
            new_ema = jax.lax.stop_gradient(new_ema)
        return self.replace(loss_ema=new_ema)
    

    def normalized_group_losses(self, grouped_losses: Dict[str, Any], eps: float = 1e-12) -> Dict[str, Any]:
        if self.loss_ema is None or len(self.loss_ema) == 0:
            return grouped_losses
        return {k: grouped_losses[k] / (self.loss_ema[k] + eps) for k in grouped_losses.keys()}

    def groupdro_update_weights(
        self,
        grouped_losses: Dict[str, Any],
        eta: float = 0.1,
        eps: float = 1e-12,
        use_loss_ema_norm: bool = True,
        smooth_weights_ema: bool = False,
    ) -> "TrainState":
        losses_for_adv = (
            self.normalized_group_losses(grouped_losses, eps=eps)
            if use_loss_ema_norm
            else grouped_losses
        )
        keys = tuple(sorted(losses_for_adv.keys()))

        loss_vec = _dict_to_vec(losses_for_adv, keys)
        loss_vec = jax.lax.stop_gradient(loss_vec)

        w_vec = _dict_to_vec(self.weights, keys)
        logw = jnp.log(w_vec + eps) + eta * loss_vec
        q = jax.nn.softmax(logw)

        new_w = _vec_to_dict(q, keys)

        if smooth_weights_ema:
            return self.ema_update_weights(new_w, stop_grad=True)
        return self.replace(weights=new_w)
    


class AlignTrainState(TrainState):
    align_alpha = None
    align_freq: int = 10
    align_eps: float = 1e-12
    align_clip_min: float = 1e-20      

    def set_align_weights(self, keys: Tuple[str, ...], alpha: jnp.ndarray) -> "TrainState":
        d = {k: alpha[i] for i, k in enumerate(keys)}
        d = jax.lax.stop_gradient(d)
        return self.replace(weights=d, align_alpha=jax.lax.stop_gradient(alpha))

    def align_and_aggregate_grads(self, grads_by_task: Dict[str, Any], keys: Tuple[str, ...]) -> Tuple["TrainState", Any]:
        G, unravel = _grads_dict_to_matrix(grads_by_task, keys)
        T = G.shape[1]
        w = jnp.ones((T,), dtype=G.dtype) / T
        alpha_prev = jax.lax.cond(
            self.align_alpha is None,
            lambda _: w,
            lambda a: a,
            operand=self.align_alpha,
        )
        do_align = (self.step % self.align_freq) == 0

        def compute_new(_):
            g_hat_flat, alpha = _task_space_alignment(G, w, self.align_eps, self.align_clip_min)
            return g_hat_flat, alpha

        def reuse_old(_):
            g_hat_flat = G @ alpha_prev
            return g_hat_flat, alpha_prev

        g_hat_flat, alpha = jax.lax.cond(do_align, compute_new, reuse_old, operand=None)
        grads = unravel(g_hat_flat)
        state = self.replace(align_alpha=jax.lax.stop_gradient(alpha))
        return state, grads
    

    
@struct.dataclass
class LossTrainState(TrainState):
    loss_params: Any = None
    loss_tx: optax.GradientTransformation = struct.field(pytree_node=False, default=None)
    loss_opt_state: optax.OptState = None

    def variables(self) -> Dict[str, Any]:
        return {"params": self.params, "loss_params": self.loss_params}

    @classmethod
    def create(
        cls,
        *,
        apply_fn,
        params,
        tx,
        weights,
        loss_params: Dict[str, jnp.ndarray],
        loss_tx: optax.GradientTransformation,
        st_params=None,
        **kwargs,
    ) -> "LossTrainState":


        return cls(
            step=0,
            apply_fn=apply_fn,
            params=params,
            tx=tx,
            opt_state=tx.init(params),
            weights=weights,
            st_params={} if st_params is None else st_params,
            loss_params=loss_params,
            loss_tx=loss_tx,
            loss_opt_state=loss_tx.init(loss_params),
            **kwargs,
        )

    def apply_gradients(
        self,
        *,
        grads=None,
        loss_grads=None,
        **kwargs,
    ) -> "LossTrainState":
        new_params, new_opt_state = self.params, self.opt_state
        if grads is not None:
            updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
            new_params = optax.apply_updates(self.params, updates)

        new_loss_params, new_loss_opt_state = self.loss_params, self.loss_opt_state
        if loss_grads is not None:
            if self.loss_tx is None:
                raise ValueError("loss_grads provided but loss_tx is None.")
            loss_updates, new_loss_opt_state = self.loss_tx.update(
                loss_grads, self.loss_opt_state, self.loss_params
            )
            new_loss_params = optax.apply_updates(self.loss_params, loss_updates)

        return self.replace(
            step=self.step + 1,
            params=new_params,
            opt_state=new_opt_state,
            loss_params=new_loss_params,
            loss_opt_state=new_loss_opt_state,
            **kwargs,
        )



def eigvecs_psd(A, eps=1e-6):
    A = A + eps * jnp.eye(A.shape[0], dtype=A.dtype)
    _, Q = jnp.linalg.eigh(A)
    return Q

def is_rot_layer(node):
    return isinstance(node, dict) and {"QL", "QR", "L", "R"}.issubset(node.keys())


def update_rot_state_and_S_for_layer(rot_layer, S, grad_S, *, rho, step, precond_freq, eps, max_dim):
    QL, QR, A, B = rot_layer["QL"], rot_layer["QR"], rot_layer["L"], rot_layer["R"]

    # G is grad in W-space
    G = QL @ grad_S @ QR.T

    A_new = rho * A + (1.0 - rho) * (G @ G.T)
    B_new = rho * B + (1.0 - rho) * (G.T @ G)

    m, n = grad_S.shape
    use_left = (m <= max_dim)
    use_right = (n <= max_dim)

    do_refresh = jnp.logical_or(step == 1, (step % precond_freq) == 0)

    def refresh(_):
        QL_new = jax.lax.cond(use_left,  lambda __: eigvecs_psd(A_new, eps), lambda __: QL, operand=None)
        QR_new = jax.lax.cond(use_right, lambda __: eigvecs_psd(B_new, eps), lambda __: QR, operand=None)

        # change-of-basis to keep W invariant
        S_new = (QL_new.T @ QL) @ S @ (QR.T @ QR_new)
        return QL_new, QR_new, S_new

    def no_refresh(_):
        return QL, QR, S

    QL_out, QR_out, S_out = jax.lax.cond(do_refresh, refresh, no_refresh, operand=None)
    rot_out = {"QL": QL_out, "QR": QR_out, "L": A_new, "R": B_new}
    return rot_out, S_out


    


@struct.dataclass
class RotTrainState(TrainState):
    rot_state: Any = None
    use_rot: bool = struct.field(pytree_node=False, default=True)

    def variables(self) -> Dict[str, Any]:
        return {"params": self.params, "rot_state": self.rot_state}

    def update_bases(self, grads, *, rho=0.99, precond_freq=50, eps=1e-6, max_dim=512):
        def update_node(rot_node, param_node, grad_node):
            if is_rot_layer(rot_node) and isinstance(grad_node, dict) and "S" in grad_node:
                rot_new, S_new = update_rot_state_and_S_for_layer(
                    rot_node,
                    param_node["S"],
                    grad_node["S"],
                    rho=rho,
                    step=self.step,
                    precond_freq=precond_freq,
                    eps=eps,
                    max_dim=max_dim,
                )
                new_param = dict(param_node)
                new_param["S"] = S_new
                return (rot_new, new_param)
            return (rot_node, param_node)

        updated = jax.tree_util.tree_map(
            update_node,
            self.rot_state,
            self.params,
            grads,
            is_leaf=is_rot_layer,
        )

        is_tuple = lambda x: isinstance(x, tuple)
        new_rot_state = jax.tree_util.tree_map(lambda t: t[0], updated, is_leaf=is_tuple)
        new_params = jax.tree_util.tree_map(lambda t: t[1], updated, is_leaf=is_tuple)
        return self.replace(rot_state=new_rot_state, params=new_params)


def is_rot_layer_f(node: Any) -> bool:
    return isinstance(node, RotLayer)


def _symmetrize(M: jnp.ndarray) -> jnp.ndarray:
    return 0.5 * (M + M.T)


def eigvecs_psd(A: jnp.ndarray, eps: float = 1e-6) -> jnp.ndarray:
    # Make sure it's symmetric PSD-ish
    #A = _symmetrize(A)
    A = A + eps * jnp.eye(A.shape[0], dtype=A.dtype)
    _, Q = jnp.linalg.eigh(A)  # ascending eigvals
    return Q


def update_fast_rot_layer_and_S(
    rot: RotLayer,
    S: jnp.ndarray,
    grad_S: jnp.ndarray,
    *,
    rho: float,
    do_refresh: jnp.ndarray,     # scalar bool (JAX)
    eps: float,
    max_dim: int,
) -> tuple[RotLayer, jnp.ndarray]:
    """
    Fast update:
      - Update EMA covariances in S-basis (NO forming G = QL @ grad_S @ QR.T).
      - On refresh, rotate bases by eigvecs of Lb/Rb, and rotate S to keep W invariant.
    """
    m, n = grad_S.shape

    use_left = m <= max_dim
    use_right = n <= max_dim

    # Update basis-space covariances (cheap relative to forming G with QL/QR)
    def upd_left(Lb):
        return rho * Lb + (1.0 - rho) * (grad_S @ grad_S.T)

    def upd_right(Rb):
        return rho * Rb + (1.0 - rho) * (grad_S.T @ grad_S)

    Lb_new = jax.lax.cond(use_left, lambda x: upd_left(x), lambda x: x, rot.Lb)
    Rb_new = jax.lax.cond(use_right, lambda x: upd_right(x), lambda x: x, rot.Rb)

    def refresh(_):
        # Get eigvecs in S-basis
        U = jax.lax.cond(use_left,  lambda __: eigvecs_psd(Lb_new, eps), lambda __: jnp.eye(rot.QL.shape[1], dtype=rot.QL.dtype), operand=None)
        V = jax.lax.cond(use_right, lambda __: eigvecs_psd(Rb_new, eps), lambda __: jnp.eye(rot.QR.shape[1], dtype=rot.QR.dtype), operand=None)

        # Update weight-space bases
        QL_out = rot.QL @ U
        QR_out = rot.QR @ V

        # Rotate S to keep W = QL @ S @ QR.T invariant:
        # new S = U^T S V
        S_out = (U.T @ S) @ V

        # Optionally diagonalize stored covariances in the new basis:
        # Lb' = U^T Lb U, Rb' = V^T Rb V
        Lb_out = jax.lax.cond(use_left,  lambda __: _symmetrize(U.T @ Lb_new @ U), lambda __: Lb_new, operand=None)
        Rb_out = jax.lax.cond(use_right, lambda __: _symmetrize(V.T @ Rb_new @ V), lambda __: Rb_new, operand=None)

        return RotLayer(QL=QL_out, QR=QR_out, Lb=Lb_out, Rb=Rb_out), S_out

    def no_refresh(_):
        return RotLayer(QL=rot.QL, QR=rot.QR, Lb=Lb_new, Rb=Rb_new), S

    rot_out, S_out = jax.lax.cond(do_refresh, refresh, no_refresh, operand=None)
    return rot_out, S_out


@struct.dataclass
class FastRotTrainState(TrainState):
    rot_state: Any = None
    use_rot: bool = struct.field(pytree_node=False, default=True)

    def variables(self) -> Dict[str, Any]:
        return {"params": self.params, "rot_state": self.rot_state}

    def update_bases(
        self,
        grads,
        *,
        rho: float = 0.99,
        precond_freq: int = 50,
        eps: float = 1e-6,
        max_dim: int = 512,
    ) -> "FastRotTrainState":
        # Compute refresh once (scalar JAX bool) and reuse for every layer
        step = self.step
        do_refresh = jnp.logical_or(step == 1, (step % precond_freq) == 0)

        def update_node(rot_node, param_node, grad_node):
            # Expect param_node to be a dict with key "S"
            if (
                isinstance(rot_node, dict)
                and "rot" in rot_node
                and isinstance(grad_node, dict)
                and "S" in grad_node
            ):
                rot_layer = rot_node['rot']
                S = param_node["S"]
                gS = grad_node["S"]
                rot_new, S_new = update_fast_rot_layer_and_S(
                    rot_layer,
                    S,
                    gS,
                    rho=rho,
                    do_refresh=do_refresh,
                    eps=eps,
                    max_dim=max_dim,
                )
                new_param = dict(param_node)
                new_param["S"] = S_new
                return ({'rot': rot_new}, new_param)

            return (rot_node, param_node)

        updated = jax.tree_util.tree_map(
            update_node,
            self.rot_state,
            self.params,
            grads,
            is_leaf=is_rot_layer,
        )

        is_tuple = lambda x: isinstance(x, tuple)
        new_rot_state = jax.tree_util.tree_map(lambda t: t[0], updated, is_leaf=is_tuple)
        new_params = jax.tree_util.tree_map(lambda t: t[1], updated, is_leaf=is_tuple)

        return self.replace(rot_state=new_rot_state, params=new_params)
    

def _task_space_alignment(G: jnp.ndarray, w: jnp.ndarray, eps: float, clip_min: float) -> Tuple[jnp.ndarray, jnp.ndarray]:
    M = G.T @ G
    lam, V = jnp.linalg.eigh(M)
    pos = lam > eps

    def fallback():
        alpha = w
        g_hat = G @ alpha
        return g_hat, alpha

    def aligned():
        lam_pos = jnp.where(pos, lam, jnp.inf)
        lam_min_pos = jnp.maximum(jnp.min(lam_pos), clip_min)
        inv_sqrt = jnp.where(pos, 1.0 / jnp.sqrt(jnp.maximum(lam, clip_min)), 0.0)
        B = jnp.sqrt(lam_min_pos) * (V @ (inv_sqrt[:, None] * V.T))
        alpha = B @ w
        g_hat = G @ alpha
        return g_hat, alpha

    return jax.lax.cond(jnp.any(pos), aligned, fallback)


def _grads_dict_to_matrix(grads_dict: Dict[str, Any], keys: Tuple[str, ...]):
    flat0, unravel = ravel_pytree(grads_dict[keys[0]])
    cols = [flat0]
    for k in keys[1:]:
        fk, _ = ravel_pytree(grads_dict[k])
        cols.append(fk)
    G = jnp.stack(cols, axis=1)
    return G, unravel




    
