#!/usr/bin/env python3
"""
Batched progressive evaluation script for finding challenging problems in NuminaMath dataset.

This version processes data in batches to handle large datasets (100K+ examples)
without running out of memory. It's designed for distributed evaluation where
each job might handle thousands of examples.

Key features:
- Processes data in configurable batch sizes
- Saves intermediate results to disk regularly
- Tracks progress across batches
- Memory-efficient for large-scale evaluation
"""

import json
import os
import numpy as np
from tqdm import tqdm
import multiprocessing
from functools import partial
import re
import unicodedata
from typing import Dict, List, Tuple, Optional
import time
import gc

import datasets
import vllm
import transformers
import torch

# Import chat template handling
from chat_templates import set_chat_template

# Default instruction for math problems
INSTRUCTION = ' Think step by step, and put your final answer within \\boxed{}.'

# Regex used previously to find the answer in the \boxed{} command (kept for reference)
_BOX_REGEX = re.compile(r"\\boxed\s*{([^}]*)}")

def _normalize_for_compare(text: str) -> str:
    """Lightweight LaTeX-ish normalization to reduce false negatives.
    """
    if text is None:
        return ""
    s = str(text)
    s = s.replace("\\dfrac", "\\frac").replace("\\tfrac", "\\frac")
    s = re.sub(r"^\$+|\$+$", "", s)
    s = s.replace("\\left", "").replace("\\right", "")
    # Remove common spacing commands and tildes
    for tok in ["\\,", "\\!", "\\:", "\\;", "\\quad", "\\qquad", "~", "\\ "]:
        s = s.replace(tok, "")
    # Normalize some unicode variants
    s = s.replace("−", "-").replace("–", "-").replace("—", "-")
    s = s.replace("°", "^{\\circ}")
    # Strip simple wrappers that only affect styling
    try:
        s = re.sub(r"\\(mathrm|operatorname|text)\s*\{([^{}]*)\}", r"\\2", s)
    except Exception:
        pass
    # Trim trailing punctuation/spaces
    s = re.sub(r"[.,;:\\s]+$", "", s)
    s = re.sub(r"\s+", "", s)
    return s

def _extract_last_boxed(text: str) -> str:
    """Extract content of the last \\boxed{...} handling nested braces.

    This scans for the last "\\boxed{" (allowing whitespace) and then walks
    forward tracking brace depth to find the matching closing brace. Correctly
    handles nested LaTeX like \\frac{2}{5}.
    """
    if not text:
        return ""
    # Support optional arguments inside []: \boxed[...]{...}
    #start_pattern = re.compile(r"\\boxed(?:\\[[^\\]]*\\])?\\s*\\{")
    start_pattern = re.compile(r"\\boxed\s*\{")
    #start_pattern = re.compile(r"\\boxed\\s*\\{") # bad
    matches = list(start_pattern.finditer(text))
    if not matches:
        return ""
    open_brace_index = matches[-1].end() - 1  # position of '{'
    i = open_brace_index + 1
    depth = 1
    n = len(text)
    while i < n and depth > 0:
        ch = text[i]
        if ch == '{':
            depth += 1
        elif ch == '}':
            depth -= 1
        i += 1
    if depth != 0:
        return ""
    return text[open_brace_index + 1:i - 1].strip()

# Try to import math_verify for better grading
try:
    from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig
    from math_verify import parse, verify
    MATH_VERIFY_AVAILABLE = True
    try:
        import os as _os  # local alias to avoid clobbering
        if _os.environ.get("EVAL_DEBUG"):
            print("math_verify available")
    except Exception:
        pass
except ImportError:
    MATH_VERIFY_AVAILABLE = False
    try:
        import os as _os  # local alias to avoid clobbering
        if _os.environ.get("EVAL_DEBUG"):
            print("Warning: math_verify not available, falling back to simple string matching")
    except Exception:
        pass


