import functools
import typing
from collections.abc import Callable
from typing import Any

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm
from torch.optim import Optimizer
from torch.optim.optimizer import ParamsT

import vendor.muon

from .nn2 import LinearEmbeddingsPack, ParameterPack, get_pack_size, make_keep_pack_idx


# ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――
# Utilities
# ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――
def _is_valid_lr(value: float) -> bool:
    return value >= 0.0


def _is_valid_momentum(value: float) -> bool:
    return 0.0 <= value < 1.0


def _is_valid_beta(value: float) -> bool:
    return _is_valid_momentum(value)


def _is_valid_weight_decay(value: float) -> bool:
    return value >= 0.0


def _is_valid_eps(value: float) -> bool:
    return value > 0.0


def default_zero_weight_decay_condition(
    module_name: str, module: nn.Module, parameter_name: str, parameter: nn.Parameter
):
    from .. import deep as libdeep

    return libdeep.default_zero_weight_decay_condition(
        module_name, module, parameter_name, parameter
    ) or isinstance(module, LinearEmbeddingsPack)


# ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――
# Conventional optimizers
# ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――
class Signum(Optimizer):
    def __init__(
        self,
        params: ParamsT,
        *,
        lr: float,
        momentum: float = 0.9,
        weight_decay: float,
    ) -> None:
        assert _is_valid_lr(lr)
        assert _is_valid_momentum(momentum)
        assert _is_valid_weight_decay(weight_decay)
        defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):  # type: ignore
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            momentum = group['momentum']
            weight_decay = group['weight_decay']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad.data

                state = self.state[p]
                if len(state) == 0:
                    # Initialize the state.
                    state['exp_avg'] = torch.zeros_like(p)
                exp_avg = state['exp_avg']

                # Apply the decoupled weight decay.
                p.mul_(1 - lr * weight_decay)

                exp_avg.lerp_(grad, 1 - momentum)
                p.sub_(exp_avg.sign(), alpha=lr)

        return loss


class AdamWReference(torch.optim.Optimizer):
    """A minimal reference implementation of AdamW.

    The implementation is based on:
    https://github.com/pytorch/pytorch/blob/e2d141dbde55c2a4370fac5165b0561b6af4798b/torch/optim/adam.py#L344

    Thus, this optimizer produces _exactly_ the same (i.e. bitwise equivalent) results
    as `torch.optim.AdamW(..., foreach=False, fused=False)` as of `torch==2.7.1`.
    """

    def __init__(
        self,
        params: ParamsT,
        *,
        lr: float,
        betas: tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-8,
        weight_decay: float,
    ):
        assert lr >= 0.0
        assert 0.0 <= betas[0] < 1.0
        assert 0.0 <= betas[1] < 1.0
        assert weight_decay >= 0.0

        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(  # type: ignore
        self,
        closure: None | Callable[[], Tensor] = None,
    ) -> None | Tensor:
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            beta1, beta2 = group['betas']
            eps = group['eps']
            weight_decay = group['weight_decay']

            for p in group['params']:
                p = typing.cast(Tensor, p)
                if p.grad is None:
                    continue

                grad = p.grad.data
                assert not grad.is_sparse, 'Sparse gradients are not supported'

                state = self.state[p]
                if len(state) == 0:
                    # Initialize the state.
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                exp_avg: Tensor = state['exp_avg']
                exp_avg_sq: Tensor = state['exp_avg_sq']

                state['step'] += 1

                # Apply the decoupled weight decay.
                if weight_decay != 0:
                    p.mul_(1 - lr * weight_decay)

                # Update biased first moment estimate.
                exp_avg.lerp_(grad, 1 - beta1)

                # Update biased second moment estimate.
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                # Bias correction
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']

                step_size = lr / bias_correction1
                denom = (exp_avg_sq.sqrt() / bias_correction2**0.5).add_(eps)

                p.addcdiv_(exp_avg, denom, value=-step_size)

        return loss


# ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――
# Sharpness-aware optimizers
# ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――
# This is an implementation of SAM and GSAM based on this file:
# https://github.com/juntang-zhuang/GSAM/blob/a770275b3324a3ed4bf4a1cddb976b2782ba852b/gsam/gsam.py
#
# Compared to the above source:
# - The math is NOT changed.
# - The original SAM is supported by setting alpha=None.
# - The API is closer to standard PyTorch optimizers.
# - Some of the original features (e.g. adaptive=True, multi-GPU training, etc.)
#   are not supported.
#
# NOTE: with AdamWSharpnessAware, the results are not deterministic for unknown reasons.

