import operator
from functools import partial
from typing import Any, Callable, Optional, Union, NamedTuple, MutableMapping
import torch
from torch.optim import Optimizer

from symo.factory import CovFactory, MeanFactory

NDArray = torch.Tensor


class Symo(Optimizer):
    """Symo optimizer."""

    @torch.no_grad
    def __init__(
        self,
        params,
        # dims: dict, bla: pass in dims for compiler
        groups_spec: tuple[tuple[str, type | tuple], ...],
        lr: float | Callable = 1e-1,
        grads_beta: float = 0.0,
        factors_beta: float = 0.0,
        grads_bias_corr: bool = False,
        factors_bias_corr: bool = True,
        update_correction: bool = False,
        damping: float = 0.0,
    ):
        if not 0.0 <= damping:
            raise ValueError(f"Invalid damping value: {damping}")
        if not 0.0 <= grads_beta < 1.0:
            raise ValueError(f"Invalid grads_beta value: {grads_beta}")
        if not 0.0 <= factors_beta < 1.0:
            raise ValueError(f"Invalid factors_beta value: {factors_beta}")

        params = list(params)
        # TODO(bla): Global factors buffer. Generalize to multiple parameter groups!
        ordered = order_group_structure(params, groups_spec)
        avg_factory = MeanFactory(ordered)
        cov_factory = CovFactory(ordered)
        
        defaults = dict(
            lr=lr,
            damping=damping,
            grads_beta=grads_beta,
            factors_beta=factors_beta,
            groups_spec=groups_spec,
            grads_bias_corr=grads_bias_corr,
            factors_bias_corr=factors_bias_corr,
            update_correction=update_correction,
        )

        super().__init__(params, defaults)
        self.avg_factory = avg_factory
        self.cov_factory = cov_factory
        self.step_t = None

    def _init_group(
        self,
        group: MutableMapping,
    ):
        params_with_grad: list[NDArray] = []
        grads: list[NDArray] = []
        grad_momentum_bufs: list[NDArray] = []

        for p in group["params"]:
            if p.grad is None:
                raise RuntimeError(
                    "Symo requires gradients to be finite for all parameters"
                )

            if torch.is_complex(p):
                raise RuntimeError("Symo does not support complex parameters")
            if p.grad.is_sparse:
                raise RuntimeError("Symo does not support sparse gradients")

            params_with_grad.append(p)
            grads.append(p.grad)

            state = self.state[p]

            if "momentum_buffer" not in state:
                state["momentum_buffer"] = torch.zeros_like(
                    p.grad, memory_format=torch.preserve_format
                )

            grad_momentum_bufs.append(state["momentum_buffer"])

        if self.step_t is None:
            self.step_t = torch.tensor(0.0, dtype=p.dtype, device=p.device)

        return params_with_grad, grads, grad_momentum_bufs

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step."""
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            damping = group["damping"]
            grads_beta = group["grads_beta"]
            grads_corr = group["grads_bias_corr"]
            factors_beta = group["factors_beta"]
            factors_corr = group["factors_bias_corr"]
            update_correction = group["update_correction"]

            cov_factory = self.cov_factory
            avg_factory = self.avg_factory

            group_variables = self._init_group(group)
            params, grads, grads_buf = group_variables

            self._symo_update(
                params,
                grads,
                grads_buf,
                avg_factory,
                cov_factory,
                self.step_t,
                lr,
                grads_beta=grads_beta,
                factors_beta=factors_beta,
                damping=damping,
                grads_corr=grads_corr,
                factors_corr=factors_corr,
                updates_corr=update_correction,
            )

        return loss

    def _symo_update(
        self,
        params,
        grads,
        grads_buf,
        avg_buf: MeanFactory,
        cov_factory: CovFactory,
        step,
        lr,
        grads_beta,
        factors_beta,
        grads_corr,
        factors_corr,
        updates_corr,
        damping,
    ):
        """Core Symo update logic."""
        step += 1

        apply_momentum(grads_buf, grads, grads_beta)

        grads_mmm = grads_buf
        if grads_corr:
            grads_mmm = apply_bias(grads_mmm, grads_beta, step)

        grads_del = invariant_del_mean(avg_buf, grads_mmm)

        weights_buf = cov_factory.weights(clone=True)
        cov_factory.outer_update(grads_del)

        apply_momentum(
            weights_buf,
            cov_factory.weights(),
            factors_beta,
        )

        new_weights = weights_buf
        if factors_corr:
            new_weights = apply_bias(new_weights, factors_beta, step)

        cov_factory.update_weights(new_weights)

        # Compute surrogate and its inverse square root
        surrogate = cov_factory.cov()
        u, s, vt = svd(surrogate)
        surrogate_sqrt_inv = inv_sqrt_mat(u, s, vt, damping=damping)
        cov_factory.cov_update(surrogate_sqrt_inv)

        # Apply preconditioning and update parameters
        apply_grads = grads_mmm if updates_corr else grads_buf
        updates = cov_factory.matvec(apply_grads)

        # Keep momentum weights without bias correction
        if factors_corr:
            cov_factory.update_weights(weights_buf)
        else:
            cov_factory.update_weights(new_weights)

        update_with_lr(lr, params, updates)


class Symo2(Optimizer):
    """Symo2 optimizer with gradient and covariance averaging."""

    def __init__(
        self,
        params,
        groups_spec: tuple[tuple[str, type | tuple], ...],
        lr: Union[float, Callable] = 0.001,
        damping: float = 0.0,
        grads_beta: float = 0.0,
        sigma_g_beta: float = 0.0,
        grads_bias_corr: bool = False,
        sigma_g_bias_corr: bool = False,
        update_correction: bool = False,
    ):
        if not 0.0 <= damping:
            raise ValueError(f"Invalid damping value: {damping}")
        if not 0.0 <= grads_beta < 1.0:
            raise ValueError(f"Invalid grad_beta value: {grads_beta}")
        if not 0.0 <= sigma_g_beta < 1.0:
            raise ValueError(f"Invalid sigma_g_beta value: {sigma_g_beta}")

        params = list(params)

        # TODO(bla): Global factors buffer. Generalize to multiple parameter groups!
        ordered = order_group_structure(params, groups_spec)

        defaults = dict(
            lr=lr,
            damping=damping,
            grads_beta=grads_beta,
            sigma_g_beta=sigma_g_beta,
            grads_bias_corr=grads_bias_corr,
            sigma_g_bias_corr=sigma_g_bias_corr,
            groups_spec=groups_spec,
            update_correction=update_correction,
        )

        super().__init__(params, defaults)
        self.avg_g_buf = MeanFactory(ordered)
        self.sigma_g_buf = CovFactory(ordered)
        self.sigma_g_buffer = None
        self.step_t = None

    def _init_group(self, group):
        """Initialize optimizer state for Symo2."""
        params_with_grad: list[NDArray] = []
        grads: list[NDArray] = []
        grad_momentum_bufs: list[NDArray] = []

        for p in group["params"]:
            if p.grad is None:
                raise RuntimeError(
                    "Symo requires gradients to be finite for all parameters"
                )

            if torch.is_complex(p):
                raise RuntimeError("Symo does not support complex parameters")
            if p.grad.is_sparse:
                raise RuntimeError("Symo does not support sparse gradients")

            params_with_grad.append(p)
            grads.append(p.grad)

            state = self.state[p]

            if "momentum_buffer" not in state:
                state["momentum_buffer"] = torch.zeros_like(
                    p.grad, memory_format=torch.preserve_format
                )

            grad_momentum_bufs.append(state["momentum_buffer"])

        if self.step_t is None:
            self.step_t = torch.tensor(0.0, dtype=p.dtype, device=p.device)

        return params_with_grad, grads, grad_momentum_bufs

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step."""
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            damping = group["damping"]
            grads_beta = group["grads_beta"]
            sigma_g_beta = group["sigma_g_beta"]
            grads_corr = group["grads_bias_corr"]
            sigma_g_corr = group["sigma_g_bias_corr"]
            update_correction = group["update_correction"]
            factory = group["factory"]

            params, grads, grads_buf = self._init_group(group)

            step = self.step_t
            sigma_g_buf = self.sigma_g_buf
            avg_g_buf = self.avg_g_buf

            self._symo2_update(
                params,
                grads,
                grads_buf,
                sigma_g_buf,
                avg_g_buf,
                lr=lr,
                step=step,
                damping=damping,
                grads_beta=grads_beta,
                sigma_g_beta=sigma_g_beta,
                grads_bias_corr=grads_corr,
                sigma_g_bias_corr=sigma_g_corr,
                updates_corr=update_correction,
            )

        return loss

    def _symo2_update(
        self,
        params: list[torch.nn.Parameter],
        grads: list[torch.Tensor],
        grads_buf: list[torch.Tensor],
        avg: MeanFactory,
        sigma_g: CovFactory,
        sigma_t: CovFactory,
        lr,
        step: torch.Tensor,
        damping,
        grads_beta,
        sigma_g_beta,
        grads_bias_corr,
        sigma_g_bias_corr,
        updates_corr,
    ):
        """Core Symo2 update logic with gradient and covariance averaging."""
        step += 1

        grads_avg = apply_grads_beta(
            grads_buf, grads, grads_beta, step, bias=grads_bias_corr
        )

        grads_del = invariant_del_mean(avg, grads_avg)
        params_del = invariant_del_mean(avg, params)

        sigma_g_weights = sigma_g.weights(clone=True)
        sigma_g.outer_update(grads_del)
        sigma_t.outer_update(params_del)

        if sigma_g_bias_corr:
            sigma_g_weights_bias = apply_bias(sigma_g.weights(), sigma_g_beta, step)

        apply_momentum(sigma_g_weights, sigma_g.weights(), sigma_g_beta)

        mat = geom_matrix_mean(sigma_g, sigma_t, damping=damping)
        prec = sigma_t.cov_update(mat)

        apply_grads = grads_avg if updates_corr else grads_buf
        updates = factory.matvec(prec, apply_grads)
        update_with_lr(lr, params, updates)


