import math
import torch


class SMAG(torch.optim.Optimizer):
    r"""
        Single loop algorithm for solving WCSC Min-Max problems.
    """

    def __init__(self,
                 params,
                 mode='sgd',
                 eta=1.0,
                 lr_0=0.01,
                 lr_1=0.1,
                 clip_value=1.0,
                 weight_decay=0,
                 epoch_decay=0,
                 betas=(0.9, 0.999),
                 eps=1e-8,
                 amsgrad=False,
                 momentum=0,
                 nesterov=False,
                 dampening=0,
                 verbose=False,
                 device=None,
                 gamma=100,
                 rand_init_wbuf=True,
                 **kwargs):

        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        if not 0.0 <= epoch_decay:
            raise ValueError("Invalid epoch_decay value: {}".format(epoch_decay))

        self.params = list(params)  # support optimizing partial parameters of models
        self.lr_0 = lr_0
        self.lr_1 = lr_1
        self.mode = mode.lower()
        self.model_ref = self.__init_model_ref__(self.params) if epoch_decay > 0 else None
        self.model_acc = self.__init_model_acc__(self.params) if epoch_decay > 0 else None
        self.T = 0  # for epoch_decay
        self.steps = 0  # total optimization steps
        self.verbose = verbose  # print updates for lr/regularizer
        self.epoch_decay = epoch_decay
        self.gamma = gamma
        self.rand_init_wbuf = rand_init_wbuf

        # assert self.mode in ['adam', 'sgd'], "Keyword is not found in [`adam`, `sgd`]!"

        defaults = dict(lr_0=lr_0, lr_1=lr_1, betas=betas, eps=eps, momentum=momentum, nesterov=nesterov, dampening=dampening,
                        epoch_decay=epoch_decay, weight_decay=weight_decay, amsgrad=amsgrad,
                        clip_value=clip_value, model_ref=self.model_ref, model_acc=self.model_acc)
        super(SMAG, self).__init__(self.params, defaults)

    def __setstate__(self, state):
        r"""
        # Set default options for sgd mode and adam mode
        """
        super(SMAG, self).__setstate__(state)
        for group in self.param_groups:
            if self.mode == 'sgd':
                group.setdefault('nesterov', False)
            elif self.mode == 'adam':
                group.setdefault('amsgrad', False)
            else:
                NotImplementedError

    def __init_model_ref__(self, params):
        model_ref = []
        if not isinstance(params, list):
            params = list(params)
        for var in params:
            if var is not None:
                model_ref.append(torch.empty(var.shape).normal_(mean=0, std=0.01).to(self.device))
        return model_ref

    def __init_model_acc__(self, params):
        model_acc = []
        if not isinstance(params, list):
            params = list(params)
        for var in params:
            if var is not None:
                model_acc.append(
                    torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(
                        self.device))
        return model_acc

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            # self.lr_0 = group['lr_0']
            # self.lr_1 = group['lr_1']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            model_ref = group['model_ref']
            model_acc = group['model_acc']
            clip_value = group['clip_value']
            weight_decay = group['weight_decay']
            epoch_decay = group['epoch_decay']

            if self.mode == 'sgd':
                for i, p in enumerate(group['params']):
                    if p.grad is None:
                        continue

                    param_state = self.state[p]
                    if 'w_buffer' not in param_state:
                        if self.rand_init_wbuf:
                            buf = param_state['w_buffer'] = torch.empty(p.data.shape).normal_(mean=0, std=0.01).to(
                                self.device).detach()
                        else:
                            buf = param_state['w_buffer'] = p.data.to(self.device).detach()
                    else:
                        buf = param_state['w_buffer']

                    d_p = torch.clamp(p.grad.data, -clip_value, clip_value) + 1 / self.gamma * (p.data - buf) # line 6: grad1 + moving average of 1/\gamma [\hat{\w}_t - \w_t]
                    p.add_(d_p, alpha= -group['lr_1']) # update \hat{w} @ line 6 : \hat{\w}_t --> \hat{\w}_{t+1}
                    # line 7 @ w_a
                    param_state['w_buffer'].mul_(1 - group['lr_0'] / self.gamma).add_(p.data, alpha=group['lr_0'] / self.gamma) # update w @ line 9: w_hat, line 9
            else:
                raise KeyError('Unknown optimizer mode.')
        self.steps += 1
        self.T += 1
        return loss

    def update_lr(self, decay_factor=None):
        if decay_factor != None:
            self.param_groups[0]['lr_0'] = self.param_groups[0]['lr_0'] / decay_factor  # for learning rate
            self.param_groups[0]['lr_1'] = self.param_groups[0]['lr_1'] / decay_factor  # for learning rate
            print('Reducing lr_0 to %.5f @ T=%s!' % (self.param_groups[0]['lr_0'], self.steps))
            print('Reducing lr_1 to %.5f @ T=%s!' % (self.param_groups[0]['lr_1'], self.steps))

