import argparse
import torch
import sys
import math
from keras.utils import to_categorical
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import pickle
import sklearn
from sklearn.preprocessing import LabelEncoder
from sklearn.datasets import fetch_openml
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.tensorboard import SummaryWriter
from Dataloaders import TablePandasDataset, TablePandasDataset_Noise_Estimate
from Misc import privatization, proj_simplex, compute_pi, get_weight_dict
from network.networks import *

################################################################################################################################
### Auto
### gender as the sensitive attribute

### load the auto dataset
def load_auto(seed, privacy, ratio, subset):
    if ((ratio == 1) & (subset == 1)):
        data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_{privacy}.csv').copy()    
    if ((ratio == 1) & (subset != 1)):
        data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_subset_{subset}/auto_{privacy}.csv').copy()
    if ((ratio != 1) & (subset == 1)):
        data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_uneven/auto_uneven_{ratio}-1_{privacy}.csv').copy()
    if ((ratio != 1) & (subset != 1)):
        data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_subset_{subset}/auto_uneven/auto_uneven_{ratio}-1_{privacy}.csv').copy()  
        
    tags = ['Gender', 'Gender_noisy', 'OUTCOME']
    for t in tags:
        data[t + '_cat'] = pd.Categorical(data[t]).codes

    col_tags = list(data.columns)
    remove_tags = ['Gender', 'Gender_cat', 'Gender_noisy', 'Gender_noisy_cat', 'OUTCOME', 'OUTCOME_cat']
    for rm in remove_tags:
        col_tags.remove(rm)

    utility_tag = 'OUTCOME_cat'
    secret_tag = 'Gender_cat'
    noisy_tag = 'Gender_noisy_cat'

    data['OUTCOME'] = pd.Categorical(data['OUTCOME']).codes
    data['AGE'] = pd.Categorical(data['AGE']).codes
    data['RACE'] = pd.Categorical(data['RACE']).codes
    data['DRIVING_EXPERIENCE'] = pd.Categorical(data['DRIVING_EXPERIENCE']).codes
    data['EDUCATION'] = pd.Categorical(data['EDUCATION']).codes
    data['INCOME'] = pd.Categorical(data['INCOME']).codes 
    data['VEHICLE_YEAR'] = pd.Categorical(data['VEHICLE_YEAR']).codes
    data['MARRIED'] = pd.Categorical(data['MARRIED']).codes
    data['CHILDREN'] = pd.Categorical(data['CHILDREN']).codes
    data['POSTAL_CODE'] = pd.Categorical(data['POSTAL_CODE']).codes
    data['VEHICLE_TYPE'] = pd.Categorical(data['VEHICLE_TYPE']).codes

    
    sc = StandardScaler()
    data[col_tags] = sc.fit_transform(data[col_tags])

    data_train, data_test = train_test_split(data, test_size = 0.2, random_state = seed, stratify = data['OUTCOME_cat'])

    data_train, data_val = train_test_split(data_train, test_size = 0.125, random_state = seed, stratify = data_train['OUTCOME_cat'])

    
    return data_train, data_test, data_val, col_tags