def invariant_del_mean(factory: MeanFactory, values: list[torch.Tensor]):
    """Compute invariant deviation from mean."""
    avg_values = factory.avg(values)
    upd_del = [values[i] - v for i, v in enumerate(avg_values)]
    return upd_del


def geom_matrix_mean(
    sigma_g_factory,
    sigma_t_factory,
    damping: float = 0.0,
):
    """Compute geometric mean of two matrices: P⁻¹ = √A √(√A B √A)⁻¹√A."""

    sigma_g = sigma_g_factory.cov()
    sigma_t = sigma_t_factory.cov()

    # # Compute √A via eigendecomposition
    tu, ts, tvt = svd(sigma_t)
    sigma_t_sqrt = sqrt_mat(tu, ts, tvt)

    # # C = √A B √A
    c = sigma_t_sqrt @ sigma_g @ sigma_t_sqrt

    # # √C⁻¹ = W √M⁻¹ Wᵀ
    cu, cs, cvt = svd(c)
    c_inv_sqrt = inv_sqrt_mat(cu, cs + cs.max() * damping, cvt)

    # # P⁻¹ = √A √C⁻¹√A
    prec = sigma_t_sqrt @ c_inv_sqrt @ sigma_t_sqrt

    return prec


def sqrt_inverse(mat: torch.Tensor, jitter: float = 0.0) -> torch.Tensor:
    """Compute inverse square root of a matrix."""
    assert jitter >= 0.0

    eigvals, eigvecs = torch.linalg.eigh(mat)

    inv_sqrt_eigvals = torch.where(
        eigvals <= jitter, torch.zeros_like(eigvals), 1.0 / torch.sqrt(eigvals)
    )
    mat_inv = eigvecs @ torch.diag(inv_sqrt_eigvals) @ eigvecs.T
    return mat_inv


