import torch
import argparse
import os
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import numpy as np
import pandas as pd
import itertools
import math
from math import sqrt
import random
import shap
import copy
from tqdm import tqdm
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from scipy.stats import pearsonr, spearmanr
from scipy.stats import norm, uniform
from scipy.stats import kstest, ks_2samp
from sklearn.svm import LinearSVC
from sklearn.neural_network import MLPClassifier
from sklearn.utils import shuffle
import networkx as nx
from sklearn.svm import SVC
from sklearn import svm, datasets
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.naive_bayes import GaussianNB
import sklearn
from sklearn.cluster import KMeans
from torch.utils.data import TensorDataset, DataLoader
import seaborn as sns
import torch.optim as optim
from train_utils import train_mlp, test_mlp, train_sklearn, test_sklearn, get_savefldr, get_loadfile
from dataloader import get_folktables_dataset
from models import get_model, flatten_weights
from helper import init_to_name_dict, init_to_color_dict, id_to_marker_dict, id_to_arch_dict
from meta_fairness import equal_opp_binary, fair_loss_binary, avg_odds_binary, acc_diff_binary

from fairtorch_local import DemographicParityLoss, EqualiedOddsLoss

from neb import neb_path_finder

import warnings
warnings.filterwarnings("ignore")

torch.autograd.set_detect_anomaly(True)

def plot_decouple_boxplot(model, input_size, outplot, ckptlosstype, use_cuda, protected_class, survey_year, state, ckptfldr, plotdata):
    we_init, arch_id, init_seed, shuffle_seed = model

    trainloader, validloader, testloader = get_folktables_dataset(train_valid_test_split=[0.7, 0.1, 0.2],
        protected_class=protected_class, survey_year=survey_year, states=[state], shuffle_seed=shuffle_seed, dataloadershuffle=False, dataset_type='acsemployment')

    fig, ax = plt.subplots(figsize=(8, 5))

    seed_tuple_list = []
    seed_tuple_list.append((0, 0))
    for iseed in range(1, 50):
        seed_tuple_list.append((iseed, 0))
    # for sseed in range(1, 50):
    #     seed_tuple_list.append((0, sseed))
    # for s in range(1, 50):
    #     seed_tuple_list.append((s, s))

    fscore_arr = []
    avg_odds_arr = []
    for (init_seed, shuffle_seed) in tqdm(seed_tuple_list):
        loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(init_seed), shuffle_seed, 299)
        model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
        fscore, avg_odds, labels, preds, groups = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds', return_outputs=True)
        fscore_arr.append(fscore)
        avg_odds_arr.append(avg_odds)
        print(avg_odds, fscore)

    # print(np.median(fscore_arr), np.percentile(fscore_arr, 75), np.percentile(fscore_arr, 25), np.max(fscore_arr), np.min(fscore_arr))
    # print(np.median(avg_odds_arr), np.percentile(avg_odds_arr, 75), np.percentile(avg_odds_arr, 25), np.max(avg_odds_arr), np.min(avg_odds_arr))
    # print(np.percentile(fscore_arr, 25), np.median(fscore_arr), np.percentile(fscore_arr, 75), np.min(fscore_arr), np.max(fscore_arr))
    # print(np.percentile(avg_odds_arr, 25), np.median(avg_odds_arr), np.percentile(avg_odds_arr, 75), np.min(avg_odds_arr), np.max(avg_odds_arr))


def plot_decouple_lineplot(model, input_size, outplot, ckptlosstype, use_cuda, protected_class, survey_year, state, ckptfldr, plotdata):
    we_init, arch_id, init_seed, shuffle_seed = model

    trainloader, validloader, testloader = get_folktables_dataset(train_valid_test_split=[0.7, 0.1, 0.2],
        protected_class=protected_class, survey_year=survey_year, states=[state], shuffle_seed=shuffle_seed, dataloadershuffle=False, dataset_type='celeba')

    fig, ax = plt.subplots(figsize=(8, 5))

    seed_tuple_list = []
    seed_tuple_list.append((0, 0))
    # for iseed in range(1, 50):
    #     seed_tuple_list.append((iseed, 0))
    for sseed in range(1, 50):
        seed_tuple_list.append((0, sseed))
    # for s in range(1, 3):
    #     seed_tuple_list.append((s, s))

    print("x y1 y2 y3 y4 y5")
    for epoch in range(1, 300):
        local_arr = []
        for (init_seed, shuffle_seed) in tqdm(seed_tuple_list):
            loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(init_seed), shuffle_seed, epoch)
            model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
            fscore, avg_odds, labels, preds, groups = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds', return_outputs=True)
            local_arr.append(avg_odds)
        print(epoch, np.median(local_arr), np.percentile(local_arr, 75), np.percentile(local_arr, 25), np.max(local_arr), np.min(local_arr))

        # per_seed_arr.append(local_arr)

    # per_epoch_arr = np.transpose(per_seed_arr)
    # for ele, epoch in zip(per_epoch_arr, range(250, 300)):
    #     print(epoch, ele[0], ele[1], ele[2])
    #
    # corr1 = spearmanr(per_seed_arr[0], per_seed_arr[1])[0]
    # corr2 = spearmanr(per_seed_arr[1], per_seed_arr[2])[0]
    # corr3 = spearmanr(per_seed_arr[0], per_seed_arr[1])[0]
    # print(corr1, corr2, corr3)
    # print("Average Correlation : ", (corr1 + corr2 + corr3)/3)


