#!/usr/bin/env python3
"""
Phase 1.1: Training Data Labeling (For Predictor) - vLLM Accelerated Version
- Input: GSM8K Train, MATH Train
- Operation: Use vLLM to accelerate Qwen2.5-Math-7B inference (max_new_tokens=4096)
- Filter: Keep only Answer Correct samples
- Label: Calculate Token ID length of generated text
- Output: data_cache/train_labeled_4096.parquet
"""

import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from math_verify.metric import math_metric
from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig

# ============== Configuration ==============
MODEL_NAME = "./models/Qwen2.5-Math-7B"
# MODEL_NAME = "./models/Qwen3-30B-A3B-Instruct-2507"
GSM8K_TRAIN_PATH = "./data/gsm8k_train_split.parquet"  # GSM8K data
MATH_TRAIN_PATH = "./data/math_train.parquet"
# OUTPUT_PATH = "./data/cache/train_labeled_30B_16k.parquet"
OUTPUT_PATH = "./data/cache/train_labeled_8k.parquet"

MAX_NEW_TOKENS = 4096
TENSOR_PARALLEL_SIZE = 4  # Qwen2.5-Math-7B has 28 attention heads, can only use factors of 28 (1,2,4,7)

# Prompt Template - Must match testing phase
PROMPT_TEMPLATE = """Solve the following math problem step by step. Show your work clearly and put the final answer inside \\boxed{{}}.

Question: {question}

Answer:"""


# ============== Data Loading ==============
def load_train_data():
    """Load GSM8K and MATH training data"""
    print("Loading training datasets...")
    
    all_samples = []
    
    # 1. Load GSM8K
    print("  Loading GSM8K...")
    if os.path.exists(GSM8K_TRAIN_PATH):
        gsm8k_df = pd.read_parquet(GSM8K_TRAIN_PATH)
        for _, row in gsm8k_df.iterrows():
            all_samples.append({
                'question': row['question'],
                'gold_answer': row['answer'],
                'dataset': 'gsm8k'
            })
        print(f"    GSM8K: {len(gsm8k_df)} samples")
    else:
        print(f"    GSM8K file not found: {GSM8K_TRAIN_PATH}")
    
    # 2. Load MATH train
    print("  Loading MATH train...")
    math_df = pd.read_parquet(MATH_TRAIN_PATH)
    for _, row in math_df.iterrows():
        all_samples.append({
            'question': row['problem'],
            'gold_answer': row['solution'],
            'dataset': 'math',
            'level': row.get('level', 'Unknown'),
            'type': row.get('type', 'Unknown')
        })
    print(f"    MATH train: {len(math_df)} samples")
    
    print(f"  Total training samples: {len(all_samples)}")
    return all_samples


# ============== Math Verification ==============
def create_verify_func():
    """Create math_verify function"""
    return math_metric(
        gold_extraction_target=(LatexExtractionConfig(), ExprExtractionConfig()),
        pred_extraction_target=(LatexExtractionConfig(), ExprExtractionConfig()),
        aggregation_function=max,
        precision=6
    )


# ============== vLLM Generation ==============
def run_vllm_generation(samples):
    """Use vLLM for efficient batch inference"""
    print(f"\nInitializing vLLM with tensor_parallel_size={TENSOR_PARALLEL_SIZE}...")
    
    # Initialize vLLM
    llm = LLM(
        model=MODEL_NAME,
        tensor_parallel_size=TENSOR_PARALLEL_SIZE,
        dtype="bfloat16",
        trust_remote_code=True,
        max_model_len=8192,  # Support long output
    )
    
    # Load tokenizer for length calculation
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    
    # Set sampling parameters
    sampling_params = SamplingParams(
        temperature=0,  # Greedy decoding
        max_tokens=MAX_NEW_TOKENS,
        stop=None,
    )
    
    # Construct all prompts
    print("Preparing prompts...")
    prompts = [PROMPT_TEMPLATE.format(question=s['question']) for s in samples]
    
    # Batch generation
    print(f"Generating responses for {len(prompts)} samples...")
    outputs = llm.generate(prompts, sampling_params)
    
    # Process outputs
    print("Processing outputs...")
    results = []
    verify_func = create_verify_func()
    
    for i, output in enumerate(tqdm(outputs, desc="Processing")):
        sample = samples[i]
        generated_text = output.outputs[0].text
        
        # Calculate generation length (using Qwen tokenizer)
        gen_tokens = tokenizer.encode(generated_text, add_special_tokens=False)
        gen_length = len(gen_tokens)
        
        # Verify answer correctness
        try:
            grade, _ = verify_func([sample['gold_answer']], [generated_text])
            is_correct = (grade == 1)
        except Exception as e:
            is_correct = False
        
        results.append({
            'question': sample['question'],
            'gold_answer': sample['gold_answer'],
            'dataset': sample['dataset'],
            'level': sample.get('level', ''),
            'type': sample.get('type', ''),
            'generated_text': generated_text,
            'gen_length': gen_length,
            'is_correct': is_correct
        })
    
    return results


# ============== Main ==============
def main():
    print("=" * 60)
    print("Phase 1.1: Training Data Label Generation (vLLM)")
    print("=" * 60)
    
    # 1. Load data
    samples = load_train_data()
    
    # 2. Run vLLM generation
    results = run_vllm_generation(samples)
    
    # 3. Filter correct samples
    correct_results = [r for r in results if r['is_correct']]
    print(f"\nFiltering results...")
    print(f"  Total generated: {len(results)}")
    print(f"  Correct answers: {len(correct_results)}")
    print(f"  Accuracy: {len(correct_results)/len(results)*100:.1f}%")
    
    # 4. Save results
    df = pd.DataFrame(correct_results)
    
    # Add log-scale length (for training)
    df['log_length'] = np.log1p(df['gen_length'])
    
    # Save
    os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
    df.to_parquet(OUTPUT_PATH, index=False)
    print(f"\nSaved to: {OUTPUT_PATH}")
    
    # Statistics
    print(f"\n=== Statistics ===")
    print(f"Total correct samples: {len(df)}")
    print(f"  GSM8K: {len(df[df['dataset']=='gsm8k'])}")
    print(f"  MATH: {len(df[df['dataset']=='math'])}")
    print(f"\nLength distribution:")
    print(f"  Mean: {df['gen_length'].mean():.1f}")
    print(f"  Median: {df['gen_length'].median():.1f}")
    print(f"  Min: {df['gen_length'].min()}")
    print(f"  Max: {df['gen_length'].max()}")
    print(f"  Std: {df['gen_length'].std():.1f}")


if __name__ == "__main__":
    main()