import os
import numpy as np  # type: ignore
from icecream import ic  # type: ignore
from ucimlrepo import fetch_ucirepo  # type: ignore
from sklearn.model_selection import train_test_split  # type: ignore
from sklearn.ensemble import HistGradientBoostingClassifier  # type: ignore
from sklearn.linear_model import LogisticRegression  # type: ignore
import pickle
from .evalue_common import *


N = 100_000

pi = 300 / N

low_budget_subsample = np.sort(
    np.random.default_rng(0).choice(N, size=round(pi * N), replace=False)
)
ic(low_budget_subsample.shape[0])

if not os.path.exists("datasets/data.fig2.pickle"):
    dataset = fetch_ucirepo(id=31)
    ic(dataset.data.features)
    ic(dataset.data.targets)
    x_full = dataset.data.features.to_numpy()
    y_full = dataset.data.targets["Cover_Type"].to_numpy() - 1
    ic(np.unique(y_full))

    x_traintest, x_later, y_traintest, y_later = train_test_split(
        x_full, y_full, test_size=0.4, random_state=0
    )
    x_train, x_test, y_train, y_test = train_test_split(
        x_traintest, y_traintest, test_size=0.25, random_state=0
    )
    ic(len(y_later))

    model = RandomForestClassifier(random_state=0, n_jobs=-2).fit(x_train, y_train)
    residual_model = RandomForestClassifier(random_state=0, n_jobs=-2).fit(
        x_test, y_test != model.predict(x_test)
    )

    risk_at_test = ic(np.mean(model.predict(x_test) != y_test))
    TOL = 0.05

    def enpoison(x, y):
        problematic_samples_x = x[residual_model.predict_proba(x)[:, 1] >= 0.5]
        problematic_samples_y = y[residual_model.predict_proba(x)[:, 1] >= 0.5]

        problematic_resample_is = np.random.default_rng(0).choice(
            len(problematic_samples_y), size=len(y), replace=True
        )
        problematic_resample_y = problematic_samples_y[problematic_resample_is]
        problematic_resample_x = problematic_samples_x[problematic_resample_is, :]

        ts = np.linspace(0, (len(y_later) / N) * 1, len(y_later))
        poisoning_probs = np.where(ts >= 0.2, ((ts + 1) / 5 + 0.1) ** 2, 0)
        should_poison = np.random.default_rng(0).random(size=len(y)) <= poisoning_probs
        return np.where(
            np.tile(should_poison, (x.shape[1], 1)).T, problematic_resample_x, x
        ), np.where(should_poison, problematic_resample_y, y)

    x_later_poisoned, y_later_poisoned = enpoison(x_later, y_later)

    realized_losses_later = model.predict(x_later) != y_later
    realized_losses_later_poisoned = model.predict(x_later) != y_later_poisoned

    eprocess_poisoned = rollout_eprocess(
        OneSidedMeanSingleEvalues(
            candidate_mean=risk_at_test + TOL,
            bettor=StandardAGrapaBettor(),
        ),
        realized_losses_later_poisoned[:N][low_budget_subsample],
    )
    eprocess_poisoned_ppi = rollout_eprocess_cppi(
        OneSidedMeanSingleEvalues(
            candidate_mean=risk_at_test + TOL,
            bettor=StandardEwmaAGrapaBettor(momentum=0.01),
        ),
        realized_losses_later_poisoned[:N],
        x_later_poisoned[:N],
        regressor=RandomForest(),
        conformal_err=0.001,
        model=residual_model,
        labelled_samples=low_budget_subsample,
    )

    ic(eprocess_poisoned)
    ic(eprocess_poisoned_ppi)

    eprocess = rollout_eprocess(
        OneSidedMeanSingleEvalues(
            candidate_mean=risk_at_test + TOL,
            bettor=StandardAGrapaBettor(),
        ),
        realized_losses_later[:N][low_budget_subsample],
    )
    eprocess_ppi = rollout_eprocess_cppi(
        OneSidedMeanSingleEvalues(
            candidate_mean=risk_at_test + TOL,
            bettor=StandardEwmaAGrapaBettor(momentum=0.01),
        ),
        realized_losses_later[:N],
        x_later[:N],
        regressor=RandomForest(),
        conformal_err=0.001,
        model=residual_model,
        labelled_samples=low_budget_subsample,
    )

    ic(eprocess)
    ic(eprocess_ppi)

    with open("datasets/data.fig2.pickle", "wb") as file:
        pickle.dump(
            (
                eprocess,
                eprocess_poisoned,
                eprocess_ppi,
                eprocess_poisoned_ppi,
                realized_losses_later,
                realized_losses_later_poisoned,
            ),
            file,
        )
else:
    with open("datasets/data.fig2.pickle", "rb") as file2:
        (
            eprocess,
            eprocess_poisoned,
            eprocess_ppi,
            eprocess_poisoned_ppi,
            realized_losses_later,
            realized_losses_later_poisoned,
        ) = pickle.load(file2)


def fixup(array):
    BIG_M = 1e150
    return np.where(array >= BIG_M, BIG_M, array)