def plot_variance_epochs_boxplot(model, input_size, outplot, ckptlosstype, use_cuda, protected_class, survey_year, state, ckptfldr, plotdata):
    we_init, arch_id, init_seed, shuffle_seed = model

    trainloader, validloader, testloader = get_folktables_dataset(train_valid_test_split=[0.7, 0.1, 0.2],
        protected_class=protected_class, survey_year=survey_year, states=[state], shuffle_seed=shuffle_seed, dataloadershuffle=False, dataset_type='acsincome')

    fig, ax = plt.subplots(figsize=(8, 5))

    seed_tuple_list = []
    seed_tuple_list.append((0, 0))
    # for iseed in range(1, 50):
    #     seed_tuple_list.append((iseed, 0))
    # for sseed in range(1, 50):
    #     seed_tuple_list.append((0, sseed))
    for s in range(1, 50):
        seed_tuple_list.append((s, s))

    avg_odds_arr = []
    min_tpr_arr = []
    maj_tpr_arr = []
    min_acc_arr = []
    maj_acc_arr = []
    # for epoch in range(100, 300):
    for (init_seed, shuffle_seed) in tqdm(seed_tuple_list):
        # loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(init_seed), shuffle_seed, epoch)
        loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(init_seed), shuffle_seed, 299)
        model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
        fscore, avg_odds, labels, preds, groups = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds', return_outputs=True)
        minority_tpr = np.sum(preds[groups==1][labels[groups==1]==1])/np.sum(labels[groups==1])
        majority_tpr = np.sum(preds[groups==0][labels[groups==0]==1])/np.sum(labels[groups==0])
        minority_acc = np.sum(preds[groups==1]==labels[groups==1])/len(labels[groups==1])
        majority_acc = np.sum(preds[groups==0]==labels[groups==0])/len(labels[groups==0])

        avg_odds_arr.append(avg_odds)
        min_tpr_arr.append(minority_tpr*100)
        maj_tpr_arr.append(majority_tpr*100)
        min_acc_arr.append(minority_acc*100)
        maj_acc_arr.append(majority_acc*100)

    print(np.median(avg_odds_arr), np.percentile(avg_odds_arr, 75), np.percentile(avg_odds_arr, 25), np.max(avg_odds_arr), np.min(avg_odds_arr))
    print(np.median(min_tpr_arr), np.percentile(min_tpr_arr, 75), np.percentile(min_tpr_arr, 25), np.max(min_tpr_arr), np.min(min_tpr_arr))
    print(np.median(maj_tpr_arr), np.percentile(maj_tpr_arr, 75), np.percentile(maj_tpr_arr, 25), np.max(maj_tpr_arr), np.min(maj_tpr_arr))
    print(np.median(min_acc_arr), np.percentile(min_acc_arr, 75), np.percentile(min_acc_arr, 25), np.max(min_acc_arr), np.min(min_acc_arr))
    print(np.median(maj_acc_arr), np.percentile(maj_acc_arr, 75), np.percentile(maj_acc_arr, 25), np.max(maj_acc_arr), np.min(maj_acc_arr))


def plot_forgetting_total_epochs_lineplot(model, input_size, outplot, ckptlosstype, use_cuda, protected_class, survey_year, state, ckptfldr, plotdata):
    we_init, arch_id, init_seed, shuffle_seed = model

    trainloader, validloader, testloader = get_folktables_dataset(train_valid_test_split=[0.7, 0.1, 0.2],
        protected_class=protected_class, survey_year=survey_year, states=[state], shuffle_seed=shuffle_seed, dataloadershuffle=False, dataset_type='acsemployment')

    fig, ax = plt.subplots(figsize=(8, 5))

    loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(init_seed), shuffle_seed, 99)
    model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
    fscore, avg_odds, prevlabels, prevpreds, prevgroups = test_mlp(model, trainloader, cuda=use_cuda, fairness_criteria='avgodds', return_outputs=True)

    print("x y1 y2 y3 y4")
    min_for_neg, maj_for_neg, min_for_pos, maj_for_pos = [], [], [], []
    min_mem_neg, maj_mem_neg, min_mem_pos, maj_mem_pos = [], [], [], []
    for epoch in tqdm(range(100, 300)):
        loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(init_seed), shuffle_seed, epoch)
        model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
        fscore, avg_odds, labels, preds, groups = test_mlp(model, trainloader, cuda=use_cuda, fairness_criteria='avgodds', return_outputs=True)

        min_for_pos.append(np.logical_and(labels[groups==1]==1,
                                            np.logical_and((preds[groups==1] != prevpreds[groups==1]), (labels[groups==1]==prevpreds[groups==1]))))
        maj_for_pos.append(np.logical_and(labels[groups==0]==1,
                                            np.logical_and((preds[groups==0] != prevpreds[groups==0]), (labels[groups==0]==prevpreds[groups==0]))))
        min_for_neg.append(np.logical_and(labels[groups==1]==0,
                                            np.logical_and((preds[groups==1] != prevpreds[groups==1]), (labels[groups==1]==prevpreds[groups==1]))))
        maj_for_neg.append(np.logical_and(labels[groups==0]==0,
                                            np.logical_and((preds[groups==0] != prevpreds[groups==0]), (labels[groups==0]==prevpreds[groups==0]))))

        min_mem_pos.append(np.logical_and(labels[groups==1]==1,
                                            np.logical_and((preds[groups==1] != prevpreds[groups==1]), (labels[groups==1]==preds[groups==1]))))
        maj_mem_pos.append(np.logical_and(labels[groups==0]==1,
                                            np.logical_and((preds[groups==0] != prevpreds[groups==0]), (labels[groups==0]==preds[groups==0]))))
        min_mem_neg.append(np.logical_and(labels[groups==1]==0,
                                            np.logical_and((preds[groups==1] != prevpreds[groups==1]), (labels[groups==1]==preds[groups==1]))))
        maj_mem_neg.append(np.logical_and(labels[groups==0]==0,
                                            np.logical_and((preds[groups==0] != prevpreds[groups==0]), (labels[groups==0]==preds[groups==0]))))

        prevlabels, prevpreds, prevgroups = labels, preds, groups

    def forgetting_ttl(forgetting_curve, start_index=0):
        return np.array([np.sum(np.max(forgetting_curve[start_index:i+1], axis=0)) for i in range(len(forgetting_curve))])

    min_for_neg_ttl, maj_for_neg_ttl, min_for_pos_ttl, maj_for_pos_ttl = forgetting_ttl(min_for_neg), forgetting_ttl(maj_for_neg), forgetting_ttl(min_for_pos), forgetting_ttl(maj_for_pos)
    min_for_neg_ttl = 100*min_for_neg_ttl/np.sum(groups[labels==0])
    min_for_pos_ttl = 100*min_for_pos_ttl/np.sum(groups[labels==1])
    maj_for_neg_ttl = 100*maj_for_neg_ttl/np.sum(1-groups[labels==0])
    maj_for_pos_ttl = 100*maj_for_pos_ttl/np.sum(1-groups[labels==1])
    for i, epoch in enumerate(range(100, 300)):
        print(epoch, min_for_pos_ttl[i], min_for_neg_ttl[i], maj_for_pos_ttl[i], maj_for_neg_ttl[i])

    # def get_period(curve_forget, curve_memorize):
    #     period = []
    #     for i in range(len(curve_forget[0])):
    #         curve1ind = np.where(np.array(curve_forget[:, i])==True)[0]
    #         curve3ind = np.where(np.array(curve_memorize[:, i])==True)[0]
    #         if len(curve1ind) > len(curve3ind):
    #             curve1ind = curve1ind[1:]
    #         if len(curve1ind) < len(curve3ind):
    #             curve3ind = curve3ind[:-1]
    #         if len(curve1ind)==0 or len(curve3ind)==0:
    #             continue
    #         if curve1ind[0] < curve3ind[0]:
    #             curve1ind = curve1ind[1:]
    #             curve3ind = curve3ind[:-1]
    #         period.extend(curve1ind - curve3ind)
    #     return np.array(period)
    #
    # min_per_neg = get_period(np.array(min_for_neg), np.array(min_mem_neg))
    # maj_per_neg = get_period(np.array(maj_for_neg), np.array(maj_mem_neg))
    # min_per_pos = get_period(np.array(min_for_pos), np.array(min_mem_pos))
    # maj_per_pos = get_period(np.array(maj_for_pos), np.array(maj_mem_pos))
    #
    # def data_to_hist(curve):
    #     heights, bins = np.histogram(curve, bins=sorted(list(set(curve))))
    #     heights = list(100*heights/np.sum(heights))
    #     while len(heights) < 200:
    #         heights.append(0.)
    #     return heights
    #
    # min_per_neg, maj_per_neg, min_per_pos, maj_per_pos = data_to_hist(min_per_neg), data_to_hist(maj_per_neg), data_to_hist(min_per_pos), data_to_hist(maj_per_pos)
    # for i, period in enumerate(range(200)):
    #     print(period, min_per_pos[i], min_per_neg[i], maj_per_pos[i], maj_per_neg[i])


