import pandas as pd
import matplotlib.pyplot as plt
import os
import glob
import re

# Base directory containing all checkpoint folders
base_dir = "/fsx/training/output/user/evaluation_v3/L1_Mentalese/no_budget/Agentica24k_1.5b_mentalese_cot_lr_1e-6_SFT30k_ins_suffix_ckpt_8910_grpo_with_l1_exact_range_64_512_ray_multinode_1024_response"

# Find all global_step directories
checkpoint_dirs = glob.glob(os.path.join(base_dir, "global_step_*"))
checkpoint_dirs.sort(key=lambda x: int(re.search(r'global_step_(\d+)', x).group(1)))

print(f"Found {len(checkpoint_dirs)} checkpoint directories")

# Collect data from all checkpoints
all_data = []

for checkpoint_dir in checkpoint_dirs:
    # Extract step number
    step_match = re.search(r'global_step_(\d+)', checkpoint_dir)
    if not step_match:
        continue
    
    step = int(step_match.group(1))
    pass_csv_path = os.path.join(checkpoint_dir, "pass.csv")
    
    if os.path.exists(pass_csv_path):
        try:
            df = pd.read_csv(pass_csv_path)
            for _, row in df.iterrows():
                dataset = row['dataset'].replace('.parquet', '')  # Remove .parquet extension
                pass_1 = row['pass@1']
                
                all_data.append({
                    'checkpoint': step,
                    'dataset': dataset,
                    'pass@1': pass_1
                })
            print(f"Processed checkpoint {step}")
        except Exception as e:
            print(f"Error processing {pass_csv_path}: {e}")
    else:
        print(f"pass.csv not found in {checkpoint_dir}")

# Convert to DataFrame
results_df = pd.DataFrame(all_data)

if results_df.empty:
    print("No data found!")
    exit(1)

print(f"Collected data for {len(results_df)} benchmark results")
print(f"Checkpoints: {sorted(results_df['checkpoint'].unique())}")
print(f"Datasets: {sorted(results_df['dataset'].unique())}")

# Create the plot
plt.figure(figsize=(12, 8))

# Color palette for different benchmarks
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']

# Plot pass@1 for each dataset
datasets = sorted(results_df['dataset'].unique())
for i, dataset in enumerate(datasets):
    dataset_data = results_df[results_df['dataset'] == dataset].sort_values('checkpoint')
    
    plt.plot(dataset_data['checkpoint'], dataset_data['pass@1'], 
             marker='o', linewidth=3, markersize=8, 
             color=colors[i % len(colors)], label=dataset)

plt.title('Benchmark Accuracy vs Checkpoint (Pass@1)', fontsize=20, pad=20)
plt.xlabel('Checkpoint (Global Step)', fontsize=14)
plt.ylabel('Pass@1 Accuracy', fontsize=14)
plt.legend(fontsize=12, loc='upper left', bbox_to_anchor=(0.01, 0.99))
plt.grid(True, alpha=0.3)
plt.ylim(0, 1)
plt.tight_layout()

# Save the plot in the base directory
plot_path_png = os.path.join(base_dir, 'benchmark_accuracy_plot.png')
plt.savefig(plot_path_png, dpi=300, bbox_inches='tight')
plt.close()

print(f"Plot saved as '{plot_path_png}'")

# Display summary statistics
print("\nSummary of collected data:")
print(results_df.groupby('dataset')['pass@1'].agg(['min', 'max', 'mean']).round(4))
