#!/usr/bin/env python3
"""
analyze_similarity.py

Held-out generator sampling + target-likeness evaluation across checkpoints.

For each checkpoint t ∈ {t0, ...}:
  1) Sample bugs from the generator on a fixed set of held-out tasks (same tasks across checkpoints)
  2) Compute target similarity s+(b) and negative similarity s-(b) against held-out pools
  3) Compute margin Δ(b) = s+(b) - s-(b)

Plots:
  - Distribution (violin/box) of s+, s-, Δ across checkpoints
  - CDF plots comparing checkpoints
  - Summary statistics table

Key design choices:
  - Use held-out evaluation pools (disjoint from any target bugs used for training reward/mixing)
  - Optionally evaluate with a DIFFERENT embedder than training (e.g., train with voyage-code-3,
    evaluate with a local model) to test for reward hacking

Modes:
  - offline: Load each checkpoint directly with vLLM (no server required)
  - server: Use external vLLM server (requires manual server restart per checkpoint)

Usage (offline mode - recommended):
    python -m examples.bugs_refactor.analyze_similarity \
        --mode offline \
        --ckpt_dir /path/to/checkpoints \
        --ckpts global_step_10,global_step_20,global_step_30 \
        --base_model Qwen/Qwen2.5-Coder-7B-Instruct \
        --source_dataset bigcodebench \
        --target_pool_path /path/to/held_out_target_pool \
        --negative_pool_path /path/to/held_out_negative_pool \
        --n_tasks 100 \
        --output_dir ./similarity_analysis

Usage (server mode):
    python -m examples.bugs_refactor.analyze_similarity \
        --mode server \
        --ckpt_dir /path/to/checkpoints \
        --base_url http://localhost:30000/v1 \
        ...
"""

from __future__ import annotations

import argparse
import asyncio
import gc
import json
import os
import random
import re
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple

import numpy as np

from rllm.data.dataset import DatasetRegistry

from examples.bugs.code_embedding import (
    CodeEmbeddingConfig,
    CodeEmbedder,
    KNNBugSimilarity,
    ReferencePool,
)

# Lazy imports for optional dependencies
OpenAIEngine = None
BugGenerator = None
BugGeneratorConfig = None


def _lazy_import_server_deps():
    """Lazily import server-mode dependencies."""
    global OpenAIEngine, BugGenerator, BugGeneratorConfig
    if OpenAIEngine is None:
        from rllm.engine import OpenAIEngine as _OpenAIEngine
        from examples.bugs_refactor.components import BugGenerator as _BugGenerator
        from examples.bugs_refactor.components import BugGeneratorConfig as _BugGeneratorConfig
        OpenAIEngine = _OpenAIEngine
        BugGenerator = _BugGenerator
        BugGeneratorConfig = _BugGeneratorConfig


# ---------------------------
# Task schema helpers
# ---------------------------

def _get_problem(task: Dict[str, Any]) -> str:
    extra_info = task.get("extra_info", {})
    if isinstance(extra_info, dict):
        for key in (
            "question", "instruct_prompt", "complete_prompt", "prompt",
            "text", "problem", "description", "code_prompt",
        ):
            v = extra_info.get(key)
            if isinstance(v, str) and v.strip():
                return v
    for key in (
        "question", "instruct_prompt", "complete_prompt", "prompt",
        "text", "problem", "description", "code_prompt",
    ):
        v = task.get(key)
        if isinstance(v, str) and v.strip():
            return v
    return ""


def _get_reference_solution(task: Dict[str, Any]) -> str:
    extra_info = task.get("extra_info", {})
    if isinstance(extra_info, dict):
        for key in ("reference_solution", "canonical_solution", "solution", "code", "correct_code"):
            v = extra_info.get(key)
            if isinstance(v, str) and v.strip():
                return v
    for key in ("reference_solution", "canonical_solution", "solution", "code", "correct_code"):
        v = task.get(key)
        if isinstance(v, str) and v.strip():
            return v
    return ""


def _get_buggy_solution(task: Dict[str, Any]) -> Optional[str]:
    """Extract buggy solution from task (for building evaluation pools)."""
    extra_info = task.get("extra_info", {})
    if isinstance(extra_info, dict):
        for key in ("buggy_solution", "buggy_sampled_solution", "buggy", "buggy_code", "bug"):
            v = extra_info.get(key)
            if isinstance(v, str) and v.strip():
                return v
    for key in ("buggy_solution", "buggy_sampled_solution", "buggy", "buggy_code", "bug"):
        v = task.get(key)
        if isinstance(v, str) and v.strip():
            return v
    return None


def _get_task_id(task: Dict[str, Any]) -> Optional[str]:
    extra_info = task.get("extra_info", {})
    candidates = []
    if isinstance(extra_info, dict):
        candidates.extend([
            extra_info.get("task_id"), extra_info.get("uid"),
            extra_info.get("id"), extra_info.get("problem_id"),
        ])
    candidates.extend([
        task.get("task_id"), task.get("uid"),
        task.get("id"), task.get("problem_id"),
    ])
    for v in candidates:
        if isinstance(v, (str, int)) and str(v).strip():
            return str(v).strip()
    return None


# ---------------------------
# Dataset parsing helpers
# ---------------------------

def parse_dataset_specs(spec_str: str, default_split: str = "test") -> List[Tuple[str, str]]:
    """
    Parse dataset specification string into list of (dataset_name, split) tuples.
    
    Formats supported:
      - "dataset1:split1,dataset2:split2"
      - "dataset1,dataset2" (uses default_split)
      - "[dataset1:split1,dataset2:split2]" (with brackets)
    """
    if not spec_str:
        return []
    
    # Remove brackets if present
    spec_str = spec_str.strip()
    if spec_str.startswith("[") and spec_str.endswith("]"):
        spec_str = spec_str[1:-1]
    
    specs: List[Tuple[str, str]] = []
    for item in spec_str.split(","):
        item = item.strip()
        if not item:
            continue
        if ":" in item:
            parts = item.split(":", 1)
            specs.append((parts[0].strip(), parts[1].strip()))
        else:
            specs.append((item, default_split))
    
    return specs