def plot_immediate_data_order_boxplot(model, input_size, outplot, ckptlosstype, use_cuda, protected_class, survey_year, state, ckptfldr, plotdata):
    we_init, arch_id, init_seed, shuffle_seed = model

    trainloader, validloader, testloader = get_folktables_dataset(train_valid_test_split=[0.7, 0.1, 0.2],
        protected_class=protected_class, survey_year=survey_year, states=[state], shuffle_seed=0, dataloadershuffle=True, dataset_type='acsincome')

    fig, ax = plt.subplots(figsize=(8, 5))

    seed_tuple_list = []
    seed_tuple_list.append((0, 0))
    # for iseed in range(1, 50):
    #     seed_tuple_list.append((iseed, 0))
    # for sseed in range(1, 50):
    #     seed_tuple_list.append((0, sseed))
    for s in range(1, 50):
        seed_tuple_list.append((s, s))

    avg_odds_arr_before = []
    avg_odds_arr_after = []
    for (init_seed, shuffle_seed) in tqdm(seed_tuple_list):
        for epoch in range(100, 300):
            if random.randint(0, 10)==2:
                loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(init_seed), shuffle_seed, epoch)
                model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
                fscore, avg_odds, labels, preds, groups = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds', return_outputs=True)
                avg_odds_arr_before.append(avg_odds)
                # print("Epoch : ", epoch)
                # print("Before : ", avg_odds)

                # torch.manual_seed(154)
                torch.manual_seed(172)
                trainloader = list(trainloader)
                model = train_mlp(model, trainloader[-20:], None, save_ite=False, cuda=use_cuda, losstype=ckptlosstype, epochs=1)
                fscore, avg_odds, labels, preds, groups = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds', return_outputs=True)
                avg_odds_arr_after.append(avg_odds)
                # print("After : ", avg_odds)

    # print(avg_odds_arr_before)
    # print(np.argmin(avg_odds_arr_before))
    # print(np.argmax(avg_odds_arr_before))
    print(np.median(avg_odds_arr_before), np.percentile(avg_odds_arr_before, 75), np.percentile(avg_odds_arr_before, 25), np.max(avg_odds_arr_before), np.min(avg_odds_arr_before))
    print(np.median(avg_odds_arr_after), np.percentile(avg_odds_arr_after, 75), np.percentile(avg_odds_arr_after, 25), np.max(avg_odds_arr_after), np.min(avg_odds_arr_after))


def dataloader_to_inputlist(data_loader, return_all=False):
    input_list = []
    label_list = []
    group_list = []
    for i, data in enumerate(data_loader, 0):
        inputs, labels, groups = data
        inputs = inputs.detach().cpu().numpy()
        labels = labels.detach().cpu().numpy()
        groups = groups.detach().cpu().numpy()
        # if i in [0, 50, 170, 451]:
        #     print(inputs[0])
        input_list.extend(inputs)
        label_list.extend(labels)
        group_list.extend(groups)
    # exit()
    if return_all:
        return input_list, label_list, group_list
    else:
        return input_list

def get_ordered_dataloader(input_list, label_list, group_list):
    input_train = np.concatenate(input_list, axis=0)
    label_train = np.concatenate(label_list, axis=0)
    group_train = np.concatenate(group_list, axis=0)

    tensorx_train = torch.from_numpy(input_train)
    tensory_train = torch.from_numpy(label_train)
    tensorgroup_train = torch.from_numpy(group_train)
    train_dataset = TensorDataset(tensorx_train, tensory_train, tensorgroup_train)
    trainloader = DataLoader(train_dataset, batch_size=16, shuffle=False, drop_last=False)

    return trainloader

def plot_group_forgetting_lineplot(model, input_size, outplot, ckptlosstype, use_cuda, protected_class, survey_year, state, ckptfldr, plotdata):
    we_init, arch_id, init_seed, shuffle_seed = model

    trainloader, validloader, testloader = get_folktables_dataset(train_valid_test_split=[0.7, 0.1, 0.2],
        protected_class=protected_class, survey_year=survey_year, states=[state], shuffle_seed=shuffle_seed, dataloadershuffle=False)

    input_list, label_list, group_list = dataloader_to_inputlist(trainloader, return_all=True)
    input_list, label_list, group_list = np.array(input_list), np.array(label_list), np.array(group_list)

    kmeans_input_list = np.concatenate((input_list, np.expand_dims(label_list, axis=1)), axis=1)

    with open('temp_clusters1070.npy', 'rb') as f:
        group_ids = np.load(f)
        group_counts = np.load(f)
        kmeans_labels = np.load(f)


    chosen_cluster_arr = [800, 522, 378, 231]
    for chosen_cluster in chosen_cluster_arr:
        print("Chosen Cluster : ", chosen_cluster)
        ## Stage 1 : One cluster at start and Other cluster at the end
        chosen_input, chosen_label, chosen_group = input_list[kmeans_labels==chosen_cluster], label_list[kmeans_labels==chosen_cluster], group_list[kmeans_labels==chosen_cluster]
        middle_input, middle_label, middle_group = input_list[kmeans_labels!=chosen_cluster], label_list[kmeans_labels!=chosen_cluster], group_list[kmeans_labels!=chosen_cluster]

        chosen_input, chosen_label, chosen_group = shuffle(chosen_input, chosen_label, chosen_group, random_state=0)
        print("x y z")
        for move_to_front in range(16, len(chosen_input), 16):
            trainloadershifted = get_ordered_dataloader([chosen_input[:-1*move_to_front], middle_input, chosen_input[-1*move_to_front:]], [chosen_label[:-1*move_to_front], middle_label, chosen_label[-1*move_to_front:]], [chosen_group[:-1*move_to_front], middle_group, chosen_group[-1*move_to_front:]])
            chosenloaderafter = get_ordered_dataloader([chosen_input[-1*move_to_front:]], [chosen_label[-1*move_to_front:]], [chosen_group[-1*move_to_front:]])
            chosenloaderbefore = get_ordered_dataloader([chosen_input[:-1*move_to_front]], [chosen_label[:-1*move_to_front]], [chosen_group[:-1*move_to_front]])


            loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(init_seed), shuffle_seed, 299)
            model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=None)
            model = train_mlp(model, trainloadershifted, None, save_ite=False, cuda=use_cuda, losstype=ckptlosstype, epochs=1)

            fscore, labels, preds, groups = test_mlp(model, chosenloaderbefore, cuda=use_cuda, fairness_criteria=None, return_outputs=True)
            acc_before = np.sum(preds==labels)/len(labels)
            fscore, labels, preds, groups = test_mlp(model, chosenloaderafter, cuda=use_cuda, fairness_criteria=None, return_outputs=True)
            acc_after = np.sum(preds==labels)/len(labels)

            print(move_to_front, acc_before*100, acc_after*100)


