import os
from sched import scheduler
import config
import shutil
import numpy as np
import random
from copy import deepcopy
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

from sam import SAM

from dataloader import get_dataloader

import sys
sys.path.insert(0, "../")

from classifier_models import PreActResNet18
from networks.models import UnetGenerator
from utils.dataloader import PostTensorTransform
from utils.utils import progress_bar


def create_dir(path_dir):
    list_subdir = path_dir.strip('.').split('/')
    list_subdir.remove('')
    base_dir = './'
    for subdir in list_subdir:
        base_dir = os.path.join(base_dir, subdir)
        try:
            os.mkdir(base_dir)
        except:
            pass


def get_model(opt):
    netC = None
    optimizerC = None
    schedulerC = None
    netT = None
    optimizerT = None
    schedulerT = None
    
    if(opt.dataset == 'cifar10'):
        # Model
        netC = PreActResNet18().to(opt.device)
        netT = PreActResNet18().to(opt.device)

    # Optimizer 
    optimizerC = torch.optim.SGD(netC.parameters(), opt.lr_C, momentum=0.9, weight_decay=5e-4, nesterov=True)
    schedulerC = torch.optim.lr_scheduler.MultiStepLR(optimizerC, opt.schedulerC_milestones, opt.schedulerC_lambda)
    base_optimizerT = torch.optim.SGD
    optimizerT = SAM(netT.parameters(), base_optimizerT, opt.lr_C, momentum=0.9, weight_decay=5e-4, nesterov=True)
    schedulerT = torch.optim.lr_scheduler.MultiStepLR(optimizerT, opt.schedulerC_milestones, opt.schedulerC_lambda)
    
    return netC, optimizerC, schedulerC, netT, optimizerT, schedulerT

def backdoor(clean_x, opt):
    bs = clean_x.shape[0]
    output = torch.clone(clean_x)
    if opt.attack_name == "badnets":
        pat_size = 4
        for i in range(output.shape[0]):
            output[i][:, 32-1-pat_size:32-1, 32-1-pat_size:32-1] = 1
        return output
    
    elif opt.attack_name == "narcisuss":
        trimg = torch.from_numpy(np.load(os.path.join('./triggers', opt.attack_name + '.npy')))
        output[i] = clean_x[i]+trimg
        
    else:
        trimg = np.transpose(plt.imread(os.path.join('./triggers', opt.attack_name + '.png')), (2,0,1))
        trimg = (torch.from_numpy((trimg*2) - np.ones_like(trimg))).to(opt.device)
        for i in range(output.shape[0]):
            output[i] = clean_x[i]+trimg
    
    return output

def create_targets_bd(targets, opt):
    bd_targets = torch.ones_like(targets) * opt.target_label
    return bd_targets.to(opt.device)

def train(netT, optimizerT, schedulerT, train_dl, epoch, opt):
    torch.autograd.set_detect_anomaly(True)
    print(" Train:")
    netT.train()

    criterion_CE = torch.nn.CrossEntropyLoss()
    transform = PostTensorTransform(opt)

    total_sample = 0
    total_loss_ce = 0.
    total_correct = 0

    for batch_idx, (inputs, targets) in enumerate(train_dl):
        optimizerT.zero_grad()
        inputs, targets = inputs.to(opt.device), targets.to(opt.device)
        bs = inputs.shape[0]
        inputs = transform(inputs)
        preds = netT(inputs)

        def closure():
            loss = criterion_CE(preds, targets)
            loss.backward()
            return loss
        
        loss_ce = criterion_CE(preds, targets)
        if torch.isnan(preds).any() or torch.isnan(targets).any():
            print(preds, targets)
        loss = loss_ce
        loss.backward()
        optimizerT.step(closure)

        total_sample += bs
        total_loss_ce += loss_ce.detach()
        total_correct += torch.sum(torch.argmax(preds, dim=1) == targets)
        avg_acc = total_correct * 100. / total_sample
        avg_loss_ce = total_loss_ce / total_sample
    print('CE Loss: {:.4f} | Acc: {:.4f}'.format(avg_loss_ce, avg_acc))
    
    schedulerT.step()

