# scripts/plot_learning_curves.py
import pandas as pd
import matplotlib.pyplot as plt

def plot_learning_curves(log_files, labels, output_file):
    """
    Plot training curves of average cost and CVaR cost over time for multiple runs.
    `log_files`: list of CSV files (one per method)
    `labels`: corresponding labels for each method
    """
    plt.figure(figsize=(6,4))
    for log_file, label in zip(log_files, labels):
        data = pd.read_csv(log_file)
        plt.plot(data['step'], data['avg_cost'], label=f'{label} (Avg Cost)')
        plt.plot(data['step'], data['avg_cvar_cost'], '--', label=f'{label} (CVaR Cost)')
    plt.xlabel('Training Steps')
    plt.ylabel('Cost')
    plt.title('Learning Curves: Average vs CVaR Cost')
    plt.legend()
    plt.tight_layout()
    plt.savefig(output_file)
    plt.show()

# Example usage:
# plot_learning_curves(
#     log_files=['logs/FiRL_train_log.csv', 'logs/PPO_train_log.csv', 'logs/CVaR_PPO_train_log.csv'],
#     labels=['FiRL', 'PPO', 'CVaR-PPO'],
#     output_file='learning_curves.png'
# )