def debug_eigvalues(msg: str, eigvals: torch.Tensor):
    """Debug helper to print eigenvalue statistics."""
    emin = eigvals.min().item()
    emax = eigvals.max().item()
    cond = emax / emin if emin > 0 else float("inf")
    print(f"{msg}.  λ_min = {emin:.3e}, λ_max = {emax:.3e}, κ = {cond:.3e}")


def ordered_groups(
    groups: dict[str, Any], named_parameters: tuple[tuple[str, torch.Tensor], ...]
) -> list[Any]:
    """Extract ordered groups from parameters."""
    result = tuple([groups[k] for k, v in named_parameters])
    return result


def svd(mat, hermitian: bool = True):
    s, u = torch.linalg.eigh(mat)
    return u, s, u.T


def inv_sqrt_mat(
    u: NDArray, s: NDArray, vt: NDArray, damping: float = 0.0
) -> torch.Tensor:
    """Compute inverse square root of a matrix."""
    inv_sqrt_s = torch.where(s > damping, 1.0 / torch.sqrt(s), 0.0)
    mat_inv = (u * inv_sqrt_s[None]) @ vt
    return mat_inv


def sqrt_mat(u: NDArray, s: NDArray, vt: NDArray, damping: float = 0.0) -> torch.Tensor:
    """Compute inverse square root of a matrix."""
    sqrt_s = torch.where(s > damping, torch.sqrt(s), 0.0)
    mat_sqrt = (u * sqrt_s[None]) @ vt
    return mat_sqrt


