import os
import glob
import json
import seaborn as sns
import matplotlib.pyplot as plt
import argparse
import re

def extract_step_number(folder_path):
    # Extract the number from folder names like "step-00002000_ck"
    match = re.search(r'step-(\d+)_ck', folder_path)
    return int(match.group(1)) if match else 0

def process_checkpoints(directory):
    # Initialize dictionaries to store scores for each dataset
    scores = {
        'ArguAna': [],
        'NFCorpus': [],
        'SciFact': []
    }
    steps = []
    
    # Get all checkpoint folders and sort them
    ckpt_folders = glob.glob(os.path.join(directory, 'step-*_ck'))
    ckpt_folders.sort(key=extract_step_number)
    
    # Process each checkpoint folder
    for folder in ckpt_folders:
        step = extract_step_number(folder)
        steps.append(step)
        
        # Process each dataset JSON file
        for dataset in scores.keys():
            json_path = os.path.join(folder, f'{dataset}.json')
            try:
                with open(json_path, 'r') as f:
                    data = json.load(f)
                    scores[dataset].append(data['scores']['test'][0]['main_score'])
            except FileNotFoundError:
                print(f"Warning: {json_path} not found")
                scores[dataset].append(None)
            except KeyError:
                print(f"Warning: main_score not found in {json_path}")
                scores[dataset].append(None)
    
    return steps, scores

def create_plot(steps, scores, output_dir):
    # Set up the plot style
    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 8))
    
    # Create three subplots
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 12))
    fig.suptitle('Main Scores Across Checkpoints', fontsize=16)
    
    # Plot each dataset
    axes = [ax1, ax2, ax3]
    colors = ['blue', 'green', 'red']
    
    for (dataset, scores_list), ax, color in zip(scores.items(), axes, colors):
        sns.lineplot(x=steps, y=scores_list, ax=ax, color=color, marker='o')
        ax.set_title(f'{dataset} Scores')
        ax.set_xlabel('Steps')
        ax.set_ylabel('Main Score')
        
    # Adjust layout and save
    plt.tight_layout()
    output_path = os.path.join(output_dir, 'scores_plot.jpg')
    plt.savefig(output_path)
    print(f"Plot saved to: {output_path}")

def main():
    # Set up argument parser
    parser = argparse.ArgumentParser(description='Process checkpoint folders and create score plots')
    parser.add_argument('--result_dir', help='Directory containing checkpoint folders')
    args = parser.parse_args()
    
    # Process the checkpoints and create the plot
    steps, scores = process_checkpoints(args.result_dir)
    create_plot(steps, scores, args.result_dir)

if __name__ == "__main__":
    main()