import argparse
import os
import shutil
import random
import distutils.util
import numpy as np
import pandas as pd
import sys
import yaml

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim

# config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from attack.dsq_attack import system_attack
from utils import mkdir_p, AverageMeter, accuracy, print_acc_conf, TrainRecorder
from cifar_utils import transform_train, transform_train_aug, transform_test, Cifardata, DistillCifardata, WarmUpLR, \
    ModelwNorm
from cifar100.models.model_selector import get_network
from cifar100.PrivacyDV.cethr import CrossEntropyThr

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def eval_dist(model, dataloader):
    pred_list = []
    true_list = []

    model.eval()
    for batch_ind, (inputs, targets) in enumerate(dataloader):
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)
        dists = distance_to_decision_boundary(model, inputs)
        # to list
        preds = list(dists.cpu())
        targets = list(targets.cpu())
        # print(preds)
        pred_list += preds
        true_list += targets
        torch.cuda.empty_cache()
    return pred_list, true_list


def distance_to_decision_boundary(model, x):
    x.requires_grad_(True)
    logits = model(x)
    with torch.no_grad():
        logits = torch.softmax(logits, dim=-1)
        margin = torch.abs(
            torch.max(logits, dim=1, keepdim=True).values - logits
        )
        margin = torch.kthvalue(margin, k=2, dim=1).values
        distance = margin
    x.requires_grad_(False)
    return distance


def main():
    parser = argparse.ArgumentParser(description='setting for cifar100')

    args = parser.parse_args()
    args.model = 'resnet18'
    args.conf = 'hp1'
    args.data_aug = True
    args.save_path = "/2nd_disk/experiment/p_dv"
    filepath = "/2nd_disk/result/p_dv"

    print(dict(args._get_kwargs()))
    num_runs = 50
    thresholds = ['2.0', '1.5', '1.0', '0']
    # attack_epochs = args.attack_epochs
    batch_size = 1024
    num_class = 100
    # classifier_epochs = args.classifier_epochs
    # print_epoch = args.print_epoch
    # warmup = args.warmup
    num_worker = 1

    DATASET_PATH = os.path.join(root_dir, 'cifar100', 'data')
    checkpoint_path = os.path.join(args.save_path, 'cifar100', args.model, 'privacydv', 'init',
                                   'aug' if args.data_aug else 'no_aug', args.conf)
    print(checkpoint_path)

    train_data_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_data.npy'))
    train_label_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_label.npy'))
    train_data_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_data.npy'))
    train_label_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_label.npy'))
    train_data = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data.npy'))
    train_label = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label.npy'))
    test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'test_data.npy'))
    test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'test_label.npy'))
    ref_data = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_data.npy'))
    ref_label = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_label.npy'))
    all_test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_data.npy'))
    all_test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_label.npy'))

    # print first 20 labels for each subset, for checking with other experiments
    print(train_label_tr_attack[:20])
    print(train_label_te_attack[:20])
    print(test_label[:20])
    print(ref_label[:20])

    # if data augmented
    if args.data_aug:
        trainset = Cifardata(train_data, train_label, transform_train_aug)
    else:
        trainset = Cifardata(train_data, train_label, transform_train)
    # load dataset
    # trainset = Cifardata(train_data, train_label, transform_train)
    traintestset = Cifardata(train_data, train_label, transform_test)
    testset = Cifardata(test_data, test_label, transform_test)
    refset = Cifardata(ref_data, ref_label, transform_test)

    trset = Cifardata(train_data_tr_attack, train_label_tr_attack, transform_test)
    teset = Cifardata(train_data_te_attack, train_label_te_attack, transform_test)
    alltestset = Cifardata(all_test_data, all_test_label, transform_test)

    trloader = torch.utils.data.DataLoader(trset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    teloader = torch.utils.data.DataLoader(teset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    alltestloader = torch.utils.data.DataLoader(alltestset, batch_size=batch_size, shuffle=False,
                                                num_workers=num_worker)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_worker)
    traintestloader = torch.utils.data.DataLoader(traintestset, batch_size=batch_size, shuffle=False,
                                                  num_workers=num_worker)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    refloader = torch.utils.data.DataLoader(refset, batch_size=batch_size, shuffle=False, num_workers=num_worker)

    if not os.path.exists(filepath):
        os.makedirs(filepath)
    filepath = os.path.join(filepath, 'cifar100')
    if not os.path.exists(filepath):
        os.makedirs(filepath)
    filepath = os.path.join(filepath, args.model)
    if not os.path.exists(filepath):
        os.makedirs(filepath)
    filepath = os.path.join(filepath, 'privacydv')
    if not os.path.exists(filepath):
        os.makedirs(filepath)
    filepath = os.path.join(filepath, 'init')
    if not os.path.exists(filepath):
        os.makedirs(filepath)
    filepath = os.path.join(filepath, 'aug' if args.data_aug else 'no_aug', args.conf)
    if not os.path.exists(filepath):
        os.makedirs(filepath)

    model_1 = get_network(arch=args.model, num_classes=100)
    model = ModelwNorm(model_1)
    for threshold in thresholds:
        train_dist_df = pd.DataFrame()
        test_dist_df = pd.DataFrame()

        model.cpu()
        for i in range(1, num_runs+1):
            cur_cpt_path = f'{checkpoint_path}/{i}'
            if threshold != '0':
                cur_cpt_path = f'{cur_cpt_path}/model_thr_{threshold}.pth.tar'
            else:
                cur_cpt_path = f'{cur_cpt_path}/model_last.pth.tar'
            state = torch.load(cur_cpt_path)['state_dict']
            model.load_state_dict(state)
            model.to(device)
            train_pred_list, train_true_list = eval_dist(model, trainloader)
            test_pred_list, test_true_list = eval_dist(model, testloader)
            # train_dist_df = pd.DataFrame({'0': [float(d) for d in train_pred_list]})
            # test_dist_df = pd.DataFrame({f'{i}': [float(d) for d in test_pred_list]})

            train_dist_df = pd.concat([train_dist_df, pd.DataFrame({f'{i}': [float(d) for d in train_pred_list]})], axis=1)
            test_dist_df = pd.concat([test_dist_df, pd.DataFrame({f'{i}': [float(d) for d in test_pred_list]})], axis=1)
            # Using DataFrame.insert() to add a column
            # test_dist_df.insert(i, "i", [float(d) for d in test_pred_list], True)

        print(threshold, 'complete')
        train_dist_df.to_csv(f'{filepath}/train_{threshold}.csv')
        test_dist_df.to_csv(f'{filepath}/test_{threshold}.csv')


if __name__ == '__main__':
    main()
