import argparse
import os
import numpy as np
import sys
import yaml
import distutils.util
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim

#config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from attack.dsq_attack import system_attack
from utils import mkdir_p, AverageMeter, accuracy, print_acc_conf
from tinyimagenet_utils import transform_train, transform_test, TINdata, DistillTINdata, WarmUpLR, ModelwNorm, \
    transform_train_aug
from tinyimagenet.models.model_selector import get_network

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def undefendtest(testloader, model, criterion, len_data, args):
    # switch to evaluate mode
    model.eval()

    num_class = args.num_class
    batch_size = args.batch_size

    losses = AverageMeter()
    infer_np = np.zeros((len_data, num_class))

    for batch_ind, (inputs, targets) in enumerate(testloader):
        # compute output
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)

        outputs = model(inputs)
        infer_np[batch_ind*batch_size: batch_ind*batch_size+inputs.shape[0]] = (F.softmax(outputs,dim=1)).detach().cpu().numpy()

        loss = criterion(outputs, targets)
        losses.update(loss.item(), inputs.size()[0])

    return (losses.avg, infer_np)#, logits_np)


def main():
    parser = argparse.ArgumentParser(description='setting for cifar100')
    parser.add_argument('--cuda', type=int, default=0)
    parser.add_argument('--model', type=str, default='mobilenetv3_small')
    parser.add_argument('--attack_epochs', type=int, default=150, help='attack epochs in NN attack')
    parser.add_argument('--print_epoch', type=int, default=5, help='print single model training stats per print_epoch training')
    parser.add_argument('--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('--warmup', type=int, default=1, help='warm up epochs')
    parser.add_argument('--num_worker', type=int, default=1, help='number workers')
    parser.add_argument('--num_class', type=int, default=200, help='num class')
    parser.add_argument('--num_runs', type=int, default=1)
    # conf
    parser.add_argument('--data_aug', type=distutils.util.strtobool, default=True, help='turn on data augmentation')
    parser.add_argument('--save_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to load the checkpoints')
    # defence conf
    parser.add_argument('--alpha', type=float, default=2.0, help='para for Adversarial Regularization')
    parser.add_argument('--k', type=int, default=5, help='k steps for Adversarial Regularization')

    args = parser.parse_args()
    print(dict(args._get_kwargs()))

    global device
    cuda_id = args.cuda
    device = torch.device(f"cuda:{str(cuda_id)}" if torch.cuda.is_available() else "cpu")

    attack_epochs = args.attack_epochs
    batch_size = args.batch_size
    num_class = args.num_class
    warmup = args.warmup
    num_worker = args.num_worker

    DATASET_PATH = os.path.join(root_dir, 'tinyimagenet',  'data')
    checkpoint_path = os.path.join(args.load_path, 'tinyimagenet', args.model, f'advreg',
                                   'aug' if args.data_aug else 'no_aug', f'{str(int(args.alpha * 100))}_{args.k}')
    save_checkpoint_path = os.path.join(args.save_path, 'csv_save', 'tinyimagenet', args.model, f'rw_advreg',
                                        'aug' if args.data_aug else 'no_aug', f'{str(int(args.alpha * 100))}_{args.k}')
    print(checkpoint_path)

    train_data_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_data.npy'))
    train_label_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_label.npy'))
    train_data_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_data.npy'))
    train_label_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_label.npy'))
    train_data = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data.npy'))
    train_label = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label.npy'))
    test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'test_data.npy'))
    test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'test_label.npy'))
    ref_data = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_data.npy'))
    ref_label = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_label.npy'))
    all_test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_data.npy'))
    all_test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_label.npy'))

    # print first 20 labels for each subset, for checking with other experiments
    print(train_label_tr_attack[:20])
    print(train_label_te_attack[:20])
    print(test_label[:20])
    print(ref_label[:20])

    # if data augmented
    if args.data_aug:
        trainset = TINdata(train_data, train_label, transform_train_aug)
    else:
        trainset = TINdata(train_data, train_label, transform_train)
    # trainset = Cifardata(train_data, train_label, transform_train)
    traintestset = TINdata(train_data, train_label, transform_test)
    testset = TINdata(test_data, test_label, transform_test)
    refset = TINdata(ref_data, ref_label, transform_test)

    trset = TINdata(train_data_tr_attack, train_label_tr_attack, transform_test)
    teset = TINdata(train_data_te_attack, train_label_te_attack, transform_test)
    alltestset = TINdata(all_test_data, all_test_label, transform_test)

    trloader = torch.utils.data.DataLoader(trset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    teloader = torch.utils.data.DataLoader(teset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    alltestloader = torch.utils.data.DataLoader(alltestset, batch_size=batch_size, shuffle=False,
                                                num_workers=num_worker)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_worker)
    traintestloader = torch.utils.data.DataLoader(traintestset, batch_size=batch_size, shuffle=False,
                                                  num_workers=num_worker)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    refloader = torch.utils.data.DataLoader(refset, batch_size=batch_size, shuffle=False, num_workers=num_worker)

    num_runs = args.num_runs

    train_accs, test_accs = tuple([np.zeros((num_runs)) for _ in range(2)])
    entr_acc, mentr_acc, conf_acc, corr_acc, nn_acc = tuple([np.zeros((num_runs)) for _ in range(5)])
    entr_fpr, entr_tpr, entr_thresholds = [], [], []
    mentr_fpr, mentr_tpr, mentr_thresholds = [], [], []
    conf_fpr, conf_tpr, conf_thresholds = [], [], []
    corr_fpr, corr_tpr, corr_thresholds = [], [], []
    for i in range(1, num_runs + 1):
        cur_cp = os.path.join(checkpoint_path, str(i))
        criterion = nn.CrossEntropyLoss().to(device, torch.float)
        net_1 = get_network(args.model, num_classes=args.num_class)
        net = net_1
        resume = cur_cp + '/model_last.pth.tar'
        print('==> Resuming from checkpoint' + resume)
        assert os.path.isfile(resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(resume, map_location='cpu')
        net.load_state_dict(checkpoint['state_dict'])
        net = net.to(device, torch.float)

        print("Attack Training: # of train data: {:d}, # of ref data: {:d}".format(int(len(train_data_tr_attack)),
                                                                                   len(ref_data)))
        print("Attack Testing: # of train data: {:d}, # of test data: {:d}".format(int(len(train_data_te_attack)),
                                                                                   len(test_data)))

        print("training set")
        train_loss, infer_train_conf = undefendtest(traintestloader, net, criterion, len(traintestset), args)
        train_acc, train_conf = print_acc_conf(infer_train_conf, train_label)
        print("tr set")
        tr_loss, infer_train_conf_tr = undefendtest(trloader, net, criterion, len(trset), args)
        tr_acc, tr_conf = print_acc_conf(infer_train_conf_tr, train_label_tr_attack)
        print("all test set")
        all_test_loss, infer_all_test_conf = undefendtest(alltestloader, net, criterion, len(alltestset), args)
        all_test_acc, all_test_conf = print_acc_conf(infer_all_test_conf, all_test_label)
        print("te set")
        te_loss, infer_train_conf_te = undefendtest(teloader, net, criterion, len(teset), args)
        te_acc, te_conf = print_acc_conf(infer_train_conf_te, train_label_te_attack)
        print("test set")
        test_loss, infer_test_conf = undefendtest(testloader, net, criterion, len(testset), args)
        test_acc, test_conf = print_acc_conf(infer_test_conf, test_label)
        print("reference set")
        ref_loss, infer_ref_conf = undefendtest(refloader, net, criterion, len(refset), args)
        ref_acc, ref_conf = print_acc_conf(infer_ref_conf, ref_label)

        print("For comparison on undefend output")
        print("avg acc  on train/all test/tr/te/test/reference set: {:.4f}/{:.4f}/{:.4f}/{:.4f}/{:.4f}/{:.4f}".format(
            train_acc, all_test_acc, tr_acc, te_acc, test_acc, ref_acc))
        print("avg conf on train/all_test/tr/te/test/reference set: {:.4f}/{:.4f}/{:.4f}/{:.4f}/{:.4f}/{:.4f}".format(
            train_conf, all_test_conf, tr_conf, te_conf, test_conf, ref_conf))

        train_accs[i - 1], test_accs[i - 1] = train_acc, test_acc
        entr_acc[i - 1], mentr_acc[i - 1], conf_acc[i - 1], corr_acc[i - 1], nn_acc[i - 1], \
            entr_aucs, mentr_aucs, conf_aucs, corr_aucs = system_attack(
            infer_train_conf_tr, train_label_tr_attack, infer_train_conf_te, train_label_te_attack, infer_ref_conf,
            ref_label, infer_test_conf, test_label, num_class=args.num_class, attack_epochs=attack_epochs,
            batch_size=256)

    temp_fpr, temp_tpr, temp_thresholds = entr_aucs
    entr_fpr.append(temp_fpr)
    entr_tpr.append(temp_tpr)
    entr_thresholds.append(temp_thresholds)
    entr_fpr = np.stack(tuple(entr_fpr))
    entr_tpr = np.stack(tuple(entr_tpr))
    entr_thresholds = np.stack(tuple(entr_thresholds))

    temp_fpr, temp_tpr, temp_thresholds = mentr_aucs
    mentr_fpr.append(temp_fpr)
    mentr_tpr.append(temp_tpr)
    mentr_thresholds.append(temp_thresholds)
    mentr_fpr = np.stack(tuple(mentr_fpr))
    mentr_tpr = np.stack(tuple(mentr_tpr))
    mentr_thresholds = np.stack(tuple(mentr_thresholds))

    temp_fpr, temp_tpr, temp_thresholds = conf_aucs
    conf_fpr.append(temp_fpr)
    conf_tpr.append(temp_tpr)
    conf_thresholds.append(temp_thresholds)
    conf_fpr = np.stack(tuple(conf_fpr))
    conf_tpr = np.stack(tuple(conf_tpr))
    conf_thresholds = np.stack(tuple(conf_thresholds))

    temp_fpr, temp_tpr, temp_thresholds = corr_aucs
    corr_fpr.append(temp_fpr)
    corr_tpr.append(temp_tpr)
    corr_thresholds.append(temp_thresholds)
    corr_fpr = np.stack(tuple(corr_fpr))
    corr_tpr = np.stack(tuple(corr_tpr))
    corr_thresholds = np.stack(tuple(corr_thresholds))

    from pathlib import Path
    filepath = Path(save_checkpoint_path)
    filepath.mkdir(parents=True, exist_ok=True)
    # ACC
    cur_scp = f'{save_checkpoint_path}/train_{num_runs}.csv'
    df = generate_dataframe(train_accs)
    df.to_csv(cur_scp, index=False)
    cur_scp = f'{save_checkpoint_path}/test_{num_runs}.csv'
    df = generate_dataframe(test_accs)
    df.to_csv(cur_scp, index=False)
    cur_scp = f'{save_checkpoint_path}/entr_{num_runs}.csv'
    df = generate_dataframe(entr_acc)
    df.to_csv(cur_scp, index=False)
    cur_scp = f'{save_checkpoint_path}/mentr_{num_runs}.csv'
    df = generate_dataframe(mentr_acc)
    df.to_csv(cur_scp, index=False)
    cur_scp = f'{save_checkpoint_path}/conf_{num_runs}.csv'
    df = generate_dataframe(conf_acc)
    df.to_csv(cur_scp, index=False)
    cur_scp = f'{save_checkpoint_path}/corr_{num_runs}.csv'
    df = generate_dataframe(corr_acc)
    df.to_csv(cur_scp, index=False)
    cur_scp = f'{save_checkpoint_path}/nn_{num_runs}.csv'
    df = generate_dataframe(nn_acc)
    df.to_csv(cur_scp, index=False)
    # AUC
    cur_scp = f'{save_checkpoint_path}/entr_auc-roc_{num_runs}.csv'
    df = generate_multi_line_dataframe(entr_fpr, entr_tpr, entr_thresholds)
    df.to_csv(cur_scp, index=False)
    cur_scp = f'{save_checkpoint_path}/mentr_auc-roc_{num_runs}.csv'
    df = generate_multi_line_dataframe(mentr_fpr, mentr_tpr, mentr_thresholds)
    df.to_csv(cur_scp, index=False)
    cur_scp = f'{save_checkpoint_path}/conf_auc-roc_{num_runs}.csv'
    df = generate_multi_line_dataframe(conf_fpr, conf_tpr, conf_thresholds)
    df.to_csv(cur_scp, index=False)
    cur_scp = f'{save_checkpoint_path}/corr_auc-roc_{num_runs}.csv'
    df = generate_multi_line_dataframe(corr_fpr, corr_tpr, corr_thresholds)
    df.to_csv(cur_scp, index=False)


def generate_multi_line_dataframe(fpr, tpr, thresholds):
    fpr_avg, fpr_std = np.mean(fpr, axis=0), np.std(fpr, axis=0)
    tpr_avg, tpr_std = np.mean(tpr, axis=0), np.std(tpr, axis=0)
    # array
    data = np.stack((fpr_avg, fpr_std, tpr_avg, tpr_std), axis=0)
    # Transposing the array
    transposed_array = data.T
    # Converting the transposed array to a DataFrame
    df = pd.DataFrame(transposed_array, columns=['fpr_avg', 'fpr_std', 'tpr_avg', 'tpr_std'])
    print(df)
    return df


def generate_dataframe(array):
    r1 = np.mean(array)
    r2 = np.std(array)
    data = np.array([
        [r1, r2]
    ])
    df = pd.DataFrame(data, columns=['avg', 'std'])
    print(df)
    return df


if __name__ == '__main__':
        main()
