import multiprocessing as mp
from tqdm import tqdm
import numpy as np, pandas as pd
from sklearn.model_selection import train_test_split
from supplementary.util import clt_confidence_interval
from supplementary.util import (
    conformal_interval_mean,
    power_interval_mean,
    ppipp_interval_mean,
    fab_interval_mean,
    classical_interval_mean,
)
import xgboost as xgb
import os

# download phishing dataset
dataset_name = "phishing_dataset_numeric"
data = np.load(f"supplementary/datasets/{dataset_name}.npz")
result_file = f"results/mean_estimation_{dataset_name}.csv"

X_full = data["X"]
Y_full = data["y"].ravel()
true_theta = Y_full.mean()  # we are interested in the prevalence of phishing domains
full_size = Y_full.shape[0]

# hyperparameters
ALPHA = 0.01  # significance level for confidence intervals of the mean
ERR_MULT = 1.01  # multiplier for the error in the conformal prediction intervals
CI_CONSTRUCTOR = clt_confidence_interval  # confidence interval used
N_SEEDS = 100  # number of seeds for the experiment
TRAIN_RATIO = 0.995  # ratio of the training set size to the total size
CAL_SIZE = 300  # size of the calibration set
M = 1.0  # size of the support of the target, ie sup(y) - inf(y), since y is binary, M = 1

BIG_M = 1e99  # big number for conformal scores

width = lambda arr: (arr[-1] - arr[0]) if len(arr) else 0
cover = lambda arr: arr[0] <= true_theta <= arr[1]


def run(seed):

    # Split our data into training, calibration and test sets
    # Where we use the training set to simulate already having a classifier for phishing domains
    X_tr, X_caltest, Y_tr, Y_caltest = train_test_split(
        X_full, Y_full, test_size=1 - TRAIN_RATIO, random_state=seed
    )

    max_cal = Y_caltest.shape[0]
    if CAL_SIZE > max_cal:
        raise ValueError(f"cal_size {CAL_SIZE} > restante disponível {max_cal}")

    # Spliting the rest of the data into calibration and test sets
    # For sake the sake of the experiment, we will drop the labels of the test set
    # We use the calibration and test set differently for each method
    test_ratio = 1 - CAL_SIZE / max_cal
    X_cal, X_test, Y_cal, Y_test = train_test_split(
        X_caltest, Y_caltest, test_size=test_ratio, random_state=seed
    )

    # train the model on the training
    model = xgb.XGBClassifier(eval_metric="logloss")
    model.fit(X_tr, Y_tr)

    # creates imputted data for the calibration and test sets using our predictive model
    Yhat_cal = model.predict_proba(X_cal)[:, 1]
    Yhat_test = model.predict_proba(X_test)[:, 1]

    # Calculates the conformal scores for the calibration set, where we do have the labels
    scores_cal = np.where(Y_cal == 0, Yhat_cal, 1 - Yhat_cal)

    # Calculates the err we want in our conformal prediction intervals
    err = ERR_MULT / CAL_SIZE

    # Calculates conformal power prediction intervals
    # Here we begin by calculating the threshold for the conformal scores,
    # This threshold is used to construct a conformal prediction model C: X-> 2^{0,1}
    # This is where the labelled data is used for in our method
    # We proceed by constructing a confidence interval for both the mean of sup C(x) and inf C(x),
    # using the unlabelled data and finally combine both to create
    # a confidence interval for the mean of the target
    CPPI = conformal_interval_mean(
        scores_cal=scores_cal,
        err=err,
        Yhat_test=Yhat_test,
        ci_constructor=CI_CONSTRUCTOR,
        alpha=ALPHA,
        M=M,
    )

    # Vanilla PPI, instead uses the labelled data to construct a confidence interval
    # for the rectifier of the mean. We then use the imputed data to construct a
    # confidence interval for the mean using the unlabelled data. And finally
    # combine both to construct a valid confidence interval for the mean of the target
    PPI = power_interval_mean(
        Y_cal=Y_cal,
        Yhat_cal=Yhat_cal,
        Yhat_test=Yhat_test,
        ci_constructor=CI_CONSTRUCTOR,
        alpha=ALPHA,
        M=M,
    )

    # Using Only labelled data to construct a confidence interval for the mean of the target
    CCI = classical_interval_mean(
        Y_cal=Y_cal, ci_constructor=CI_CONSTRUCTOR, alpha=ALPHA, M=M
    )

    # PPI++ with a split for the tuning parameter (asymptotic)
    PPIPP_split = ppipp_interval_mean(
        Y_cal=Y_cal,
        Yhat_cal=Yhat_cal,
        Yhat_test=Yhat_test,
        ci_constructor=CI_CONSTRUCTOR,
        alpha=ALPHA,
        M=M,
        split=True,
    )
    # FAB
    FAB = fab_interval_mean(
        Y_cal=Y_cal,
        Yhat_cal=Yhat_cal,
        Yhat_test=Yhat_test,
        ci_constructor=CI_CONSTRUCTOR,
        alpha=ALPHA,
        M=M,
    )

    w1, w2, w3, w4, w5 = map(width, (CPPI, PPI, CCI, PPIPP_split, FAB))
    (
        CPPI_coverage,
        PPI_coverage,
        CCI_coverage,
        PPIPP_split_coverage,
        FAB_coverage,
    ) = map(cover, (CPPI, PPI, CCI, PPIPP_split, FAB))

    return dict(
        seed=seed,
        cal_size=CAL_SIZE,
        test_size=Y_test.shape[0],
        CPPI_lower=CPPI[0],
        CPPI_upper=CPPI[-1],
        PPI_lower=PPI[0],
        PPI_upper=PPI[-1],
        CCI_lower=CCI[0],
        CCI_upper=CCI[-1],
        PPIPP_split_lower=PPIPP_split[0],
        PPIPP_split_upper=PPIPP_split[-1],
        FAB_lower=FAB[0],
        FAB_upper=FAB[-1],
        CPPI_width=w1,
        PPI_width=w2,
        CCI_width=w3,
        PPIPP_split_width=w4,
        FAB_width=w5,
        CPPI_PPI=w1 / w2 if w2 else np.nan,
        CPPI_CCI=w1 / w3 if w3 else np.nan,
        PPI_CCI=w2 / w3 if w3 else np.nan,
        CPPI_coverage=CPPI_coverage,
        PPI_coverage=PPI_coverage,
        CCI_coverage=CCI_coverage,
        PPIPP_split_coverage=PPIPP_split_coverage,
        FAB_coverage=FAB_coverage,
        true_theta=true_theta,
    )


if __name__ == "__main__":

    if not os.path.exists("results"):
        os.makedirs("results")

    n_workers = mp.cpu_count()
    mp.set_start_method("fork", force=True)

    # with mp.Pool(processes=n_workers) as pool:
    #     results = list(tqdm(pool.imap(run, range(N_SEEDS)), total=N_SEEDS))
    results = [run(seed) for seed in tqdm(range(N_SEEDS))]

    df = pd.DataFrame(results).to_csv(result_file, index=False)
