import backbone
import utils

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from torchvision import transforms
import pdb
import copy
import time


class BaselineTrainAdv(nn.Module):
    def __init__(self, model_func, num_class, loss_type = 'softmax',dataset='miniImagenet'):
        super(BaselineTrainAdv, self).__init__()

        self.feature = model_func()
        self.wa_feature = model_func()

        for param_q, param_k in zip(self.feature.parameters(), self.wa_feature.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False
        self.loss_type = loss_type  #'softmax' #'dist'
        self.num_class = num_class
        self.loss_fn = nn.CrossEntropyLoss()

        self.DBval = False; #only set True for CUB dataset, see issue #31
        self.k = 7 #7-default
        self.a = 2.0/255 #2.0/255-default
        self.epsilon = 8.0/255 #8.0/255-default
        self.dataset = dataset
        self.m = 0.999




        if self.dataset=='miniImagenet':

            # self.feature.final_feat_dim = 51200
            if model_func.__name__ == 'Conv4' :
                self.feature.final_feat_dim = 1600
        elif self.dataset=='cifar':

            if model_func.__name__ == 'Conv4' :
                self.feature.final_feat_dim = 256
            elif model_func.__name__ == 'R2D2' :
                self.feature.final_feat_dim = 8192


        if loss_type == 'softmax':
            self.classifier = nn.Linear(self.feature.final_feat_dim, num_class)
            self.classifier.bias.data.fill_(0)
        elif loss_type == 'dist': #Baseline ++
            self.classifier = backbone.distLinear(self.feature.final_feat_dim, num_class)
            # self.classifier.scale_factor=1
        # if model_func.__name__ == 'WRN_28_10'  :
        #     self.feature    = torch.nn.DataParallel(self.feature,device_ids=[0,1,2,3])

    @torch.no_grad()
    def _momentum_update(self):
        for param_q, param_k in zip(self.feature.parameters(), self.wa_feature.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)




    def forward_orig(self,x):
        x    = Variable(x.cuda())
        out  = self.feature(x)

        scores  = self.classifier(out)
        return scores

    def forward_loss_adv(self,x_orig,y):
        x_nat = torch.zeros_like(x_orig).cuda()
        # x_nat = x_orig.clone()
        x_nat[:, 0, :, :] =  x_orig[:, 0, :, :]
        x_nat[:, 1, :, :] =  x_orig[:, 1, :, :]
        x_nat[:, 2, :, :] =  x_orig[:, 2, :, :]

        # x_nat = x_nat.detach().cpu().numpy()
        # x = x_nat + np.random.uniform(-self.epsilon, self.epsilon,
        #     x_nat.shape).astype('float32')
        x = x_nat + torch.FloatTensor(x_nat.shape).uniform_(-self.epsilon, self.epsilon).cuda()
        x = torch.clamp(x, 0, 1)

        # x = x_orig.detach().clone().data.cpu().numpy()
        y = y.cuda()
        for i in range(self.k):
            # x = Variable(torch.from_numpy(x)).cuda()
            x = x.cuda()
            # y_var = to_var(torch.LongTensor(y))

            # scores = self.model(x)
            # x    = Variable(x.cuda())
            # x.requires_grad = True
            x.requires_grad=True
            out  = self.feature(x)
            scores  = self.classifier(out)
            loss = self.loss_fn(scores, y)
            grad = torch.autograd.grad(loss,x)[0].detach()
            x = x.detach()


            # for i in range(x_orig.shape[0]):
            #

            x += self.a * torch.sign(grad)

            # x = torch.clamp(x, x_nat - self.epsilon, x_nat + self.epsilon)
            x = torch.max(torch.min(x, x_nat+self.epsilon), x_nat- self.epsilon)
            x = torch.clamp(x, 0, 1) # ensure valid pixel range

            # x = copy.deepcopy(x_var)

            torch.cuda.empty_cache()
        # x_adv = torch.from_numpy(x)
        x_adv = x.clone()
        # self.feature.train()
        # self.classifier.train()
        return x_adv


    def forward_loss(self, x, y):
        scores = self.forward_orig(x)
        y = Variable(y.cuda())
        return self.loss_fn(scores, y )

    def train_loop(self, epoch, train_loader, optimizer, is_adv,wandb = None):
        print_freq = 1000
        avg_loss=0
        # for param_group in optimizer.param_groups:
        #     print(param_group['lr'])
        for i, (x,y) in enumerate(train_loader):
            if self.dataset == 'CUB':
                y1 = y.clone()
                for l1 in torch.unique(y1): y[y1 == l1] = self.label_map[l1.item()]
            optimizer.zero_grad()
            if is_adv:
                x_adv = self.forward_loss_adv(x, y)

                loss = self.forward_loss(x_adv,y)

            else:
                loss = self.forward_loss(x,y)
            loss.backward()

            optimizer.step()
            with torch.no_grad():
                self._momentum_update()
                _ = self.wa_feature(x_adv)

            avg_loss = avg_loss+loss.detach().item()

            if i % print_freq==0:
                #print(optimizer.state_dict()['param_groups'][0]['lr'])
                print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), avg_loss/float(i+1)  ))

                if wandb is not None:
                    wandb.log({'Epoch':epoch,'loss': avg_loss/float(i+1)})
                    if i==0:
                        wandb.config.adv_iter = self.k
                        wandb.config.alpha = self.a
                        wandb.config.epsilon = self.epsilon


    def test_loop(self, val_loader):
        if self.DBval:
            return self.analysis_loop(val_loader)
        else:
            return -1   #no validation, just save model during iteration

    def analysis_loop(self, val_loader, record = None):
        class_file  = {}
        for i, (x,y) in enumerate(val_loader):
            x = x.cuda()
            x_var = Variable(x)
            feats = self.feature(x_var).data.cpu().numpy()
            labels = y.cpu().numpy()
            for f, l in zip(feats, labels):
                if l not in class_file.keys():
                    class_file[l] = []
                class_file[l].append(f)

        for cl in class_file:
            class_file[cl] = np.array(class_file[cl])

        DB = DBindex(class_file)
        print('DB index = %4.2f' %(DB))
        return 1/DB #DB index: the lower the better

def DBindex(cl_data_file):
    #For the definition Davis Bouldin index (DBindex), see https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index
    #DB index present the intra-class variation of the data
    #As baseline/baseline++ do not train few-shot classifier in training, this is an alternative metric to evaluate the validation set
    #Emperically, this only works for CUB dataset but not for miniImagenet dataset

    class_list = cl_data_file.keys()
    cl_num= len(class_list)
    cl_means = []
    stds = []
    DBs = []
    for cl in class_list:
        cl_means.append( np.mean(cl_data_file[cl], axis = 0) )
        stds.append( np.sqrt(np.mean( np.sum(np.square( cl_data_file[cl] - cl_means[-1]), axis = 1))))

    mu_i = np.tile( np.expand_dims( np.array(cl_means), axis = 0), (len(class_list),1,1) )
    mu_j = np.transpose(mu_i,(1,0,2))
    mdists = np.sqrt(np.sum(np.square(mu_i - mu_j), axis = 2))

    for i in range(cl_num):
        DBs.append( np.max([ (stds[i]+ stds[j])/mdists[i,j]  for j in range(cl_num) if j != i ]) )
    return np.mean(DBs)
