"""LLM-as-judge for evaluating bug similarity between generated and target bugs.

This module provides:
- BugSimilarityJudge: Scores how similar bug patterns are between two buggy code samples
- Test runner to evaluate similarity across different datasets and scenarios
"""

from __future__ import annotations

import asyncio
import json
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

from examples.bugs.prompts import _build_bug_similarity_judge_prompt


@dataclass
class BugSimilarityJudgeConfig:
    """Configuration for LLM-as-judge bug similarity scoring."""
    enabled: bool = False
    reward_weight: float = 0.5  # Weight for combining with main generator reward
    system_prompt: Optional[str] = None


class BugSimilarityJudge:
    """LLM-as-judge for scoring similarity between generated and target bugs."""

    def __init__(self, rollout_engine, config: Optional[BugSimilarityJudgeConfig] = None):
        """Initialize the judge.
        
        Args:
            rollout_engine: Engine for getting LLM responses (e.g., OpenAIEngine)
            config: Optional configuration for the judge
        """
        self.rollout_engine = rollout_engine
        self.config = config or BugSimilarityJudgeConfig()

    async def score_similarity(
        self,
        generated_problem: str,
        generated_bug: str,
        target_problem: str,
        target_bug: str,
        uid: str,
        generated_ground_truth: str = "",
        target_ground_truth: str = "",
    ) -> Tuple[float, Dict[str, Any]]:
        """Score how similar the generated bug pattern is to the target bug pattern.
        
        Args:
            generated_problem: Problem description for the generated bug
            generated_bug: The machine-generated buggy code
            target_problem: Problem description for the target/reference bug
            target_bug: The human-written buggy code to compare against
            uid: Unique identifier for this comparison
            generated_ground_truth: Correct solution for the generated bug's problem
            target_ground_truth: Correct solution for the target bug's problem
        
        Returns:
            Tuple of (normalized_score in [0,1], metadata dict)
        """
        prompt = _build_bug_similarity_judge_prompt(
            generated_problem=generated_problem,
            generated_bug=generated_bug,
            target_problem=target_problem,
            target_bug=target_bug,
            generated_ground_truth=generated_ground_truth,
            target_ground_truth=target_ground_truth,
        )
        
        messages: List[Dict[str, str]] = []
        if self.config.system_prompt:
            messages.append({"role": "system", "content": self.config.system_prompt})
        messages.append({"role": "user", "content": prompt})
        
        try:
            model_output = await self.rollout_engine.get_model_response(messages)
            response = model_output.content.strip()
            
            # Try to parse JSON response
            # Handle cases where response might have markdown code blocks
            json_match = re.search(r'\{[^{}]*"score"[^{}]*\}', response, re.DOTALL)
            if json_match:
                response = json_match.group(0)
            
            result = json.loads(response)
            raw_score = float(result.get("score", 0))
            reasoning = str(result.get("reasoning", ""))
            edit_size_comparison = str(result.get("edit_size_comparison", ""))
            
            # Normalize to [0, 1]
            normalized_score = max(0.0, min(1.0, raw_score / 10.0))
            
            return normalized_score, {
                "raw_score": raw_score,
                "normalized_score": normalized_score,
                "reasoning": reasoning,
                "edit_size_comparison": edit_size_comparison,
                "raw_response": model_output.content,
            }
        except Exception as e:
            # If parsing fails, return 0 with error info
            raw_resp = ""
            try:
                raw_resp = model_output.content if 'model_output' in dir() else ""
            except Exception:
                pass
            return 0.0, {
                "error": str(e),
                "raw_response": raw_resp,
            }


def _get_problem(task: Dict[str, Any]) -> str:
    """Extract problem description from a task."""
    for key in ("question", "instruct_prompt", "complete_prompt", "prompt", "text", "problem", "description", "code_prompt"):
        val = task.get(key)
        if isinstance(val, str) and val.strip():
            return val
    return ""


def _get_buggy_solution(task: Dict[str, Any]) -> Optional[str]:
    """Extract buggy solution from a task."""
    for key in ("buggy_solution", "buggy_sampled_solution", "buggy", "buggy_code", "bug"):
        val = task.get(key)
        if isinstance(val, str) and val.strip():
            return val
    return None