### load data to estimate pi from the data (n1 = 1, 2, 4)
def load_auto_NE(seed, privacy, n1, split, ratio, subset):
    # balanced case
    if ((ratio == 1) & (subset == 1)):
        # n1 = 1
        if n1 == 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_{privacy}.csv').copy()
        # n1 = 2, 4
        if n1 != 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_split_{ratio}/auto_{ratio}_{privacy}_{n1}_{split}.csv').copy()
    if ((ratio == 1) & (subset != 1)):
        # n1 = 1
        if n1 == 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_subset_{subset}/auto_{privacy}.csv').copy()
        # n1 = 2, 4
        if n1 != 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_subset_{subset}/auto_split_{ratio}/auto_{ratio}_{privacy}_{n1}_{split}.csv').copy()
    
    # unbalanced case
    if ((ratio != 1) & (subset == 1)):
        # n1 = 1
        if n1 == 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_uneven/auto_uneven_{ratio}-1_{privacy}.csv').copy()
        # n1 = 2, 4
        if n1 != 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_split_{ratio}/auto_{ratio}_{privacy}_{n1}_{split}.csv').copy()
    if ((ratio != 1) & (subset != 1)):
        # n1 = 1
        if n1 == 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_subset_{subset}/auto_uneven/auto_uneven_{ratio}-1_{privacy}.csv').copy()
        # n1 = 2, 4
        if n1 != 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_subset_{subset}/auto_split_{ratio}/auto_{ratio}_{privacy}_{n1}_{split}.csv').copy()

    tags = ['Gender', 'Gender_noisy', 'OUTCOME']
    for t in tags:
        data[t + '_cat'] = pd.Categorical(data[t]).codes

    col_tags = list(data.columns)
    remove_tags = ['Gender', 'Gender_cat', 'Gender_noisy', 'Gender_noisy_cat', 'OUTCOME', 'OUTCOME_cat']
    for rm in remove_tags:
        col_tags.remove(rm)

    utility_tag = 'Gender_cat'
    secret_tag = 'Gender_cat'
    noisy_tag = 'Gender_noisy_cat'

    data['OUTCOME'] = pd.Categorical(data['OUTCOME']).codes
    data['AGE'] = pd.Categorical(data['AGE']).codes
    data['RACE'] = pd.Categorical(data['RACE']).codes
    data['DRIVING_EXPERIENCE'] = pd.Categorical(data['DRIVING_EXPERIENCE']).codes
    data['EDUCATION'] = pd.Categorical(data['EDUCATION']).codes
    data['INCOME'] = pd.Categorical(data['INCOME']).codes 
    data['VEHICLE_YEAR'] = pd.Categorical(data['VEHICLE_YEAR']).codes
    data['MARRIED'] = pd.Categorical(data['MARRIED']).codes
    data['CHILDREN'] = pd.Categorical(data['CHILDREN']).codes
    data['POSTAL_CODE'] = pd.Categorical(data['POSTAL_CODE']).codes
    data['VEHICLE_TYPE'] = pd.Categorical(data['VEHICLE_TYPE']).codes

    
    sc = StandardScaler()
    data[col_tags] = sc.fit_transform(data[col_tags])

    data_train, data_test = train_test_split(data, test_size = 0.2, random_state = seed, stratify = data['Gender_noisy_cat'])

    data_train, data_val = train_test_split(data_train, test_size = 0.125, random_state = seed, stratify = data_train['Gender_noisy_cat'])

    
    return data_train, data_test, data_val, col_tags



### used to compute Pi matrix with the trained model
def load_auto_NE_Pi(privacy, ratio, split, n1, subset):
    # balanced case
    if ((ratio == 1) & (subset == 1)):
        # n1 = 1
        if n1 == 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_{privacy}.csv').copy()
        # n1 = 2, 4
        if n1 != 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_split_{ratio}/auto_{ratio}_{privacy}_{n1}_{split}.csv').copy()
    if ((ratio == 1) & (subset != 1)):
        # n1 = 1
        if n1 == 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_subset_{subset}/auto_{privacy}.csv').copy()
        # n1 = 2, 4
        if n1 != 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_subset_{subset}/auto_split_{ratio}/auto_{ratio}_{privacy}_{n1}_{split}.csv').copy()
    
    # unbalanced case
    if ((ratio != 1) & (subset == 1)):
        # n1 = 1
        if n1 == 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_uneven/auto_uneven_{ratio}-1_{privacy}.csv').copy()
        # n1 = 2, 4
        if n1 != 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_split_{ratio}/auto_{ratio}_{privacy}_{n1}_{split}.csv').copy()
    if ((ratio != 1) & (subset != 1)):
        # n1 = 1
        if n1 == 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_subset_{subset}/auto_uneven/auto_uneven_{ratio}-1_{privacy}.csv').copy()
        # n1 = 2, 4
        if n1 != 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/auto/auto_subset_{subset}/auto_split_{ratio}/auto_{ratio}_{privacy}_{n1}_{split}.csv').copy()

    tags = ['Gender', 'Gender_noisy', 'OUTCOME']
    for t in tags:
        data[t + '_cat'] = pd.Categorical(data[t]).codes

    col_tags = list(data.columns)
    remove_tags = ['Gender', 'Gender_cat', 'Gender_noisy', 'Gender_noisy_cat', 'OUTCOME', 'OUTCOME_cat']
    for rm in remove_tags:
        col_tags.remove(rm)

    utility_tag = 'OUTCOME_cat'
    secret_tag = 'Gender_cat'
    noisy_tag = 'Gender_noisy_cat'

    data['OUTCOME'] = pd.Categorical(data['OUTCOME']).codes
    data['AGE'] = pd.Categorical(data['AGE']).codes
    data['RACE'] = pd.Categorical(data['RACE']).codes
    data['DRIVING_EXPERIENCE'] = pd.Categorical(data['DRIVING_EXPERIENCE']).codes
    data['EDUCATION'] = pd.Categorical(data['EDUCATION']).codes
    data['INCOME'] = pd.Categorical(data['INCOME']).codes 
    data['VEHICLE_YEAR'] = pd.Categorical(data['VEHICLE_YEAR']).codes
    data['MARRIED'] = pd.Categorical(data['MARRIED']).codes
    data['CHILDREN'] = pd.Categorical(data['CHILDREN']).codes
    data['POSTAL_CODE'] = pd.Categorical(data['POSTAL_CODE']).codes
    data['VEHICLE_TYPE'] = pd.Categorical(data['VEHICLE_TYPE']).codes

    
    sc = StandardScaler()
    data[col_tags] = sc.fit_transform(data[col_tags])
    data_Pi = data

    return data_Pi, col_tags