def plot_fairorder_single_epoch(model, input_size, outplot, ckptlosstype, use_cuda, protected_class, survey_year, state, ckptfldr, plotdata):
    we_init, arch_id, init_seed, shuffle_seed = model

    trainloader, validloader, testloader = get_folktables_dataset(train_valid_test_split=[0.7, 0.1, 0.2],
        protected_class=protected_class, survey_year=survey_year, states=[state], shuffle_seed=shuffle_seed, dataloadershuffle=False, dataset_type='acsincome')

    seed_tuple_list = []
    seed_tuple_list.append((0, 0))
    # for iseed in range(1, 50):
    #     seed_tuple_list.append((iseed, 0))
    # for sseed in range(1, 50):
    #     seed_tuple_list.append((0, sseed))
    for s in range(1, 50):
        seed_tuple_list.append((s, s))

    fscore_arr = []
    avg_odds_arr = []
    for (init_seed, shuffle_seed) in tqdm(seed_tuple_list):
        loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(init_seed), shuffle_seed, 299)
        model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=None)

        trainloader = make_trainloader_ratio(trainloader, [0.1, 0.3, 0.3, 0.3])
        # trainloader = make_trainloader_ratio(trainloader, [0.25, 0.25, 0.25, 0.25])
        model = train_mlp(model, trainloader, None, save_ite=False, cuda=use_cuda, losstype='fairce', epochs=1)
        # model = train_mlp(model, trainloader, None, save_ite=False, cuda=use_cuda, reweigh=True, epochs=1)
        # model = train_mlp(model, trainloader, None, save_ite=False, cuda=use_cuda, epochs=1)
        fscore, avg_odds, labels, preds, groups = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds', return_outputs=True)
        fscore_arr.append(fscore)
        avg_odds_arr.append(avg_odds)

    # print(np.mean(fscore_arr), np.std(fscore_arr))
    # print(np.mean(avg_odds_arr), np.std(avg_odds_arr))
    print(np.percentile(fscore_arr, 25), np.median(fscore_arr), np.percentile(fscore_arr, 75), np.min(fscore_arr), np.max(fscore_arr))
    print(np.percentile(avg_odds_arr, 25), np.median(avg_odds_arr), np.percentile(avg_odds_arr, 75), np.min(avg_odds_arr), np.max(avg_odds_arr))


def get_pareto_front(Xs, Ys, maxX=True, maxY=True):
    sorted_list = sorted([[Xs[i], Ys[i]] for i in range(len(Xs))], reverse=maxX)
    pareto_front = [sorted_list[0]]
    for pair in sorted_list[1:]:
        if maxY:
            if pair[1] >= pareto_front[-1][1]:
                pareto_front.append(pair)
        else:
            if pair[1] <= pareto_front[-1][1]:
                pareto_front.append(pair)

    return [ele[0] for ele in pareto_front], [ele[1] for ele in pareto_front]

def pareto_distance_score(Xp, Yp, Xo, Yo):
    all_scores = []
    for p1x, p1y in zip(Xp, Yp):
        point_dist = 1e10
        for p3x, p3y in zip(Xo, Yo):
            dist = sqrt((p1x - p3x)**2 + (p1y - p3y)**2)
            point_dist = min(point_dist, dist)
        all_scores.append(point_dist)
    # return np.mean(all_scores)
    return np.max(all_scores)

def figure_sampling_heatmap(model, input_size, outplot, ckptlosstype, use_cuda, protected_class, survey_year, state, ckptfldr, plotdata):
    we_init, arch_id, init_seed, shuffle_seed = model

    trainloader, validloader, testloader = get_folktables_dataset(train_valid_test_split=[0.7, 0.1, 0.2],
        protected_class=protected_class, survey_year=survey_year, states=[state], shuffle_seed=shuffle_seed, dataset_type='acsemployment')

    end_all_seed = 200
    end_epoch = 50
    epoch_offset = 250

    fscore_mat = []
    eopp_mat = []
    all_seed_list = list(range(end_all_seed))
    for all_seed in tqdm(all_seed_list):
        fscore_arr = []
        eopp_arr = []
        for epoch in range(end_epoch):
            loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(all_seed), all_seed, epoch + epoch_offset)
            model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
            fscore, eopp_per = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds')

            fscore_arr.append(fscore)
            eopp_arr.append(eopp_per)

        fscore_mat.append(fscore_arr)
        eopp_mat.append(eopp_arr)

    eopp_pareto, fscore_pareto = get_pareto_front(np.array(eopp_mat).flatten(), np.array(fscore_mat).flatten(), maxX=False)

    overall_heatmap = []
    for repeats in tqdm(range(50)):
        temp = list(zip(fscore_mat, eopp_mat))
        random.shuffle(temp)
        fscore_mat, eopp_mat = zip(*temp)
        heatmap_mat = []
        for all_seed in range(end_epoch):
            heatmap_arr = []
            for epoch in range(end_epoch):
                eopp_subset = np.array(eopp_mat)[:(all_seed+1), (-1*epoch-1):].flatten()
                fscore_subset = np.array(fscore_mat)[:(all_seed+1), (-1*epoch-1):].flatten()

                eopp_subset, fscore_subset = get_pareto_front(eopp_subset, fscore_subset, maxX=False)

                # ds = np.min(eopp_subset)
                ds = max(pareto_distance_score(eopp_pareto, fscore_pareto, eopp_subset, fscore_subset),
                         pareto_distance_score(eopp_subset, fscore_subset, eopp_pareto, fscore_pareto))
                heatmap_arr.append(ds)
            heatmap_mat.append(heatmap_arr)
        overall_heatmap.append(heatmap_mat)

    overall_heatmap_avg = np.mean(overall_heatmap, axis=0)
    overall_heatmap_std = np.std(overall_heatmap, axis=0)

    font = {'size'   : 22}
    plt.rc('font', **font)

    fig, ax = plt.subplots()
    fmt = lambda x, pos: '{:.1f}'.format(x)
    im = ax.imshow(overall_heatmap_avg, cmap="YlGn")
    # cbar = ax.figure.colorbar(im, ax=ax, ticks=[0.5, 1, 1.5, 2, 2.5])
    # cbar = ax.figure.colorbar(im, ax=ax, ticks=[5.5, 6.0, 6.5, 7.0])
    # cbar = ax.figure.colorbar(im, ax=ax, ticks=[2.5, 4.0, 5.5, 7.0])
    cbar = ax.figure.colorbar(im, ax=ax, format=FuncFormatter(fmt))
    cbar.ax.set_ylabel("Hausdorff Distance", rotation=-90, va="bottom")
    # cbar.ax.set_yticks([0, 1])
    # cbar.ax.set_ylabel("Average Odds", rotation=-90, va="bottom")

    # ax.axis('off')
    # ax.spines['top'].set_visible(False)
    # ax.spines['left'].set_visible(False)
    # ax.spines['bottom'].set_visible(False)
    # ax.spines['right'].set_visible(False)
    ax.set_xlabel("Epochs")
    ax.set_ylabel("Seeds")
    ax.invert_yaxis()

    plt.tight_layout()
    plt.savefig(outplot)
    plt.clf()


