import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from supplementary.util import clt_confidence_interval
from supplementary.util import (
    conformal_interval_quantile,
    power_interval_quantile,
    ppipp_interval_quantile,
    fab_interval_quantile,
    classical_interval_quantile,
)
import xgboost as xgb
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from icecream import ic


# ------------------------------------------
# 1) Data and helpers
# ------------------------------------------
# def generate_mock_data(n_samples=50000, noise_std=1.0, random_state=42, d=4):
#     rng = np.random.RandomState(random_state)
#     X = rng.normal(loc=0.0, scale=3.0, size=(n_samples, d))
#     c = rng.normal(loc=0.0, scale=1.0, size=(d,))
#     noise = rng.normal(loc=0.0, scale=noise_std, size=n_samples)
#     Y = X @ c + noise
#     return X, Y


def generate_mock_data(n_samples=50000, noise_std=1.0, random_state=1, d=5):
    """
    X ~ N(0, I_d), d-dim.
    Base inclui:
      - componentes originais
      - ReLU(x)
      - termos quadráticos x_i^2
      - seno de combinações lineares
    Y = c^T phi(X) + noise
    """
    rng = np.random.RandomState(random_state)

    # X ~ N(0, I_d)
    # L = rng.normal(loc=0.0, scale=1.0, size=(d, d))
    # Sigma = L @ L.T
    # X = rng.multivariate_normal(np.zeros(d), Sigma, size=(n_samples,))
    X = rng.normal(loc=0.0, scale=10.0, size=(n_samples, d))

    # componentes originais
    feats = [X]  # shape (n_samples, d)

    # # ReLU(x)
    # feats.append(np.maximum(X, 0.0))

    # termos quadráticos
    # feats.append(X**2)
    # feats.append(X**3)

    # # seno de combinações aleatórias de X (tipo random features)
    # W = rng.normal(loc=0.0, scale=1.0, size=(d, 5))  # 5 combinações
    # feats.append(np.sin(X @ W))  # shape (n_samples, 5)

    # concatena tudo
    Phi = np.concatenate(feats, axis=1)  # shape (n_samples, p)

    # coeficientes da base
    c = rng.normal(loc=0.0, scale=1.0, size=(Phi.shape[1],))

    # ruído
    noise = rng.normal(loc=0.0, scale=noise_std, size=n_samples)

    optimal_model = lambda x: x @ c
    Y = Phi @ c + noise

    return X, Y, optimal_model


ALPHA = 0.01  # CI level
CI_CONSTRUCTOR = clt_confidence_interval
BIG_M = 1e99
q = 0.5  # quantile (median)
M = 1.0  # range of psi in the quantile case

ind = lambda y, theta: (y - theta <= 0).astype(int)
psi = lambda y, theta: ind(y, theta) - q

width = lambda arr: (
    (arr[-1] - arr[0]) if (arr is not None and len(arr) >= 2) else np.nan
)

# ------------------------------------------
# 2) Dataset sizes
# ------------------------------------------
n_cal = 200
n_test = 10_000
n_train = 100000
n_samples = n_cal + n_test + n_train

# ------------------------------------------
# 3) Experiment hyperparameters
# ------------------------------------------
# gamma = 0.01  # 'err' parameter for CPPI
gamma = 1.1 / n_cal  # 'err' parameter for CPPI
ic(gamma)
n_estimators = 400  # fixed number of boosters for XGBoost

# Grid over noise_std (this will be the x-axis)
noise_stds = np.linspace(0.0, 0.5, 10)

# Result arrays: interval widths for each method
N_REPEATS = 20
cppi_widths = np.full((len(noise_stds), N_REPEATS), np.nan, dtype=float)
ppi_widths = np.full((len(noise_stds), N_REPEATS), np.nan, dtype=float)
cci_widths = np.full((len(noise_stds), N_REPEATS), np.nan, dtype=float)
ppipp_widths = np.full((len(noise_stds), N_REPEATS), np.nan, dtype=float)
fab_widths = np.full((len(noise_stds), N_REPEATS), np.nan, dtype=float)

mse_test = np.full((len(noise_stds), N_REPEATS), np.nan, dtype=float)