def grade_single_response(task: Tuple[int, str, str]) -> Tuple[int, bool]:
    """Grade a single response against ground truth answer."""
    index, response, gt_answer = task
    is_correct = False
    
    # Normalize unicode characters
    if response:
        response = unicodedata.normalize('NFKC', response)
    
    # Extract answer from \boxed{} (supports nested LaTeX)
    model_answer = _extract_last_boxed(response)
    if not model_answer:
        return (index, False)

    if MATH_VERIFY_AVAILABLE:
        # Prefer symbolic verification; fall back to normalized string compare on any failure
        try:
            gt_parsed = parse(f"${gt_answer}$", extraction_config=[ExprExtractionConfig(), LatexExtractionConfig()])
            model_parsed = parse(f"${model_answer}$", extraction_config=[ExprExtractionConfig(), LatexExtractionConfig()])
            is_correct = verify(gt_parsed, model_parsed)
        except Exception:
            try:
                import os as _os
                if _os.environ.get("EVAL_DEBUG"):
                    print("math_verify failed, falling back to normalized string compare")
            except Exception:
                pass
            is_correct = (_normalize_for_compare(model_answer).lower() == _normalize_for_compare(gt_answer).lower())
    else:
        # Normalized string comparison (case-insensitive)
        try:
            import os as _os
            if _os.environ.get("EVAL_DEBUG"):
                print("math_verify not available, falling back to normalized string compare")
        except Exception:
            pass
        is_correct = (_normalize_for_compare(model_answer).lower() == _normalize_for_compare(gt_answer).lower())
    
    return (index, is_correct)


