import argparse
import logging
import os
import pdb
import wandb

from datasets import *
from imbalanced_datasets import imb_cifar10_dataloader, imb_svhn_dataloader, imb_stl10_dataloader
from model_zoo import *
from trainer import BatTrainer
from utils.general_utils import write_csv_rows, setup_seed, print_network
from utils.lamb import Lamb
from utils.loading_bar import Log
from utils.math_utils import smooth_crossentropy, dlr_loss
from utils.step_lr import *

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

if __name__ == "__main__":
    parser = argparse.ArgumentParser("Adversarial Training")

    ########################## basic setting ##########################
    parser.add_argument('--device', default="cuda:0", help="The name of the device you want to use (default: cuda:0)")
    parser.add_argument('--time_stamp', default="",
                        help="The time stamp that helps identify different trails.")
    parser.add_argument('--dataset', default="CIFAR10",
                        choices=["CIFAR10", "SVHN", "GTSRB", "STL10"])
    parser.add_argument('--dataset_val_ratio', default=0., type=float)
    parser.add_argument('--mode', default='ours', type=str,
                        choices=["ours", "ours++", "fast_at", "fast_bat", "fast_at_ga", "pgd"],
                        help="fast-at : pgd-at, fast_bat_kkt: bi-level at with kkt, fast_at_ga: gradient alignment")
    parser.add_argument('--data_dir', default='./data/', type=str, help="The folder where you store your dataset")
    parser.add_argument('--model_prefix', default='checkpoints/',
                        help='File folders where you want to store your checkpoints (default: results/checkpoints/)')
    parser.add_argument('--csv_prefix', default='accuracy/',
                        help='File folders where you want to put your results (default: results/accruacy)')
    parser.add_argument('--random_seed', default=2023, type=int,
                        help='Random seed (default: 2023)')
    parser.add_argument('--pretrained_model', default=None, help="The path of pretrained model")
    parser.add_argument('--pretrained_epochs', default=0, type=int)

    ########################## training setting ##########################
    parser.add_argument("--batch_size", default=128, type=int,
                        help="Batch size used in the training and validation loop.")
    parser.add_argument("--epochs", default=30, type=int, help="Total number of epochs. Default=30")
    parser.add_argument("--threads", default=2, type=int, help="Number of CPU threads for dataloaders.")
    parser.add_argument("--optimizer", default="SGD", choices=['SGD', 'Adam', 'Lamb'])
    parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum.")
    parser.add_argument("--weight_decay", default=0.0005, type=float, help="L2 weight decay.")
    parser.add_argument("--dropout", default=0.1, type=float, help="Dropout rate.")

    ########################## learning scheduler ##########################
    parser.add_argument('--lr_scheduler', default='cyclic',
                        choices=['cyclic', 'multistep'])
    parser.add_argument("--key_epochs", nargs="+", type=int, default=[100, 150],
                        help="Epochs where learning rate decays, this is for multi-step scheduler only.")
    parser.add_argument("--lr_decay_rate", default=0.1, type=float, help="This is for multi-step scheduler only.")
    parser.add_argument("--cyclic_milestone", default=10, type=int,
                        help="Key epoch for cyclic scheduler. This is for cyclic scheduler only.")
    parser.add_argument('--lr_min', default=0.0002, type=float,
                        help="Min lr for cyclic scheduler. This is for cyclic scheduler only. Default=0.")
    parser.add_argument('--lr_max', default=0.2, type=float,
                        help="Max lr for cyclic scheduler. This is for cyclic scheduler only. Default=0.2") # lr_initial
    parser.add_argument('--lr_initial', default=0.01, type=float,
                        help="Initial lr for multi-step scheduler. This is for multi-step scheduler only. Default=0.01")

    ########################## model setting ##########################
    parser.add_argument('--train_loss', default="ce", choices=["ce", "sce", "n_dlr"],
                        help="ce for cross entropy, sce for label-smoothed ce, n_dlr for negative dlr loss")
    parser.add_argument('--act_fn', default="relu", choices=["relu", "softplus", "swish"],
                        help="choose the activation function for your model")
    parser.add_argument("--model_type", default="ResNet", choices=['ResNet', 'PreActResNet', 'WideResNet'])
    parser.add_argument("--width_factor", default=0, type=int, help="Parameter for WideResNet only.")
    parser.add_argument("--depth", default=18, type=int, help="Parameter for all model types.")

    ########################## attack setting ##########################
    parser.add_argument('--attack_step', default=10, type=int,
                        help='attack steps for training (default: 1)')
    parser.add_argument('--attack_step_test', default=20, type=int,
                        help='attack steps for evaluation (default: 20)')
    parser.add_argument('--attack_eps', default=8, type=float,
                        help='attack constraint for training (default: 8/255)')
    parser.add_argument('--attack_rs', default=1, type=int,
                        help='attack restart number')
    parser.add_argument('--attack_lr', default=2., type=float,
                        help='attack learning rate (default: 2./255). Note this parameter is for training only. The attack lr is always set to attack_eps / 4 when evaluating.')
    parser.add_argument('--attack_rs_test', default=1, type=int,
                        help='attack restart number for evaluation')

    ############################### others ###################################
    parser.add_argument('--kl_coef', default=10.0, type=float, help="....")
    parser.add_argument('--mu', default=0.01, type=float, help="....")
    parser.add_argument('--ir', default=1.0, type=float, help="....")
    parser.add_argument('--klc_milestone1', default=15, type=int, help="First epoch number to decay kl coef.")
    parser.add_argument('--klc_milestone2', default=20, type=int, help='Second epoch number to decay kl coef.')
    parser.add_argument('--wandb', type=str, default='AT_prelim', help='Wandb project name prefix. If None, no wandb logging')
    parser.add_argument('--wandb_project', type=str, default=None, help='Wandb project name. If None, will be constructed from relevant paramters')
    parser.add_argument('--run_name', type=str, default=None, help='Wandb run name. If None, name will be random')
    parser.add_argument('--wandb_key', type=str, default=None, help='Wandb token key for login. If None, shell login assumed')

    ############################### fast-bat options ###################################
    parser.add_argument('--lmbda', default=10.0, type=float, help="The parameter lambda for Fast-BAT.")

    ############################### grad alignment ##################################
    parser.add_argument('--ga_coef', default=0.0, type=float,
                        help="coefficient of the cosine gradient alignment regularizer")

    args = parser.parse_args()
    device = args.device
    if device != "cpu" and torch.cuda.is_available():
        # Please use CUDA_VISIBLE_DEVICES to assign gpu
        device = "cuda:0"

    result_path = "./results/"
    log_path = "./log/"
    model_dir = os.path.join(result_path, args.model_prefix)
    csv_dir = os.path.join(result_path, args.csv_prefix)
    if not os.path.exists(result_path):
        os.mkdir(result_path)
    if not os.path.exists(log_path):
        os.mkdir(log_path)
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)
    if not os.path.exists(csv_dir):
        os.mkdir(csv_dir)

    setup_seed(seed=args.random_seed)
    training_type = args.mode.upper()
    model_name = f"{args.dataset}_{training_type}_{args.model_type}-{args.depth}_Eps{args.attack_eps}_{args.time_stamp}"
    model_path = os.path.join(result_path, args.model_prefix + model_name + '.pth')
    best_model_path = os.path.join(result_path, args.model_prefix + model_name + '_best.pth')
    csv_path = os.path.join(result_path, args.csv_prefix + model_name + '.csv')

    if args.mode == "fast_at" or args.mode == "fast_at_ga":
        args.attack_lr = args.attack_eps * 1.25 / 255
    elif args.mode == "pgd":
        if args.attack_step == 2:
            args.attack_lr = args.attack_eps * 0.5 / 255
        elif args.attack_step == 10:
            args.attack_lr = 2.0 / 255
        else:
            args.attack_lr = args.attack_lr / 255
    elif args.mode == 'fast_bat':
        if args.attack_eps <= 8:
            args.attack_lr = 5000 # 5000 is good for 2 attack steps during training, 2000 is good for 10 attack steps
        else:
            args.attack_lr = 2000
        args.attack_lr = args.attack_lr / 255
    else:
        if args.attack_step >= 8:
            args.attack_lr = 2.
        else:
            args.attack_lr = args.attack_eps * 1.25
        args.attack_lr = args.attack_lr / 255

    args.attack_eps = args.attack_eps / 255

    ############################## Logger #################################
    log = Log(log_each=2)
    logging.basicConfig(filename=os.path.join(log_path, f'{model_name}.log'), level=logging.INFO)
    logger = logging.getLogger("Adversarial Training")

    ########################## dataset and model ##########################
    if args.dataset == "CIFAR10":
        train_dl, val_dl, test_dl, norm_layer, num_classes = imb_cifar10_dataloader(data_dir=args.data_dir,
                                                                                    batch_size=args.batch_size,
                                                                                    val_ratio=args.dataset_val_ratio,
                                                                                    imratio=args.ir)
        # print(len(train_dl))
        # pdb.set_trace()


    elif args.dataset == "SVHN":
        train_dl, val_dl, test_dl, norm_layer, num_classes = imb_svhn_dataloader(data_dir=args.data_dir,
                                                                             batch_size=args.batch_size,
                                                                             val_ratio=args.dataset_val_ratio,
                                                                             imratio=args.ir)

        # train_dl, val_dl, test_dl, norm_layer, num_classes = svhn_dataloader(data_dir=args.data_dir,
        #                                                                      batch_size=args.batch_size,
        #                                                                      # val_ratio=args.dataset_val_ratio,
        #                                                                      # imratio=args.ir
        #                                                                      )

    elif args.dataset == "GTSRB":
        train_dl, val_dl, test_dl, norm_layer, num_classes = gtsrb_dataloader(data_dir=args.data_dir,
                                                                              batch_size=args.batch_size,
                                                                              val_ratio=args.dataset_val_ratio)

    elif args.dataset == "STL10":
        train_dl, val_dl, test_dl, norm_layer, num_classes = stl10_dataloader(data_dir=args.data_dir,
                                                                              batch_size=args.batch_size)
    else:
        raise NotImplementedError("Invalid Dataset")

    print(f'\nDataset: {args.dataset}. Number of trainining batches: {len(train_dl)}\n')
    # pdb.set_trace()

    if args.act_fn == "relu":
        activation_fn = nn.ReLU
    elif args.act_fn == "softplus":
        activation_fn = nn.Softplus
    elif args.act_fn == "swish":
        activation_fn = Swish
    else:
        raise NotImplementedError("Unsupported activation function!")

    if args.model_type == "WideResNet":
        if args.depth == 16:
            model = WRN_16_8(num_classes=num_classes, dropout=args.dropout,
                             activation_fn=activation_fn)
        elif args.depth == 28:
            model = WRN_28_10(num_classes=num_classes, dropout=args.dropout,
                              activation_fn=activation_fn)
        elif args.depth == 34:
            model = WRN_34_10(num_classes=num_classes, dropout=args.dropout,
                              activation_fn=activation_fn)
        elif args.depth == 70:
            model = WRN_70_16(num_classes=num_classes, dropout=args.dropout,
                              activation_fn=activation_fn)
        else:
            raise NotImplementedError("Unsupported WideResNet!")
    elif args.model_type == "PreActResNet":
        if args.depth == 18:
            model = PreActResNet18(num_classes=num_classes, activation_fn=activation_fn)
        elif args.depth == 34:
            model = PreActResNet34(num_classes=num_classes, activation_fn=activation_fn)
        else:
            model = PreActResNet50(num_classes=num_classes, activation_fn=activation_fn)
    elif args.model_type == "ResNet":
        if args.depth == 18:
            model = ResNet18(num_classes=num_classes, activation_fn=activation_fn)
        elif args.depth == 34:
            model = ResNet34(num_classes=num_classes, activation_fn=activation_fn)
        else:
            model = ResNet50(num_classes=num_classes, activation_fn=activation_fn)
    else:
        raise NotImplementedError("Unsupported Model Type!")
    model.normalize = norm_layer
    model = model.to(device)

    print_network(model)

    ########################## optimizer and scheduler ##########################
    if args.optimizer == "Adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr_initial, # was lr=args.lr_max # args.lr_max if args.lr_scheduler == "cyclic" else args.lr_initial
                                     weight_decay=args.weight_decay)
    elif args.optimizer == "SGD":
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_initial,
                                    weight_decay=args.weight_decay,
                                    momentum=args.momentum)
    elif args.optimizer == "Lamb":
        optimizer = Lamb(model.parameters(), lr=args.lr_initial,
                                    weight_decay=args.weight_decay,
                                    betas=(0.9, 0.999))
    else:
        raise NotImplementedError("Unsupported optimizer!")

    lr_steps = args.epochs * len(train_dl)
    if args.lr_scheduler == "cyclic":
        milestone_epoch_num = args.cyclic_milestone
        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,
                                                      base_lr=args.lr_min,
                                                      max_lr=args.lr_max,
                                                      step_size_up=int(milestone_epoch_num * len(train_dl)),
                                                      step_size_down=int(milestone_epoch_num * len(train_dl))) # int((args.epochs - milestone_epoch_num) * len(train_dl)))
    elif args.lr_scheduler == "multistep":
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         milestones=[len(train_dl) * i for i in args.key_epochs],
                                                         gamma=args.lr_decay_rate)
    else:
        raise NotImplementedError("Unsupported Scheduler!")

    if args.train_loss == "sce":
        train_loss = smooth_crossentropy
    elif args.train_loss == "ce":
        train_loss = torch.nn.CrossEntropyLoss(reduction="sum")
    elif args.train_loss == "n_dlr":
        def n_dlr(predictions, labels):
            return -dlr_loss(predictions, labels)


        train_loss = n_dlr
    else:
        raise NotImplementedError("Unsupported Loss Function!")

    ############################ Trainer ###################################
    trainer = BatTrainer(args=args,
                         log=log)

    ########################## resume ##########################
    if args.pretrained_model:
        model.load(args.pretrained_model, map_location=device)

    for epoch in range(0, args.pretrained_epochs):
        for i in range(len(train_dl)):
            optimizer.step()
            scheduler.step()

    param_info = f"training type: {training_type}\n" + \
                 f"device: {args.device}\n" + \
                 f"model name: {model_name}\n" + \
                 f"epoch number: {args.epochs}\n" + \
                 f"random seed: {args.random_seed}\n" + \
                 f"key epoch: {args.key_epochs}\n" + \
                 f"batch size: {args.batch_size}\n" + \
                 f"validation set ratio: {args.dataset_val_ratio}\n" + \
                 f"model type: {args.model_type}\n" + \
                 f"model depth: {args.depth}\n" + \
                 f"model width: {args.width_factor}\n" + \
                 f"scheduler: {args.lr_scheduler}\n" + \
                 f"learning rate decay rate for multi-step: {args.lr_decay_rate}\n" + \
                 f"max learning rate: {args.lr_max}\n" + \
                 f"weight_decay: {args.weight_decay}\n" + \
                 f"momentum: {args.momentum}\n" \
                 f"dropout: {args.dropout}\n" + \
                 f"attack learning rate: {args.attack_lr * 255} / 255\n" \
                 f"attack epsilon: {args.attack_eps * 255} / 255\n" \
                 f"attack step: {args.attack_step}\n" + \
                 f"attack restart: {args.attack_rs}\n" + \
                 f"evaluation attack step: {args.attack_step_test}\n" + \
                 f"evaluation attack restart: {args.attack_rs_test}\n" + \
                 f"pretrained model: {args.pretrained_model}\n" + \
                 f"pretrained epochs: {args.pretrained_epochs}\n" + \
                 f"lambda: {args.lmbda}\n" + \
                 f"gradient alignment cosine coefficient: {args.ga_coef}\n" + \
                 f"KL regularizer coeff: {args.kl_coef}\n" + \
                 f"Log barrier function coeff: {args.mu}\n" + \
                 f"Max lr: {args.lr_max}\n" + \
                 f"Imbalanced ratio: {args.ir}\n" + \
                 f"Number of classes: {num_classes}\n"

    logger.info(param_info)
    print(param_info)

    ############ wandb #############
    if args.wandb is not None:
        if args.wandb_project is None:
            args.wandb_project = f'{args.wandb}_{args.dataset}_{args.model_type}'
        else:
            args.wandb_project = f'{args.wandb}_{args.wandb_project}'
        if args.wandb_key is not None:
            wandb.login(key=args.wandb_key)
            print('successfully logged in to wandb')
        # os.system("wandb --relogin")

        group = f'{args.mode}_{args.attack_eps * 255}_{args.ir}'
        if args.run_name is None:
            wandb.init(project=args.wandb_project, group=group, reinit=True)
        else:
            wandb.init(project=args.wandb_project, name=args.run_name, group=group, reinit=True)
        wandb.config.update(args)

        ra_tab = wandb.Table(columns=['avg'] + ['tail_30'] + [f'class_{c}' for c in range(num_classes)])
        sa_tab = wandb.Table(columns=['avg'] + ['tail_30'] + [f'class_{c}' for c in range(num_classes)])

    ##################################

    epoch_num_list = ['Epoch Number']
    training_sa_list = ['Training Standard Accuracy']
    training_ra_list = ['Training Robust Accuracy']
    test_sa_list = ['Test Standard Accuracy']
    test_ra_list = ['Test Robust Accuracy']
    training_loss_list = ['Training Loss']
    test_loss_list = ['Test Loss']

    best_acc = 0.0

    klc = args.kl_coef

    for epoch in range(args.pretrained_epochs, args.epochs):
        logger.info(f"\n========================Epoch : {epoch}========================")
        model.train()
        csv_row_list = []
        log.train(len_dataset=len(train_dl))

        if epoch == args.klc_milestone1:
            klc = args.kl_coef / 5. if num_classes<=10 else args.kl_coef / 1.
        if epoch == args.klc_milestone2:
            klc = args.kl_coef / 100. if num_classes<=10 else args.kl_coef / 2.


        if args.wandb and ('ours' in args.mode):
            wandb.log({"kl_coef": klc}, step=epoch*len(train_dl))

        model = trainer.train(model=model,
                              train_dl=train_dl,
                              opt=optimizer,
                              loss_func=train_loss,
                              scheduler=scheduler if (epoch < 2*args.cyclic_milestone or args.lr_scheduler == "multistep") else None,
                              device=device,
                              kl_coef=klc,
                              epoch=epoch,
                              wandb=wandb if args.wandb else None) # output more training stuffs here?

        if ((epoch+1) % 5 == 0) or epoch >= 2*args.cyclic_milestone: #13
            model.eval()
            log.eval(len_dataset=len(val_dl))

            correct_total, robust_total, total, test_loss, out_str, ra_30, sa_30, RAs, SAs = trainer.eval_per_class(model=model,
                                                                     test_dl=val_dl,
                                                                     attack_eps=args.attack_eps,
                                                                     attack_steps=args.attack_step_test,
                                                                     attack_lr=args.attack_eps / 4,
                                                                     attack_rs=args.attack_rs_test,
                                                                     device=device,
                                                                     num_classes=num_classes)

            print(out_str)
            natural_acc = correct_total / total
            robust_acc = robust_total / total
            if args.wandb:
                wandb.log({"robust_acc": 100. * robust_acc}, step=(epoch+1) * len(train_dl))
                wandb.log({"standard_acc": 100. * natural_acc}, step=(epoch+1) * len(train_dl))
                wandb.log({"ra_tail_30": 100. * ra_30}, step=(epoch+1) * len(train_dl))
                wandb.log({"sa_tail_30": 100. * sa_30}, step=(epoch+1) * len(train_dl))

                ra_row = (100. * robust_acc, 100. * ra_30) + tuple(100. * RAs[c] for c in range(num_classes))
                sa_row = (100. * natural_acc, 100. * sa_30) + tuple(100. * SAs[c] for c in range(num_classes))
                ra_tab.add_data(*ra_row)
                sa_tab.add_data(*sa_row)


            # Writing data into csv file
            epoch_num_list.append(epoch)
            csv_row_list.append(epoch_num_list)
            csv_row_list.append(training_loss_list) # not updated????
            csv_row_list.append(training_sa_list) # not updated????
            csv_row_list.append(training_ra_list) # not updated????

            test_loss_list.append(test_loss)
            csv_row_list.append(test_loss_list)
            test_sa_list.append(100. * natural_acc)
            csv_row_list.append(test_sa_list)
            test_ra_list.append(100. * robust_acc)
            csv_row_list.append(test_ra_list)

            logger.info(f'For the epoch {epoch} the test loss is {test_loss}')
            logger.info(f'For the epoch {epoch} the standard accuracy is {natural_acc}')
            logger.info(f'For the epoch {epoch} the robust accuracy is {robust_acc}')

            model.save(model_path)
            write_csv_rows(csv_path, csv_row_list)

            if robust_acc > best_acc:
                best_acc = robust_acc
                model.save(best_model_path)

    log.flush()
    if args.wandb is not None:
        wandb.log({"best_acc": 100. * best_acc})
        wandb.log({"ra_table": ra_tab})
        wandb.log({"sa_table": sa_tab})
    print('\n============= Training Over ================')
