#!/usr/bin/env python3
"""
Sample representative examples from the unintended behaviors dataset.

This script selects examples for annotation by:
1. Selecting one successful seed per task to ensure coverage
2. Performing stratified sampling based on severity level classification

Usage Examples:
---------------

# Basic usage - random selection, one example per task (no distribution matching)
python sample_examples.py --domain multi_apps_test

# Match the computed distribution from successful seeds (recommended)
python sample_examples.py --domain multi_apps_test --match-dataset-distribution

# Provide an explicit target distribution
python sample_examples.py --domain multi_apps_test \
    --target-distribution '{"NONE": 0.0, "MINIMAL": 2.2, "LOW": 24.4, "MEDIUM": 54.8, "HIGH": 11.1, "CRITICAL": 7.4}'

# Specify different models/agents
python sample_examples.py --domain multi_apps_test \
    --perturbation-model o4-mini-2025-04-16 \
    --refinement-model us_anthropic_claude-haiku-4-5-20251001-v1_0 \
    --execution-agent claude-haiku-4-5-20251001 \
    --match-dataset-distribution

# Limit the number of samples per task
python sample_examples.py --domain multi_apps_test --match-dataset-distribution --max-samples 20

# Save results to a JSON file (defaults to sampling/ directory)
python sample_examples.py --domain multi_apps_test --match-dataset-distribution \
    --output-json samples.json

# Verbose mode with full example details
python sample_examples.py --domain multi_apps_test --match-dataset-distribution --verbose

# Quiet mode (only JSON output, useful for piping)
python sample_examples.py --domain multi_apps_test --match-dataset-distribution --quiet

# Set random seed for reproducibility
python sample_examples.py --domain multi_apps_test --match-dataset-distribution --seed 123

# Copy selected examples to the data directory for annotation
python sample_examples.py --domain multi_apps_test --match-dataset-distribution --copy-to-data

# Copy to a custom data directory
python sample_examples.py --domain multi_apps_test --match-dataset-distribution \
    --copy-to-data --data-dir /path/to/custom/data

# Include ALL successful examples (no sampling)
python sample_examples.py --domain multi_apps_test --include-all --copy-to-data

# Filter out specific task IDs based on a filter file (e.g., harm threshold filter)
python sample_examples.py --domain multi_apps_test --match-dataset-distribution \
    --filter-file ../data/data_filter_harm_threshold_0%.json

# Combine filtering with copying to a separate data directory for comparison
python sample_examples.py --domain multi_apps_test --match-dataset-distribution \
    --filter-file ../data/data_filter_harm_threshold_0%_final.json \
    --copy-to-data --data-dir ../data_filtered_stratified_final
"""

import argparse
import json
import os
import random
import shutil
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple


def load_filter_file(filter_path: Path, execution_agent: str) -> Set[str]:
    """
    Load a filter file and return the set of task IDs to exclude for the given agent.
    
    Filter file format:
    {
        "agent_<execution_agent>": ["task_id_1", "task_id_2", ...],
        ...
    }
    
    Args:
        filter_path: Path to the filter JSON file
        execution_agent: The execution agent name (without 'agent_' prefix)
        
    Returns:
        Set of task IDs to exclude, or empty set if agent not found in filter
    """
    try:
        with open(filter_path, "r") as f:
            filter_config = json.load(f)
    except (json.JSONDecodeError, FileNotFoundError) as e:
        print(f"Warning: Could not load filter file {filter_path}: {e}")
        return set()
    
    # The filter file uses keys like "agent_claude-haiku-4-5-20251001"
    agent_key = f"agent_{execution_agent}"
    
    if agent_key not in filter_config:
        return set()
    
    return set(filter_config[agent_key])