### dataloader for
# 1: unawareness            | input: X
# 2: best estimate          | input: X,A
# 3: GSM                    | input: X,A
# 4: GSM-LDP                | input: X,Z 
# 5: Noise Rate Estimate    | input: X
def get_dataloaders_auto(sampler = True,
                            secret_tag ='Gender_cat', 
                            utility_tag ='OUTCOME_cat',
                            noisy_tag = 'Gender_noisy_cat',
                            balanced_tag = 'Gender_cat',
                            shuffle_train = True,
                            shuffle_val = True,
                            drop_last = True,
                            seed = None,
                            privacy = None, 
                            ratio = None,
                            model = None,
                            data_type = None,
                            task = None,
                            subset = None):

    data_train, data_test, data_val, col_tags = load_auto(seed, privacy, ratio, subset)

    batch_size = int(data_train.shape[0] / 8)

    n_utility = data_train[utility_tag].nunique()
    n_secret = data_train[secret_tag].nunique()
    n_noisy = data_train[noisy_tag].nunique()

    ## Dataframe rename
    data_train['secret_cat'] = data_train[secret_tag].apply(lambda x: to_categorical(x, num_classes=n_secret))
    data_test['secret_cat'] = data_test[secret_tag].apply(lambda x: to_categorical(x, num_classes=n_secret))
    data_val['secret_cat'] = data_val[secret_tag].apply(lambda x: to_categorical(x, num_classes=n_secret))

    data_train['noisy_cat'] = data_train[noisy_tag].apply(lambda x: to_categorical(x, num_classes=n_noisy))
    data_test['noisy_cat'] = data_test[noisy_tag].apply(lambda x: to_categorical(x, num_classes=n_noisy))
    data_val['noisy_cat'] = data_val[noisy_tag].apply(lambda x: to_categorical(x, num_classes=n_noisy))

    data_train['utility_cat'] = data_train[utility_tag]
    data_test['utility_cat'] = data_test[utility_tag]
    data_val['utility_cat'] = data_val[utility_tag]
    
    

    weight_dic = get_weight_dict(data_train, balanced_tag)
    train_weights = torch.DoubleTensor(data_train[balanced_tag].apply(lambda x: weight_dic[x]).values)
    train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_weights, len(train_weights))


    composed = None
    if sampler:
        train_dataloader = DataLoader(TablePandasDataset(pd = data_train, cov_list = col_tags,
                                                         utility_tag = 'utility_cat', sensitive_tag = 'secret_cat', 
                                                         noisy_sensitive_tag = 'noisy_cat',
                                                         transform = composed),
                                      batch_size = batch_size,
                                      sampler = train_sampler, pin_memory = True, drop_last = drop_last)
    else:
        train_dataloader = DataLoader(TablePandasDataset(pd = data_train, cov_list = col_tags,
                                                         utility_tag = 'utility_cat', sensitive_tag = 'secret_cat', 
                                                         noisy_sensitive_tag = 'noisy_cat',
                                                         transform = composed),
                                      batch_size= batch_size,
                                      shuffle = shuffle_train, pin_memory=True, drop_last = drop_last)
        print('Not using balanced sampling!')
    
    val_dataloader = DataLoader(TablePandasDataset(pd = data_val, cov_list = col_tags,
                                                   utility_tag = 'utility_cat', sensitive_tag = 'secret_cat', 
                                                   noisy_sensitive_tag = 'noisy_cat', 
                                                   transform = composed),
                                batch_size = batch_size,
                                shuffle = shuffle_val, pin_memory = True, drop_last = drop_last)

    test_dataloader = DataLoader(TablePandasDataset(pd = data_test, cov_list = col_tags,
                                                    utility_tag = 'utility_cat', sensitive_tag = 'secret_cat', 
                                                    noisy_sensitive_tag = 'noisy_cat', 
                                                    transform = composed),
                                 batch_size = batch_size,
                                 shuffle = False, pin_memory = True, drop_last = drop_last)
    
    if model == 'unawareness':
        classifier_network = Unawareness(input_size = len(col_tags), hidden_size = 20)
    if model == 'best-estimate':
        if data_type == 'X':
            classifier_network = Best_Estimate(input_size = len(col_tags) + 2, hidden_size = 20)
        else: 
            classifier_network = Best_Estimate(input_size = 22, hidden_size = 20)
    if model == 'GSM':
        if data_type == 'X':
            classifier_network = GSM(input_size = len(col_tags), hidden_size = 20)
        else: 
            classifier_network = GSM(input_size = 20, hidden_size = 20)
    if model == 'GSM-LDP':
        if data_type == 'X':
            classifier_network = GSM_LDP(input_size = len(col_tags), hidden_size = 20)
        else: 
            classifier_network = GSM_LDP(input_size = 20, hidden_size = 20)


    return train_dataloader, test_dataloader, val_dataloader, classifier_network



