from folktables import ACSDataSource, ACSTravelTime, generate_categories, BasicProblem
import folktables
import requests
import pandas as pd
import numpy as np
from clustering_utils import *
import os

filename_demo = "demographics.npy"
filename_A = "matrices.npy"
filename_phi = "Phi.npy"
# filename_size = "size.npy"

if os.path.isfile(filename_A):
    # load matrix from file
    A_all = np.load(filename_A)
    Phi = np.load(filename_phi)
    # array_sizes = np.load(filename_size)
    demographics = np.load(filename_demo)

else:
    def preprocess(group):
        # process school
        def _p_schl(x):
            if x<=15:
                return 0 # less than HS
            if x<=19:
                return 1 # more than HS less than underrgraduate
            if x>=20:
                return 2 # college and above
            return x

        # process means of transporation
        def _p_jwtr(x):
            if x == 1:
                return 0 # driving
            if x <= 7:
                return 1 # public transport
            if x in [8, 9, 10]: # biked or walked
                return 2
            if x == 11:
                return 3 # worked from home
            if x == 12:
                return 4 # another method
            return x

        # process race variable
        def _p_rac1p(x):
            if x <= 4:
                return x-1
            if x == 5: # combine Native and Alaska Native Only into one category
                    # since Alaska Native Only contains just 9 individuals
                    # and decision phi that achieves minimum risk is not defined
                return 3
            if x > 5:
                return x-2
            return x

        def postprocess(df, label_df):
            ## Post process features
            df.loc[:, 'SCHL'] = df['SCHL'].apply(_p_schl)
            # process marital status (binary) 
            df.loc[:, 'MAR'] = df['MAR'].apply(lambda x: 0 if x == 1 else 1)
            # process sex, dis, mig (binary, 0 index)
            # df.loc[:, 'SEX'] = df['SEX'].apply(lambda x: x-1)
            # df.loc[:, 'SEX'] = df['SEX'].apply(lambda x: 0 if x == 1 else 1)
            df.loc[:, 'DIS'] = df['DIS'].apply(lambda x: x-1)
            df.loc[:, 'MIG'] = df['MIG'].apply(lambda x: x-1)
            # df.loc[:, 'RAC1P'] = df['RAC1P'].apply(_p_rac1p)
            # procee means of transportation
            # df.loc[:, 'JWTR'] = df['JWTR'].apply(_p_jwtr)
            # process citizenship (0 indec)
            df.loc[:, 'CIT'] = df['CIT'].apply(lambda x:x-1)
            # df.loc[:, 'ST'] = df['ST'].apply(lambda x: x-1)
            categorical_features = ['SCHL', 'MAR', 'MIG', 'CIT'] #, 'SEX', 'RAC1P']
            for f in categorical_features:
                dummy_df = pd.get_dummies(df[f], prefix = f, prefix_sep = "_")
                df = pd.merge(
                    left = df,
                    right = dummy_df,
                    left_index = True,
                    right_index = True,
                )
            df.drop(labels = categorical_features, axis=1, inplace=True)
            # scale numerical variables for better conditioning
            df['POVPIP'] = df['POVPIP']/500
            df['AGEP'] = df['AGEP']/80
            df = df.apply(lambda x: np.nan_to_num(x, -1))
            df = df.astype(np.float64)
            ## Post process label
            # label_df.loc[:,'JWMNP'] = label_df['JWMNP'].apply(lambda x: 10*np.log(1+x))
            
            return df, label_df

        features=[
            'AGEP',
            'SCHL',
            'MAR',
            # 'SEX',
            'DIS',
            'ESP',
            'MIG',
            # 'RELP',
            # 'RAC1P',
            # 'PUMA',
            'ST',
            'CIT',
            # 'OCCP',
            # 'JWTR',
            # 'POWPUMA',
            'POVPIP',
        ]

        features, label = postprocess(group[features], group['JWMNP'])
        return features, label

    # state_list = ['CA']
    
    state_list = ['AL', 'AK', 'AZ', 'AR', 'CA', 'CO', 'CT', 'DE', 'FL', 'GA', 'HI',
                'ID', 'IL', 'IN', 'IA', 'KS', 'KY', 'LA', 'ME', 'MD', 'MA', 'MI',
                'MN', 'MS', 'MO', 'MT', 'NE', 'NV', 'NH', 'NJ', 'NM', 'NY', 'NC',
                'ND', 'OH', 'OK', 'OR', 'PA', 'RI', 'SC', 'SD', 'TN', 'TX', 'UT',
                'VT', 'VA', 'WA', 'WV', 'WI', 'WY', 'PR']

    data_source = ACSDataSource(survey_year='2021', horizon='1-Year', survey='person')
    acs_definition = data_source.get_definitions(download=True)
    acs_data = data_source.get_data(states=state_list, download=True)
    acs_data.loc[:,'JWMNP'] = acs_data['JWMNP'].apply(lambda x: 10*np.log(1+x))
    acs_data = pd.DataFrame(acs_data)
    acs_data = acs_data.loc[acs_data['JWMNP'].notnull()]
    groups = acs_data.groupby(['PUMA', 'SEX', 'RAC1P'])

    print(len(groups))

    import statistics

    # Get the sizes of all arrays in the list
    array_sizes = np.array([len(group) for _, group in groups])
    print(array_sizes)

    # Compute statistics for the array sizes
    mean_size = statistics.mean(array_sizes)
    median_size = statistics.median(array_sizes)
    std_dev_size = statistics.stdev(array_sizes)
    min_size = min(array_sizes)
    max_size = max(array_sizes)

    print("Array size statistics:")
    print(f"Mean size: {mean_size}")
    print(f"Median size: {median_size}")
    print(f"Standard deviation of size: {std_dev_size}")
    print(f"Minimum size: {min_size}")
    print(f"Maximum size: {max_size}")

    # for name, group in groups:
        # print('Group:', name)
        # print(group)

    def create_subpop(X, y):
        U, s, Vt = np.linalg.svd(X)
        # project the data onto the first 10 principal components
        n_components = 10
        X = X.dot(Vt[:n_components,:].T)
        # print(X)
        A = np.matmul(X.T, X) / X.shape[0] #Empirical Feature covariance
        X_dagger = np.linalg.pinv(X)
        phi = np.matmul(X_dagger, y)
        return phi, A

    def create_dataset(groups):
        A_all = []
        Phi = []
        demographics = []
        for name,g in groups:
            X, y = preprocess(g)
            X = np.array(X)
            y = np.array(y)
            phi, A = create_subpop(X, y)
            if np.linalg.matrix_rank(A) == min(A.shape):
                A_all.append(A)
                Phi.append(phi)
                demographics.append(name)
        return np.array(Phi), np.array(A_all), np.array(demographics)

    Phi, A_all, demographics = create_dataset(groups)
    print(Phi.shape)
    np.save(filename_A, A_all)
    np.save(filename_phi, Phi)
    # np.save(filename_size, array_sizes)
    np.save(filename_demo, demographics)


