import os

import numpy as np
import pandas as pd
from xgboost import XGBClassifier
from icecream import ic
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
from supplementary.util import clt_confidence_interval
from scipy.special import softmax


# fetch dataset
df = pd.read_csv("supplementary/datasets/Thyroid_Diff.csv")

Y = df["Recurred"] == "Yes"
X = df.drop("Recurred", axis=1)
ic(X.shape)
for column in X.columns:
    X[column] = X[column].astype("category").cat.codes

X.info()
ic(Y)

X_traincal, X_test, Y_traincal, Y_test = train_test_split(
    X, Y, test_size=0.6, random_state=0
)
X_train, X_cal, Y_train, Y_cal = train_test_split(
    X_traincal, Y_traincal, test_size=0.3, random_state=1
)

ic(X_train)
ic(X_train.shape[0])
ic(X_cal.shape[0])
ic(X_test.shape[0])

# Train the model
model = XGBClassifier(random_state=0).fit(X_train, Y_train)
ic(np.mean(model.predict(X_test) == Y_test))
ic(roc_auc_score(Y_test, model.predict_proba(X_test)[:, 1]))

# Conformal calibration
scores_cal = np.where(
    Y_cal == 1, -model.predict_proba(X_cal)[:, 1], -model.predict_proba(X_cal)[:, 0]
)
BIG_M = 1e99
ERR = 0.025

threshold = np.quantile(np.append(scores_cal, BIG_M), 1 - ERR)
ic(threshold)


def get_qtilde(n, alpha, gamma, epsilon, m):
    qtilde = (n + 1) * (1 - alpha) / (n * (1 - gamma * alpha)) + 2 / (
        epsilon * n
    ) * np.log(m / (gamma * alpha))
    qtilde = min(qtilde, 1 - 1e-12)
    return qtilde


def private_quantile(scores, alpha, epsilon, gamma, bins):
    n = scores.shape[0]
    epsilon_normed = epsilon * min(alpha, 1 - alpha)
    # Get the quantile
    qtilde = get_qtilde(n, alpha, gamma, epsilon, bins.shape[0])
    scores = scores.squeeze()
    score_to_bin = np.digitize(scores, bins)
    binned_scores = bins[np.minimum(score_to_bin, bins.shape[0] - 1)]
    w1 = np.digitize(binned_scores, bins)
    w2 = np.digitize(binned_scores, bins, right=True)
    # Clip bins
    w1 = np.maximum(np.minimum(w1, bins.shape[0] - 1), 0)
    w2 = np.maximum(np.minimum(w2, bins.shape[0] - 1), 0)
    lower_mass = np.bincount(w1, minlength=bins.shape[0]).cumsum() / qtilde
    upper_mass = (n - np.bincount(w2, minlength=bins.shape[0]).cumsum()) / (1 - qtilde)
    w = np.maximum(lower_mass, upper_mass)
    sampling_probabilities = softmax(-(epsilon_normed / 2) * w)
    # Check
    sampling_probabilities = sampling_probabilities / sampling_probabilities.sum()
    qhat = np.random.default_rng(1).choice(bins, p=sampling_probabilities)
    return qhat


def get_optimal_gamma(scores, n, alpha, epsilon):
    a = alpha**2
    b = -(alpha * epsilon * (n + 1) * (1 - alpha) / 2 + 2 * alpha)
    c = 1
    best_q = 1
    gamma1 = (-b + np.sqrt(b**2 - 4 * a * c)) / (2 * a)
    gamma2 = (-b - np.sqrt(b**2 - 4 * a * c)) / (2 * a)

    gamma1 = min(max(gamma1, 1e-12), 1 - 1e-12)
    gamma2 = min(max(gamma2, 1e-12), 1 - 1e-12)

    bins = np.linspace(-1, 0, 10)
    q1 = private_quantile(scores, alpha, epsilon, gamma1, bins)
    q2 = private_quantile(scores, alpha, epsilon, gamma2, bins)

    return (gamma1, q1) if q1 < q2 else (gamma2, q2)


optimal_gamma, _ = get_optimal_gamma(scores_cal, len(scores_cal), ERR, 1)
private_threshold = private_quantile(
    scores_cal, ERR, 1, optimal_gamma, np.linspace(-1, 0, 10)
)

# Conformal PPI
contains1_test = -model.predict_proba(X_test)[:, 1] <= threshold
contains0_test = -model.predict_proba(X_test)[:, 0] <= threshold
empty = (~contains1_test) & (~contains0_test)
ic(np.mean(contains1_test & contains0_test))
ic(np.mean(contains1_test & (~contains0_test)))
ic(np.mean((~contains1_test) & contains0_test))
ic(np.mean(empty))

private_contains1_test = -model.predict_proba(X_test)[:, 1] <= private_threshold
private_contains0_test = -model.predict_proba(X_test)[:, 0] <= private_threshold
private_empty = (~private_contains1_test) & (~private_contains0_test)

