import math
from collections import defaultdict
import torch
from torch.optim.optimizer import Optimizer


class Adam(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, svd=False, thres=1.001,
                 weight_decay=0, amsgrad=False, param_name_map=defaultdict(dict),name_param_map=defaultdict(dict)):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(
                "Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(
                "Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, amsgrad=amsgrad, svd=svd,
                        thres=thres)
        super(Adam, self).__init__(params, defaults)

        self.eigens = defaultdict(dict)
        self.transforms = defaultdict(dict)
        self.param_name_map = param_name_map
        self.name_param_map = name_param_map

    def __setstate__(self, state):
        super(Adam, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('amsgrad', False)
            group.setdefault('svd', False)



    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            svd = group['svd']
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        'Adam does not support sparse gradients, please consider SparseAdam instead')

                update = self.get_update(group, grad, p)

                name = self.param_name_map.get(p, 'unknown_param')

                if svd and len(self.transforms) > 0:
                    if len(update.shape) == 4:
                        # the transpose of the manuscript
                        if 'fit' in name:
                            update_ = update - torch.mm(update.view(update.size(0), -1), self.transforms[p]).view_as(update)
                        else:
                            update_ = torch.mm(update.view(update.size(0), -1), self.transforms[p]).view_as(update)
                    else:
                        if 'fit' in name:
                            if self.transforms[p].shape[0]==update.shape[0]:
                                update_ =  update - torch.mm(self.transforms[p], update)
                            else:
                                update_ = update -torch.mm(update, self.transforms[p])
                        else:
                            if self.transforms[p].shape[0]==update.shape[0]:
                                update_ = torch.mm(self.transforms[p], update)
                            else:
                                update_ = torch.mm(update, self.transforms[p])
                else:
                    update_ = update
                p.data.add_(update_)
        return loss

    def get_transforms(self):
        for group in self.param_groups:
            svd = group['svd']
            if svd is False:
                continue

            for p in group['params']:
                thres = group['thres']
                if p.requires_grad == False or thres == 1.0:
                    continue
                eigen_values = self.eigens[p]['eigen_value']
                cumulative_sum = eigen_values.cumsum(dim=0) / eigen_values.sum()
                num_vectors = (cumulative_sum >= thres).nonzero(as_tuple=True)[0][0] + 1
                name = self.param_name_map.get(p, 'unknown_param')
                print('[{}] reserving basis {}/{}; cond: {}, ratio:{}'.format(
                    name,
                    num_vectors, eigen_values.shape[0],
                    eigen_values[0] / eigen_values[-1],
                    cumulative_sum[num_vectors - 1]
                ))
                basis = self.eigens[p]['eigen_vector'][:, :num_vectors]
                transform = torch.mm(basis, basis.transpose(1, 0))
                self.transforms[p] = transform / torch.norm(transform)
                self.transforms[p].detach_()

    def get_eigens(self, fea_in):

        for group in self.param_groups:
            svd = group['svd']
            if svd is False:
                continue
            for p in group['params']:
                if p.requires_grad == False:
                    continue
                eigen = self.eigens[p]
                _, eigen_value, eigen_vector = torch.svd(fea_in[p], some=False)
                eigen['eigen_value'] = eigen_value
                eigen['eigen_vector'] = eigen_vector

    def get_update(self, group, grad, p):
        amsgrad = group['amsgrad']
        state = self.state[p]
        # State initialization
        if len(state) == 0:
            state['step'] = 0
            # Exponential moving average of gradient values
            state['exp_avg'] = torch.zeros_like(p.data)
            # Exponential moving average of squared gradient values
            state['exp_avg_sq'] = torch.zeros_like(p.data)
            if amsgrad:
                # Maintains max of all exp. moving avg. of sq. grad. values
                state['max_exp_avg_sq'] = torch.zeros_like(p.data)

        exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
        if amsgrad:
            max_exp_avg_sq = state['max_exp_avg_sq']
        beta1, beta2 = group['betas']

        state['step'] += 1

        if group['weight_decay'] != 0:
            grad.add_(group['weight_decay'], p.data)

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(1 - beta1, grad)
        exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
            # Use the max. for normalizing running avg. of gradient
            denom = max_exp_avg_sq.sqrt().add_(group['eps'])
        else:
            denom = exp_avg_sq.sqrt().add_(group['eps'])

        bias_correction1 = 1 - beta1 ** state['step']
        bias_correction2 = 1 - beta2 ** state['step']
        step_size = group['lr'] * \
            math.sqrt(bias_correction2) / bias_correction1
        update = - step_size * exp_avg / denom
        return update