import math
import numpy as np
import torch
import copy
from functools import reduce
from torch.optim.optimizer import Optimizer, required
    
class Adam_CAM_HD(Optimizer):
    
    def __init__(self, params, args):    
        
        self.lr = args["lr"] #
        self.betas = args["betas"]
        self.eps = args["eps"]
        self.weight_decay = args["weight_decay"]
        self.hypergrad_lr = args["hypergrad_lr"]
        self.level = args["level"]
        self.reg_lr_layer = args["reg_lr_layer"]
        self.reg_lr_unit = args["reg_lr_unit"]
        self.reg_lr_ts = args["reg_lr_ts"]
        self.kappa = args["kappa"]
        self.print_lr = args["print_lr"]
        self.cuda = args["cuda"]
        self.update_from_combination = args["update_from_combination"]
        
        self.gamma_1 = args["gamma_1"]
        self.gamma_2 = args["gamma_2"]
        self.gamma_3 = args["gamma_3"]
        self.delta = args["delta"]
        self.print_gamma = args["print_gamma"]
        self.hd_decay = args["hd_decay"]
        self.hd_decay_coef = args["hd_decay_coef"]
        self.timestep = 0
        self.method = args["method"]

        defaults = dict(lr=self.lr, betas=self.betas, eps=self.eps,
                        weight_decay=self.weight_decay, hypergrad_lr=self.hypergrad_lr)

        super(Adam_CAM_HD, self).__init__(params, defaults)
        self._params = self.param_groups[0]['params']
        self._params_numel = reduce(lambda total, p: total + p.numel(), self._params, 0)
        
        self.lr = torch.tensor(self.lr)
        self.lr_global = self.lr
        self.timestep = 0
        if self.level == "layer" or self.level == "layer_global":
            self.lr_list = self.lr * torch.ones(len(self._params))
            self.lr_list_1 = self.lr * torch.ones(len(self._params))
            self.lr_list_2 = copy.deepcopy(self.lr_list_1)
        elif self.level == "unit":
            self.lr_list = []
            for group in self.param_groups:
                for p in group['params']:
                    try:
                        lr_item = self.lr * torch.ones(p.data.shape[1])
                    except:
                        lr_item = self.lr * torch.ones(p.data.shape)
                    self.lr_list.append(lr_item)
        elif self.level == "para":
            self.lr_list_para = []
            self.lr_list_para_1 = []
            for group in self.param_groups:
                for p in group['params']:
                    lr_item = self.lr * torch.ones(p.data.shape)
                    self.lr_list_para.append(lr_item)
                    self.lr_list_para_1.append(lr_item)
        elif self.level == "para_layer_global":
            self.lr_list = self.lr * torch.ones(len(self._params))
            if self.cuda:
                self.lr_list = self.lr_list.cuda()
            
            self.lr_list_para = []
            self.lr_list_para_1 = []
            for group in self.param_groups:
                for p in group['params']:
                    lr_item = self.lr * torch.ones(p.data.shape)
                    if self.cuda:
                        lr_item = lr_item.cuda()
                    self.lr_list_para.append(lr_item)
                    self.lr_list_para_1.append(lr_item)
                    
            self.lr_list_para_2 = copy.deepcopy(self.lr_list_para_1)
            
        self.reg_lr_layer_list = self.reg_lr_layer * torch.ones(len(self._params))
        if self.cuda:
            self.reg_lr_layer_list = self.reg_lr_layer_list.cuda()
        
    def get_global_h(self):
        
        for group in self.param_groups:
            h=0
            h_len_list = []
            
            for p in group['params']:
                
                if p.grad is None:
                    continue
                
                grad = p.grad.data
                state = self.state[p]
                
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                prev_bias_correction1 = 1 - beta1 ** (state['step'])
                prev_bias_correction2 = 1 - beta2 ** (state['step'])
                    
                grad_1 = grad.view(-1)
                                           
                h += torch.dot(grad.view(-1), torch.div(exp_avg, exp_avg_sq.sqrt().add_(group['eps'])).view(-1)) * math.sqrt(prev_bias_correction2) / prev_bias_correction1
                h_len_list.append(len(grad_1))
                        
            h_global = h
        
        return h_global, h_len_list
        
    def get_layer_h_list(self):

        h_layer_list = []
            
        for group in self.param_groups:
                
            i = 0

            for p in group['params']:

                if p.grad is None:
                    continue
                grad = p.grad.data

                state = self.state[p]
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']   
                
                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                prev_bias_correction1 = 1 - beta1 ** (state['step'] - 1)
                prev_bias_correction2 = 1 - beta2 ** (state['step'] - 1)
                                    
                h = torch.dot(grad.view(-1), torch.div(exp_avg, exp_avg_sq.sqrt().add_(group['eps'])).view(-1)) * math.sqrt(prev_bias_correction2) / prev_bias_correction1
                        
                h_layer_list.append(h.item())
                i = i + 1

        return h_layer_list
    
    def get_para_h_list(self):     
        
        h_para_list = []
        
        for group in self.param_groups:
                
            h_para_list = []
            i = 0

            for p in group['params']:
 
                if p.grad is None:
                    continue
            
                grad = p.grad.data
                state = self.state[p]
        
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']
                
                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                prev_bias_correction1 = 1 - beta1 ** (state['step'] - 1)
                prev_bias_correction2 = 1 - beta2 ** (state['step'] - 1)
                    
                h = grad * torch.div(exp_avg, exp_avg_sq.sqrt().add_(group['eps'])) * math.sqrt(prev_bias_correction2) / prev_bias_correction1
                h_para_list.append(h)
        
        return h_para_list
        
    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()
            
        if self.hd_decay == True:
            tau = np.exp(-self.hd_decay_coef*self.timestep)
        else:
            tau = 1
         
        if self.timestep > 0:
            h_global, h_len_list = self.get_global_h()
            # h_layer_list = self.get_layer_h_list()
            # h_para_list = self.get_para_h_list()
            
            if self.level == "para_layer_global":
                self.lr_global += tau * self.hypergrad_lr * h_global * self.gamma_3
            if self.level == "layer_global":
                self.lr_global += tau * self.hypergrad_lr * h_global * self.gamma_2
            
        # self.timestep = self.timestep + 1
        gamma_list = []
        
        for group in self.param_groups:    
            
            i = 0
            for p in group['params']:
                
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']
                
                # print("state_step", state['step'])
                state['step'] = state['step'] + 1
                
                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                if state['step'] > 1:
                    prev_bias_correction1 = 1 - beta1 ** (state['step'] - 1)
                    prev_bias_correction2 = 1 - beta2 ** (state['step'] - 1)
                    
                    h_layer = torch.dot(grad.view(-1), torch.div(exp_avg, exp_avg_sq.sqrt().add_(group['eps'])).view(-1)) * math.sqrt(prev_bias_correction2) / prev_bias_correction1
                    h_para = grad * torch.div(exp_avg, exp_avg_sq.sqrt().add_(group['eps'])) * math.sqrt(prev_bias_correction2) / prev_bias_correction1
                    
                    # Hypergradient for Adam:                    
                    if self.level == "para_layer_global":
                        
                        n_para_layer = p.data.numel()
                        
                        if self.update_from_combination == False:
                            self.lr_list_para[i] += tau * group['hypergrad_lr'] * h_para * self.gamma_1 * n_para_layer * np.sum(h_len_list)/h_len_list[i]
                            self.lr_list[i] += tau * group['hypergrad_lr'] * h_layer  * self.gamma_2 * np.sum(h_len_list)/h_len_list[i]
                        else:
                            self.lr_list_para[i] = self.lr_list_para_1[i] + tau * group['hypergrad_lr'] * h_para * self.gamma_1 * n_para_layer * np.sum(h_len_list)/h_len_list[i]
                            self.lr_list[i] = torch.mean(self.lr_list_para_1[i]) + tau * group['hypergrad_lr'] * h_layer  * self.gamma_2 * np.sum(h_len_list)/h_len_list[i]
                            self.lr_global =  torch.mean(torch.tensor([torch.mean(self.lr_list_para_1[i]) for i in range(len(self.lr_list_para_1))])) + tau * group['hypergrad_lr'] * h_global * self.gamma_3
                    
                        if self.cuda:
                            self.lr_list_para_1[i] = self.gamma_1*self.lr_list_para[i] + self.gamma_2*self.lr_list[i] * torch.ones(p.data.shape).cuda() + self.gamma_3 * self.lr_global * torch.ones(p.data.shape).cuda()                             
                        else:
                            self.lr_list_para_1[i] = self.gamma_1*self.lr_list_para[i] + self.gamma_2*self.lr_list[i] * torch.ones(p.data.shape) + self.gamma_3 * self.lr_global * torch.ones(p.data.shape) 
                            
                        if self.hd_decay == True:
                            self.lr_list_para_2[i] = tau * self.lr_list_para_1[i] + (1 - tau) * self.lr
                    
                        self.gamma_1 = self.gamma_1 + tau * torch.sum(self.delta*h_para*self.lr_list_para[i]) #-self.lr_global)
                        self.gamma_2 = self.gamma_2 + tau * torch.sum(self.delta*h_para*self.lr_list[i])
                        self.gamma_3 = self.gamma_3 + tau * torch.sum(self.delta*h_para*self.lr_global)
                        
                        sum_gamma = self.gamma_1 + self.gamma_2 + self.gamma_3
                        
                        self.gamma_1 = self.gamma_1/sum_gamma
                        self.gamma_2 = self.gamma_2/sum_gamma
                        self.gamma_3 = self.gamma_3/sum_gamma
                        
                        if self.gamma_1<0: 
                            self.gamma_1=0
                        if self.gamma_1>1:
                            self.gamma_1=1
                            
                        if self.gamma_2<0: 
                            self.gamma_2=0
                        if self.gamma_2>1:
                            self.gamma_2=1 
                            
                        if self.gamma_3<0: 
                            self.gamma_3=0
                        if self.gamma_3>1:
                            self.gamma_3=1
                            
                        gamma_list = [self.gamma_1, self.gamma_2, self.gamma_3]
                    
                    if self.level == "layer_global": # try to only involve fully connected layer in CNNs
                                                                              
                        if self.update_from_combination == False:
                            self.lr_list[i] += tau * group['hypergrad_lr'] * h_layer * self.gamma_1 * np.sum(h_len_list)/h_len_list[i]
                        else:
                            self.lr_list[i] = self.lr_list_1[i] + tau * group['hypergrad_lr'] * h_layer * self.gamma_1 * np.sum(h_len_list)/h_len_list[i]
                            self.lr_global = torch.mean(self.lr_list_1[i]) + tau * group['hypergrad_lr'] * h_global * self.gamma_2  
                        
                        self.lr_list_1[i] = self.gamma_1*self.lr_list[i]+ self.gamma_2*self.lr_global
                        
                        if self.hd_decay == True:
                            self.lr_list_2[i] = tau * self.lr_list_1[i] + (1 - tau) * self.lr
                        
                        self.gamma_1 = self.gamma_1 + tau * self.delta*h_layer*self.lr_list[i]
                        self.gamma_2 = self.gamma_2 + tau * self.delta*h_layer*self.lr_global
                        
                        self.gamma_1 = self.gamma_1/(self.gamma_1 + self.gamma_2)
                        self.gamma_2 = self.gamma_2/(self.gamma_1 + self.gamma_2)
                        
                        if self.gamma_1<0: 
                            self.gamma_1=0
                        if self.gamma_1>1:
                            self.gamma_1=1
                            
                        if self.gamma_2<0: 
                            self.gamma_2=0
                        if self.gamma_2>1:
                            self.gamma_2=1    
                            
                        gamma_list = [self.gamma_1, self.gamma_2]
    
                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                if self.level == "para_layer_global":
                    if self.hd_decay == True:
                        step_size = self.lr_list_para_2[i] * math.sqrt(bias_correction2) / bias_correction1
                    else:
                        step_size = self.lr_list_para_1[i] * math.sqrt(bias_correction2) / bias_correction1
                    if self.cuda:
                        step_size = step_size.cuda()
                    p.data.add_(- step_size* (exp_avg/denom))  
                
                elif self.level == "layer_global":
                    if self.hd_decay == True:
                        step_size = self.lr_list_2[i] * math.sqrt(bias_correction2) / bias_correction1
                    else:
                        step_size = self.lr_list_1[i] * math.sqrt(bias_correction2) / bias_correction1       
                    p.data.addcdiv_(-step_size, exp_avg, denom)
                    
                elif self.level == None:
                    step_size = self.lr * math.sqrt(bias_correction2) / bias_correction1
                    p.data.addcdiv_(-step_size, exp_avg, denom)
                else:
                    print("level not defined!")
                        
                i = i + 1
                # state['step'] += 1
            if self.print_gamma == True:
                print("self.gamma_1, self.gamma_2", self.gamma_1, self.gamma_2)                                       
            if self.print_lr == True:
                print("lr_list_1", self.lr_list_1)
                
        self.timestep = self.timestep + 1

        return loss, gamma_list
    
