import os
from sched import scheduler
import config
import shutil
import numpy as np
import random
from copy import deepcopy
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from dataloader import get_dataloader

import sys
sys.path.insert(0,'../')
from classifier_models import PreActResNet18
from networks.models import UnetGenerator
from utils.dataloader import PostTensorTransform
from utils.utils import progress_bar
from utils.dct import *

def create_dir(path_dir):
    list_subdir = path_dir.strip('.').split('/')
    list_subdir.remove('')
    base_dir = './'
    for subdir in list_subdir:
        base_dir = os.path.join(base_dir, subdir)
        try:
            os.mkdir(base_dir)
        except:
            pass

def backdoor(clean_x, opt):
    bs = clean_x.shape[0]
    output = torch.clone(clean_x)
    if opt.attack_name == "badnets":
        pat_size = 4
        for i in range(output.shape[0]):
            output[i][:, 32-1-pat_size:32-1, 32-1-pat_size:32-1] = 1
        return output
    
    elif opt.attack_name == "narcisuss":
        trimg = torch.from_numpy(np.load(os.path.join('./triggers', opt.attack_name + '.npy')))
        output[i] = clean_x[i]+trimg
        
    else:
        trimg = np.transpose(plt.imread(os.path.join('./triggers', opt.attack_name + '.png')), (2,0,1))
        trimg = (torch.from_numpy((trimg*2) - np.ones_like(trimg))).to(opt.device)
        for i in range(output.shape[0]):
            output[i] = clean_x[i]+trimg
    
    return output

def get_model(opt):
    netC = None
    optimizerC = None
    schedulerC = None
    netT = None
    optimizerT = None
    schedulerT = None
    
    if(opt.dataset == 'cifar10'):
        # Model
        netC = PreActResNet18().to(opt.device)
        netT = PreActResNet18().to(opt.device)

    # Optimizer 
    optimizerC = torch.optim.SGD(netC.parameters(), opt.lr_C, momentum=0.9, weight_decay=5e-4, nesterov=True)
    schedulerC = torch.optim.lr_scheduler.MultiStepLR(optimizerC, opt.schedulerC_milestones, opt.schedulerC_lambda)
    optimizerT = torch.optim.SGD(netT.parameters(), opt.lr_T, momentum=0.9, weight_decay=1e-4, nesterov=True)
    schedulerT = torch.optim.lr_scheduler.MultiStepLR(optimizerT, opt.schedulerT_milestones, opt.schedulerT_lambda)
    
    return netC, optimizerC, schedulerC, netT, optimizerT, schedulerT

def create_targets_bd(targets, opt):
    bd_targets = torch.ones_like(targets) * opt.target_label
    return bd_targets.to(opt.device)

def train(netT, optimizerT, schedulerT, train_dl, epoch, opt):
    torch.autograd.set_detect_anomaly(True)
    print(" Train:")
    netT.train()

    criterion_CE = torch.nn.CrossEntropyLoss()
    transform = PostTensorTransform(opt)

    total_sample = 0
    total_loss_ce = 0.
    total_correct = 0

    for batch_idx, (inputs, targets) in enumerate(train_dl):
        optimizerT.zero_grad()
        inputs, targets = inputs.to(opt.device), targets.to(opt.device)
        bs = inputs.shape[0]
        inputs = transform(inputs)
        preds = netT(inputs)
        loss_ce = criterion_CE(preds, targets)
        if torch.isnan(preds).any() or torch.isnan(targets).any():
            print(preds, targets)
        loss = loss_ce
        loss.backward()
        optimizerT.step()

        total_sample += bs
        total_loss_ce += loss_ce.detach()
        total_correct += torch.sum(torch.argmax(preds, dim=1) == targets)
        avg_acc = total_correct * 100. / total_sample
        avg_loss_ce = total_loss_ce / total_sample
        progress_bar(batch_idx, len(train_dl), 'CE Loss: {:.4f} | Acc: {:.4f}'.format(avg_loss_ce, avg_acc))
    
    schedulerT.step()