### Fit model to estimate noise rate with n_1 = 1, 2, 4
def get_dataloaders_auto_NE(sampler=True,
                                utility_tag = 'Gender_noisy_cat',
                                balanced_tag = 'Gender_noisy_cat',
                                shuffle_train = True,
                                shuffle_val = True,
                                drop_last = True,
                                seed = None,
                                n1 = None,
                                split = None,
                                privacy = None, 
                                ratio = None,
                                subset = None):

    data_train, data_test, data_val, col_tags = load_auto_NE(seed, privacy, n1, split, ratio, subset)

    batch_size = int(data_train.shape[0] / 8)

    n_utility = data_train[utility_tag].nunique()

    data_train['utility_cat'] = data_train[utility_tag]
    data_test['utility_cat'] = data_test[utility_tag]
    data_val['utility_cat'] = data_val[utility_tag]

    # get prior of subgroups
    p_utility = data_train['utility_cat'].mean()
    
    weight_dic = get_weight_dict(data_train, balanced_tag)
    train_weights = torch.DoubleTensor(data_train[balanced_tag].apply(lambda x: weight_dic[x]).values)
    train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_weights, len(train_weights))

    composed = None

    if sampler:
        train_dataloader = DataLoader(TablePandasDataset_Noise_Estimate(pd = data_train, cov_list = col_tags,
                                                         utility_tag = 'utility_cat',
                                                         transform = composed),
                                      batch_size = batch_size,
                                      sampler = train_sampler, pin_memory = True, drop_last = drop_last)
    else:
        train_dataloader = DataLoader(TablePandasDataset_Noise_Estimate(pd = data_train, cov_list = col_tags,
                                                         utility_tag = 'utility_cat',
                                                         transform = composed),
                                      batch_size= batch_size,
                                      shuffle = shuffle_train, pin_memory=True, drop_last = drop_last)
        print('Not using balanced sampling!')
    
    val_dataloader = DataLoader(TablePandasDataset_Noise_Estimate(pd = data_val, cov_list = col_tags,
                                                   utility_tag = 'utility_cat',
                                                   transform = composed),
                                batch_size = batch_size,
                                shuffle = shuffle_val, pin_memory = True, drop_last = drop_last)

    test_dataloader = DataLoader(TablePandasDataset_Noise_Estimate(pd = data_test, cov_list = col_tags,
                                                    utility_tag = 'utility_cat', 
                                                    transform = composed),
                                 batch_size = batch_size,
                                 shuffle = False, pin_memory = True, drop_last = drop_last)
    
    classifier_network = Noise_Estimate(input_size = len(col_tags), hidden_size = 20)
    return train_dataloader, test_dataloader, val_dataloader, classifier_network