class BatchedProgressiveEvaluator:
    """Handles progressive evaluation of math problems in batches."""
    
    def __init__(self, 
                 model_path: str,
                 initial_samples: int = 64,
                 sample_increment: int = 64,
                 target_samples: int = 1024,
                 temperature: float = 0.6,
                 top_p: float = 0.95,
                 max_tokens: int = 8192,
                 tensor_parallel_size: int = 2,
                 n_samples_chunk: int = 8,
                 num_grading_workers: Optional[int] = None,
                 batch_size: int = 100,
                 resume_from_samples: int = 0,
                 debug: bool = False):
        """
        Initialize the batched progressive evaluator.
        
        Args:
            batch_size: Number of problems to process at once (default: 100)
            resume_from_samples: Number of samples already completed (for resuming)
            debug: Enable debug logging to diagnose issues
            Other args same as ProgressiveEvaluator
        """
        self.initial_samples = initial_samples
        self.sample_increment = sample_increment
        self.target_samples = target_samples
        self.temperature = temperature
        self.top_p = top_p
        self.max_tokens = max_tokens
        self.n_samples_chunk = n_samples_chunk
        self.num_grading_workers = num_grading_workers or multiprocessing.cpu_count()
        self.batch_size = batch_size
        self.resume_from_samples = resume_from_samples
        self.debug = debug
        
        # Load model and tokenizer
        print(f"Loading model from: {model_path}")
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        
        # Set chat template if not already present
        set_chat_template(self.tokenizer, model_path)
        
        # Initialize vLLM model
        self.model = vllm.LLM(
            model=model_path,
            dtype="bfloat16",
            tensor_parallel_size=tensor_parallel_size,
            trust_remote_code=True,
            gpu_memory_utilization=0.90,
            enforce_eager=True,
            max_model_len=max_tokens
        )

    def process_batch(self, batch_data: List[dict], round_num: int, current_samples: int,
                     samples_this_round: int) -> Tuple[List[dict], List[int]]:
        """
        Process a single batch of problems.
        
        Returns:
            Tuple of (results_list, indices_with_zero_pass)
        """
        batch_prompts = [item['prompt'] for item in batch_data]
        batch_indices = [item['index'] for item in batch_data]
        batch_gt_answers = [item['gt_answer'] for item in batch_data]
        
        # Set up sampling parameters
        n_chunks = (samples_this_round + self.n_samples_chunk - 1) // self.n_samples_chunk
        samples_per_chunk = min(samples_this_round, self.n_samples_chunk)
        
        sample_params = vllm.SamplingParams(
            temperature=self.temperature,
            top_p=self.top_p,
            n=samples_per_chunk,
            max_tokens=self.max_tokens,
        )
        
        # Debug: Log first prompt to check formatting
        if self.debug and batch_prompts:
            print(f"\n=== DEBUG: First prompt being sent to model ===")
            print(batch_prompts[0][:500])  # First 500 chars
            print("=== END DEBUG ===\n")
        
        # Generate responses
        batch_responses = {idx: [] for idx in batch_indices}
        
        for chunk_idx in range(n_chunks):
            vllm_outputs = self.model.generate(batch_prompts, sample_params)
            
            for idx, output in zip(batch_indices, vllm_outputs):
                texts = [o.text.strip() for o in output.outputs]
                batch_responses[idx].extend(texts)
                
                # Debug: Log first response from first batch
                if self.debug and chunk_idx == 0 and idx == batch_indices[0]:
                    print(f"\n=== DEBUG: Sample response for problem {idx} ===")
                    if texts:
                        print(texts[0][:1000])  # First 1000 chars of first response
                    else:
                        print("No response generated!")
                    print("=== END DEBUG ===\n")
        
        # Grade responses
        tasks = []
        # Pre-build ground truth lookup to avoid O(n^2) indexing
        idx_to_gt = {item['index']: item['gt_answer'] for item in batch_data}
        for idx in batch_indices:
            gt_answer = idx_to_gt.get(idx)
            for resp in batch_responses[idx]:
                tasks.append((idx, resp, gt_answer))
        
        # Parallel grading
        batch_results = {idx: [] for idx in batch_indices}
        with multiprocessing.Pool(processes=self.num_grading_workers) as pool:
            for idx, is_correct in pool.imap_unordered(grade_single_response, tasks):
                batch_results[idx].append(is_correct)
        
        # Prepare results and identify zero-pass problems
        results_list = []
        zero_pass_indices = []
        
        for i, idx in enumerate(batch_indices):
            # Combine with previous results if this is not the first round
            if 'previous_results' in batch_data[i]:
                all_results = batch_data[i]['previous_results'] + batch_results[idx]
            else:
                all_results = batch_results[idx]
            
            has_correct = np.any(all_results)

            result = {
                'index': idx,
                'problem': batch_data[i]['problem'],
                'original_prompt': batch_data[i].get('original_prompt', None),
                'prompt': batch_data[i].get('prompt', None),
                'gt_answer': batch_data[i]['gt_answer'],
                'n_samples': len(all_results),
                'exact_match': [bool(b) for b in all_results],
                'has_correct': bool(has_correct),
                'round': round_num
            }
            # Include a small sample of responses for failed items to aid debugging
            if not has_correct:
                # Save up to two responses from this round
                sample_resps = batch_responses.get(idx, [])[:2]
                result['responses_sample'] = sample_resps
            results_list.append(result)
            
            if not has_correct:
                zero_pass_indices.append(idx)
        
        return results_list, zero_pass_indices
    
    def evaluate_progressive_batched(self, 
                                   dataset: datasets.Dataset,
                                   output_dir: str,
                                   job_id: int = 0,
                                   save_intermediate: bool = True,
                                   resume_from_samples: int = 0) -> None:
        """
        Perform progressive evaluation on the dataset using batched processing.
        
        This method processes the dataset in batches and saves results incrementally
        to handle large datasets without OOM issues.
        """
        os.makedirs(output_dir, exist_ok=True)
        
        # Prepare initial problem data
        print("Preparing dataset...")
        problem_data = []
        
        for example in tqdm(dataset, desc="Processing dataset"):
            idx = example["index"]

            # Data is now in openai_math format
            prompt_content = example['prompt'][0]['content']
            gt_answer = str(example['reward_model']['ground_truth'])
            
            # The original problem text can be retrieved from extra_info if needed
            problem_text = example.get('extra_info', {}).get('original_problem', '')

            if not gt_answer:
                print(f"Warning: Skipping example {idx} - no ground truth answer found")
                continue

            # Create prompt using chat template
            prompt = self.tokenizer.apply_chat_template(
                [{"role": "user", "content": prompt_content}],
                tokenize=False,
                add_generation_prompt=True
            )
            
            problem_data.append({
                'index': idx,
                'prompt': prompt,
                'gt_answer': gt_answer,
                'problem': problem_text,
                'original_prompt': prompt_content
            })
        
        print(f"Prepared {len(problem_data)} problems for evaluation")
        
        # Progressive evaluation with batching
        active_data = problem_data.copy()
        current_samples = resume_from_samples
        round_num = 0
        
        # If resuming, calculate the starting round number
        if resume_from_samples > 0:
            # This is an approximation, but good for logging. Handle zero/negative increments safely.
            if self.sample_increment and self.sample_increment > 0:
                round_num = (resume_from_samples - self.initial_samples + self.sample_increment) // self.sample_increment
            else:
                round_num = 0
                print("sample_increment <= 0; treating as single-round run for resume logging.")
            print(f"Resuming evaluation from {resume_from_samples} samples (approx. round {round_num})")

        # Track challenging problems across all batches
        all_challenging_indices = set()
        
        while current_samples < self.target_samples and active_data:
            round_start_time = time.time()
            # Determine samples for this round
            if round_num == 0 and resume_from_samples == 0:
                samples_this_round = self.initial_samples
            else:
                samples_this_round = self.sample_increment
            
            # Handle zero or negative increments to avoid infinite loops
            if samples_this_round <= 0:
                print(f"Sample increment is {samples_this_round}. Stopping progression at {current_samples} samples.")
                break
            
            current_samples += samples_this_round
            
            print(f"\n=== Round {round_num + 1}: Generating {samples_this_round} samples " +
                  f"(total: {current_samples}) for {len(active_data)} problems ===")
            
            # Process in batches
            new_active_data = []
            round_results = []
            
            for batch_start in tqdm(range(0, len(active_data), self.batch_size), 
                                  desc=f"Processing batches"):
                batch_end = min(batch_start + self.batch_size, len(active_data))
                batch = active_data[batch_start:batch_end]
                
                # Process batch
                batch_results, zero_pass_indices = self.process_batch(
                    batch, round_num, current_samples, samples_this_round
                )
                
                round_results.extend(batch_results)
                
                # Update active data for next round
                for item in batch:
                    if item['index'] in zero_pass_indices:
                        # Update with accumulated results for next round
                        idx_in_results = next(i for i, r in enumerate(batch_results) 
                                            if r['index'] == item['index'])
                        item['previous_results'] = batch_results[idx_in_results]['exact_match']
                        new_active_data.append(item)
                
                # Save intermediate results for this batch
                if save_intermediate:
                    intermediate_file = os.path.join(
                        output_dir, 
                        f'intermediate_round_{round_num}_job_{job_id}_batch_{batch_start}.jsonl'
                    )
                    with open(intermediate_file, 'w') as f:
                        for result in batch_results:
                            if not result['has_correct']:
                                f.write(json.dumps(result) + '\n')
                
                # Clear memory
                del batch_results
                gc.collect()
            
            round_duration = time.time() - round_start_time
            print(f"Round {round_num + 1} completed in {round_duration:.2f} seconds.")
            print(f"Problems with 0 pass@{current_samples}: {len(new_active_data)}/{len(active_data)}")
            
            # Update challenging problems set
            if current_samples >= self.target_samples:
                for item in new_active_data:
                    all_challenging_indices.add(item['index'])
            
            active_data = new_active_data
            round_num += 1
        
        final_challenging_indices = {item['index'] for item in active_data}

        # Save final challenging problems to a parquet file
        print(f"\nSaving {len(final_challenging_indices)} challenging problems to parquet file...")
        
        # Filter the original job's dataset to get only the challenging problems
        challenging_dataset = dataset.filter(lambda example: example['index'] in final_challenging_indices)
        
        # The user wants all old columns. The 'index' column was added temporarily for processing.
        # Let's remove it before saving. The true original index is in extra_info['index'].
        if 'index' in challenging_dataset.column_names:
            challenging_dataset = challenging_dataset.remove_columns(['index'])

        challenging_file = os.path.join(output_dir, f'challenging_problems_job_{job_id}.parquet')
        
        challenging_dataset.to_parquet(challenging_file)
        
        print(f"Evaluation complete! Found {len(challenging_dataset)} challenging problems.")
        print(f"Results saved to: {challenging_file}")


