import os
import shutil
from functools import partial

import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision
import torchvision.transforms.functional as fn
from torch import nn

import matplotlib.pyplot as plt

import config

from classifier_models import PreActResNet18, ResNet18

from networks.models import (
    AE,
    Denormalizer,
    Normalizer,
    UnetGenerator,
)
from utils.dataloader import PostTensorTransform
from utils.dataloader_ft import get_dataloader
from utils.utils import progress_bar
from utils.dct import *

from torch.utils.tensorboard import SummaryWriter

def create_targets_bd(targets, opt):
    bd_targets = torch.ones_like(targets) * opt.target_label
    return bd_targets.to(opt.device)


def get_model(opt):
    netC = None
    optimizerC = None
    schedulerC_p1 = None
    schedulerC_p2 = None

    if opt.dataset == "cifar10":
        netC = PreActResNet18().to(opt.device)
    else:
        netC = ResNet18(num_classes=opt.num_classes).to(opt.device)

    # Optimizer
    optimizerC = torch.optim.SGD(netC.parameters(), opt.lr_C, momentum=0.9, weight_decay=5e-4, nesterov=True)
    schedulerC_p1 = torch.optim.lr_scheduler.CyclicLR(optimizerC, 
                     base_lr = opt.lr_C, 
                     max_lr = opt.lr_C_max1, 
                     step_size_up = opt.scheduler_step_size, 
                     mode = "triangular")
    schedulerC_p2 = torch.optim.lr_scheduler.CyclicLR(optimizerC, 
                     base_lr = opt.lr_C, 
                     max_lr = opt.lr_C_max2, 
                     step_size_up = opt.scheduler_step_size, 
                     mode = "triangular")

    return netC, optimizerC, schedulerC_p1, schedulerC_p2



def create_dir(path_dir):
    list_subdir = path_dir.strip(".").split("/")
    list_subdir.remove("")
    base_dir = "./"
    for subdir in list_subdir:
        base_dir = os.path.join(base_dir, subdir)
        try:
            os.mkdir(base_dir)
        except:
            pass


def backdoor(clean_x, opt):
    bs = clean_x.shape[0]
    output = torch.clone(clean_x)
    if opt.attack_name == "badnets":
        pat_size = 4
        for i in range(output.shape[0]):
            output[i][:, 32-1-pat_size:32-1, 32-1-pat_size:32-1] = 1
        return output
    
    elif opt.attack_name == "narcisuss":
        trimg = torch.from_numpy(np.load(os.path.join('./triggers', opt.attack_name + '.npy')))
        output[i] = clean_x[i]+trimg
        
    else:
        trimg = np.transpose(plt.imread(os.path.join('./triggers', opt.attack_name + '.png')), (2,0,1))
        trimg = (torch.from_numpy((trimg*2) - np.ones_like(trimg))).to(opt.device)
        for i in range(output.shape[0]):
            output[i] = clean_x[i]+trimg
    
    return output