def load_buggy_tasks_from_datasets(
    dataset_specs: List[Tuple[str, str]],
) -> Tuple[List[Dict[str, Any]], List[str]]:
    """
    Load tasks with buggy solutions from multiple datasets.
    
    Returns: (all_tasks, dataset_labels)
    """
    all_tasks: List[Dict[str, Any]] = []
    labels: List[str] = []
    
    for ds_name, split in dataset_specs:
        ds = DatasetRegistry.load_dataset(ds_name, split)
        if ds is None:
            print(f"    [WARN] Could not load {ds_name}:{split}, skipping")
            continue
        
        tasks = list(ds.get_data())
        tasks_with_bugs = [t for t in tasks if _get_buggy_solution(t)]
        
        print(f"    {ds_name}:{split}: {len(tasks_with_bugs)}/{len(tasks)} tasks with buggy solutions")
        all_tasks.extend(tasks_with_bugs)
        labels.append(f"{ds_name}:{split}")
    
    return all_tasks, labels


# ---------------------------
# Data structures
# ---------------------------

@dataclass
class CheckpointScores:
    """Scores for bugs generated by a single checkpoint."""
    checkpoint_name: str
    step: int
    target_sims: List[float] = field(default_factory=list)  # s+(b)
    negative_sims: List[float] = field(default_factory=list)  # s-(b)
    margins: List[float] = field(default_factory=list)  # Δ(b) = s+(b) - s-(b)
    normalized_scores: List[float] = field(default_factory=list)  # sigmoid(margin)
    task_ids: List[str] = field(default_factory=list)
    generation_time: float = 0.0
    n_failed: int = 0


# ---------------------------
# Bug generation (server mode)
# ---------------------------

async def generate_bugs_batch_server(
    generator,  # BugGenerator
    tasks: List[Dict[str, Any]],
    n_parallel: int = 16,
) -> List[Tuple[Dict[str, Any], str, bool]]:
    """Generate bugs for a batch of tasks using server mode."""
    semaphore = asyncio.Semaphore(n_parallel)
    
    async def generate_one(task: Dict[str, Any], idx: int) -> Tuple[Dict[str, Any], str, bool]:
        async with semaphore:
            try:
                traj = await generator.generate_bug(task, f"gen_{idx}")
                buggy_code = traj.steps[0].action if traj.steps else ""
                return (task, buggy_code, True)
            except Exception as e:
                print(f"  [WARN] Failed to generate bug for task {idx}: {e}")
                return (task, "", False)
    
    coros = [generate_one(task, i) for i, task in enumerate(tasks)]
    results = await asyncio.gather(*coros)
    return list(results)


# ---------------------------
# Bug generation (offline mode - no server required)
# ---------------------------

# Default bug generator prompt
DEFAULT_BUG_GENERATOR_PROMPT = """You are a code mutation expert. Given a problem description and a correct solution, introduce a subtle bug that:
1. Makes the code fail on some (but not all) test cases
2. Is realistic - the kind of mistake a programmer might make
3. Is not immediately obvious

Return ONLY the buggy code, no explanations."""


def build_bug_prompt(task: Dict[str, Any], system_prompt: Optional[str] = None) -> List[Dict[str, str]]:
    """Build the prompt for bug generation."""
    problem = _get_problem(task)
    solution = _get_reference_solution(task)
    
    sys_prompt = system_prompt or DEFAULT_BUG_GENERATOR_PROMPT
    
    user_content = f"""Problem:
{problem}

Correct Solution:
```python
{solution}
```

Generate a buggy version of this code with a subtle bug."""
    
    return [
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": user_content},
    ]


def extract_code_from_response(response: str) -> str:
    """Extract code from model response, handling markdown code blocks."""
    # Try to extract from code blocks
    code_block_pattern = r"```(?:python)?\s*\n(.*?)```"
    matches = re.findall(code_block_pattern, response, re.DOTALL)
    if matches:
        return matches[0].strip()
    
    # If no code blocks, return the whole response (might be raw code)
    return response.strip()


def check_checkpoint_format(ckpt_path: str) -> Tuple[bool, Optional[str]]:
    """
    Check if checkpoint is in HF format or needs merging.
    
    Returns: (is_valid, actor_subdir_if_fsdp)
    """
    ckpt_path = Path(ckpt_path)
    
    # Check for config.json (HF format)
    if (ckpt_path / "config.json").exists():
        return True, None
    
    # Check for params.json (Mistral format)
    if (ckpt_path / "params.json").exists():
        return True, None
    
    # Check if this is an FSDP checkpoint (has actor subdirectory)
    actor_dir = ckpt_path / "actor"
    if actor_dir.exists() and actor_dir.is_dir():
        # This is likely an FSDP checkpoint that needs merging
        return False, str(actor_dir)
    
    # Check if there are shard files directly
    shard_files = list(ckpt_path.glob("*.distcp")) + list(ckpt_path.glob("__*_0"))
    if shard_files:
        return False, str(ckpt_path)
    
    return False, None