def figure_sampling_violin(model, input_size, outplot, ckptlosstype, use_cuda, protected_class, survey_year, state, ckptfldr, plotdata):
    we_init, arch_id, init_seed, shuffle_seed = model

    trainloader, validloader, testloader = get_folktables_dataset(train_valid_test_split=[0.7, 0.1, 0.2],
        protected_class=protected_class, survey_year=survey_year, states=[state], shuffle_seed=shuffle_seed, dataset_type='acsincome')

    seed_range = 200
    epoch_start = 100
    epoch_end = 300

    avgodds_seed_arr200 = []
    for all_seed in tqdm(range(seed_range)):
        loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(all_seed), all_seed, 199)
        model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
        fscore, eopp_per = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds')

        avgodds_seed_arr200.append(eopp_per)

    avgodds_seed_arr250 = []
    for all_seed in tqdm(range(seed_range)):
        loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(all_seed), all_seed, 249)
        model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
        fscore, eopp_per = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds')

        avgodds_seed_arr250.append(eopp_per)

    avgodds_seed_arr300 = []
    for all_seed in tqdm(range(seed_range)):
        loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(all_seed), all_seed, 299)
        model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
        fscore, eopp_per = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds')

        avgodds_seed_arr300.append(eopp_per)

    avgodds_epoch_arr = []
    for epoch in tqdm(range(epoch_start, epoch_end)):
        loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(0), 0, epoch)
        model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
        fscore, eopp_per = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds')

        avgodds_epoch_arr.append(eopp_per)

    avgodds_epoch_arr2 = []
    for epoch in tqdm(range(epoch_start, epoch_end)):
        loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(1), 1, epoch)
        model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
        fscore, eopp_per = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds')

        avgodds_epoch_arr2.append(eopp_per)

    avgodds_epoch_arr3 = []
    for epoch in tqdm(range(epoch_start, epoch_end)):
        loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(2), 2, epoch)
        model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
        fscore, eopp_per = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds')

        avgodds_epoch_arr3.append(eopp_per)

    all_labels = np.concatenate((['Across Seeds \n(Epoch : 200)' for ele in range(len(avgodds_seed_arr200))],
                                 ['Across Seeds \n(Epoch : 250)' for ele in range(len(avgodds_seed_arr250))],
                                 ['Across Seeds \n(Epoch : 300)' for ele in range(len(avgodds_seed_arr300))],
                                 ['Across Epochs \n(Seed : 0)' for ele in range(len(avgodds_epoch_arr))],
                                 ['Across Epochs \n(Seed : 1)' for ele in range(len(avgodds_epoch_arr2))],
                                 ['Across Epochs \n(Seed : 2)' for ele in range(len(avgodds_epoch_arr3))]))
    all_avgodds = np.concatenate((avgodds_seed_arr200, avgodds_seed_arr250, avgodds_seed_arr300,
                                  avgodds_epoch_arr, avgodds_epoch_arr2, avgodds_epoch_arr3))

    font = {'size'   : 22}
    plt.rc('font', **font)
    data = {'Average Odds': all_avgodds, 'Sampling': all_labels}
    df = pd.DataFrame(data=data)
    sns.set(rc={'figure.figsize':(11,7), 'axes.facecolor':'white', 'figure.facecolor':'white', 'axes.linewidth': 1, 'axes.edgecolor':'black'})
    my_pal = {'Across Seeds \n(Epoch : 200)': '#1aa3ff', 'Across Seeds \n(Epoch : 250)': '#1aa3ff', 'Across Seeds \n(Epoch : 300)': '#1aa3ff',
              'Across Epochs \n(Seed : 0)': '#ff4d4d', 'Across Epochs \n(Seed : 1)': '#ff4d4d', 'Across Epochs \n(Seed : 2)': '#ff4d4d'}
    # my_pal = {'Across Seeds \n(Epoch : 200)': '#99d6ff', 'Across Seeds \n(Epoch : 250)': '#1aa3ff', 'Across Seeds \n(Epoch : 300)': '#006bb3',
    #           'Across Epochs \n(Seed : 0)': '#ff9999', 'Across Epochs \n(Seed : 1)': '#ff4d4d', 'Across Epochs \n(Seed : 2)': '#cc0000'}
    # my_pal = {'Across Seeds \n(Epoch : 200)': '#99ebff', 'Across Seeds \n(Epoch : 250)': '#99ccff', 'Across Seeds \n(Epoch : 300)': '#33cccc',
    #           'Across Epochs \n(Seed : 0)': '#ff9966', 'Across Epochs \n(Seed : 1)': '#ff9999', 'Across Epochs \n(Seed : 2)': '#ff99cc'}
    ax = sns.violinplot(x="Sampling", y="Average Odds", data=df, inner='quartile', palette=my_pal)
    for _, spine in ax.spines.items():
        spine.set_visible(True)
    # ax.set_xlabel("Epochs")
    # ax.set_ylabel("Seeds")
    # ax.invert_yaxis()
    fig = ax.get_figure()
    fig.savefig(outplot)


def figure_sampling_kstest(model, input_size, outplot, ckptlosstype, use_cuda, protected_class, survey_year, state, ckptfldr, plotdata):
    we_init, arch_id, init_seed, shuffle_seed = model

    trainloader, validloader, testloader = get_folktables_dataset(train_valid_test_split=[0.7, 0.1, 0.2],
        protected_class=protected_class, survey_year=survey_year, states=[state], shuffle_seed=shuffle_seed)

    avgodds_seed_arr300 = []
    for all_seed in range(200):
        loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(all_seed), all_seed, 299)
        model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
        fscore, eopp_per = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds')

        avgodds_seed_arr300.append(eopp_per)

    avgodds_epoch_arr = []
    for epoch in range(100, 300):
        loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(0), 0, epoch)
        model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
        fscore, eopp_per = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds')

        avgodds_epoch_arr.append(eopp_per)

    kstest_score = kstest(avgodds_seed_arr300, avgodds_epoch_arr)
    # kstest_score = ks_2samp(avgodds_seed_arr300, avgodds_epoch_arr)
    print(kstest_score)


