r"""Functional interface"""
from collections import defaultdict
import math
import torch
from torch import Tensor
from typing import List
import torch
from torch.optim import Optimizer

def _make_sparse(grad, grad_indices, values):
    size = grad.size()
    if grad_indices.numel() == 0 or values.numel() == 0:
        return torch.empty_like(grad)
    return torch.sparse_coo_tensor(grad_indices, values, size)


# def adagrad(params: List[Tensor],
#             grads: List[Tensor],
#             state_sums: List[Tensor],
#             state_steps: List[int],
#             lr: float,
#             weight_decay: float,
#             lr_decay: float,
#             eps: float):
#     r"""Functional API that performs Adagrad algorithm computation.

#     See :class:`~torch.optim.Adagrad` for details.
#     """

#     for (param, grad, state_sum, step) in zip(params, grads, state_sums, state_steps):
#         if weight_decay != 0:
#             if grad.is_sparse:
#                 raise RuntimeError("weight_decay option is not compatible with sparse gradients")
#             grad = grad.add(param, alpha=weight_decay)

#         clr = lr / (1 + (step - 1) * lr_decay)

#         if grad.is_sparse:
#             grad = grad.coalesce()  # the update is non-linear so indices must be unique
#             grad_indices = grad._indices()
#             grad_values = grad._values()
#             size = grad.size()

#             state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
#             std = state_sum.sparse_mask(grad)
#             std_values = std._values().sqrt_().add_(eps)
#             param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
#         else:
#             state_sum.addcmul_(grad, grad, value=1)
#             std = state_sum.sqrt().add_(eps)
#             param.addcdiv_(grad, std, value=-clr)


def adam(params: List[Tensor],
         grads: List[Tensor],
         exp_avgs: List[Tensor],
         exp_avg_sqs: List[Tensor],
         max_exp_avg_sqs: List[Tensor],
         state_steps: List[int],
         amsgrad: bool,
         beta1: float,
         beta2: float,
         lr: float,
         weight_decay: float,
         eps: float,
         statistic: defaultdict(dict)):
    r"""Functional API that performs Adam algorithm computation.

    See :class:`~torch.optim.Adam` for details.
    """
    # Initialize the statistic
    for k in statistic.keys():
        statistic[k] = 0
    
    for i, param in enumerate(params):

        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step = state_steps[i]
        if amsgrad:
            max_exp_avg_sq = max_exp_avg_sqs[i]

        bias_correction1 = 1 - beta1 ** step
        bias_correction2 = 1 - beta2 ** step
        greater = exp_avg.mul(grad)
        statistic['positive_num_before_wd'] += torch.sum(greater > 0).item()
        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        greater = exp_avg.mul(grad)
        m_true_accm = (greater > 0)
        statistic['total_num'] += m_true_accm.numel()
        statistic['positive_num'] += m_true_accm.sum().item()
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
            # Use the max. for normalizing running avg. of gradient
            denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
        else:
            denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

        step_size = lr / bias_correction1

        param.addcdiv_(exp_avg, denom, value=-step_size)