def order_group_structure(
    params: list[tuple[str, NDArray]],
    groups: tuple[tuple[str, Any], ...],
):
    groups_dict = dict(groups)
    ordered_groups = tuple([groups_dict[k] for k, _ in params])
    return ordered_groups


def apply_momentum(
    values,
    new_values,
    beta,
):
    """Apply momentum."""

    for i, val in enumerate(values):
        new_val = new_values[i]
        val.lerp_(new_val, 1 - beta)


def apply_bias(
    values,
    beta: torch.Tensor,
    step: torch.Tensor,
):
    """Apply bias correction."""

    bias_corr = 1 - beta**step
    updates = []

    for val in values:
        val_corr = val / bias_corr
        updates.append(val_corr)

    return updates


def apply_grads_beta(
    bufs,
    values,
    beta,
    step: torch.Tensor,
    bias: bool = True,
):
    """Apply momentum with optional bias correction."""

    bias_corr = 1 - beta**step
    updates = []

    for i, val in enumerate(values):
        buf = bufs[i]
        buf.lerp_(val, 1 - beta)

        if not bias:
            updates.append(buf)
        else:
            buf_corr = buf / bias_corr
            updates.append(buf_corr)

    return updates


def apply_factors_beta(
    bufs,
    values,
    beta,
    step: torch.Tensor,
    bias: bool = True,
):
    """Apply momentum with optional bias correction."""

    bias_corr = 1 - beta**step
    updates = []

    for i, val in enumerate(values):
        buf = bufs[i]
        weights_val = val.weights
        weights_buf = bufs[i].weights
        weights_buf.lerp_(weights_val, 1 - beta)

        if not bias:
            updates.append(buf)

        else:
            weights_buf_corr = weights_buf / bias_corr
            buf_corr = buf.__class__(buf.eq, weights_buf_corr)
            updates.append(buf_corr)

    return updates


def update_with_lr(lr: float | NDArray, params, updates):
    for i, p in enumerate(params):
        u = updates[i]
        p.sub_(u, alpha=lr)
