import multiprocessing as mp
from tqdm import tqdm
import numpy as np, pandas as pd
from sklearn.model_selection import train_test_split
from supplementary.util import clt_confidence_interval
from supplementary.util import (
    conformal_interval_quantile,
    power_interval_quantile,
    ppipp_interval_quantile,
    fab_interval_quantile,
    classical_interval_quantile,
)
import os

# download gene expression dataset
dataset_name = "gene_expression"
data = np.load(f"supplementary/datasets/{dataset_name}.npz")
result_file = f"results/quantile_estimation_{dataset_name}.csv"

# We use the transformer model developed and trained by Vaishnav et al.
# to predict gene expression level.
Yhat_full = data["Yhat"].ravel()
Y_full = data["Y"].ravel()
full_size = Y_full.shape[0]

# Define the range of possible thetas
thetas_global = np.linspace(Y_full.min(), Y_full.max(), 120)

# hyperparameters
Q = 0.5  # quantile to estimate, i.e. median
ALPHA = 0.01  # significance level for confidence intervals of the median
ERR_MULT = 1.01  # multiplier for the error in the conformal prediction intervals
CI_CONSTRUCTOR = clt_confidence_interval  # confidence interval used
N_SEEDS = 100  # number of seeds for the experiment
CAL_SIZE = 8  # size of the calibration set
M = 1.0  # size of the support of \psi(y) , ie sup(\psi(y)) - inf(\psi(y)), in the case of quantile estimation, M = 1

true_theta = np.quantile(Y_full, Q)  # true median

BIG_M = 1e99

width = lambda arr: (arr[-1] - arr[0]) if len(arr) else 0
cover = lambda arr: (arr[0] <= true_theta) & (true_theta <= arr[-1])


def run(seed):

    # Since we already have a predictive model, we only split our data
    # into calibration and test sets
    test_ratio = 1 - CAL_SIZE / Yhat_full.shape[0]

    Yhat_cal, Yhat_test, Y_cal, Y_test = train_test_split(
        Yhat_full, Y_full, test_size=test_ratio, random_state=seed
    )

    # Calculates the err we want in our conformal prediction intervals
    err = ERR_MULT / CAL_SIZE

    # defines quantile estimation as a Z-estimation problem
    ind = lambda y, theta: (y - theta <= 0).astype(int)
    psi = lambda y, theta: ind(y, theta) - Q

    # Calculates the conformal scores for the calibration set
    scores_cal = np.abs(Y_cal - Yhat_cal)

    # Calculates conformal power prediction intervals for the median
    CPPI = conformal_interval_quantile(
        psi=psi,
        scores_cal=scores_cal,
        err=err,
        Yhat_test=Yhat_test,
        thetas=thetas_global,
        ci_constructor=CI_CONSTRUCTOR,
        alpha=ALPHA,
        M=M,
    )

    # Calculates vanilla power prediction intervals for the median
    PPI = power_interval_quantile(
        psi=psi,
        Y_cal=Y_cal,
        Yhat_cal=Yhat_cal,
        Yhat_test=Yhat_test,
        thetas=thetas_global,
        ci_constructor=CI_CONSTRUCTOR,
        alpha=ALPHA,
        M=M,
    )

    # Using Only labelled data to construct a confidence interval for the median of the target
    CCI = classical_interval_quantile(
        psi=psi,
        Y_cal=Y_cal,
        thetas=thetas_global,
        ci_constructor=CI_CONSTRUCTOR,
        alpha=ALPHA,
        M=M,
    )

    # PPI++ with a data split for the tuning parameter (asymptotic)
    PPIPP_split = ppipp_interval_quantile(
        psi=psi,
        Y_cal=Y_cal,
        Yhat_cal=Yhat_cal,
        Yhat_test=Yhat_test,
        thetas=thetas_global,
        ci_constructor=CI_CONSTRUCTOR,
        alpha=ALPHA,
        M=M,
        split=True,
    )
    # FAB
    FAB = fab_interval_quantile(
        q=Q,
        Y_cal=Y_cal,
        Yhat_cal=Yhat_cal,
        Yhat_test=Yhat_test,
        thetas=thetas_global,
        ci_constructor=CI_CONSTRUCTOR,
        alpha=ALPHA,
        M=M,
    )

    w1, w2, w3, w4, w5 = map(width, (CPPI, PPI, CCI, PPIPP_split, FAB))
    (
        CPPI_coverage,
        PPI_coverage,
        CCI_coverage,
        PPIPP_split_coverage,
        FAB_coverage,
    ) = map(cover, (CPPI, PPI, CCI, PPIPP_split, FAB))

    return dict(
        seed=seed,
        cal_size=CAL_SIZE,
        test_size=Y_test.shape[0],
        CPPI_lower=CPPI[0],
        CPPI_upper=CPPI[-1],
        PPI_lower=PPI[0],
        PPI_upper=PPI[-1],
        CCI_lower=CCI[0],
        CCI_upper=CCI[-1],
        PPIPP_split_lower=PPIPP_split[0],
        PPIPP_split_upper=PPIPP_split[-1],
        FAB_lower=FAB[0],
        FAB_upper=FAB[-1],
        CPPI_width=w1,
        PPI_width=w2,
        CCI_width=w3,
        PPIPP_split_width=w4,
        FAB_width=w5,
        CPPI_PPI=w1 / w2 if w2 else np.nan,
        CPPI_CCI=w1 / w3 if w3 else np.nan,
        PPI_CCI=w2 / w3 if w3 else np.nan,
        CPPI_coverage=CPPI_coverage,
        PPI_coverage=PPI_coverage,
        CCI_coverage=CCI_coverage,
        PPIPP_split_coverage=PPIPP_split_coverage,
        FAB_coverage=FAB_coverage,
        true_theta=true_theta,
    )


if __name__ == "__main__":

    if not os.path.exists("results"):
        os.makedirs("results")

    n_workers = mp.cpu_count()
    mp.set_start_method("fork", force=True)

    # with mp.Pool(processes=n_workers) as pool:
    #     results = list(tqdm(pool.imap(run, range(N_SEEDS)), total=N_SEEDS))
    results = [run(seed) for seed in tqdm(range(N_SEEDS))]

    df = pd.DataFrame(results).to_csv(result_file, index=False)
