#!/usr/bin/env python3
"""
Phase 1.2: Test Data Oracle Generation (For Evaluation) - vLLM Accelerated Version
- Input: GSM8K-test MATH500 AMC-23, AIME-24 AIME-25, MINERVA, OLYMPIAD-BENCH
- Operation: Use vLLM to accelerate Qwen2.5-Math-7B inference (max_new_tokens=8192)
- Output: generations/test_oracle_8k.parquet
"""

import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from math_verify.metric import math_metric
from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig

# ============== Configuration ==============
MODEL_NAME = "./models/Qwen2.5-Math-7B"
# MODEL_NAME = "./models/Qwen3-30B-A3B-Instruct-2507"
GSM8K_TEST_PATH = "./data/gsm8k_test_split.parquet"
MATH_PATH = "./data/math_test.parquet"  
AMC_PATH = "./data/amc-23.parquet"
AIME_PATH = "./data/aime-24.parquet"
AIME25_PATH = "./data/aime-25.parquet"
MINERVA_PATH = "./data/minerva.jsonl"
OLYMPIBENCH_PATH = "./data/olympiad_bench.parquet"

OUTPUT_PATH = "./data/generations/test_ood_oracle_8k.parquet"

MAX_NEW_TOKENS = 8192  # 8K Oracle
TENSOR_PARALLEL_SIZE = 4  

# Prompt Template
PROMPT_TEMPLATE = """Solve the following math problem step by step. Show your work clearly and put the final answer inside \\boxed{{}}.

Question: {question}

Answer:"""