def figure_forgetting_location_kstest(model, input_size, outplot, ckptlosstype, use_cuda, protected_class, survey_year, state, ckptfldr, plotdata):
    we_init, arch_id, init_seed, shuffle_seed = model

    trainloader, validloader, testloader = get_folktables_dataset(train_valid_test_split=[0.7, 0.1, 0.2],
        protected_class=protected_class, survey_year=survey_year, states=[state], shuffle_seed=shuffle_seed, dataloadershuffle=False)

    start_epoch = 100
    end_epoch = 300

    inputlist = dataloader_to_inputlist(trainloader)

    loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(init_seed), shuffle_seed, start_epoch-1)
    model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
    fscore, avg_odds, prevlabels, prevpreds, prevgroups = test_mlp(model, trainloader, cuda=use_cuda, fairness_criteria='avgodds', return_outputs=True)

    forgotten, memorized = [], []
    for epoch in tqdm(range(start_epoch, end_epoch)):
        loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(init_seed), shuffle_seed, epoch)
        model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
        fscore, avg_odds, labels, preds, groups = test_mlp(model, trainloader, cuda=use_cuda, fairness_criteria='avgodds', return_outputs=True)

        forgotten.append(np.logical_and((preds != prevpreds), (labels==prevpreds)))
        memorized.append(np.logical_and((preds != prevpreds), (labels==preds)))

        prevlabels, prevpreds, prevgroups = labels, preds, groups

    forgotten, memorized = np.array(forgotten).transpose(), np.array(memorized).transpose()

    forgetting_groupid = []
    forgetting_behavior_dict = {}
    input_index_dict = {}
    group_label = 0
    for i, (inp, fg, mm) in enumerate(zip(inputlist, forgotten, memorized)):
        forgetting_memorized_tuple = tuple(np.concatenate((fg, mm)))

        if forgetting_memorized_tuple in forgetting_behavior_dict:
            group_local = forgetting_behavior_dict[forgetting_memorized_tuple]
        else:
            group_local = group_label
            group_label += 1
            forgetting_behavior_dict[forgetting_memorized_tuple] = group_local

        input_index_dict[tuple(inp)] = i
        forgetting_groupid.append(group_local)

    trainloadershuffle, validloadershuffle, testloadershuffle = get_folktables_dataset(train_valid_test_split=[0.7, 0.1, 0.2],
        protected_class=protected_class, survey_year=survey_year, states=[state], shuffle_seed=shuffle_seed, dataloadershuffle=True)

    location = []
    for epoch in tqdm(range(start_epoch, end_epoch)):
        torch.manual_seed(epoch-1)
        input_local = dataloader_to_inputlist(trainloadershuffle)

        for i, inp in enumerate(input_local):
            if tuple(inp) not in input_index_dict:
                continue
            inpind = input_index_dict[tuple(inp)]
            was_forgotten = forgotten[inpind, epoch-start_epoch]
            # was_memorized = memorized[inpind, epoch-start_epoch]
            if was_forgotten:
            # if was_memorized:
                location.append(i/len(inputlist))

    kstest_score = kstest(location, uniform.cdf)
    # kstest_score = ks_2samp(avgodds_seed_arr300, avgodds_epoch_arr)
    print(kstest_score)
    plt.hist(location, bins=50)
    plt.xlabel('Data Location When Forgotten')
    plt.ylabel('Number of Examples')
    plt.tight_layout()
    plt.savefig('tests.png')


def single_epoch_data_order(model, input_size, outplot, ckptlosstype, use_cuda, protected_class, survey_year, state, ckptfldr, plotdata):
    we_init, arch_id, init_seed, shuffle_seed = model

    loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(init_seed), shuffle_seed, 299)
    model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=None)

    trainloader, validloader, testloader = get_folktables_dataset(train_valid_test_split=[0.7, 0.1, 0.2],
        protected_class=protected_class, survey_year=survey_year, states=[state], shuffle_seed=500, dataloadershuffle=False, dataset_type='acsincome',
        fairbatch=True, model=model)

    fscore, avg_odds, labels, preds, groups = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds', return_outputs=True)
    print("Average Odds Before : ", avg_odds)

    for i in range(20):
        model = train_mlp(model, trainloader, None, save_ite=False, cuda=use_cuda, losstype=ckptlosstype, epochs=1)
        fscore, avg_odds, labels, preds, groups = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds', return_outputs=True)
        print("Average Odds After : ", avg_odds)