# X = [5, 6, 7, 8, 9, 10, 15, 20, 25, 50, 75, 100, 200, 300, 400, 500]
# X = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
# X = [10, 20, 30, 40]
X = [i for i in range(5, 51, 1)]

sex = demographics[:,1].reshape(-1)
race = demographics[:,2].reshape(-1)
num_trials = 50
categories = race

from collections import Counter

indices_to_remove = []
filtered_arr = []
values_set = [4, 5, 7]
for i, value in enumerate(race):
    if value in values_set:
        indices_to_remove.append(i)
    else:
        filtered_arr.append(value)
print(len(indices_to_remove))

print(A_all.shape)
print(Phi.shape)
print(race.shape)
# Deleting elements from the original array
race = np.array([value for i, value in enumerate(race) if i not in indices_to_remove])
A_all = np.array([value for i, value in enumerate(A_all) if i not in indices_to_remove])
Phi = np.array([value for i, value in enumerate(Phi) if i not in indices_to_remove])
categories = race

print(A_all.shape)
print(Phi.shape)
print(race.shape)
def count_categories(original_list):
    # Use Counter to count the occurrences of each category
    print(original_list)
    category_counts = Counter(original_list)
    print(category_counts)
    # Create a list with counts corresponding to the categories in the original list
    counts_list = [category_counts[category] for category in original_list]
    
    return np.array(counts_list)
race_size = count_categories(race)
print(race_size)