def parse_args():
    parser = argparse.ArgumentParser(
        description="Sample representative examples for annotation"
    )
    parser.add_argument(
        "--source-dir",
        type=str,
        default="/local/scratch/jones.6278/unintended_behaviors/OSWorld/perturbation_generation/perturbed_queries",
        help="Path to the perturbed_queries source directory",
    )
    parser.add_argument(
        "--domain",
        type=str,
        required=True,
        help="Domain to sample from (e.g., 'multi_apps_test', 'os')",
    )
    parser.add_argument(
        "--perturbation-model",
        type=str,
        default="o4-mini-2025-04-16",
        help="Perturbation model used",
    )
    parser.add_argument(
        "--refinement-model",
        type=str,
        default="us_anthropic_claude-haiku-4-5-20251001-v1_0",
        help="Refinement model used",
    )
    parser.add_argument(
        "--execution-agent",
        type=str,
        default="claude-haiku-4-5-20251001",
        help="Execution agent used",
    )
    parser.add_argument(
        "--target-distribution",
        type=str,
        default=None,
        help=(
            "Target severity distribution as JSON. "
            'Example: \'{"NONE": 0.0, "MINIMAL": 2.2, "LOW": 24.4, "MEDIUM": 54.8, "HIGH": 11.1, "CRITICAL": 7.4}\''
        ),
    )
    parser.add_argument(
        "--match-dataset-distribution",
        action="store_true",
        help=(
            "Set target distribution to match the computed distribution of successful seeds. "
            "Ignored if --target-distribution is provided."
        ),
    )
    parser.add_argument(
        "--max-samples",
        type=int,
        default=None,
        help="Maximum number of samples to select (if not specified, one per task)",
    )
    parser.add_argument(
        "--include-all",
        action="store_true",
        help="Include ALL successful examples instead of sampling one per task",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Print detailed information about the sampling process",
    )
    parser.add_argument(
        "--output-json",
        type=str,
        default=None,
        help="Path to save the sampling results as JSON. Relative paths are saved in the sampling/ directory.",
    )
    parser.add_argument(
        "--quiet",
        action="store_true",
        help="Suppress output except for the JSON summary",
    )
    parser.add_argument(
        "--copy-to-data",
        action="store_true",
        help="Copy selected examples to the data directory for annotation",
    )
    parser.add_argument(
        "--data-dir",
        type=str,
        default=None,
        help="Path to the annotation data directory. Defaults to ../data relative to the script.",
    )
    parser.add_argument(
        "--filter-file",
        type=str,
        default=None,
        help=(
            "Path to a JSON file specifying task IDs to exclude per agent. "
            "Format: {\"agent_<name>\": [\"task_id_1\", ...]}. "
            "Task IDs listed for the specified --execution-agent will be excluded from sampling."
        ),
    )
    return parser.parse_args()


def load_refinement_summary(summary_path: Path) -> Optional[Dict]:
    """Load and parse a refinement_summary.json file."""
    try:
        with open(summary_path, "r") as f:
            return json.load(f)
    except (json.JSONDecodeError, FileNotFoundError) as e:
        print(f"Warning: Could not load {summary_path}: {e}")
        return None


def get_severity_from_summary(summary: Dict) -> Optional[str]:
    """Extract the final severity assessment from a refinement summary."""
    if not summary.get("success", False):
        return None
    
    execution_history = summary.get("execution_history", [])
    if not execution_history:
        return None
    
    # Get the last execution iteration
    last_execution = execution_history[-1]
    trajectory_eval = last_execution.get("trajectory_evaluation", {})
    severity = trajectory_eval.get("severity_assessment", None)
    
    if severity:
        return severity.upper()
    return None


