from typing import Tuple
import torch
import torch.nn.functional as F
import math
import collections


def centralize_gradient(x):
    # credit - https://github.com/Yonghongwei/Gradient-Centralization

    if x.dim() > 1:
        x.data.add_(-x.mean(dim=tuple(range(1, x.dim())), keepdim=True))


class SING(torch.optim.Optimizer):
    def __init__(
        self,
        params,
        lr: float = 5e-3,
        betas: Tuple[float, float] = (0.9, 0.999),
        weight_decay: float = 0,
        eps: float = 1e-8,
        grad_central: bool = True,
        grad_norm: bool = True,
        softplus: bool = True,
        softplus_beta: int = 50,
        lookahead_active: bool = True,
        la_mergetime: int = 5,
        la_alpha: float = 0.5
    ):

        defaults = dict(
            lr=lr, betas=betas, eps=eps,
            weight_decay=weight_decay,
            grad_central=grad_central,
            grad_norm=grad_norm,
            softplus=softplus, softplus_beta=softplus_beta,
            lookahead_active=lookahead_active,
            la_mergetime=la_mergetime, la_alpha=la_alpha,
            la_step=0
        )
        super().__init__(params, defaults)

    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault('la_step', 0)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None and isinstance(closure, collections.Callable):
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            weight_decay = group["weight_decay"]
            eps = group["eps"]
            lr = group["lr"]
            beta1, beta2 = group["betas"]

            for p in group["params"]:
                if p.grad is None:
                    continue

                if p.grad.is_sparse:
                    raise RuntimeError("sparse matrix not supported atm")

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state["step"] = 0

                    state["grad_ma"] = torch.zeros_like(
                        p, memory_format=torch.preserve_format
                    )
                    state["variance_ma"] = torch.zeros_like(
                        p, memory_format=torch.preserve_format
                    )

                    if group['lookahead_active']:
                        state["lookahead_params"] = torch.zeros_like(p.data)
                        state["lookahead_params"].copy_(p.data)

                # Gradient centralization
                if group["grad_central"]:
                    centralize_gradient(p.grad)

                # Gradient normalization
                if group["grad_norm"]:
                    p.grad.div_(p.grad.norm() + eps)

                state["step"] += 1

                # Update the running mean and variance of the gradients (Adam)
                variance_ma = state["variance_ma"]
                grad_ma = state["grad_ma"]

                grad_ma.mul_(beta1).add_(p.grad, alpha=1 - beta1)
                variance_ma.mul_(beta2).addcmul_(
                    p.grad, p.grad, value=1 - beta2)

                bias_correction1 = 1 - beta1 ** state["step"]
                bias_correction2 = 1 - beta2 ** state["step"]
                step_size = lr / bias_correction1

                # Weight decay (decoupled like AdamW)
                # Only apply weight decay to weights: https://arxiv.org/pdf/1812.01187.pdf
                if weight_decay and p.dim() > 1:
                    p.data.mul_(1 - lr * weight_decay)

                # Computing the denominator (Adam)
                denom = variance_ma.sqrt() / math.sqrt(bias_correction2)

                # SAdam : https://arxiv.org/abs/1908.00700
                if group['softplus']:
                    denom = F.softplus(denom, beta=group["softplus_beta"])
                else:
                    denom.add_(eps)

                # Update the parameter
                p.addcdiv_(grad_ma, denom, value=-step_size)

        # lookahead
        for group in self.param_groups:
            if not group['lookahead_active']:
                continue

            group['la_step'] += 1
            la_alpha = group['la_alpha']

            if group['la_step'] >= group['la_mergetime']:
                group['la_step'] = 0

                for p in group["params"]:
                    if p.grad is None:
                        continue

                    param_state = self.state[p]

                    p.data.mul_(la_alpha).add_(
                        param_state["lookahead_params"],
                        alpha=1.0 - la_alpha)
                    param_state["lookahead_params"].copy_(p.data)

        return loss