def eval(netT, optimizerT, schedulerT, test_dl, best_clean_acc, best_bd_acc, epoch, opt):
    print(" Eval:")
    netT.eval()
    
    total_sample = 0
    total_clean_correct = 0
    total_bd_correct = 0

    for batch_idx, (inputs, targets) in enumerate(test_dl):
        with torch.no_grad():
            inputs, targets = inputs.to(opt.device), targets.to(opt.device)
            bs = inputs.shape[0]
            total_sample += bs
            # Evaluate Clean
            preds_clean = netT(inputs)
            total_clean_correct += torch.sum(torch.argmax(preds_clean, 1) == targets)
            
            # Evaluate Backdoor
            inputs_bd = backdoor(inputs, opt)
            
            targets_bd = create_targets_bd(targets, opt)
            preds_bd = netT(inputs_bd)
            total_bd_correct += torch.sum(torch.argmax(preds_bd, 1) == targets_bd)

            acc_clean = total_clean_correct * 100. / total_sample
            acc_bd = total_bd_correct * 100. / total_sample
            
            info_string = "Clean Acc: {:.4f} - Best: {:.4f} | Bd Acc: {:.4f} - Best: {:.4f}".format(acc_clean, best_clean_acc, acc_bd, best_bd_acc)
            progress_bar(info_string)
            

    return acc_clean, acc_bd

def main():
    opt = config.get_arguments().parse_args()
    if(opt.dataset == 'cifar10'):
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel  = 3 
    else:
        raise Exception("Invalid Dataset")

    # Dataset 
    train_dl = get_dataloader(opt, True)
    print(len(train_dl.dataset))
    test_dl = get_dataloader(opt, False)
    print(len(test_dl.dataset))
        
    # Prepare student model
    netS, optimizerS, schedulerS, netT, optimizerT, schedulerT = get_model(opt)
    opt.backdoored_model_folder = os.path.join(opt.checkpoints, opt.saving_prefix, opt.dataset)
    opt.backdoored_model_path = os.path.join(opt.backdoored_model_folder, '{}_{}.pth.tar'.format(opt.dataset, opt.saving_prefix))
    opt.backdoored_model_log_dir = os.path.join(opt.backdoored_model_folder, 'log_dir')

    if(os.path.exists(opt.backdoored_model_path)):
            print('Load pretrained backdoored model')
            state_dict_S = torch.load(opt.backdoored_model_path)
            netS.load_state_dict(state_dict_S['netC'])
            optimizerS.load_state_dict(state_dict_S['optimizerC'])
            schedulerS.load_state_dict(state_dict_S['schedulerC'])
            best_clean_acc = state_dict_S['best_clean_acc']
            best_bd_acc = state_dict_S['best_bd_acc']
            epoch_current = state_dict_S['epoch_current']
    else: 
        print('Pretrained model doesnt exist', opt.backdoored_model_path)
        exit()
    print(state_dict_S['best_clean_acc'], state_dict_S['best_bd_acc'])
    
    # Finetune to get teacher model
    netT.load_state_dict(state_dict_S['netC'])
    
    opt.ckpt_folder = os.path.join(opt.result_checkpoints, '{}_finetuned_sam'.format(opt.saving_prefix), opt.dataset)
    opt.ckpt_path = os.path.join(opt.ckpt_folder, '{}_{}_finetuned_sam.pth.tar'.format(opt.dataset, opt.saving_prefix))
    opt.log_dir = os.path.join(opt.ckpt_folder, 'log_dir')

    print('Fine tune model!!!')
    best_clean_acc = 0.
    best_bd_acc = 0.
    epoch_current = 0
    shutil.rmtree(opt.ckpt_folder, ignore_errors=True)
    create_dir(opt.log_dir)

    with open(opt.result_file, 'w') as f:
        for epoch in range(epoch_current, opt.n_iters):
            print('Epoch {}:'.format(epoch + 1))
            train(netT, optimizerT, schedulerT, train_dl, epoch, opt)
            best_clean_acc, best_bd_acc = eval(netT, 
                                        optimizerT, 
                                        schedulerT, 
                                        test_dl, 
                                        best_clean_acc, 
                                        best_bd_acc, 
                                        epoch, opt)
            acc_clean = best_clean_acc.cpu().detach().numpy()
            acc_bd = best_bd_acc.cpu().detach().numpy()
            f.write('%s\t%s\t%s\n' % (epoch, acc_clean, acc_bd))
    
    
if(__name__ == '__main__'):
    main()