def collect_all_examples(
    source_dir: Path,
    domain: str,
    perturbation_model: str,
    refinement_model: str,
    execution_agent: str,
    excluded_task_ids: Optional[Set[str]] = None,
) -> Dict[str, List[Dict]]:
    """
    Collect all successful examples organized by task.
    
    Args:
        source_dir: Path to the perturbed_queries source directory
        domain: Domain to sample from
        perturbation_model: Perturbation model used
        refinement_model: Refinement model used
        execution_agent: Execution agent used
        excluded_task_ids: Optional set of task IDs to exclude from collection
    
    Returns a dict: {task_id: [list of example dicts]}
    """
    domain_dir = source_dir / domain
    if not domain_dir.exists():
        raise ValueError(f"Domain directory not found: {domain_dir}")
    
    if excluded_task_ids is None:
        excluded_task_ids = set()
    
    examples_by_task = defaultdict(list)
    
    for task_dir in domain_dir.iterdir():
        if not task_dir.is_dir():
            continue
        
        task_id = task_dir.name
        
        # Skip excluded task IDs
        if task_id in excluded_task_ids:
            continue
        
        model_dir = task_dir / perturbation_model
        
        if not model_dir.exists():
            continue
        
        for perturbed_dir in model_dir.iterdir():
            if not perturbed_dir.is_dir() or not perturbed_dir.name.startswith("perturbed_query_"):
                continue
            
            perturbed_id = perturbed_dir.name.replace("perturbed_query_", "")
            refinement_dir = perturbed_dir / f"iterative_refinement_{refinement_model}"
            
            if not refinement_dir.exists():
                continue
            
            agent_dir = refinement_dir / f"agent_{execution_agent}"
            if not agent_dir.exists():
                continue
            
            summary_path = agent_dir / "refinement_summary.json"
            if not summary_path.exists():
                continue
            
            summary = load_refinement_summary(summary_path)
            if summary is None:
                continue
            
            # Check if this is a successful elicitation
            if not summary.get("success", False):
                continue
            
            severity = get_severity_from_summary(summary)
            if severity is None:
                continue
            
            # Get additional info for display
            execution_history = summary.get("execution_history", [])
            last_execution = execution_history[-1] if execution_history else {}
            perturbed_instruction = last_execution.get("perturbed_instruction", "N/A")
            elicitation_strategy = last_execution.get("elicitation_strategy", "N/A")
            trajectory_eval = last_execution.get("trajectory_evaluation", {})
            
            example = {
                "task_id": task_id,
                "perturbed_id": perturbed_id,
                "severity": severity,
                "source_path": str(agent_dir),
                "perturbed_instruction": perturbed_instruction,
                "elicitation_strategy": elicitation_strategy,
                "behavior_match": trajectory_eval.get("matches_expected_behavior", "N/A"),
                "final_score": summary.get("final_score", 0),
                "harmful_actions": trajectory_eval.get("harmful_actions_observed", []),
            }
            
            examples_by_task[task_id].append(example)
    
    return dict(examples_by_task)


def calculate_severity_distribution(examples: List[Dict]) -> Dict[str, float]:
    """Calculate the percentage distribution of severity levels."""
    severity_counts = defaultdict(int)
    total = len(examples)
    
    for example in examples:
        severity_counts[example["severity"]] += 1
    
    distribution = {}
    severity_order = ["NONE", "MINIMAL", "LOW", "MEDIUM", "HIGH", "CRITICAL"]
    
    for severity in severity_order:
        count = severity_counts.get(severity, 0)
        distribution[severity] = (count / total * 100) if total > 0 else 0.0
    
    return distribution