def eval(netT, optimizerT, schedulerT, test_dl, best_clean_acc, best_bd_acc, epoch, opt):
    print(" Eval:")
    netT.eval()
    
    total_sample = 0
    total_clean_correct = 0
    total_bd_correct = 0


    for batch_idx, (inputs, targets) in enumerate(test_dl):
        with torch.no_grad():
            inputs, targets = inputs.to(opt.device), targets.to(opt.device)
            bs = inputs.shape[0]
            total_sample += bs
            # Evaluate Clean
            preds_clean = netT(inputs)
            total_clean_correct += torch.sum(torch.argmax(preds_clean, 1) == targets)
            
            # Evaluate Backdoor
            inputs_bd = backdoor(inputs, opt)
            
            targets_bd = create_targets_bd(targets, opt)
            preds_bd = netT(inputs_bd)
            total_bd_correct += torch.sum(torch.argmax(preds_bd, 1) == targets_bd)

            acc_clean = total_clean_correct * 100. / total_sample
            acc_bd = total_bd_correct * 100. / total_sample
            
            info_string = "Clean Acc: {:.4f} - Best: {:.4f} | Bd Acc: {:.4f} - Best: {:.4f}".format(acc_clean, best_clean_acc, acc_bd, best_bd_acc)
            progress_bar(batch_idx, len(test_dl), info_string)
            

    # Save checkpoint 
    if(acc_clean > best_clean_acc):
        print(' Saving...')
        best_clean_acc = acc_clean
        best_bd_acc = acc_bd
        state_dict = {'netT': netT.state_dict(),
                      'schedulerT': schedulerT.state_dict(),
                      'optimizerT': optimizerT.state_dict(),
                      'best_clean_acc': acc_clean,
                      'best_bd_acc': acc_bd,
                      'epoch_current': epoch}
        torch.save(state_dict, opt.ckpt_path)
    return best_clean_acc, best_bd_acc

def main():
    opt = config.get_arguments().parse_args()
    if(opt.dataset == 'cifar10'):
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel  = 3 
    else:
        raise Exception("Invalid Dataset")

    # Dataset 
    train_dl = get_dataloader(opt, True)
    print(len(train_dl.dataset))
    test_dl = get_dataloader(opt, False)
    print(len(test_dl.dataset))
        
    # Prepare student model
    netS, optimizerS, schedulerS, netT, optimizerT, schedulerT = get_model(opt)
    opt.student_ckpt_folder = os.path.join(opt.student_checkpoints, opt.saving_prefix, opt.dataset)
    opt.student_ckpt_path = os.path.join(opt.student_ckpt_folder, '{}_{}.pth.tar'.format(opt.dataset, opt.saving_prefix))
    opt.student_log_dir = os.path.join(opt.student_ckpt_folder, 'log_dir')

    if(os.path.exists(opt.student_ckpt_path)):
            print('Load pretrained student model')
            state_dict_S = torch.load(opt.student_ckpt_path)
            netS.load_state_dict(state_dict_S['netC'])
            optimizerS.load_state_dict(state_dict_S['optimizerC'])
            schedulerS.load_state_dict(state_dict_S['schedulerC'])
            best_clean_acc = state_dict_S['best_clean_acc']
            best_bd_acc = state_dict_S['best_bd_acc']
            epoch_current = state_dict_S['epoch_current']
    else: 
        print('Pretrained student model doesnt exist')
        exit()
    print(state_dict_S['best_clean_acc'], state_dict_S['best_bd_acc'])
    
    # Finetune to get teacher model
    netT.load_state_dict(state_dict_S['netC'])
    
    opt.ckpt_folder = os.path.join(opt.checkpoints, '{}_teacher'.format(opt.saving_prefix), opt.dataset)
    opt.ckpt_path = os.path.join(opt.ckpt_folder, '{}_{}_teacher.pth.tar'.format(opt.dataset, opt.saving_prefix))
    opt.log_dir = os.path.join(opt.ckpt_folder, 'log_dir')

    print('Train teacher model!!!')
    best_clean_acc = 0.
    best_bd_acc = 0.
    epoch_current = 0
    shutil.rmtree(opt.ckpt_folder, ignore_errors=True)
    create_dir(opt.log_dir)

    for epoch in range(epoch_current, opt.n_iters):
        print('Epoch {}:'.format(epoch + 1))
        train(netT, optimizerT, schedulerT, train_dl, epoch, opt)
        best_clean_acc, best_bd_acc = eval(netT, 
                                        optimizerT, 
                                        schedulerT, 
                                        test_dl, 
                                        best_clean_acc, 
                                        best_bd_acc, 
                                        epoch, opt)
    
    
if(__name__ == '__main__'):
    main()