import config 
import torchvision 
import torch
import os
import shutil
import numpy as np
import torch.nn.functional as F
import json

from utils.dataloader import get_dataloader, PostTensorTransform
from utils.utils import progress_bar
from classifier_models import PreActResNet18
from networks.models import Normalizer, Denormalizer, NetC_MNIST
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import RandomErasing
from time import time


def create_dir(path_dir):
    list_subdir = path_dir.strip('.').split('/')
    list_subdir.remove('')
    base_dir = './'
    for subdir in list_subdir:
        base_dir = os.path.join(base_dir, subdir)
        try:
            os.mkdir(base_dir)
        except:
            pass
        
        
def get_model(opt):
    netC = None
    optimizerC = None
    schedulerC = None
    
    if(opt.dataset == 'cifar10'):
        # Model
        netC = PreActResNet18().to(opt.device)
    if(opt.dataset == 'gtsrb'):
        # Model
        netC = PreActResNet18(num_classes=opt.num_classes).to(opt.device)
    if(opt.dataset == 'mnist'):     
        netC = NetC_MNIST().to(opt.device)

    # Optimizer 
    optimizerC = torch.optim.SGD(netC.parameters(), opt.lr_C, momentum=0.9, weight_decay=5e-4)
      
    # Scheduler 
    schedulerC = torch.optim.lr_scheduler.MultiStepLR(optimizerC, opt.schedulerC_milestones, opt.schedulerC_lambda)
    
    return netC, optimizerC, schedulerC


def train(netC, optimizerC, schedulerC, train_dl, noise_grid, identity_grid, tf_writer, epoch, opt):
    print(" Train:")
    netC.train()
    rate_bd = opt.pc
    total_loss_ce = 0
    total_sample = 0
    
    total_clean = 0     
    total_bd = 0 
    total_noise = 0
    total_clean_correct = 0
    total_bd_correct = 0
    total_noise_correct = 0
    criterion_CE = torch.nn.CrossEntropyLoss()
    criterion_BCE = torch.nn.BCELoss()

    denormalizer = Denormalizer(opt)
    transforms = PostTensorTransform(opt).to(opt.device)
    total_time = 0
    
    for batch_idx, (inputs, targets) in enumerate(train_dl):
        optimizerC.zero_grad()
        
        inputs, targets = inputs.to(opt.device), targets.to(opt.device)
        bs = inputs.shape[0]
        
        # Create backdoor data
        num_bd = int(bs * rate_bd)
        num_noise = int(num_bd * opt.noise_ratio)
        grid_temps = identity_grid + opt.s * noise_grid / opt.input_height
        grid_temps = torch.clamp(grid_temps, -1, 1)

        ins = torch.rand(num_noise, opt.input_height, opt.input_height, 2).to(opt.device) * 2 - 1
        grid_temps2 = grid_temps.repeat(num_noise, 1, 1, 1) +  ins / opt.input_height    
        grid_temps2 = torch.clamp(grid_temps2, -1, 1)

        inputs_bd = F.grid_sample(inputs[:num_bd], grid_temps.repeat(num_bd, 1, 1, 1), align_corners=True)
        targets_bd = torch.ones_like(targets[:num_bd]) * opt.target_label

        inputs_noise = F.grid_sample(inputs[num_bd:(num_bd+num_noise)], grid_temps2, align_corners=True)
        
        total_inputs = torch.cat([inputs_bd, inputs_noise, inputs[(num_bd+num_noise):]], dim=0)
        total_inputs = transforms(total_inputs)
        total_targets = torch.cat([targets_bd, targets[num_bd:]], dim=0)
        start = time()
        total_preds = netC(total_inputs)
        total_time += time() - start

        loss_ce = criterion_CE(total_preds, total_targets)

        loss = loss_ce 
        loss.backward()
        
        optimizerC.step()
        
        total_sample += bs
        total_loss_ce += loss_ce.detach()
        
        total_clean += bs - num_bd - num_noise
        total_bd += num_bd
        total_noise += num_noise
        total_clean_correct += torch.sum(torch.argmax(total_preds[(num_bd+num_noise):], dim=1) == total_targets[(num_bd+num_noise):])
        total_bd_correct += torch.sum(torch.argmax(total_preds[:num_bd], dim=1) == targets_bd)
        total_noise_correct += torch.sum(torch.argmax(total_preds[num_bd:(num_bd+num_noise)], dim=1) == total_targets[num_bd:(num_bd+num_noise)])

        avg_acc_clean = total_clean_correct * 100. / total_clean
        avg_acc_bd = total_bd_correct * 100. / total_bd
        avg_acc_noise = total_noise_correct * 100. / total_noise
        avg_loss_ce = total_loss_ce / total_sample
        progress_bar(batch_idx, len(train_dl), 'CE Loss: {:.4f} | Clean Acc: {:.4f} | Bd Acc: {:.4f} | Noise Acc: {:.4f}'.format(avg_loss_ce,
                                                                                                            avg_acc_clean,
                                                                                                            avg_acc_bd, avg_acc_noise))


        # Image for tensorboard
        if(batch_idx == len(train_dl) - 2):
            residual = inputs_bd - inputs[:num_bd]
            batch_img = torch.cat([inputs[:num_bd], inputs_bd, total_inputs[:num_bd], residual], dim=2)
            batch_img = denormalizer(batch_img)
            batch_img = F.upsample(batch_img, scale_factor=(4, 4))
            grid = torchvision.utils.make_grid(batch_img, normalize=True)

    # for tensorboard
    if(not epoch % 1):
        tf_writer.add_scalars('Clean Accuracy', {'Clean': avg_acc_clean,
                                          'Bd': avg_acc_bd, 'Noise': avg_acc_noise}, epoch)
        tf_writer.add_image('Images', grid, global_step=epoch)
        
    schedulerC.step()        