def select_one_per_task(
    examples_by_task: Dict[str, List[Dict]],
    target_distribution: Optional[Dict[str, float]] = None,
    random_seed: int = 42,
) -> List[Dict]:
    """
    Select one example per task using stratified sampling.
    
    If target_distribution is provided, prioritize examples that help
    match the target distribution.
    """
    random.seed(random_seed)
    selected = []
    
    # If no target distribution, just pick randomly from each task
    if target_distribution is None:
        for task_id, examples in examples_by_task.items():
            if examples:
                selected.append(random.choice(examples))
        return selected
    
    # With target distribution, we need to be smarter about selection
    # First, calculate what we need for each severity level
    total_tasks = len(examples_by_task)
    target_counts = {
        severity: int(round(pct * total_tasks / 100))
        for severity, pct in target_distribution.items()
    }
    
    # Track current counts
    current_counts = defaultdict(int)
    
    # Group tasks by available severities
    tasks_with_severity = defaultdict(list)
    for task_id, examples in examples_by_task.items():
        available_severities = set(ex["severity"] for ex in examples)
        for severity in available_severities:
            tasks_with_severity[severity].append(task_id)
    
    # First pass: select examples that help meet underrepresented severities
    selected_tasks = set()
    severity_priority = ["CRITICAL", "HIGH", "LOW", "MINIMAL", "MEDIUM", "NONE"]
    
    for severity in severity_priority:
        needed = target_counts.get(severity, 0) - current_counts[severity]
        available_tasks = [
            t for t in tasks_with_severity[severity]
            if t not in selected_tasks
        ]
        random.shuffle(available_tasks)
        
        for task_id in available_tasks[:max(0, needed)]:
            examples = examples_by_task[task_id]
            matching = [ex for ex in examples if ex["severity"] == severity]
            if matching:
                selected.append(random.choice(matching))
                selected_tasks.add(task_id)
                current_counts[severity] += 1
    
    # Second pass: fill remaining tasks with any severity
    for task_id, examples in examples_by_task.items():
        if task_id not in selected_tasks and examples:
            # Prefer examples that are underrepresented
            best_example = None
            best_deficit = -float("inf")
            
            for example in examples:
                severity = example["severity"]
                target = target_counts.get(severity, 0)
                current = current_counts[severity]
                deficit = target - current
                
                if deficit > best_deficit:
                    best_deficit = deficit
                    best_example = example
            
            if best_example:
                selected.append(best_example)
                selected_tasks.add(task_id)
                current_counts[best_example["severity"]] += 1
    
    return selected


def get_successful_iteration_number(summary: Dict) -> int:
    """Get the iteration number of the successful execution from the refinement summary."""
    execution_history = summary.get("execution_history", [])
    if not execution_history:
        return 0
    # The last execution in history is the successful one
    return len(execution_history) - 1


def create_final_successful_query(summary: Dict) -> Dict:
    """Create the final_successful_query.json content from the refinement summary."""
    execution_history = summary.get("execution_history", [])
    if not execution_history:
        return {}
    
    last_execution = execution_history[-1]
    trajectory_eval = last_execution.get("trajectory_evaluation", {})
    
    return {
        "id": last_execution.get("execution_result", {}).get("perturbed_id", "unknown"),
        "perturbed_instruction": last_execution.get("perturbed_instruction", ""),
        "elicitation_strategy": last_execution.get("elicitation_strategy", ""),
        "behavior_elicitation_score": trajectory_eval.get("behavior_elicitation_score", 0),
        "execution_iterations_required": summary.get("execution_iterations", 1),
        "quality_refinements_required": summary.get("total_quality_refinements", 0),
        "recommendation": trajectory_eval.get("recommendation", "COLLECT"),
        "timestamp": last_execution.get("timestamp", ""),
    }


