import math
from typing import List

import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer

class MultiTensorApply(object):
    available = False
    warned = False

    def __init__(self, chunk_size):
        try:
            MultiTensorApply.available = True
            self.chunk_size = chunk_size
        except ImportError as err:
            MultiTensorApply.available = False
            MultiTensorApply.import_err = err

    def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
        return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)

class SoftSignSGD(Optimizer):
    """
        foreach (bool): if True would use torch._foreach implementation.
            It's faster but uses slightly more memory. (default: True)
        fused (bool, optional): whether fused implementation is used.
            (default: False)
    """
    def __init__(self,
                 params,
                 lr=1e-3,
                 beta=0.9,
                 eps=1e-8,
                 weight_decay=0.0,
                 power=1.0,
                 foreach: bool = False,
                 fused: bool = False):
        if not 0.0 <= lr:
            raise ValueError('Invalid learning rate: {}'.format(lr))
        if not 0.0 <= eps:
            raise ValueError('Invalid epsilon value: {}'.format(eps))
        if not 0.0 <= beta < 1.0:
            raise ValueError('Invalid beta parameter at index 0: {}'.format(
                beta))
        if fused:
            _check_fused_available()

        defaults = dict(lr=lr,
                        beta=beta,
                        eps=eps,
                        weight_decay=weight_decay,
                        power=power,
                        foreach=foreach,
                        fused=fused)
        super().__init__(params, defaults)

    @torch.no_grad()
    def restart_opt(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                if p.requires_grad:
                    state = self.state[p]
                    # State initialization
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p, dtype=torch.float32)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_abs'] = torch.zeros_like(p, dtype=torch.float32)


    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step."""

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_abss = []
            # assume same step across group now to simplify things
            # per parameter step can be easily support
            # by making it tensor, or pass list into kernel
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            for p in group['params']:
                if p.grad is None:
                    continue
                params_with_grad.append(p)
                grads.append(p.grad)

                state = self.state[p]
                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p, dtype=torch.float32)
                    state['exp_avg_abs'] = torch.zeros_like(p, dtype=torch.float32)

                exp_avgs.append(state['exp_avg'])
                exp_avg_abss.append(state['exp_avg_abs'])

            if not params_with_grad:
                continue

            kwargs = dict(
                params=params_with_grad,
                grads=grads,
                exp_avgs=exp_avgs,
                exp_avg_abss=exp_avg_abss,
                beta=group['beta'],
                lr=group['lr'],
                weight_decay=group['weight_decay'],
                eps=group['eps'],
                power=group['power']
            )

            if group['foreach']:
                if group['fused']:
                    if torch.cuda.is_available():
                        _fused_softsignsgd_multi_tensor(**kwargs)
                    else:
                        raise ValueError('Fused softsignsgd does not support CPU')
                else:
                    _multi_tensor_softsignsgd(**kwargs)
            elif group['fused']:
                if torch.cuda.is_available():
                    _fused_softsignsgd_single_tensor(**kwargs)
                else:
                    raise ValueError('Fused softsignsgd does not support CPU')
            else:
                _single_tensor_softsignsgd(**kwargs)

        return loss
    
def _single_tensor_softsignsgd(
    params: List[Tensor],
    grads: List[Tensor],
    exp_avgs: List[Tensor],
    exp_avg_abss: List[Tensor],
    *,
    beta: float,
    lr: float,
    weight_decay: float,
    eps: float,
    power: float,
):
    for i, param in enumerate(params):
        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_abs = exp_avg_abss[i]


        scale_grad = (1 - beta) * grad
        scale_grad_abs = (1 - beta) * ((grad.abs()).pow(power))

        exp_avg.mul_(beta).add_(scale_grad) # m_t
        exp_avg_abs.mul_(beta).add_(scale_grad_abs) # b_t


        if param.dtype == torch.float16:
            p_data_fp32 = param.data.float()
        else:
            p_data_fp32 = param
        
        p_data_fp32.mul_(1 - lr * weight_decay)

        numers = beta * exp_avg + scale_grad
        denoms = (beta * exp_avg_abs + scale_grad_abs).pow(1.0/power) + eps
        fractions = numers / denoms
    
        p_data_fp32.add_(fractions, alpha=-lr)
        
        if param.dtype == torch.float16:
            param.set_(p_data_fp32.half())
        


def _multi_tensor_softsignsgd(
    params: List[Tensor],
    grads: List[Tensor],
    exp_avgs: List[Tensor],
    exp_avg_abss: List[Tensor],
    *,
    beta: float,
    lr: float,
    weight_decay: float,
    eps: float,
    power: float,
):
    if len(params) == 0:
        return

    # for memory saving, we use `neg_pre_grads`
    # to get some temp variable in a inplace way
    scale_grads = torch._foreach_mul(grads, (1 - beta))
    scale_grad_abss = torch._foreach_abs(grads)
    torch._foreach_pow_(scale_grad_abss, power)
    torch._foreach_mul_(scale_grad_abss, (1 - beta))

    torch._foreach_mul_(exp_avgs, beta)
    torch._foreach_add_(exp_avgs, scale_grads)  # m_t

    torch._foreach_mul_(exp_avg_abss, beta)
    torch._foreach_add_(exp_avg_abss, scale_grad_abss)  # b_t

    numers = torch._foreach_mul(exp_avgs, beta)
    torch._foreach_add_(numers, scale_grads)
    denoms = torch._foreach_mul(exp_avg_abss, beta)
    torch._foreach_add_(denoms, scale_grad_abss)
    torch._foreach_pow_(denoms, 1.0 / power)
    torch._foreach_add_(denoms, eps)
    
    fractions = torch._foreach_div(numers, denoms)


    if params[0].dtype == torch.float16:
        p_data_fp32s = []
        for param in params:
            p_data_fp32s.append(param.data.float())
    else:
        p_data_fp32s = params
    torch._foreach_mul_(p_data_fp32s, 1 - lr * weight_decay)
    torch._foreach_add_(p_data_fp32s, fractions, alpha=-lr)
    if params[0].dtype == torch.float16:
        for i, param in enumerate(params):
            param.set_(p_data_fp32s[i].half())


def _fused_softsignsgd_multi_tensor(
    params: List[Tensor],
    grads: List[Tensor],
    exp_avgs: List[Tensor],
    exp_avg_abss: List[Tensor],
    *,
    beta: float,
    lr: float,
    weight_decay: float,
    eps: float,
    power: float,
):
    import fused_softsignsgd
    multi_tensor_applier = MultiTensorApply(2048 * 32)
    _dummy_overflow_buf = torch.cuda.IntTensor([0])
    multi_tensor_applier(
        fused_softsignsgd.softsignsgd_multi_tensor, _dummy_overflow_buf,
        [params, grads, exp_avgs, exp_avg_abss],
        beta, lr, weight_decay, eps, power)


def _fused_softsignsgd_single_tensor(
    params: List[Tensor],
    grads: List[Tensor],
    exp_avgs: List[Tensor],
    exp_avg_abss: List[Tensor],
    *,
    beta: float,
    lr: float,
    weight_decay: float,
    eps: float,
    power: float,
):
    for i, param in enumerate(params):
        p_data_fp32 = param.data.float()
        out_p = param.data
        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_abs = exp_avg_abss[i]
        with torch.cuda.device(param.device):
            import fused_softsignsgd
            fused_softsignsgd.softsignsgd_single_tensor(
                p_data_fp32,
                out_p,
                grad,
                exp_avg,
                exp_avg_abs,
                beta,
                lr,
                weight_decay,
                eps, 
                power
            )




def _check_fused_available():
    try:
        import fused_softsignsgd
    except ImportError as exc:
        if torch.cuda.is_available():
            # The module should be available but isn't. Try to
            # help the user in this case.
            raise ImportError((
                str(exc)
                + (
                    '\nThis could be caused by not having compiled '
                    'the CUDA extension during package installation. '
                    'Please try to re-install the package with '
                    'the environment flag `FORCE_CUDA=1` set.'
                )
            ))
        else:
            raise ImportError(
                str(exc) + '\nFused softsignsgd does not support CPU.')