_SAM_EPS = 1e-12


def _disable_running_stats(model):
    def _disable(module):
        if isinstance(module, _BatchNorm):
            module.backup_momentum = module.momentum  # type: ignore
            module.momentum = 0

    model.apply(_disable)


def _enable_running_stats(model):
    def _enable(module):
        if isinstance(module, _BatchNorm) and hasattr(module, 'backup_momentum'):
            module.momentum = module.backup_momentum  # type: ignore

    model.apply(_enable)


class AdamWSharpnessAware(torch.optim.AdamW):
    def __init__(
        self, params: ParamsT, *, rho: float, alpha: None | float = None, **kwargs
    ) -> None:
        super().__init__(params, **kwargs)

        self.defaults['rho'] = rho
        for group in self.param_groups:
            group['rho'] = rho

        if alpha is not None:
            self._is_gsam = True
            self.defaults['alpha'] = alpha
            for group in self.param_groups:
                group['alpha'] = rho
        else:
            self._is_gsam = False

    def _assemble_gradient(self) -> Tensor:
        return torch.cat(
            [
                p.grad.flatten()
                for group in self.param_groups
                for p in group['params']
                if p.grad is not None
            ]
        )

    def _assemble_original_gradient(
        self, original_gradients: dict[torch.nn.Parameter, Tensor]
    ) -> Tensor:
        return torch.cat(
            [
                original_gradients[p].flatten()
                for group in self.param_groups
                for p in group['params']
                if p.grad is not None
            ]
        )

    @torch.no_grad()
    def _perturb_parameters(
        self,
    ) -> tuple[
        dict[torch.nn.Parameter, Tensor], None | dict[torch.nn.Parameter, Tensor]
    ]:
        original_parameters = {}
        original_gradients = {} if self._is_gsam else None

        grad_norm = self._assemble_gradient().norm()
        for group in self.param_groups:
            scale = group['rho'] / (grad_norm + _SAM_EPS)
            for p in group['params']:
                if p.grad is None:
                    continue
                original_parameters[p] = p.data.clone()
                if self._is_gsam:
                    assert original_gradients is not None
                    original_gradients[p] = p.grad.data.clone()
                e_w = p.grad * scale
                # Perturb the parameter (climb to the local maximum).
                p.add_(e_w)

        return original_parameters, original_gradients

    @torch.no_grad()
    def _restore_original_parameters(
        self, original_parameters: dict[torch.nn.Parameter, Tensor]
    ):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                p.data = original_parameters[p]

    @torch.no_grad()
    def _apply_gsam(self, original_gradients: dict[torch.nn.Parameter, Tensor]):
        grad = self._assemble_gradient()
        original_grad = self._assemble_original_gradient(original_gradients)

        grad_norm = grad.norm()
        original_grad_norm = original_grad.norm()

        cosine = torch.dot(grad, original_grad) / (
            grad_norm * original_grad_norm + _SAM_EPS
        )

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                # fmt: off
                vertical = (
                    original_gradients[p]
                    - cosine * original_grad_norm * p.grad / (grad_norm + _SAM_EPS)
                )
                # fmt: on
                p.grad.data.sub_(vertical, alpha=group['alpha'])

    def step(self, closure, model: torch.nn.Module) -> Tensor:  # type: ignore
        # Compute the original gradients.
        self.zero_grad()
        with torch.enable_grad():
            loss = closure()
        # Perturb the weights based on the original gradients.
        original_parameters, original_gradients = self._perturb_parameters()

        # Compute the gradients for the perturbed weights.
        _disable_running_stats(model)
        self.zero_grad()
        with torch.enable_grad():
            closure()
        if self._is_gsam:
            assert original_gradients is not None
            self._apply_gsam(original_gradients)

        # Restore the original weights and perform the actual optimization step.
        self._restore_original_parameters(original_parameters)
        super().step()
        _enable_running_stats(model)

        return loss


# ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――
# Optimizer packs
# ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――
def _is_shared_group_value(value) -> bool:
    """Check if the group value is shared between pack members."""
    return value is None or isinstance(
        value, bool | int | float | str | bytes | tuple | dict
    )


class OptimizerPack(Optimizer):
    def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None:
        super().__init__(
            params,
            {
                # `torch.tensor` is used to ensure that each group
                # receives a tensor with a separate storage.
                k: v if _is_shared_group_value(v) else torch.tensor(v)
                for k, v in defaults.items()
            },
        )

        for group in self.param_groups:
            group_params = group['params']
            if not group_params:
                continue
            for p in group_params:
                assert isinstance(p, ParameterPack), (
                    'For now, only parameter packs are supported'
                )
            device = group_params[0].device
            for key, value in list(group.items()):
                if key != 'params' and isinstance(value, Tensor):
                    group[key] = value.to(device=device)