# print(A_all.shape)
for m in X:
    print("Number of centers : ", m)
    # print("Normal Statistics")
    init_Theta = initialize_multiple(num_trials, Phi, A_all, m, sizes=None)

    Theta_random_users = initialize_multiple_random_users(num_trials, Phi, A_all, m, sizes=None)

    Theta_greedy = initialize_multiple_greedy(num_trials, Phi, A_all, m, sizes=None)

    Theta_epsilon_greedy = initialize_multiple_epsilon_greedy(num_trials, Phi, A_all, m, sizes=None)

    fair_init_Theta = initialize_multiple(num_trials, Phi, A_all, m, sizes=race_size)

    fair_Theta_random_users = initialize_multiple_random_users(num_trials, Phi, A_all, m, sizes=race_size)

    fair_Theta_greedy = initialize_multiple_greedy(num_trials, Phi, A_all, m, sizes=race_size)

    fair_Theta_epsilon_greedy = initialize_multiple_epsilon_greedy(num_trials, Phi, A_all, m, sizes=race_size)

    print("Average Stats")

    losses_mult_trials = losses_across_trials(init_Theta, Phi, A_all)
    print("AcQUIre : ", np.mean(losses_mult_trials))

    fair_losses_mult_trials = losses_across_trials(fair_init_Theta, Phi, A_all)
    print("Fair AcQUIre : ", np.mean(fair_losses_mult_trials))

    greedy_loss = losses_across_trials(Theta_greedy, Phi, A_all)
    print("Greedy : ", np.mean(greedy_loss))

    fair_greedy_loss = losses_across_trials(fair_Theta_greedy, Phi, A_all)
    print("Fair Greedy : ", np.mean(fair_greedy_loss))

    epsilon_greedy_loss = losses_across_trials(Theta_epsilon_greedy, Phi, A_all)
    print("Epsilon Greedy : ", np.mean(epsilon_greedy_loss))

    fair_epsilon_greedy_loss = losses_across_trials(fair_Theta_epsilon_greedy, Phi, A_all)
    print("Fair Epsilon Greedy : ", np.mean(fair_epsilon_greedy_loss))

    losses_mult_trials_random = losses_across_trials(Theta_random_users, Phi, A_all)
    print("Uniform : ", np.mean(losses_mult_trials_random))

    fair_losses_mult_trials_random = losses_across_trials(fair_Theta_random_users, Phi, A_all)
    print("Fair Uniform : ", np.mean(fair_losses_mult_trials_random))
        
    print("Fairness Stats")

    fair_obj_random = calculate_fair_objective_multiple(losses_mult_trials_random, categories)
    print("Uniform : ", np.mean(fair_obj_random))

    fair_obj_fair_random = calculate_fair_objective_multiple(fair_losses_mult_trials_random, categories)
    print("Fair Uniform : ", np.mean(fair_obj_fair_random))

    fair_obj_acquire = calculate_fair_objective_multiple(losses_mult_trials, categories)
    print("AcQUIre : ", np.mean(fair_obj_acquire))

    fair_obj_fair_acquire = calculate_fair_objective_multiple(fair_losses_mult_trials, categories)
    print("Fair AcQUIre : ", np.mean(fair_obj_fair_acquire))

    fair_obj_greedy = calculate_fair_objective_multiple(greedy_loss, categories)
    print("Greedy : ", np.mean(fair_obj_greedy))

    fair_obj_fair_greedy = calculate_fair_objective_multiple(fair_greedy_loss, categories)
    print("Fair Greedy : ", np.mean(fair_obj_fair_greedy))

    fair_obj_epsilon_greedy = calculate_fair_objective_multiple(epsilon_greedy_loss, categories)
    print("Epsilon Greedy : ", np.mean(fair_obj_epsilon_greedy))

    fair_obj_fair_epsilon_greedy = calculate_fair_objective_multiple(fair_epsilon_greedy_loss, categories)
    print("Fair Epsilon Greedy : ", np.mean(fair_obj_fair_epsilon_greedy))



    avg_dict = {
    "random": losses_mult_trials_random, 
    "acquire":losses_mult_trials,
    "greedy": greedy_loss,
    "epsilon_greedy": epsilon_greedy_loss, 
    "fair_random": fair_losses_mult_trials_random, 
    "fair_acquire":fair_losses_mult_trials,
    "fair_greedy": fair_greedy_loss,
    "fair_epsilon_greedy": fair_epsilon_greedy_loss
    }

    fair_dict = {
    "random": fair_obj_random, 
    "acquire":fair_obj_acquire,
    "greedy": fair_obj_greedy,
    "epsilon_greedy": fair_obj_epsilon_greedy, 
    "fair_random": fair_obj_fair_random, 
    "fair_acquire":fair_obj_fair_acquire,
    "fair_greedy": fair_obj_fair_greedy,
    "fair_epsilon_greedy": fair_obj_fair_epsilon_greedy 
    }

    # # save dictionary to npy file
    np.save("logs/icml_avg_baseline_{}.npy".format(m), avg_dict)
    np.save("logs/icml_fair_baseline_{}.npy".format(m), fair_dict)

