#!/usr/bin/env python3
import os
import json
import numpy as np
import re
from typing import List, Dict, Any, Tuple

# Dataset configurations
DATASETS = {
    'MMLU Pro': {
        'path': '/fast/XXXX-3/forecasting/evals/mmlu_pro/mmlu_pro',
        'type': 'mcq',
        'judge_field': None
    },
    'GPQA': {
        'path': '/fast/XXXX-3/forecasting/evals/gpqa/gpqa_diamond', 
        'type': 'mcq',
        'judge_field': None
    },
    'SimpleQA': {
        'path': '/fast/XXXX-3/forecasting/evals/freeform/SimpleQA/simpleqa-iclr',
        'type': 'judge',
        'judge_field': 'score_Llama_4_Scout'
    }
}

def load_jsonl_file(file_path):
    """Load data from a JSONL file."""
    data = []
    with open(file_path, 'r') as f:
        for line in f:
            try:
                data.append(json.loads(line))
            except json.JSONDecodeError:
                continue
    return data

def extract_model_info_from_filename(filename, dataset_type='mcq'):
    """Extract model name and number of generations from filename."""
    name_without_ext = filename.replace('.jsonl', '')
    
    if dataset_type == 'mcq':
        # For MMLU Pro and GPQA: ModelName_split_size_N_generations_M.jsonl
        model_match = re.match(r'([^_]+(?:-[^_]*)*?)(?:_(?:train|test))', name_without_ext)
        if model_match:
            model_name = model_match.group(1)
        else:
            parts = name_without_ext.split('_')
            model_name = parts[0]
            for i, part in enumerate(parts[1:], 1):
                if part in ['train', 'test', 'size'] or part.isdigit():
                    break
                model_name += f"_{part}"
    else:
        # For SimpleQA and other freeform: ModelName_eval_size_N_generations_M.jsonl
        model_match = re.match(r'([^_]+(?:_[^_]*?)?(?:-\d+\.?\d*[bB])?)', name_without_ext)
        if model_match:
            model_name = model_match.group(1)
        else:
            model_name = name_without_ext.split('_')[0]
    
    # Extract number of generations
    gen_match = re.search(r'generations_(\d+)', name_without_ext)
    num_generations = int(gen_match.group(1)) if gen_match else 1
    
    return model_name, num_generations

def test_dataset_processing():
    """Test the dataset processing logic."""
    
    for dataset_name, config in DATASETS.items():
        print(f"\n{'='*50}")
        print(f"Testing {dataset_name}")
        print(f"{'='*50}")
        
        dataset_path = config['path']
        dataset_type = config['type']
        
        if not os.path.exists(dataset_path):
            print(f"Warning: Dataset path {dataset_path} does not exist")
            continue
        
        # Get all JSONL files
        jsonl_files = [f for f in os.listdir(dataset_path) 
                       if f.endswith('.jsonl') and any(model in f for model in ['Qwen3-4B', 'Qwen3-8B'])]
        
        print(f"Found {len(jsonl_files)} relevant files:")
        
        for filename in jsonl_files[:4]:  # Test first 4 files
            print(f"\n  Testing: {filename}")
            
            # Extract model info
            model_name, num_generations = extract_model_info_from_filename(filename, dataset_type)
            print(f"    Extracted model: {model_name}, generations: {num_generations}")
            
            # Load a sample of data
            file_path = os.path.join(dataset_path, filename)
            data = load_jsonl_file(file_path)
            print(f"    Loaded {len(data)} samples")
            
            if data:
                sample = data[0]
                print(f"    Sample keys: {list(sample.keys())}")
                
                # Check structure based on dataset type
                if dataset_type == 'mcq':
                    if 'answer' in sample:
                        print(f"    Correct answer: {sample['answer']}")
                    if 'extracted_answer' in sample and sample['extracted_answer']:
                        print(f"    First extracted_answer: {sample['extracted_answer'][0]}")
                else:  # judge
                    judge_fields = [k for k in sample.keys() if k.startswith('score_')]
                    print(f"    Available judge fields: {judge_fields}")
                    if 'extracted_answer' in sample and sample['extracted_answer']:
                        print(f"    First extracted_answer: {sample['extracted_answer'][0]}")

if __name__ == "__main__":
    test_dataset_processing() 