def _make_weight_decay_multiplier(
    *, lr: float | Tensor, weight_decay: float | Tensor
) -> None | float | Tensor:
    return (
        None
        if (
            (isinstance(lr, float) and lr == 0.0)
            or (isinstance(weight_decay, float) and weight_decay == 0.0)
        )
        else (1 - lr * weight_decay)
    )


@typing.overload
def _maybe_unsqueeze(value: Tensor, *, p: Tensor) -> Tensor: ...


@typing.overload
def _maybe_unsqueeze[T](value: T, *, p: Tensor) -> T: ...


def _maybe_unsqueeze(value, *, p):
    return value[:, *((None,) * (p.ndim - 1))] if isinstance(value, Tensor) else value


class SignumPack(OptimizerPack):
    def __init__(
        self,
        params: ParamsT,
        *,
        lr: float | list[float],
        momentum: float | list[float] = 0.9,
        weight_decay: float | list[float],
        pack_size: int,
    ) -> None:
        assert pack_size > 0

        defaults = {}
        for key, value, is_valid_fn in [
            ('lr', lr, _is_valid_lr),
            ('momentum', momentum, _is_valid_momentum),
            ('weight_decay', weight_decay, _is_valid_weight_decay),
        ]:
            if isinstance(value, list):
                assert len(value) == pack_size
                assert all(map(is_valid_fn, value))
            else:
                is_valid_fn(value)
            defaults[key] = value

        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):  # type: ignore
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            momentum = group['momentum']

            weight_decay_multiplier = _make_weight_decay_multiplier(
                lr=lr, weight_decay=group['weight_decay']
            )

            for p in group['params']:
                assert isinstance(p, ParameterPack), (
                    'For now, only parameter packs are supported'
                )
                if p.grad is None:
                    continue

                grad = p.grad.data

                maybe_unsqueeze = functools.partial(_maybe_unsqueeze, p=p)

                state = self.state[p]
                if len(state) == 0:
                    # Initialize the state.
                    state['momentum_buffer'] = torch.zeros_like(p.data)
                momentum_buffer: Tensor = state['momentum_buffer']

                if weight_decay_multiplier is not None:
                    p.mul_(maybe_unsqueeze(weight_decay_multiplier))

                momentum_buffer.lerp_(grad, maybe_unsqueeze(1 - momentum))

                if isinstance(lr, float):
                    p.sub_(momentum_buffer.sign(), alpha=lr)
                else:
                    p.sub_(momentum_buffer.sign().mul_(maybe_unsqueeze(lr)))

        return loss


