#!/usr/bin/env python
# -*- coding: utf-8 -*-


"""AdaDQH v0
- Update based on adahd_v2_eps
- add weight_decouple, default to `True`
- add fix_decay, default to `False`
"""


from xmlrpc.client import boolean
import numpy as np
import torch
from torch.optim.optimizer import Optimizer

from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

from torch import Tensor

Params = Union[Iterable[Tensor], Iterable[Dict[str, Any]]]

LossClosure = Callable[[], float]
OptLossClosure = Optional[LossClosure]
Betas2 = Tuple[float, float]
State = Dict[str, Any]
OptFloat = Optional[float]
Nus2 = Tuple[float, float]

__all__ = ("AdaDQH",)


class AdaDQH(Optimizer):
    def __init__(
        self,
        params: Params,
        lr: float = 1e-3,
        betas: Betas2 = (0.9, 0.999),
        eps: float = 1e-14,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
    ) -> None:
        if lr <= 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if eps < 0.0:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        self.weight_decouple = weight_decouple
        self.fixed_decay = fixed_decay
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,)
        super(AdaDQH, self).__init__(params, defaults)

    def step(self, closure: OptLossClosure = None) -> OptFloat:
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            beta1, beta2 = group["betas"]

            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    msg = (
                        "AdaDQH does not support sparse gradients, "
                        "please consider SparseAdam instead"
                    )
                    raise RuntimeError(msg)

                if self.weight_decouple:
                    if not self.fixed_decay:
                        p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
                    else:
                        p.data.mul_(1.0 - group['weight_decay'])
                else:
                    if group['weight_decay'] != 0:
                        grad.add_(p.data, alpha=group['weight_decay'])

                state = self.state[p]
                # Lazy state initialization
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(
                        p, memory_format=torch.preserve_format
                    )
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(
                        p, memory_format=torch.preserve_format
                    )
                    # if group['amsgrad']:
                    #     # Maintains max of all exp. moving avg. of sq. grad. values
                    #     state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                exp_avg, exp_avg_sq = (
                    state["exp_avg"],
                    state["exp_avg_sq"],
                )

                state["step"] += 1
                exp_avg_old = exp_avg.detach().clone()
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                bias_correction1_old = 1 - beta1 ** (state["step"] - 1)
                bias_correction1 = 1 - beta1 ** state["step"]
                bias_correction2 = 1 - beta2 ** state["step"]
                h_t = exp_avg / bias_correction1
                if state["step"] > 1:
                    h_t -= exp_avg_old / bias_correction1_old
                exp_avg_sq.mul_(beta2).addcmul_(h_t, h_t, value=1 - beta2)
                denom = exp_avg_sq.sqrt() / np.sqrt(bias_correction2)
                if torch.cuda.is_available():
                    denom = torch.where(
                        denom > group["eps"], denom, torch.Tensor([group["eps"]]).cuda()
                    )
                else:
                    denom = np.where(denom > group["eps"], denom, group["eps"])
                # step_size = group['lr'] / bias_correction1
                p.data = p.data - group["lr"] * (exp_avg / (denom * bias_correction1))
        return loss
