import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

def smooth_curve(data, window_size=20):
    """Smooth the curve using a moving average."""
    return data.rolling(window=window_size, min_periods=1, center=True).mean()

# Load the CSV file
df = pd.read_csv('logs/analysis/pre_dyna_GRU_comparison.csv')  # Replace with your actual file name

step = df.iloc[:, 0]
num_configs = 6
runs_per_config = 3
colors = ['blue', 'red', 'purple', 'brown']  # High contrast colors

# Set up the plot
plt.figure(figsize=(10, 6))

subset_idx = [1, 0, 2]

# for i in range(num_configs):
for i in subset_idx:
    start_idx = 1 + i * runs_per_config
    end_idx = start_idx + runs_per_config
    runs = df.iloc[:, start_idx:end_idx]

    mean_vals = runs.mean(axis=1)
    std_vals = runs.std(axis=1)

    mean_smoothed = smooth_curve(mean_vals)

    plt.plot(step, mean_smoothed, label=f'Config {i+1}', 
             color=colors[i],
             )
    plt.fill_between(step, mean_vals - std_vals, mean_vals + std_vals, 
                     alpha=0.3, 
                     color=colors[i], 
                     label="_nolegend_")

# plt.xlim(0, 500)

plt.axvline(x=200, linestyle="--", color="red", alpha=0.5)

run_names = df.columns.tolist()
run_names_for_legend = run_names[1::runs_per_config]
legend = [
              "Pretrained dynamics module (Not frozen)",
    "Pretrained dynamics module (frozen in initial 200 iterations)",

          "Baseline", 
          
          ]

legend = [legend[i] for i in subset_idx] + ["unfrozen timepoint"]

plt.xlabel('Step')
plt.ylabel('Metric')
plt.title('Training Curves with Mean ± Std Dev')
# plt.legend(run_names_for_legend)
plt.legend(legend)
plt.grid(True)
plt.tight_layout()
plt.savefig('logs/analysis/plots/training_curves_dynamics_pedi.png')