import jax
import jax.numpy as jnp
from optax import GradientTransformation, OptState, Updates, Params
from typing import NamedTuple, Any
from flax import nnx


# 1. Update ShampooState dataclass
class ShampooState(NamedTuple):
    """Holds the optimizer state for Shampoo."""

    L_mats: Any  # PyTree for left preconditioners
    R_mats: Any  # PyTree for right preconditioners


class L_state(nnx.Variable):
    pass


class R_state(nnx.Variable):
    pass


class update_state(nnx.Variable):
    pass


def shampoo(
    learning_rate: float, epsilon: float = 1e-6, sexp: float = 0.25
) -> GradientTransformation:
    """
    Implements the Shampoo optimizer as an Optax gradient transformation.
    """

    def init_fn(params: Params) -> ShampooState:
        def init_param_state_l(p):
            if p.ndim != 2:
                # For non-matrix params, we don't need a preconditioner
                return None
            dim_0, dim_1 = p.shape
            return jnp.eye(dim_0) * epsilon

        def init_param_state_r(p):
            if p.ndim != 2:
                # For non-matrix params, we don't need a preconditioner
                return None
            dim_0, dim_1 = p.shape
            return jnp.eye(dim_1) * epsilon  # R_mat (dim_1 x dim_1)

        l_mats = jax.tree_util.tree_map(init_param_state_l, params)
        r_mats = jax.tree_util.tree_map(init_param_state_r, params)

        return ShampooState(L_mats=l_mats, R_mats=r_mats)

    @jax.jit
    def update_fn(
        grads: Updates, state: ShampooState, params: Params = None
    ) -> (Updates, ShampooState):

        def update_single_param(grad, p, l_mat, r_mat):
            if p.ndim != 2:
                # Simple SGD for non-matrix parameters (e.g., biases)
                return nnx.Param(-learning_rate * grad), L_state(l_mat), R_state(r_mat)

            # Update statistics for matrix parameters

            new_r_mat = r_mat + jnp.dot(grad.T, grad)
            new_l_mat = l_mat + jnp.dot(grad, grad.T)

            u_r, s_r, v_r = jnp.linalg.svd(new_r_mat)
            spow_r = jnp.power(s_r, -sexp)
            r_pow = (u_r * spow_r) @ v_r

            # Left preconditioner (L^{-1/4})
            u_l, s_l, v_l = jnp.linalg.svd(new_l_mat)
            spow_l = jnp.power(s_l, -sexp)
            l_pow = (u_l * spow_l) @ v_l

            update = l_pow @ grad @ r_pow

            return (
                nnx.Param(-learning_rate * update),
                L_state(new_l_mat),
                R_state(new_r_mat),
            )

        output = jax.tree_util.tree_map(
            update_single_param,
            grads,
            params,
            state.L_mats,
            state.R_mats,
        )
        _, updates, new_L_mats, new_R_mats = nnx.split(
            output, nnx.Param, L_state, R_state
        )

        updates = jax.tree.map(
            lambda p, u: (
                None
                if p is None
                else jnp.asarray(u[0].value).astype(jnp.asarray(p).dtype)
            ),
            params,
            updates,
            is_leaf=lambda x: x is None,
        )

        new_L_mats = jax.tree.map(
            lambda p, l: (None if p is None else l[1].value),
            params,
            new_L_mats,
            is_leaf=lambda x: x is None,
        )

        new_R_mats = jax.tree.map(
            lambda p, r: (None if p is None else r[2].value),
            params,
            new_R_mats,
            is_leaf=lambda x: x is None,
        )

        return updates, ShampooState(L_mats=new_L_mats, R_mats=new_R_mats)

    return GradientTransformation(init_fn, update_fn)
