import os
import numpy as np
import torch.nn as nn
import random
import argparse
import deepcore.nets as nets
from deepcore.datasets.cifar10_poison import CIFAR10_BD
from deepcore.datasets.gtsrb_poison import GTSRB_BD
from deepcore.datasets.tinyimagenet_poison import TinyImageNet_BD
import deepcore.methods as methods
from torchvision import transforms
from utils import *
from datetime import datetime
from time import sleep


def main():
    parser = argparse.ArgumentParser(description='Parameter Processing')

    # Basic arguments
    parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
    parser.add_argument('--model', type=str, default='ResNet18', help='model')
    parser.add_argument('--selection', type=str, default="uniform", help="selection method")
    parser.add_argument('--num_exp', type=int, default=5, help='the number of experiments')
    parser.add_argument('--num_eval', type=int, default=10, help='the number of evaluating randomly initialized models')
    parser.add_argument('--epochs', default=150, type=int, help='number of total epochs to run')
    parser.add_argument('--data_path', type=str, default='/share', help='dataset path')
    parser.add_argument('--gpu', default=None, nargs="+", type=int, help='GPU id to use')
    parser.add_argument('--print_freq', '-p', default=100, type=int, help='print frequency (default: 20)')
    parser.add_argument('--fraction', default=0.1, type=float, help='fraction of data to be selected (default: 0.1)')
    parser.add_argument('--seed', default=34, type=int, help="random seed")
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument("--cross", type=str, nargs="+", default=None, help="models for cross-architecture experiments")

    parser.add_argument("--save_model", action="store_true", default=False, help="save model checkpoints")

    # Optimizer and scheduler
    parser.add_argument('--optimizer', default="SGD", help='optimizer to use, e.g. SGD, Adam')
    parser.add_argument('--lr', type=float, default=0.1, help='learning rate for updating network parameters')
    parser.add_argument('--min_lr', type=float, default=1e-4, help='minimum learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum (default: 0.9)')
    parser.add_argument('-wd', '--weight_decay', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 5e-4)',
                        dest='weight_decay')
    parser.add_argument("--nesterov", default=True, type=str_to_bool, help="if set nesterov")
    parser.add_argument("--scheduler", default="CosineAnnealingLR", type=str, help=
    "Learning rate scheduler")
    parser.add_argument("--gamma", type=float, default=.1, help="Gamma value for StepLR")
    parser.add_argument("--step_size", type=float, default=50, help="Step size for StepLR")

    # Training
    parser.add_argument('--batch', '--batch-size', "-b", default=128, type=int, metavar='N',
                        help='mini-batch size (default: 128)')
    parser.add_argument("--train_batch", "-tb", default=None, type=int,
                     help="batch size for training, if not specified, it will equal to batch size in argument --batch")
    parser.add_argument("--selection_batch", "-sb", default=None, type=int,
                     help="batch size for selection, if not specified, it will equal to batch size in argument --batch")

    parser.add_argument("--unlearn", default='False', type=str_to_bool, help='if unlearn intermediate subset after each epoch')
    parser.add_argument("--reg_gamma", default=0.001, type=float, help='regularization factor of unlearning')
    parser.add_argument("--unlearn_smooth", default=0.9, type=float, help='factor of label-smoothing')
    parser.add_argument("--adaptsize", default='False', type=str_to_bool, help='subset size by the threshold of mean uncertainty')

    # Testing
    parser.add_argument("--test_interval", '-ti', default=1, type=int, help=
    "the number of training epochs to be preformed between two test epochs; a value of 0 means no test will be run (default: 1)")
    parser.add_argument("--test_fraction", '-tf', type=float, default=1.,
                        help="proportion of test dataset used for evaluating the model (default: 1.)")

    # Selecting
    parser.add_argument("--selection_epochs", "-se", default=40, type=int,
                        help="number of epochs whiling performing selection on full dataset")
    parser.add_argument('--selection_momentum', '-sm', default=0.9, type=float, metavar='M',
                        help='momentum whiling performing selection (default: 0.9)')
    parser.add_argument('--selection_weight_decay', '-swd', default=5e-4, type=float,
                        metavar='W', help='weight decay whiling performing selection (default: 5e-4)',
                        dest='selection_weight_decay')
    parser.add_argument('--selection_optimizer', "-so", default="SGD",
                        help='optimizer to use whiling performing selection, e.g. SGD, Adam')
    parser.add_argument("--selection_nesterov", "-sn", default=True, type=str_to_bool,
                        help="if set nesterov whiling performing selection")
    parser.add_argument('--selection_lr', '-slr', type=float, default=0.1, help='learning rate for selection')
    parser.add_argument("--selection_test_interval", '-sti', default=1, type=int, help=
    "the number of training epochs to be preformed between two test epochs during selection (default: 1)")
    parser.add_argument("--selection_test_fraction", '-stf', type=float, default=1.,
             help="proportion of test dataset used for evaluating the model while preforming selection (default: 1.)")
    parser.add_argument('--balance', default=True, type=str_to_bool,
                        help="whether balance selection is performed per class")

    parser.add_argument("--warmup_epochs", "-warmup", default=0, type=int,
                        help="number of epochs for warmup before performing selection on full dataset")

    parser.add_argument("--repeats", "-re", default=5, type=int,
                        help="number of cycles of performing selection on the full dataset")

    parser.add_argument('--accumulative', default=True, type=str_to_bool, help='use accumulative uncertainty')
    parser.add_argument('--prefix_coresize', default=True, type=str_to_bool, help='use coreset size based on warmup')
    parser.add_argument('--correct_only', default=False, type=str_to_bool, help='adaptive size with threshold = average uncertainty of correctly classified samples')
    parser.add_argument("--selection_dataaug", "-saug", default=False, type=str_to_bool, help="use data augmentation in selection")
    parser.add_argument('--eval_purif', "-evlpur", default=False, type=str_to_bool, help='Evaluate on poisoned set with clean labeling')

    parser.add_argument('--save_coreset', default=False, type=str_to_bool, help='save coreset indices')

    # Algorithm
    parser.add_argument('--submodular', default="FacilityLocation", help="specifiy submodular function to use", choices=["FacilityLocation", "GraphCat"])
    parser.add_argument('--submodular_greedy', default="LazyGreedy", help="specifiy greedy algorithm for submodular optimization")
    parser.add_argument('--uncertainty', default="Entropy", help="specifiy uncertanty score to use")
    parser.add_argument('--uncertainty_temperature', '-ut', default=1, type=float,
                        help='temperature used for smoothing prediction for uncertainty measurement')
    parser.add_argument('--norm_score', '-ns', default=False, type=str_to_bool, help='Normalize scores of each cycle')

    # Checkpoint and resumption
    parser.add_argument('--save_path', "-sp", type=str, default='', help='path to save results (default: do not save)')
    parser.add_argument('--resume', '-r', type=str, default='', help="path to latest checkpoint (default: do not load)")
    
    # Backdoor
    parser.add_argument('--inject_portion', type=float, default=0.1, help='ratio of backdoor samples')
    parser.add_argument('--target_label', type=int, default=0, help='class of target label')
    parser.add_argument('--trigger_type', type=str, default='gridTrigger', help='type of backdoor trigger')
    parser.add_argument('--target_type', type=str, default='all2one', help='type of backdoor label')
    parser.add_argument('--trig_w', type=int, default=3, help='width of trigger pattern')
    parser.add_argument('--trig_h', type=int, default=3, help='height of trigger pattern')

    # for dynamic atttack
    # parser.add_argument("--data_root", type=str, default="data/")
    parser.add_argument("--temps", type=str, default="./temps")

    parser.add_argument("--input_height", type=int, default=32)
    parser.add_argument("--input_width", type=int, default=32)
    parser.add_argument("--input_channel", type=int, default=3)
    parser.add_argument("--num_classes", type=int, default=10)

    parser.add_argument("--batchsize", type=int, default=128)
    parser.add_argument("--lr_G", type=float, default=1e-2)
    parser.add_argument("--lr_C", type=float, default=1e-2)
    parser.add_argument("--lr_M", type=float, default=1e-2)
    parser.add_argument("--schedulerG_milestones", type=list, default=[200, 300, 400, 500])
    parser.add_argument("--schedulerC_milestones", type=list, default=[100, 200, 300, 400])
    parser.add_argument("--schedulerM_milestones", type=list, default=[10, 20])
    parser.add_argument("--schedulerG_lambda", type=float, default=0.1)
    parser.add_argument("--schedulerC_lambda", type=float, default=0.1)
    parser.add_argument("--schedulerM_lambda", type=float, default=0.1)
    parser.add_argument("--n_iters", type=int, default=600)
    parser.add_argument("--lambda_div", type=float, default=1)
    parser.add_argument("--lambda_norm", type=float, default=100)
    parser.add_argument("--num_workers", type=float, default=4)

    parser.add_argument("--attack_mode", type=str, default="all2one", help="all2one or all2all")
    parser.add_argument("--p_attack", type=float, default=0.1)
    parser.add_argument("--p_cross", type=float, default=0.1)
    parser.add_argument("--mask_density", type=float, default=0.032)
    parser.add_argument("--EPSILON", type=float, default=1e-7)

    parser.add_argument("--random_rotation", type=int, default=10)
    parser.add_argument("--random_crop", type=int, default=5)

    args = parser.parse_args()
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print(args)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    if args.train_batch is None:
        args.train_batch = args.batch
    if args.selection_batch is None:
        args.selection_batch = args.batch
    if args.save_path != "" and not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    # if not os.path.exists(args.data_path):
    #     os.mkdir(args.data_path)

    if args.resume != "":
        # Load checkpoint
        try:
            print("=> Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=args.device)
            assert {"exp", "epoch", "state_dict", "opt_dict", "best_acc1", "rec", "subset", "sel_args"} <= set(
                checkpoint.keys())
            assert 'indices' in checkpoint["subset"].keys()
            start_exp = checkpoint['exp']
            start_epoch = checkpoint["epoch"]
        except AssertionError:
            try:
                assert {"exp", "subset", "sel_args"} <= set(checkpoint.keys())
                assert 'indices' in checkpoint["subset"].keys()
                print("=> The checkpoint only contains the subset, training will start from the begining")
                start_exp = checkpoint['exp']
                start_epoch = 0
            except AssertionError:
                print("=> Failed to load the checkpoint, an empty one will be created")
                checkpoint = {}
                start_exp = 0
                start_epoch = 0
    else:
        checkpoint = {}
        start_exp = 0
        start_epoch = 0

    acc_list = []
    asr_list = []
    prate_list = []
    split_list = []

    for exp in range(start_exp, args.num_exp):
        exp_start = datetime.now()

        if args.save_path != "":
            checkpoint_name = "{dst}_{net}_{mtd}_exp{exp}_epoch{epc}_{dat}_{fr}_".format(dst=args.dataset,
                                                                                         net=args.model,
                                                                                         mtd=args.selection,
                                                                                         dat=datetime.now(),
                                                                                         exp=start_exp,
                                                                                         epc=args.epochs,
                                                                                         fr=args.fraction)

        print('\n============================== Exp %d ==============================\n' % exp)
        print("dataset: ", args.dataset, ", model: ", args.model, ", selection: ", args.selection, ", num_ex: ",
              args.num_exp, ", epochs: ", args.epochs, ", fraction: ", args.fraction, ", seed: ", args.seed,
              ", lr: ", args.lr, ", save_path: ", args.save_path, ", resume: ", args.resume, ", device: ", args.device,
              ", checkpoint_name: " + checkpoint_name if args.save_path != "" else "", "\n", sep="")

        if args.dataset == "CIFAR10":
            channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, dst_test_bad, dst_test_purif = CIFAR10_BD(args)
        elif args.dataset == "GTSRB":
            channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, dst_test_bad, dst_test_purif = GTSRB_BD(args)
        elif args.dataset == "TinyImageNet":
            channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, dst_test_bad, dst_test_purif = TinyImageNet_BD(args)
        else:
            raise NameError(f"Dataset {args.dataset} is not defined yet.")
        args.channel, args.im_size, args.num_classes, args.class_names = channel, im_size, num_classes, class_names

        np.random.seed(args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.random.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        # torch.backends.cudnn.deterministic = True

        if "subset" in checkpoint.keys():
            subset = checkpoint['subset']
            selection_args = checkpoint["sel_args"]
        else:
            selection_args = dict(epochs=args.selection_epochs,
                                  selection_method=args.uncertainty,
                                  balance=args.balance,
                                  greedy=args.submodular_greedy,
                                  function=args.submodular,
                                  dst_test=dst_test,
                                  dst_test_bad=dst_test_bad
                                  )
            method = methods.__dict__[args.selection](dst_train, args, mean, std, args.fraction, args.seed, **selection_args)
            subset = method.select()

        split_list.append(len(subset['indices']) / len(dst_train))

        num_subpoison = 0
        num_subnoise = 0
        cls_cnt = np.zeros(dst_train.num_classes)
        for id in subset["indices"]:
            target = dst_train.targets[id]
            cls_cnt[target] += 1
            if id in dst_train.poison_ids:
                num_subpoison += 1
            if id in dst_train.noise_ids:
                num_subnoise += 1
        prate_subset = num_subpoison/len(subset["indices"])
        prate_list.append(prate_subset)

        print(f"Subset: class distribution {list(cls_cnt.astype('int'))}")
        print(f'Subset: {len(subset["indices"])} (p_num={num_subpoison}, p_rate={prate_subset:.4f}, n_num={num_subnoise})')

        exp_end_subset = datetime.now()
        print(f"\n=================================================")
        print(f" Subset selection duration time: {exp_end_subset - exp_start}")
        print(f"=================================================\n")

        if args.save_coreset:
            selected_indices = subset["indices"]
            subset_perc = int(len(subset["indices"])/len(dst_train.targets)*100)
            if not os.path.exists("./saved_coresets"):
                os.makedirs("./saved_coresets")
            np.save(f"./saved_coresets/{args.dataset}_top{subset_perc}.npy", selected_indices)

        # Augmentation for Datasets
        if args.dataset == "CIFAR10":
            if args.selection_dataaug:
                dst_train.transform = dst_train.transform
            else:
                dst_train.transform = transforms.Compose(
                    [
                        transforms.RandomHorizontalFlip(0.5),
                        transforms.RandomCrop(args.im_size, padding=4, padding_mode="reflect"),
                        dst_train.transform
                    ])
        elif args.dataset == "GTSRB":
            if args.selection_dataaug:
                dst_train.transform = dst_train.transform
            else:
                dst_train.transform = transforms.Compose(
                    [
                        dst_train.transform
                    ])
        elif args.dataset == "TinyImageNet":
            if args.selection_dataaug:
                dst_train.transform = dst_train.transform
            else:
                dst_train.transform = transforms.Compose(
                    [
                        transforms.RandomCrop(args.im_size),
                        transforms.RandomHorizontalFlip(0.5),
                        dst_train.transform
                    ])
        else:
            raise NameError(f"transforms for data augmentation is not defined yet for {args.dataset}")

        print("Train sampled subset with transforms: ", dst_train.transform)

        # Handle weighted subset
        if_weighted = "weights" in subset.keys()
        if if_weighted:
            dst_subset = WeightedSubset(dst_train, subset["indices"], subset["weights"])
        else:
            dst_subset = torch.utils.data.Subset(dst_train, subset["indices"])

        # BackgroundGenerator for ImageNet to speed up dataloaders
        if args.dataset == "ImageNet":
            train_loader = DataLoaderX(dst_subset, batch_size=args.train_batch, shuffle=True,
                                       num_workers=args.workers, pin_memory=True)
            test_loader = DataLoaderX(dst_test, batch_size=args.train_batch, shuffle=False,
                                      num_workers=args.workers, pin_memory=True)
        else:
            train_loader = torch.utils.data.DataLoader(dst_subset, batch_size=args.train_batch, shuffle=True,
                                                       num_workers=args.workers, pin_memory=True)
            test_loader = torch.utils.data.DataLoader(dst_test, batch_size=args.train_batch, shuffle=False,
                                                      num_workers=args.workers, pin_memory=True)
            test_bad_loader = torch.utils.data.DataLoader(dst_test_bad, batch_size=args.train_batch, shuffle=False,
                                                          num_workers=args.workers, pin_memory=True)
            test_purif_loader = torch.utils.data.DataLoader(dst_test_purif, batch_size=args.train_batch, shuffle=False,
                                                            num_workers=args.workers, pin_memory=True)

        # Listing cross-architecture experiment settings if specified.
        models = [args.model]
        if isinstance(args.cross, list):
            for model in args.cross:
                if model != args.model:
                    models.append(model)

        for model in models:
            if len(models) > 1:
                print("| Training on model %s" % model)

            network = nets.__dict__[model](channel, num_classes, im_size, mean=mean, std=std).to(args.device)

            if args.device == "cpu":
                print("Using CPU.")
            elif args.gpu is not None:
                torch.cuda.set_device(args.gpu[0])
                network = nets.nets_utils.MyDataParallel(network, device_ids=args.gpu)
            elif torch.cuda.device_count() > 1:
                network = nets.nets_utils.MyDataParallel(network).cuda()

            if "state_dict" in checkpoint.keys():
                # Loading model state_dict
                network.load_state_dict(checkpoint["state_dict"])


            criterion = nn.CrossEntropyLoss(reduction='none').to(args.device)

            # Optimizer
            if args.optimizer == "SGD":
                optimizer = torch.optim.SGD(network.parameters(), args.lr, momentum=args.momentum,
                                            weight_decay=args.weight_decay, nesterov=args.nesterov)
            elif args.optimizer == "Adam":
                optimizer = torch.optim.Adam(network.parameters(), args.lr, weight_decay=args.weight_decay)
            else:
                optimizer = torch.optim.__dict__[args.optimizer](network.parameters(), args.lr, momentum=args.momentum,
                                                                 weight_decay=args.weight_decay, nesterov=args.nesterov)

            # LR scheduler
            if args.scheduler == "CosineAnnealingLR":
                scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader) * args.epochs,
                                                                           eta_min=args.min_lr)
            elif args.scheduler == "MultiStepLR":
                # milestones = [int(args.epochs)*0.5, int(args.epochs)*0.75]
                milestones = [
                    int(args.epochs*0.5)*int(len(subset["indices"])/args.train_batch)*0.5*0.5,
                    int(args.epochs*0.5)*int(len(subset["indices"])/args.train_batch)*0.5*0.75
                ]
                scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=args.gamma)
            elif args.scheduler == "StepLR":
                scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=len(train_loader) * args.step_size,
                                                            gamma=args.gamma)
            else:
                scheduler = torch.optim.lr_scheduler.__dict__[args.scheduler](optimizer)
            scheduler.last_epoch = (start_epoch - 1) * len(train_loader)

            if "opt_dict" in checkpoint.keys():
                optimizer.load_state_dict(checkpoint["opt_dict"])

            # Log recorder
            if "rec" in checkpoint.keys():
                rec = checkpoint["rec"]
            else:
                rec = init_recorder()

            best_prec1 = checkpoint["best_acc1"] if "best_acc1" in checkpoint.keys() else 0.0

            # Save the checkpont with only the susbet.
            if args.save_path != "" and args.resume == "" and args.save_model:
                save_checkpoint({"exp": exp,
                                 "subset": subset,
                                 "sel_args": selection_args},
                                os.path.join(args.save_path, checkpoint_name + ("" if model == args.model else model
                                             + "_") + "unknown.ckpt"), 0, 0.)

            for epoch in range(start_epoch, args.epochs):
                # train for one epoch
                print(f"\nLearning rate = {optimizer.param_groups[0]['lr']:.6f}")


                train(train_loader, network, criterion, optimizer, scheduler, epoch, args, rec, if_weighted=if_weighted)

                # evaluate on validation set
                if args.test_interval > 0 and (epoch + 1) % args.test_interval == 0:

                    print(">> Evaluation on clean dataset:")
                    prec1 = test(test_loader, network, criterion, epoch, args, rec, print_clsacc=True)

                    print(">> Evaluation on poisoned dataset:")
                    prec1_bad = test(test_bad_loader, network, criterion, epoch, args, rec)

                    if args.eval_purif:
                        print(">> Evaluation on the purification robustness of poisoned dataset:")
                        prec1_purif = test(test_purif_loader, network, criterion, epoch, args, rec)


                    # remember best prec@1 and save checkpoint
                    is_best = prec1 > best_prec1

                    if is_best:
                        best_prec1 = prec1
                        asr_best = prec1_bad
                        if args.save_path != "" and args.save_model:
                            rec = record_ckpt(rec, epoch)
                            save_checkpoint({"exp": exp,
                                             "epoch": epoch + 1,
                                             "state_dict": network.state_dict(),
                                             "opt_dict": optimizer.state_dict(),
                                             "best_acc1": best_prec1,
                                             "rec": rec,
                                             "subset": subset,
                                             "sel_args": selection_args},
                                            os.path.join(args.save_path, checkpoint_name + (
                                                "" if model == args.model else model + "_") + "unknown.ckpt"),
                                            epoch=epoch, prec=best_prec1)

            # Prepare for the next checkpoint
            if args.save_path != "" and args.save_model:
                try:
                    os.rename(
                        os.path.join(args.save_path, checkpoint_name + ("" if model == args.model else model + "_") +
                                     "unknown.ckpt"), os.path.join(args.save_path, checkpoint_name +
                                     ("" if model == args.model else model + "_") + "%f.ckpt" % best_prec1))
                except:
                    save_checkpoint({"exp": exp,
                                     "epoch": args.epochs,
                                     "state_dict": network.state_dict(),
                                     "opt_dict": optimizer.state_dict(),
                                     "best_acc1": best_prec1,
                                     "rec": rec,
                                     "subset": subset,
                                     "sel_args": selection_args},
                                    os.path.join(args.save_path, checkpoint_name +
                                                 ("" if model == args.model else model + "_") + "%f.ckpt" % best_prec1),
                                    epoch=args.epochs - 1,
                                    prec=best_prec1)

            print('| Best accuracy: ', best_prec1, " (ASR: ", round(asr_best, 2), ")", ", on model " + model if len(models) > 1 else "", end="\n\n")
            acc_list.append(best_prec1)
            asr_list.append(asr_best)

            start_epoch = 0
            checkpoint = {}

            exp_end = datetime.now()
            print(f"\n=================================================")
            print(f" Total duration: {exp_end - exp_start}")
            print(f"=================================================")

            sleep(2)

        args.seed = np.random.randint(1, 100)  # random seed in range [1, 100]

    print("\nSummary of all experiments in [%]:")
    print(f"ACC\tASR\tPRate\tSubRate")
    for i in range(len(acc_list)):
        print(f"{acc_list[i]:.2f}\t{asr_list[i]:.2f}\t{prate_list[i]*100.0:.2f}\t{split_list[i]*100.0:.2f}")


if __name__ == '__main__':
    main()
