import argparse
import os
import numpy as np
import sys
import yaml
import distutils.util

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim

# config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from attack.dsq_attack import system_attack, get_entropy, get_mentropy
from utils import mkdir_p, AverageMeter, accuracy, print_acc_conf
from cifar_utils import transform_train, transform_test, Cifardata, DistillCifardata, WarmUpLR, ModelwNorm, \
    transform_train_aug
from cifar100.models.model_selector import get_network

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def undefendtest(testloader, model, criterion, len_data, args):
    # switch to evaluate mode
    model.eval()

    num_class = args.num_class
    batch_size = args.batch_size

    losses = AverageMeter()
    infer_np = np.zeros((len_data, num_class))
    # print(infer_np.shape)

    for batch_ind, (inputs, targets) in enumerate(testloader):
        # compute output
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)

        outputs = model(inputs)
        # print(inputs.size(0), outputs.size(0), batch_ind*batch_size - (batch_ind*batch_size+inputs.shape[0]))
        infer_np[batch_ind * batch_size: batch_ind * batch_size + inputs.shape[0]] = (
            F.softmax(outputs, dim=1)).detach().cpu().numpy()

        loss = criterion(outputs, targets)
        losses.update(loss.item(), inputs.size()[0])

    return (losses.avg, infer_np)  # , logits_np)


def _thre_setting(tr_values, te_values):
    """Select the best threshold"""
    value_list = np.concatenate((tr_values, te_values))
    thre, max_acc = 0, 0
    for value in value_list:
        tr_ratio = np.sum(tr_values <= value) / (len(tr_values) + 0.0)
        te_ratio = np.sum(te_values > value) / (len(te_values) + 0.0)
        acc = 0.5 * (tr_ratio + te_ratio)
        if acc > max_acc:
            thre, max_acc = value, acc
    return thre


def _mem_inf_thre(t_tr_values, t_te_values):
    """MIA by thresholding overall feature values"""
    t_tr_mem, t_te_non_mem = 0, 0
    thre = _thre_setting(t_tr_values, t_te_values)
    t_tr_mem += np.sum(t_tr_values < thre)
    t_te_non_mem += np.sum(t_te_values >= thre)
    mem_inf_acc = 0.5 * (t_tr_mem / (len(t_tr_values) + 0.0) + t_te_non_mem / (len(t_te_values) + 0.0))
    info = 'MIA (general threshold): the attack acc is {acc:.3f}'.format(acc=mem_inf_acc)
    print(info)
    return thre  # , mem_inf_acc


def sort_with_indexes(arr):
    # Get the indices that would sort the array
    sorted_indices = np.argsort(arr)
    # Sort the array
    sorted_arr = np.sort(arr)
    return sorted_arr, sorted_indices


def create_directory_structure(directory_path):
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)
        print(f"Directory '{directory_path}' created successfully.")
    else:
        print(f"Directory '{directory_path}' already exists.")


