import argparse
import os
import numpy as np
import sys

import pandas as pd
import yaml
import distutils.util

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim

# config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from attack.dsq_attack import system_attack, get_entropy, get_mentropy
from utils import mkdir_p, AverageMeter, accuracy, print_acc_conf
from cifar_utils import transform_train, transform_test, Cifardata, DistillCifardata, WarmUpLR, ModelwNorm, \
    transform_train_aug
from cifar100.models.model_selector import get_network

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def undefendtest(testloader, model, criterion, len_data, args):
    # switch to evaluate mode
    model.eval()

    num_class = args.num_class
    batch_size = args.batch_size

    losses = AverageMeter()
    infer_np = np.zeros((len_data, num_class))
    # print(infer_np.shape)

    for batch_ind, (inputs, targets) in enumerate(testloader):
        # compute output
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)

        outputs = model(inputs)
        # print(inputs.size(0), outputs.size(0), batch_ind*batch_size - (batch_ind*batch_size+inputs.shape[0]))
        infer_np[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = (
            F.softmax(outputs, dim=1)).detach().cpu().numpy()

        loss = criterion(outputs, targets)
        losses.update(loss.item(), inputs.size()[0])

    return (losses.avg, infer_np)  # , logits_np)


def _thre_setting(tr_values, te_values):
    """Select the best threshold"""
    value_list = np.concatenate((tr_values, te_values))
    thre, max_acc = 0, 0
    for value in value_list:
        tr_ratio = np.sum(tr_values <= value) / (len(tr_values) + 0.0)
        te_ratio = np.sum(te_values > value) / (len(te_values) + 0.0)
        acc = 0.5 * (tr_ratio + te_ratio)
        if acc > max_acc:
            thre, max_acc = value, acc
    return thre


def _mem_inf_thre(t_tr_values, t_te_values):
    """MIA by thresholding overall feature values"""
    t_tr_mem, t_te_non_mem = 0, 0
    thre = _thre_setting(t_tr_values, t_te_values)
    t_tr_mem += np.sum(t_tr_values < thre)
    t_te_non_mem += np.sum(t_te_values >= thre)
    mem_inf_acc = 0.5 * (t_tr_mem / (len(t_tr_values) + 0.0) + t_te_non_mem / (len(t_te_values) + 0.0))
    info = 'MIA (general threshold): the attack acc is {acc:.3f}'.format(acc=mem_inf_acc)
    print(info)
    return thre  # , mem_inf_acc


def sort_with_indexes(arr):
    # Get the indices that would sort the array
    sorted_indices = np.argsort(arr)
    # Sort the array
    sorted_arr = np.sort(arr)
    return sorted_arr, sorted_indices


def create_directory_structure(directory_path):
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)
        print(f"Directory '{directory_path}' created successfully.")
    else:
        print(f"Directory '{directory_path}' already exists.")


