import math
import torch
from torch.optim import Optimizer
from collections import defaultdict


class Lookahead(Optimizer):
    '''
    PyTorch implementation of the lookahead wrapper.
    Lookahead Optimizer: https://arxiv.org/abs/1907.08610
    '''
    def __init__(self, x_optimizer, y_optimizer, k=5, alpha=0.5,
                 pullback_momentum=None, device='cuda'):
        '''
        :param optimizer:inner optimizer
        :param k (int): number of lookahead steps
        :param alpha(float): linear interpolation factor. 1.0 recovers the
            inner optimizer.
        :param pullback_momentum (str): change to inner optimizer momentum on
            interpolation update
        '''
        if not 0.0 <= alpha <= 1.0:
            raise ValueError(f'Invalid slow update rate: {alpha}')
        if not 1 <= k:
            raise ValueError(f'Invalid lookahead steps: {k}')
        self.x_optimizer = x_optimizer
        self.y_optimizer = y_optimizer
        self.x_param_groups = self.x_optimizer.param_groups
        self.y_param_groups = self.y_optimizer.param_groups
        self.k = k
        self.alpha = alpha
        self.k_counter = 0
        assert pullback_momentum in ('reset', 'pullback', None)
        self.pullback_momentum = pullback_momentum
        self.state = defaultdict(dict)

        # Cache the current optimizer parameters
        for group in self.x_optimizer.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                param_state['cached_params'] = torch.zeros_like(p.data).to(device)
                param_state['cached_params'].copy_(p.data)
        for group in self.y_optimizer.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                param_state['cached_params'] = torch.zeros_like(p.data).to(device)
                param_state['cached_params'].copy_(p.data)

    def __getstate__(self):
        return {
            'state': self.state,
            'x_optimizer': self.x_optimizer,
            'y_optimizer': self.y_optimizer,
            'k': self.k,
            'alpha': self.alpha,
            'k_counter': self.k_counter,
            'pullback_momentum': self.pullback_momentum,
        }

    def zero_grad(self):
        self.x_optimizer.zero_grad()
        self.y_optimizer.zero_grad()

    def state_dict(self):
        return {
            'x_state': self.x_optimizer.state_dict(),
            'y_state': self.y_optimizer.state_dict()
        }

    def load_state_dict(self, state_dict):
        self.x_optimizer.load_state_dict(state_dict['x_state'])
        self.y_optimizer.load_state_dict(state_dict['y_state'])

    def _backup_and_load_cache(self):
        """
        Useful for performing evaluation on the slow weights (which typically
        generalize better)

        """
        for group in self.x_optimizer.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                param_state['backup_params'] = torch.zeros_like(p.data)
                param_state['backup_params'].copy_(p.data)
                p.data.copy_(param_state['cached_params'])
        for group in self.y_optimizer.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                param_state['backup_params'] = torch.zeros_like(p.data)
                param_state['backup_params'].copy_(p.data)
                p.data.copy_(param_state['cached_params'])

    def _clear_and_load_backup(self):
        for group in self.x_optimizer.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                p.data.copy_(param_state['backup_params'])
                del param_state['backup_params']
        for group in self.y_optimizer.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                p.data.copy_(param_state['backup_params'])
                del param_state['backup_params']

    def step(self, closure=None):
        """Performs a single Lookahead optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        # loss = self.optimizer.step(closure)
        self.k_counter += 1

        if self.k_counter >= self.k:
            self.k_counter = 0

            # ======================
            # update slow weight (x)
            # ======================

            for group in self.x_optimizer.param_groups:
                for p in group['params']:
                    param_state = self.state[p]
                    p.data\
                        .mul_(self.alpha)\
                        .add_(
                            1.0 - self.alpha, param_state['cached_params']
                        )
                    param_state['cached_params'].copy_(p.data)

                    if self.pullback_momentum == "pullback":
                        internal_momentum = self.x_optimizer.state[p]["momentum_buffer"]
                        self.x_optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.alpha).add_(
                            1.0 - self.alpha, param_state["cached_mom"])
                        param_state["cached_mom"] = self.x_optimizer.state[p]["momentum_buffer"]
                    elif self.pullback_momentum == "reset":
                        self.x_optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data)

            # ======================
            # update slow weight (y)
            # ======================

            for group in self.y_optimizer.param_groups:
                for p in group['params']:
                    param_state = self.state[p]
                    p.data\
                        .mul_(self.alpha)\
                        .add_(
                            1.0 - self.alpha, param_state['cached_params']
                        )
                    param_state['cached_params'].copy_(p.data)

                    if self.pullback_momentum == "pullback":
                        internal_momentum = self.y_optimizer.state[p]["momentum_buffer"]
                        self.y_optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.alpha).add_(
                            1.0 - self.alpha, param_state["cached_mom"])
                        param_state["cached_mom"] = self.y_optimizer.state[p]["momentum_buffer"]
                    elif self.pullback_momentum == "reset":
                        self.y_optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data)