# ------------------------------------------
# 4) Loop over noise_std
# ------------------------------------------
rng_generator = np.random.default_rng(0)
for i, NOISE_STD in enumerate(tqdm(noise_stds, desc="noise_std")):
    for i_repeat in range(N_REPEATS):
        # Generate data for this noise level
        X, Y, optimal_model = generate_mock_data(
            n_samples=n_samples,
            noise_std=NOISE_STD,
            random_state=rng_generator.integers(100_000),
        )

        # Train / calibration / test split
        X_train, X_temp, Y_train, Y_temp = train_test_split(
            X, Y, train_size=n_train, random_state=0
        )

        X_cal, X_test, Y_cal, Y_test = train_test_split(
            X_temp, Y_temp, test_size=n_test, random_state=0
        )

        # Train linear model
        # model = LinearRegression()
        model = xgb.XGBRegressor(n_estimators=n_estimators)
        model.fit(X_train, Y_train)

        # Predictions
        # Yhat_cal = model.predict(X_cal)
        # Yhat_test = model.predict(X_test)
        Yhat_cal = optimal_model(X_cal)
        Yhat_test = optimal_model(X_test)

        mse_test[i, i_repeat] = np.mean((Y_test - Yhat_test) ** 2)
        res = Y_cal - Yhat_cal

        res_abs = np.abs(res)
        scale_model = xgb.XGBRegressor(n_estimators=50)
        scale_model.fit(X_cal, res_abs)

        sigma_hat_cal = scale_model.predict(X_cal)
        sigma_hat_test = scale_model.predict(X_test)

        sigma_hat_cal = np.clip(sigma_hat_cal, 1e-6, None)
        sigma_hat_test = np.clip(sigma_hat_test, 1e-6, None)

        # 3) final adaptive pinball scores
        scores_cal = res_abs

        # Theta grid adapted to the observed support (margin 3 * NOISE_STD)
        theta_min = float(Y_test.min() - 3 * NOISE_STD)
        theta_max = float(Y_test.max() + 3 * NOISE_STD)
        thetas = np.linspace(theta_min, theta_max, 1000)

        # ---- Intervals for this noise_std ----
        # try:
        # CPPI
        CPPI = conformal_interval_quantile(
            psi=psi,
            scores_cal=scores_cal,
            err=gamma,
            Yhat_test=Yhat_test,
            thetas=thetas,
            ci_constructor=CI_CONSTRUCTOR,
            alpha=ALPHA,
            M=M,
        )

        # PPI
        PPI = power_interval_quantile(
            psi=psi,
            Y_cal=Y_cal,
            Yhat_cal=Yhat_cal,
            Yhat_test=Yhat_test,
            thetas=thetas,
            ci_constructor=CI_CONSTRUCTOR,
            alpha=ALPHA,
            M=M,
        )

        # CCI (classical, only labeled data)
        CCI = classical_interval_quantile(
            psi=psi,
            Y_cal=Y_cal,
            thetas=thetas,
            ci_constructor=CI_CONSTRUCTOR,
            alpha=ALPHA,
            M=M,
        )

        # PPIPP (split version)
        PPIPP_split = ppipp_interval_quantile(
            psi=psi,
            Y_cal=Y_cal,
            Yhat_cal=Yhat_cal,
            Yhat_test=Yhat_test,
            thetas=thetas,
            ci_constructor=CI_CONSTRUCTOR,
            alpha=ALPHA,
            M=M,
            split=True,
        )

        # FAB
        FAB = fab_interval_quantile(
            q=q,
            Y_cal=Y_cal,
            Yhat_cal=Yhat_cal,
            Yhat_test=Yhat_test,
            thetas=thetas,
            ci_constructor=CI_CONSTRUCTOR,
            alpha=ALPHA,
            M=M,
        )

        # Store widths
        cppi_widths[i, i_repeat] = width(CPPI)
        ppi_widths[i, i_repeat] = width(PPI)
        cci_widths[i, i_repeat] = width(CCI)
        ppipp_widths[i, i_repeat] = width(PPIPP_split)
        fab_widths[i, i_repeat] = width(FAB)

COLOR_CPPI = "#469990"
COLOR_FAB = "#000075"
COLOR_PPIPP = "#4363d8"
COLOR_PPI = "#e6194B"
COLOR_CLASS = "#a9a9a9"


def plot_data(x, y, *, label, marker, color):
    mean = np.nanmean(y, axis=1)
    std = np.nanstd(y, axis=1)

    plt.fill_between(
        x,
        mean - std,
        mean + std,
        color=color,
        alpha=0.2,
    )
    plt.plot(
        x,
        mean,
        color=color,
        marker=marker,
        label=label,
    )


# ------------------------------------------
# 5) Plot: interval widths vs noise_std
# ------------------------------------------
# plt.figure(figsize=(8, 5))
plt.figure()
plot_data(
    noise_stds,
    ic(cppi_widths),
    color=COLOR_CPPI,
    marker="o",
    label="Conformal PPI (Ours)",
)
plot_data(
    noise_stds,
    ic(fab_widths),
    color=COLOR_FAB,
    marker="d",
    label="FAB (Cortinovis & Caron, 2025)",
)
# plot_data(noise_stds, ic(ppi_widths), marker="s", label="Vanilla PPI")
plot_data(
    noise_stds,
    ic(ppipp_widths),
    color=COLOR_PPIPP,
    marker="x",
    label="PPI++ (Angelopoulos et al., 2023b)",
)
plot_data(
    noise_stds,
    ic(cci_widths),
    color=COLOR_CLASS,
    marker="^",
    label="Only labelled data",
)

plt.xlabel("Standard deviation of added exogenous noise")
plt.ylabel("CI width")
# plt.title(f"Interval width vs noise_std\n(gamma={gamma}, n_estimators={n_estimators})")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig("out_noiseXwidth.png")

# plt.figure(figsize=(8, 5))
plt.figure()
plt.plot(noise_stds, np.mean(mse_test, axis=1), marker="o")
plt.xlabel("Noise standard deviation (noise_std)")
plt.ylabel("Predictive MSE on test set")
plt.title(f"Predictive MSE vs noise_std\n(model: XGBoost, n_estimators={n_estimators})")
plt.grid(True)
plt.tight_layout()
plt.savefig("out_mse.png")
