import torch
from omegaconf import DictConfig, OmegaConf

from pado.core.base.optimizer import PadoOptimizer
from pado.optim import register_optimizer

__all__ = ["SGD"]


@register_optimizer("SGD")
class SGD(PadoOptimizer):
    def __init__(self,
                 params,
                 lr: float = 0.1,
                 momentum: float = 0.0,
                 dampening: float = 0.0,
                 weight_decay: float = 0.0,
                 nesterov: bool = False,
                 *, decoupled: bool = False,
                 centralize: bool = False,  # gradient centralization
                 centralize_dim: int = 0,
                 lookahead: bool = False,
                 lookahead_period: int = 5,
                 lookahead_alpha: float = 0.5,
                 ) -> None:
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not (0.0 <= momentum < 1.0):
            raise ValueError(f"Invalid momentum value: {momentum}")
        if weight_decay < 0.0:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError(f"Nesterov momentum requires a momentum and zero dampening.")
        if lookahead and (not (0.0 < lookahead_alpha < 1.0)):
            raise ValueError(f"Invalid lookahead alpha value: {lookahead_alpha}")

        defaults = dict(
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
            decoupled=decoupled,
            centralize=centralize,
            centralize_dim=centralize_dim,
            lookahead=lookahead,
            lookahead_period=lookahead_period,
            lookahead_alpha=lookahead_alpha,
        )
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            momentum = group["momentum"]
            dampening = group["dampening"]
            weight_decay = group["weight_decay"]
            nesterov = group["nesterov"]
            decoupled = group["decoupled"]
            centralize = group["centralize"]
            centralize_dim = group["centralize_dim"]
            lookahead = group["lookahead"]
            lookahead_period = group["lookahead_period"]
            lookahead_alpha = group["lookahead_alpha"]

            for p in group["params"]:
                if p.grad is None:
                    continue

                grad = p.grad

                # -------------------------------- #
                # Centralize
                # -------------------------------- #
                if centralize and (p.ndim > 1):
                    grad_mean = self._compute_mean_over_dim(grad, dim=centralize_dim)
                    grad.sub_(grad_mean)

                # -------------------------------- #
                # Weight decay
                # -------------------------------- #
                if weight_decay != 0:
                    if decoupled:
                        p.mul_(1 - lr * weight_decay)
                    else:
                        grad = grad.add(p, alpha=weight_decay)

                # -------------------------------- #
                # Gradient momentum
                # -------------------------------- #
                state = self.state[p]
                if len(state) == 0:
                    state["step"] = 0
                if lookahead and ("lookahead_weight" not in state):
                    state["lookahead_weight"] = torch.clone(p).detach()

                if momentum != 0:
                    if "momentum_buffer" not in state:
                        buf = state["momentum_buffer"] = torch.clone(grad).detach()
                    else:
                        buf = state["momentum_buffer"]
                        buf.mul_(momentum).add_(grad, alpha=1 - dampening)

                    if nesterov:
                        grad = grad.add(buf, alpha=momentum)
                    else:
                        grad = buf

                state["step"] += 1

                # -------------------------------- #
                # Update
                # -------------------------------- #
                p.add_(grad, alpha=-lr)

                # -------------------------------- #
                # Lookahead
                # -------------------------------- #
                if lookahead and (state["step"] % lookahead_period == 0):
                    p.mul_(lookahead_alpha).add_(state["lookahead_weight"], alpha=1 - lookahead_alpha)
                    state["lookahead_weight"].copy_(p.data)

        return loss

    @classmethod
    def from_config(cls, cfg: DictConfig, params) -> "SGD":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(params, **cfg)
