from typing import Any, Callable, Union, MutableMapping, Sequence, Literal
from functools import partial
import torch
from torch.optim import Optimizer

import torch.nn as nn
import numpy as np
from symo.factory2 import GroupsSpec, CovFactory, MeanFactory
from symo.utils import to_dtype

NDArray = torch.Tensor


class NSymo(Optimizer):
    """NS Symo optimizer."""

    @torch.no_grad
    def __init__(
        self,
        params,
        groups_spec: GroupsSpec,
        lr: float | Callable = 1e-1,
        weight_decay: float = 0.1,
        grads_beta: float = 0.0,
        factors_beta: float = 0.0,
        grads_bias_corr: bool = False,
        factors_bias_corr: bool = True,
        update_correction: bool = False,
        update_avg: bool = False,
        block_diag: bool = False,
        nesterov = True,
    ):
        if not 0.0 <= grads_beta <= 1.0:
            raise ValueError(f"Invalid grads_beta value: {grads_beta}")
        if not 0.0 <= factors_beta <= 1.0:
            raise ValueError(f"Invalid factors_beta value: {factors_beta}")

        params = list(params)
        # TODO(bla): Global factors buffer. Generalize to multiple parameter groups!

        dev_cfg = dict(device=params[0].device, dtype=params[0].dtype)

        avg_factory = MeanFactory(groups_spec).to(**dev_cfg)
        cov_factory = CovFactory(groups_spec, block_diag_only=block_diag).to(**dev_cfg)

        defaults = dict(
            lr=lr,
            weight_decay=weight_decay,
            grads_beta=grads_beta,
            factors_beta=factors_beta,
            groups_spec=groups_spec,
            grads_bias_corr=grads_bias_corr,
            factors_bias_corr=factors_bias_corr,
            update_correction=update_correction,
            update_avg=update_avg,
            nesterov=nesterov,
        )

        super().__init__(params, defaults)
        self.avg_factory = avg_factory
        self.cov_factory = cov_factory
        self.step_t = None

    def _init_group(
        self,
        group: MutableMapping,
    ):
        params_with_grad: list[NDArray] = []
        grads: list[NDArray] = []
        grad_momentum_bufs: list[NDArray] = []

        for p in group["params"]:
            if p.grad is None:
                raise RuntimeError(
                    "Symo requires gradients to be finite for all parameters"
                )

            if torch.is_complex(p):
                raise RuntimeError("Symo does not support complex parameters")
            if p.grad.is_sparse:
                raise RuntimeError("Symo does not support sparse gradients")

            params_with_grad.append(p)
            grads.append(p.grad)

            state = self.state[p]

            if "momentum_buffer" not in state:
                state["momentum_buffer"] = torch.zeros_like(
                    p.grad, memory_format=torch.preserve_format
                )

            grad_momentum_bufs.append(state["momentum_buffer"])

        if self.step_t is None:
            self.step_t = torch.tensor(0.0, dtype=p.dtype, device=p.device)

        return params_with_grad, grads, grad_momentum_bufs

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step."""
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            weight_decay = group["weight_decay"]
            grads_beta = group["grads_beta"]
            grads_corr = group["grads_bias_corr"]
            factors_beta = group["factors_beta"]
            factors_corr = group["factors_bias_corr"]
            update_correction = group["update_correction"]
            updates_avg = group["update_avg"]
            nesterov = group["nesterov"]
            cov_factory = self.cov_factory
            avg_factory = self.avg_factory

            group_variables = self._init_group(group)
            params, grads, grads_buf = group_variables

            self._symo_update(
                params,
                grads,
                grads_buf,
                avg_factory,
                cov_factory,
                self.step_t,
                lr,
                weight_decay=weight_decay,
                grads_beta=grads_beta,
                factors_beta=factors_beta,
                grads_corr=grads_corr,
                factors_corr=factors_corr,
                updates_corr=update_correction,
                updates_avg=updates_avg,
                nesterov=nesterov,
            )

        return loss

    def _symo_update(
        self,
        params: Sequence[NDArray | nn.Parameter],
        grads: Sequence[NDArray],
        grads_buf: Sequence[NDArray],
        avg_buf: MeanFactory,
        cov_factory: CovFactory,
        step: float | NDArray,
        lr: float | NDArray,
        weight_decay: float | NDArray,
        grads_beta: float | NDArray,
        factors_beta: float | NDArray,
        grads_corr: bool,
        factors_corr: bool,
        updates_corr: bool,
        updates_avg: bool,
        nesterov: bool,
    ):
        """Core Symo update logic."""
        step += 1

        apply_momentum(grads_buf, grads, grads_beta)

        new_grads = grads_buf
        if nesterov:
            new_grads = apply_nesterov(new_grads, grads, grads_beta)
        
        new_grads = normalize(new_grads, diag=cov_factory.block_diag_only)
        updates = newton_schulz_iter(
            new_grads, cov_factory
        )

        lr = adjust_lr(lr, cov_factory.groups_spec, params)
        weight_decay_params(lr, params, weight_decay)
        update_with_lr(lr, params, updates)

def newton_schulz_iter(
    init: Sequence[NDArray],
    cov_factory: CovFactory,
    ns: int = 5,
    ns_coefficients: tuple[float, float, float] = (3.4445, -4.775, 2.0315),
):
    x = init
    a,b,c = ns_coefficients
    poly = partial(polynomial, a=a, b=b, c=c)
    for i in range(ns):
        # carry out one iterations of the Newton–Schulz orthogonalization operator
        cov_factory.outer_update(x)
        third = cov_factory.matvec(x)
        fifth = cov_factory.matvec(third)
        x = poly(x, third, fifth)
    return x

def polynomial(
    first, third, fifth,
    a, b, c
):
    updates = []
    for x,y,z in zip(first, third, fifth):
        updates.append(
            a * x + b * y + c * z
        )
    return updates

def normalize(grads, eps=1e-7, diag=True):
    if diag: # normalize per parameter
        updates = [g/g.norm().clamp(min=eps) for g in grads]
    else: # normalize globally
        norms = [g.norm() for g in grads]
        total_norm = torch.stack(norms).norm().clamp(min=eps)
        updates = [g/total_norm for g in grads]
    return updates


def apply_momentum(
    buffer: Sequence[NDArray],
    new_values: Sequence[NDArray],
    beta: float | NDArray,
):
    """Apply momentum."""

    for i, buf in enumerate(buffer):
        new_val = new_values[i]
        buf.lerp_(new_val, 1 - beta)

def apply_nesterov(
    buffer: Sequence[NDArray],
    new_values: Sequence[NDArray],
    beta: float | NDArray,
):
    """Apply nesterov."""
    updates = []

    for i, grad in enumerate(new_values):
        updates.append(
            grad.lerp(buffer[i], beta)
        )
    return updates

def apply_bias(
    values: Sequence[NDArray],
    beta: float | NDArray,
    step: float | NDArray,
):
    """Apply bias correction."""

    bias_corr = 1 - beta**step
    updates = []

    for val in values:
        val_corr = val / bias_corr
        updates.append(val_corr)

    return updates


def values_diff(lhs: Sequence[NDArray], rhs: Sequence[NDArray]) -> Sequence[NDArray]:
    out = [l - r for (l, r) in zip(lhs, rhs)]
    return out

def update_with_lr(lr: float | list, params, updates):
    for i, p in enumerate(params):
        u = updates[i]
        p.sub_(u, alpha=lr[i])

def weight_decay_params(lr: float | list, params, decay = 0.1):
    for i,p in enumerate(params):
        p.mul_(1 - lr[i] * decay)

def adjust_lr(lr, groups_spec, params):
    groups = groups_spec.groups
    dim_sizes = groups_spec.dim_sizes

    adjusted = []
    for g,p in zip(groups, params):
        size = []
        for s in g:
            name, dim = s.split("_")
            if name != "I":
                size.append(dim_sizes[dim])
        size = np.prod([size]) # compensate for Cinv

        A, B = p.shape[:2]
        adjusted_ratio = np.sqrt(max(1, A / B))

        adj_lr = lr * adjusted_ratio / np.sqrt(size)
        adjusted.append(adj_lr)
    return adjusted
