import torch
from torch.optim import Optimizer
import torch.autograd as autograd

class SGD(Optimizer):


    def __init__(self, params, lr=0.0, momentum=0, weight_decay=0, N = 4, x_lr = 1e-3, lamb_lr = 1e-2, init_x = None, device = 'cpu', Gamma = 1):

        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, N = N, lamb_lr = lamb_lr, Gamma = Gamma)
        if (init_x is None):
            self.x = torch.ones(N, device = device)/(N-1)
            self.x[0] = 0
        else:
            self.x = init_x
        
        self.N = N
        self.xupdate = torch.zeros(N,device =device)
        self.x_lr = x_lr
                       
        super(SGD, self).__init__(params, defaults)
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                tmp = torch.clone(p.data)
                l = []
                for i in range(N):
                    l.append(tmp)
                tmp = torch.stack(l)
                state['m'] = torch.zeros_like(p.data, device = p.device)
                state['w'] = torch.clone(tmp)
                state['lamb'] = torch.zeros_like(p.data, device =p.device)
                state['lambm'] = torch.zeros_like(p.data, device = p.device)
                state['update_grad'] = torch.zeros_like(tmp, device = p.device) 
                state['update_w'] = torch.zeros_like(tmp, device = p.device)
        #super(SGD,self).__init__(params, defaults) 
             
        
    def step(self, closure=None):
        
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        
        
        for group in self.param_groups: 
            for p in group['params']:
                state = self.state[p]
                
                
                if (p.grad is None):
                    continue
                grad = p.grad.data
                if group['weight_decay']!=0:
                    grad.add_(group['weight_decay'], p.data)
                    
                state['update_w'][0] = torch.clone(grad.data).detach()
                    
               
                
                    
        return loss
    
    def local_step_w(self,loss,node = 1):
        self.xupdate[node] = 0
        for group in self.param_groups:
            local_grad = autograd.grad(loss, group['params'], retain_graph = True)#, create_graph=True)
            i = 0
            for p in group['params']:
                if (p.grad is None):
                     continue
                state = self.state[p]
                #print(local_grad[i].shape, state['lamb'].shape,i)
                self.xupdate[node] = self.xupdate[node] + torch.sum(local_grad[i]*state['lamb'])
                state['update_grad'][node] = local_grad[i].detach()
                i = i+1
        loss.backward()
        self.xupdate = self.xupdate.detach()

    def local_step_w1(self,loss, node = 1):
        tmp = 0
        for group in self.param_groups:
            local_grad = autograd.grad(loss, group['params'], retain_graph = True, create_graph=True)
            i = 0
            for p in group['params']:
                if (p.grad is None):
                     continue
                state = self.state[p]
                tmp = tmp + torch.sum(local_grad[i]*state['lamb'])
                i = i+1
        for group in self.param_groups:
            local_grad = autograd.grad(tmp, group['params'], retain_graph = True)
            i = 0
            for p in group['params']:
                if (p.grad is None):
                     continue
                state = self.state[p]
                state['update_w'][node] = local_grad[i]
                
                i = i+1
        tmp.backward()
        tmp = tmp.detach()        
        
    def step_lamb(self,active):
        for group in self.param_groups:
            for p in group['params']:
                if (p.grad is None):
                     continue
                state = self.state[p]
                tmp = 0
                N = group['N']
                total = torch.sum(active[1:]).float()
                for i in range(1,N):
                   if (active[i]==1):
                     tmp = tmp + self.x[i]*state['update_grad'][i]
                state['lambm'] = tmp*(N-1)/total.detach()
                state['lamb'] = state['lamb'] + group['lamb_lr']*state['lambm']
                state['lamb'] = state['lamb'].detach()
    
    
    def step_x(self,active):
        total = torch.sum(active[1:]).float()
        self.x[active] = self.x[active] - self.x_lr*(self.N-1)/total*(self.xupdate[active])
        #print(self.x[active].size())
        #print(self.x_lr, (self.N-1), total)
        #self.x = self.x - self.x_lr*self.xupdate
        value, _ = torch.sort(self.x[1:], descending = True)
        S = 0
        tmp = 0
        N = self.N
        for i in range(N-1):
             S = S+ value[i]
             if (value[i]+(1-S)/(i+1)>0):
                tmp = (1-S)/(i+1)
        self.x[1:] = self.x[1:] + tmp
        self.x[self.x<0] = 0        
        
        self.x = self.x.detach()
        
    def step_w(self,active):
       for group in self.param_groups:
            for p in group['params']:
                if (p.grad is None):
                     continue
                state = self.state[p]
                tmp = state['update_w'][0]
                N = self.N
                total = torch.sum(active[1:]).float()
                for i in range(1,N):
                    if (active[i]==1):
                        tmp = tmp + self.x[i]*state['update_w'][i] + group['Gamma']*self.x[i] * state['update_grad'][i]
               
                tmp = tmp *(N-1)/total
                #print(tmp)
                state['m'] = ((1-group['momentum'])*tmp + group['momentum']*state['m']).detach()
                 
                p.data.add_(-group['lr']*state['m'])
                p.data = p.data.detach() 