def cadam(params: List[Tensor],
         grads: List[Tensor],
         exp_avgs: List[Tensor],
         exp_avg_sqs: List[Tensor],
         max_exp_avg_sqs: List[Tensor],
         state_steps: List[int],
         amsgrad: bool,
         beta1: float,
         beta2: float,
         lr: float,
         weight_decay: float,
         eps: float,
         statistic: defaultdict(dict),
         counters: List[Tensor] = None,
         ):
    r"""Functional API that performs Adam algorithm computation.

    See :class:`~torch.optim.Adam` for details.
    """
    # Initialize the statistic
    for k in statistic.keys():
        statistic[k] = 0
    for i, param in enumerate(params):

        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step = state_steps[i]
        counter = counters[i]

        if amsgrad:
            max_exp_avg_sq = max_exp_avg_sqs[i]

        bias_correction1 = 1 - beta1 ** step
        bias_correction2 = 1 - beta2 ** step

        greater = exp_avg.mul(grad)
        statistic['positive_num_before_wd'] += torch.sum(greater > 0).item()
        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        ###############
        #This is the case for sparse update
        #exp_avg = torch.where(torch.abs(grad) < 10e-16, exp_avg, exp_avg.mul(beta1).add(grad, alpha=1 - beta1))
        #zero_grad = torch.where(torch.abs(grad) < 10e-16, torch.ones_like(grad), torch.zeros_like(grad))
        #statistic['sum_zero_grad'] = torch.sum(zero_grad)
        ###############

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        greater = exp_avg.mul(grad)
        exp_true = torch.zeros_like(exp_avg, memory_format=torch.preserve_format)
        exp_true[greater > 0] = exp_avg[greater > 0]
        m_true_accm = (greater > 0)
        greater_sign = torch.where(greater > 0, torch.tensor(1), torch.tensor(-1)).int()
        num_transitions_pos_to_pos = torch.sum(torch.logical_and(m_true_accm, counter > 0)).item()
        num_transitions_neg_to_neg = torch.sum(torch.logical_and(~m_true_accm, counter <= 0)).item()
        num_transitions_neg_to_pos = torch.sum(torch.logical_and(m_true_accm, counter <= 0)).item()
        num_transitions_pos_to_neg = torch.sum(torch.logical_and(~m_true_accm, counter > 0)).item()
        # Target: abs(counter) record the duration of the same sign of greater, and sign(counter) record the sign of the greater
        counter_sign = torch.where(counter > 0, torch.tensor(1), torch.tensor(-1))
        counter_sign_ne = counter_sign != greater_sign
        counter[counter_sign_ne] = greater_sign[counter_sign_ne]
        counter[~counter_sign_ne] += greater_sign[~counter_sign_ne]
        # Calculate the mean duration of the same sign
        sum_duration_same_sign = torch.sum(torch.abs(counter)).item()
        sum_duration_positive = torch.sum(torch.abs(counter[counter > 0])).item()
        sum_duration_negative = torch.sum(torch.abs(counter[counter < 0])).item()
        statistic['num_transitions_pos_to_pos'] += num_transitions_pos_to_pos
        statistic['num_transitions_neg_to_neg'] += num_transitions_neg_to_neg
        statistic['num_transitions_neg_to_pos'] += num_transitions_neg_to_pos
        statistic['num_transitions_pos_to_neg'] += num_transitions_pos_to_neg
        statistic['sum_duration_same_sign'] += sum_duration_same_sign
        statistic['sum_duration_positive'] += sum_duration_positive
        statistic['sum_duration_negative'] += sum_duration_negative
        statistic['total_num'] += m_true_accm.numel()
        statistic['positive_num'] += m_true_accm.sum().item()
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
        
        ###############
        #This is the case for sparse update
        #exp_avg_sq = torch.where(torch.abs(grad) < 10e-16, exp_avg_sq, exp_avg_sq.mul(beta2).addcmul(grad, grad, value=1 - beta2))
        ###############
        
        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
            # Use the max. for normalizing running avg. of gradient
            denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
        else:
            denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

        step_size = lr / bias_correction1

        param.addcdiv_(exp_true, denom, value=-step_size)

# def cadamw(params: List[Tensor],
#          grads: List[Tensor],
#          exp_avgs: List[Tensor],
#          exp_avg_sqs: List[Tensor],
#          max_exp_avg_sqs: List[Tensor],
#          state_steps: List[int],
#          amsgrad: bool,
#          beta1: float,
#          beta2: float,
#          lr: float,
#          weight_decay: float,
#          eps: float,
#         decay_all: bool,):
#     r"""Functional API that performs Adam algorithm computation.

#     See :class:`~torch.optim.Adam` for details.
#     """

#     for i, param in enumerate(params):

#         grad = grads[i]
#         exp_avg = exp_avgs[i]
#         exp_avg_sq = exp_avg_sqs[i]
#         step = state_steps[i]

#         if amsgrad:
#             max_exp_avg_sq = max_exp_avg_sqs[i]

#         bias_correction1 = 1 - beta1 ** step
#         bias_correction2 = 1 - beta2 ** step

#         # Decay the first and second moment running average coefficient
        