def make_trainloader_ratio(trainloader, ratioarr):
    min_pos_ratio, min_neg_ratio, maj_pos_ratio, maj_neg_ratio = ratioarr[0], ratioarr[1], ratioarr[2], ratioarr[3]

    all_inputs, all_labels, all_groups = [], [], []
    batch_size = None
    for data in trainloader:
        inputs, labels, groups = data
        inputs, labels, groups = inputs.cpu().detach().numpy(), labels.cpu().detach().numpy(), groups.cpu().detach().numpy()
        batch_size = len(inputs)
        all_inputs.extend(inputs)
        all_labels.extend(labels)
        all_groups.extend(groups)

    all_inputs, all_labels, all_groups = np.array(all_inputs), np.array(all_labels), np.array(all_groups)

    x_o_pos, y_o_pos, group_o_pos = all_inputs[all_groups==1][all_labels[all_groups==1]==1], all_labels[all_groups==1][all_labels[all_groups==1]==1], all_groups[all_groups==1][all_labels[all_groups==1]==1]
    x_o_neg, y_o_neg, group_o_neg = all_inputs[all_groups==1][all_labels[all_groups==1]==0], all_labels[all_groups==1][all_labels[all_groups==1]==0], all_groups[all_groups==1][all_labels[all_groups==1]==0]
    x_z_pos, y_z_pos, group_z_pos = all_inputs[all_groups==0][all_labels[all_groups==0]==1], all_labels[all_groups==0][all_labels[all_groups==0]==1], all_groups[all_groups==0][all_labels[all_groups==0]==1]
    x_z_neg, y_z_neg, group_z_neg = all_inputs[all_groups==0][all_labels[all_groups==0]==0], all_labels[all_groups==0][all_labels[all_groups==0]==0], all_groups[all_groups==0][all_labels[all_groups==0]==0]

    new_inputs, new_labels, new_groups = [], [], []
    for i in range(0, len(all_inputs), batch_size):
        min_pos_si, min_pos_ei = int(min_pos_ratio*i), int(min_pos_ratio*(i+batch_size))
        min_neg_si, min_neg_ei = int(min_neg_ratio*i), int(min_neg_ratio*(i+batch_size))
        maj_pos_si, maj_pos_ei = int(maj_pos_ratio*i), int(maj_pos_ratio*(i+batch_size))
        maj_neg_si, maj_neg_ei = int(maj_neg_ratio*i), int(maj_neg_ratio*(i+batch_size))

        if min_pos_ei > len(x_o_pos) or min_neg_ei > len(x_o_neg) or maj_pos_ei > len(x_z_pos) or maj_neg_ei > len(x_z_neg):
            left_inputs = np.concatenate((x_o_pos[min_pos_si:], x_o_neg[min_neg_si:], x_z_pos[maj_pos_si:], x_z_neg[maj_neg_si:]), axis=0)
            left_labels = np.concatenate((y_o_pos[min_pos_si:], y_o_neg[min_neg_si:], y_z_pos[maj_pos_si:], y_z_neg[maj_neg_si:]), axis=0)
            left_groups = np.concatenate((group_o_pos[min_pos_si:], group_o_neg[min_neg_si:], group_z_pos[maj_pos_si:], group_z_neg[maj_neg_si:]), axis=0)

            left_inputs, left_labels, left_groups = shuffle(left_inputs, left_labels, left_groups, random_state=0)
            new_inputs.extend(left_inputs)
            new_labels.extend(left_labels)
            new_groups.extend(left_groups)

            break

        new_inputs.extend(x_o_pos[min_pos_si:min_pos_ei])
        new_inputs.extend(x_o_neg[min_neg_si:min_neg_ei])
        new_inputs.extend(x_z_pos[maj_pos_si:maj_pos_ei])
        new_inputs.extend(x_z_neg[maj_neg_si:maj_neg_ei])

        new_labels.extend(y_o_pos[min_pos_si:min_pos_ei])
        new_labels.extend(y_o_neg[min_neg_si:min_neg_ei])
        new_labels.extend(y_z_pos[maj_pos_si:maj_pos_ei])
        new_labels.extend(y_z_neg[maj_neg_si:maj_neg_ei])

        new_groups.extend(group_o_pos[min_pos_si:min_pos_ei])
        new_groups.extend(group_o_neg[min_neg_si:min_neg_ei])
        new_groups.extend(group_z_pos[maj_pos_si:maj_pos_ei])
        new_groups.extend(group_z_neg[maj_neg_si:maj_neg_ei])

    new_inputs, new_labels, new_groups = np.array(new_inputs[::-1]), np.array(new_labels[::-1]), np.array(new_groups[::-1])
    tensorx_train = torch.from_numpy(new_inputs)
    tensory_train = torch.from_numpy(new_labels)
    tensorgroup_train = torch.from_numpy(new_groups)
    train_dataset = torch.utils.data.TensorDataset(tensorx_train, tensory_train, tensorgroup_train)

    return torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

def table_group_accuracy_manipulate(model, input_size, outplot, ckptlosstype, use_cuda, protected_class, survey_year, state, ckptfldr, plotdata):
    we_init, arch_id, init_seed, shuffle_seed = model

    trainloader, validloader, testloader = get_folktables_dataset(train_valid_test_split=[0.7, 0.1, 0.2],
        protected_class=protected_class, survey_year=survey_year, states=[state], shuffle_seed=10, dataloadershuffle=False, dataset_type='acsemployment')

    fig, ax = plt.subplots(figsize=(8, 5))

    seed_tuple_list = []
    seed_tuple_list.append((0, 0))
    # for iseed in range(1, 50):
    #     seed_tuple_list.append((iseed, 0))
    # for sseed in range(1, 50):
    #     seed_tuple_list.append((0, sseed))
    for s in range(1, 50):
        seed_tuple_list.append((s, s))

    for ratio in [0.1, 0.2, 0.5, 1, 2, 5, 10]:
        min_pos_acc = []
        min_neg_acc = []
        maj_pos_acc = []
        maj_neg_acc = []
        overall_acc = []
        for (init_seed, shuffle_seed) in tqdm(seed_tuple_list):
            loadfile = get_loadfile(ckptfldr, protected_class, ckptlosstype, we_init, arch_id + '_' + str(init_seed), shuffle_seed, 299)
            model = get_model(arch_id, input_size, id_to_arch_dict[arch_id], we_init, ckpt=loadfile, cuda=use_cuda, dataloader=trainloader)
            fscore, avg_odds, labels, preds, groups = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds', return_outputs=True)

            # ACSIncome
            # min_neg = 0.48/(1+ratio)
            # min_pos = ratio*min_neg
            # maj_neg = 0.52/(1+ratio)
            # maj_pos = ratio*maj_neg
            # trainloader = make_trainloader_ratio(trainloader, [min_pos, min_neg, 0.24, 0.28])
            # trainloader = make_trainloader_ratio(trainloader, [0.17, 0.31, maj_pos, maj_neg])

            # CelebA
            # min_neg = 0.58/(1+ratio)
            # min_pos = ratio*min_neg
            # maj_neg = 0.42/(1+ratio)
            # maj_pos = ratio*maj_neg
            # trainloader = make_trainloader_ratio(trainloader, [min_pos, min_neg, 0.17, 0.25])
            # trainloader = make_trainloader_ratio(trainloader, [0.31, 0.27, maj_pos, maj_neg])

            # ACSEmployment
            # min_neg = 0.51/(1+ratio)
            # min_pos = ratio*min_neg
            maj_neg = 0.49/(1+ratio)
            maj_pos = ratio*maj_neg
            # trainloader = make_trainloader_ratio(trainloader, [min_pos, min_neg, 0.24, 0.25])
            trainloader = make_trainloader_ratio(trainloader, [0.22, 0.29, maj_pos, maj_neg])


            model = train_mlp(model, trainloader, None, save_ite=False, cuda=use_cuda, losstype=ckptlosstype, epochs=1)
            fscore, avg_odds, labels, preds, groups = test_mlp(model, testloader, cuda=use_cuda, fairness_criteria='avgodds', return_outputs=True)

            all_corr = labels==preds
            min_corr = labels[groups==1]==preds[groups==1]
            maj_corr = labels[groups==0]==preds[groups==0]

            min_pos_acc.append(np.sum(min_corr[labels[groups==1]==1])/len(min_corr[labels[groups==1]==1]))
            min_neg_acc.append(np.sum(min_corr[labels[groups==1]==0])/len(min_corr[labels[groups==1]==0]))
            maj_pos_acc.append(np.sum(maj_corr[labels[groups==0]==1])/len(maj_corr[labels[groups==0]==1]))
            maj_neg_acc.append(np.sum(maj_corr[labels[groups==0]==0])/len(maj_corr[labels[groups==0]==0]))

            overall_acc.append(np.sum(all_corr)/len(all_corr))

        min_pos_avg, min_pos_std = 100*np.mean(min_pos_acc), 100*np.std(min_pos_acc)
        min_neg_avg, min_neg_std = 100*np.mean(min_neg_acc), 100*np.std(min_neg_acc)
        maj_pos_avg, maj_pos_std = 100*np.mean(maj_pos_acc), 100*np.std(maj_pos_acc)
        maj_neg_avg, maj_neg_std = 100*np.mean(maj_neg_acc), 100*np.std(maj_neg_acc)
        overall_avg, overall_std = 100*np.mean(overall_acc), 100*np.std(overall_acc)

        print(ratio,
              min_pos_avg, min_pos_avg + min_pos_std, min_pos_avg - min_pos_std,
              min_neg_avg, min_neg_avg + min_neg_std, min_neg_avg - min_neg_std,
              maj_pos_avg, maj_pos_avg + maj_pos_std, maj_pos_avg - maj_pos_std,
              maj_neg_avg, maj_neg_avg + maj_neg_std, maj_neg_avg - maj_neg_std,
              overall_avg, overall_avg + overall_std, overall_avg - overall_std)