### Compute noise rate pi using the fitted model
def get_dataloaders_auto_NE_Pi(sampler=True,
                                utility_tag = 'Gender_noisy_cat',
                                balanced_tag = 'Gender_noisy_cat',
                                shuffle_train = True,
                                shuffle_val = True,
                                drop_last = True,
                                privacy = None, 
                                seed = None,
                                n1 = None,
                                ratio = None,
                                split = None,
                                subset = None):

    data, col_tags = load_auto_NE_Pi(privacy, ratio, split, n1, subset)

    batch_size = data.shape[0]

    data['utility_cat'] = data[utility_tag]
    

    weight_dic = get_weight_dict(data, balanced_tag)
    weights = torch.DoubleTensor(data[balanced_tag].apply(lambda x: weight_dic[x]).values)
    sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))

    composed = None
    
    dataloader_Pi = DataLoader(TablePandasDataset_Noise_Estimate(pd = data, cov_list = col_tags,
                                                        utility_tag = 'utility_cat',
                                                        transform = composed),
                                    batch_size = batch_size,
                                    sampler = sampler, pin_memory = True, drop_last = drop_last)
    
    if subset == 1:
        save_path = os.path.expanduser(f'~/Desktop/DFP-1/Code/Models/Auto/NE/auto_NE_n{n1}_r{ratio}_p{privacy}_s{seed}_sp{split}.pth')
        classifier_network = torch.load(save_path)
    if subset != 1:
        save_path = os.path.expanduser(f'~/Desktop/DFP-1/Code/Models/Auto/Auto_subset_{subset}/NE/auto_NE_n{n1}_r{ratio}_p{privacy}_s{seed}_sp{split}.pth')
        classifier_network = torch.load(save_path)

    return dataloader_Pi, classifier_network
################################################################################################################################



################################################################################################################################
### Health
### gender as the sensitive attribute

### load health dataset
def load_health(seed, privacy, ratio, subset):
    if ((ratio == 1) & (subset == 1)):
        data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_{privacy}.csv').copy()    
    if ((ratio == 1) & (subset != 1)):
        data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_subset_{subset}/health_{privacy}.csv').copy()
    if ((ratio != 1) & (subset == 1)):
        data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_uneven/health_uneven_{ratio}-1_{privacy}.csv').copy()
    if ((ratio != 1) & (subset != 1)):
        data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_subset_{subset}/health_uneven/health_uneven_{ratio}-1_{privacy}.csv').copy()
        
    tags = ['Gender', 'Gender_noisy']
    for t in tags:
        data[t + '_cat'] = pd.Categorical(data[t]).codes

    col_tags = list(data.columns)
    remove_tags = ['Gender', 'Gender_cat', 'Gender_noisy', 'Gender_noisy_cat', 'charges']
    for rm in remove_tags:
        col_tags.remove(rm)

    utility_tag = 'charges'
    secret_tag = 'Gender_cat'
    noisy_tag = 'Gender_noisy_cat'

    data['region'] = pd.Categorical(data['region']).codes
    data['smoker'] = pd.Categorical(data['smoker']).codes
    
    sc = StandardScaler()
    data[col_tags] = sc.fit_transform(data[col_tags])
    data_train, data_test = train_test_split(data, test_size = 0.2, random_state = seed)
    data_train, data_val = train_test_split(data_train, test_size = 0.125, random_state = seed)

    
    return data_train, data_test, data_val, col_tags