eprocess = fixup(eprocess)
eprocess_poisoned = fixup(eprocess_poisoned)
eprocess_ppi = fixup(eprocess_ppi)
eprocess_poisoned_ppi = fixup(eprocess_poisoned_ppi)

smoothed_samples = pd.Series(realized_losses_later[:N]).ewm(halflife=60).mean()
smoothed_samples_poisoned = (
    pd.Series(realized_losses_later_poisoned[:N]).ewm(halflife=60).mean()
)

PALETTE_20 = [
    "#e6194b",
    "#3cb44b",
    "#ffe119",
    "#4363d8",
    "#f58231",
    "#911eb4",
    "#46f0f0",
    "#f032e6",
    "#bcf60c",
    "#fabebe",
    "#008080",
    "#e6beff",
    "#9a6324",
    "#fffac8",
    "#800000",
    "#aaffc3",
    "#808000",
    "#ffd8b1",
    "#000075",
    "#808080",
    "#a9a9a9",
]

COLORS = [PALETTE_20[1], PALETTE_20[-3], PALETTE_20[-1]]

PPI_COLOR = "#469990"
PPI_ACTIVE_COLOR = COLORS[1]
CLASSICAL_COLOR = COLORS[2]
IMPUTED_COLOR = COLORS[2]
EWMA_COLOR = "#4363d8"

fig, axs = plt.subplots(2, 2, figsize=(12, 4.1), sharex=True, height_ratios=(1, 1.5))

axs[1, 0].plot(
    np.linspace(0, 1, len(smoothed_samples)),
    smoothed_samples,
    alpha=0.5,
    label="Exp. weighted moving average of loss",
    c=EWMA_COLOR,
)
axs[1, 0].scatter(
    low_budget_subsample / len(smoothed_samples),
    smoothed_samples[low_budget_subsample],
    marker="x",
    color="black",
    label="Labelled data",
)
axs[1, 1].plot(
    np.linspace(0, 1, len(smoothed_samples_poisoned)),
    smoothed_samples_poisoned,
    alpha=0.5,
    c=EWMA_COLOR,
)
axs[1, 1].scatter(
    low_budget_subsample / len(smoothed_samples_poisoned),
    smoothed_samples_poisoned[low_budget_subsample],
    marker="x",
    color="black",
)
axs[1, 0].set_ylim(0.0, 0.3)
axs[1, 1].set_ylim(0.0, 0.3)

axs[0, 0].plot(
    np.linspace(0, 1, len(eprocess_ppi)),
    eprocess_ppi,
    color=PPI_COLOR,
    label="Conformal PPI (Ours)",
)
axs[0, 0].plot(
    np.linspace(0, 1, len(eprocess)),
    eprocess,
    color=CLASSICAL_COLOR,
    label="Only labelled samples",
)

axs[0, 0].set_yscale("symlog", linthresh=0.01)
axs[0, 0].set_ylabel("E-values")
axs[0, 0].set_title("Null hypothesis is true (e-values should be small)")
axs[0, 0].axhline(20, linestyle="--", color="black", alpha=0.4)
axs[0, 0].set_ylim(0, 1000)


axs[0, 1].plot(
    np.linspace(0, 1, len(eprocess_poisoned)),
    eprocess_poisoned,
    color=CLASSICAL_COLOR,
)
axs[0, 1].plot(
    np.linspace(0, 1, len(eprocess_poisoned_ppi)),
    eprocess_poisoned_ppi,
    color=PPI_COLOR,
)

axs[0, 1].axvline(
    np.argmax(eprocess_poisoned >= 20) / len(eprocess_poisoned),
    alpha=0.7,
    linestyle="dotted",
    color=CLASSICAL_COLOR,
)
axs[0, 1].axvline(
    np.argmax(eprocess_poisoned_ppi >= 20) / len(eprocess_poisoned_ppi),
    alpha=0.7,
    linestyle="dotted",
    color=PPI_COLOR,
)

axs[0, 1].set_yscale("symlog", linthresh=0.01)
axs[0, 1].set_ylabel("E-values")
axs[0, 1].set_title("Null hypothesis is false (e-values should be large)")
axs[0, 1].axhline(20, linestyle="--", color="black", alpha=0.4)

axs[0, 0].set_ylim(-0.1, 1e20)
axs[0, 1].set_ylim(-0.1, 1e20)

# set x and y labels
axs[1, 0].set_xlabel("Time", fontsize=12)
axs[1, 1].set_xlabel("Time", fontsize=12)
axs[1, 0].set_ylabel("Risk")
axs[1, 1].set_ylabel("Risk")

fig.legend(
    loc="lower center", bbox_to_anchor=(0.5, -0.06), ncols=4, frameon=False, fontsize=12
)

fig.tight_layout()

# Manually setting the ticks to avoid clutter
for ax in [axs[0, 0], axs[0, 1]]:
    max_val = ax.get_ylim()[1]
    step_size = 4
    ticks = [0, 1, 1e4, 1e8, 1e12, 1e16, 1e20]
    ax.set_yticks(ticks)
    ax.tick_params(axis="y", which="major", labelsize=8)

plt.savefig("results/fig4.png", bbox_inches="tight", dpi=300)