def merge_fsdp_checkpoint(
    ckpt_path: str,
    base_model: str,
    output_dir: str,
) -> str:
    """
    Merge FSDP sharded checkpoint to HF format.
    
    Returns: Path to merged checkpoint.
    """
    import subprocess
    
    ckpt_name = Path(ckpt_path).parent.name if Path(ckpt_path).name == "actor" else Path(ckpt_path).name
    merged_path = Path(output_dir) / f"{ckpt_name}_merged"
    
    # Check if already merged
    if (merged_path / "config.json").exists():
        print(f"  Using previously merged checkpoint: {merged_path}")
        return str(merged_path)
    
    print(f"  Merging FSDP checkpoint: {ckpt_path}")
    print(f"  Output: {merged_path}")
    
    # Create output directory
    merged_path.mkdir(parents=True, exist_ok=True)
    
    # Run verl model merger
    cmd = [
        "python", "-m", "verl.model_merger", "merge",
        "--backend", "fsdp",
        "--local_dir", str(ckpt_path),
        "--target_dir", str(merged_path),
    ]
    
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, check=True)
        print(f"  Merge completed successfully")
    except subprocess.CalledProcessError as e:
        print(f"  [ERROR] Merge failed: {e.stderr}")
        raise RuntimeError(f"Failed to merge checkpoint: {e.stderr}")
    except FileNotFoundError:
        # verl.model_merger not available, try alternative approach
        print(f"  [WARN] verl.model_merger not found, trying manual merge...")
        try:
            from verl.model_merger import merge_fsdp_checkpoint as verl_merge
            verl_merge(str(ckpt_path), str(merged_path))
        except ImportError:
            raise RuntimeError(
                f"Checkpoint {ckpt_path} appears to be in FSDP format and needs merging.\n"
                f"Please merge manually:\n"
                f"  python -m verl.model_merger merge --backend fsdp "
                f"--local_dir {ckpt_path} --target_dir {merged_path}"
            )
    
    return str(merged_path)


def generate_bugs_offline(
    ckpt_path: str,
    base_model: str,
    tasks: List[Dict[str, Any]],
    temperature: float = 0.6,
    top_p: float = 0.95,
    max_tokens: int = 4096,
    system_prompt: Optional[str] = None,
    tensor_parallel_size: int = 1,
    gpu_memory_utilization: float = 0.85,
    merged_ckpt_dir: Optional[str] = None,
    skip_merge: bool = False,
) -> List[Tuple[Dict[str, Any], str, bool]]:
    """
    Generate bugs using vLLM offline mode (no server required).
    
    Loads the checkpoint, generates all bugs, then cleans up GPU memory.
    Automatically merges FSDP checkpoints if needed (unless skip_merge=True).
    """
    try:
        from vllm import LLM, SamplingParams
        from transformers import AutoTokenizer
    except ImportError as e:
        raise RuntimeError(f"vLLM or transformers not available: {e}")
    
    # Check if checkpoint needs merging
    is_valid, actor_subdir = check_checkpoint_format(ckpt_path)
    
    if not is_valid:
        if skip_merge:
            raise RuntimeError(
                f"Checkpoint {ckpt_path} is not in HF format and --skip_merge is set.\n"
                f"Please merge manually:\n"
                f"  python -m verl.model_merger merge --backend fsdp "
                f"--local_dir {ckpt_path}/actor --target_dir <output_dir>"
            )
        if actor_subdir:
            print(f"  Detected FSDP checkpoint, merging required...")
            merge_output = merged_ckpt_dir or os.path.dirname(ckpt_path)
            ckpt_path = merge_fsdp_checkpoint(actor_subdir, base_model, merge_output)
        else:
            raise RuntimeError(
                f"Checkpoint {ckpt_path} is not in a recognized format.\n"
                f"Expected config.json (HF) or params.json (Mistral), or actor/ subdirectory (FSDP)."
            )
    
    print(f"  Loading checkpoint: {ckpt_path}")
    
    # Load tokenizer from base model (for chat template)
    tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
    
    # Build prompts for all tasks
    prompts: List[str] = []
    for task in tasks:
        messages = build_bug_prompt(task, system_prompt)
        if hasattr(tokenizer, "apply_chat_template"):
            prompt = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
        else:
            # Fallback for tokenizers without chat template
            prompt = f"{messages[0]['content']}\n\nUser: {messages[1]['content']}\n\nAssistant:"
        prompts.append(prompt)
    
    # Load model
    llm = LLM(
        model=ckpt_path,
        tokenizer=base_model,  # Use base model tokenizer
        trust_remote_code=True,
        tensor_parallel_size=tensor_parallel_size,
        gpu_memory_utilization=gpu_memory_utilization,
        max_model_len=8192,
    )
    
    # Generate
    sampling_params = SamplingParams(
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
    )
    
    print(f"  Generating {len(prompts)} bugs...")
    outputs = llm.generate(prompts, sampling_params)
    
    # Extract results
    results: List[Tuple[Dict[str, Any], str, bool]] = []
    for task, output in zip(tasks, outputs):
        if output.outputs:
            response = output.outputs[0].text
            buggy_code = extract_code_from_response(response)
            results.append((task, buggy_code, bool(buggy_code.strip())))
        else:
            results.append((task, "", False))
    
    # Clean up GPU memory
    del llm
    gc.collect()
    try:
        import torch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
    except ImportError:
        pass
    
    return results


# ---------------------------
# Checkpoint discovery
# ---------------------------

def find_checkpoints(ckpt_dir: str, ckpt_names: Optional[List[str]] = None) -> List[Tuple[str, int]]:
    """
    Find checkpoints in directory.
    
    Returns: List of (ckpt_path, step_number) sorted by step.
    """
    ckpt_dir = Path(ckpt_dir)
    if not ckpt_dir.exists():
        raise RuntimeError(f"Checkpoint directory does not exist: {ckpt_dir}")
    
    checkpoints: List[Tuple[str, int]] = []
    
    if ckpt_names:
        # Use specified checkpoint names
        for name in ckpt_names:
            ckpt_path = ckpt_dir / name
            if ckpt_path.exists():
                # Extract step number from name
                step = extract_step_number(name)
                checkpoints.append((str(ckpt_path), step))
            else:
                print(f"  [WARN] Checkpoint not found: {ckpt_path}")
    else:
        # Auto-discover checkpoints matching pattern global_step_*
        for item in ckpt_dir.iterdir():
            if item.is_dir() and item.name.startswith("global_step_"):
                step = extract_step_number(item.name)
                checkpoints.append((str(item), step))
    
    # Sort by step number
    checkpoints.sort(key=lambda x: x[1])
    return checkpoints


def extract_step_number(name: str) -> int:
    """Extract step number from checkpoint name like 'global_step_100'."""
    import re
    match = re.search(r'(\d+)', name)
    if match:
        return int(match.group(1))
    return 0