parser = argparse.ArgumentParser()
parser.add_argument("--mode", default="scatter_variance", help="Set Experiment Mode")
parser.add_argument("--arch", default="mlp_64", help="Model Architecture")
parser.add_argument("--loss", default="ce", help="Loss Type for Training/Finetuning")
parser.add_argument("--pc", default="sex", help="Protected Class")
parser.add_argument("--year", type=int, default=2018, help="Dataset Year")
parser.add_argument("--state", default="CA", help="Dataset State")
parser.add_argument("--ckptfldr", default="folktables2018CAstability", help="Folder for Saving Model Files")
parser.add_argument("--ckptloss", default=None, help="Loss Type for Training")
parser.add_argument("--finetuneloss", default=None, help="Loss Type for Finetuning")
parser.add_argument("--shuffle_seed", type=int, default=0, help="Seed for mini-batch shuffling")
parser.add_argument("--init_seed", type=int, default=0, help="Seed for weight initialization")
parser.add_argument("--num_feat", type=int, default=10, help="Number of Input Features in Task")
parser.add_argument("--epochs", type=int, default=150, help="Number of Training Epochs")
parser.add_argument("--lr", type=float, default=0.001, help="Learning Rate for Training")
parser.add_argument("--outplot", default="noname.png", help="Name of Output Graph (If any)")
parser.add_argument("--gpus", default="0,1", help="GPU Device ID to use")
parser.add_argument("--cuda", action="store_true", help="Use CUDA")
parser.add_argument("--plotdata", default='train', help="Dataset to Use During Plotting Tools")

args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus

if args.mode=='plot_decouple_boxplot':
    model = ['kaiming_normal', 'mlp_64', args.init_seed, args.shuffle_seed]
    plot_decouple_boxplot(model, args.num_feat, args.outplot, args.ckptloss, args.cuda, args.pc, args.year, args.state, args.ckptfldr, args.plotdata)
if args.mode=='plot_decouple_lineplot':
    model = ['kaiming_normal', 'mlp_64', args.init_seed, args.shuffle_seed]
    plot_decouple_lineplot(model, args.num_feat, args.outplot, args.ckptloss, args.cuda, args.pc, args.year, args.state, args.ckptfldr, args.plotdata)
if args.mode=='plot_variance_epochs_boxplot':
    model = ['kaiming_normal', 'mlp_64', args.init_seed, args.shuffle_seed]
    plot_variance_epochs_boxplot(model, args.num_feat, args.outplot, args.ckptloss, args.cuda, args.pc, args.year, args.state, args.ckptfldr, args.plotdata)
if args.mode=='plot_forgetting_total_epochs_lineplot':
    model = ['kaiming_normal', 'mlp_64', args.init_seed, args.shuffle_seed]
    plot_forgetting_total_epochs_lineplot(model, args.num_feat, args.outplot, args.ckptloss, args.cuda, args.pc, args.year, args.state, args.ckptfldr, args.plotdata)
if args.mode=='plot_immediate_data_order_boxplot':
    model = ['kaiming_normal', 'mlp_64', args.init_seed, args.shuffle_seed]
    plot_immediate_data_order_boxplot(model, args.num_feat, args.outplot, args.ckptloss, args.cuda, args.pc, args.year, args.state, args.ckptfldr, args.plotdata)
if args.mode=='plot_group_forgetting_lineplot':
    model = ['kaiming_normal', 'mlp_64', args.init_seed, args.shuffle_seed]
    plot_group_forgetting_lineplot(model, args.num_feat, args.outplot, args.ckptloss, args.cuda, args.pc, args.year, args.state, args.ckptfldr, args.plotdata)
if args.mode=='plot_fairorder_single_epoch':
    model = ['kaiming_normal', 'mlp_64', args.init_seed, args.shuffle_seed]
    plot_fairorder_single_epoch(model, args.num_feat, args.outplot, args.ckptloss, args.cuda, args.pc, args.year, args.state, args.ckptfldr, args.plotdata)
if args.mode=='figure_sampling_heatmap':
    model = ['kaiming_normal', 'mlp_64', args.init_seed, args.shuffle_seed]
    figure_sampling_heatmap(model, args.num_feat, args.outplot, args.ckptloss, args.cuda, args.pc, args.year, args.state, args.ckptfldr, args.plotdata)
if args.mode=='figure_sampling_violin':
    model = ['kaiming_normal', 'mlp_64', args.init_seed, args.shuffle_seed]
    figure_sampling_violin(model, args.num_feat, args.outplot, args.ckptloss, args.cuda, args.pc, args.year, args.state, args.ckptfldr, args.plotdata)
if args.mode=='figure_sampling_kstest':
    model = ['kaiming_normal', 'mlp_64', args.init_seed, args.shuffle_seed]
    figure_sampling_kstest(model, args.num_feat, args.outplot, args.ckptloss, args.cuda, args.pc, args.year, args.state, args.ckptfldr, args.plotdata)
if args.mode=='figure_forgetting_location_kstest':
    model = ['kaiming_normal', 'mlp_64', args.init_seed, args.shuffle_seed]
    figure_forgetting_location_kstest(model, args.num_feat, args.outplot, args.ckptloss, args.cuda, args.pc, args.year, args.state, args.ckptfldr, args.plotdata)
if args.mode=='single_epoch_data_order':
    model = ['kaiming_normal', 'mlp_64', args.init_seed, args.shuffle_seed]
    single_epoch_data_order(model, args.num_feat, args.outplot, args.ckptloss, args.cuda, args.pc, args.year, args.state, args.ckptfldr, args.plotdata)
if args.mode=='table_group_accuracy_manipulate':
    model = ['kaiming_normal', 'mlp_64', args.init_seed, args.shuffle_seed]
    table_group_accuracy_manipulate(model, args.num_feat, args.outplot, args.ckptloss, args.cuda, args.pc, args.year, args.state, args.ckptfldr, args.plotdata)
