import os
from sched import scheduler
import config
import shutil
import numpy as np
import random
from copy import deepcopy
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

from dataloader import get_dataloader

import sys
sys.path.insert(0, "../")

from classifier_models import PreActResNet18, ResNet18
from networks.models import UnetGenerator
from utils.dataloader import PostTensorTransform
from utils.utils import progress_bar

def create_dir(path_dir):
    list_subdir = path_dir.strip('.').split('/')
    list_subdir.remove('')
    base_dir = './'
    for subdir in list_subdir:
        base_dir = os.path.join(base_dir, subdir)
        try:
            os.mkdir(base_dir)
        except:
            pass

def backdoor(clean_x, opt):
    bs = clean_x.shape[0]
    output = torch.clone(clean_x)
    if opt.attack_name == "badnets":
        pat_size = 4
        for i in range(output.shape[0]):
            output[i][:, 32-1-pat_size:32-1, 32-1-pat_size:32-1] = 1
        return output
    
    elif opt.attack_name == "narcisuss":
        trimg = torch.from_numpy(np.load(os.path.join('./triggers', opt.attack_name + '.npy')))
        output[i] = clean_x[i]+trimg
        
    else:
        trimg = np.transpose(plt.imread(os.path.join('./triggers', opt.attack_name + '.png')), (2,0,1))
        trimg = (torch.from_numpy((trimg*2) - np.ones_like(trimg))).to(opt.device)
        for i in range(output.shape[0]):
            output[i] = clean_x[i]+trimg
    
    return output


def get_model(opt):
    netC = None
    optimizerC = None
    schedulerC_p1 = None
    schedulerC_p2 = None

    if opt.dataset == "cifar10":
        netC = PreActResNet18().to(opt.device)
    else:
        netC = ResNet18(num_classes=opt.num_classes).to(opt.device)

    # Optimizer
    optimizerC = torch.optim.SGD(netC.parameters(), opt.lr_C, momentum=0.9, weight_decay=5e-4, nesterov=True)
    schedulerC_p1 = torch.optim.lr_scheduler.CyclicLR(optimizerC, 
                     base_lr = opt.lr_C, 
                     max_lr = opt.lr_C_max1, 
                     step_size_up = opt.scheduler_step_size, 
                     mode = "triangular")
    schedulerC_p2 = torch.optim.lr_scheduler.CyclicLR(optimizerC, 
                     base_lr = opt.lr_C, 
                     max_lr = opt.lr_C_max2, 
                     step_size_up = opt.scheduler_step_size, 
                     mode = "triangular")

    return netC, optimizerC, schedulerC_p1, schedulerC_p2

def create_targets_bd(targets, opt):
    bd_targets = torch.ones_like(targets) * opt.target_label
    return bd_targets.to(opt.device)

def train(netT, optimizerT, schedulerT_p1, schedulerT_p2, train_dl, epoch, opt):
    torch.autograd.set_detect_anomaly(True)
    print(" Train:")
    netT.train()

    criterion_CE = torch.nn.CrossEntropyLoss()
    transform = PostTensorTransform(opt)

    total_sample = 0
    total_loss_ce = 0.
    total_correct = 0

    for batch_idx, (inputs, targets) in enumerate(train_dl):
        optimizerT.zero_grad()
        inputs, targets = inputs.to(opt.device), targets.to(opt.device)
        bs = inputs.shape[0]
        inputs = transform(inputs)
        preds = netT(inputs)
        loss_ce = criterion_CE(preds, targets)
        if torch.isnan(preds).any() or torch.isnan(targets).any():
            print(preds, targets)
        loss = loss_ce
        loss.backward()
        optimizerT.step()

        total_sample += bs
        total_loss_ce += loss_ce.detach()
        total_correct += torch.sum(torch.argmax(preds, dim=1) == targets)
        avg_acc = total_correct * 100. / total_sample
        avg_loss_ce = total_loss_ce / total_sample
    print('CE Loss: {:.4f} | Acc: {:.4f}'.format(avg_loss_ce, avg_acc))
    
    if epoch < (opt.n_iters / 2):
        schedulerT_p1.step()
    else:
        schedulerT_p2.step()

