import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.optim import SGD
import pdb

class LARS(SGD):
    """
    Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py
    Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py
    Args:
        optimizer: Pytorch optimizer to wrap and modify learning rate for.
        trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888
    """

    def __init__(self,  param_groups, lr, momentum, trust_coefficient=0.001):
        super(LARS, self).__init__(param_groups, lr, momentum)
        self.trust_coefficient = trust_coefficient

    def step(self):
        with torch.no_grad():
            weight_decays = []
            for group in self.param_groups:
                # absorb weight decay control from optimizer
                weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
                weight_decays.append(weight_decay)
                group['weight_decay'] = 0
                for p in group['params']:
                    if p.grad is None:
                        continue

                    if weight_decay != 0:
                        p.grad.data += weight_decay * p.data

                    param_norm = torch.norm(p.data)
                    grad_norm = torch.norm(p.grad.data)
                    adaptive_lr = 1.

                    if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']:
                        adaptive_lr = self.trust_coefficient * param_norm / grad_norm

                    p.grad.data *= adaptive_lr

        super().step()
        # return weight decay control to optimizer
        for i, group in enumerate(self.param_groups):
            group['weight_decay'] = weight_decays[i]