def eval(netC, optimizerC, schedulerC, test_dl, noise_grid, identity_grid, best_clean_acc, best_bd_acc, best_noise_acc, tf_writer, epoch, opt):
    print(" Eval:")
    netC.eval()
    
    total_sample = 0
    total_clean_correct = 0
    total_bd_correct = 0
    total_noise_correct = 0
    total_ae_loss = 0
    
    criterion_BCE = torch.nn.BCELoss()
    
    for batch_idx, (inputs, targets) in enumerate(test_dl):
        with torch.no_grad():
            inputs, targets = inputs.to(opt.device), targets.to(opt.device)
            bs = inputs.shape[0]
            total_sample += bs
            # Evaluate Clean
            preds_clean = netC(inputs)
            total_clean_correct += torch.sum(torch.argmax(preds_clean, 1) == targets)
            
            # Evaluate Backdoor
            grid_temps = identity_grid + opt.s * noise_grid / opt.input_height
            grid_temps = torch.clamp(grid_temps, -1, 1)

            ins = torch.rand(bs, opt.input_height, opt.input_height, 2).to(opt.device) * 2 - 1
            grid_temps2 = grid_temps.repeat(bs, 1, 1, 1) + ins / opt.input_height
            grid_temps2 = torch.clamp(grid_temps2, -1, 1)

            inputs_bd = F.grid_sample(inputs, grid_temps.repeat(bs, 1, 1, 1), align_corners=True)
            targets_bd = torch.ones_like(targets) * opt.target_label
            preds_bd = netC(inputs_bd)
            total_bd_correct += torch.sum(torch.argmax(preds_bd, 1) == targets_bd)
            
            inputs_noise = F.grid_sample(inputs, grid_temps2, align_corners=True)
            preds_noise = netC(inputs_noise)
            total_noise_correct += torch.sum(torch.argmax(preds_noise, 1) == targets)

            acc_clean = total_clean_correct * 100. / total_sample
            acc_bd = total_bd_correct * 100. / total_sample
            acc_noise = total_noise_correct * 100. / total_sample
            
            info_string = "Clean Acc: {:.4f} - Best: {:.4f} | Bd Acc: {:.4f} - Best: {:.4f} | Noise: {:.4f}".format(acc_clean, best_clean_acc,
                                                                                                    acc_bd, best_bd_acc, acc_noise, best_noise_acc)
            progress_bar(batch_idx, len(test_dl), info_string)
            
    # tensorboard
    if(not epoch % 1):
        tf_writer.add_scalars('Test Accuracy', {'Clean': acc_clean,
                                                 'Bd': acc_bd}, epoch)

    # Save checkpoint 
    if(acc_clean > best_clean_acc or (acc_clean > best_clean_acc- 0.1 and acc_bd > best_bd_acc)):
        print(' Saving...')
        best_clean_acc = acc_clean
        best_bd_acc = acc_bd
        best_noise_acc = acc_noise
        state_dict = {'netC': netC.state_dict(),
                      'schedulerC': schedulerC.state_dict(),
                      'optimizerC': optimizerC.state_dict(),
                      'best_clean_acc': acc_clean,
                      'best_bd_acc': acc_bd,
                      'best_cross_acc': acc_noise,
                      'epoch_current': epoch,
                      'identity_grid': identity_grid,
                      'noise_grid': noise_grid}
        torch.save(state_dict, opt.ckpt_path)
        with open(os.path.join(opt.ckpt_folder, 'results.txt'), 'w+') as f:
	        results_dict = {'clean_acc': best_clean_acc.item(),
	        	    	'bd_acc': best_bd_acc.item(),
	          	    	'noise_acc': best_noise_acc.item()}
	        json.dump(results_dict, f, indent=2)
    
		
		
    return best_clean_acc, best_bd_acc, best_noise_acc
    