def main():
    parser = argparse.ArgumentParser(description='setting for cifar100')
    parser.add_argument('--model', type=str, default='mobilenetv3_small')
    parser.add_argument('--attack_epochs', type=int, default=150, help='attack epochs in NN attack')
    parser.add_argument('--print_epoch', type=int, default=5,
                        help='print single model training stats per print_epoch training')
    parser.add_argument('--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('--batch_step', type=int, default=4, help='batch accumulation steps')
    parser.add_argument('--warmup', type=int, default=1, help='warm up epochs')
    parser.add_argument('--num_worker', type=int, default=1, help='number workers')
    parser.add_argument('--num_class', type=int, default=100, help='num class')
    # parser.add_argument('--num_runs', type=int, default=1)
    parser.add_argument('--run_idx', type=int, default=100, help='idx running')
    # conf
    parser.add_argument('--data_aug', type=distutils.util.strtobool, default=True, help='turn on data augmentation')
    parser.add_argument('--save_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to load the checkpoints')

    args = parser.parse_args()
    print(dict(args._get_kwargs()))

    attack_epochs = args.attack_epochs
    batch_size = args.batch_size
    num_class = args.num_class
    warmup = args.warmup
    num_worker = args.num_worker

    DATASET_PATH = os.path.join(root_dir, 'cifar100', 'data')
    load_checkpoint_path = os.path.join(args.load_path, 'cifar100', args.model, 'undefend',
                                        'aug' if args.data_aug else 'no_aug', str(args.run_idx))
    checkpoint_path = os.path.join(args.save_path, 'cifar100', args.model, 'e2a_mentr',
                                   'aug' if args.data_aug else 'no_aug', str(args.run_idx))
    print(checkpoint_path)

    train_data_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_data.npy'))
    train_label_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_label.npy'))
    train_data_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_data.npy'))
    train_label_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_label.npy'))
    train_data = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data.npy'))
    train_label = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label.npy'))
    test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'test_data.npy'))
    test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'test_label.npy'))
    ref_data = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_data.npy'))
    ref_label = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_label.npy'))
    all_test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_data.npy'))
    all_test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_label.npy'))

    # print first 20 labels for each subset, for checking with other experiments
    print(train_label_tr_attack[:20])
    print(train_label_te_attack[:20])
    print(test_label[:20])
    print(ref_label[:20])

    # if data augmented
    # if args.data_aug:
    #     trainset = Cifardata(train_data, train_label, transform_train_aug)
    # else:
    #     trainset = Cifardata(train_data, train_label, transform_train)
    # trainset = Cifardata(train_data, train_label, transform_train)
    traintestset = Cifardata(train_data, train_label, transform_test)
    testset = Cifardata(test_data, test_label, transform_test)
    refset = Cifardata(ref_data, ref_label, transform_test)

    trset = Cifardata(train_data_tr_attack, train_label_tr_attack, transform_test)
    teset = Cifardata(train_data_te_attack, train_label_te_attack, transform_test)
    alltestset = Cifardata(all_test_data, all_test_label, transform_test)

    trloader = torch.utils.data.DataLoader(trset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    teloader = torch.utils.data.DataLoader(teset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    alltestloader = torch.utils.data.DataLoader(alltestset, batch_size=batch_size, shuffle=False,num_workers=num_worker)
    # trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_worker)
    traintestloader = torch.utils.data.DataLoader(traintestset, batch_size=batch_size, shuffle=False,num_workers=num_worker)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    refloader = torch.utils.data.DataLoader(refset, batch_size=batch_size, shuffle=False, num_workers=num_worker)

    cur_cp = load_checkpoint_path
    criterion = nn.CrossEntropyLoss().to(device, torch.float)
    net_1 = get_network(args.model, num_classes=args.num_class)
    net = ModelwNorm(net_1)
    resume = cur_cp + '/model_last.pth.tar'
    print('==> Resuming from checkpoint' + resume)
    assert os.path.isfile(resume), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(resume, map_location='cpu')
    net.load_state_dict(checkpoint['state_dict'])
    net = net.to(device, torch.float)

    tr_loss, infer_train_conf_tr = undefendtest(traintestloader, net, criterion, len(traintestset), args)
    # tr_acc, tr_conf = print_acc_conf(infer_train_conf_tr, train_label)
    ref_loss, infer_ref_conf = undefendtest(refloader, net, criterion, len(refset), args)
    # ref_acc, ref_conf = print_acc_conf(infer_ref_conf, ref_label)
    train_member_pred, train_member_label = infer_train_conf_tr, train_label
    train_nonmember_pred, train_nonmember_label = infer_ref_conf, ref_label
    train_mem_stat = get_mentropy(train_member_pred, train_member_label)
    train_nonmem_stat = get_mentropy(train_nonmember_pred, train_nonmember_label)
    thre = _mem_inf_thre(train_mem_stat, train_nonmem_stat)
    print('the thres:', thre)
    # compute and set
    train_arr = (thre - train_mem_stat) # / thre
    ref_arr = (train_nonmem_stat - thre) # / thre
    sorted_train_arr, sorted_train_idx = sort_with_indexes(train_arr)
    sorted_ref_arr, sorted_ref_idx = sort_with_indexes(ref_arr)
    print(sorted_train_arr[:10], sorted_train_arr[-10:])
    print(sorted_ref_arr[:10], sorted_ref_arr[-10:])
    # Save the sorted arrays
    create_directory_structure(checkpoint_path)
    np.savez(f'{checkpoint_path}/train.npz', val=sorted_train_arr, idx=sorted_train_idx)
    np.savez(f'{checkpoint_path}/ref.npz', val=sorted_ref_arr, idx=sorted_ref_idx)


if __name__ == '__main__':
    main()
