#!/usr/bin/env python3
"""
eval_kodcode_reference_solutions.py

Evaluates all ground truth/reference solutions in the KodCode dataset
by running them against their test suites.

This helps verify dataset quality and identify any problematic tasks
where the reference solution doesn't pass all tests.
"""

from __future__ import annotations

import argparse
import json
import multiprocessing
import os
import time
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Any, Dict, List, Optional, Tuple

from tqdm import tqdm


def extract_code_from_markdown(code_str: str) -> str:
    """Extract code from markdown code blocks if present."""
    import re
    
    if not code_str:
        return ""
    
    # Try to extract from markdown code block
    code_blocks = re.findall(r"```(?:\w+)?\n(.*?)```", code_str, re.DOTALL)
    if code_blocks:
        return code_blocks[-1].strip()
    
    return code_str.strip()


def run_single_task(task: Dict[str, Any], timeout_per_test: int = 5) -> Dict[str, Any]:
    """
    Run a single task's reference solution against its tests.
    
    Returns a dict with:
        - task_id: The task identifier
        - passed: Whether all tests passed
        - error: Error message if any
        - output: Test output
        - subset: The subset category (Algorithm, Data Structures, etc.)
    """
    from rllm.rewards.code_reward import kodcode_check_correctness
    
    task_id = task.get("uid", task.get("task_id", task.get("index", "unknown")))
    subset = task.get("subset", task.get("metadata", {}).get("subset", "unknown"))
    
    reference_solution = task.get("reference_solution", "")
    ground_truth = task.get("ground_truth", "")
    
    if not reference_solution:
        return {
            "task_id": task_id,
            "passed": False,
            "error": "No reference solution found",
            "output": "",
            "subset": subset,
        }
    
    if not ground_truth:
        return {
            "task_id": task_id,
            "passed": False,
            "error": "No tests found",
            "output": "",
            "subset": subset,
        }
    
    # Extract code from markdown if needed
    code = extract_code_from_markdown(reference_solution)
    
    if not code.strip():
        return {
            "task_id": task_id,
            "passed": False,
            "error": "Empty code after extraction",
            "output": "",
            "subset": subset,
        }
    
    try:
        passed, details = kodcode_check_correctness(ground_truth, code, timeout_per_test=timeout_per_test)
        return {
            "task_id": task_id,
            "passed": passed,
            "error": None if passed else details.get("output", "Test failed"),
            "output": details.get("output", ""),
            "subset": subset,
            "total_tests": details.get("total_tests", 0),
        }
    except Exception as e:
        return {
            "task_id": task_id,
            "passed": False,
            "error": f"Exception: {str(e)}",
            "output": "",
            "subset": subset,
        }


