r"""Functional interface"""
from collections import defaultdict
import math
import torch
from torch import Tensor
from typing import List

# TODO: use foreach API in optim.functional to do all the computation

def _make_sparse(grad, grad_indices, values):
    size = grad.size()
    if grad_indices.numel() == 0 or values.numel() == 0:
        return torch.empty_like(grad)
    return torch.sparse_coo_tensor(grad_indices, values, size)


def adagrad(params: List[Tensor],
            grads: List[Tensor],
            state_sums: List[Tensor],
            state_steps: List[int],
            lr: float,
            weight_decay: float,
            lr_decay: float,
            eps: float):
    r"""Functional API that performs Adagrad algorithm computation.

    See :class:`~torch.optim.Adagrad` for details.
    """

    for (param, grad, state_sum, step) in zip(params, grads, state_sums, state_steps):
        if weight_decay != 0:
            if grad.is_sparse:
                raise RuntimeError("weight_decay option is not compatible with sparse gradients")
            grad = grad.add(param, alpha=weight_decay)

        clr = lr / (1 + (step - 1) * lr_decay)

        if grad.is_sparse:
            grad = grad.coalesce()  # the update is non-linear so indices must be unique
            grad_indices = grad._indices()
            grad_values = grad._values()
            size = grad.size()

            state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
            std = state_sum.sparse_mask(grad)
            std_values = std._values().sqrt_().add_(eps)
            param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
        else:
            state_sum.addcmul_(grad, grad, value=1)
            std = state_sum.sqrt().add_(eps)
            param.addcdiv_(grad, std, value=-clr)


def adam(params: List[Tensor],
         grads: List[Tensor],
         exp_avgs: List[Tensor],
         exp_avg_sqs: List[Tensor],
         max_exp_avg_sqs: List[Tensor],
         state_steps: List[int],
         amsgrad: bool,
         beta1: float,
         beta2: float,
         lr: float,
         weight_decay: float,
         eps: float,
         statistic: defaultdict(dict)):
    r"""Functional API that performs Adam algorithm computation.

    See :class:`~torch.optim.Adam` for details.
    """

    for i, param in enumerate(params):

        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step = state_steps[i]
        if amsgrad:
            max_exp_avg_sq = max_exp_avg_sqs[i]

        bias_correction1 = 1 - beta1 ** step
        bias_correction2 = 1 - beta2 ** step

        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        greater = exp_avg.mul(grad)
        exp_true = torch.where(greater > 0, exp_avg, torch.zeros_like(exp_avg, memory_format=torch.preserve_format))
        m_true_accm = torch.where(greater > 0, torch.ones_like(exp_avg), torch.zeros_like(exp_avg))
        statistic['sum_up'] += torch.sum(torch.ones_like(m_true_accm, memory_format=torch.preserve_format))
        statistic['real_up'] += torch.sum(m_true_accm)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
            # Use the max. for normalizing running avg. of gradient
            denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
        else:
            denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

        step_size = lr / bias_correction1

        param.addcdiv_(exp_avg, denom, value=-step_size)

def cadam(params: List[Tensor],
         grads: List[Tensor],
         exp_avgs: List[Tensor],
         exp_avg_sqs: List[Tensor],
         max_exp_avg_sqs: List[Tensor],
         state_steps: List[int],
         amsgrad: bool,
         beta1: float,
         beta2: float,
         lr: float,
         weight_decay: float,
         eps: float,
         statistic: defaultdict(dict)):
    r"""Functional API that performs Adam algorithm computation.

    See :class:`~torch.optim.Adam` for details.
    """

    for i, param in enumerate(params):

        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step = state_steps[i]

        if amsgrad:
            max_exp_avg_sq = max_exp_avg_sqs[i]

        bias_correction1 = 1 - beta1 ** step
        bias_correction2 = 1 - beta2 ** step

        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        # Decay the first and second moment running average coefficient
        
        #exp_avg = torch.where(torch.abs(grad) < 10e-16, exp_avg, exp_avg.mul(beta1).add(grad, alpha=1 - beta1))
        #zero_grad = torch.where(torch.abs(grad) < 10e-16, torch.ones_like(grad), torch.zeros_like(grad))
        #statistic['sum_zero_grad'] = torch.sum(zero_grad)
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        greater = exp_avg.mul(grad)
        exp_true = torch.where(greater > 0, exp_avg, torch.zeros_like(exp_avg, memory_format=torch.preserve_format))
        m_true_accm = torch.where(greater > 0, torch.ones_like(exp_avg), torch.zeros_like(exp_avg))
        statistic['sum_up'] += torch.sum(torch.ones_like(m_true_accm, memory_format=torch.preserve_format))
        statistic['real_up'] += torch.sum(m_true_accm)
        #last_exp_avg_sq = exp_avg_sq
        #new_exp_avg_sq = last_exp_avg_sq.mul(beta2).addcmul(grad, grad, value=1 - beta2)
        #exp_avg_sq.add_(torch.where(greater > 0, new_exp_avg_sq-last_exp_avg_sq, torch.zeros_like(exp_avg_sq)))
        #exp_avg_sq.add_(torch.where(new_exp_avg_sq > last_exp_avg_sq, new_exp_avg_sq-last_exp_avg_sq, torch.zeros_like(exp_avg_sq)))
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
        #exp_avg_sq = torch.where(torch.abs(grad) < 10e-16, exp_avg_sq, exp_avg_sq.mul(beta2).addcmul(grad, grad, value=1 - beta2))
        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
            # Use the max. for normalizing running avg. of gradient
            denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
        else:
            denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

        step_size = lr / bias_correction1

        param.addcdiv_(exp_true, denom, value=-step_size)

def agd(params: List[Tensor],
         grads: List[Tensor],
         exp_avgs: List[Tensor],
         exp_avg_sqs: List[Tensor],
         max_exp_avg_sqs: List[Tensor],
         state_steps: List[int],
         amsgrad: bool,
         beta1: float,
         beta2: float,
         lr: float,
         weight_decay: float,
         eps: float,
         statistic: defaultdict(dict)):
    r"""Functional API that performs Adam algorithm computation.

    See :class:`~torch.optim.Adam` for details.
    """

    for i, param in enumerate(params):

        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step = state_steps[i]
        if amsgrad:
            max_exp_avg_sq = max_exp_avg_sqs[i]

        bias_correction1 = 1 - beta1 ** step
        bias_correction2 = 1 - beta2 ** step

        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        greater = exp_avg.mul(grad)
        exp_true = torch.where(greater > 0, exp_avg, torch.zeros_like(exp_avg, memory_format=torch.preserve_format))
        m_true_accm = torch.where(greater > 0, torch.ones_like(exp_avg), torch.zeros_like(exp_avg))
        statistic['sum_up'] += torch.sum(torch.ones_like(m_true_accm, memory_format=torch.preserve_format))
        statistic['real_up'] += torch.sum(m_true_accm)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
            # Use the max. for normalizing running avg. of gradient
            denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
        else:
            denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

        step_size = lr / bias_correction1

        param.addcdiv_(exp_avg, denom, value=-step_size)