# ---------------------------
# Data loading
# ---------------------------

def load_tasks_with_solutions(dataset_name: str, split: str) -> List[Dict[str, Any]]:
    """Load tasks that have reference solutions (needed for bug generation)."""
    ds = DatasetRegistry.load_dataset(dataset_name, split)
    if ds is None:
        raise RuntimeError(f"Could not load dataset={dataset_name!r} split={split!r}")
    data = list(ds.get_data())
    with_solution = [t for t in data if _get_reference_solution(t)]
    return with_solution


def sample_fixed_tasks(
    tasks: List[Dict[str, Any]],
    n: int,
    seed: int,
) -> List[Dict[str, Any]]:
    """Sample a fixed set of tasks (deterministic based on seed)."""
    rng = random.Random(seed)
    if n >= len(tasks):
        return list(tasks)
    return rng.sample(tasks, n)


# ---------------------------
# Scoring
# ---------------------------

def score_bugs(
    knn: KNNBugSimilarity,
    task_bug_pairs: List[Tuple[Dict[str, Any], str]],
) -> Tuple[List[float], List[float], List[float], List[float]]:
    """
    Score a list of (task, buggy_code) pairs.
    
    Returns: (target_sims, negative_sims, margins, normalized_scores)
    """
    target_sims: List[float] = []
    negative_sims: List[float] = []
    margins: List[float] = []
    normalized_scores: List[float] = []
    
    for task, buggy_code in task_bug_pairs:
        problem = _get_problem(task)
        correct_code = _get_reference_solution(task)
        
        score, meta = knn.score_similarity(problem, buggy_code, correct_code=correct_code)
        normalized_scores.append(float(score))
        
        # Extract raw similarities
        if "avg_cosine_target" in meta:
            target_sims.append(float(meta["avg_cosine_target"]))
        elif "target_sim" in meta:
            target_sims.append(float(meta["target_sim"]))
        else:
            target_sims.append(float(score))  # fallback
        
        if "avg_cosine_negative" in meta:
            negative_sims.append(float(meta["avg_cosine_negative"]))
        elif "negative_sim" in meta:
            negative_sims.append(float(meta["negative_sim"]))
        else:
            negative_sims.append(0.0)
        
        if "margin" in meta:
            margins.append(float(meta["margin"]))
        else:
            margins.append(target_sims[-1] - negative_sims[-1])
    
    return target_sims, negative_sims, margins, normalized_scores


# ---------------------------
# Plotting
# ---------------------------

def plot_violin_comparison(
    all_scores: List[CheckpointScores],
    output_dir: str,
    metric: str = "margin",
) -> None:
    """Plot violin plot comparing score distributions across checkpoints."""
    try:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
    except ImportError:
        print("  [WARN] matplotlib not available, skipping violin plot")
        return
    
    # Prepare data
    data = []
    labels = []
    positions = []
    
    for i, scores in enumerate(all_scores):
        if metric == "target":
            vals = scores.target_sims
        elif metric == "negative":
            vals = scores.negative_sims
        elif metric == "margin":
            vals = scores.margins
        else:
            vals = scores.normalized_scores
        
        if vals:
            data.append(vals)
            labels.append(f"Step {scores.step}")
            positions.append(i)
    
    if not data:
        print(f"  [WARN] No data for violin plot ({metric})")
        return
    
    # Create figure
    fig, ax = plt.subplots(figsize=(max(10, len(data) * 1.5), 6))
    
    # Violin plot
    parts = ax.violinplot(data, positions=positions, showmeans=True, showmedians=True)
    
    # Color the violins
    cmap = plt.cm.viridis
    colors = [cmap(i / max(1, len(data) - 1)) for i in range(len(data))]
    
    for i, pc in enumerate(parts['bodies']):
        pc.set_facecolor(colors[i])
        pc.set_alpha(0.7)
    
    # Add box plot overlay for IQR
    bp = ax.boxplot(data, positions=positions, widths=0.15, patch_artist=True,
                    showfliers=False, showcaps=False, showmeans=False)
    for patch in bp['boxes']:
        patch.set_facecolor('white')
        patch.set_alpha(0.8)
    
    # Labels
    ax.set_xticks(positions)
    ax.set_xticklabels(labels, rotation=45, ha='right')
    
    metric_labels = {
        "target": "Target Similarity s⁺(b)",
        "negative": "Negative Similarity s⁻(b)",
        "margin": "Margin Δ(b) = s⁺(b) - s⁻(b)",
        "normalized": "Normalized Score (sigmoid)",
    }
    ax.set_ylabel(metric_labels.get(metric, metric))
    ax.set_title(f"Bug Embedding Similarity Across Checkpoints ({metric})")
    
    # Add horizontal line at 0 for margin
    if metric == "margin":
        ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    
    output_path = os.path.join(output_dir, f"violin_{metric}.png")
    plt.savefig(output_path, dpi=150)
    plt.close()
    print(f"  Saved: {output_path}")