### load subset of data to estimate pi from the data (n1 = 1, 2, 4)
def load_health_NE(seed, privacy, n1, split, ratio, subset):
    # balanced case
    if ((ratio == 1) & (subset == 1)):
        # n1 = 1
        if n1 == 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_{privacy}.csv').copy()
        # n1 = 2, 4
        if n1 != 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_split_{ratio}/health_{ratio}_{privacy}_{n1}_{split}.csv').copy()
    if ((ratio == 1) & (subset != 1)):
        # n1 = 1
        if n1 == 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_subset_{subset}/health_{privacy}.csv').copy()
        # n1 = 2, 4
        if n1 != 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_subset_{subset}/health_split_{ratio}/health_{ratio}_{privacy}_{n1}_{split}.csv').copy()
    
    # unbalanced case
    if ((ratio != 1) & (subset == 1)):
        # n1 = 1
        if n1 == 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_uneven/health_uneven_{ratio}-1_{privacy}.csv').copy()
        # n1 = 2, 4
        if n1 != 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_split_{ratio}/health_{ratio}_{privacy}_{n1}_{split}.csv').copy()
    if ((ratio != 1) & (subset != 1)):
        # n1 = 1
        if n1 == 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_subset_{subset}/health_uneven/health_uneven_{ratio}-1_{privacy}.csv').copy()
        # n1 = 2, 4
        if n1 != 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_subset_{subset}/health_split_{ratio}/health_{ratio}_{privacy}_{n1}_{split}.csv').copy()

    tags = ['Gender', 'Gender_noisy']
    for t in tags:
        data[t + '_cat'] = pd.Categorical(data[t]).codes

    col_tags = list(data.columns)
    remove_tags = ['Gender', 'Gender_cat', 'Gender_noisy', 'Gender_noisy_cat', 'charges']
    for rm in remove_tags:
        col_tags.remove(rm)

    utility_tag = 'Gender_cat'
    secret_tag = 'Gender_cat'
    noisy_tag = 'Gender_noisy_cat'

    data['region'] = pd.Categorical(data['region']).codes
    data['smoker'] = pd.Categorical(data['smoker']).codes
    
    sc = StandardScaler()
    data[col_tags] = sc.fit_transform(data[col_tags])
    data_train, data_test = train_test_split(data, test_size = 0.2, random_state = seed, stratify = data['Gender_noisy_cat'])
    data_train, data_val = train_test_split(data_train, test_size = 0.125, random_state = seed, stratify = data_train['Gender_noisy_cat'])

    return data_train, data_test, data_val, col_tags



### used to compute Pi matrix with the trained model
def load_health_NE_Pi(privacy, ratio, split, n1, subset):
    # balanced case
    if ((ratio == 1) & (subset == 1)):
        # n1 = 1
        if n1 == 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_{privacy}.csv').copy()
        # n1 = 2, 4
        if n1 != 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_split_{ratio}/health_{ratio}_{privacy}_{n1}_{split}.csv').copy()
    if ((ratio == 1) & (subset != 1)):
        # n1 = 1
        if n1 == 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_subset_{subset}/health_{privacy}.csv').copy()
        # n1 = 2, 4
        if n1 != 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_subset_{subset}/health_split_{ratio}/health_{ratio}_{privacy}_{n1}_{split}.csv').copy()
    
    # unbalanced case
    if ((ratio != 1) & (subset == 1)):
        # n1 = 1
        if n1 == 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_uneven/health_uneven_{ratio}-1_{privacy}.csv').copy()
        # n1 = 2, 4
        if n1 != 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_split_{ratio}/health_{ratio}_{privacy}_{n1}_{split}.csv').copy()
    if ((ratio != 1) & (subset != 1)):
        # n1 = 1
        if n1 == 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_subset_{subset}/health_uneven/health_uneven_{ratio}-1_{privacy}.csv').copy()
        # n1 = 2, 4
        if n1 != 1:
            data = pd.read_csv(f'~/Desktop/DFP-1/Code/Data/health/health_subset_{subset}/health_split_{ratio}/health_{ratio}_{privacy}_{n1}_{split}.csv').copy()

    tags = ['Gender', 'Gender_noisy']
    for t in tags:
        data[t + '_cat'] = pd.Categorical(data[t]).codes

    col_tags = list(data.columns)
    remove_tags = ['Gender', 'Gender_cat', 'Gender_noisy', 'Gender_noisy_cat', 'charges']
    for rm in remove_tags:
        col_tags.remove(rm)

    utility_tag = 'OUTCOME_cat'
    secret_tag = 'Gender_cat'
    noisy_tag = 'Gender_noisy_cat'

    data['region'] = pd.Categorical(data['region']).codes
    data['smoker'] = pd.Categorical(data['smoker']).codes

    sc = StandardScaler()
    data[col_tags] = sc.fit_transform(data[col_tags])
    data_Pi = data

    return data_Pi, col_tags



