"""Dataset construction for SWE-Smith domain.

Generates training and validation datasets for software engineering patch generation with advisor feedback.
Uses SWE-Smith harness for evaluation. Saves datasets as JSON files (JSONL format).

Example usage:
    uv run python -m advisor_models.swe_smith.construct_dataset \
        --output_dir data/swe_smith \
        --repo jd__tenacity \
        --train_size 100 \
        --val_size 49 \
        --suffix efficient
"""

from __future__ import annotations

import argparse
import json
import random
from pathlib import Path
from typing import List, Dict, Any
from tqdm import tqdm
from datasets import load_dataset

from .config import build_advisor_prompt

# SWE-Smith dataset on HuggingFace
SWE_SMITH_DATASET = "SWE-bench/SWE-smith"


def load_swe_smith_data(
    split: str = "train", repo_filter: str = None
) -> List[Dict[str, Any]]:
    """Load SWE-Smith data from HuggingFace.

    Args:
        split: Dataset split to load ('train' or 'test')
        repo_filter: Optional repo name to filter by (e.g., 'jd__tenacity')

    Returns:
        List of problem instances
    """
    print(f"Loading SWE-Smith {split} data...")

    try:
        dataset = load_dataset(SWE_SMITH_DATASET, split=split)
        instances = []

        for item in dataset:
            # Filter by repo if specified
            if repo_filter and repo_filter not in item.get("repo", ""):
                continue

            if item.get("problem_statement", "") == "":
                continue

            instances.append(dict(item))

        print(
            f"Loaded {len(instances)} instances"
            + (f" for repo {repo_filter}" if repo_filter else "")
        )
        return instances

    except Exception as e:
        print(f"Error loading SWE-Smith data: {e}")
        return []


def build_dataset_row(instance: Dict[str, Any]) -> Dict[str, Any]:
    """Build a dataset row for a single SWE-Smith instance.

    Args:
        instance: SWE-Smith problem instance

    Returns:
        Dataset row dict
    """
    # Build advisor prompt
    info_dict = {
        "problem_statement": instance.get("problem_statement", ""),
        "repo": instance.get("repo", ""),
    }
    advisor_prompt_text = build_advisor_prompt(info_dict)

    # Create the prompt structure for training
    prompt = [{"role": "user", "content": advisor_prompt_text}]

    # Serialize the entire instance as JSON to avoid schema issues with list fields
    # The instance will be deserialized in the environment
    # Convert to dict and explicitly handle any Arrow types from HuggingFace datasets
    instance_dict = {
        "instance_id": str(instance.get("instance_id", "")),
        "patch": str(instance.get("patch", "")),
        "repo": str(instance.get("repo", "")),
        "problem_statement": str(instance.get("problem_statement", "")),
        "image_name": str(instance.get("image_name", "")),
        "FAIL_TO_PASS": list(instance.get("FAIL_TO_PASS", [])),
        "PASS_TO_PASS": list(instance.get("PASS_TO_PASS", [])),
    }
    instance_json = json.dumps(instance_dict)

    return {
        "prompt": prompt,
        "env_class": "swe_smith",
        "reward_spec": {"ground_truth_json": instance_json},  # Store as JSON string
        "original_question": instance.get("problem_statement", ""),
    }


def process_instances(instances: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Process multiple instances into dataset rows.

    Args:
        instances: List of SWE-Smith instances

    Returns:
        List of dataset rows
    """
    print(f"Processing {len(instances)} instances...")

    rows = []
    for instance in tqdm(instances, desc="Building dataset"):
        try:
            row = build_dataset_row(instance)
            rows.append(row)
        except Exception as e:
            print(
                f"Error processing instance {instance.get('instance_id', 'unknown')}: {e}"
            )

    return rows


def main():
    parser = argparse.ArgumentParser(description="Construct SWE-Smith domain dataset")
    parser.add_argument(
        "--output_dir",
        type=str,
        required=True,
        help="Output directory for dataset files",
    )
    parser.add_argument(
        "--repo",
        type=str,
        default=None,
        help="Filter by specific repo (e.g., 'jd__tenacity')",
    )
    parser.add_argument(
        "--train_size",
        type=int,
        default=None,
        help="Limit training set size",
    )
    parser.add_argument(
        "--val_size",
        type=int,
        default=None,
        help="Limit validation set size",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for sampling",
    )
    parser.add_argument(
        "--suffix",
        type=str,
        default="",
        help="Suffix to add to output filenames",
    )

    args = parser.parse_args()
    random.seed(args.seed)

    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Load datasets
    train_instances = load_swe_smith_data(split="train", repo_filter=args.repo)

    # For validation, we can use a subset of train or load test split
    # For now, we'll split train data
    if len(train_instances) > 0:
        # Shuffle and split
        random.shuffle(train_instances)
        val_split_idx = int(len(train_instances) * 0.8)
        val_instances = train_instances[val_split_idx:]
        train_instances = train_instances[:val_split_idx]
    else:
        val_instances = []

    # Limit sizes if specified
    if args.train_size and len(train_instances) > args.train_size:
        train_instances = random.sample(train_instances, args.train_size)
    if args.val_size and len(val_instances) > args.val_size:
        val_instances = random.sample(val_instances, args.val_size)

    # Process instances
    train_rows = process_instances(train_instances)
    val_rows = process_instances(val_instances)

    # Save datasets as JSON (JSONL format - one JSON object per line)
    if train_rows:
        repo_suffix = f"_{args.repo}" if args.repo else ""
        suffix_str = f"_{args.suffix}" if args.suffix else ""
        filename = f"train{repo_suffix}{suffix_str}.json"
        with open(output_dir / filename, "w") as f:
            for row in train_rows:
                f.write(json.dumps(row) + "\n")
        print(f"Saved {len(train_rows)} training examples to {filename}")

    if val_rows:
        repo_suffix = f"_{args.repo}" if args.repo else ""
        suffix_str = f"_{args.suffix}" if args.suffix else ""
        filename = f"validation{repo_suffix}{suffix_str}.json"
        with open(output_dir / filename, "w") as f:
            for row in val_rows:
                f.write(json.dumps(row) + "\n")
        print(f"Saved {len(val_rows)} validation examples to {filename}")


if __name__ == "__main__":
    main()