def plot_cdf_comparison(
    all_scores: List[CheckpointScores],
    output_dir: str,
    metric: str = "margin",
) -> None:
    """Plot CDF comparison across checkpoints."""
    try:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
    except ImportError:
        print("  [WARN] matplotlib not available, skipping CDF plot")
        return
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    cmap = plt.cm.viridis
    
    for i, scores in enumerate(all_scores):
        if metric == "target":
            vals = sorted(scores.target_sims)
        elif metric == "negative":
            vals = sorted(scores.negative_sims)
        elif metric == "margin":
            vals = sorted(scores.margins)
        else:
            vals = sorted(scores.normalized_scores)
        
        if not vals:
            continue
        
        # Compute CDF
        y = np.arange(1, len(vals) + 1) / len(vals)
        color = cmap(i / max(1, len(all_scores) - 1))
        ax.plot(vals, y, label=f"Step {scores.step}", color=color, linewidth=2)
    
    metric_labels = {
        "target": "Target Similarity s⁺(b)",
        "negative": "Negative Similarity s⁻(b)",
        "margin": "Margin Δ(b) = s⁺(b) - s⁻(b)",
        "normalized": "Normalized Score",
    }
    ax.set_xlabel(metric_labels.get(metric, metric))
    ax.set_ylabel("CDF")
    ax.set_title(f"CDF of Bug Embedding Similarity Across Checkpoints")
    ax.legend(loc='lower right')
    ax.grid(True, alpha=0.3)
    
    if metric == "margin":
        ax.axvline(x=0, color='gray', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    
    output_path = os.path.join(output_dir, f"cdf_{metric}.png")
    plt.savefig(output_path, dpi=150)
    plt.close()
    print(f"  Saved: {output_path}")


def plot_summary_progression(
    all_scores: List[CheckpointScores],
    output_dir: str,
) -> None:
    """Plot summary statistics progression across checkpoints."""
    try:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
    except ImportError:
        print("  [WARN] matplotlib not available, skipping progression plot")
        return
    
    steps = [s.step for s in all_scores]
    
    # Compute medians and IQRs
    target_medians = [np.median(s.target_sims) if s.target_sims else 0 for s in all_scores]
    negative_medians = [np.median(s.negative_sims) if s.negative_sims else 0 for s in all_scores]
    margin_medians = [np.median(s.margins) if s.margins else 0 for s in all_scores]
    
    target_q25 = [np.percentile(s.target_sims, 25) if s.target_sims else 0 for s in all_scores]
    target_q75 = [np.percentile(s.target_sims, 75) if s.target_sims else 0 for s in all_scores]
    negative_q25 = [np.percentile(s.negative_sims, 25) if s.negative_sims else 0 for s in all_scores]
    negative_q75 = [np.percentile(s.negative_sims, 75) if s.negative_sims else 0 for s in all_scores]
    margin_q25 = [np.percentile(s.margins, 25) if s.margins else 0 for s in all_scores]
    margin_q75 = [np.percentile(s.margins, 75) if s.margins else 0 for s in all_scores]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Target similarity
    ax = axes[0]
    ax.plot(steps, target_medians, 'o-', color='green', linewidth=2, label='Median')
    ax.fill_between(steps, target_q25, target_q75, alpha=0.3, color='green', label='IQR')
    ax.set_xlabel("Training Step")
    ax.set_ylabel("Target Similarity s⁺(b)")
    ax.set_title("Target Similarity Progression")
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Negative similarity
    ax = axes[1]
    ax.plot(steps, negative_medians, 'o-', color='red', linewidth=2, label='Median')
    ax.fill_between(steps, negative_q25, negative_q75, alpha=0.3, color='red', label='IQR')
    ax.set_xlabel("Training Step")
    ax.set_ylabel("Negative Similarity s⁻(b)")
    ax.set_title("Negative Similarity Progression")
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Margin
    ax = axes[2]
    ax.plot(steps, margin_medians, 'o-', color='blue', linewidth=2, label='Median')
    ax.fill_between(steps, margin_q25, margin_q75, alpha=0.3, color='blue', label='IQR')
    ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    ax.set_xlabel("Training Step")
    ax.set_ylabel("Margin Δ(b)")
    ax.set_title("Margin Progression")
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    output_path = os.path.join(output_dir, "progression.png")
    plt.savefig(output_path, dpi=150)
    plt.close()
    print(f"  Saved: {output_path}")


def save_statistics_csv(
    all_scores: List[CheckpointScores],
    output_dir: str,
) -> None:
    """Save summary statistics to CSV."""
    rows = []
    for s in all_scores:
        row = {
            "checkpoint": s.checkpoint_name,
            "step": s.step,
            "n_samples": len(s.margins),
            "n_failed": s.n_failed,
            "generation_time_s": s.generation_time,
        }
        
        if s.target_sims:
            arr = np.array(s.target_sims)
            row.update({
                "target_mean": arr.mean(),
                "target_std": arr.std(),
                "target_median": np.median(arr),
                "target_q25": np.percentile(arr, 25),
                "target_q75": np.percentile(arr, 75),
            })
        
        if s.negative_sims:
            arr = np.array(s.negative_sims)
            row.update({
                "negative_mean": arr.mean(),
                "negative_std": arr.std(),
                "negative_median": np.median(arr),
                "negative_q25": np.percentile(arr, 25),
                "negative_q75": np.percentile(arr, 75),
            })
        
        if s.margins:
            arr = np.array(s.margins)
            row.update({
                "margin_mean": arr.mean(),
                "margin_std": arr.std(),
                "margin_median": np.median(arr),
                "margin_q25": np.percentile(arr, 25),
                "margin_q75": np.percentile(arr, 75),
                "margin_positive_frac": (arr > 0).mean(),
            })
        
        rows.append(row)
    
    # Write CSV
    output_path = os.path.join(output_dir, "statistics.csv")
    if rows:
        import csv
        with open(output_path, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=rows[0].keys())
            writer.writeheader()
            writer.writerows(rows)
        print(f"  Saved: {output_path}")


def save_raw_scores(
    all_scores: List[CheckpointScores],
    output_dir: str,
) -> None:
    """Save raw scores for each checkpoint."""
    for s in all_scores:
        output_path = os.path.join(output_dir, f"scores_step{s.step}.json")
        data = {
            "checkpoint": s.checkpoint_name,
            "step": s.step,
            "task_ids": s.task_ids,
            "target_sims": s.target_sims,
            "negative_sims": s.negative_sims,
            "margins": s.margins,
            "normalized_scores": s.normalized_scores,
        }
        with open(output_path, 'w') as f:
            json.dump(data, f, indent=2)
    print(f"  Saved raw scores for {len(all_scores)} checkpoints")


# ---------------------------
# Main
# ---------------------------

def main():
    ap = argparse.ArgumentParser(
        description="Analyze bug generator similarity across checkpoints"
    )
    
    # Mode selection
    ap.add_argument("--mode", type=str, default="offline", choices=["offline", "server"],
                    help="Generation mode: 'offline' loads each checkpoint directly (recommended), "
                         "'server' requires external vLLM server")
    
    # Checkpoint configuration
    ap.add_argument("--ckpt_dir", type=str, required=True,
                    help="Directory containing checkpoints")
    ap.add_argument("--ckpts", type=str, default=None,
                    help="Comma-separated list of checkpoint names (e.g., 'global_step_10,global_step_20'). "
                         "If not specified, auto-discovers global_step_* directories.")
    ap.add_argument("--base_model", type=str, default="Qwen/Qwen2.5-Coder-7B-Instruct",
                    help="Base model path (required for offline mode tokenizer)")
    
    # Server configuration (only for server mode)
    ap.add_argument("--base_url", type=str, default="http://localhost:30000/v1",
                    help="Base URL for vLLM server (server mode only)")
    ap.add_argument("--api_key", type=str, default=None,
                    help="API key (default: OPENAI_API_KEY or 'EMPTY')")
    
    # Generation configuration
    ap.add_argument("--temperature", type=float, default=0.6)
    ap.add_argument("--top_p", type=float, default=0.95)
    ap.add_argument("--max_tokens", type=int, default=4096,
                    help="Max tokens for generation")
    ap.add_argument("--generator_system_prompt", type=str, default=None)
    
    # Offline mode configuration
    ap.add_argument("--tensor_parallel_size", type=int, default=1,
                    help="Tensor parallel size for offline vLLM (offline mode only)")
    ap.add_argument("--gpu_memory_utilization", type=float, default=0.85,
                    help="GPU memory utilization for offline vLLM (offline mode only)")
    ap.add_argument("--merged_ckpt_dir", type=str, default=None,
                    help="Directory to store merged FSDP checkpoints (default: same as ckpt_dir)")
    ap.add_argument("--skip_merge", action="store_true", default=False,
                    help="Skip automatic FSDP merging (error if checkpoint not in HF format)")
    
    # Source dataset (held-out tasks for generation)
    ap.add_argument("--source_dataset", type=str, default="bigcodebench",
                    help="Dataset to sample held-out tasks from")
    ap.add_argument("--source_split", type=str, default="train")
    ap.add_argument("--n_tasks", type=int, default=100,
                    help="Number of held-out tasks to use (same across all checkpoints)")
    ap.add_argument("--task_seed", type=int, default=42,
                    help="Seed for task sampling (ensures same tasks across checkpoints)")
    
    # Evaluation pools (HELD-OUT, disjoint from training)
    # Option 1: Load from pre-computed pool files
    ap.add_argument("--target_pool_path", type=str, default=None,
                    help="Path to pre-computed target pool (if not provided, builds from --target_datasets)")
    ap.add_argument("--negative_pool_path", type=str, default=None,
                    help="Path to pre-computed negative pool (if not provided, builds from --negative_datasets)")
    # Option 2: Build pools from datasets (supports multiple datasets)
    # Format: "dataset1:split1,dataset2:split2" or just "dataset1,dataset2" (uses default split)
    ap.add_argument("--target_datasets", type=str, default="bugbench_human:test",
                    help="Comma-separated list of datasets to build target pool from. "
                         "Format: 'ds1:split1,ds2:split2' or 'ds1,ds2' (default split: test)")
    ap.add_argument("--negative_datasets", type=str, default=None,
                    help="Comma-separated list of datasets to build negative pool from. "
                         "Format: 'ds1:split1,ds2:split2' (e.g., 'bugbench_qwen7b_sampled:test')")
    ap.add_argument("--default_pool_split", type=str, default="test",
                    help="Default split to use when not specified in dataset string")
    # Option to save built pools for future use
    ap.add_argument("--save_pools", action="store_true", default=False,
                    help="Save built pools to output_dir for future use")
    
    # Evaluation embedder (use DIFFERENT model than training for robustness)
    ap.add_argument("--eval_embed_model", type=str, default="voyage-code-3",
                    help="Embedding model for evaluation (recommend different from training)")
    ap.add_argument("--eval_embed_mode", type=str, default="buggy", choices=["diff", "buggy"])
    ap.add_argument("--eval_include_problem", action="store_true", default=False)
    ap.add_argument("--eval_top_k", type=int, default=20)
    ap.add_argument("--eval_margin_temperature", type=float, default=10.0)
    ap.add_argument("--device", type=str, default="cuda")
    
    # Execution
    ap.add_argument("--n_parallel", type=int, default=16,
                    help="Number of parallel generation requests (server mode only)")
    
    # Output
    ap.add_argument("--output_dir", type=str, default="./similarity_analysis",
                    help="Directory to save results and plots")
    ap.add_argument("--save_raw_scores", action="store_true", default=False,
                    help="Save raw scores for each checkpoint")
    
    args = ap.parse_args()
    
    # Resolve API key
    api_key = args.api_key
    if api_key is None:
        api_key = os.getenv("OPENAI_API_KEY", "")
    if not api_key.strip():
        api_key = "EMPTY"
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    print("=" * 80)
    print("Bug Generator Checkpoint Similarity Analysis")
    print("=" * 80)
    print(f"Mode: {args.mode}")
    print(f"Checkpoint dir: {args.ckpt_dir}")
    print(f"Base model: {args.base_model}")
    print(f"Source dataset: {args.source_dataset}:{args.source_split} (n={args.n_tasks})")
    
    # Parse dataset specs
    target_specs = parse_dataset_specs(args.target_datasets, args.default_pool_split) if args.target_datasets else []
    negative_specs = parse_dataset_specs(args.negative_datasets, args.default_pool_split) if args.negative_datasets else []
    
    # Display target pool source
    if args.target_pool_path:
        print(f"Target pool: {args.target_pool_path} (pre-computed)")
    elif target_specs:
        print(f"Target pool: {', '.join(f'{d}:{s}' for d, s in target_specs)} (will build)")
    else:
        print(f"Target pool: None (ERROR: must specify --target_pool_path or --target_datasets)")
    
    # Display negative pool source
    if args.negative_pool_path:
        print(f"Negative pool: {args.negative_pool_path} (pre-computed)")
    elif negative_specs:
        print(f"Negative pool: {', '.join(f'{d}:{s}' for d, s in negative_specs)} (will build)")
    else:
        print(f"Negative pool: None (using absolute scoring)")
    
    print(f"Eval embedder: {args.eval_embed_model} (mode={args.eval_embed_mode})")
    print(f"Output: {args.output_dir}")
    if args.mode == "offline":
        print(f"Tensor parallel: {args.tensor_parallel_size}, GPU util: {args.gpu_memory_utilization}")
    print()
    
    # ---------------------------
    # Find checkpoints
    # ---------------------------
    print("Finding checkpoints...")
    ckpt_names = args.ckpts.split(",") if args.ckpts else None
    checkpoints = find_checkpoints(args.ckpt_dir, ckpt_names)
    
    if not checkpoints:
        print("  [ERROR] No checkpoints found!")
        return
    
    print(f"  Found {len(checkpoints)} checkpoints:")
    for ckpt_path, step in checkpoints:
        print(f"    Step {step}: {ckpt_path}")
    print()
    
    # ---------------------------
    # Load held-out tasks
    # ---------------------------
    print("Loading held-out tasks...")
    all_tasks = load_tasks_with_solutions(args.source_dataset, args.source_split)
    print(f"  Loaded {len(all_tasks)} tasks with solutions")
    
    held_out_tasks = sample_fixed_tasks(all_tasks, args.n_tasks, args.task_seed)
    print(f"  Sampled {len(held_out_tasks)} held-out tasks (seed={args.task_seed})")
    print()
    
    # ---------------------------
    # Load or build evaluation pools
    # ---------------------------
    print("Setting up evaluation pools...")
    
    # Determine if we'll have a negative pool (for relative scoring)
    has_negative = bool(args.negative_pool_path or negative_specs)
    
    eval_cfg = CodeEmbeddingConfig(
        model_name=args.eval_embed_model,
        embed_mode=args.eval_embed_mode,
        include_problem=args.eval_include_problem,
        top_k=args.eval_top_k,
        device=args.device,
        use_relative_score=has_negative,
        margin_temperature=args.eval_margin_temperature,
    )
    embedder = CodeEmbedder(eval_cfg)
    knn = KNNBugSimilarity(embedder, top_k=args.eval_top_k)
    
    # Load or build target pool
    if args.target_pool_path:
        print(f"  Loading target pool from: {args.target_pool_path}")
        target_pool = ReferencePool.load(args.target_pool_path)
        knn.target_pool = target_pool
    elif target_specs:
        print(f"  Building target pool from {len(target_specs)} dataset(s):")
        target_tasks_with_bugs, target_labels = load_buggy_tasks_from_datasets(target_specs)
        if not target_tasks_with_bugs:
            raise RuntimeError(f"No buggy solutions found in target datasets: {target_specs}")
        print(f"  Total: {len(target_tasks_with_bugs)} tasks with buggy solutions")
        knn.build_target_pool(target_tasks_with_bugs)
        
        # Save pool if requested
        if args.save_pools:
            pool_name = "_".join(d.replace("/", "-") for d, s in target_specs)
            pool_path = os.path.join(args.output_dir, f"target_pool_{pool_name}")
            knn.target_pool.save(pool_path)
            print(f"    Saved target pool to: {pool_path}")
    else:
        raise RuntimeError("Must specify --target_pool_path or --target_datasets")
    
    print(f"  Target pool: {len(knn.target_pool)} embeddings")
    
    # Load or build negative pool (optional)
    if args.negative_pool_path:
        print(f"  Loading negative pool from: {args.negative_pool_path}")
        negative_pool = ReferencePool.load(args.negative_pool_path)
        knn.negative_pool = negative_pool
        print(f"  Negative pool: {len(negative_pool)} embeddings")
    elif negative_specs:
        print(f"  Building negative pool from {len(negative_specs)} dataset(s):")
        negative_tasks_with_bugs, negative_labels = load_buggy_tasks_from_datasets(negative_specs)
        if negative_tasks_with_bugs:
            print(f"  Total: {len(negative_tasks_with_bugs)} tasks with buggy solutions")
            knn.build_negative_pool(negative_tasks_with_bugs)
            print(f"  Negative pool: {len(knn.negative_pool)} embeddings")
            
            # Save pool if requested
            if args.save_pools:
                pool_name = "_".join(d.replace("/", "-") for d, s in negative_specs)
                pool_path = os.path.join(args.output_dir, f"negative_pool_{pool_name}")
                knn.negative_pool.save(pool_path)
                print(f"    Saved negative pool to: {pool_path}")
        else:
            print(f"  [WARN] No buggy solutions in negative datasets, using absolute scoring")
    else:
        print("  Negative pool: None (using absolute scoring)")
    
    print()
    
    # ---------------------------
    # Evaluate each checkpoint
    # ---------------------------
    all_checkpoint_scores: List[CheckpointScores] = []
    
    for ckpt_path, step in checkpoints:
        print(f"\n{'='*60}")
        print(f"Evaluating: Step {step}")
        print(f"  Path: {ckpt_path}")
        print(f"{'='*60}")
        
        # Generate bugs
        print(f"\n  🐛 Generating bugs for {len(held_out_tasks)} tasks...")
        start_time = time.time()
        
        try:
            if args.mode == "offline":
                # Offline mode: load checkpoint directly with vLLM
                merged_dir = args.merged_ckpt_dir or args.output_dir
                results = generate_bugs_offline(
                    ckpt_path=ckpt_path,
                    base_model=args.base_model,
                    tasks=held_out_tasks,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    max_tokens=args.max_tokens,
                    system_prompt=args.generator_system_prompt,
                    tensor_parallel_size=args.tensor_parallel_size,
                    gpu_memory_utilization=args.gpu_memory_utilization,
                    merged_ckpt_dir=merged_dir,
                    skip_merge=args.skip_merge,
                )
            else:
                # Server mode: use external vLLM server
                _lazy_import_server_deps()
                
                print(f"\n  ⚠️  Make sure vLLM server is running with checkpoint: {ckpt_path}")
                print(f"     Base URL: {args.base_url}")
                
                model_name = ckpt_path
                generator_engine = OpenAIEngine(
                    model=model_name,
                    tokenizer=None,
                    base_url=args.base_url,
                    api_key=api_key,
                    max_prompt_length=8192,
                    max_response_length=args.max_tokens,
                    sampling_params={
                        "temperature": args.temperature,
                        "top_p": args.top_p,
                    },
                    verbose=False,
                )
                
                generator = BugGenerator(
                    generator_engine,
                    BugGeneratorConfig(system_prompt=args.generator_system_prompt),
                )
                
                results = asyncio.run(generate_bugs_batch_server(
                    generator, held_out_tasks, args.n_parallel
                ))
                
        except Exception as e:
            print(f"  [ERROR] Generation failed: {e}")
            import traceback
            traceback.print_exc()
            if args.mode == "server":
                print(f"  [INFO] Make sure vLLM server is running with the correct checkpoint")
            continue
        
        gen_time = time.time() - start_time
        successful = [(t, b) for t, b, s in results if s and b.strip()]
        n_failed = len(results) - len(successful)
        
        print(f"  Generated {len(successful)} bugs in {gen_time:.1f}s ({n_failed} failed)")
        
        if not successful:
            print(f"  [WARN] No successful generations, skipping checkpoint")
            continue
        
        # Score bugs
        print(f"  🔢 Scoring bugs...")
        target_sims, negative_sims, margins, normalized_scores = score_bugs(knn, successful)
        
        # Store results
        scores = CheckpointScores(
            checkpoint_name=os.path.basename(ckpt_path),
            step=step,
            target_sims=target_sims,
            negative_sims=negative_sims,
            margins=margins,
            normalized_scores=normalized_scores,
            task_ids=[_get_task_id(t) or str(i) for i, (t, _) in enumerate(successful)],
            generation_time=gen_time,
            n_failed=n_failed,
        )
        all_checkpoint_scores.append(scores)
        
        # Print summary
        if margins:
            m_arr = np.array(margins)
            t_arr = np.array(target_sims)
            n_arr = np.array(negative_sims)
            print(f"\n  📊 Summary:")
            print(f"     Target sim:   median={np.median(t_arr):.4f}, mean={t_arr.mean():.4f}")
            print(f"     Negative sim: median={np.median(n_arr):.4f}, mean={n_arr.mean():.4f}")
            print(f"     Margin:       median={np.median(m_arr):+.4f}, mean={m_arr.mean():+.4f}")
            print(f"     Margin > 0:   {(m_arr > 0).mean()*100:.1f}%")
    
    if not all_checkpoint_scores:
        print("\n[ERROR] No checkpoints were successfully evaluated!")
        return
    
    # ---------------------------
    # Generate plots and save results
    # ---------------------------
    print(f"\n{'='*60}")
    print("Generating plots and saving results...")
    print(f"{'='*60}")
    
    # Violin plots
    for metric in ["target", "negative", "margin"]:
        plot_violin_comparison(all_checkpoint_scores, args.output_dir, metric)
    
    # CDF plots
    for metric in ["target", "negative", "margin"]:
        plot_cdf_comparison(all_checkpoint_scores, args.output_dir, metric)
    
    # Progression plot
    plot_summary_progression(all_checkpoint_scores, args.output_dir)
    
    # Statistics CSV
    save_statistics_csv(all_checkpoint_scores, args.output_dir)
    
    # Raw scores (optional)
    if args.save_raw_scores:
        save_raw_scores(all_checkpoint_scores, args.output_dir)
    
    # ---------------------------
    # Print final summary
    # ---------------------------
    print(f"\n{'='*80}")
    print("📊 FINAL SUMMARY")
    print(f"{'='*80}")
    
    print(f"\nEvaluated {len(all_checkpoint_scores)} checkpoints on {args.n_tasks} held-out tasks")
    print(f"Evaluation embedder: {args.eval_embed_model}")
    print()
    
    print("Checkpoint progression:")
    print("-" * 70)
    print(f"{'Step':>8} | {'Target (med)':>12} | {'Neg (med)':>12} | {'Margin (med)':>12} | {'Margin>0':>10}")
    print("-" * 70)
    
    for s in all_checkpoint_scores:
        t_med = np.median(s.target_sims) if s.target_sims else 0
        n_med = np.median(s.negative_sims) if s.negative_sims else 0
        m_med = np.median(s.margins) if s.margins else 0
        m_pos = (np.array(s.margins) > 0).mean() * 100 if s.margins else 0
        print(f"{s.step:>8} | {t_med:>12.4f} | {n_med:>12.4f} | {m_med:>+12.4f} | {m_pos:>9.1f}%")
    
    print("-" * 70)
    
    # Check if there's improvement
    if len(all_checkpoint_scores) >= 2:
        first = all_checkpoint_scores[0]
        last = all_checkpoint_scores[-1]
        
        first_margin = np.median(first.margins) if first.margins else 0
        last_margin = np.median(last.margins) if last.margins else 0
        
        delta = last_margin - first_margin
        print(f"\nMargin improvement (step {last.step} vs step {first.step}): {delta:+.4f}")
        
        if delta > 0.01:
            print("✅ Later checkpoints generate bugs that look MORE like target and LESS like negative!")
        elif delta < -0.01:
            print("⚠️ Later checkpoints generate bugs that look LESS like target")
        else:
            print("➡️ No significant change in target-likeness across checkpoints")
    
    print(f"\n📁 Results saved to: {args.output_dir}")
    print("\n✅ Done!")


if __name__ == "__main__":
    main()