class AdamWPack(OptimizerPack):
    def __init__(
        self,
        params: ParamsT,
        *,
        lr: float | list[float],
        beta1: float | list[float] = 0.9,
        beta2: float | list[float] = 0.999,
        eps: float | list[float] = 1e-8,
        weight_decay: float | list[float],
        pack_size: int,
        shared_step: bool = False,
        follow_pytorch: bool = True,
    ):
        assert pack_size > 0

        defaults = {}
        for key, value, is_valid_fn in [
            ('lr', lr, _is_valid_lr),
            ('beta1', beta1, _is_valid_beta),
            ('beta2', beta2, _is_valid_beta),
            ('eps', eps, _is_valid_eps),
            ('weight_decay', weight_decay, _is_valid_weight_decay),
        ]:
            if isinstance(value, list):
                assert len(value) == pack_size
                assert all(map(is_valid_fn, value))
            else:
                is_valid_fn(value)
            defaults[key] = value

        super().__init__(params, defaults)
        self._shared_step = shared_step
        self._follow_pytorch = follow_pytorch

    @torch.no_grad()
    def step(  # type: ignore
        self,
        closure: None | Callable[[], Tensor] = None,
    ) -> None | Tensor:
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            beta1 = group['beta1']
            beta2 = group['beta2']
            eps = group['eps']

            weight_decay_multiplier = _make_weight_decay_multiplier(
                lr=lr, weight_decay=group['weight_decay']
            )

            for p in group['params']:
                assert isinstance(p, ParameterPack), (
                    'For now, only parameter packs are supported'
                )
                if p.grad is None:
                    continue

                grad = p.grad.data
                assert not grad.is_sparse, 'Sparse gradients are not supported'

                maybe_unsqueeze = functools.partial(_maybe_unsqueeze, p=p)

                state = self.state[p]
                if len(state) == 0:
                    # Initialize the state.
                    state['step'] = (
                        0
                        if self._shared_step
                        else torch.zeros(
                            get_pack_size(p), dtype=torch.int64, device=p.device
                        )
                    )
                    state['exp_avg'] = torch.zeros_like(p.data)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                exp_avg = state['exp_avg']
                exp_avg_sq = state['exp_avg_sq']

                state['step'] += 1

                if weight_decay_multiplier is not None:
                    p.mul_(maybe_unsqueeze(weight_decay_multiplier))

                # Update biased first moment estimate.
                exp_avg.lerp_(grad, maybe_unsqueeze(1 - beta1))

                # Update biased second raw moment estimate.
                if self._follow_pytorch and isinstance(beta2, float):
                    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                else:
                    exp_avg_sq.lerp_(grad.square(), maybe_unsqueeze(1 - beta2))

                # Perform the bias correction.
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']

                step_size = lr / bias_correction1

                if self._follow_pytorch and isinstance(bias_correction2, float):
                    denom = exp_avg_sq.sqrt() / bias_correction2**0.5
                else:
                    denom = exp_avg_sq.sqrt().div_(
                        maybe_unsqueeze(bias_correction2**0.5)
                    )
                denom.add_(maybe_unsqueeze(eps))

                if isinstance(step_size, float):
                    p.addcdiv_(exp_avg, denom, value=-step_size)
                else:
                    p.sub_((exp_avg / denom).mul_(maybe_unsqueeze(step_size)))

        return loss