def main():
    parser = argparse.ArgumentParser(
        description="Evaluate all ground truth solutions in KodCode dataset"
    )
    parser.add_argument(
        "--split",
        type=str,
        default="train",
        help="Dataset split to evaluate (default: train)",
    )
    parser.add_argument(
        "--n_tasks",
        type=int,
        default=-1,
        help="Number of tasks to evaluate (-1 for all, default: -1)",
    )
    parser.add_argument(
        "--n_workers",
        type=int,
        default=32,
        help="Number of parallel workers (default: 32)",
    )
    parser.add_argument(
        "--timeout",
        type=int,
        default=5,
        help="Timeout per test in seconds (default: 5)",
    )
    parser.add_argument(
        "--save_failures",
        type=str,
        default=None,
        help="Path to save failed tasks as JSON (optional)",
    )
    parser.add_argument(
        "--save_results",
        type=str,
        default=None,
        help="Path to save all results as JSON (optional)",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Print details of failed tasks",
    )
    parser.add_argument(
        "--subset",
        type=str,
        default=None,
        help="Filter to specific subset (e.g., 'Algorithm', 'Data Structures')",
    )
    
    args = parser.parse_args()
    
    # Disable tokenizer parallelism warning
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    print("=" * 80)
    print("🧪 KODCODE REFERENCE SOLUTION EVALUATOR")
    print("=" * 80)
    
    # Load KodCode dataset
    print(f"\nLoading KodCode dataset (split: {args.split})...")
    
    from rllm.data.dataset import DatasetRegistry
    
    try:
        dataset = DatasetRegistry.load_dataset("kodcode", args.split)
        tasks = list(dataset.get_data())
    except Exception as e:
        print(f"Failed to load from registry: {e}")
        print("Trying to prepare dataset from HuggingFace...")
        
        from examples.bugs.data_processing.prepare_kodcode_data import prepare_kodcode_data
        prepare_kodcode_data(repo_id="KodCode/KodCode-V1", registry_name="kodcode")
        
        dataset = DatasetRegistry.load_dataset("kodcode", args.split)
        tasks = list(dataset.get_data())
    
    print(f"  Loaded {len(tasks)} tasks")
    
    # Filter by subset if specified
    if args.subset:
        original_count = len(tasks)
        tasks = [t for t in tasks if t.get("subset", t.get("metadata", {}).get("subset", "")) == args.subset]
        print(f"  Filtered to subset '{args.subset}': {len(tasks)}/{original_count} tasks")
    
    # Limit number of tasks if specified
    if args.n_tasks > 0:
        tasks = tasks[:args.n_tasks]
        print(f"  Limited to first {len(tasks)} tasks")
    
    if not tasks:
        print("No tasks to evaluate!")
        return
    
    # Run evaluation
    print(f"\n🚀 Running evaluation with {args.n_workers} workers...")
    print(f"   Timeout per test: {args.timeout}s")
    
    results: List[Dict[str, Any]] = []
    start_time = time.time()
    
    # Use ProcessPoolExecutor for parallel execution
    with ProcessPoolExecutor(max_workers=args.n_workers) as executor:
        futures = {
            executor.submit(run_single_task, task, args.timeout): i
            for i, task in enumerate(tasks)
        }
        
        for future in tqdm(as_completed(futures), total=len(futures), desc="Evaluating"):
            try:
                result = future.result(timeout=args.timeout * 20 + 60)
                results.append(result)
            except Exception as e:
                idx = futures[future]
                results.append({
                    "task_id": tasks[idx].get("uid", str(idx)),
                    "passed": False,
                    "error": f"Worker exception: {str(e)}",
                    "output": "",
                    "subset": tasks[idx].get("subset", "unknown"),
                })
    
    elapsed = time.time() - start_time
    
    # Compute statistics
    total = len(results)
    passed = sum(1 for r in results if r["passed"])
    failed = total - passed
    
    # Group by subset
    subset_stats: Dict[str, Dict[str, int]] = defaultdict(lambda: {"total": 0, "passed": 0})
    for r in results:
        subset = r.get("subset", "unknown")
        subset_stats[subset]["total"] += 1
        if r["passed"]:
            subset_stats[subset]["passed"] += 1
    
    # Print results
    print("\n" + "=" * 80)
    print("📊 EVALUATION RESULTS")
    print("=" * 80)
    print(f"\nTotal tasks evaluated: {total}")
    print(f"Time elapsed: {elapsed:.1f}s ({elapsed/max(1,total):.2f}s per task)")
    print(f"\nOverall Results:")
    print(f"  ✅ Passed: {passed}/{total} ({100*passed/total:.1f}%)")
    print(f"  ❌ Failed: {failed}/{total} ({100*failed/total:.1f}%)")
    
    if subset_stats:
        print(f"\nResults by Subset:")
        for subset in sorted(subset_stats.keys()):
            stats = subset_stats[subset]
            pct = 100 * stats["passed"] / stats["total"] if stats["total"] > 0 else 0
            print(f"  {subset:20s}: {stats['passed']:4d}/{stats['total']:4d} ({pct:5.1f}%)")
    
    # Show failed tasks if verbose
    failed_results = [r for r in results if not r["passed"]]
    if args.verbose and failed_results:
        print("\n" + "=" * 80)
        print(f"❌ FAILED TASKS ({len(failed_results)} total)")
        print("=" * 80)
        for i, r in enumerate(failed_results[:20]):  # Show first 20
            print(f"\n--- Task {i+1}: {r['task_id']} (subset: {r['subset']}) ---")
            error = r.get("error", "Unknown error")
            # Truncate long errors
            if len(error) > 500:
                error = error[:500] + "..."
            print(f"Error: {error}")
        
        if len(failed_results) > 20:
            print(f"\n... and {len(failed_results) - 20} more failed tasks")
    
    # Save failures if requested (just task IDs)
    if args.save_failures and failed_results:
        os.makedirs(os.path.dirname(args.save_failures) or ".", exist_ok=True)
        failed_task_ids = [r["task_id"] for r in failed_results]
        with open(args.save_failures, "w") as f:
            json.dump(failed_task_ids, f, indent=2)
        print(f"\n💾 Failed task IDs saved to: {args.save_failures} ({len(failed_task_ids)} tasks)")
    
    # Save all results if requested
    if args.save_results:
        os.makedirs(os.path.dirname(args.save_results) or ".", exist_ok=True)
        with open(args.save_results, "w") as f:
            json.dump(results, f, indent=2)
        print(f"💾 All results saved to: {args.save_results}")
    
    print("\n✅ Done!")
    
    # Return exit code based on pass rate
    if passed == total:
        return 0
    else:
        return 1


if __name__ == "__main__":
    exit(main() or 0)