#         exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
#         greater = exp_avg.mul(grad)
#         exp_true = torch.where(greater > 0, exp_avg, torch.zeros_like(exp_avg, memory_format=torch.preserve_format))
#         # m_true_accm = torch.where(greater > 0, torch.ones_like(exp_avg), torch.zeros_like(exp_avg))
#         exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
#         if amsgrad:
#             # Maintains the maximum of all 2nd moment running avg. till now
#             torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
#             # Use the max. for normalizing running avg. of gradient
#             denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
#         else:
#             denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

#         step_size = lr / bias_correction1

#         param.addcdiv_(exp_true, denom, value=-step_size)
#         if weight_decay != 0 and decay_all:
#             param.add_(param, alpha=-lr * weight_decay)
#         if weight_decay != 0 and not decay_all:
#             # if exp_true, then minus param*lr * weight_decay
#             # else do nothing in this dimension
#             param.add_(torch.where(greater > 0, param, torch.zeros_like(param)), alpha=-(lr * weight_decay))
            

# def agd(params: List[Tensor],
#          grads: List[Tensor],
#          exp_avgs: List[Tensor],
#          exp_avg_sqs: List[Tensor],
#          max_exp_avg_sqs: List[Tensor],
#          state_steps: List[int],
#          amsgrad: bool,
#          beta1: float,
#          beta2: float,
#          lr: float,
#          weight_decay: float,
#          eps: float,
#          statistic: defaultdict(dict)):
#     r"""Functional API that performs Adam algorithm computation.

#     See :class:`~torch.optim.Adam` for details.
#     """

#     for i, param in enumerate(params):

#         grad = grads[i]
#         exp_avg = exp_avgs[i]
#         exp_avg_sq = exp_avg_sqs[i]
#         step = state_steps[i]
#         if amsgrad:
#             max_exp_avg_sq = max_exp_avg_sqs[i]

#         bias_correction1 = 1 - beta1 ** step
#         bias_correction2 = 1 - beta2 ** step

#         if weight_decay != 0:
#             grad = grad.add(param, alpha=weight_decay)

#         # Decay the first and second moment running average coefficient
#         exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
#         greater = exp_avg.mul(grad)
#         exp_true = torch.where(greater > 0, exp_avg, torch.zeros_like(exp_avg, memory_format=torch.preserve_format))
#         m_true_accm = torch.where(greater > 0, torch.ones_like(exp_avg), torch.zeros_like(exp_avg))
#         # statistic['sum_up'] += torch.sum(torch.ones_like(m_true_accm, memory_format=torch.preserve_format))
#         # statistic['real_up'] += torch.sum(m_true_accm)
#         exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
#         if amsgrad:
#             # Maintains the maximum of all 2nd moment running avg. till now
#             torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
#             # Use the max. for normalizing running avg. of gradient
#             denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
#         else:
#             denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

#         step_size = lr / bias_correction1

#         param.addcdiv_(exp_avg, denom, value=-step_size)


