#!/usr/bin/env python3
import os
import json
import numpy as np
import re
from typing import List, Dict, Any, Tuple

def load_jsonl_file(file_path):
    """Load data from a JSONL file."""
    data = []
    with open(file_path, 'r') as f:
        for line in f:
            try:
                data.append(json.loads(line))
            except json.JSONDecodeError:
                continue
    return data

def extract_model_info_from_filename(filename, dataset_type='mcq'):
    """Extract model name and number of generations from filename."""
    name_without_ext = filename.replace('.jsonl', '')
    
    if dataset_type == 'mcq':
        # For MMLU Pro and GPQA: ModelName_split_size_N_generations_M.jsonl
        model_match = re.match(r'([^_]+(?:-[^_]*)*?)(?:_(?:train|test))', name_without_ext)
        if model_match:
            model_name = model_match.group(1)
        else:
            parts = name_without_ext.split('_')
            model_name = parts[0]
            for i, part in enumerate(parts[1:], 1):
                if part in ['train', 'test', 'size'] or part.isdigit():
                    break
                model_name += f"_{part}"
    else:
        # For SimpleQA and other freeform: ModelName_test_size_N_generations_M.jsonl
        model_match = re.match(r'([^_]+(?:-[^_]*)*?)(?:_(?:test|eval))', name_without_ext)
        if model_match:
            model_name = model_match.group(1)
        else:
            parts = name_without_ext.split('_')
            model_name = parts[0]
            for i, part in enumerate(parts[1:], 1):
                if part in ['train', 'test', 'eval', 'size'] or part.isdigit():
                    break
                model_name += f"_{part}"
    
    # Extract number of generations
    gen_match = re.search(r'generations_(\d+)', name_without_ext)
    num_generations = int(gen_match.group(1)) if gen_match else 1
    
    return model_name, num_generations

def test_simpleqa_processing():
    """Test SimpleQA dataset processing."""
    
    simpleqa_path = '/fast/XXXX-3/forecasting/evals/freeform/SimpleQA/simpleqa-iclr'
    judge_field = 'score_Llama_4_Scout'
    
    print("Testing SimpleQA Data Processing")
    print("=" * 50)
    
    # Get all JSONL files
    jsonl_files = [f for f in os.listdir(simpleqa_path) 
                   if f.endswith('.jsonl') and any(model in f for model in ['Qwen3-4B', 'Qwen3-8B'])]
    
    print(f"Found {len(jsonl_files)} relevant files:")
    
    for filename in jsonl_files:
        print(f"\n  Processing: {filename}")
        
        # Extract model info
        model_name, num_generations = extract_model_info_from_filename(filename, 'judge')
        print(f"    Extracted model: {model_name}, generations: {num_generations}")
        
        # Load data
        file_path = os.path.join(simpleqa_path, filename)
        data = load_jsonl_file(file_path)
        print(f"    Loaded {len(data)} samples")
        
        if data:
            sample = data[0]
            print(f"    Sample keys: {list(sample.keys())}")
            
            # Check judge field
            if judge_field in sample:
                print(f"    ✓ Judge field '{judge_field}' found")
                print(f"    Sample judge scores: {sample[judge_field][:2] if len(sample[judge_field]) > 2 else sample[judge_field]}")
            else:
                available_judges = [k for k in sample.keys() if k.startswith('score_')]
                print(f"    ✗ Judge field '{judge_field}' NOT found")
                print(f"    Available judge fields: {available_judges}")
            
            # Check extracted_answer format
            if 'extracted_answer' in sample and sample['extracted_answer']:
                print(f"    First extracted_answer: {sample['extracted_answer'][0]}")
                print(f"    Extracted answer type: {type(sample['extracted_answer'][0])}")

if __name__ == "__main__":
    test_simpleqa_processing() 