def main():
    import argparse
    
    parser = argparse.ArgumentParser(description="Batched progressive evaluation for large datasets")
    
    # Model and generation parameters
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--max_tokens", type=int, default=8192)
    parser.add_argument("--tensor_parallel_size", type=int, default=2)
    
    # Progressive evaluation parameters
    parser.add_argument("--initial_samples", type=int, default=64)
    parser.add_argument("--sample_increment", type=int, default=64)
    parser.add_argument("--target_samples", type=int, default=1024)
    parser.add_argument("--n_samples_chunk", type=int, default=8)
    
    # Batching parameters
    parser.add_argument("--batch_size", type=int, default=100,
                       help="Number of problems to process at once")
    
    # Dataset parameters
    parser.add_argument("--dataset_path", type=str, required=True)
    parser.add_argument("--limit", type=int, default=None)
    
    # Output parameters
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--exp_name", type=str, required=True)
    
    # Distributed evaluation
    parser.add_argument("--n_jobs", type=int, default=1)
    parser.add_argument("--job_id", type=int, default=0)
    
    # System parameters
    parser.add_argument("--num_grading_workers", type=int, default=None)
    parser.add_argument("--save_intermediate", action="store_true")
    parser.add_argument("--resume_from_samples", type=int, default=0,
                        help="Number of samples already completed (for resuming).")
    parser.add_argument("--debug", action="store_true",
                        help="Enable debug logging to diagnose issues")
    
    # Aggregation mode
    parser.add_argument("--aggregate_only", action="store_true")
    
    args = parser.parse_args()
    
    # Create output directory
    output_path = os.path.join(args.output_dir, args.exp_name)
    os.makedirs(output_path, exist_ok=True)
    
    if args.aggregate_only:
        # Aggregation logic for parquet files
        print("=== AGGREGATION MODE ===")
        
        pattern = os.path.join(output_path, 'challenging_problems_job_*.parquet')
        files = glob.glob(pattern)
        
        if not files:
            print(f"No files found matching pattern: {pattern}")
            return
        
        print(f"Found {len(files)} files to aggregate")
        
        # Load all parquet files into a single datasets.Dataset object
        all_datasets = [datasets.load_dataset('parquet', data_files=f)['train'] for f in files]
        consolidated_dataset = datasets.concatenate_datasets(all_datasets)
        
        # Save consolidated results
        consolidated_path = os.path.join(output_path, 'challenging_problems_all.parquet')
        consolidated_dataset.to_parquet(consolidated_path)
        
        print(f"Found {len(consolidated_dataset)} challenging problems total")
        print(f"Results saved to: {consolidated_path}")
        
        # Save summary using the original index from extra_info
        summary = {
            'total_challenging_problems': len(consolidated_dataset),
            'problems_by_dataset_index': [
                row['extra_info']['index'] for row in consolidated_dataset
            ]
        }
        
        summary_path = os.path.join(output_path, 'summary.json')
        with open(summary_path, 'w') as f:
            json.dump(summary, f, indent=2)
        
        return
    
    # Load dataset
    print(f"Loading dataset from: {args.dataset_path}")
    dataset = datasets.load_dataset('parquet', data_files={'test': args.dataset_path})['test']
    
    # We will use the row number as the 'index' for this job, but the original
    # index from the full dataset is preserved in 'extra_info'.
    dataset = dataset.add_column('index', list(range(len(dataset))))
    
    if args.limit:
        dataset = dataset.select(range(min(args.limit, len(dataset))))
    
    # Handle distributed evaluation
    if args.n_jobs > 1:
        total = len(dataset)
        n = args.n_jobs
        j = args.job_id
        base = total // n
        remainder = total % n
        # Distribute the remainder: first `remainder` jobs get one extra item
        start_idx = j * base + min(j, remainder)
        count = base + (1 if j < remainder else 0)
        end_idx = min(start_idx + count, total)
        if start_idx >= total or count == 0:
            print(f"Job {j}/{n}: No data assigned (dataset size {total}).")
            return
        dataset = dataset.select(range(start_idx, end_idx))
        print(f"Job {j}/{n}: Evaluating indices {start_idx} to {end_idx-1} ({end_idx - start_idx} items)")
    
    # Initialize evaluator
    evaluator = BatchedProgressiveEvaluator(
        model_path=args.model_path,
        initial_samples=args.initial_samples,
        sample_increment=args.sample_increment,
        target_samples=args.target_samples,
        temperature=args.temperature,
        top_p=args.top_p,
        max_tokens=args.max_tokens,
        tensor_parallel_size=args.tensor_parallel_size,
        n_samples_chunk=args.n_samples_chunk,
        num_grading_workers=args.num_grading_workers,
        batch_size=args.batch_size,
        resume_from_samples=args.resume_from_samples,
        debug=args.debug
    )
    
    # Run evaluation
    evaluator.evaluate_progressive_batched(
        dataset=dataset,
        output_dir=output_path,
        job_id=args.job_id,
        save_intermediate=args.save_intermediate,
        resume_from_samples=args.resume_from_samples
    )


if __name__ == "__main__":
    import glob
    main() 