# ============== Data Loading ==============
def load_test_data():
    """Load all test datasets"""
    print("Loading test datasets...")
    
    all_samples = []
    

    print("  Loading GSM8K test...")
    gsm8k_df = pd.read_parquet(GSM8K_TEST_PATH)
    for _, row in gsm8k_df.iterrows():
        all_samples.append({
            'question': row['question'],
            'gold_answer': row['answer'],
            'dataset': 'gsm8k',
            'category': 'IID'
        })
    print(f"    GSM8K test: {len(gsm8k_df)} samples")
    

    print("  Loading MATH test...")
    math_df = pd.read_parquet(MATH_PATH)
    for _, row in math_df.iterrows():
        all_samples.append({
            'question': row['problem'],
            'gold_answer': row['solution'],
            'dataset': 'math500',
            'category': 'OOD',
        })
    print(f"    MATH test: {len(math_df)} samples")
    
    print("  Loading AMC-23...")
    amc_df = pd.read_parquet(AMC_PATH)
    for _, row in amc_df.iterrows():
        all_samples.append({
            'question': row['question'],
            'gold_answer': str(row['answer']),
            'dataset': 'amc23',
            'category': 'OOD'
        })
    print(f"    AMC-23: {len(amc_df)} samples")
    
    print("  Loading AIME-24...")
    aime_df = pd.read_parquet(AIME_PATH)
    for _, row in aime_df.iterrows():
        all_samples.append({
            'question': row['problem'],
            'gold_answer': str(row['answer']),
            'dataset': 'aime24',
            'category': 'OOD'
        })
    print(f"    AIME-24: {len(aime_df)} samples")
    

    print("  Loading AIME-25...")
    aime25_df = pd.read_parquet(AIME25_PATH)
    for _, row in aime25_df.iterrows():
        all_samples.append({
            'question': row['prompt'],
            'gold_answer': str(row['solution']),
            'dataset': 'aime25',
            'category': 'OOD'
        })
    print(f"    AIME-25: {len(aime25_df)} samples")
    
    print("  Loading Minerva...")
    minerva_samples = []
    with open(MINERVA_PATH, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            # Handle empty lines and JSON parse errors, enhance robustness
            line = line.strip()
            if not line:
                continue
            try:
                # Parse JSON data for each line
                row = json.loads(line)
                minerva_samples.append(row)
            except json.JSONDecodeError as e:
                print(f"  Warning: JSON parse failed at line {line_num}, skipping: {e}")

    for row in minerva_samples:
        all_samples.append({
            'question': row['question'],
            'gold_answer': str(row['answer']),
            'dataset': 'minerva',
            'category': 'OOD'
        })

    print(f"    Minerva: {len(minerva_samples)} samples")

    print("  Loading Olympiad Bench...")
    olympiad_bench_df = pd.read_parquet(OLYMPIBENCH_PATH)
    for _, row in olympiad_bench_df.iterrows():
        all_samples.append({
            'question': row['prompt'],
            'gold_answer': str(row['reward_model']['ground_truth']),
            'dataset': 'olympiad_bench',
            'category': 'OOD'
        })
    print(f"    Olympiad Bench: {len(olympiad_bench_df)} samples")
    

    print(f"  Total test samples: {len(all_samples)}")
    return all_samples


# ============== Math Verification ==============
def create_verify_func():
    """Create math_verify function"""
    return math_metric(
        gold_extraction_target=(LatexExtractionConfig(), ExprExtractionConfig()),
        pred_extraction_target=(LatexExtractionConfig(), ExprExtractionConfig()),
        aggregation_function=max,
        precision=6
    )


# ============== vLLM Generation ==============
def run_vllm_generation(samples):
    """Use vLLM for efficient batch inference"""
    print(f"\nInitializing vLLM with tensor_parallel_size={TENSOR_PARALLEL_SIZE}...")
    
    # Initialize vLLM
    llm = LLM(
        model=MODEL_NAME,
        tensor_parallel_size=TENSOR_PARALLEL_SIZE,
        dtype="bfloat16",
        trust_remote_code=True,
        max_model_len=16384,  
    )
    
    # Load tokenizer for length calculation and token id saving
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    
    # Set sampling parameters
    sampling_params = SamplingParams(
        temperature=0,  # Greedy decoding
        max_tokens=MAX_NEW_TOKENS,
        stop=None,
    )
    
    # Construct all prompts
    print("Preparing prompts...")
    prompts = [PROMPT_TEMPLATE.format(question=s['question']) for s in samples]
    
    # Batch generation
    print(f"Generating responses for {len(prompts)} samples...")
    outputs = llm.generate(prompts, sampling_params)
    
    # Process outputs
    print("Processing outputs...")
    results = []
    verify_func = create_verify_func()
    
    for i, output in enumerate(tqdm(outputs, desc="Processing")):
        sample = samples[i]
        generated_text = output.outputs[0].text
        
        # Get token ids
        gen_token_ids = tokenizer.encode(generated_text, add_special_tokens=False)
        oracle_length = len(gen_token_ids)
        
        # Verify answer correctness
        try:
            grade, _ = verify_func([sample['gold_answer']], [generated_text])
            is_correct = (grade == 1)
        except Exception as e:
            is_correct = False
        
        results.append({
            'question': sample['question'],
            'gold_answer': sample['gold_answer'],
            'dataset': sample['dataset'],
            'category': sample['category'],
            'level': sample.get('level', ''),
            'type': sample.get('type', ''),
            'full_generation_text': generated_text,
            'full_token_ids': json.dumps(gen_token_ids),
            'oracle_length': oracle_length,
            'is_correct': is_correct
        })
    
    return results


# ============== Main ==============
def main():
    print("=" * 60)
    print("Phase 1.2: Test Oracle Generation (vLLM - 8K)")
    print("=" * 60)
    
    # 1. Load data
    samples = load_test_data()
    
    # 2. Run vLLM generation
    results = run_vllm_generation(samples)
    
    # 3. Save results
    df = pd.DataFrame(results)
    
    # Define columns that must be strings
    text_columns = ['question', 'gold_answer', 'full_generation_text', 'dataset', 'category']
    
    print("Sanitizing dataframe columns...")
    for col in text_columns:
        if col in df.columns:
            # Use apply to process individually, extract content if numpy array or convert to string
            df[col] = df[col].apply(lambda x: x.item() if isinstance(x, np.ndarray) and x.size == 1 else str(x))
            # Ensure overall type is string (object)
            df[col] = df[col].astype(str)
            print(f"Sanitized dataframe columns: {col}")

    
    os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
    df.to_parquet(OUTPUT_PATH, index=False)
    print(f"\nSaved to: {OUTPUT_PATH}")
    
    # 4. Statistics
    print(f"\n=== Statistics ===")
    print(f"Total samples: {len(df)}")
    
    for dataset in ['gsm8k', 'math500', 'amc23', 'aime24', 'aime25', 'minerva', 'olympiad_bench']:
        subset = df[df['dataset'] == dataset]
        if len(subset) > 0:
            correct = subset['is_correct'].sum()
            print(f"\n{dataset.upper()}:")
            print(f"  Samples: {len(subset)}")
            print(f"  Correct: {correct} ({correct/len(subset)*100:.1f}%)")
            print(f"  Oracle length - Mean: {subset['oracle_length'].mean():.1f}, Max: {subset['oracle_length'].max()}")
    
    print(f"\n=== By Category ===")
    for cat in ['OOD']:
        subset = df[df['category'] == cat]
        if len(subset) > 0:
            correct = subset['is_correct'].sum()
            print(f"{cat}: {len(subset)} samples, {correct/len(subset)*100:.1f}% accuracy")


if __name__ == "__main__":
    main()