
import time
import random
import logging
import numpy as np
import torch
import torch.optim as optim
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.decomposition import PCA

def sigmoid(x):
    return 1 / (1+np.exp(-x))


class OptSwitcher(object):

    def __init__(self, init_step=10, sw_step=10, optimizer_space = ['SGD', 'SGDM', 'Adagrad'], random_ratio=0.1, random_schedule=20, lr=0.01, alpha=0.1, device='cpu'):
        ### init_step:  the random init for start
        ### sw_step: the iteration for switch
        ### optimizer_space: 
        self.lr = lr
        self.device = device
        self.alpha = alpha
        self.init_step = init_step
        self.sw_step = sw_step
        self.optimizer_space = optimizer_space
        self.cur_iter = 0
        self.cur_step  = 0
        self.recommender = {k: GaussianProcessRegressor() for k in optimizer_space}
        self.history = {k: [[],[]] for k in optimizer_space}
        self.losses = []
        self.random_ws = {k: 1 for k in optimizer_space}
        self.random_ratio = random_ratio
        self.random_schedule = random_schedule
        
    

    """initialize"""
    def init(self, net, net_name, train_loader, val_loader):
        self.net_name = net_name
        self._random_recommend(net)
        self.net_vec = self._compress(net)
        return self._transferEst(net, train_loader, val_loader)


    """Transferability Estimation"""
    def _transferEst(self, net, train_loader, val_loader):
        with torch.no_grad():
            x, y = next(iter(train_loader))
            x, y = x.to(self.device), y.to(self.device)

            y_pred = net(x.to(self.device))

            # LogME
            from LogME.LogME import LogME
            logme = LogME(regression=False)
            f = x
            for idx, layer in enumerate(net.children()):
                if idx == len(list(net.children()))-1:
                    break
                f = layer(f)
            logme_score = logme.fit(f.reshape(f.shape[0], -1).cpu().numpy(), y.cpu().numpy())
            logging.warning('- LogME: {}'.format(logme_score))

            # LEEP
            from LogME.LEEP import LEEP
            leep_score = LEEP(y_pred.cpu().numpy(), y.cpu().numpy())
            logging.warning('- LEEP: {}'.format(leep_score))
            
            # NCE
            from LogME.NCE import NCE
            nce_score = NCE(np.argmax(y_pred.cpu().numpy(), axis=1), y.cpu().numpy())
            logging.warning('- NCE: {}'.format(nce_score))

            # Accuracy
            acc_score = (y==torch.argmax(y_pred,1)).sum().item()/len(y)
            logging.warning('- Acc: {}'.format(acc_score))


            # combine the scores
            self.trans_w = 1-((sigmoid(logme_score/100)+sigmoid(nce_score))*0.1+acc_score*0.8)
            logging.warning('- Trans.W: {}'.format(self.trans_w))

        return

    


    def _compress(self, net):
        n_components = 2
        net_vec = np.array([])

        if 'resnet' in self.net_name:
            for layer in net.fc.parameters(recurse=True):           
                layer_matrix = layer.clone().detach().cpu().numpy()
                layer_shape = list(layer_matrix.shape)
                
                if len(layer_shape)>2:
                    layer_matrix = layer_matrix.reshape((layer_shape[0]*layer_shape[2], -1))
                
                if len(layer_shape)==1:
                    layer_vec = layer_matrix
                else:
                    layer_vec = PCA(n_components = n_components).fit_transform(layer_matrix)
                    layer_vec = PCA(n_components = n_components).fit_transform(layer_vec.T).reshape(-1)
                
                net_vec = np.concatenate((net_vec, layer_vec))

        elif 'densenet' in self.net_name:
            for layer in net.classifier.parameters(recurse=True):           
                layer_matrix = layer.clone().detach().numpy()
                layer_shape = list(layer_matrix.shape)
                
                if len(layer_shape)>2:
                    layer_matrix = layer_matrix.reshape((layer_shape[0]*layer_shape[2], -1))
                
                if len(layer_shape)==1:
                    layer_vec = layer_matrix
                else:
                    layer_vec = PCA(n_components = n_components).fit_transform(layer_matrix)
                    layer_vec = PCA(n_components = n_components).fit_transform(layer_vec.T).reshape(-1)
                
                net_vec = np.concatenate((net_vec, layer_vec))

        elif 'mobilenet' in self.net_name:
            for layer in net.classifier[1].parameters(recurse=True):           
                layer_matrix = layer.clone().detach().numpy()
                layer_shape = list(layer_matrix.shape)
                
                if len(layer_shape)>2:
                    layer_matrix = layer_matrix.reshape((layer_shape[0]*layer_shape[2], -1))
                
                if len(layer_shape)==1:
                    layer_vec = layer_matrix
                else:
                    layer_vec = PCA(n_components = n_components).fit_transform(layer_matrix)
                    layer_vec = PCA(n_components = n_components).fit_transform(layer_vec.T).reshape(-1)
                
                net_vec = np.concatenate((net_vec, layer_vec))
        
        return net_vec
    


    def _get_optimizer(self, net, name='SGD'):

        if name == 'SGD':
            return optim.SGD(net.parameters(), lr=self.lr)
        elif name == 'SGDM':
            return optim.SGD(net.parameters(), lr=self.lr, momentum=0.9)
        elif name == 'Adagrad':
            return optim.Adagrad(net.parameters(), lr=self.lr)
        elif name == 'RMSprop':
            return optim.RMSprop(net.parameters(), lr=self.lr)
        elif name == 'Adam':
            return optim.Adam(net.parameters(), lr=self.lr)
        else:
            return optim.SGD(net.parameters(), lr=self.lr)  
    


    def _random_recommend(self, net):
        rn_ws = [self.random_ws[k] for k in self.random_ws]
        if (self.cur_iter+2>self.init_step) and (1 in rn_ws):
            self.name = self.optimizer_space[rn_ws.index(1)]
        else:
            self.name = random.choices(self.optimizer_space, rn_ws)[0]
        self.optimizer = self._get_optimizer(net, self.name)
        return
    


    def _our_recommend(self, net):
        scores = []
        for k in self.optimizer_space:
            score_pred = self.recommender[k].predict([self.net_vec], return_std=True)
            scores.append(self._acq_score(score_pred[0][0], score_pred[1][0]))

        self.name = self.optimizer_space[np.argmax(scores)]
        self.optimizer = self._get_optimizer(net, self.name)

        return

    
    
    def _update_recommender(self, name):
        x = self.history[name][0]
        y = self.history[name][1]
        self.recommender[name].fit(x,y)

    
    def _eva_score(self):
        self.losses = self.losses, nan=self.losses[0]

        scores = (np.array(self.losses[1:])-np.array(self.losses[:-1]))/np.max([np.array(self.losses[1:]),np.array(self.losses[:-1])], axis=0)
        score_mean = scores.mean()
        score_std = scores.std()
        score_upper = (self.losses[0]-max(self.losses[1:]))/max(self.losses[0],max(self.losses[1:]))
        score_lower = (self.losses[0]-min(self.losses[1:]))/max(self.losses[0],min(self.losses[1:]))
        score = np.tanh((score_mean+score_upper+score_lower)/3+self.alpha*score_std)

        return score_mean, score_std, score
    

    def _acq_score(self, s_mean, s_std):
        return sigmoid(s_mean+s_std*self.trans_w)


    def recommend_optimizer(self, net):

        self.cur_step += 1
        if self.cur_step < self.sw_step:
            return

        self.cur_iter +=1
        self.cur_step = 0
        self.history[self.name][0].append(self.net_vec)

        score_mean, score_std, score = self._eva_score()

        if len(self.optimizer_space)==1:
            self.losses  = []
            return
        
        if np.isnan(score):
            self.history[self.name][0].pop()
        else:
            self.history[self.name][1].append(score)
        self.net_vec = self._compress(net)
        self.losses  = []

        if self.cur_iter < self.init_step:
            self.random_ws[self.name] = max((score+1)/2,0.2)#self._acq_score(score_mean, score_std)
            self._random_recommend(net)
        elif self.cur_iter == self.init_step:
            for name in self.optimizer_space:
                self._update_recommender(name)
            self._our_recommend(net)
        else:
            self._update_recommender(self.name)

            if random.choices([0,1], [self.random_ratio, 1-self.random_ratio])[0]:
                self._our_recommend(net)
            else:
                self._random_recommend(net)
        
        if self.cur_iter > self.init_step and self.cur_iter % self.random_schedule == 0:
            self.random_ratio = self.random_ratio / 2
            self.trans_w = self.trans_w/2

        #logging.warning('--> {}: {}'.format(self.name, score))
        
        return
    
    def zero_grad(self):
        self.optimizer.zero_grad()
    
    def step(self, loss):
        self.optimizer.step()
        self.losses.append(loss) 