def copy_example_to_data(
    example: Dict,
    data_dir: Path,
    domain: str,
    perturbation_model: str,
    refinement_model: str,
    execution_agent: str,
    quiet: bool = False,
) -> bool:
    """
    Copy a selected example to the data directory with the correct structure.
    
    Source structure:
        perturbed_queries/<domain>/<task_id>/<perturbation_model>/perturbed_query_<id>/
            iterative_refinement_<refinement_model>/agent_<execution_agent>/
            ├── refinement_summary.json
            └── iteration_<N>/
                ├── traj.jsonl
                ├── trajectory_summary.md
                ├── trajectory_evaluation.json
                └── step_*.png
    
    Destination structure:
        data/<domain>/<perturbation_model>/<task_id>/perturbed_query_<id>/
            iterative_refinement_<refinement_model>/agent_<execution_agent>/
            ├── final_successful_query.json
            ├── refinement_summary.json
            └── trajectory/
                ├── traj.jsonl
                ├── trajectory_summary.md
                ├── trajectory_evaluation.json
                └── step_*.png
    """
    source_path = Path(example["source_path"])
    task_id = example["task_id"]
    perturbed_id = example["perturbed_id"]
    
    # Build destination path (note the different hierarchy: domain/perturbation_model/task_id)
    dest_path = (
        data_dir / domain / perturbation_model / task_id /
        f"perturbed_query_{perturbed_id}" /
        f"iterative_refinement_{refinement_model}" /
        f"agent_{execution_agent}"
    )
    
    # Check if already exists
    if dest_path.exists():
        if not quiet:
            print(f"  Skipping {task_id[:8]}...:{perturbed_id} (already exists)")
        return False
    
    # Load refinement summary to get iteration info
    summary_path = source_path / "refinement_summary.json"
    if not summary_path.exists():
        print(f"  Error: refinement_summary.json not found at {summary_path}")
        return False
    
    summary = load_refinement_summary(summary_path)
    if summary is None:
        return False
    
    # Determine the successful iteration
    iteration_num = get_successful_iteration_number(summary)
    iteration_dir = source_path / f"iteration_{iteration_num}"
    
    if not iteration_dir.exists():
        print(f"  Error: iteration_{iteration_num} not found at {iteration_dir}")
        return False
    
    # Create destination directory
    dest_path.mkdir(parents=True, exist_ok=True)
    
    # Copy refinement_summary.json
    shutil.copy2(summary_path, dest_path / "refinement_summary.json")
    
    # Check if final_successful_query.json exists in source, if so copy it
    source_final_query = source_path / "final_successful_query.json"
    if source_final_query.exists():
        shutil.copy2(source_final_query, dest_path / "final_successful_query.json")
    else:
        # Create final_successful_query.json from summary
        final_query = create_final_successful_query(summary)
        with open(dest_path / "final_successful_query.json", "w") as f:
            json.dump(final_query, f, indent=2)
    
    # Create trajectory directory and copy contents
    trajectory_dest = dest_path / "trajectory"
    trajectory_dest.mkdir(parents=True, exist_ok=True)
    
    # Files to copy from iteration folder to trajectory folder
    trajectory_files = ["traj.jsonl", "trajectory_summary.md", "trajectory_evaluation.json"]
    for filename in trajectory_files:
        src_file = iteration_dir / filename
        if src_file.exists():
            shutil.copy2(src_file, trajectory_dest / filename)
    
    # Copy step_*.png files
    for png_file in iteration_dir.glob("step_*.png"):
        shutil.copy2(png_file, trajectory_dest / png_file.name)
    
    if not quiet:
        print(f"  Copied {task_id[:8]}...:{perturbed_id} (iteration {iteration_num})")
    
    return True


def print_selected_examples(selected: List[Dict], verbose: bool = False):
    """Print the selected examples in a readable format."""
    print("\n" + "=" * 80)
    print("SELECTED EXAMPLES")
    print("=" * 80)
    
    # Sort by severity then task_id
    severity_order = {"CRITICAL": 0, "HIGH": 1, "MEDIUM": 2, "LOW": 3, "MINIMAL": 4, "NONE": 5}
    selected_sorted = sorted(
        selected,
        key=lambda x: (severity_order.get(x["severity"], 6), x["task_id"])
    )
    
    current_severity = None
    for example in selected_sorted:
        if example["severity"] != current_severity:
            current_severity = example["severity"]
            print(f"\n--- {current_severity} ---")
        
        task_short = example["task_id"][:8] + "..."
        print(f"  {task_short}:{example['perturbed_id']} "
              f"(match: {example['behavior_match']}, score: {example['final_score']})")
        
        if verbose:
            print(f"    Instruction: {example['perturbed_instruction'][:80]}...")
            print(f"    Strategy: {example['elicitation_strategy'][:50]}...")
            if example["harmful_actions"]:
                print(f"    Harmful actions: {len(example['harmful_actions'])}")
            print(f"    Path: {example['source_path']}")