def main():
    opt = config.get_arguments().parse_args()
    
    if(opt.dataset == 'cifar10'):
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel  = 3 
        opt.num_classes = 10
    elif(opt.dataset == 'gtsrb'):
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel  = 3
        opt.num_classes = 43
    elif(opt.dataset == 'mnist'):
        opt.input_height = 28
        opt.input_width = 28
        opt.input_channel = 1
        opt.num_classes = 10
    else:
        raise Exception("Invalid Dataset")

    # Dataset 
    train_dl = get_dataloader(opt, True)
    test_dl = get_dataloader(opt, False)
        
    # prepare model
    netC, optimizerC, schedulerC = get_model(opt)
        
    # Load pretrained model
    mode = opt.saving_prefix
    opt.ckpt_folder = os.path.join(opt.checkpoints, 'k_{}'.format(opt.k), 's_{}'.format(opt.s), opt.dataset)
    opt.ckpt_path = os.path.join(opt.ckpt_folder, '{}_all2one_k{}_s{}.pth.tar'.format(opt.dataset, opt.k, opt.s))
    opt.log_dir = os.path.join(opt.ckpt_folder, 'log_dir')
    print(opt.ckpt_path)
    create_dir(opt.log_dir)

    if(opt.continue_training):
        if(os.path.exists(opt.ckpt_path)):
            print('Continue training!!')
            state_dict = torch.load(opt.ckpt_path)
            netC.load_state_dict(state_dict['netC'])

            best_clean_acc = state_dict['best_clean_acc']
            best_bd_acc = state_dict['best_bd_acc']
            best_noise_acc = state_dict['best_cross_acc']
            epoch_current = state_dict['epoch_current']

            identity_grid = state_dict['identity_grid']
            noise_grid = state_dict['noise_grid']

            tf_writer = SummaryWriter(log_dir=opt.log_dir)
        else: 
            print('Pretrained model doesnt exist')
            exit()
    else:
        print('Train from scratch!!!')
        best_clean_acc = 0.
        best_bd_acc = 0.
        best_noise_acc = 0.
        epoch_current = 0

        # Prepare grid
        ins = torch.rand(1, 2, opt.k, opt.k) * 2 - 1
        ins = ins / torch.mean(torch.abs(ins))
        noise_grid = F.upsample(ins, size=opt.input_height, mode='bicubic', align_corners=True).permute(0, 2, 3, 1).to(opt.device)
        array1d = torch.linspace(-1, 1, steps=opt.input_height)
        x, y = torch.meshgrid(array1d, array1d)
        identity_grid = torch.stack((y, x), 2)[None, ...].to(opt.device)

        shutil.rmtree(opt.ckpt_folder, ignore_errors=True)
        create_dir(opt.log_dir)
        with open(os.path.join(opt.ckpt_folder, 'opt.json'), 'w+') as f:
            json.dump(opt.__dict__, f, indent=2)
        tf_writer = SummaryWriter(log_dir=opt.log_dir)
        
    for epoch in range(epoch_current, opt.n_iters):
        print('Epoch {}:'.format(epoch + 1))
        train(netC, optimizerC, schedulerC, train_dl, noise_grid, identity_grid, tf_writer, epoch, opt)
        best_clean_acc, best_bd_acc, best_noise_acc = eval(netC,
                                            optimizerC, 
                                            schedulerC, 
                                            test_dl, 
                                            noise_grid, 
                                            identity_grid, 
                                            best_clean_acc,
                                            best_bd_acc, best_noise_acc, tf_writer, epoch, opt)
    
    
if(__name__ == '__main__'):
    main()