def train(
    netC,
    optimizerC,
    schedulerC_p1,
    schedulerC_p2,
    train_dl,
    tf_writer,
    epoch,
    opt,
):
    torch.autograd.set_detect_anomaly(True)
    print(" Train:")
    netC.train()

    total_loss_ce = 0
    total_sample = 0

    total_clean = 0
    total_bd = 0
    total_clean_correct = 0
    total_bd_correct = 0
    criterion_CE = torch.nn.CrossEntropyLoss()

    denormalizer = Denormalizer(opt)
    transforms = PostTensorTransform(opt)

    for batch_idx, (inputs, targets) in enumerate(train_dl):
        inputs, targets = inputs.to(opt.device), targets.to(opt.device)
        bs = inputs.shape[0]
        num_bd = int(bs * opt.pc)
        #print(num_bd)
        

        ### Train C
        netC.train()
        optimizerC.zero_grad()
        # Create backdoor data
        # if num_bd < 1:
        #   continue
        targets_bd = torch.ones_like(targets[:num_bd]) * opt.target_label
        inputs_toChange = inputs[:num_bd]
        inputs_bd = backdoor(inputs_toChange, opt)
        
        if epoch % 2 == 0:
            #noise_bd = netG(inputs_toChange)
            # if inputs_toChange.shape[0] != 0:
            #     noise_bd = low_freq(noise_bd, opt)
            # inputs_bd = torch.clamp(inputs_toChange + noise_bd * opt.noise_rate, -1, 1)
            # if inputs_bd.shape[0] != 0:
            #     inputs_bd = gauss_smooth(inputs_bd)
            total_inputs = torch.cat([inputs_bd, inputs[num_bd:]], dim=0)
            total_inputs = transforms(total_inputs)
            total_targets = torch.cat([targets_bd, targets[num_bd:]], dim=0)
        else:
            total_inputs = inputs
            total_targets = targets

        total_preds = netC(total_inputs)

        loss_ce = criterion_CE(total_preds, total_targets)
        if torch.isnan(total_preds).any() or torch.isnan(total_targets).any():
            print(total_preds, total_targets)
        loss = loss_ce
        loss.backward()
        optimizerC.step()

        total_sample += bs
        total_loss_ce += loss_ce.detach()

        total_clean += bs - num_bd 
        total_bd += num_bd
        total_clean_correct += torch.sum(
            torch.argmax(total_preds[num_bd:], dim=1) == total_targets[num_bd:]
        )
        total_bd_correct += torch.sum(torch.argmax(total_preds[:num_bd], dim=1) == targets_bd)

        avg_acc_clean = total_clean_correct * 100.0 / total_clean
        avg_acc_bd = total_bd_correct * 100.0 / total_bd

        if batch_idx % 50 == 0:
            if not os.path.exists(opt.temps):
                os.makedirs(opt.temps)
            path = os.path.join(opt.temps, "backdoor_image.png")
            torchvision.utils.save_image(inputs_bd, path, normalize=True)

        # for tensorboard
        if not epoch % 1:
            tf_writer.add_scalars(
                "Clean Accuracy",
                {
                    "Clean": avg_acc_clean,
                    "Bd": avg_acc_bd,
                },
                epoch,
            )

        # if not epoch % 20:
        #     batch_img = torch.cat([inputs[:num_bd], inputs_bd], dim=2)
        #     if denormalizer is not None:
        #         batch_img = denormalizer(batch_img)
        #     grid = torchvision.utils.make_grid(batch_img, normalize=True)
        #     tf_writer.add_image("Images", grid, global_step=epoch)
    
        progress_bar(
                batch_idx,
                len(train_dl),
                "Clean Acc: {:.4f} | Bd Acc: {:.4f}".format(
                    avg_acc_clean,
                    avg_acc_bd,

                ),
            )
    
    if epoch < 50:
        schedulerC_p1.step()
    else:
        schedulerC_p2.step()


def eval(
    netC,
    optimizerC,
    schedulerC_p1,
    schedulerC_p2,
    test_dl,
    best_clean_acc,
    best_bd_acc,
    tf_writer,
    epoch,
    opt,
):
    print(" Eval:")
    netC.eval()

    total_clean_sample = 0
    total_bd_sample = 0
    total_clean_correct = 0
    total_bd_correct = 0

    for batch_idx, (inputs, targets) in enumerate(test_dl):
        with torch.no_grad():
            inputs, targets = inputs.to(opt.device), targets.to(opt.device)

            # Evaluate Clean
            preds_clean = netC(inputs)

            total_clean_sample += len(inputs)
            total_clean_correct += torch.sum(torch.argmax(preds_clean, 1) == targets)

            # Evaluate Backdoor
            ntrg_ind = (targets != opt.target_label).nonzero()[:, 0]
            inputs_toChange = inputs[ntrg_ind]
            targets_toChange = targets[ntrg_ind]
            inputs_bd = backdoor(inputs_toChange, opt)
            
            targets_bd = create_targets_bd(targets_toChange, opt)
            preds_bd = netC(inputs_bd)

            total_bd_sample += len(ntrg_ind)
            total_bd_correct += torch.sum(torch.argmax(preds_bd, 1) == targets_bd)

            acc_clean = total_clean_correct * 100.0 / total_clean_sample
            acc_bd = total_bd_correct * 100.0 / total_bd_sample

            info_string = "Clean Acc: {:.4f} - Best: {:.4f} | Bd Acc: {:.4f} - Best: {:.4f}".format(
                acc_clean,
                best_clean_acc,
                acc_bd,
                best_bd_acc,
            )
            progress_bar(batch_idx, len(test_dl), info_string)

    # tensorboard
    if not epoch % 1:
        tf_writer.add_scalars(
            "Test Accuracy",
            {
                "Clean": acc_clean,
                "Bd": acc_bd,
            },
            epoch,
        )

    # Save checkpoint
    if acc_clean > best_clean_acc or (acc_clean > best_clean_acc-0.5 and acc_bd > best_bd_acc):
        print(" Saving...")
        best_clean_acc = acc_clean
        best_bd_acc = acc_bd
        state_dict = {
            "netC": netC.state_dict(),
            "schedulerC_p1": schedulerC_p1.state_dict(),
            "schedulerC_p2": schedulerC_p2.state_dict(),
            "optimizerC": optimizerC.state_dict(),
            "epoch_current": epoch,
            "best_clean_acc": best_clean_acc,
            "best_bd_acc": best_bd_acc
        }
        torch.save(state_dict, opt.ckpt_path)
    return (
        best_clean_acc,
        best_bd_acc,
    )