class Adam_CAM_HD_1(Optimizer):
    
    def __init__(self, params, args):    
        
        self.lr = args["lr"] #
        self.betas = args["betas"]
        self.eps = args["eps"]
        self.weight_decay = args["weight_decay"]
        self.hypergrad_lr = args["hypergrad_lr"]
        self.level = args["level"]
        self.reg_lr_layer = args["reg_lr_layer"]
        self.reg_lr_unit = args["reg_lr_unit"]
        self.reg_lr_ts = args["reg_lr_ts"]
        self.kappa = args["kappa"]
        self.cuda = args["cuda"]
        self.update_from_combination = args["update_from_combination"]
        
        self.gamma_1 = args["gamma_1"]
        self.gamma_2 = args["gamma_2"]
        self.gamma_3 = args["gamma_3"]
        self.delta = args["delta"]
        self.print_gamma = args["print_gamma"]
        self.hd_decay = args["hd_decay"]
        self.timestep = 0

        defaults = dict(lr=self.lr, betas=self.betas, eps=self.eps,
                        weight_decay=self.weight_decay, hypergrad_lr=self.hypergrad_lr)
        super(Adam_CAM_HD_1, self).__init__(params, defaults)
        self._params = self.param_groups[0]['params']
        self._params_numel = reduce(lambda total, p: total + p.numel(), self._params, 0)
        
        self.lr_global = self.lr
        if self.level == "layer" or self.level == "layer_global":
            self.lr_list = self.lr * torch.ones(len(self._params))
            self.lr_list_1 = self.lr * torch.ones(len(self._params))
        elif self.level == "unit":
            self.lr_list = []
            for group in self.param_groups:
                for p in group['params']:
                    try:
                        lr_item = self.lr * torch.ones(p.data.shape[1])
                    except:
                        lr_item = self.lr * torch.ones(p.data.shape)
                    # print("lr_item.shape", lr_item.shape)
                    self.lr_list.append(lr_item)
        elif self.level == "para":
            self.lr_list_para = []
            self.lr_list_para_1 = []
            for group in self.param_groups:
                for p in group['params']:
                    lr_item = self.lr * torch.ones(p.data.shape)
                    self.lr_list_para.append(lr_item)
                    self.lr_list_para_1.append(lr_item)
        elif self.level == "para_layer_global":
            self.lr_list = self.lr * torch.ones(len(self._params))
            if self.cuda:
                self.lr_list = self.lr_list.cuda()
            
            self.lr_list_para = []
            self.lr_list_para_1 = []
            for group in self.param_groups:
                for p in group['params']:
                    lr_item = self.lr * torch.ones(p.data.shape)
                    if self.cuda:
                        lr_item = lr_item.cuda()
                    self.lr_list_para.append(lr_item)
                    self.lr_list_para_1.append(lr_item)
            
        # print("self.lr_list_para_1", self.lr_list_para_1)
        self.reg_lr_layer_list = self.reg_lr_layer * torch.ones(len(self._params))
        if self.cuda:
            self.reg_lr_layer_list = self.reg_lr_layer_list.cuda()
        
    def get_global_h(self):
        # set the range of layers!!!
        
        for group in self.param_groups:
            h=0
            h_len_list = []
            
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    state['step'] = 1
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                # state['step_global'] += 1
                # print("state['step']", state['step'])

                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                if state['step'] > 1:
                    prev_bias_correction1 = 1 - beta1 ** (state['step'] - 1)
                    prev_bias_correction2 = 1 - beta2 ** (state['step'] - 1)
                    # Hypergradient for Adam:
                    
                    grad_1 = grad.view(-1)
                    grad_2 = torch.div(exp_avg, exp_avg_sq.sqrt().add_(group['eps'])).view(-1)
                                           
                    h += torch.dot(grad.view(-1), torch.div(exp_avg, exp_avg_sq.sqrt().add_(group['eps'])).view(-1)) * math.sqrt(prev_bias_correction2) / prev_bias_correction1
                    h_len_list.append(len(grad_1))
                        
            h_global = h
            # print("h_len_list", h_len_list)
        
        return h_global, h_len_list
        
    def get_layer_h_list(self):
         # set the range of layers!!!
        # print("get_layer_h_list!!!")
        
        h_layer_list = []
            
        for group in self.param_groups:
                
            i = 0

            for p in group['params']:
                # print("state['step']", state['step'])    
                if p.grad is None:
                    continue
                grad = p.grad.data
                # print("state['step']", state['step'])
                state = self.state[p]
                # print("state['step']", state['step'])
        
                if len(state) == 0:
                    # print("state['step']", state['step'])
                    state['step'] = 1
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                    # print("state['step']", state['step'])
                    
                # print("state['step']", state['step'])        
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                # state['step_layer'] += 1
                # print("i, get_layer_h_list:state['step']", i, state['step'])
                
                if state['step'] > 1:
                    prev_bias_correction1 = 1 - beta1 ** (state['step'] - 1)
                    prev_bias_correction2 = 1 - beta2 ** (state['step'] - 1)
                    
                    # Hypergradient for Adam:                    
                    h = torch.dot(grad.view(-1), torch.div(exp_avg, exp_avg_sq.sqrt().add_(group['eps'])).view(-1)) * math.sqrt(prev_bias_correction2) / prev_bias_correction1
                        
                    h_layer_list.append(h.item())
                i = i + 1
                # print("state['step']", state['step'])
        
        # print("end:get_layer_h_list")
        # print("h_layer_list", h_layer_list)
        return h_layer_list
    
    def get_unit_h_list(self):
        
        h_layer_list = []
        
        for group in self.param_groups:
                
            h_layer_list = []        
            for p in group['params']:
                
                if p.grad is None:
                    continue
                        
                grad = p.grad.data
                state = self.state[p]
                grad_dim = grad.shape
                
                # print("grad_dim", grad_dim)
                    
                if len(grad_dim) == 1:
                    try:
                        grad_dim_col = grad_dim[0]
                    except:
                        grad_dim_col = len(grad_dim)
                if len(grad_dim) == 2:
                    grad_dim_row = grad_dim[0]
                    grad_dim_col = grad_dim[1]
                
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                        
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                if state['step'] > 1:
                    prev_bias_correction1 = 1 - beta1 ** (state['step'] - 1)
                    prev_bias_correction2 = 1 - beta2 ** (state['step'] - 1)
                    
                    h_unit_list = []
                    
                    for unit in range(grad_dim_col):
                        try:
                            h = torch.dot(grad[:, unit], torch.div(exp_avg[:, unit], exp_avg_sq[:, unit].sqrt().add_(group['eps']))) * math.sqrt(prev_bias_correction2) / prev_bias_correction1
                        except:                           
                            h = grad[unit] * torch.div(exp_avg[unit], exp_avg_sq[unit].sqrt().add_(group['eps'])) * math.sqrt(prev_bias_correction2) / prev_bias_correction1
                        h_unit_list.append(h)
                    
                try: 
                    h_unit_list = torch.tensor(h_unit_list)
                    h_layer_list.append(h_unit_list) 
                except:
                    0
                    
        return h_layer_list 
        
    def get_para_h_list(self):     
        
        h_para_list = []
        
        for group in self.param_groups:
                
            h_para_list = []
            i = 0

            for p in group['params']:
 
                if p.grad is None:
                    continue
            
                grad = p.grad.data
                state = self.state[p]
        
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                        
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                # state['step'] += 1

                if state['step'] > 1:
                    prev_bias_correction1 = 1 - beta1 ** (state['step'] - 1)
                    prev_bias_correction2 = 1 - beta2 ** (state['step'] - 1)
                    
                    # Hypergradient for Adam:                    
                    h = grad * torch.div(exp_avg, exp_avg_sq.sqrt().add_(group['eps'])) * math.sqrt(prev_bias_correction2) / prev_bias_correction1
                    h_para_list.append(h)
        
        return h_para_list
        
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()
            
        h_global, h_len_list = self.get_global_h()
        h_layer_list = self.get_layer_h_list()
        h_para_list = self.get_para_h_list()

        for group in self.param_groups:    
            
            i = 0
            for p in group['params']:
                
                # print("p.data.shape", p.data.shape)
                
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 1
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                if state['step'] > 1:
                    prev_bias_correction1 = 1 - beta1 ** (state['step'] - 1)
                    prev_bias_correction2 = 1 - beta2 ** (state['step'] - 1)
                    # Hypergradient for Adam:
                    
                    if self.level == "para_layer_global":
                        
                        #if self.print_gamma == True:
                            #print("self.gamma_1, self.gamma_2, self.gamma_3", self.gamma_1, self.gamma_2, self.gamma_3)
                        
                        n_para_layer = p.data.numel()
                        
                        if self.update_from_combination == False:
                            
                            d_gamma_1_term = np.exp(self.gamma_1)/(np.exp(self.gamma_1) + np.exp(self.gamma_2) + np.exp(self.gamma_3))
                            d_gamma_1 = d_gamma_1_term*(1-d_gamma_1_term)
                            d_gamma_2_term = np.exp(self.gamma_2)/(np.exp(self.gamma_1) + np.exp(self.gamma_2) + np.exp(self.gamma_3))
                            d_gamma_2 = d_gamma_2_term*(1-d_gamma_2_term)
                            d_gamma_3_term = np.exp(self.gamma_3)/(np.exp(self.gamma_1) + np.exp(self.gamma_2) + np.exp(self.gamma_3))
                            d_gamma_3 = d_gamma_3_term*(1-d_gamma_3_term)
                            
                            d_lr_para = d_gamma_1_term
                            d_lr_list = d_gamma_2_term
                            d_lr_global = d_gamma_3_term
                            
                            self.lr_list_para[i] += group['hypergrad_lr'] * h_para_list[i] * d_lr_para * n_para_layer * np.sum(h_len_list)/h_len_list[i]
                            self.lr_list[i] += group['hypergrad_lr'] * h_layer_list[i]  * d_lr_list * np.sum(h_len_list)/h_len_list[i]
                            self.lr_global += group['hypergrad_lr'] * h_global * d_lr_global
                        else:
                            self.lr_list_para[i] = self.lr_list_para_1[i] + group['hypergrad_lr'] * h_para_list[i] * self.gamma_1 * n_para_layer * np.sum(h_len_list)/h_len_list[i]
                            self.lr_list[i] = torch.mean(self.lr_list_para_1[i]) + group['hypergrad_lr'] * h_layer_list[i]  * self.gamma_2 * np.sum(h_len_list)/h_len_list[i]
                            self.lr_global =  torch.mean(torch.tensor([torch.mean(self.lr_list_para_1[i]) for i in range(len(self.lr_list_para_1))])) + group['hypergrad_lr'] * h_global * self.gamma_3
                    
                        if self.cuda:
                            self.lr_list_para_1[i] = d_gamma_1_term*self.lr_list_para[i] + d_gamma_2_term*self.lr_list[i] * torch.ones(p.data.shape).cuda() + d_gamma_3_term * self.lr_global * torch.ones(p.data.shape).cuda()                             
                        else:
                            self.lr_list_para_1[i] = self.gamma_1*self.lr_list_para[i] + self.gamma_2*self.lr_list[i] * torch.ones(p.data.shape) + self.gamma_3 * self.lr_global * torch.ones(p.data.shape) 
                            
                        if self.hd_decay == True:
                            self.lr_list_para_1[i] = np.exp(-self.timestep) * self.lr_list_para_1[i]
                            + (1 - np.exp(-self.timestep)) * self.lr
                    
                        self.gamma_1 = self.gamma_1 + torch.sum(self.delta*h_para_list[i]*d_gamma_1) #-self.lr_global)
                        self.gamma_2 = self.gamma_2 + torch.sum(self.delta*h_para_list[i]*d_gamma_2)
                        self.gamma_3 = self.gamma_3 + torch.sum(self.delta*h_para_list[i]*d_gamma_3)
                    
                    if self.level == "layer_global": # try to only involve fully connected layer in CNNs
      
                        if self.update_from_combination == False:
                            
                            try:
                                d_gamma_1_term = np.exp(self.gamma_1)/(np.exp(self.gamma_1) + np.exp(self.gamma_2))            
                                d_gamma_2_term = np.exp(self.gamma_2)/(np.exp(self.gamma_1) + np.exp(self.gamma_2))
                            except:
                                d_gamma_1_term = torch.exp(self.gamma_1)/(torch.exp(self.gamma_1) + torch.exp(self.gamma_2))            
                                d_gamma_2_term = torch.exp(self.gamma_2)/(torch.exp(self.gamma_1) + torch.exp(self.gamma_2))                                
                                
                            d_gamma_1 = d_gamma_1_term*(1-d_gamma_1_term)
                            d_gamma_2 = d_gamma_2_term*(1-d_gamma_2_term)
                            
                            d_lr_list = d_gamma_1_term
                            d_lr_global = d_gamma_2_term
    
                            self.lr_list[i] += group['hypergrad_lr'] * h_layer_list[i] * d_lr_list * np.sum(h_len_list)/h_len_list[i]
                            self.lr_global += group['hypergrad_lr'] * h_global * d_lr_global
                        else:
                            self.lr_list[i] = self.lr_list_1[i] + group['hypergrad_lr'] * h_layer_list[i] * d_lr_list * np.sum(h_len_list)/h_len_list[i]
                            self.lr_global = torch.mean(self.lr_list_1[i]) + group['hypergrad_lr'] * h_global * d_lr_global
                        
                        if self.cuda:
                            self.lr_list_1[i] = d_gamma_1_term* self.lr_list[i].cuda() + d_gamma_2_term*self.lr_global.cuda()
                        else:
                            self.lr_list_1[i] = d_gamma_1_term * self.lr_list[i] + d_gamma_2_term * self.lr_global
                        
                        self.gamma_1 = self.gamma_1 + self.delta*h_layer_list[i]*d_gamma_1*self.lr_list[i] 
                        self.gamma_2 = self.gamma_2 + self.delta*h_layer_list[i]*d_gamma_2*self.lr_global
                            
                    if self.level == "layer":
                        self.lr_list[i] += group['hypergrad_lr'] * h_layer_list[i]
                        self.lr_list_1[i] = self.lr_list[i]

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                if self.level == "para_layer_global":
                    step_size = self.lr_list_para_1[i] * math.sqrt(bias_correction2) / bias_correction1
                    if self.cuda:
                        step_size = step_size.cuda()
                    p.data.add_(- step_size* (exp_avg/denom))  
                else:
                    step_size = self.lr_list_1[i] * math.sqrt(bias_correction2) / bias_correction1
                    p.data.addcdiv_(-step_size, exp_avg, denom)
                
                i = i + 1
                state['step'] += 1
            
            if self.print_gamma == True:
                print("self.gamma_1, self.gamma_2", self.gamma_1, self.gamma_2)    
        self.timestep = self.timestep + 1

        return loss
    

    