def _get_reference_solution(task: Dict[str, Any]) -> str:
    """Extract reference/correct solution from a task."""
    for key in ("reference_solution", "canonical_solution", "solution", "code", "correct_code"):
        val = task.get(key)
        if isinstance(val, str) and val.strip():
            return val
    return ""


async def run_similarity_comparison(
    judge: BugSimilarityJudge,
    bug1_task: Dict[str, Any],
    bug2_task: Dict[str, Any],
    bug1_label: str,
    bug2_label: str,
) -> Dict[str, Any]:
    """Run a single similarity comparison between two bugs."""
    problem1 = _get_problem(bug1_task)
    bug1 = _get_buggy_solution(bug1_task)
    gt1 = _get_reference_solution(bug1_task)
    problem2 = _get_problem(bug2_task)
    bug2 = _get_buggy_solution(bug2_task)
    gt2 = _get_reference_solution(bug2_task)
    
    if not bug1 or not bug2:
        return {
            "error": "Missing buggy solution",
            "bug1_present": bool(bug1),
            "bug2_present": bool(bug2),
        }
    
    uid = f"{bug1_label}_vs_{bug2_label}"
    score, meta = await judge.score_similarity(
        generated_problem=problem1,
        generated_bug=bug1,
        target_problem=problem2,
        target_bug=bug2,
        uid=uid,
        generated_ground_truth=gt1,
        target_ground_truth=gt2,
    )
    
    return {
        "bug1_label": bug1_label,
        "bug2_label": bug2_label,
        "score": score,
        "raw_score": meta.get("raw_score", 0),
        "reasoning": meta.get("reasoning", ""),
        "edit_size_comparison": meta.get("edit_size_comparison", ""),
        "problem1_preview": problem1[:200] + "..." if len(problem1) > 200 else problem1,
        "problem2_preview": problem2[:200] + "..." if len(problem2) > 200 else problem2,
    }


