import copy
import torch
from torch.optim import Optimizer


class SVRG_k(Optimizer):
    r"""Optimization class for calculating the gradient of one iteration.
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate
    """
    def __init__(self, params, lr, weight_decay=0):
        print("Using optimizer: SVRG")
        self.u = None
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if weight_decay < 0.0:
            raise ValueError(f"Invalid weight decay: {weight_decay}")
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super(SVRG_k, self).__init__(params, defaults)
    
    def get_param_groups(self):
        return self.param_groups

    def set_u(self, new_u):
        """Set the mean gradient for the current epoch. 
        """
        if self.u is None:
            self.u = copy.deepcopy(new_u)
            # self.u = [{'params': [torch.zeros_like(p) for p in group['params']]} for group in new_u]
        for u_group, new_group in zip(self.u, new_u):  
            for u, new_u in zip(u_group['params'], new_group['params']):
                u.grad = new_u.grad.clone()
                # u.copy_(new_u.grad)

    def step(self, params):
        """Performs a single optimization step.
        """
        with torch.no_grad():
            for group, new_group, u_group in zip(self.param_groups, params, self.u):
                weight_decay = group['weight_decay']
                lr = group['lr']
                for p, q, u in zip(group['params'], new_group['params'], u_group['params']):
                    if p.grad is None or q.grad is None:
                        continue
                    # core SVRG gradient update 
                    new_d = p.grad.data - q.grad.data + u.grad.data
                    # new_d = p.grad.data - q.grad.data + u
                    
                    # print l1 norm of the p q u
                    # print(torch.norm(p).item(), torch.norm(q).item(), torch.norm(u).item())
                    # check p q u mean and std is highly different ?
                    # print('mean:', p.grad.data.mean(), q.grad.data.mean(), u.grad.data.mean())
                    # print('std:', p.grad.data.std(), q.grad.data.std(), u.grad.data.std())
                    if weight_decay != 0:
                        new_d.add_(p.data, alpha=weight_decay)
                    p.data.add_(new_d, alpha=-lr)


class SVRG_Snapshot(Optimizer):
    r"""Optimization class for calculating the mean gradient (snapshot) of all samples.
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate
    """
    def __init__(self, params):
        defaults = {}
        super(SVRG_Snapshot, self).__init__(params, defaults)
      
    def get_param_groups(self):
            return self.param_groups
    
    def set_param_groups(self, new_params):
        """Copies the parameters from another optimizer. 
        """
        for group, new_group in zip(self.param_groups, new_params): 
            for p, q in zip(group['params'], new_group['params']):
                  p.data[:] = q.data[:]