class MuonAdamWPack(OptimizerPack):
    def __init__(
        self,
        params: ParamsT,
        *,
        lr: float | list[float],
        beta1: float | list[float] = 0.9,
        beta2: float | list[float] = 0.999,
        eps: float | list[float] = 1e-8,
        weight_decay: float | list[float],
        muon_lr: None | float | list[float],
        muon_momentum: float = 0.95,
        muon_weight_decay: None | float | list[float] = None,
        muon_ns_steps: int = 5,
        muon_nesterov: bool = True,
        pack_size: int,
        shared_step: bool = False,
        follow_pytorch: bool = True,
    ):
        assert pack_size > 0

        defaults: dict[str, Any] = {
            'muon': False,
            'muon_ns_steps': muon_ns_steps,
            'muon_nesterov': muon_nesterov,
        }
        for key, value, is_valid_fn in [
            ('lr', lr, _is_valid_lr),
            ('beta1', beta1, _is_valid_beta),
            ('beta2', beta2, _is_valid_beta),
            ('eps', eps, _is_valid_eps),
            ('weight_decay', weight_decay, _is_valid_weight_decay),
            ('muon_lr', muon_lr, _is_valid_lr),
            ('muon_momentum', muon_momentum, _is_valid_beta),
            ('muon_weight_decay', muon_weight_decay, _is_valid_weight_decay),
        ]:
            if isinstance(value, list):
                assert len(value) == pack_size
                assert all(map(is_valid_fn, value))
            elif value is not None:
                is_valid_fn(value)
            defaults[key] = value

        super().__init__(params, defaults)
        self._shared_step = shared_step
        self._follow_pytorch = follow_pytorch

    def _step_muon(self, group: dict[str, Any]) -> None:
        lr = group['muon_lr']
        momentum = group['muon_momentum']
        weight_decay = group['muon_weight_decay']
        ns_steps = group['muon_ns_steps']
        nesterov = group['muon_nesterov']

        if lr is None:
            lr = group['lr']
        if weight_decay is None:
            weight_decay = group['weight_decay']

        weight_decay_multiplier = _make_weight_decay_multiplier(
            lr=lr, weight_decay=group['weight_decay']
        )

        for p in group['params']:
            # 3 = 2 layer dimensions + 1 pack dimension
            assert p.ndim == 3
            assert isinstance(p, ParameterPack), (
                'For now, only parameter packs are supported'
            )
            if p.grad is None:
                continue

            grad = p.grad.data
            assert not grad.is_sparse, 'Sparse gradients are not supported'

            maybe_unsqueeze = functools.partial(_maybe_unsqueeze, p=p)

            state = self.state[p]
            if len(state) == 0:
                # Initialize the state.
                state['muon_momentum_buffer'] = torch.zeros_like(p)
            momentum_buffer: Tensor = state['muon_momentum_buffer']

            if weight_decay_multiplier is not None:
                p.mul_(maybe_unsqueeze(weight_decay_multiplier))

            momentum_buffer.lerp_(grad, maybe_unsqueeze(1 - momentum))
            update = (
                grad.lerp_(momentum_buffer, maybe_unsqueeze(momentum))
                if nesterov
                else momentum_buffer
            )
            update = vendor.muon.zeropower_via_newtonschulz5(update, steps=ns_steps)
            update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5

            assert update.shape == p.shape
            if isinstance(lr, float):
                p.sub_(update, alpha=lr)
            else:
                p.sub_(update.mul_(maybe_unsqueeze(lr)))

    def _step_adamw(self, group: dict[str, Any]) -> None:
        lr = group['lr']
        beta1 = group['beta1']
        beta2 = group['beta2']
        eps = group['eps']

        weight_decay_multiplier = _make_weight_decay_multiplier(
            lr=lr, weight_decay=group['weight_decay']
        )

        for p in group['params']:
            assert isinstance(p, ParameterPack), (
                'For now, only parameter packs are supported'
            )
            if p.grad is None:
                continue

            grad = p.grad.data
            assert not grad.is_sparse, 'Sparse gradients are not supported'

            maybe_unsqueeze = functools.partial(_maybe_unsqueeze, p=p)

            state = self.state[p]
            if len(state) == 0:
                # Initialize the state.
                state['step'] = (
                    0
                    if self._shared_step
                    else torch.zeros(
                        get_pack_size(p), dtype=torch.int64, device=p.device
                    )
                )
                state['exp_avg'] = torch.zeros_like(p.data)
                state['exp_avg_sq'] = torch.zeros_like(p.data)
            exp_avg = state['exp_avg']
            exp_avg_sq = state['exp_avg_sq']

            state['step'] += 1

            if weight_decay_multiplier is not None:
                p.mul_(maybe_unsqueeze(weight_decay_multiplier))

            # Update biased first moment estimate.
            exp_avg.lerp_(grad, maybe_unsqueeze(1 - beta1))

            # Update biased second raw moment estimate.
            if self._follow_pytorch and isinstance(beta2, float):
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
            else:
                exp_avg_sq.lerp_(grad.square(), maybe_unsqueeze(1 - beta2))

            # Perform the bias correction.
            bias_correction1 = 1 - beta1 ** state['step']
            bias_correction2 = 1 - beta2 ** state['step']

            step_size = lr / bias_correction1

            if self._follow_pytorch and isinstance(bias_correction2, float):
                denom = exp_avg_sq.sqrt() / bias_correction2**0.5
            else:
                denom = exp_avg_sq.sqrt().div_(maybe_unsqueeze(bias_correction2**0.5))
            denom.add_(maybe_unsqueeze(eps))

            if isinstance(step_size, float):
                p.addcdiv_(exp_avg, denom, value=-step_size)
            else:
                p.sub_((exp_avg / denom).mul_(maybe_unsqueeze(step_size)))

    @torch.no_grad()
    def step(  # type: ignore
        self,
        closure: None | Callable[[], Tensor] = None,
    ) -> None | Tensor:
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            if group['muon']:
                self._step_muon(group)
            else:
                self._step_adamw(group)

        return loss


def optimizer_pack_remove(
    optimizer: torch.optim.Optimizer,
    pack_idx: Tensor,
    old_to_new: dict[ParameterPack, ParameterPack],
) -> None:
    assert len(pack_idx) > 0
    assert old_to_new

    pack_size = len(next(iter(old_to_new.keys())))
    keep_pack_idx = make_keep_pack_idx(pack_size, pack_idx)

    for group in optimizer.param_groups:
        for key, value in list(group.items()):
            if isinstance(value, Tensor) and value.ndim > 0:
                group[key] = value[keep_pack_idx].clone()
            del key, value

        for i, p in list(enumerate(group['params'])):
            if isinstance(p, ParameterPack):
                state = optimizer.state.pop(p, None)
                # `state` can be missing even if optimizer has already been used.
                # For example, in MLPBackbonePack, some of the blocks remain unused
                # (and thus don't have the corresponding optimizer states)
                # if the maximum allowed number of blocks is never used.
                if state is not None:
                    for key, value in list(state.items()):
                        if isinstance(value, Tensor) and value.ndim > 0:
                            state[key] = value[keep_pack_idx].clone()
                        del key, value
                p_new = old_to_new[p]
                group['params'][i] = p_new
                if state is not None:
                    optimizer.state[p_new] = state
