import math
import torch as th
import numpy as np
from torch.optim.optimizer import Optimizer, required
from collections import OrderedDict, defaultdict

def Ranger(params, alpha=0.5, k=6, *args, **kwargs):
     radam = RAdam(params, *args, **kwargs)
     return Lookahead(radam, alpha, k)

""" Lookahead Optimizer Wrapper.
Implementation modified from: https://github.com/alphadl/lookahead.pytorch
Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
"""
class Lookahead(Optimizer):
    def __init__(self, base_optimizer, alpha=0.5, k=6):
        if not 0.0 <= alpha <= 1.0:
            raise ValueError(f'Invalid slow update rate: {alpha}')
        if not 1 <= k:
            raise ValueError(f'Invalid lookahead steps: {k}')
        defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
        self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
        self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
        self.base_optimizer = base_optimizer
        self.param_groups = self.base_optimizer.param_groups
        self.defaults = base_optimizer.defaults
        self.defaults.update(defaults)
        self.state = defaultdict(dict)
        # manually add our defaults to the param groups
        for name, default in defaults.items():
            for group in self.param_groups:
                group.setdefault(name, default)

    def update_slow(self, group):
        for fast_p in group["params"]:
            if fast_p.grad is None:
                continue
            param_state = self.state[fast_p]
            if 'slow_buffer' not in param_state:
                param_state['slow_buffer'] = th.empty_like(fast_p.data)
                param_state['slow_buffer'].copy_(fast_p.data)
            slow = param_state['slow_buffer']
            slow.add_(group['lookahead_alpha'], fast_p.data - slow)
            fast_p.data.copy_(slow)

    def sync_lookahead(self):
        for group in self.param_groups:
            self.update_slow(group)

    def step(self, closure=None):
        # print(self.k)
        #assert id(self.param_groups) == id(self.base_optimizer.param_groups)
        loss = self.base_optimizer.step(closure)
        for group in self.param_groups:
            group['lookahead_step'] += 1
            if group['lookahead_step'] % group['lookahead_k'] == 0:
                self.update_slow(group)
        return loss

    def state_dict(self):
        fast_state_dict = self.base_optimizer.state_dict()
        slow_state = {
            (id(k) if isinstance(k, th.Tensor) else k): v
            for k, v in self.state.items()
        }
        fast_state = fast_state_dict['state']
        param_groups = fast_state_dict['param_groups']
        return {
            'state': fast_state,
            'slow_state': slow_state,
            'param_groups': param_groups,
        }

    def load_state_dict(self, state_dict):
        fast_state_dict = {
            'state': state_dict['state'],
            'param_groups': state_dict['param_groups'],
        }
        self.base_optimizer.load_state_dict(fast_state_dict)

        # We want to restore the slow state, but share param_groups reference
        # with base_optimizer. This is a bit redundant but least code
        slow_state_new = False
        if 'slow_state' not in state_dict:
            print('Loading state_dict from optimizer without Lookahead applied.')
            state_dict['slow_state'] = defaultdict(dict)
            slow_state_new = True
        slow_state_dict = {
            'state': state_dict['slow_state'],
            'param_groups': state_dict['param_groups'],  # this is pointless but saves code
        }
        super(Lookahead, self).load_state_dict(slow_state_dict)
        self.param_groups = self.base_optimizer.param_groups  # make both ref same container
        if slow_state_new:
            # reapply defaults to catch missing lookahead specific ones
            for name, default in self.defaults.items():
                for group in self.param_groups:
                    group.setdefault(name, default)

