import torch
from torch.optim import Optimizer
import torch.autograd as autograd

class SGD(Optimizer):


    def __init__(self, params, lr=0.0, momentum=0, weight_decay=0, N = 4, x_lr = 1e-3, lamb_lr = 1e-2, init_x = None, device = 'cpu'):

        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, N = N, lamb_lr = lamb_lr)
        if (init_x is None):
            self.x = torch.ones(N, device = device)/(N-1)
            self.x[0] = 0
        else:
            self.x = init_x
        
        self.N = N
        self.xupdate = torch.zeros(N,device =device)
        self.x_lr = x_lr
        self.device = device               
        super(SGD, self).__init__(params, defaults)
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                tmp = torch.clone(p.data)
                l = []
                for i in range(N):
                    l.append(tmp)
                tmp = torch.stack(l)
                state['m'] = torch.zeros_like(p.data, device = p.device)
                state['w'] = torch.clone(tmp)
                state['lamb'] = torch.zeros_like(tmp, device =p.device)
                state['update_grad'] = torch.zeros_like(tmp, device = p.device) 
                state['update_w'] = torch.zeros_like(tmp, device = p.device)
        #super(SGD,self).__init__(params, defaults) 
             
        
    def step(self, closure=None, node = None):
        
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        #print(node) 
        
        for group in self.param_groups: 
            for p in group['params']:
                state = self.state[p]
                
                
                if (p.grad is None):
                    continue
                grad = p.grad.data
                if group['weight_decay']!=0:
                    grad.add_(group['weight_decay'], p.data)
                    
                state['update_w'][node] = torch.clone(grad).detach()
                    
               
        return loss
    
    
    def local_step_w1(self,loss, node = 1,active = None):
        
        tmp = torch.zeros(self.N,device = self.device)
        for group in self.param_groups:
            local_grad = autograd.grad(loss, group['params'], retain_graph = True, create_graph=True)
            i = 0
            for p in group['params']:
                if (p.grad is None):
                     continue
                state = self.state[p]
                for k in range(1,self.N):
                    tmp[k] = tmp[k] + torch.sum(local_grad[i]*state['lamb'][k])
                i = i+1
        total = torch.sum(active[1:]).float()
        for k in range(1,self.N):        
            for group in self.param_groups:
                local_grad = autograd.grad(tmp[k], group['params'], retain_graph = True)
                i = 0
                for p in group['params']:
                    if (p.grad is None):
                        continue
                    state = self.state[p]
                    #state['lamb'] = (state['lamb']+state['update_w'][0]).detach()
                    
                    state['update_grad'][k] =(state['update_grad'][k] + self.x[k]*(self.N-1)/total*local_grad[i]).detach()
                    i = i+1
        tmp[1].backward() 
        tmp = tmp.detach()
    
    def update_mul(self,active):
        for group in self.param_groups:
            for p in group['params']:
                if (p.grad is None):
                     continue
                state = self.state[p]
                for i in range(1,self.N):
                    if (active[i]):
                        state['lamb'][i] = (state['lamb'][i] - group['lr']*state['update_grad'][i] - group['lr']*state['update_w'][i]).detach()        
    
    def pre_step_x(self):
        for group in self.param_groups:
            for p in group['params']:
                if (p.grad is None):
                     continue
                state = self.state[p]
                state['lamb'].zero_()
    
    def pre_step_w1(self):
    
        for group in self.param_groups:
            for p in group['params']:
                if (p.grad is None):
                     continue
                state = self.state[p]
                state['update_grad'].zero_()
    
    
    def gen_update_x(self,node = 1):
        self.xupdate[node] = 0
        for group in self.param_groups:
            for p in group['params']:
                if (p.grad is None):
                     continue
                state = self.state[p]
                self.xupdate[node]+= torch.sum(state['update_w'][0]*state['lamb'][node])
        self.xupdate = self.xupdate.detach()
        
    def step_x(self,active):
        total = torch.sum(active[1:]).float()
        self.x[active] = self.x[active] - self.x_lr*(self.N-1)/total*(self.xupdate[active])
        value, _ = torch.sort(self.x[1:], descending = True)
        S = 0
        tmp = 0
        N = self.N
        for i in range(N-1):
             S = S+ value[i]
             if (value[i]+(1-S)/(i+1)>0):
                tmp = (1-S)/(i+1)
        self.x[1:] = self.x[1:] + tmp
        self.x[self.x<0] = 0        
        
        self.x = self.x.detach()
        
    def step_w(self, active):
       for group in self.param_groups:
            for p in group['params']:
                if (p.grad is None):
                     continue
                state = self.state[p]
                tmp = 0# state['update_w'][0]
                N = self.N
                total = torch.sum(active[1:]).float()
                
                for i in range(1,N):
                    if (active[i]==1):
                        tmp = tmp + self.x[i] * state['update_w'][i]
               
                tmp = tmp *(N-1)/total
                state['m'] = ((1-group['momentum'])*tmp + group['momentum']*state['m']).detach()
                p.data.add_(-group['lr']*state['m'])
                p.data = p.data.detach() 