### dataloader for
# 1: unawareness            | input: X
# 2: best estimate          | input: X,A
# 3: GSM                    | input: X,A
# 4: GSM-LDP                | input: X,Z 
# 5: Noise Rate Estimate    | input: X
def get_dataloaders_health(sampler = True,
                            secret_tag ='Gender_cat', 
                            utility_tag ='charges',
                            noisy_tag = 'Gender_noisy_cat',
                            balanced_tag = 'Gender_cat',
                            shuffle_train = True,
                            shuffle_val = True,
                            drop_last = True,
                            seed = None,
                            privacy = None, 
                            ratio = None,
                            model = None,
                            data_type = None,
                            task = None,
                            subset = None):

    data_train, data_test, data_val, col_tags = load_health(seed, privacy, ratio, subset)

    batch_size = int(data_train.shape[0] / 8)

    n_utility = data_train[utility_tag].nunique()
    n_secret = data_train[secret_tag].nunique()
    n_noisy = data_train[noisy_tag].nunique()

    ## Dataframe rename
    data_train['secret_cat'] = data_train[secret_tag].apply(lambda x: to_categorical(x, num_classes=n_secret))
    data_test['secret_cat'] = data_test[secret_tag].apply(lambda x: to_categorical(x, num_classes=n_secret))
    data_val['secret_cat'] = data_val[secret_tag].apply(lambda x: to_categorical(x, num_classes=n_secret))

    data_train['noisy_cat'] = data_train[noisy_tag].apply(lambda x: to_categorical(x, num_classes=n_noisy))
    data_test['noisy_cat'] = data_test[noisy_tag].apply(lambda x: to_categorical(x, num_classes=n_noisy))
    data_val['noisy_cat'] = data_val[noisy_tag].apply(lambda x: to_categorical(x, num_classes=n_noisy))

    data_train['utility_cat'] = data_train[utility_tag]
    data_test['utility_cat'] = data_test[utility_tag]
    data_val['utility_cat'] = data_val[utility_tag]
    
    

    weight_dic = get_weight_dict(data_train, balanced_tag)
    train_weights = torch.DoubleTensor(data_train[balanced_tag].apply(lambda x: weight_dic[x]).values)
    train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_weights, len(train_weights))


    composed = None
    if sampler:
        train_dataloader = DataLoader(TablePandasDataset(pd = data_train, cov_list = col_tags,
                                                         utility_tag = 'utility_cat', sensitive_tag = 'secret_cat', 
                                                         noisy_sensitive_tag = 'noisy_cat',
                                                         transform = composed),
                                      batch_size = batch_size,
                                      sampler = train_sampler, pin_memory = True, drop_last = drop_last)
    else:
        train_dataloader = DataLoader(TablePandasDataset(pd = data_train, cov_list = col_tags,
                                                         utility_tag = 'utility_cat', sensitive_tag = 'secret_cat', 
                                                         noisy_sensitive_tag = 'noisy_cat',
                                                         transform = composed),
                                      batch_size= batch_size,
                                      shuffle = shuffle_train, pin_memory=True, drop_last = drop_last)
        print('Not using balanced sampling!')
    
    val_dataloader = DataLoader(TablePandasDataset(pd = data_val, cov_list = col_tags,
                                                   utility_tag = 'utility_cat', sensitive_tag = 'secret_cat', 
                                                   noisy_sensitive_tag = 'noisy_cat', 
                                                   transform = composed),
                                batch_size = batch_size,
                                shuffle = shuffle_val, pin_memory = True, drop_last = drop_last)

    test_dataloader = DataLoader(TablePandasDataset(pd = data_test, cov_list = col_tags,
                                                    utility_tag = 'utility_cat', sensitive_tag = 'secret_cat', 
                                                    noisy_sensitive_tag = 'noisy_cat', 
                                                    transform = composed),
                                 batch_size = batch_size,
                                 shuffle = False, pin_memory = True, drop_last = drop_last)
    
    if model == 'unawareness':
        classifier_network = Unawareness(input_size = len(col_tags), hidden_size = 8)
    if model == 'best-estimate':
        if data_type == 'X':
            classifier_network = Best_Estimate(input_size = len(col_tags) + 2, hidden_size = 8)
        else: 
            classifier_network = Best_Estimate(input_size = 10, hidden_size = 8)
    if model == 'GSM':
        if data_type == 'X':
            classifier_network = GSM(input_size = len(col_tags), hidden_size = 8)
        else: 
            classifier_network = GSM(input_size = 8, hidden_size = 8)
    if model == 'GSM-LDP':
        if data_type == 'X':
            classifier_network = GSM_LDP(input_size = len(col_tags), hidden_size = 8)
        else: 
            classifier_network = GSM_LDP(input_size = 8, hidden_size = 8)


    return train_dataloader, test_dataloader, val_dataloader, classifier_network