async def run_scenario_tests(
    judge: BugSimilarityJudge,
    datasets: Dict[str, List[Dict[str, Any]]],
    n_samples: int = 3,
) -> Dict[str, List[Dict[str, Any]]]:
    """Run bug similarity tests across different scenarios.
    
    Args:
        judge: The BugSimilarityJudge instance
        datasets: Dict mapping dataset names to lists of tasks with bugs
        n_samples: Number of samples per scenario
        
    Returns:
        Dict with results for each scenario
    """
    import random
    
    results = {
        "same_dataset_different_problems": [],
        "same_problem_different_datasets": [],
        "different_problems_different_datasets": [],
    }
    
    dataset_names = list(datasets.keys())
    
    # Scenario 1: Same dataset, different problems
    print("\n" + "=" * 80)
    print("📊 SCENARIO 1: Same dataset, different problems")
    print("=" * 80)
    for ds_name, tasks in datasets.items():
        if len(tasks) < 2:
            continue
        print(f"\n  Dataset: {ds_name}")
        sampled_pairs = []
        for i in range(min(n_samples, len(tasks) - 1)):
            idx1, idx2 = random.sample(range(len(tasks)), 2)
            sampled_pairs.append((idx1, idx2))
        
        for idx1, idx2 in sampled_pairs:
            result = await run_similarity_comparison(
                judge,
                tasks[idx1],
                tasks[idx2],
                f"{ds_name}[{idx1}]",
                f"{ds_name}[{idx2}]",
            )
            result["scenario"] = "same_dataset_different_problems"
            result["dataset"] = ds_name
            results["same_dataset_different_problems"].append(result)
            print(f"    {result['bug1_label']} vs {result['bug2_label']}: score={result.get('raw_score', 'N/A')}")
            if result.get("edit_size_comparison"):
                print(f"      Edit size: {result['edit_size_comparison'][:100]}...")
            if result.get("reasoning"):
                print(f"      Reasoning: {result['reasoning'][:100]}...")
    
    # Print scenario average
    scenario_scores = [r.get("raw_score", 0) for r in results["same_dataset_different_problems"] if "raw_score" in r]
    if scenario_scores:
        print(f"\n  📊 Scenario 1 Average: {sum(scenario_scores)/len(scenario_scores):.2f} / 10")
    
    # Scenario 2: Same problem across different dataset versions
    # (Requires matching by task_id or uid)
    print("\n" + "=" * 80)
    print("📊 SCENARIO 2: Same problem, different dataset versions")
    print("=" * 80)
    if len(dataset_names) >= 2:
        # Build index by task_id/uid for each dataset
        task_indices = {}
        for ds_name, tasks in datasets.items():
            task_indices[ds_name] = {}
            for i, task in enumerate(tasks):
                tid = task.get("task_id") or task.get("uid") or task.get("index")
                if tid is not None:
                    task_indices[ds_name][str(tid)] = i
        
        # Find common task IDs
        common_ids = set()
        for ds1, ds2 in [(dataset_names[i], dataset_names[j]) 
                         for i in range(len(dataset_names)) 
                         for j in range(i+1, len(dataset_names))]:
            common = set(task_indices[ds1].keys()) & set(task_indices[ds2].keys())
            if common:
                print(f"\n  Found {len(common)} common task IDs between {ds1} and {ds2}")
                sampled_ids = random.sample(list(common), min(n_samples, len(common)))
                for tid in sampled_ids:
                    idx1 = task_indices[ds1][tid]
                    idx2 = task_indices[ds2][tid]
                    result = await run_similarity_comparison(
                        judge,
                        datasets[ds1][idx1],
                        datasets[ds2][idx2],
                        f"{ds1}[{tid}]",
                        f"{ds2}[{tid}]",
                    )
                    result["scenario"] = "same_problem_different_datasets"
                    result["task_id"] = tid
                    results["same_problem_different_datasets"].append(result)
                    print(f"    {result['bug1_label']} vs {result['bug2_label']}: score={result.get('raw_score', 'N/A')}")
                    if result.get("edit_size_comparison"):
                        print(f"      Edit size: {result['edit_size_comparison'][:100]}...")
                    if result.get("reasoning"):
                        print(f"      Reasoning: {result['reasoning'][:100]}...")
    
    # Print scenario average
    scenario_scores = [r.get("raw_score", 0) for r in results["same_problem_different_datasets"] if "raw_score" in r]
    if scenario_scores:
        print(f"\n  📊 Scenario 2 Average: {sum(scenario_scores)/len(scenario_scores):.2f} / 10")
    
    # Scenario 3: Different problems, different datasets
    print("\n" + "=" * 80)
    print("📊 SCENARIO 3: Different problems, different datasets")
    print("=" * 80)
    if len(dataset_names) >= 2:
        for _ in range(n_samples):
            ds1, ds2 = random.sample(dataset_names, 2)
            idx1 = random.randint(0, len(datasets[ds1]) - 1)
            idx2 = random.randint(0, len(datasets[ds2]) - 1)
            result = await run_similarity_comparison(
                judge,
                datasets[ds1][idx1],
                datasets[ds2][idx2],
                f"{ds1}[{idx1}]",
                f"{ds2}[{idx2}]",
            )
            result["scenario"] = "different_problems_different_datasets"
            results["different_problems_different_datasets"].append(result)
            print(f"  {result['bug1_label']} vs {result['bug2_label']}: score={result.get('raw_score', 'N/A')}")
            if result.get("edit_size_comparison"):
                print(f"    Edit size: {result['edit_size_comparison'][:100]}...")
            if result.get("reasoning"):
                print(f"    Reasoning: {result['reasoning'][:100]}...")
    
    # Print scenario average
    scenario_scores = [r.get("raw_score", 0) for r in results["different_problems_different_datasets"] if "raw_score" in r]
    if scenario_scores:
        print(f"\n  📊 Scenario 3 Average: {sum(scenario_scores)/len(scenario_scores):.2f} / 10")
    
    return results


def print_summary(results: Dict[str, List[Dict[str, Any]]]):
    """Print summary statistics for each scenario."""
    print("\n" + "=" * 80)
    print("📈 SUMMARY STATISTICS")
    print("=" * 80)
    
    for scenario, scenario_results in results.items():
        if not scenario_results:
            print(f"\n{scenario}: No results")
            continue
        
        scores = [r.get("raw_score", 0) for r in scenario_results if "raw_score" in r]
        if not scores:
            print(f"\n{scenario}: No valid scores")
            continue
        
        avg_score = sum(scores) / len(scores)
        min_score = min(scores)
        max_score = max(scores)
        
        print(f"\n{scenario}:")
        print(f"  Samples: {len(scores)}")
        print(f"  Average score: {avg_score:.2f} / 10")
        print(f"  Min score: {min_score:.2f}, Max score: {max_score:.2f}")