def main():
    parser = argparse.ArgumentParser(description='setting for cifar100')
    parser.add_argument('--model', type=str, default='mobilenetv3_small')
    parser.add_argument('--attack_epochs', type=int, default=150, help='attack epochs in NN attack')
    parser.add_argument('--print_epoch', type=int, default=5,
                        help='print single model training stats per print_epoch training')
    parser.add_argument('--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('--batch_step', type=int, default=4, help='batch accumulation steps')
    parser.add_argument('--warmup', type=int, default=1, help='warm up epochs')
    parser.add_argument('--num_worker', type=int, default=1, help='number workers')
    parser.add_argument('--num_class', type=int, default=100, help='num class')
    parser.add_argument('--num_runs', type=int, default=1)
    # parser.add_argument('--run_idx', type=int, default=100, help='idx running')
    # conf
    parser.add_argument('--data_aug', type=distutils.util.strtobool, default=True, help='turn on data augmentation')
    parser.add_argument('--save_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to load the checkpoints')

    args = parser.parse_args()
    print(dict(args._get_kwargs()))

    attack_epochs = args.attack_epochs
    batch_size = args.batch_size
    num_class = args.num_class
    warmup = args.warmup
    num_worker = args.num_worker

    DATASET_PATH = os.path.join(root_dir, 'cifar100', 'data')
    load_checkpoint_path = os.path.join(args.load_path, 'cifar100', args.model, 'undefend_aug_8x')
    checkpoint_path = os.path.join(args.save_path, 'cifar100', args.model, 'stat_mentr_aug_8x')
    print(checkpoint_path)

    train_data_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_data.npy'))
    train_label_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_label.npy'))
    train_data_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_data.npy'))
    train_label_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_label.npy'))
    train_data = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data.npy'))
    train_label = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label.npy'))
    test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'test_data.npy'))
    test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'test_label.npy'))
    ref_data = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_data.npy'))
    ref_label = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_label.npy'))
    all_test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_data.npy'))
    all_test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_label.npy'))

    # print first 20 labels for each subset, for checking with other experiments
    print(train_label_tr_attack[:20])
    print(train_label_te_attack[:20])
    print(test_label[:20])
    print(ref_label[:20])

    # if data augmented
    # if args.data_aug:
    #     trainset = Cifardata(train_data, train_label, transform_train_aug)
    # else:
    #     trainset = Cifardata(train_data, train_label, transform_train)
    # trainset = Cifardata(train_data, train_label, transform_train)
    traintestset = Cifardata(train_data, train_label, transform_test)
    testset = Cifardata(test_data, test_label, transform_test)
    refset = Cifardata(ref_data, ref_label, transform_test)

    trset = Cifardata(train_data_tr_attack, train_label_tr_attack, transform_test)
    teset = Cifardata(train_data_te_attack, train_label_te_attack, transform_test)
    alltestset = Cifardata(all_test_data, all_test_label, transform_test)

    trloader = torch.utils.data.DataLoader(trset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    teloader = torch.utils.data.DataLoader(teset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    alltestloader = torch.utils.data.DataLoader(alltestset, batch_size=batch_size, shuffle=False,num_workers=num_worker)
    # trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_worker)
    traintestloader = torch.utils.data.DataLoader(traintestset, batch_size=batch_size, shuffle=False,num_workers=num_worker)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    refloader = torch.utils.data.DataLoader(refset, batch_size=batch_size, shuffle=False, num_workers=num_worker)

    acc_list = None
    num_runs = args.num_runs
    for i in range(1, num_runs+1):
        cur_cp = f'{load_checkpoint_path}/{i}'
        criterion = nn.CrossEntropyLoss().to(device, torch.float)
        net_1 = get_network(args.model, num_classes=args.num_class)
        net = ModelwNorm(net_1)
        resume = cur_cp + '/model_last.pth.tar'
        print('==> Resuming from checkpoint' + resume)
        assert os.path.isfile(resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(resume, map_location='cpu')
        net.load_state_dict(checkpoint['state_dict'])
        net = net.to(device, torch.float)

        tr_loss, infer_train_conf_tr = undefendtest(traintestloader, net, criterion, len(traintestset), args)
        ref_loss, infer_ref_conf = undefendtest(refloader, net, criterion, len(refset), args)
        train_member_pred, train_member_label = infer_train_conf_tr, train_label
        train_nonmember_pred, train_nonmember_label = infer_ref_conf, ref_label
        train_mem_stat = get_mentropy(train_member_pred, train_member_label)
        train_nonmem_stat = get_mentropy(train_nonmember_pred, train_nonmember_label)
        thre = _mem_inf_thre(train_mem_stat, train_nonmem_stat)
        print('the thres:', thre)
        # compute and set
        if acc_list is None:
            acc_list = [0.0 for _ in range(train_mem_stat.shape[0])]

        for idx in range(len(train_mem_stat)):
            score = 1.0 if train_mem_stat[idx] < thre else 0.0
            acc_list[idx] += score
    acc_list = list(np.array(acc_list) / num_runs)
    df = pd.DataFrame(acc_list, columns=['acc'])
    print(df['acc'].value_counts())
    # Save the sorted arrays
    create_directory_structure(checkpoint_path)
    df.to_csv(f'{checkpoint_path}/stat_{num_runs}.csv', index=False)


if __name__ == '__main__':
    main()
