import rpy2.robjects as ro
from rpy2.robjects import r as R
import rpy2.robjects.numpy2ri
import os
import numpy as np
import itertools
import multiprocessing as mp
from tqdm import tqdm
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from main import find_infty_error, find_MLE_DAC, DP_invert
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt



def experiment(df,ratio, truth_cdf, ep, DAC_n,n=None):
    r = 1 - np.exp(-ep)  # ACRR retention probability
    random.seed(int.from_bytes(os.urandom(16), 'big'))
    np.random.seed(int.from_bytes(os.urandom(4), 'big'))

    K = len(ratio)
    if n is None:
        n = len(df)
    assert (n<=len(df))
    df=df.sample(n)
    questions = np.random.rand(n)
    responses = np.array(df['race_code'])
    responses[questions < np.array(df['standard_salary'])] = 0  # Censor
    rnd_for_DP = np.random.rand(n)
    responses = np.where(rnd_for_DP < r, responses, 0)
    MLE_support, MLE_cdfs = find_MLE_DAC(questions, responses, DAC_n)
    MLE_cdfs_no_correction = DP_invert(MLE_cdfs, r)
    MLE_cdfs_oracle_clip = np.minimum(MLE_cdfs_no_correction, ratio)
    _, oracle_error = find_infty_error(MLE_support, MLE_cdfs_oracle_clip, truth_cdf, ratio,False)
    MLE_cdfs_stop_at_1 = np.copy(MLE_cdfs_no_correction)
    sums = MLE_cdfs_stop_at_1.sum(axis=1)
    ww = np.where(sums <= 1)[0].max(initial=-1)
    MLE_cdfs_stop_at_1[ww + 1:] = MLE_cdfs_stop_at_1[ww]
    _, stop_at_one_error = find_infty_error(MLE_support, MLE_cdfs_stop_at_1, truth_cdf, ratio,False)
    return n, ep, DAC_n, float(oracle_error), float(stop_at_one_error)


if __name__ == "__main__":
    rpy2.robjects.numpy2ri.activate()
    R.source("icm.R")  # Load R source script

    # 1. Read CSV and keep relevant columns
    gov_census = pd.read_csv('gov_census.csv', usecols=['race', 'salary'])

    # 2. Replace race not in ['black', 'white', 'asian'] with 'other'
    gov_census['race'] = gov_census['race'].str.lower()
    # gov_census.loc[~gov_census['race'].isin(['black', 'white', 'asian']), 'race'] = 'other'
    gov_census = gov_census[gov_census['salary'] < 200000]
    # 3. Find max and min salary
    max_salary = gov_census['salary'].max()
    min_salary = gov_census['salary'].min()
    print("Max salary:", max_salary)
    print("Min salary:", min_salary)

    # 4. Create standard_salary column
    gov_census['standard_salary'] = (gov_census['salary'] - min_salary) / (max_salary - min_salary)

    # 5. find total number of races
    race_counts = gov_census['race'].value_counts()
    total_races = race_counts.sum()
    print("Total races:", total_races)

    # Add race_code column starting from 1
    unique_races = race_counts.index.tolist()
    race_to_code = {race: idx + 1 for idx, race in enumerate(unique_races)}
    code_to_race = {idx + 1: race for idx, race in enumerate(unique_races)}
    gov_census['race_code'] = gov_census['race'].map(race_to_code)

    # 6. Compute empirical Cgov_census functions and ratios
    gov_census_truth_cdf = []
    gov_census_ratio = []
    total_samples = len(gov_census)


    def make_cdf_function(values):
        values_sorted = np.sort(values)

        def cdf_func(x):
            return np.searchsorted(values_sorted, x, side='right') / len(values_sorted)

        return cdf_func


    for race in race_to_code.keys():
        race_values = gov_census[gov_census['race'] == race]['standard_salary'].values
        cdf_func = make_cdf_function(race_values)
        gov_census_truth_cdf.append(cdf_func)
        gov_census_ratio.append(len(race_values) / total_samples)

    # 7. Plot CDFs
    plt.figure(figsize=(10, 6))
    x_vals = np.linspace(0, 1, 500)
    for race, cdf_func in enumerate(gov_census_truth_cdf):
        y_vals = [cdf_func(x) for x in x_vals]
        plt.plot(x_vals, y_vals, label=code_to_race[race + 1])
    plt.xlabel('Standardized Salary')
    plt.ylabel('Empirical CDF')
    plt.title('Empirical CDF of Standardized Salary by Race')
    plt.legend()
    plt.grid(True)
    plt.show()

    rpy2.robjects.numpy2ri.activate()
    R.source("icm.R")  # Load R source script

    # Parameter settings for the experiment
    sample_sizes = [5000]
    eps = [1,2,3]  # Privacy levels
    DAC_candidate = [ 4]  # Number of DAC splits
    repeat = 100 # Repetitions

    # Generate task list
    tasks = list(itertools.product(range(repeat), sample_sizes, eps, DAC_candidate))
    random.shuffle(tasks)

    def run_task(task):
        _, s, ep, DAC_n = task
        return experiment(gov_census, gov_census_ratio, gov_census_truth_cdf, ep=ep, DAC_n=DAC_n, n=s)


    # Run experiments in parallel
    with mp.Pool() as pool:
        results = list(tqdm(pool.imap_unordered(run_task, tasks, chunksize=1), total=len(tasks)))

    results = np.array(results)
    np.save("experimentReal_sm.npy", results)  # Save output to disk