#%%
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


#%%
df_perez_sim = pd.read_csv('tree_no_depth_limit_minsample15_calbinning20percentall/concatenated_sanity_check-y_residuals_predict_proba_HistGradientBoostingClassifier.csv')
df_residuals_sim = pd.read_csv('tree_no_depth_limit_minsample15_calbinning20percentall/concatenated_sanity_check-y_residuals_predict_proba.csv')
# df_residuals_real = pd.read_csv('tree_depth/concatenated_sanity-check-y_residuals_real.csv')

# Filter df_residuals_sim to only include data after x=20000
# df_residuals_sim = df_residuals_sim[df_residuals_sim['N samples'] >= 20000]

#%%
# Process each dataframe
dfs = {
    'Perez et al.': (df_perez_sim, 'GL'),
    'Ours': (df_residuals_sim, 'GL'),
    # 'Residuals Real': (df_residuals_real, 'GL2')
}
# Create a single figure with one subplot
fig, ax = plt.subplots(figsize=(5, 2))
# fig.suptitle('Recalibration method = Sigmoid, min_samples_leaf = 10', fontsize=16)

# Add gray shaded area for x < 20000
# ax.axvspan(4000, 20000, alpha=0.4, color='lightgray', label='Training region with constant\ncalibration set size (4000)')

# Set markers and colors for different datasets
markers = ['o', 's']
colors = ['C0', 'C1']

# Plot for each dataset
for i, (name, (df, gl_col)) in enumerate(dfs.items()):
    # Group by N samples and calculate mean and std of GL over different seeds
    grouped = df.groupby(['N samples']).agg({gl_col: ['mean', 'std']}).reset_index()
    grouped.columns = ['N samples', 'GL_mean', 'GL_std']
    
    # Plot mean line
    ax.plot(
        grouped['N samples'], 
        grouped['GL_mean'], 
        # marker=markers[i],
        linestyle='-',
        linewidth=2,
        color=colors[i],
        label=name
    )
    
    # Fill area between mean ± std
    ax.fill_between(
        grouped['N samples'],
        grouped['GL_mean'] - grouped['GL_std'],
        grouped['GL_mean'] + grouped['GL_std'],
        alpha=0.2,
        color=colors[i]
    )
    
    # Scatter all individual data points
    # ax.scatter(
    #     df['N samples'],
    #     df[gl_col],
    #     alpha=0.3,
    #     s=20,
    #     color=colors[i],
    #     marker='x',
    #     label=f"{name} (individual points)" if i == 0 else None
    # )
    
    # Add points for each depth to see the distribution
    depths = sorted(df['Depth'].unique())
    for depth in depths:
        depth_data = df[df['Depth'] == depth]
        depth_grouped = depth_data.groupby(['N samples']).agg({gl_col: ['mean']}).reset_index()
        depth_grouped.columns = ['N samples', 'GL_mean']
        ax.scatter(
            depth_grouped['N samples'],
            depth_grouped['GL_mean'],
            alpha=0.3,
            s=20,
            color=colors[i],
            marker='x' if i == 1 else '+',
        )

# Add reference line for GL truth at y=0.00697
ax.axhline(y=0.0041, color='red', linestyle='dotted', label='Ground truth GL (0.0041)')

# Set plot attributes
ax.set_xlabel('$n_{samples}$ available to fit the estimator', fontsize=12)
ax.set_ylabel('Grouping Loss', fontsize=12)
ax.grid(True, linestyle='--', alpha=0.7)

# Set scales based on data range
y_min = 0 
y_max = 0.005  
x_min = 4000 
x_max = df_residuals_sim['N samples'].max()+1000 #max(df_perez_sim['N samples'].max()+100, df_residuals_sim['N samples'].max()+1000)
# Add vertical line at x = 4000 for minimum calibration samples
# ax.axvline(x=4000, color='forestgreen', linestyle='--', linewidth=2, label='Min calibration samples for the residual method (4000)')
ax.set_ylim(y_min, y_max)
ax.set_xlim(x_min, x_max)

# Use log scale for x-axis if the range of samples is large
if x_max / x_min > 10:
    ax.set_xscale('log')

# Add legend
ax.legend(loc='lower right', fontsize=9)
plt.savefig('n_samples_evolution.pdf', bbox_inches='tight', dpi=800)
# plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.show()


# %%
