import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

LINE_WIDTH = 6

# CSV file paths
csv_path_1 = "logistic_loss_DIM_8_NDATA_50_NITERS_1000000_LR_1e-05_MU_0.0001_ZOOBATCH_16_NRUNS_16_PROCESSES_8.csv"
csv_path_2 = "loss_DIM_64_NITERS_1000000_LR_0.001_MU_0.0001_ZOOBATCH_16_NRUNS_16_PROCESSES_8.csv"

# Extract DIM value from CSV filename
def extract_dim_from_filename(filename):
    """Extract DIM value from CSV filename"""
    import re
    match = re.search(r'DIM_(\d+)', filename)
    if match:
        return int(match.group(1))
    return None

DIM_VALUE_1 = extract_dim_from_filename(csv_path_1)
DIM_VALUE_2 = extract_dim_from_filename(csv_path_2)
print(f"Extracted DIM values: {DIM_VALUE_1}, {DIM_VALUE_2}")

# ========== SMOOTHING HYPERPARAMETERS ==========

SMOOTHING_ENABLED = True          # Whether to apply smoothing
SMOOTHING_WINDOW = 5000             # Window size for moving average smoothing
SMOOTHING_METHOD = "moving_avg"   # Options: "moving_avg", "exponential", "savgol"
SAVGOL_POLYORDER = 3              # Polynomial order for Savitzky-Golay (if using savgol method)

def smooth_curve(data, method="moving_avg", window=50, polyorder=3):
    """
    Apply smoothing to a 1D array.
    
    Args:
        data: 1D numpy array to smooth
        method: smoothing method ("moving_avg", "exponential", "savgol")
        window: window size for smoothing
        polyorder: polynomial order for Savitzky-Golay
    
    Returns:
        Smoothed 1D numpy array
    """
    if not SMOOTHING_ENABLED or len(data) < window:
        return data
    
    if method == "moving_avg":
        # Simple moving average
        smoothed = np.convolve(data, np.ones(window)/window, mode='same')
        # Handle boundaries
        half_window = window // 2
        for i in range(half_window):
            smoothed[i] = np.mean(data[:i+half_window+1])
            smoothed[-(i+1)] = np.mean(data[-(i+half_window+1):])
        return smoothed
    
    elif method == "exponential":
        # Exponential moving average
        alpha = 2.0 / (window + 1)
        smoothed = np.zeros_like(data)
        smoothed[0] = data[0]
        for i in range(1, len(data)):
            smoothed[i] = alpha * data[i] + (1 - alpha) * smoothed[i-1]
        return smoothed
    
    elif method == "savgol":
        # Savitzky-Golay filter
        from scipy.signal import savgol_filter
        if window % 2 == 0:
            window += 1  # savgol requires odd window size
        return savgol_filter(data, window, polyorder)
    
    else:
        return data

# Load both CSV files
df_1 = pd.read_csv(csv_path_1)
df_2 = pd.read_csv(csv_path_2)

# Cut off at 1000000 iterations
cutoff_iter = 500000
df_1 = df_1.iloc[:cutoff_iter]
df_2 = df_2.iloc[:cutoff_iter]

# Create color palette
palette = sns.color_palette("Set2", 3)

# Create side-by-side subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 7))

# helper to draw mean & 5‑95 quantile band (with smoothing)
def draw_band(ax, iters, arr: np.ndarray, color, label):
    # Apply smoothing to each run separately
    smoothed_arr = np.zeros_like(arr)
    for i in range(arr.shape[1]):
        smoothed_arr[:, i] = smooth_curve(arr[:, i], 
                                         method=SMOOTHING_METHOD, 
                                         window=SMOOTHING_WINDOW, 
                                         polyorder=SAVGOL_POLYORDER)
    
    # Calculate statistics on smoothed data
    mean = smoothed_arr.mean(axis=1)
    if "Rescale" in label:
        p5  = np.percentile(smoothed_arr, 10, axis=1)
        p95 = np.percentile(smoothed_arr, 90, axis=1)
    else:
        p5  = np.percentile(smoothed_arr, 20, axis=1)
        p95 = np.percentile(smoothed_arr, 80, axis=1)
    
    line = ax.plot(iters, mean, color=color, label=label, linewidth=LINE_WIDTH)[0]
    ax.fill_between(iters, p5, p95, color=color, alpha=0.25)
    return line

# Plot first CSV (DIM_64) - swapped position
iters_2 = np.arange(len(df_2))
cols_direct_2 = [c for c in df_2.columns if c.startswith("zeroth_order_direct_")]
cols_reject_2 = [c for c in df_2.columns if c.startswith("zeroth_order_rejection_")]

line1 = draw_band(ax1, iters_2, df_2[cols_direct_2].values, palette[1], "Rescale Sampling")
line2 = draw_band(ax1, iters_2, df_2[cols_reject_2].values, palette[2], "Rejection Sampling")

ax1.set_xlabel("Iteration", fontsize=18)
ax1.set_ylabel("Objective Value", fontsize=18)
ax1.set_title(f"Quadratic Objective", fontsize=20)
ax1.tick_params(axis='both', which='major', labelsize=18)

# Set custom x-axis ticks and labels
x_ticks = [100000, 200000, 300000, 400000, 500000]
x_labels = ["100k", "200k", "300k", "400k", "500k"]
ax1.set_xticks(x_ticks)
ax1.set_xticklabels(x_labels)

# Plot second CSV (DIM_8) - swapped position
iters_1 = np.arange(len(df_1))
cols_direct_1 = [c for c in df_1.columns if c.startswith("zeroth_order_direct_")]
cols_reject_1 = [c for c in df_1.columns if c.startswith("zeroth_order_rejection_")]

draw_band(ax2, iters_1, df_1[cols_direct_1].values, palette[1], "Rescale Sampling")
draw_band(ax2, iters_1, df_1[cols_reject_1].values, palette[2], "Rejection Sampling")

ax2.set_xlabel("Iteration", fontsize=18)
ax2.set_ylabel("Objective Value", fontsize=18)
ax2.set_title(f"Logistic Objective", fontsize=20)
ax2.tick_params(axis='both', which='major', labelsize=18)

# Set custom x-axis ticks and labels
ax2.set_xticks(x_ticks)
ax2.set_xticklabels(x_labels)

# Remove spines
sns.despine(ax=ax1)
sns.despine(ax=ax2)

# Create a single legend below both plots
handles = [line1, line2]
labels = ["Rescale Sampling", "Rejection Sampling"]
fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, 0.025), 
           ncol=2, fontsize=18)

plt.tight_layout()
plt.subplots_adjust(bottom=0.2)  # Make room for legend

# Create output filename
output_filename = f"logistic_convergence_comparison_DIM_{DIM_VALUE_1}_and_{DIM_VALUE_2}.png"
plt.savefig(output_filename, dpi=300, bbox_inches="tight")
plt.show()

print(f"CSV files loaded: {csv_path_1}, {csv_path_2}")
print(f"Plot saved to: {output_filename}")
print(f"Smoothing: {'Enabled' if SMOOTHING_ENABLED else 'Disabled'}")
if SMOOTHING_ENABLED:
    print(f"Method: {SMOOTHING_METHOD}, Window: {SMOOTHING_WINDOW}") 