import rpy2.robjects as ro
from rpy2.robjects import r as R
import rpy2.robjects.numpy2ri
import os
import numpy as np
import itertools
import multiprocessing as mp
from tqdm import tqdm
import random

# Define the mixture proportions for each component of the distribution
ratio = np.array([0.2, 0.3, 0.3, 0.2])

# Define inverse CDFs (quantile functions) for each component
truth_icdf = [None] * len(ratio)
truth_icdf[0] = lambda x: x
truth_icdf[1] = lambda x: x ** 4
truth_icdf[2] = lambda x: x ** (1 / 4)
truth_icdf[3] = lambda x: x / 3 + 2 / 3  # Overrides the previous line

# Define CDFs for each component
truth_cdf = [None] * len(ratio)
truth_cdf[0] = lambda x: x
truth_cdf[1] = lambda x: x ** (1 / 4)
truth_cdf[2] = lambda x: x ** 4
truth_cdf[3] = lambda x: np.where(x > 2 / 3, x * 3 - 2, 0)  # Overrides the previous line

# Fixed arguments used in experiments
args_fixed = {
    "ratio": ratio,
    "truth_icdf": truth_icdf,
    "truth_cdf": truth_cdf
}

# Wrapper for calling R function ComputeMLE
# Returns support points and estimated conditional CDFs

def find_MLE(questions, responses):
    questions, responses = zip(*sorted(zip(questions, responses)))
    info_matrix = np.stack((questions, responses)).T
    res = R.ComputeMLE(info_matrix)
    res = np.array(res[0])
    support = res[:, 0]
    cdfs = res[:, 1:-1]
    return support, cdfs

# Step function interpolation

def step_interp(x, xp, fp):
    x = np.asarray(x)
    xp = np.asarray(xp)
    fp = np.asarray(fp)
    idx = np.searchsorted(xp, x, side='right') - 1
    idx = np.clip(idx, 0, len(fp) - 1)
    return fp[idx]

# Differentially private MLE using divide-and-conquer (DAC)

def find_MLE_DAC(questions, responses, DAC_n):
    if DAC_n == 1:
        return find_MLE(questions, responses)

    assert len(questions) == len(responses)

    # Shuffle data
    indices = np.arange(len(questions))
    np.random.shuffle(indices)
    questions = questions[indices]
    responses = responses[indices]

    # Split into DAC_n groups
    DAC_size = int(len(questions) / DAC_n)
    DACs_questions = [questions[i:i + DAC_size] for i in range(0, len(questions), DAC_size)]
    DACs_responses = [responses[i:i + DAC_size] for i in range(0, len(responses), DAC_size)]

    # Compute MLE for each group
    DACs_MLE = [find_MLE(DACs_questions[i], DACs_responses[i]) for i in range(DAC_n)]

    # Handle missing labels in groups by padding with zeros
    K = np.max(responses)
    for i in range(DAC_n):
        if DACs_MLE[i][1].shape[1] < K:
            filled_MLE = np.append(DACs_MLE[i][1], np.zeros((DACs_MLE[i][1].shape[0], K - DACs_MLE[i][1].shape[1])), axis=1)
            support = DACs_MLE[i][0]
            DACs_MLE[i] = (support, filled_MLE)

    # Interpolate and average
    support = np.unique(np.concatenate([d[0] for d in DACs_MLE]))
    values = np.zeros((len(support), DACs_MLE[0][1].shape[1]))
    for k in range(DACs_MLE[0][1].shape[1]):
        values[:, k] = np.mean([step_interp(support, d[0], d[1][:, k]) for d in DACs_MLE], axis=0)

    return support, values

# Invert the ACRR randomized mechanism

def DP_invert(cdfs, r):
    return cdfs if r == 1 else cdfs / r

# Compute error metrics against ground truth

