import os
from sched import scheduler
import config
import shutil
import numpy as np
import random
from copy import deepcopy
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

from dataloader import get_dataloader
import sys
sys.path.insert(0,'../')
from classifier_models import PreActResNet18
from networks.models import UnetGenerator
from utils.dataloader import PostTensorTransform
from utils.dct import *
#from utils.utils import progress_bar

def create_dir(path_dir):
    list_subdir = path_dir.strip('.').split('/')
    list_subdir.remove('')
    base_dir = './'
    for subdir in list_subdir:
        base_dir = os.path.join(base_dir, subdir)
        try:
            os.mkdir(base_dir)
        except:
            pass

def backdoor(clean_x, opt):
    bs = clean_x.shape[0]
    output = torch.clone(clean_x)
    if opt.attack_name == "badnets":
        pat_size = 4
        for i in range(output.shape[0]):
            output[i][:, 32-1-pat_size:32-1, 32-1-pat_size:32-1] = 1
        return output
    
    elif opt.attack_name == "narcisuss":
        trimg = torch.from_numpy(np.load(os.path.join('./triggers', opt.attack_name + '.npy')))
        output[i] = clean_x[i]+trimg
        
    else:
        trimg = np.transpose(plt.imread(os.path.join('./triggers', opt.attack_name + '.png')), (2,0,1))
        trimg = (torch.from_numpy((trimg*2) - np.ones_like(trimg))).to(opt.device)
        for i in range(output.shape[0]):
            output[i] = clean_x[i]+trimg
    
    return output

def get_model(opt):
    netC = None
    optimizerC = None
    schedulerC = None
    netT = None
    optimizerT = None
    schedulerT = None
    
    if(opt.dataset == 'cifar10'):
        # Model
        netC = PreActResNet18().to(opt.device)
        netT = PreActResNet18().to(opt.device)


    # Optimizer 
    optimizerC = torch.optim.SGD(netC.parameters(), opt.lr_C, momentum=0.9, weight_decay=1e-4, nesterov=True)
    schedulerC = torch.optim.lr_scheduler.MultiStepLR(optimizerC, opt.schedulerC_milestones, opt.schedulerC_lambda)
    optimizerT = torch.optim.SGD(netT.parameters(), opt.lr_T, momentum=0.9, weight_decay=1e-4, nesterov=True)
    schedulerT = torch.optim.lr_scheduler.MultiStepLR(optimizerT, opt.schedulerT_milestones, opt.schedulerT_lambda)
    
    return netC, optimizerC, schedulerC, netT, optimizerT, schedulerT

def create_targets_bd(targets, opt):
    bd_targets = torch.ones_like(targets) * opt.target_label
    return bd_targets.to(opt.device)

class AT(nn.Module):
	'''
	Paying More Attention to Attention: Improving the Performance of Convolutional
	Neural Netkworks wia Attention Transfer
	https://arxiv.org/pdf/1612.03928.pdf
	'''
	def __init__(self, p):
		super(AT, self).__init__()
		self.p = p

	def forward(self, fm_s, fm_t):
		loss = F.mse_loss(self.attention_map(fm_s), self.attention_map(fm_t))

		return loss

	def attention_map(self, fm, eps=1e-6):
		am = torch.pow(torch.abs(fm), self.p)
		am = torch.sum(am, dim=1, keepdim=True)
		norm = torch.norm(am, dim=(2,3), keepdim=True)
		am = torch.div(am, norm+eps)

		return am

def train(netS, netT, optimizerS, schedulerS, train_dl, epoch, opt):
    torch.autograd.set_detect_anomaly(True)
    print(" Train:")
    netS.train()

    criterion_CE = torch.nn.CrossEntropyLoss()
    criterion_AT = AT(opt.power)
    transform = PostTensorTransform(opt)

    total_sample = 0
    total_loss = 0.
    total_correct = 0

    for batch_idx, (inputs, targets) in enumerate(train_dl):
        optimizerS.zero_grad()
        inputs, targets = inputs.to(opt.device), targets.to(opt.device)
        bs = inputs.shape[0]
        inputs = transform(inputs)
        if opt.dataset == 'cifar10': # preactresnet18
            outputs_s = netS(inputs)
            features_out_3 = list(netS.children())[:-1]  
            modelout_3 = nn.Sequential(*features_out_3)
            modelout_3.to(opt.device)
            activation3_s = modelout_3(inputs)
            
            features_out_2 = list(netS.children())[:-2]  
            modelout_2 = nn.Sequential(*features_out_2)
            modelout_2.to(opt.device)
            activation2_s = modelout_2(inputs)
            
            features_out_1 = list(netS.children())[:-3] 
            modelout_1 = nn.Sequential(*features_out_1)
            modelout_1.to(opt.device)
            activation1_s = modelout_1(inputs)
            
            features_out_3 = list(netT.children())[:-1]  
            modelout_3 = nn.Sequential(*features_out_3)
            modelout_3.to(opt.device)
            activation3_t = modelout_3(inputs)

            features_out_2 = list(netT.children())[:-2]  
            modelout_2 = nn.Sequential(*features_out_2)
            modelout_2.to(opt.device)
            activation2_t = modelout_2(inputs)

            features_out_1 = list(netT.children())[:-3]  
            modelout_1 = nn.Sequential(*features_out_1)
            modelout_1.to(opt.device)
            activation1_t = modelout_1(inputs)

            cls_loss = criterion_CE(outputs_s, targets)
            at3_loss = criterion_AT(activation3_s, activation3_t.detach()) * opt.beta3
            at2_loss = criterion_AT(activation2_s, activation2_t.detach()) * opt.beta2
            at1_loss = criterion_AT(activation1_s, activation1_t.detach()) * opt.beta1

            loss = at1_loss + at2_loss + at3_loss + cls_loss

            loss.backward()
            optimizerS.step()

            total_sample += bs
            total_loss += loss.detach()
            total_correct += torch.sum(torch.argmax(outputs_s, dim=1) == targets)
            avg_acc = total_correct * 100. / total_sample
            avg_loss = total_loss / total_sample
    print(batch_idx, len(train_dl), 'Loss: {:.4f} | Acc: {:.4f}'.format(avg_loss, avg_acc))
    
    schedulerS.step()

