# This code is modified from https://github.com/dragen1860/MAML-Pytorch and https://github.com/katerakelly/pytorch-maml 

import backbone
import torch
import torch.nn as nn

from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from methods.meta_template import MetaTemplate
from methods.min_norm_solvers import MinNormSolver, gradient_normalizers
import sys
sys.path.append('..')
import hypergrad as hg
import higher

class Constrained_implicit_1(MetaTemplate):
    def __init__(self, model_func,  n_way, n_support, approx = False):
        super(Constrained_implicit_1, self).__init__( model_func,  n_way, n_support, change_way = False)

        self.loss_fn = nn.CrossEntropyLoss()
        self.classifier = nn.Linear(self.feat_dim, n_way)
        self.classifier.bias.data.fill_(0)
        
    def forward(self,x):
        out  = self.feature.forward(x)
        scores  = self.classifier.forward(out)
        return scores

class Constrained_implicit(MetaTemplate):
    def __init__(self, model_func,  n_way, n_support, approx = False):
        super(Constrained_implicit, self).__init__( model_func,  n_way, n_support, change_way = False)

        self.loss_fn = nn.CrossEntropyLoss()
        self.classifier = backbone.Linear_fw(self.feat_dim, n_way)
        self.classifier.bias.data.fill_(0)
        
        self.n_task     = 4
        self.task_update_num = 100 
        self.train_lr = 0.001
        self.meta_lambda=None
        self.approx = approx #first order approx.  

        self.weighting_mode = None

        self.model_func=model_func
        self.n_way=n_way
        self.n_support=n_support
        
    def forward(self,x):
        out  = self.feature.forward(x)
        scores  = self.classifier.forward(out)
        return scores
    
    def meta_biased_reg(self, meta_parameters,parameter_faster):
        theta_prime = [(parameter_faster[i] - meta_parameters[i]) for i in range(len(meta_parameters))]
        bias_reg_loss=0.0
        for i in range(len(meta_parameters)):
            bias_reg_loss+=torch.norm(theta_prime[i])*torch.norm(theta_prime[i])
            if i==len(meta_parameters)-1:
                bias_reg_loss+=torch.norm(theta_prime[i])*torch.norm(theta_prime[i])*8.0       #add a large regularization weight for the linear layer to mimic boil ########################333333333333333333333333
        return bias_reg_loss*self.meta_lambda

    def set_forward(self,x, is_feature = False, robust = True ,LLmode = False):
        assert is_feature == False, 'MAML do not support fixed feature' 
        x = x.cuda()
        x_var = Variable(x)
        x_a_i = x_var[:,:self.n_support,:,:,:].contiguous().view( self.n_way* self.n_support, *x.size()[2:]) #support data 
        x_b_i = x_var[:,self.n_support:,:,:,:].contiguous().view( self.n_way* self.n_query,   *x.size()[2:]) #query data
        y_a_i = Variable( torch.from_numpy( np.repeat(range( self.n_way ), self.n_support ) )).cuda() #label for support data
        y_b_i = Variable( torch.from_numpy( np.repeat(range( self.n_way ), self.n_query   ) )).cuda() #label for query data
        adv_input_a = self.test_PGD(x_a_i, y_a_i, step_num = 7).cuda()
        adv_input_b = self.test_PGD(x_b_i, y_b_i, step_num = 7).cuda()
        y_a_i = y_a_i.to(torch.int64)
        y_b_i = y_b_i.to(torch.int64)

        self.zero_grad()

        device = next(self.parameters()).device
        model_adapt= Constrained_implicit_1(self.model_func,  self.n_way, self.n_support)
        fmodel = higher.monkeypatch(model_adapt, device=device, copy_initial_weights=True)

        def loss_train_call(params, hparams):
            reg_loss=self.meta_biased_reg(hparams,params)
            scores = fmodel(x_a_i,params=params)
            scores_robust = fmodel(adv_input_a,params=params)
            set_loss = self.loss_fn( scores, y_a_i)
            loss_robust = self.loss_fn(scores_robust, y_a_i)
            return set_loss+reg_loss+F.softplus(loss_robust-set_loss*1.3,100.0)*1.0            ########################1111111111111111111111
        
        def loss_val_call(params, hparams):
            scores=fmodel(x_b_i,params=params)
            set_loss = self.loss_fn( scores, y_b_i)
            scores_robust = fmodel(adv_input_b,params=params)
            loss_robust = self.loss_fn(scores_robust, y_b_i)
            return set_loss*0.8+loss_robust*0.2              ##########################2222222222222222222222222222
            #return set_loss                      
        
        def inner_loop_my(params,hparams, loss, n_steps=self.task_update_num, create_graph=False):
            optimizer0 = torch.optim.Adam(params,lr=self.train_lr,weight_decay=0.0)
            for i in range(n_steps):
                loss_train1=loss(params,hparams)
                optimizer0.zero_grad()
                loss_train1.backward(retain_graph=True)
                optimizer0.step()
            return [par.detach().clone().requires_grad_(False) for par in params]
        
        fast_temp = [para.detach().clone().requires_grad_() for para in list(self.parameters())]
        meta_parameters_self = [para.detach().clone().requires_grad_(False) for para in list(self.parameters())]
        fast_parameters=inner_loop_my(fast_temp,meta_parameters_self,loss_train_call)
        
        scores = fmodel(x_b_i,params=fast_parameters)
        loss = self.loss_fn(scores, y_b_i)
        scores_robust = fmodel(adv_input_b,params=fast_parameters) 
        loss_robust = self.loss_fn(scores_robust, y_b_i)
        
        return scores,scores_robust,loss,loss_robust,loss_train_call, loss_val_call, fast_parameters

    def set_forward_adaptation(self,x, is_feature = False): #overwrite parrent function
        raise ValueError('MAML performs further adapation simply by increasing task_upate_num')

    def set_forward_loss(self, x, require_rob = True):
        scores,scores_robust,loss,loss_robust,loss_train_call, loss_val_call, fast_parameters = self.set_forward(x, is_feature = False, robust = True)
        cg_fp_map = hg.GradientDescent(loss_f=loss_train_call, step_size=1.)  
        hg.CG(fast_parameters, list(self.parameters()), K=5, fp_map=cg_fp_map, outer_loss=loss_val_call) 

        
        return loss, loss_robust

    def train_loop(self, epoch, train_loader, optimizer): #overwrite parrent function
        print_freq = 10
        avg_loss_acc=0
        avg_loss_rob=0
        task_count = 0
        grads = {}
        scale = {}
        optimizer.zero_grad()
        tasks = ['acc','rob']
        x_list = []
        grads['acc'] = []
        grads['rob'] = []

        if self.weighting_mode == 'SOML': 
            for i, (x,_) in enumerate(train_loader):
                self.n_query = x.size(1) - self.n_support
                assert self.n_way  ==  x.size(0), "MAML do not support way change"
                loss, loss_robust = self.set_forward_loss(x ,require_rob = True)
                avg_loss_acc = avg_loss_acc+loss.item()
                avg_loss_rob = avg_loss_rob+loss_robust.item()
                task_count += 1

                nan_list=[bool(torch.isnan(pa.grad).any()) for pa in list(self.parameters()) ]
                if bool(nan_list[0]):
                    print(nan_list)
                    print(i)
                    optimizer.zero_grad()

                if task_count == self.n_task: #MAML update several tasks at one time
                    optimizer.step()
                    task_count = 0
                optimizer.zero_grad()
                if i % print_freq==0:
                    print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f} Loss ADV {:f}'.format(epoch, i, len(train_loader), avg_loss_acc/float(i+1), avg_loss_rob/float(i+1)))
                    #print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), avg_loss_acc/float(i+1)))
        elif self.weighting_mode == 'COML': 
            for i, (x,_) in enumerate(train_loader):
                self.n_query = x.size(1) - self.n_support
                assert self.n_way  ==  x.size(0), "MAML do not support way change"
                loss, loss_robust = self.set_forward_loss(x ,require_rob = True)
                avg_loss_acc = avg_loss_acc+loss.item()
                avg_loss_rob = avg_loss_rob+loss_robust.item()
                task_count += 1

                nan_list=[bool(torch.isnan(pa.grad).any()) for pa in list(self.parameters()) ]
                if bool(nan_list[0]):
                    print(nan_list)
                    print(i)
                    optimizer.zero_grad()
                
                if task_count == self.n_task: #MAML update several tasks at one time
                    optimizer.step()
                    task_count = 0
                optimizer.zero_grad()
                if i % print_freq==0:
                    print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f} Loss ADV {:f}'.format(epoch, i, len(train_loader), avg_loss_acc/float(i+1), avg_loss_rob/float(i+1)))
                    #print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), avg_loss_acc/float(i+1)))
              
    def test_loop(self, test_loader, return_std = False): #overwrite parrent function
        correct =0
        count = 0
        acc_all = []
        acc_all2 = []
        iter_num = len(test_loader) 
        for i, (x,_) in enumerate(test_loader):
            self.n_query = x.size(1) - self.n_support
            assert self.n_way  ==  x.size(0), "MAML do not support way change"
            scores,scores2,_,_,_,_,_ = self.set_forward(x, robust = True)
            y_query = np.repeat(range( self.n_way ), self.n_query )
            topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
            topk_scores2, topk_labels2 = scores2.data.topk(1, 1, True, True)
            topk_ind = topk_labels.cpu().numpy()
            topk_ind2 = topk_labels2.cpu().numpy()
            top1_correct = np.sum(topk_ind[:,0] == y_query)
            top1_correct2 = np.sum(topk_ind2[:,0] == y_query)
            correct_this = float(top1_correct)
            correct_this2 = float(top1_correct2)
            count_this = len(y_query)
            acc_all.append(correct_this/ count_this *100 )
            acc_all2.append(correct_this2/ count_this *100 )
            
        acc_all  = np.asarray(acc_all)
        acc_mean = np.mean(acc_all)
        acc_std  = np.std(acc_all)
        acc_all2  = np.asarray(acc_all2)
        acc_mean2 = np.mean(acc_all2)
        acc_std2  = np.std(acc_all2)
        B_score = 2 * (acc_all * acc_all2) / (acc_all + acc_all2)
        B_score2 = 2 * (np.mean(acc_all) * np.mean(acc_all2)) / (np.mean(acc_all) + np.mean(acc_all2))
        print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num,  acc_mean, 1.96* acc_std/np.sqrt(iter_num)))
        print('%d Test Rob = %4.2f%% +- %4.2f%%' %(iter_num,  acc_mean2, 1.96* acc_std2/np.sqrt(iter_num)))
        print('%d Test B Acc = %4.2f%% and %4.2f%% +- %4.2f%%' %(iter_num, B_score2,  np.mean(B_score), 1.96* np.std(B_score)/np.sqrt(iter_num)))
        return acc_mean, acc_mean2, np.mean(B_score)

    def clamp(self, X, lower_limit, upper_limit):
        return torch.max(torch.min(X, upper_limit), lower_limit)

    def test_PGD(self, x, y, step_num = 2):
        eps = 2/255 * torch.FloatTensor([1.0,1.0,1.0]).cuda()
        mean=  torch.FloatTensor([0.485, 0.456, 0.406]).cuda()
        std =  torch.FloatTensor([0.229, 0.224, 0.225]).cuda()
        epsilon = ((eps ) / std).reshape(3,1,1)
        upper_limit = torch.FloatTensor([2.2489, 2.4286, 2.6400]).reshape(3,1,1).cuda()
        lower_limit = torch.FloatTensor([-2.1179, -2.0357, -1.8044]).reshape(3,1,1).cuda()
        labels = Variable(y, requires_grad=False).cuda()
        images = Variable(x, requires_grad=True).cuda()
        
        step_size = 1.5 / step_num * epsilon

        for i in range(step_num):
            scores_test = self.forward(images)
            labels = labels.to(torch.int64)
            loss = self.loss_fn( scores_test, labels) 
            #loss.backward(retain_graph=True)
            grad = torch.autograd.grad(loss, images, 
                                    retain_graph=False, create_graph=False)[0]
            grad = grad.detach().data
            adv_images = images.detach().data + step_size * torch.sign(grad)
            delta = self.clamp(adv_images - x, -epsilon, epsilon)
            adv_images = self.clamp(x + delta, lower_limit, upper_limit)
            images = Variable(adv_images, requires_grad=True).cuda()
        return images