if __name__ == "__main__":
    import argparse
    import os
    
    parser = argparse.ArgumentParser(description="Test bug similarity scoring across datasets")
    parser.add_argument("--model", type=str, default="openai/gpt-oss-120b",
                        help="Model to use for judging")
    parser.add_argument("--base_url", type=str, default="http://localhost:30000/v1",
                        help="API base URL")
    parser.add_argument("--api_key", type=str, default=os.getenv("OPENAI_API_KEY"), help="API key")
    parser.add_argument("--n_samples", type=int, default=3,
                        help="Number of samples per scenario")
    parser.add_argument("--datasets", type=str, nargs="+",
                        default=["bugbench", "bugbench_qwen7b_sampled", "bugbench_gpt-oss-20b_sampled"],
                        help="Datasets to compare")
    parser.add_argument("--split", type=str, default="test", help="Dataset split to use")
    parser.add_argument("--save_results", type=str, default=None,
                        help="Path to save results JSON")
    parser.add_argument("--skip_tokenizer", action="store_true", default=True,
                        help="Skip loading tokenizer (use for OpenAI-compatible APIs)")
    
    args = parser.parse_args()
    
    os.environ["TOKENIZERS_PARALLELISM"] = "true"
    
    print("=" * 80)
    print("🔬 BUG SIMILARITY JUDGE TEST")
    print("=" * 80)
    print(f"Model: {args.model}")
    print(f"Datasets: {args.datasets}")
    print(f"Split: {args.split}")
    print(f"Samples per scenario: {args.n_samples}")
    
    # Load datasets
    from rllm.data.dataset import DatasetRegistry
    
    datasets = {}
    for ds_name in args.datasets:
        print(f"\nLoading {ds_name}:{args.split}...")
        ds = DatasetRegistry.load_dataset(ds_name, args.split)
        if ds is None:
            print(f"  WARNING: Could not load {ds_name}:{args.split}")
            continue
        data = list(ds.get_data())
        # Filter to only tasks with buggy solutions
        data_with_bugs = [t for t in data if _get_buggy_solution(t)]
        print(f"  Loaded {len(data)} tasks, {len(data_with_bugs)} with buggy solutions")
        if data_with_bugs:
            datasets[ds_name] = data_with_bugs
    
    if not datasets:
        print("\nERROR: No datasets loaded with buggy solutions")
        raise SystemExit(1)
    
    # Initialize rollout engine and judge
    print("\nInitializing model...")
    from rllm.engine import OpenAIEngine
    
    # Try to load tokenizer, but fall back to None for API-only models
    # (when tokenizer=None, OpenAIEngine uses the chat completions API directly)
    tokenizer = None
    if args.skip_tokenizer:
        print("  Skipping tokenizer (using chat completions API)")
    else:
        try:
            from transformers import AutoTokenizer
            tokenizer = AutoTokenizer.from_pretrained(args.model)
            print(f"  Loaded tokenizer for {args.model}")
        except Exception as e:
            print(f"  Could not load tokenizer for {args.model}: {e}")
            print("  Using chat completions API without tokenizer")
    
    rollout_engine = OpenAIEngine(
        model=args.model,
        tokenizer=tokenizer,
        max_prompt_length=8192,
        max_response_length=2048,
        base_url=args.base_url,
        api_key=args.api_key,
        sampling_params={"temperature": 0.3, "top_p": 0.95},
    )
    
    judge = BugSimilarityJudge(rollout_engine, BugSimilarityJudgeConfig(enabled=True))
    
    # Run tests
    print("\nRunning similarity comparisons...")
    results = asyncio.run(run_scenario_tests(judge, datasets, n_samples=args.n_samples))
    
    # Print summary
    print_summary(results)
    
    # Save results
    if args.save_results:
        with open(args.save_results, "w") as f:
            json.dump(results, f, indent=2)
        print(f"\n💾 Results saved to: {args.save_results}")
    
    print("\n✅ Done!")

