import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

def smooth_curve(data, window_size=20):
    """Smooth the curve using a moving average."""
    return data.rolling(window=window_size, min_periods=1, center=True).mean()

# Load the CSV file
# df = pd.read_csv('logs/analysis/data/dynamics_pedi.csv')  # Pedipulation dynamics

df = pd.read_csv('p4rl_assets/invdynamics_dev/cdac_weights_reuse_learning_curves.csv')  # Replace with your actual file name

step = df.iloc[:, 0]
num_configs = 6
runs_per_config = 3
colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown']  # High contrast colors

# Set up the plot
plt.figure(figsize=(10, 6))

subset_idx = range(0, 5)

# for i in range(num_configs):
for i in subset_idx:
    start_idx = 1 + i * runs_per_config
    end_idx = start_idx + runs_per_config
    runs = df.iloc[:, start_idx:end_idx]

    mean_vals = runs.mean(axis=1)
    std_vals = runs.std(axis=1)

    mean_smoothed = smooth_curve(mean_vals)

    plt.plot(step, mean_smoothed, label=f'Config {i+1}', 
             color=colors[i],
             )
    # plt.fill_between(step, mean_vals - std_vals, mean_vals + std_vals, 
    #                  alpha=0.3, 
    #                  color=colors[i], 
    #                  label="_nolegend_")

# run_names = df.columns.tolist()
# run_names_for_legend = run_names[1::runs_per_config]
# legend = ["RL with trained dynamics module - Frozen",
#             "random_init - Frozen", 
#             "random_init - Unfrozen",
#             "Trained (velocity tracking, initial) - Unfrozen",
#             "3-layer MLP baseline", 
#             "RL with trained dynamics module - Trainable",
#           ]


plt.xlim(0, 300)
plt.ylim(-20, 20)

legend = ["Backend 100% from trained pedipulation models",
            "Backend 50% from trained pedipulation models", 
            "Backend 100% from trained locomotion models",
            "Backend 50% from trained locomotion models",
            "Baseline", 
          ]

legend = [legend[i] for i in subset_idx]

plt.xlabel('Iteration Number')
plt.ylabel('Reward')
# plt.title('Training Curves with Mean ± Std Dev')
# plt.legend(run_names_for_legend)
plt.legend(legend)
plt.grid(True)
plt.tight_layout()
plt.savefig('logs/analysis/plots/training_curves_task_transfer.pdf')