import argparse
import os
import math
import shutil
import random
import distutils.util
import numpy as np
import pandas as pd
import sys
import yaml

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim

# from cifar100.MemGuard.memguard_run import model

# config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from attack.dsq_attack import system_attack
from utils import mkdir_p, AverageMeter, accuracy, print_acc_conf, TrainRecorder
from cifar_utils import transform_train, transform_train_aug, transform_test, Cifardata, DistillCifardata, WarmUpLR, \
    ModelwNorm
from cifar100.models.model_selector import get_network

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def find_rank(arr, targets):
    ranks = []
    for target in targets:
        index = np.where(arr == target)[0]
        ranks.append(int(index))
    return ranks


def get_bin_value(bins, num_samples, rank):
    ans = -1
    for i in range(len(bins)):
        if rank <= bins[i] * num_samples:
            ans = i
            break
    return bins[ans]


def overlap_percentage(set_list):
    # Calculate the intersection of all sets
    intersection = set.intersection(*set_list)
    # Calculate the total number of unique elements in all sets
    total_unique_elements = len(set_list[0])
    # Calculate the percentage overlap
    percentage_overlap = (len(intersection) / total_unique_elements) * 100
    return percentage_overlap


def overlap_samples(set_list):
    # Calculate the intersection of all sets
    intersection = set.intersection(*set_list)
    return intersection


def main():
    parser = argparse.ArgumentParser(description='setting for cifar100')
    parser.add_argument('--model', type=str, default='mobilenetv3_small_50')
    parser.add_argument('--data_aug', type=distutils.util.strtobool, default=True, help='turn on data augmentation')
    parser.add_argument('--num_run', type=int, default=100, help='idx running')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to load the checkpoints')

    parser.add_argument('--data_retain', type=float, default=1, help='retain rate')
    parser.add_argument('--conf', type=str, default='', help='')
    args = parser.parse_args()
    print(dict(args._get_kwargs()))

    # attack_epochs = args.attack_epochs
    # batch_size = args.batch_size
    # num_class = args.num_class
    # classifier_epochs = args.classifier_epochs
    # print_epoch = args.print_epoch
    # warmup = args.warmup
    # num_worker = args.num_worker
    data_retain_rate = args.data_retain

    DATASET_PATH = os.path.join(root_dir, 'cifar100', 'data')
    load_checkpoint_path = os.path.join(args.load_path, 'cifar100', args.model, 'e2a_mentr_rl',
                                        'aug' if args.data_aug else 'no_aug', args.conf)
    load_cpt_vanilla = os.path.join(args.load_path, 'cifar100', args.model, 'e2a_mentr',
                                    'aug' if args.data_aug else 'no_aug')
    load_spt = os.path.join(args.load_path, 'cifar100', args.model, 'e2a_mentr_rl',
                                        'aug' if args.data_aug else 'no_aug', args.conf)
    print(load_checkpoint_path)

    train_data_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_data.npy'))
    train_label_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_label.npy'))
    train_data_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_data.npy'))
    train_label_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_label.npy'))
    train_data = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data.npy'))
    train_label = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label.npy'))
    test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'test_data.npy'))
    test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'test_label.npy'))
    ref_data = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_data.npy'))
    ref_label = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_label.npy'))
    all_test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_data.npy'))
    all_test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_label.npy'))

    # print first 20 labels for each subset, for checking with other experiments
    print(train_label_tr_attack[:20])
    print(train_label_te_attack[:20])
    print(test_label[:20])
    print(ref_label[:20])

    arrs = []
    for i in range(1, args.num_run + 1):
        lcp = f'{load_checkpoint_path}/{i}'
        # Pruning a part train data
        # Load the arrays
        rank_data = np.load(f'{lcp}/train.npz')
        # Retrieve the arrays
        rank_val = rank_data['val']
        rank_idx = rank_data['idx']
        #print(rank_idx)
        # prune data
        num_retain = int(data_retain_rate * len(rank_idx))
        new_idx = rank_idx[:num_retain]
        arrs.append(set(new_idx))

    ols = overlap_samples(arrs)
    ols = list(ols)
    # print(ols)

    arrs = []
    for i in range(1, args.num_run + 1):
        lcp = f'{load_cpt_vanilla}/{i}'
        # Pruning a part train data
        # Load the arrays
        rank_data = np.load(f'{lcp}/train.npz')
        # Retrieve the arrays
        rank_val = rank_data['val']
        rank_idx = rank_data['idx']
        # prune data
        new_idx = rank_idx
        arrs.append(new_idx)

    # print(len(arrs), len(arrs[0]))

    ranks = []
    for arr in arrs:
        ranks.append(find_rank(arr, ols))
    avg_ranks = [0.0 for _ in range(len(ols))]
    for i in range(0, args.num_run):
        for j in range(len(ols)):
            avg_ranks[j] += ranks[i][j] / args.num_run
    num_samples = len(arrs[0])
    bins = [0.25, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    v_list = []
    for avg_rank in avg_ranks:
        v_list.append(get_bin_value(bins, num_samples, avg_rank))
    df = pd.DataFrame(v_list, columns=['bin'])
    print(df)
    df.to_csv(f'/home/xfang23/PycharmProjects/Privacy_Plot/DataValuation/cifar100/olwhere_rl{args.conf}_{args.data_retain}.csv', index=False)


if __name__ == '__main__':
    main()
