import torch
from torch.optim import Optimizer
import torch.autograd as autograd

class SGD(Optimizer):


    def __init__(self, params, lr=0.0, momentum=0, weight_decay=0, N = 4, x_lr = 1e-3, lamb_lr = 1e-2, init_x = None, device = 'cpu'):

        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, N = N, lamb_lr = lamb_lr)
        if (init_x is None):
            self.x = torch.ones(N, device = device)/(N-1)
            self.x[0] = 0
        else:
            self.x = init_x
        
        self.N = N
        self.xupdate = torch.zeros(N,device =device)
        self.x_lr = x_lr
                       
        super(SGD, self).__init__(params, defaults)
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                tmp = torch.clone(p.data)
                l = []
                for i in range(N):
                    l.append(tmp)
                tmp = torch.stack(l)
                state['m'] = torch.zeros_like(p.data, device = p.device)
                state['update_w'] = torch.zeros_like(tmp, device = p.device)
        #super(SGD,self).__init__(params, defaults) 
             
        
    def step(self, closure=None,node = None):
        
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        
        
        for group in self.param_groups: 
            for p in group['params']:
                state = self.state[p]
                
                
                if (p.grad is None):
                    continue
                grad = p.grad.data
                if group['weight_decay']!=0:
                    grad.add_(group['weight_decay'], p.data)
                    
                state['update_w'][node] = torch.clone(grad).detach()
                   
        return loss
    
    def step_w(self,active):
       for group in self.param_groups:
            for p in group['params']:
                if (p.grad is None):
                     continue
                state = self.state[p]
                tmp = 0#state['update_w'][0]
                N = self.N
                total = torch.sum(active[1:]).float()
                for i in range(1,N):
                    if (active[i]==1):
                        tmp = tmp + self.x[i] * state['update_w'][i]
               
                tmp = tmp *(N-1)/total
                state['m'] = ((1-group['momentum'])*tmp + group['momentum']*state['m']).detach()
                 
                p.data.add_(-group['lr']*state['m'])
                p.data = p.data.detach() 