AGE_GROUPS = [
    (0, 20),
    (20, 30),
    (30, 40),
    (40, 50),
]

MALE = 1
FEMALE = 0
ic(np.sum(X_test["Gender"] == MALE))
ic(np.sum(X_test["Gender"] == FEMALE))

results = {}
private_results = {}
for gender in [MALE, FEMALE, None]:
    subresults = {}
    private_subresults = {}
    for age_group in AGE_GROUPS:
        selector = (age_group[0] <= X_test["Age"]) & (X_test["Age"] <= age_group[1])
        if gender is not None:
            selector = selector & (X_test["Gender"] == gender)

        M = 1

        imputed = np.mean(model.predict(X_test[selector]))

        lower_mean, _, _ = clt_confidence_interval(
            np.where(
                empty[selector],
                model.predict(X_test[selector]),
                np.where(contains0_test[selector], 0.0, 1.0),
            )
        )
        _, _, upper_mean = clt_confidence_interval(
            np.where(
                empty[selector],
                model.predict(X_test[selector]),
                np.where(contains1_test[selector], 1.0, 0.0),
            )
        )
        lower_bound = max(lower_mean - M * ERR, 0)
        upper_bound = min(upper_mean + M * ERR, 1)

        private_lower_mean, _, _ = clt_confidence_interval(
            np.where(
                private_empty[selector],
                model.predict(X_test[selector]),
                np.where(private_contains0_test[selector], 0.0, 1.0),
            )
        )
        _, _, private_upper_mean = clt_confidence_interval(
            np.where(
                private_empty[selector],
                model.predict(X_test[selector]),
                np.where(private_contains1_test[selector], 1.0, 0.0),
            )
        )
        private_lower_bound = max(private_lower_mean - M * ERR, 0)
        private_upper_bound = min(private_upper_mean + M * ERR, 1)

        print(
            f"Age group {age_group}, {('male' if gender == MALE else 'female') if gender is not None else 'both genders'}: [{lower_bound} .. {upper_bound}] {imputed} ({np.sum(selector)} observations)"
        )

        subresults[age_group] = lower_bound, imputed, upper_bound
        private_subresults[age_group] = (
            private_lower_bound,
            imputed,
            private_upper_bound,
        )

    results[gender] = subresults
    private_results[gender] = private_subresults

plt.figure(figsize=(10, 2.6))

# for i, gender in enumerate([None, MALE, FEMALE]):
for i, gender in enumerate([MALE, FEMALE]):
    for j, age_group in enumerate(AGE_GROUPS):
        if gender == MALE:
            color = "#4363d8"
        elif gender == FEMALE:
            color = "#e6194B"
        else:
            color = "#a9a9a9"

        pos = 2.5 * (4.6 * i + j) + 1
        pos2 = 2.5 * (4.6 * i + j)

        plt.bar([pos], [results[gender][age_group][1]], alpha=0.2, color=color)
        plt.bar([pos2], [private_results[gender][age_group][1]], alpha=0.4, color=color)
        plt.plot(
            [pos, pos],
            [results[gender][age_group][0], results[gender][age_group][2]],
            color=color,
            alpha=0.6,
        )
        plt.plot(
            [pos2, pos2],
            [
                private_results[gender][age_group][0],
                private_results[gender][age_group][2],
            ],
            color=color,
        )
        # plt.scatter([pos], [results[gender][age_group][1]], color=color)

        plt.text(
            (pos + pos2) / 2,
            -0.08,
            f"{age_group[0]}-{age_group[1]} yrs",
            horizontalalignment="center",
            # color=color,
            alpha=0.8,
            size="small",
        )

plt.text(
    2.5 * 1.5 + 0.5,
    0.97,
    "Male",
    horizontalalignment="center",
    color="#4363d8",
)
plt.text(
    4.6 * 2.5 + 2.5 * 1.5 + 0.5,
    0.97,
    "Female",
    horizontalalignment="center",
    color="#e6194B",
)

plt.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)

plt.fill_between(
    [2, 3],
    [-100, -101],
    [-100, -101],
    color="k",
    alpha=0.4,
    label="Private Conformal Prediction-Powered CIs",
)
plt.fill_between(
    [2, 3],
    [-100, -101],
    [-100, -101],
    color="k",
    alpha=0.2,
    label="Non-private Conformal Prediction-Powered CIs",
)
plt.legend(loc="upper center", bbox_to_anchor=(0.5, 1.155), ncol=2, frameon=False, fontsize="small")

plt.ylabel("Prevalence")
plt.ylim(-0.11, 1.1)

if not os.path.exists("results"):
    os.makedirs("results")

plt.savefig("results/fig3.png", dpi=300, bbox_inches="tight")