def find_infty_error(support, cdfs, truth_cdf, ratio,relative=True):
    truth_cdf_on_support = np.empty_like(cdfs)
    for i in range(len(truth_cdf)):
        truth_cdf_on_support[:, i] = truth_cdf[i](support) * ratio[i]
    if relative:
        error_on_support = np.max(np.abs(cdfs - truth_cdf_on_support) / ratio)
        error_on_next_support = np.max(np.abs(cdfs[1:] - truth_cdf_on_support[:-1]) / ratio)
    else:
        error_on_support = np.max(np.abs(cdfs - truth_cdf_on_support))
        error_on_next_support = np.max(np.abs(cdfs[1:] - truth_cdf_on_support[:-1]))
    return error_on_support, max(error_on_support, error_on_next_support)

# Run a single experiment with synthetic data

def experiment(n, ratio, truth_icdf, truth_cdf, ep, DAC_n):
    r = 1 - np.exp(-ep)  # ACRR retention probability

    # Set random seed uniquely per call
    random.seed(int.from_bytes(os.urandom(16), 'big'))
    np.random.seed(int.from_bytes(os.urandom(4), 'big'))

    K = len(ratio)
    n_labels = np.random.multinomial(n, ratio)
    labels = []
    locations = []

    # Generate data from mixture model
    for i in range(K):
        labels += [i] * int(n_labels[i])
        locations += list(truth_icdf[i](np.random.rand(int(n_labels[i]))))
    locations = np.array(locations)
    labels = np.array(labels) + 1  # Shift for censoring label (0 = censored)
    questions = np.random.rand(n)
    responses = labels.copy()
    responses[questions < locations] = 0  # Censor

    # Apply ACRR mechanism
    rnd_for_DP = np.random.rand(n)
    responses = np.where(rnd_for_DP < r, responses, 0)

    # Estimate MLE with DAC
    MLE_support, MLE_cdfs = find_MLE_DAC(questions, responses, DAC_n)
    MLE_cdfs_no_correction = DP_invert(MLE_cdfs, r)

    # Clip CDFs by ratio (oracle cap)
    MLE_cdfs_oracle_clip = np.minimum(MLE_cdfs_no_correction, ratio)
    _, oracle_error = find_infty_error(MLE_support, MLE_cdfs_oracle_clip, truth_cdf, ratio)

    # Stop MLE growth at 1
    MLE_cdfs_stop_at_1 = np.copy(MLE_cdfs_no_correction)
    sums = MLE_cdfs_stop_at_1.sum(axis=1)
    i = np.where(sums <= 1)[0].max(initial=-1)
    MLE_cdfs_stop_at_1[i + 1:] = MLE_cdfs_stop_at_1[i]
    _, stop_at_one_error = find_infty_error(MLE_support, MLE_cdfs_stop_at_1, truth_cdf, ratio)



    return n, ep, DAC_n, oracle_error, stop_at_one_error


if __name__ == '__main__':
    # Activate numpy to R conversion
    rpy2.robjects.numpy2ri.activate()
    R.source("icm.R")  # Load R source script

    # Parameter settings for the experiment
    sample_sizes = [1000, 2000, 5000, 10000, 20000, 50000, 100000, 200000, 500000,]
    eps = [1,2,3]  # Privacy levels
    DAC_candidate = [4]  # Number of DAC splits
    repeat = 100  # Repetitions

    # Generate task list
    tasks = list(itertools.product(range(repeat), sample_sizes, eps, DAC_candidate))
    random.shuffle(tasks)

    def run_task(task):
        _, s, ep, DAC_n = task
        return experiment(s, args_fixed["ratio"], args_fixed["truth_icdf"], args_fixed["truth_cdf"], ep=ep, DAC_n=DAC_n)

    # Run experiments in parallel
    with mp.Pool() as pool:
        results = list(tqdm(pool.imap_unordered(run_task, tasks, chunksize=1), total=len(tasks)))

    results = np.array(results)
    np.save("experiment.npy", results)  # Save output to disk