"""
Liyuan Liu , Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, and Jiawei Han (2020). 
On the Variance of the Adaptive Learning Rate and Beyond. the Eighth International Conference on Learning 
Representations.
"""
class RAdam(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, degenerated_to_sgd=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        
        self.degenerated_to_sgd = degenerated_to_sgd
        if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
            for param in params:
                if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
                    param['buffer'] = [[None, None, None] for _ in range(10)]
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
        super(RAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = th.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = th.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value = 1 - beta2)
                exp_avg.mul_(beta1).add_(grad, alpha = 1 - beta1)

                state['step'] += 1
                buffered = group['buffer'][int(state['step'] % 10)]
                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    elif self.degenerated_to_sgd:
                        step_size = 1.0 / (1 - beta1 ** state['step'])
                    else:
                        step_size = -1
                    buffered[2] = step_size

                # more conservative since it's an approximated value
                if N_sma >= 5:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(exp_avg, denom, value = -step_size * group['lr'])
                    p.data.copy_(p_data_fp32)
                elif step_size > 0:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
                    p_data_fp32.add_(-step_size * group['lr'], exp_avg)
                    p.data.copy_(p_data_fp32)

        return loss


class SDRMSprop(Optimizer):

    def __init__(self, params, lr=1e-3, alpha=0.99, beta=0.9, eps=1e-8, weight_decay=0):

        defaults = dict(lr=lr, alpha=alpha, beta=beta, eps=eps, weight_decay=weight_decay)
        super(SDRMSprop, self).__init__(params, defaults)

    def reset_state(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['v'] = th.zeros_like(p, memory_format=th.preserve_format)
                state['s'] = th.zeros_like(p, memory_format=th.preserve_format)

    @th.no_grad()
    def step(self, debug=False):

        for group in self.param_groups:

            a = group['alpha']
            b = group['beta']
            lr  = group['lr']
            eps = group['eps']

            for p in group['params']:
                if p.grad is not None:

                    g = p.grad

                    state = self.state[p]

                    # Lazy state initialization
                    if len(state) == 0:
                        state['v'] = th.zeros_like(p, memory_format=th.preserve_format)
                        state['s'] = th.zeros_like(p, memory_format=th.preserve_format)

                    state['v'] = a * state['v'] + (1 - a) * g**2
                    state['s'] = b * state['s'] + (1 - b) * th.sign(g)

                    _v = state['v']
                    _s = state['s']


                    p.add_(-1 * lr * _s**2 * g / (th.sqrt(_v) + eps))



class SDAdam(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.9), eps=1e-8, weight_decay=0):

        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(SDAdam, self).__init__(params, defaults)

    def reset_state(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step']  = 0
                state['m']     = th.zeros_like(p, memory_format=th.preserve_format)
                state['v']     = th.zeros_like(p, memory_format=th.preserve_format)
                state['s']     = th.zeros_like(p, memory_format=th.preserve_format)

    @th.no_grad()
    def step(self, debug=False):

        for group in self.param_groups:

            b1, b2, b3 = group['betas']
            lr  = group['lr']
            eps = group['eps']

            for p in group['params']:
                if p.grad is not None:

                    g = p.grad

                    state = self.state[p]

                    # Lazy state initialization
                    if len(state) == 0:
                        state['step']  = 0
                        state['m']     = th.zeros_like(p, memory_format=th.preserve_format)
                        state['v']     = th.zeros_like(p, memory_format=th.preserve_format)
                        state['s']     = th.zeros_like(p, memory_format=th.preserve_format)

                    state['step'] += 1

                    state['m'] = b1 * state['m'] + (1 - b1) * g
                    state['v'] = b2 * state['v'] + (1 - b2) * g**2
                    state['s'] = b3 * state['s'] + (1 - b3) * th.sign(g)

                    _m = state['m']     / (1 - b1**state['step'])
                    _v = state['v']     / (1 - b2**state['step'])
                    _s = state['s']     / (1 - b3**state['step'])


                    p.add_(-1 * lr * _s**2 * _m / (th.sqrt(_v) + eps))

class SDAMSGrad(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.9), eps=1e-8, weight_decay=0):

        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(SDAMSGrad, self).__init__(params, defaults)

    def reset_state(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step']  = 0
                state['m']     = th.zeros_like(p, memory_format=th.preserve_format)
                state['v']     = th.zeros_like(p, memory_format=th.preserve_format)
                state['s']     = th.zeros_like(p, memory_format=th.preserve_format)
                state['v_max'] = th.zeros_like(p, memory_format=th.preserve_format)

    @th.no_grad()
    def step(self, debug=False):

        for group in self.param_groups:

            b1, b2, b3 = group['betas']
            lr  = group['lr']
            eps = group['eps']

            for p in group['params']:
                if p.grad is not None:

                    g = p.grad

                    state = self.state[p]

                    # Lazy state initialization
                    if len(state) == 0:
                        state['step']  = 0
                        state['m']     = th.zeros_like(p, memory_format=th.preserve_format)
                        state['v']     = th.zeros_like(p, memory_format=th.preserve_format)
                        state['s']     = th.zeros_like(p, memory_format=th.preserve_format)
                        state['v_max'] = th.zeros_like(p, memory_format=th.preserve_format)

                    state['step'] += 1

                    state['m'] = b1 * state['m'] + (1 - b1) * g
                    state['v'] = b2 * state['v'] + (1 - b2) * g**2
                    state['s'] = b3 * state['s'] + (1 - b3) * th.sign(g)
                    state['v_max'] = th.max(state['v_max'], state['v'])

                    _m = state['m']     / (1 - b1**state['step'])
                    _v = state['v_max'] / (1 - b2**state['step'])
                    _s = state['s']     / (1 - b3**state['step'])


                    p.add_(-1 * lr * _s**2 * _m / (th.sqrt(_v) + eps))


