import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from supplementary.util import clt_confidence_interval
from supplementary.util import (
    conformal_interval_quantile,
    power_interval_quantile,
    ppipp_interval_quantile,
    fab_interval_quantile,
    classical_interval_quantile,
)
import xgboost as xgb
import matplotlib.pyplot as plt


def generate_mock_data(n_samples=50000, noise_std=1, random_state=42):
    np.random.seed(random_state)
    X = np.random.uniform(42, 1, size=(n_samples, 1))
    Y = np.sin(np.pi * X).ravel() + np.random.normal(0, noise_std, size=n_samples)
    return X, Y

ALPHA = 0.01                # nível do IC
CI_CONSTRUCTOR = clt_confidence_interval
BIG_M = 1e99
q = 0.5
M = 1.0

ind = lambda y, theta: (y - theta <= 0).astype(int)
psi = lambda y, theta: ind(y, theta) - q


n_cal = 1000
n_test = 10000
n_train = 10000
n_samples = n_cal + n_test + n_train


NOISE_STD = 0.1
gamma = 0.01

N_boosts = np.arange(1, 60, 1)
max_boosts = 100

train_errors = np.full(len(N_boosts), np.nan, dtype=float)
cal_errors = np.full(len(N_boosts), np.nan, dtype=float)
test_errors = np.full(len(N_boosts), np.nan, dtype=float)
interval_wid = np.full(len(N_boosts), np.nan, dtype=float)


X, Y = generate_mock_data(n_samples=n_samples, noise_std=NOISE_STD, random_state=42)

X_train, X_temp, Y_train, Y_temp = train_test_split(
    X, Y,
    train_size=n_train,
    random_state=0
)

X_cal, X_test, Y_cal, Y_test = train_test_split(
    X_temp, Y_temp,
    test_size=n_test,
    random_state=0
)


theta_min = float(Y_test.min() - 3 * NOISE_STD)
theta_max = float(Y_test.max() + 3 * NOISE_STD)
thetas = np.linspace(theta_min, theta_max, 200)


model = xgb.XGBRegressor(n_estimators=max_boosts)
model.fit(X_train, Y_train)

for idx, n_est in enumerate(tqdm(N_boosts, desc="n_estimators")):
    Yhat_train = model.predict(X_train, iteration_range=(0, int(n_est)))
    Yhat_cal = model.predict(X_cal, iteration_range=(0, int(n_est)))
    Yhat_test  = model.predict(X_test, iteration_range=(0, int(n_est)))

    train_errors[idx] = np.mean(np.abs(Y_train - Yhat_train))
    cal_errors[idx]   = np.mean(np.abs(Y_cal   - Yhat_cal))
    test_errors[idx]  = np.mean(np.abs(Y_test  - Yhat_test))

    scores_cal = np.abs(Y_cal - Yhat_cal)

    CPPI = conformal_interval_quantile(psi=psi,scores_cal=scores_cal, err=gamma,Yhat_test=Yhat_test,thetas=thetas,ci_constructor=CI_CONSTRUCTOR,alpha=ALPHA,M=M)
    if CPPI is not None and len(CPPI) >= 2:
        interval_wid[idx] = float(CPPI[-1] - CPPI[0])
    else:
        interval_wid[idx] = np.nan



plt.figure(figsize=(8, 5))
plt.plot(N_boosts, train_errors, marker='o', label='Train')
plt.plot(N_boosts, cal_errors,   marker='s', label='Calibration')
plt.plot(N_boosts, test_errors,  marker='^', label='Test')
plt.xlabel('Number of boosters (n_estimators used)')
plt.ylabel('MAE')
plt.title(f'MAE vs n_estimators\n(noise_std={NOISE_STD}, gamma={gamma})')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(8, 5))
plt.plot(N_boosts, interval_wid, marker='o')
plt.xlabel('Number of boosters (n_estimators used)')
plt.ylabel('CPPI set width (sup C - inf C)')
plt.title(f'CPPI width vs n_estimators\n(noise_std={NOISE_STD}, gamma={gamma})')
plt.grid(True)
plt.tight_layout()
plt.show()

fig, ax1 = plt.subplots(figsize=(8, 5))

ax1.set_xlabel('Training Iterations')
ax1.set_ylabel('Calibration MAE', color='tab:blue')
ax1.plot(N_boosts, cal_errors, marker='o', linestyle='-', color='tab:blue',
         label='Calibration MAE')
ax1.tick_params(axis='y', labelcolor='tab:blue')
ax1.grid(True)

ax2 = ax1.twinx()
ax2.set_ylabel('CPPI width', color='tab:orange')
ax2.plot(N_boosts, interval_wid, marker='s', linestyle='--', color='tab:orange',
         label='CPPI width')
ax2.tick_params(axis='y', labelcolor='tab:orange')

plt.title(f'Calibration error vs CPPI width\n'
          f'(noise_std={NOISE_STD}, gamma={gamma})')
fig.tight_layout()
plt.show()