### Fit model to estimate noise rate with n_1 = 1, 2, 4
def get_dataloaders_health_NE(sampler=True,
                                utility_tag = 'Gender_noisy_cat',
                                balanced_tag = 'Gender_noisy_cat',
                                shuffle_train = True,
                                shuffle_val = True,
                                drop_last = True,
                                seed = None,
                                n1 = None,
                                split = None,
                                privacy = None, 
                                ratio = None,
                                subset = None):

    data_train, data_test, data_val, col_tags = load_health_NE(seed, privacy, n1, split, ratio, subset)

    batch_size = int(data_train.shape[0] / 8)

    n_utility = data_train[utility_tag].nunique()

    data_train['utility_cat'] = data_train[utility_tag]
    data_test['utility_cat'] = data_test[utility_tag]
    data_val['utility_cat'] = data_val[utility_tag]

    # get prior of subgroups
    p_utility = data_train['utility_cat'].mean()
    
    weight_dic = get_weight_dict(data_train, balanced_tag)
    train_weights = torch.DoubleTensor(data_train[balanced_tag].apply(lambda x: weight_dic[x]).values)
    train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_weights, len(train_weights))

    composed = None

    if sampler:
        train_dataloader = DataLoader(TablePandasDataset_Noise_Estimate(pd = data_train, cov_list = col_tags,
                                                         utility_tag = 'utility_cat',
                                                         transform = composed),
                                      batch_size = batch_size,
                                      sampler = train_sampler, pin_memory = True, drop_last = drop_last)
    else:
        train_dataloader = DataLoader(TablePandasDataset_Noise_Estimate(pd = data_train, cov_list = col_tags,
                                                         utility_tag = 'utility_cat',
                                                         transform = composed),
                                      batch_size= batch_size,
                                      shuffle = shuffle_train, pin_memory=True, drop_last = drop_last)
        print('Not using balanced sampling!')
    
    val_dataloader = DataLoader(TablePandasDataset_Noise_Estimate(pd = data_val, cov_list = col_tags,
                                                   utility_tag = 'utility_cat',
                                                   transform = composed),
                                batch_size = batch_size,
                                shuffle = shuffle_val, pin_memory = True, drop_last = drop_last)

    test_dataloader = DataLoader(TablePandasDataset_Noise_Estimate(pd = data_test, cov_list = col_tags,
                                                    utility_tag = 'utility_cat', 
                                                    transform = composed),
                                 batch_size = batch_size,
                                 shuffle = False, pin_memory = True, drop_last = drop_last)
    
    classifier_network = Noise_Estimate(input_size = len(col_tags), hidden_size = 8)
    return train_dataloader, test_dataloader, val_dataloader, classifier_network



### Compute noise rate pi using the fitted model
def get_dataloaders_health_NE_Pi(sampler=True,
                                utility_tag = 'Gender_noisy_cat',
                                balanced_tag = 'Gender_noisy_cat',
                                shuffle_train = True,
                                shuffle_val = True,
                                drop_last = True,
                                privacy = None, 
                                seed = None,
                                n1 = None,
                                ratio = None,
                                split = None,
                                subset = None):

    data, col_tags = load_health_NE_Pi(privacy, ratio, split, n1, subset)

    batch_size = data.shape[0]

    data['utility_cat'] = data[utility_tag]
    

    weight_dic = get_weight_dict(data, balanced_tag)
    weights = torch.DoubleTensor(data[balanced_tag].apply(lambda x: weight_dic[x]).values)
    sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))

    composed = None
    
    dataloader_Pi = DataLoader(TablePandasDataset_Noise_Estimate(pd = data, cov_list = col_tags,
                                                        utility_tag = 'utility_cat',
                                                        transform = composed),
                                    batch_size = batch_size,
                                    sampler = sampler, pin_memory = True, drop_last = drop_last)
    
    if subset == 1:
        save_path = os.path.expanduser(f'~/Desktop/DFP-1/Code/Models/Health/NE/health_NE_n{n1}_r{ratio}_p{privacy}_s{seed}_sp{split}.pth')
        classifier_network = torch.load(save_path)
    if subset != 1:
        save_path = os.path.expanduser(f'~/Desktop/DFP-1/Code/Models/Health/Health_subset_{subset}/NE/health_NE_n{n1}_r{ratio}_p{privacy}_s{seed}_sp{split}.pth')
        classifier_network = torch.load(save_path)

    return dataloader_Pi, classifier_network