import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

def smooth_curve(data, window_size=20):
    """Smooth the curve using a moving average."""
    return data.rolling(window=window_size, min_periods=1, center=True).mean()

# Load the CSV file
# df = pd.read_csv('logs/analysis/data/dynamics_pedi.csv')  # Pedipulation dynamics

# ----------------------------------------------
# filename = 'logs/analysis/plots/parkour_jump_with_INV_weight_source_comparison_individual_runs.csv'

# labels = ["RL with INV (Pretrained on pedipulation initial data)", 
#           "RL with INV (Pretrained on exploration data)", 
#           "RL with INV (Random Init)",]

# runs_to_vis_per_config = 6

# runs_per_config = [8, 6, 8]
# subset_idx_start_idx = [0] + list(np.cumsum(runs_per_config)[:-1])

# ----------------------------------------------
filename = 'logs/analysis/plots/pedi_with_INV.csv'

labels = [
            "RL with INV (Pretrained on initial pedipulation data)", 
            "RL with INV (Random Init)", 
            "Baseline",
            "RL with INV (Pretrained on exploration data)", 
            ]

runs_to_vis_per_config = 3

runs_per_config = [3]*4
subset_idx_start_idx = [0] + list(np.cumsum(runs_per_config)[:-1])
# ----------------------------------------------

df = pd.read_csv(filename) 
figure_name = filename.split('/')[-1].replace('.csv', '')

step = df.iloc[:, 0]
num_configs = 6
three = 3 # wandb output csv has three columns per run, mean, min, max
colors = ['blue', 'red', 'green', 'purple', 'brown']  # High contrast colors

# Set up the plot
plt.figure(figsize=(10, 6))


# for i in range(num_configs):
for i, idx in enumerate(subset_idx_start_idx):
    start_idx = 1 + idx * three
    end_idx = start_idx + three*runs_to_vis_per_config
    runs = df.iloc[:, start_idx:end_idx:three]

    mean_vals = runs.mean(axis=1)
    std_vals = runs.std(axis=1)

    mean_vals = smooth_curve(mean_vals)

    plt.plot(step, mean_vals, label=labels[i], color=colors[i], linewidth=2
             )
    
    # plt.plot(step, mean_vals - std_vals, color=colors[i], alpha=0.5, linewidth=1, linestyle='--'
    #          )
    # plt.plot(step, mean_vals + std_vals, color=colors[i], alpha=0.5, linewidth=1, linestyle='--'
    #          )
    
    plt.fill_between(step, mean_vals - std_vals, mean_vals + std_vals, 
                     alpha=0.3, 
                     color=colors[i], 
                     label="_nolegend_")

# run_names = df.columns.tolist()
# run_names_for_legend = run_names[1::runs_per_config]
# legend = ["RL with trained dynamics module - Frozen",
#             "random_init - Frozen", 
#             "random_init - Unfrozen",
#             "Trained (velocity tracking, initial) - Unfrozen",
#             "3-layer MLP baseline", 
#             "RL with trained dynamics module - Trainable",
#           ]


# plt.xlim(0, 300)
# plt.ylim(-20, 20)

# legend = ["Backend 100% from trained pedipulation models",
#             "Backend 50% from trained pedipulation models", 
#             "Backend 100% from trained locomotion models",
#             "Backend 50% from trained locomotion models",
#             "Baseline", 
#           ]

# legend = [legend[i] for i in subset_idx]

plt.xlabel('Iteration Number')
plt.ylabel('Reward')
# plt.title('Training Curves with Mean ± Std Dev')
# plt.legend(run_names_for_legend)


handles, labels = plt.gca().get_legend_handles_labels()

# Reorder them manually
order = [2, 1, 3, 0]  # indices of the order you want (Third, First, Second)
plt.legend([handles[i] for i in order], [labels[i] for i in order])


# plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(f'logs/analysis/plots/'+figure_name+'.pdf')
print(f'Plot saved as "logs/analysis/plots/"'+figure_name+'.pdf')