def main():
    opt = config.get_arguments().parse_args()
    if opt.dataset == "cifar10":
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif opt.dataset == "celeba":
        opt.input_height = 64
        opt.input_width = 64
        opt.input_channel = 3
        opt.num_workers = 40
        opt.num_classes = 8
    elif opt.dataset == 'imagenet10':
        opt.input_height = 224
        opt.input_width = 224
        opt.input_channel = 3
        opt.num_classes = 10
        opt.bs = 64
    else:
        raise Exception("Invalid Dataset")

    # Dataset
    train_dl = get_dataloader(opt, True)
    test_dl = get_dataloader(opt, False)
    ft_dl = get_dataloader(opt)

    print(len(train_dl.dataset), len(test_dl.dataset))

    # prepare model
    netC, optimizerC, schedulerC_p1, schedulerC_p2 = get_model(opt)

    # Load pretrained model
    opt.ckpt_folder = os.path.join(opt.checkpoints, opt.saving_prefix, opt.dataset)
    opt.ckpt_path = os.path.join(opt.ckpt_folder, "{}_{}.pth.tar".format(opt.dataset, opt.saving_prefix))
    opt.log_dir = os.path.join(opt.ckpt_folder, "log_dir")
    create_dir(opt.log_dir)


    if opt.continue_training:
        if os.path.exists(opt.ckpt_path):
            print("Continue training!!")
            state_dict = torch.load(opt.ckpt_path)
            netC.load_state_dict(state_dict["netC"])
            optimizerC.load_state_dict(state_dict["optimizerC"])
            schedulerC_p1.load_state_dict(state_dict["schedulerC_p1"])
            schedulerC_p2.load_state_dict(state_dict["schedulerC_p2"])

            best_clean_acc = state_dict["best_clean_acc"]
            best_bd_acc = state_dict["best_bd_acc"]
            epoch_current = state_dict["epoch_current"]

            tf_writer = SummaryWriter(log_dir=opt.log_dir)
        else:
            print("Pretrained model doesnt exist")
            exit()
    else:
        print("Train from scratch!!!")
        best_clean_acc = 0.0
        best_bd_acc = 0.0
        epoch_current = 0
        shutil.rmtree(opt.ckpt_folder, ignore_errors=True)
        create_dir(opt.log_dir)

        tf_writer = SummaryWriter(log_dir=opt.log_dir)

    for epoch in range(epoch_current, opt.n_iters):
        print("Epoch {}:".format(epoch + 1))
        train(
            netC,
            optimizerC,
            schedulerC_p1,
            schedulerC_p2,
            train_dl,
            tf_writer,
            epoch,
            opt,
        )
        (best_clean_acc, best_bd_acc) = eval(
            netC,
            optimizerC,
            schedulerC_p1,
            schedulerC_p2,
            test_dl,
            best_clean_acc,
            best_bd_acc,
            tf_writer,
            epoch,
            opt,
        )

if __name__ == "__main__":
    main()