def main():
    args = parse_args()
    
    # Parse target distribution if provided
    target_distribution = None
    if args.target_distribution:
        try:
            target_distribution = json.loads(args.target_distribution)
        except json.JSONDecodeError as e:
            print(f"Error parsing target distribution: {e}")
            return 1
    
    source_dir = Path(args.source_dir)
    quiet = args.quiet
    
    # Load filter file if provided
    excluded_task_ids: Set[str] = set()
    filter_file_path: Optional[Path] = None
    if args.filter_file:
        filter_file_path = Path(args.filter_file)
        # If relative path, resolve relative to the script's directory
        if not filter_file_path.is_absolute():
            script_dir = Path(__file__).parent.resolve()
            filter_file_path = script_dir / filter_file_path
        
        excluded_task_ids = load_filter_file(filter_file_path, args.execution_agent)
        if not quiet:
            if excluded_task_ids:
                print(f"\nFilter file loaded: {filter_file_path}")
                print(f"  Excluding {len(excluded_task_ids)} task IDs for agent '{args.execution_agent}'")
            else:
                print(f"\nFilter file loaded: {filter_file_path}")
                print(f"  No exclusions found for agent '{args.execution_agent}'")
    
    if not quiet:
        print(f"\nConfiguration:")
        print(f"  Source Directory: {source_dir}")
        print(f"  Domain: {args.domain}")
        print(f"  Perturbation Model: {args.perturbation_model}")
        print(f"  Refinement Model: {args.refinement_model}")
        print(f"  Execution Agent: {args.execution_agent}")
        print(f"  Random Seed: {args.seed}")
        if args.filter_file:
            print(f"  Filter File: {filter_file_path}")
            print(f"  Excluded Task IDs: {len(excluded_task_ids)}")
        
        if target_distribution:
            print(f"  Target Distribution: {target_distribution}")
    
    # Collect all examples (excluding filtered task IDs)
    if not quiet:
        print("\nCollecting examples...")
        if excluded_task_ids:
            print(f"  (Excluding {len(excluded_task_ids)} filtered task IDs)")
    examples_by_task = collect_all_examples(
        source_dir,
        args.domain,
        args.perturbation_model,
        args.refinement_model,
        args.execution_agent,
        excluded_task_ids=excluded_task_ids,
    )
    
    total_examples = sum(len(exs) for exs in examples_by_task.values())
    if not quiet:
        print(f"  Found {total_examples} successful examples across {len(examples_by_task)} tasks")
    
    if not examples_by_task:
        print("No examples found!")
        return 1
    
    # Calculate original distribution
    all_examples = [ex for exs in examples_by_task.values() for ex in exs]
    original_distribution = calculate_severity_distribution(all_examples)
    
    if not quiet:
        print("\nOriginal Severity Distribution (all successful examples):")
        for severity, pct in original_distribution.items():
            count = sum(1 for ex in all_examples if ex["severity"] == severity)
            print(f"  {severity}: {pct:.1f}% ({count})")
    
    # If --match-dataset-distribution is set and no explicit target was provided,
    # use the computed distribution as the target
    if args.match_dataset_distribution and target_distribution is None:
        target_distribution = original_distribution.copy()
        if not quiet:
            print("\nUsing dataset distribution as target (--match-dataset-distribution)")
    
    # Select examples
    if args.include_all:
        # Include all successful examples
        selected = all_examples
        if not quiet:
            print("\nIncluding ALL successful examples (--include-all)")
    else:
        # Sample one per task
        selected = select_one_per_task(
            examples_by_task,
            target_distribution=target_distribution,
            random_seed=args.seed,
        )
    
    # Apply max_samples limit if specified
    if args.max_samples and len(selected) > args.max_samples:
        random.seed(args.seed)
        selected = random.sample(selected, args.max_samples)
    
    # Calculate selected distribution
    selected_distribution = calculate_severity_distribution(selected)
    
    if not quiet:
        if args.include_all:
            print(f"\nIncluded {len(selected)} examples (all successful)")
        else:
            print(f"\nSelected {len(selected)} examples (one per task)")
        print("\nSelected Severity Distribution:")
        for severity in ["NONE", "MINIMAL", "LOW", "MEDIUM", "HIGH", "CRITICAL"]:
            count = sum(1 for ex in selected if ex["severity"] == severity)
            pct = selected_distribution.get(severity, 0.0)
            target_pct = target_distribution.get(severity, 0.0) if target_distribution else original_distribution.get(severity, 0.0)
            diff = pct - target_pct
            diff_str = f"({diff:+.1f}%)" if abs(diff) > 0.1 else "(on target)"
            print(f"  {severity}: {pct:.1f}% ({count}) {diff_str}")
        
        # Compare to target
        if target_distribution:
            print("\nComparison to Target Distribution:")
            print(f"  {'Severity':<10} {'Target':>10} {'Actual':>10} {'Diff':>10}")
            print("  " + "-" * 40)
            for severity in ["NONE", "MINIMAL", "LOW", "MEDIUM", "HIGH", "CRITICAL"]:
                target = target_distribution.get(severity, 0.0)
                actual = selected_distribution.get(severity, 0.0)
                diff = actual - target
                print(f"  {severity:<10} {target:>9.1f}% {actual:>9.1f}% {diff:>+9.1f}%")
        
        # Print selected examples
        print_selected_examples(selected, verbose=args.verbose)
    
    output = {
        "configuration": {
            "domain": args.domain,
            "perturbation_model": args.perturbation_model,
            "refinement_model": args.refinement_model,
            "execution_agent": args.execution_agent,
            "random_seed": args.seed,
            "filter_file": str(filter_file_path) if filter_file_path else None,
            "excluded_task_ids_count": len(excluded_task_ids),
        },
        "summary": {
            "total_tasks": len(examples_by_task),
            "total_successful_examples": total_examples,
            "selected_count": len(selected),
            "excluded_by_filter": len(excluded_task_ids),
        },
        "distribution": {
            "original": original_distribution,
            "selected": selected_distribution,
            "target": target_distribution,
        },
        "selected_examples": [
            {
                "task_id": ex["task_id"],
                "perturbed_id": ex["perturbed_id"],
                "severity": ex["severity"],
                "source_path": ex["source_path"],
            }
            for ex in selected
        ],
    }
    
    # Copy to data directory if requested
    if args.copy_to_data:
        # Determine data directory
        if args.data_dir:
            data_dir = Path(args.data_dir)
        else:
            # Default to ../data relative to script
            script_dir = Path(__file__).parent.resolve()
            data_dir = script_dir.parent / "data"
        
        if not quiet:
            print(f"\n{'=' * 80}")
            print("COPYING EXAMPLES TO DATA DIRECTORY")
            print(f"{'=' * 80}")
            print(f"Destination: {data_dir}")
        
        copied_count = 0
        skipped_count = 0
        error_count = 0
        
        for example in selected:
            try:
                if copy_example_to_data(
                    example,
                    data_dir,
                    args.domain,
                    args.perturbation_model,
                    args.refinement_model,
                    args.execution_agent,
                    quiet=quiet,
                ):
                    copied_count += 1
                else:
                    skipped_count += 1
            except Exception as e:
                error_count += 1
                if not quiet:
                    print(f"  Error copying {example['task_id'][:8]}...:{example['perturbed_id']}: {e}")
        
        if not quiet:
            print(f"\nCopy summary: {copied_count} copied, {skipped_count} skipped, {error_count} errors")
        
        # Add copy info to output
        output["copy_results"] = {
            "data_dir": str(data_dir),
            "copied": copied_count,
            "skipped": skipped_count,
            "errors": error_count,
        }
    
    # Save to file if requested
    if args.output_json:
        output_path = Path(args.output_json)
        # If relative path, resolve relative to the script's directory (sampling/)
        if not output_path.is_absolute():
            script_dir = Path(__file__).parent.resolve()
            output_path = script_dir / output_path
        output_path.parent.mkdir(parents=True, exist_ok=True)
        with open(output_path, "w") as f:
            json.dump(output, f, indent=2)
        if not quiet:
            print(f"\nSaved sampling results to: {output_path}")
    
    if not quiet:
        print("\n" + "=" * 80)
        print("JSON OUTPUT (for programmatic use)")
        print("=" * 80)
        print(json.dumps(output, indent=2))
    elif not args.output_json:
        # If quiet and no output file, still print JSON for piping
        print(json.dumps(output, indent=2))
    
    return 0


if __name__ == "__main__":
    exit(main())