class CAdam(Optimizer):
    r"""Implements CAdam algorithm.
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0, amsgrad=False, decay_all=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, amsgrad=amsgrad)
        super().__init__(params, defaults)
        self.statistic = defaultdict(int)

    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault('amsgrad', False)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            max_exp_avg_sqs = []
            counters = []
            state_steps = []
            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    if p.grad.is_sparse:
                        raise RuntimeError('CAdam does not support sparse gradients, please consider SparseAdam instead')
                    grads.append(p.grad)

                    state = self.state[p]
                    # Lazy state initialization
                    if len(state) == 0:
                        state['step'] = 0
                        # Exponential moving average of gradient values
                        state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        # Exponential moving average of squared gradient values
                        state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        if group['amsgrad']:
                            # Maintains max of all exp. moving avg. of sq. grad. values
                            state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['counter'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.int32)
                    exp_avgs.append(state['exp_avg'])
                    exp_avg_sqs.append(state['exp_avg_sq'])
                    counters.append(state['counter'])

                    if group['amsgrad']:
                        max_exp_avg_sqs.append(state['max_exp_avg_sq'])

                    # update the steps for each param group update
                    state['step'] += 1
                    # record the step after step update
                    state_steps.append(state['step'])

            beta1, beta2 = group['betas']
            cadam(params_with_grad,
                   grads,
                   exp_avgs,
                   exp_avg_sqs,
                   max_exp_avg_sqs,
                   state_steps,
                   group['amsgrad'],
                   beta1,
                   beta2,
                   group['lr'],
                   group['weight_decay'],
                   group['eps'],
                   self.statistic,
                   counters
                   )
        return loss

class Adam(Optimizer):
    r"""Implements CAdam algorithm.
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0, amsgrad=False, decay_all=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, amsgrad=amsgrad)
        super().__init__(params, defaults)
        self.statistic = defaultdict(int)

    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault('amsgrad', False)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        
        if len(self.statistic) == 0:
            self.statistic['total_num'] = 0
            self.statistic['positive_num'] = 0

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            max_exp_avg_sqs = []
            state_steps = []
            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    if p.grad.is_sparse:
                        raise RuntimeError('CAdam does not support sparse gradients, please consider SparseAdam instead')
                    grads.append(p.grad)

                    state = self.state[p]
                    # Lazy state initialization
                    if len(state) == 0:
                        state['step'] = 0
                        # Exponential moving average of gradient values
                        state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        # Exponential moving average of squared gradient values
                        state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        if group['amsgrad']:
                            # Maintains max of all exp. moving avg. of sq. grad. values
                            state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                    exp_avgs.append(state['exp_avg'])
                    exp_avg_sqs.append(state['exp_avg_sq'])

                    if group['amsgrad']:
                        max_exp_avg_sqs.append(state['max_exp_avg_sq'])

                    # update the steps for each param group update
                    state['step'] += 1
                    # record the step after step update
                    state_steps.append(state['step'])

            beta1, beta2 = group['betas']
            adam(params_with_grad,
                   grads,
                   exp_avgs,
                   exp_avg_sqs,
                   max_exp_avg_sqs,
                   state_steps,
                   group['amsgrad'],
                   beta1,
                   beta2,
                   group['lr'],
                   group['weight_decay'],
                   group['eps'],
                   self.statistic
                   )
        return loss

class CAdamW(Optimizer):
    r"""Implements CAdamW algorithm.


    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
            algorithm from the paper `On the Convergence of Adam and Beyond`_
            (default: False)

    .. _Adam\: A Method for Stochastic Optimization:
        https://arxiv.org/abs/1412.6980
    .. _Decoupled Weight Decay Regularization:
        https://arxiv.org/abs/1711.05101
    .. _On the Convergence of Adam and Beyond:
        https://openreview.net/forum?id=ryQu7f-RZ
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0, amsgrad=False, decay_all=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, amsgrad=amsgrad, decay_all=decay_all)
        super().__init__(params, defaults)
        self.statistic = defaultdict(int)

    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault('amsgrad', False)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            max_exp_avg_sqs = []
            state_steps = []
            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    if p.grad.is_sparse:
                        raise RuntimeError('CAdam does not support sparse gradients, please consider SparseAdam instead')
                    grads.append(p.grad)

                    state = self.state[p]
                    # Lazy state initialization
                    if len(state) == 0:
                        state['step'] = 0
                        # Exponential moving average of gradient values
                        state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        # Exponential moving average of squared gradient values
                        state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        if group['amsgrad']:
                            # Maintains max of all exp. moving avg. of sq. grad. values
                            state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                    exp_avgs.append(state['exp_avg'])
                    exp_avg_sqs.append(state['exp_avg_sq'])

                    if group['amsgrad']:
                        max_exp_avg_sqs.append(state['max_exp_avg_sq'])

                    # update the steps for each param group update
                    state['step'] += 1
                    # record the step after step update
                    state_steps.append(state['step'])

            beta1, beta2 = group['betas']
            cadamw(params_with_grad,
                   grads,
                   exp_avgs,
                   exp_avg_sqs,
                   max_exp_avg_sqs,
                   state_steps,
                   group['amsgrad'],
                   beta1,
                   beta2,
                   group['lr'],
                   group['weight_decay'],
                   group['eps'],
                   group['decay_all'],
                   )
        # print('sum_dense_update {} real_dense_update {} sum_zero_grad {}'.format(self.statistic['sum_up'], self.statistic['real_up'], self.statistic['sum_zero_grad']))
        return loss
