import json
import multiprocessing as mp
from functools import partial
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import os

from simplified_evaluator.eval import parse_prediction
from datasets import load_dataset

def process_batch(batch_questions, data, gt_dict):
    """Process a batch of questions to extract answers"""
    results = {}
    for question in batch_questions:
        try:
            gt = gt_dict[question]['ground_truth']
            all_reply = data[question]
            answers = [parse_prediction(reply, gt, 'math') for reply in all_reply]
            results[question] = answers
        except Exception as e:
            print(f"Error processing question: {question[:50]}..., {str(e)}")
    return results

def main():
    gt_data = {}
    dataset_name= "math500"  # Change this to "math500" or "gsm8k" as needed
    # Load the samples json
    # JSON_DIRS = "Qwen2.5_1.5B_GSM8K_samples/samples.json"
    # JSON_DIRS = "Llama3.1_8B_MATH500_samples/samples.json"
    JSON_DIRS = "Qwen2.5_1.5B_MATH500_samples/samples.json"
    with open(JSON_DIRS, "r") as f:
        data = json.load(f)
        questions = list(data.keys())


    if dataset_name == "math500":
        ds = load_dataset("HuggingFaceH4/MATH-500")['test']
    elif dataset_name == "gsm8k":
        ds = load_dataset("gsm8k", "main", split="test")
    for row in ds:
        prob_content = row["problem"] if dataset_name == "math500" else row["question"]
        gt = row["answer"] if dataset_name == "math500" else row["answer"].split('####')[-1].strip()
        gt_data[prob_content] = {
            'ground_truth': gt
        }

    
    # Set up multiprocessing
    num_processes = max(1, os.cpu_count() - 2)  # Leave some cores free
    total_questions = len(questions)
    batch_size = max(10, total_questions // (num_processes * 4))  # Aim for ~4 batches per process

    print(f"Using {num_processes} processes with batch size {batch_size} for {total_questions} questions")
    
    # Create batches of questions
    batches = [questions[i:i+batch_size] for i in range(0, total_questions, batch_size)]

    # Process batches in parallel
    parsed_answers = {}
    with ProcessPoolExecutor(max_workers=num_processes) as executor:
        process_func = partial(process_batch, data=data, gt_dict=gt_data)
        
        results = list(tqdm(
            executor.map(process_func, batches),
            total=len(batches),
            desc="Processing batches"
        ))
    
    # Combine batch results
    for batch_result in results:
        parsed_answers.update(batch_result)
    
    # Write to another json file
    new_json_path = JSON_DIRS.replace(".json", "_parsed.json")
    with open(new_json_path, "w") as f:
        json.dump(parsed_answers, f)
    
    print(f"Processed {len(parsed_answers)} questions out of {total_questions}. Results saved to {new_json_path}")

if __name__ == "__main__":
    # No need for spawn with ProcessPoolExecutor - it handles it automatically
    main()