import argparse
import json
from typing import Optional, Set, Tuple, Union

from datasets import load_dataset

from rllm.data.dataset import DatasetRegistry
from rllm.data.utils import fetch_live_code_bench_system_prompt


ALLOWED_SUBSETS = {
    "Prefill",
    "Docs",
    "Package",
    "Filter",
    "Algorithm",
    "Data Structures",
}

# Default path to the failures JSON from eval_kodcode_reference_solutions.py
DEFAULT_FAILURES_PATH = "/data/user/rllm/logs/kodcode_eval/kodcode_failures_20260111_120434.json"


def _ensure_fenced_python(code: str) -> str:
    """
    Ensure code is wrapped as a markdown python code block.
    If the input already contains triple-backtick fences, leave it as-is.
    """
    s = (code or "").strip("\n")
    if "```" in s:
        return s
    return f"```python\n{s}\n```"


def _normalize_tests_to_string(tests_raw) -> str:
    """
    Normalize tests to a string format expected by kodcode_check_correctness.
    KodCode uses pytest-style tests as strings.
    """
    if tests_raw is None:
        return ""
    if isinstance(tests_raw, list):
        return "\n\n".join(str(t) for t in tests_raw)
    return str(tests_raw)


def prepare_kodcode_data(
    *,
    repo_id: str = "KodCode/KodCode-V1",
    registry_name: Optional[str] = "kodcode",
    split_half: bool = False,
    split_half_shuffle: bool = True,
    split_half_seed: int = 0,
) -> Union[object, Tuple[object, object]]:
    """
    Preprocess the KodCode V1 dataset from HuggingFace into rLLM's expected format,
    and register it via `DatasetRegistry`.

    Filters the dataset to only include entries where 'subset' is in:
        {Prefill, Docs, Package, Filter, Algorithm, Data Structures}

    Output keys (per example):
        - question
        - reference_solution
        - ground_truth
        - data_source
        - uid
        - index
        - starter_code
        - metadata

    Registry behavior:
        - Default: registers the processed dataset as `kodcode/train`.
        - If split_half=True: shuffles (optional) and splits evenly, registering:
            - `kodcode/train0.5`
            - `kodcode/test0.5`
    """
    if not repo_id:
        raise ValueError("repo_id must be a non-empty HuggingFace dataset id (e.g., 'KodCode/KodCode-V1').")

    if registry_name is None:
        registry_name = str(repo_id).split("/")[-1] or "kodcode"

    dataset = load_dataset(str(repo_id), split="train")

    def filter_by_subset(example):
        subset_val = example.get("subset", None)
        if subset_val is None:
            return True
        return subset_val in ALLOWED_SUBSETS

    if "subset" in dataset.column_names:
        dataset = dataset.filter(filter_by_subset, num_proc=8)

    def preprocess_fn(example, idx):
        question_raw = (
            example.get("question", None) or
            example.get("problem", None) or
            example.get("prompt", None) or
            example.get("instruction", None) or
            ""
        )

        solution_raw = (
            example.get("solution", None) or
            example.get("code", None) or
            example.get("canonical_solution", None) or
            example.get("reference_solution", None) or
            ""
        )

        tests_raw = (
            example.get("test", None) or
            example.get("tests", None) or
            example.get("test_code", None) or
            example.get("unit_tests", None) or
            ""
        )

        starter_code = example.get("starter_code", None) or example.get("code_prompt", None) or ""

        # --------
        # Normalize to rLLM format
        # --------
        tests = _normalize_tests_to_string(tests_raw)

        if starter_code:
            question = fetch_live_code_bench_system_prompt(question_raw, starter_code)
        else:
            question = question_raw

        reference_solution = _ensure_fenced_python(solution_raw)

        subset = example.get("subset", "unknown")
        task_id = example.get("task_id", None) or example.get("id", None) or f"kodcode_{idx}"

        return {
            "question": question,
            "reference_solution": reference_solution,
            "ground_truth": tests,
            "data_source": "kodcode",
            "uid": str(task_id),
            "index": int(idx),
            "starter_code": starter_code,
            "subset": subset,
            "metadata": {"subset": subset},
        }

    processed = dataset.map(
        preprocess_fn,
        with_indices=True,
        writer_batch_size=10,
        num_proc=16,
        remove_columns=dataset.column_names,
    )

    if not split_half:
        train_dataset = DatasetRegistry.register_dataset(str(registry_name), processed, "train")
        return train_dataset

    n = len(processed)
    if n == 0:
        train_half = DatasetRegistry.register_dataset(str(registry_name), processed, "train0.5")
        test_half = DatasetRegistry.register_dataset(str(registry_name), processed, "test0.5")
        return train_half, test_half

    if split_half_shuffle:
        processed = processed.shuffle(seed=int(split_half_seed))

    mid = n // 2
    train_ds = processed.select(range(0, mid))
    test_ds = processed.select(range(mid, n))

    train_half = DatasetRegistry.register_dataset(str(registry_name), train_ds, "train0.5")
    test_half = DatasetRegistry.register_dataset(str(registry_name), test_ds, "test0.5")
    return train_half, test_half


