from typing import Tuple
# import math
import torch
from omegaconf import DictConfig, OmegaConf

from pado.core.base.optimizer import PadoOptimizer
from pado.optim import register_optimizer

__all__ = ["Adam"]


@register_optimizer("Adam")
class Adam(PadoOptimizer):

    def __init__(self,
                 params,
                 lr: float = 1e-3,
                 betas: Tuple[float, float] = (0.9, 0.999),
                 eps: float = 1e-8,
                 weight_decay: float = 0.0,
                 *, decoupled: bool = False,
                 centralize: bool = False,  # gradient centralization
                 centralize_dim: int = 0,
                 lookahead: bool = False,
                 lookahead_period: int = 5,
                 lookahead_alpha: float = 0.5,
                 ) -> None:
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if eps <= 0.0:
            raise ValueError(f"[Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {weight_decay}")
        if lookahead and (not (0.0 < lookahead_alpha < 1.0)):
            raise ValueError(f"Invalid lookahead alpha value: {lookahead_alpha}")

        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            decoupled=decoupled,
            centralize=centralize,
            centralize_dim=centralize_dim,
            lookahead=lookahead,
            lookahead_period=lookahead_period,
            lookahead_alpha=lookahead_alpha,
        )
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            beta1, beta2 = group["betas"]
            weight_decay = group["weight_decay"]
            eps = group["eps"]
            decoupled = group["decoupled"]
            centralize = group["centralize"]
            centralize_dim = group["centralize_dim"]
            lookahead = group["lookahead"]
            lookahead_period = group["lookahead_period"]
            lookahead_alpha = group["lookahead_alpha"]

            for p in group["params"]:
                if p.grad is None:
                    continue

                grad = p.grad

                # -------------------------------- #
                # Centralize
                # -------------------------------- #
                if centralize and (p.ndim > 1):
                    grad_mean = self._compute_mean_over_dim(grad, dim=centralize_dim)
                    grad.sub_(grad_mean)

                # -------------------------------- #
                # Weight decay
                # -------------------------------- #
                if weight_decay != 0:
                    if decoupled:
                        p.mul_(1 - lr * weight_decay)
                    else:
                        grad = grad.add(p, alpha=weight_decay)

                # -------------------------------- #
                # Gradient momentum
                # -------------------------------- #
                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)

                if lookahead and ("lookahead_weight" not in state):
                    state["lookahead_weight"] = torch.clone(p).detach()

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                state["step"] += 1

                # -------------------------------- #
                # Update
                # -------------------------------- #
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                bias_correction1 = 1 - beta1 ** state["step"]
                bias_correction2 = 1 - beta2 ** state["step"]

                variance_est = exp_avg_sq / bias_correction2
                denominator = variance_est.sqrt().add_(eps)

                step_size = lr / bias_correction1
                p.addcdiv_(exp_avg, denominator, value=-step_size)

                # -------------------------------- #
                # Lookahead
                # -------------------------------- #
                if lookahead and (state["step"] % lookahead_period == 0):
                    p.mul_(lookahead_alpha).add_(state["lookahead_weight"], alpha=1 - lookahead_alpha)
                    state["lookahead_weight"].copy_(p.data)

        return loss

    @classmethod
    def from_config(cls, cfg: DictConfig, params) -> "Adam":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(params, **cfg)