class SWATSwitcher(object):

    def __init__(self, sw_step=10, optimizer_type='SWATSW', optimizer_space=['SGD', 'Adam'], lr=0.01):
        self.lr = lr
        self.sw_step = sw_step
        self.optimizer_space = optimizer_space
    
    def init(self, net):
        self.name = 'SGD'

    def _get_optimizer(self, net, name='SGD'):
        if name == 'SGD':
            return optim.SGD(net.parameters(), lr=self.lr)
        elif name == 'SGDM':
            return optim.SGD(net.parameters(), lr=self.lr, momentum=0.9)
        elif name == 'Adagrad':
            return optim.Adagrad(net.parameters(), lr=self.lr)
        elif name == 'RMSprop':
            return optim.RMSprop(net.parameters(), lr=self.lr)
        elif name == 'Adam':
            return optim.Adam(net.parameters(), lr=self.lr)
        else:
            return optim.SGD(net.parameters(), lr=self.lr)  


    def recommend_optimizer(self, net, epoch):
        if epoch < self.sw_step:
            self.name = 'Adam'
        else:
            self.name = 'SGD'
        
        self.optimizer = self._get_optimizer(net, self.name)
    

    def zero_grad(self):
        self.optimizer.zero_grad()
    
    def step(self, loss=0):
        self.optimizer.step()



from torch.optim import *
import math

class Padam(Optimizer):
    """Implements Partially adaptive momentum estimation (Padam) algorithm.
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-1)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        partial (float, optional): partially adaptive parameter
    """

    def __init__(self, params, lr=1e-1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=True, partial = 1/4):
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, partial = partial)
        super(Padam, self).__init__(params, defaults)

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
                amsgrad = group['amsgrad']
                partial = group['partial']

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                if amsgrad:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                if amsgrad:
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    # Use the max. for normalizing running avg. of gradient
                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])
                else:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

                p.data.addcdiv_(-step_size, exp_avg, denom**(partial*2))
                
        return loss
    


    
    