def load_failed_task_ids(failures_path: str) -> Set[int]:
    """
    Load failed task IDs from a JSON file and extract the numeric indices.
    
    The failures JSON contains IDs like "kodcode_14", "kodcode_49", etc.
    This function extracts the numeric part (14, 49, etc.) as integers.
    """
    with open(failures_path, "r") as f:
        failed_ids = json.load(f)
    
    # Extract numeric indices from IDs like "kodcode_14"
    failed_indices = set()
    for task_id in failed_ids:
        if isinstance(task_id, str) and task_id.startswith("kodcode_"):
            try:
                idx = int(task_id.replace("kodcode_", ""))
                failed_indices.add(idx)
            except ValueError:
                print(f"Warning: Could not parse task ID: {task_id}")
        elif isinstance(task_id, int):
            failed_indices.add(task_id)
    
    return failed_indices


def filter_and_push_kodcode(
    *,
    source_repo_id: str = "KodCode/KodCode-V1",
    target_repo_id: str = "anonymous/KodCode-V1-filtered",
    failures_path: str = DEFAULT_FAILURES_PATH,
    dry_run: bool = False,
) -> None:
    """
    Filter out failed task IDs from KodCode dataset and push to HuggingFace.
    
    This preserves the original dataset format (same columns as anonymous/KodCode-V1-filtered).
    
    Args:
        source_repo_id: Source HuggingFace dataset repo (default: KodCode/KodCode-V1)
        target_repo_id: Target HuggingFace repo to push to (default: anonymous/KodCode-V1-filtered)
        failures_path: Path to JSON file containing failed task IDs
        dry_run: If True, don't actually push to HuggingFace
    """
    print(f"Loading failed task IDs from: {failures_path}")
    failed_indices = load_failed_task_ids(failures_path)
    print(f"  Found {len(failed_indices)} failed task IDs to filter out")
    
    print(f"\nLoading source dataset: {source_repo_id}")
    dataset = load_dataset(source_repo_id, split="train")
    original_count = len(dataset)
    print(f"  Loaded {original_count} examples")
    
    # Filter by allowed subsets first (same as prepare_kodcode_data)
    def filter_by_subset(example):
        subset_val = example.get("subset", None)
        if subset_val is None:
            return True
        return subset_val in ALLOWED_SUBSETS
    
    if "subset" in dataset.column_names:
        dataset = dataset.filter(filter_by_subset, num_proc=8)
        print(f"  After subset filtering: {len(dataset)} examples")
    
    # Filter out failed tasks by index
    def filter_failed_tasks(example, idx):
        return idx not in failed_indices
    
    filtered_dataset = dataset.filter(
        filter_failed_tasks,
        with_indices=True,
        num_proc=8,
    )
    
    filtered_count = len(filtered_dataset)
    removed_count = original_count - filtered_count
    
    print(f"\nFiltering results:")
    print(f"  Original: {original_count} examples")
    print(f"  Filtered: {filtered_count} examples")
    print(f"  Removed:  {removed_count} examples (failed reference solutions)")
    
    if dry_run:
        print(f"\n[DRY RUN] Would push to: {target_repo_id}")
        print("Skipping actual push.")
    else:
        print(f"\nPushing to HuggingFace: {target_repo_id}")
        filtered_dataset.push_to_hub(target_repo_id, private=False)
        print(f"  ✅ Successfully pushed {filtered_count} examples to {target_repo_id}")
    
    # Print sample to verify format
    print("\nSample example (first row):")
    if filtered_count > 0:
        sample = filtered_dataset[0]
        for key in list(sample.keys())[:8]:  # Show first 8 columns
            val = sample[key]
            val_str = str(val)[:100] + "..." if len(str(val)) > 100 else str(val)
            print(f"  {key}: {val_str}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Preprocess and register the KodCode V1 dataset (filtered by subset) into RLLM's DatasetRegistry, "
                    "or filter out failed tasks and push to HuggingFace."
    )
    
    # Add subparsers for different modes
    subparsers = parser.add_subparsers(dest="command", help="Command to run")
    
    # Subparser for 'prepare' command (original functionality)
    prepare_parser = subparsers.add_parser(
        "prepare",
        help="Preprocess and register KodCode dataset into RLLM's DatasetRegistry",
    )
    prepare_parser.add_argument(
        "--repo-id",
        type=str,
        required=True,
        help="HuggingFace dataset repo id to load (e.g., 'KodCode/KodCode-V1').",
    )
    prepare_parser.add_argument(
        "--registry-name",
        type=str,
        required=True,
        help="Name to register under in DatasetRegistry (e.g., 'kodcode').",
    )
    prepare_parser.add_argument(
        "--split-half",
        action="store_true",
        help="If set, shuffle (optional) and split evenly into train0.5/test0.5 registrations.",
    )
    prepare_parser.add_argument(
        "--no-split-half-shuffle",
        action="store_true",
        help="Disable shuffle before split-half (enabled by default).",
    )
    prepare_parser.add_argument(
        "--split-half-seed",
        type=int,
        default=0,
        help="Seed used for shuffle when split-half is enabled.",
    )
    
    # Subparser for 'filter-and-push' command (new functionality)
    filter_parser = subparsers.add_parser(
        "filter-and-push",
        help="Filter out failed task IDs and push to HuggingFace",
    )
    filter_parser.add_argument(
        "--source-repo-id",
        type=str,
        default="KodCode/KodCode-V1",
        help="Source HuggingFace dataset repo (default: KodCode/KodCode-V1)",
    )
    filter_parser.add_argument(
        "--target-repo-id",
        type=str,
        default="anonymous/KodCode-V1-filtered",
        help="Target HuggingFace repo to push to (default: anonymous/KodCode-V1-filtered)",
    )
    filter_parser.add_argument(
        "--failures-path",
        type=str,
        default=DEFAULT_FAILURES_PATH,
        help=f"Path to JSON file containing failed task IDs (default: {DEFAULT_FAILURES_PATH})",
    )
    filter_parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Don't actually push to HuggingFace, just show what would be done",
    )
    
    args = parser.parse_args()
    
    if args.command == "prepare":
        out = prepare_kodcode_data(
            repo_id=args.repo_id,
            registry_name=args.registry_name,
            split_half=bool(args.split_half),
            split_half_shuffle=not bool(args.no_split_half_shuffle),
            split_half_seed=int(args.split_half_seed),
        )

        if isinstance(out, tuple):
            train_half, test_half = out
            print("Train0.5 dataset path:", train_half.get_data_path())
            print(f"Train0.5 dataset length: {len(train_half)}")
            print("Test0.5 dataset path:", test_half.get_data_path())
            print(f"Test0.5 dataset length: {len(test_half)}")

            ds = test_half if len(test_half) > 0 else train_half
            if len(ds) > 0:
                print("\nSample example:")
                ex = ds[0]
                for k in ["uid", "data_source", "subset", "question", "reference_solution", "ground_truth"]:
                    print(f"\n=== {k} ===\n{ex.get(k)}")
        else:
            train_dataset = out
            print("Train dataset path:", train_dataset.get_data_path())
            print(f"Train dataset length: {len(train_dataset)}")

            if len(train_dataset) > 0:
                print("\nSample train example:")
                ex = train_dataset[0]
                for k in ["uid", "data_source", "subset", "question", "reference_solution", "ground_truth"]:
                    print(f"\n=== {k} ===\n{ex.get(k)}")
    
    elif args.command == "filter-and-push":
        filter_and_push_kodcode(
            source_repo_id=args.source_repo_id,
            target_repo_id=args.target_repo_id,
            failures_path=args.failures_path,
            dry_run=args.dry_run,
        )
    
    else:
        parser.print_help()