def eval(netS, optimizerS, schedulerS, test_dl, best_clean_acc, best_bd_acc, epoch, opt):
    print(" Eval:")
    netS.eval()
    
    total_sample = 0
    total_clean_correct = 0
    total_bd_correct = 0
 

    for batch_idx, (inputs, targets) in enumerate(test_dl):
        with torch.no_grad():
            inputs, targets = inputs.to(opt.device), targets.to(opt.device)
            bs = inputs.shape[0]
            total_sample += bs
            # Evaluate Clean
            preds_clean = netS(inputs)
            total_clean_correct += torch.sum(torch.argmax(preds_clean, 1) == targets)
            
            # Evaluate Backdoor
            inputs_bd = backdoor(inputs, opt)
            
            targets_bd = create_targets_bd(targets, opt)
            preds_bd = netS(inputs_bd)
            total_bd_correct += torch.sum(torch.argmax(preds_bd, 1) == targets_bd)

            acc_clean = total_clean_correct * 100. / total_sample
            acc_bd = total_bd_correct * 100. / total_sample
            
            info_string = "Clean Acc: {:.4f} - Best: {:.4f} | Bd Acc: {:.4f} - Best: {:.4f}".format(acc_clean, best_clean_acc, acc_bd, best_bd_acc)
    print(batch_idx, len(test_dl), info_string)
    
    # Save checkpoint 
    if(acc_clean > 0):
        print(' Saving...')
        best_clean_acc = acc_clean
        best_bd_acc = acc_bd
        state_dict = {'netS': netS.state_dict(),
                      'schedulerS': schedulerS.state_dict(),
                      'optimizerT': optimizerS.state_dict(),
                      'best_clean_acc': acc_clean,
                      'best_bd_acc': acc_bd,
                      'epoch_current': epoch}
        torch.save(state_dict, opt.ckpt_path)
    return acc_clean, acc_bd

def main():
    opt = config.get_arguments().parse_args()
    if(opt.dataset == 'cifar10'):
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel  = 3 
    else:
        raise Exception("Invalid Dataset")

    # Dataset 
    train_dl = get_dataloader(opt, True)
    test_dl = get_dataloader(opt, False)
        
    # Prepare pretrained models
    netS, optimizerS, schedulerS, netT, optimizerT, schedulerT = get_model(opt)
    opt.student_ckpt_folder = os.path.join(opt.student_checkpoints, opt.saving_prefix, opt.dataset)
    opt.student_ckpt_path = os.path.join(opt.student_ckpt_folder, '{}_{}.pth.tar'.format(opt.dataset, opt.saving_prefix))

    if(os.path.exists(opt.student_ckpt_path)):
            print('Load pretrained student model')
            state_dict_S = torch.load(opt.student_ckpt_path)
            netS.load_state_dict(state_dict_S['netC'])

    else: 
        print('Pretrained student model doesnt exist')
        exit()
    print(state_dict_S['best_clean_acc'], state_dict_S['best_bd_acc'])

    opt.teacher_ckpt_folder = os.path.join(opt.checkpoints, '{}_teacher'.format(opt.saving_prefix), opt.dataset)
    opt.teacher_ckpt_path = os.path.join(opt.teacher_ckpt_folder, '{}_{}_teacher.pth.tar'.format(opt.dataset, opt.saving_prefix))
    
    if(os.path.exists(opt.teacher_ckpt_path)):
            print('Load pretrained student model')
            state_dict_T = torch.load(opt.teacher_ckpt_path)
            netT.load_state_dict(state_dict_T['netT'])
    else: 
        print('Pretrained teacher model doesnt exist')
        exit()

    opt.ckpt_folder = os.path.join(opt.checkpoints, '{}_nad'.format(opt.saving_prefix), opt.dataset)
    opt.ckpt_path = os.path.join(opt.teacher_ckpt_folder, '{}_{}_nad.pth.tar'.format(opt.dataset, opt.saving_prefix))
    opt.log_dir = os.path.join(opt.ckpt_folder, 'log_dir')

    print('Train NAD model!!!')
    best_clean_acc = 0.
    best_bd_acc = 0.
    epoch_current = 0
    shutil.rmtree(opt.ckpt_folder, ignore_errors=True)
    create_dir(opt.log_dir)

    with open(opt.result_file, 'w') as f:
        for epoch in range(epoch_current, opt.n_iters):
            print('Epoch {}:'.format(epoch + 1))
            train(netS, netT, optimizerS, schedulerS, train_dl, epoch, opt)
            best_clean_acc, best_bd_acc = eval(netS, 
                                        optimizerS, 
                                        schedulerS,
                                        test_dl, 
                                        best_clean_acc, 
                                        best_bd_acc, 
                                        epoch, opt)
            acc_clean = best_clean_acc.cpu().detach().numpy()
            acc_bd = best_bd_acc.cpu().detach().numpy()
            f.write('%s\t%s\t%s\n' % (epoch, acc_clean, acc_bd))

        
    
if(__name__ == '__main__'):
    main()
