import math
import numpy as np
import copy
import torch
from functools import reduce
from torch.optim.optimizer import Optimizer, required

class SGD_CAM_HD(Optimizer):
    
    def __init__(self, params, args):    
        
        self.lr = args["lr"] 
        self.momentum = args["momentum"]
        self.weight_decay = args["weight_decay"]
        self.dampening = args["dampening"]
        self.nesterov = args["nesterov"]
        
        self.hypergrad_lr = args["hypergrad_lr"]
        self.level = args["level"]
        self.reg_lr_layer = args["reg_lr_layer"]
        self.reg_lr_unit = args["reg_lr_unit"]
        self.reg_lr_ts = args["reg_lr_ts"]
        self.kappa = args["kappa"]
        
        self.gamma_1 = args["gamma_1"]
        self.gamma_2 = args["gamma_2"]
        self.gamma_3 = args["gamma_3"]
        self.delta = args["delta"]
        self.print_gamma = args["print_gamma"]
        self.update_from_combination = args["update_from_combination"]
        
        self.hd_decay = args["hd_decay"]
        self.hd_decay_coef = args["hd_decay_coef"]
        self.timestep = 0
        self.method = args["method"]
        self.cuda = args["cuda"]
        
        if self.lr is not required and self.lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(self.lr))
        if self.momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(self.momentum))
        if self.weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(self.weight_decay))

        defaults = dict(lr=self.lr, momentum=self.momentum, dampening=self.dampening,
                        weight_decay=self.weight_decay, nesterov=self.nesterov, hypergrad_lr=self.hypergrad_lr)

        super(SGD_HD_1, self).__init__(params, defaults)
        self._params = self.param_groups[0]['params']
        self._params_numel = reduce(lambda total, p: total + p.numel(), self._params, 0)
        
        self.lr = torch.tensor(self.lr)
        self.lr_global = self.lr
        self.timestep = 0
        if self.level == "layer" or self.level == "layer_global":
            self.lr_list = self.lr * torch.ones(len(self._params))
            self.lr_list_1 = self.lr * torch.ones(len(self._params))
            self.lr_list_2 = copy.deepcopy(self.lr_list_1)
        elif self.level == "unit":
            self.lr_list = []
            for group in self.param_groups:
                for p in group['params']:
                    try:
                        lr_item = self.lr * torch.ones(p.data.shape[1])
                    except:
                        lr_item = self.lr * torch.ones(p.data.shape)
                    self.lr_list.append(lr_item)
        elif self.level == "para":
            self.lr_list_para = []
            self.lr_list_para_1 = []
            for group in self.param_groups:
                for p in group['params']:
                    lr_item = self.lr * torch.ones(p.data.shape)
                    self.lr_list_para.append(lr_item)
                    self.lr_list_para_1.append(lr_item)
        elif self.level == "para_layer_global":
            self.lr_list = self.lr * torch.ones(len(self._params))
            if self.cuda:
                self.lr_list = self.lr_list.cuda()
            
            self.lr_list_para = []
            self.lr_list_para_1 = []
            for group in self.param_groups:
                for p in group['params']:
                    lr_item = self.lr * torch.ones(p.data.shape)
                    if self.cuda:
                        lr_item = lr_item.cuda()
                    self.lr_list_para.append(lr_item)
                    self.lr_list_para_1.append(lr_item)
                    
            self.lr_list_para_2 = copy.deepcopy(self.lr_list_para_1)
            
        self.reg_lr_layer_list = self.reg_lr_layer * torch.ones(len(self._params))
        if self.cuda:
            self.reg_lr_layer_list = self.reg_lr_layer_list.cuda()
        
    def get_global_h(self):
             
        for group in self.param_groups:
            h_len_list = []
            h = 0
            
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                state = self.state[p]
                       
                if group['weight_decay'] != 0:
                    d_p = d_p.add(group['weight_decay'], p.data)
                    
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf

                try:
                    d_p_prev = state['grad_prev_para']
                except:
                    state['grad_prev_para'] = torch.zeros_like(d_p)
                    d_p_prev = state['grad_prev_para']
                    
                grad_1 = d_p.view(-1)
                h_len_list.append(len(grad_1))
                h += torch.dot(d_p.view(-1), d_p_prev.view(-1))
                        
        h_global = h
        
        return h_global, h_len_list
        
    def get_layer_h_list(self):

        h_layer_list = []
            
        for group in self.param_groups:
                
            i = 0

            for p in group['params']: 
                if p.grad is None:
                    continue
                grad = p.grad.data
                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 1 
                    state['grad_prev_layer'] = torch.zeros_like(grad)
                    
                grad_prev = state['grad_prev_layer']

                if state['step'] > 1:
                    h = torch.dot(grad.view(-1), grad_prev.view(-1))
                    h_layer_list.append(h.item())
                state['grad_prev_layer'] = grad  
                i = i + 1

        return h_layer_list
    
    def get_unit_h_list(self):

        h_layer_list = []
        
        for group in self.param_groups:
                
            h_layer_list = []        
            for p in group['params']:
                
                if p.grad is None:
                    continue
                        
                grad = p.grad.data
                state = self.state[p]
                grad_dim = grad.shape
                    
                if len(grad_dim) == 1:
                    try:
                        grad_dim_col = grad_dim[0]
                    except:
                        grad_dim_col = len(grad_dim)
                if len(grad_dim) == 2:
                    grad_dim_row = grad_dim[0]
                    grad_dim_col = grad_dim[1]
                
                if state['step'] == 1:
                    state['grad_prev_unit'] = torch.zeros_like(grad)
                    
                grad_prev = state['grad_prev_unit']
                state['grad_prev_unit'] = grad

                if state['step'] > 1:

                    h_unit_list = []
                    for unit in range(grad_dim_col):
                        try:
                            h = torch.dot(grad[:, unit], grad_prev[:, unit])
                        except:                           
                            h = grad[unit] * grad_prev[unit]                                       
                        h_unit_list.append(h)
                    
                try: 
                    h_unit_list = torch.tensor(h_unit_list)
                    h_layer_list.append(h_unit_list) 
                except:
                    0
                    
        return h_layer_list 
        
    def get_para_h_list(self):     
        
        h_para_list = []
        
        for group in self.param_groups:
                
            h_para_list = []
            i = 0

            for p in group['params']:
 
                if p.grad is None:
                    continue
                grad = p.grad.data
                state = self.state[p]
        
                if state['step'] == 1:
                    state['grad_prev_para'] = torch.zeros_like(grad)
                    
                grad_prev = state['grad_prev_para']
                if state['step'] > 1:
                    h = grad * grad_prev
                    h_para_list.append(h)
                state['grad_prev_para'] = grad
        
        return h_para_list
        
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        
        loss = None
        if closure is not None:
            loss = closure()
            
        if self.hd_decay == True:
            tau = np.exp(-self.hd_decay_coef*self.timestep)
        else:
            tau = 1
         
        # print("self.timestep", self.timestep)
        if self.timestep > 0:
            h_global, h_len_list = self.get_global_h()
            # h_layer_list = self.get_layer_h_list()
            # h_para_list = self.get_para_h_list()
            
            # print("h_layer_list", h_layer_list)
            if self.level == "para_layer_global":
                self.lr_global += tau * self.hypergrad_lr * h_global * self.gamma_3
            if self.level == "layer_global":
                self.lr_global += tau * self.hypergrad_lr * h_global * self.gamma_2
            
        gamma_list = []
        
        for group in self.param_groups:    
            
            i = 0
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
                
            for p in group['params']:
                
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if d_p.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

                state = self.state[p]
                
                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                
                # print("state_step", state['step'])
                state['step'] = state['step'] + 1
                    
                if group['weight_decay'] != 0:
                    d_p = d_p.add(group['weight_decay'], p.data)
                    
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf
                
                if state['step'] > 1:
                    
                    try:
                        d_p_prev = state['grad_prev_para']
                    except:
                        state['grad_prev_para'] = torch.zeros_like(d_p)
                        d_p_prev = state['grad_prev_para']

                    h_layer = torch.dot(d_p.view(-1), d_p_prev.view(-1))
                    h_para = d_p * d_p_prev
                    state['grad_prev_para'] = d_p
                    
                    if self.level == "para_layer_global":

                        n_para_layer = p.data.numel()
                        
                        if self.update_from_combination == False:
                            self.lr_list_para[i] += tau * group['hypergrad_lr'] * h_para * self.gamma_1 * n_para_layer * np.sum(h_len_list)/h_len_list[i]
                            self.lr_list[i] += tau * group['hypergrad_lr'] * h_layer  * self.gamma_2 * np.sum(h_len_list)/h_len_list[i]
                        else:
                            self.lr_list_para[i] = self.lr_list_para_1[i] + tau * group['hypergrad_lr'] * h_para * self.gamma_1 * n_para_layer * np.sum(h_len_list)/h_len_list[i]
                            self.lr_list[i] = torch.mean(self.lr_list_para_1[i]) + tau * group['hypergrad_lr'] * h_layer  * self.gamma_2 * np.sum(h_len_list)/h_len_list[i]
                            self.lr_global =  torch.mean(torch.tensor([torch.mean(self.lr_list_para_1[i]) for i in range(len(self.lr_list_para_1))])) + tau * group['hypergrad_lr'] * h_global * self.gamma_3
                    
                        if self.cuda:
                            self.lr_list_para_1[i] = self.gamma_1*self.lr_list_para[i] + self.gamma_2*self.lr_list[i] * torch.ones(p.data.shape).cuda() + self.gamma_3 * self.lr_global * torch.ones(p.data.shape).cuda()                             
                        else:
                            self.lr_list_para_1[i] = self.gamma_1*self.lr_list_para[i] + self.gamma_2*self.lr_list[i] * torch.ones(p.data.shape) + self.gamma_3 * self.lr_global * torch.ones(p.data.shape) 
                            
                        if self.hd_decay == True:
                            self.lr_list_para_2[i] = tau * self.lr_list_para_1[i] + (1 - tau) * self.lr
                    
                        self.gamma_1 = self.gamma_1 + tau * torch.sum(self.delta*h_para*self.lr_list_para[i]) 
                        self.gamma_2 = self.gamma_2 + tau * torch.sum(self.delta*h_para*self.lr_list[i])
                        self.gamma_3 = self.gamma_3 + tau * torch.sum(self.delta*h_para*self.lr_global)
                        
                        sum_gamma = self.gamma_1 + self.gamma_2 + self.gamma_3
                        
                        self.gamma_1 = self.gamma_1/sum_gamma
                        self.gamma_2 = self.gamma_2/sum_gamma
                        self.gamma_3 = self.gamma_3/sum_gamma
                        
                        if self.gamma_1<0: 
                            self.gamma_1=0
                        if self.gamma_1>1:
                            self.gamma_1=1
                            
                        if self.gamma_2<0: 
                            self.gamma_2=0
                        if self.gamma_2>1:
                            self.gamma_2=1 
                            
                        if self.gamma_3<0: 
                            self.gamma_3=0
                        if self.gamma_3>1:
                            self.gamma_3=1
                            
                        gamma_list = [self.gamma_1, self.gamma_2, self.gamma_3]
                    
                    if self.level == "layer_global": # try to only involve fully connected layer in CNNs                              
                        if self.update_from_combination == False:
                            self.lr_list[i] += tau * group['hypergrad_lr'] * h_layer * self.gamma_1 * np.sum(h_len_list)/h_len_list[i]
                        else:
                            self.lr_list[i] = self.lr_list_1[i] + tau * group['hypergrad_lr'] * h_layer * self.gamma_1 * np.sum(h_len_list)/h_len_list[i]
                            self.lr_global = torch.mean(self.lr_list_1[i]) + tau * group['hypergrad_lr'] * h_global * self.gamma_2  
                        
                        self.lr_list_1[i] = self.gamma_1*self.lr_list[i]+ self.gamma_2*self.lr_global
                        
                        if self.hd_decay == True:
                            self.lr_list_2[i] = tau * self.lr_list_1[i] + (1 - tau) * self.lr
                        
                        self.gamma_1 = self.gamma_1 + tau * self.delta*h_layer*self.lr_list[i]
                        self.gamma_2 = self.gamma_2 + tau * self.delta*h_layer*self.lr_global
                        
                        self.gamma_1 = self.gamma_1/(self.gamma_1 + self.gamma_2)
                        self.gamma_2 = self.gamma_2/(self.gamma_1 + self.gamma_2)
                        
                        if self.gamma_1<0: 
                            self.gamma_1=0
                        if self.gamma_1>1:
                            self.gamma_1=1
                            
                        if self.gamma_2<0: 
                            self.gamma_2=0
                        if self.gamma_2>1:
                            self.gamma_2=1    
                            
                        gamma_list = [self.gamma_1, self.gamma_2]
                            
                    if self.level == "layer":
                        self.lr_list[i] += group['hypergrad_lr'] * h_layer_list[i]
                        self.lr_list_1[i] = self.lr_list[i]
                
                state['grad_prev'] = d_p
            
                if self.level == "para_layer_global":
                    if self.hd_decay == True:
                        p.data.add_(-self.lr_list_para_2[i]*d_p.view_as(p.data))
                    else:
                        p.data.add_(-self.lr_list_para_1[i]*d_p.view_as(p.data))
                else:
                    if self.hd_decay == True:
                        p.data.add_(-self.lr_list_2[i]*d_p.view_as(p.data))
                    else:
                        p.data.add_(-self.lr_list_1[i]*d_p.view_as(p.data))
                
                i = i + 1

            if self.print_gamma == True:
                print("self.gamma_1, self.gamma_2", self.gamma_1, self.gamma_2)     
        self.timestep = self.timestep + 1
            
        return loss, gamma_list