def eval(netT, optimizerT, schedulerT_p1, schedulerT_p2, test_dl, best_clean_acc, best_bd_acc, epoch, opt):
    print(" Eval:")
    netT.eval()
    
    total_sample = 0
    total_bd_sample = 0
    total_clean_correct = 0
    total_bd_correct = 0

    for batch_idx, (inputs, targets) in enumerate(test_dl):
        with torch.no_grad():
            inputs, targets = inputs.to(opt.device), targets.to(opt.device)
            bs = inputs.shape[0]
            total_sample += bs
            # Evaluate Clean
            preds_clean = netT(inputs)
            total_clean_correct += torch.sum(torch.argmax(preds_clean, 1) == targets)
            
            # Evaluate Backdoor
            ntrg_ind = (targets != opt.target_label).nonzero()[:, 0]
            inputs_toChange = inputs[ntrg_ind]
            targets_toChange = targets[ntrg_ind]
            inputs_bd = backdoor(inputs_toChange, opt)
            
            #inputs_bd = torch.clamp(inputs_toChange + noise_bd * opt.noise_rate, -1, 1)
            #inputs_bd = gauss_smooth(inputs_bd)
            targets_bd = create_targets_bd(targets_toChange, opt)
            preds_bd = netT(inputs_bd)

            total_bd_sample += len(ntrg_ind)
            total_bd_correct += torch.sum(torch.argmax(preds_bd, 1) == targets_bd)

            acc_clean = total_clean_correct * 100.0 / total_sample
            acc_bd = total_bd_correct * 100.0 / total_bd_sample
            
            info_string = "Clean Acc: {:.4f} - Best: {:.4f} | Bd Acc: {:.4f} - Best: {:.4f}".format(acc_clean, best_clean_acc, acc_bd, best_bd_acc)
            progress_bar(info_string)
            
    return acc_clean, acc_bd

def main():
    opt = config.get_arguments().parse_args()
    if(opt.dataset == 'cifar10'):
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel  = 3 
    else:
        raise Exception("Invalid Dataset")

    # Dataset 
    train_dl = get_dataloader(opt, True)
    print(len(train_dl.dataset))
    test_dl = get_dataloader(opt, False)
    print(len(test_dl.dataset))
        
    # Prepare pretrained model
    netT, optimizerT, schedulerT_p1, schedulerT_p2 = get_model(opt)
    opt.backdoored_model_folder = os.path.join(opt.checkpoints, opt.saving_prefix, opt.dataset)
    opt.backdoored_model_path = os.path.join(opt.backdoored_model_folder, '{}_{}.pth.tar'.format(opt.dataset, opt.saving_prefix))
    opt.backdoored_model_log_dir = os.path.join(opt.backdoored_model_folder, 'log_dir')

    if(os.path.exists(opt.backdoored_model_path)):
            print('Load pretrained backdoored model')
            state_dict_C = torch.load(opt.backdoored_model_path)

    else: 
        print('Pretrained model doesnt exist', opt.backdoored_model_path)
        exit()
    print(state_dict_C['best_clean_acc'], state_dict_C['best_bd_acc'])
    
    # Finetune to get teacher model
    netT.load_state_dict(state_dict_C['netC'])
    
    opt.ckpt_folder = os.path.join(opt.result_checkpoints, '{}_super_finetuned'.format(opt.saving_prefix), opt.dataset)
    opt.ckpt_path = os.path.join(opt.ckpt_folder, '{}_{}_super_finetuned.pth.tar'.format(opt.dataset, opt.saving_prefix))
    opt.log_dir = os.path.join(opt.ckpt_folder, 'log_dir')

    print('Fine tune model!!!')
    best_clean_acc = 0.
    best_bd_acc = 0.
    epoch_current = 0
    shutil.rmtree(opt.ckpt_folder, ignore_errors=True)
    create_dir(opt.log_dir)

    with open(opt.result_file, 'w') as f:
        for epoch in range(epoch_current, opt.n_iters):
            print('Epoch {}:'.format(epoch + 1))
            train(netT, optimizerT, schedulerT_p1, schedulerT_p2, train_dl, epoch, opt)
            best_clean_acc, best_bd_acc = eval(netT, 
                                        optimizerT, 
                                        schedulerT_p1,
                                        schedulerT_p2, 
                                        test_dl, 
                                        best_clean_acc, 
                                        best_bd_acc, 
                                        epoch, opt)
            acc_clean = best_clean_acc.cpu().detach().numpy()
            acc_bd = best_bd_acc.cpu().detach().numpy()
            f.write('%s\t%s\t%s\n' % (epoch, acc_clean, acc_bd))
    
    
if(__name__ == '__main__'):
    main()