import argparse
import os
import numpy as np
import sys
import yaml
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim

#config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from attack.lira import lira_attack
from utils import mkdir_p, AverageMeter, accuracy, print_acc_conf
from cifar_utils import transform_train, transform_test, Cifardata, DistillCifardata, WarmUpLR, ModelwNorm, \
    transform_train_aug
from cifar100.models.model_selector import get_network

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def undefendtest(testloader, model, criterion, len_data, args):
    # switch to evaluate mode
    model.eval()

    num_class = args.num_class
    batch_size = args.batch_size

    losses = AverageMeter()
    infer_np = np.zeros((len_data, num_class))

    for batch_ind, (inputs, targets) in enumerate(testloader):
        # compute output
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)

        outputs = model(inputs)
        infer_np[batch_ind*batch_size: batch_ind*batch_size+inputs.shape[0]] = (F.softmax(outputs,dim=1)).detach().cpu().numpy()

        loss = criterion(outputs, targets)
        losses.update(loss.item(), inputs.size()[0])

    return (losses.avg, infer_np)#, logits_np)


def main():
    parser = argparse.ArgumentParser(description='setting for cifar100')
    parser.add_argument('--cuda', type=int, default=0)
    parser.add_argument('--model', type=str, default='mobilenetv3_small')
    parser.add_argument('--attack_epochs', type=int, default=150, help='attack epochs in NN attack')
    parser.add_argument('--print_epoch', type=int, default=5, help='print single model training stats per print_epoch training')
    parser.add_argument('--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('--warmup', type=int, default=1, help='warm up epochs')
    parser.add_argument('--num_worker', type=int, default=1, help='number workers')
    parser.add_argument('--num_class', type=int, default=100, help='num class')
    parser.add_argument('--num_runs', type=int, default=1)
    # conf
    parser.add_argument('--data_aug', type=bool, default=True, help='turn on data augmentation')
    parser.add_argument('--save_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to load the checkpoints')

    args = parser.parse_args()
    print(dict(args._get_kwargs()))

    global device
    cuda_id = args.cuda
    device = torch.device(f"cuda:{str(cuda_id)}" if torch.cuda.is_available() else "cpu")

    attack_epochs = args.attack_epochs
    batch_size = args.batch_size
    num_class = args.num_class
    warmup = args.warmup
    num_worker = args.num_worker

    DATASET_PATH = os.path.join(root_dir, 'cifar100',  'data')
    checkpoint_path = os.path.join(args.load_path, 'cifar100', args.model, 'undefend', 'aug' if args.data_aug else 'no_aug')
    save_checkpoint_path = os.path.join(args.save_path, 'csv_save', 'cifar100', args.model, 'undefend', 'aug' if args.data_aug else 'no_aug')
    print(checkpoint_path)

    train_data_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_data.npy'))
    train_label_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_label.npy'))
    train_data_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_data.npy'))
    train_label_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_label.npy'))
    train_data = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data.npy'))
    train_label = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label.npy'))
    test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'test_data.npy'))
    test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'test_label.npy'))
    ref_data = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_data.npy'))
    ref_label = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_label.npy'))
    all_test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_data.npy'))
    all_test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_label.npy'))

    #print first 20 labels for each subset, for checking with other experiments
    print(train_label_tr_attack[:20])
    print(train_label_te_attack[:20])
    print(test_label[:20])
    print(ref_label[:20])

    # if data augmented
    if args.data_aug:
        trainset = Cifardata(train_data, train_label, transform_train_aug)
    else:
        trainset = Cifardata(train_data, train_label, transform_train)
    #trainset = Cifardata(train_data, train_label, transform_train)
    traintestset = Cifardata(train_data, train_label, transform_test)
    testset = Cifardata(test_data, test_label, transform_test)
    refset = Cifardata(ref_data, ref_label, transform_test)

    trset = Cifardata(train_data_tr_attack, train_label_tr_attack, transform_test)
    teset = Cifardata(train_data_te_attack, train_label_te_attack, transform_test)
    alltestset = Cifardata(all_test_data, all_test_label, transform_test)

    trloader = torch.utils.data.DataLoader(trset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    teloader = torch.utils.data.DataLoader(teset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    alltestloader = torch.utils.data.DataLoader(alltestset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_worker)
    traintestloader = torch.utils.data.DataLoader(traintestset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    refloader = torch.utils.data.DataLoader(refset, batch_size=batch_size, shuffle=False, num_workers=num_worker)

    num_runs = args.num_runs

    train_accs, test_accs = tuple([np.zeros((num_runs)) for _ in range(2)])
    lira_acc = np.zeros((num_runs))
    lira_fpr, lira_tpr, lira_thresholds    = [], [], []
    for i in range(1, num_runs + 1):
        cur_cp = os.path.join(checkpoint_path, str(i))
        criterion = nn.CrossEntropyLoss().to(device, torch.float)
        net_1 = get_network(args.model, num_classes=args.num_class)
        net = ModelwNorm(net_1)
        resume = cur_cp + '/model_last.pth.tar'
        print('==> Resuming from checkpoint' + resume)
        assert os.path.isfile(resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(resume, map_location='cpu')
        net.load_state_dict(checkpoint['state_dict'])
        net = net.to(device, torch.float)

        print("Attack Training: # of train data: {:d}, # of ref data: {:d}".format(int(len(train_data_tr_attack)),
                                                                                   len(ref_data)))
        print("Attack Testing: # of train data: {:d}, # of test data: {:d}".format(int(len(train_data_te_attack)),
                                                                                   len(test_data)))

        print("training set")
        train_loss, infer_train_conf = undefendtest(traintestloader, net, criterion, len(traintestset), args)
        train_acc, train_conf = print_acc_conf(infer_train_conf, train_label)
        print("tr set")
        tr_loss, infer_train_conf_tr = undefendtest(trloader, net, criterion, len(trset), args)
        tr_acc, tr_conf = print_acc_conf(infer_train_conf_tr, train_label_tr_attack)
        print("all test set")
        all_test_loss, infer_all_test_conf = undefendtest(alltestloader, net, criterion, len(alltestset), args)
        all_test_acc, all_test_conf = print_acc_conf(infer_all_test_conf, all_test_label)
        print("te set")
        te_loss, infer_train_conf_te = undefendtest(teloader, net, criterion, len(teset), args)
        te_acc, te_conf = print_acc_conf(infer_train_conf_te, train_label_te_attack)
        print("test set")
        test_loss, infer_test_conf = undefendtest(testloader, net, criterion, len(testset), args)
        test_acc, test_conf = print_acc_conf(infer_test_conf, test_label)
        print("reference set")
        ref_loss, infer_ref_conf = undefendtest(refloader, net, criterion, len(refset), args)
        ref_acc, ref_conf = print_acc_conf(infer_ref_conf, ref_label)

        print("For comparison on undefend output")
        print("avg acc  on train/all test/tr/te/test/reference set: {:.4f}/{:.4f}/{:.4f}/{:.4f}/{:.4f}/{:.4f}".format(
            train_acc, all_test_acc, tr_acc, te_acc, test_acc, ref_acc))
        print("avg conf on train/all_test/tr/te/test/reference set: {:.4f}/{:.4f}/{:.4f}/{:.4f}/{:.4f}/{:.4f}".format(
            train_conf, all_test_conf, tr_conf, te_conf, test_conf, ref_conf))

        train_accs[i - 1], test_accs[i - 1] = train_acc, test_acc
        lira_acc[i - 1], lira_aucs = lira_attack(
            infer_train_conf_tr, train_label_tr_attack, infer_train_conf_te, train_label_te_attack, infer_ref_conf,
            ref_label, infer_test_conf, test_label, num_class=args.num_class, attack_epochs=attack_epochs,
            batch_size=256)

    temp_fpr, temp_tpr, temp_thresholds = lira_aucs
    lira_fpr.append(temp_fpr)
    lira_tpr.append(temp_tpr)
    lira_thresholds.append(temp_thresholds)
    lira_fpr = np.stack(tuple(lira_fpr))
    lira_tpr = np.stack(tuple(lira_tpr))
    lira_thresholds = np.stack(tuple(lira_thresholds))

    from pathlib import Path
    filepath = Path(save_checkpoint_path)
    filepath.mkdir(parents=True, exist_ok=True)
    # AUC
    cur_scp = f'{save_checkpoint_path}/lira_auc-roc_{num_runs}.csv'
    df = generate_multi_line_dataframe(lira_fpr, lira_tpr, lira_thresholds)
    df.to_csv(cur_scp, index=False)


def generate_multi_line_dataframe(fpr, tpr, thresholds):
    fpr_avg, fpr_std = np.mean(fpr, axis=0), np.std(fpr, axis=0)
    tpr_avg, tpr_std = np.mean(tpr, axis=0), np.std(tpr, axis=0)
    # array
    data = np.stack((fpr_avg, fpr_std, tpr_avg, tpr_std), axis=0)
    # Transposing the array
    transposed_array = data.T
    # Converting the transposed array to a DataFrame
    df = pd.DataFrame(transposed_array, columns=['fpr_avg', 'fpr_std', 'tpr_avg', 'tpr_std'])
    print(df)
    return df


def generate_dataframe(array):
    r1 = np.mean(array)
    r2 = np.std(array)
    data = np.array([
        [r1, r2]
    ])
    df = pd.DataFrame(data, columns=['avg', 'std'])
    print(df)
    return df


if __name__ == '__main__':
        main()
