#!/usr/bin/env python3
"""
Load "test_all" and "test" splits from multiple BugBench datasets, find the intersection
of task_ids across all datasets and splits, filter each dataset to only include those
common task_ids, and push filtered datasets to custom repo_ids (or {repo_id}_new by default).

For each dataset, also creates a "train" split containing all task_ids from the original
dataset that are NOT in test or test_all splits AND are also in bigcodebench train.

When pushing to HuggingFace, splits are renamed:
- test_all -> test
- test -> test_small
- train -> train (unchanged)

Set NEW_DATASET_NAMES to specify custom repo_ids for each dataset.
"""
import argparse
import json
from typing import Dict, List, Set, Any
from datasets import load_dataset, DatasetDict, concatenate_datasets

from rllm.data.dataset import DatasetRegistry

# Dataset names from sanity_check_bugbench_splits.py
DATASET_NAMES = [
    "bugbench",
    "bugbench_human",
    "bugbench_qwen7b_sampled",
    "bugbench_gpt-oss-20b_sampled",
    "bugbench_adversarial",
]

SPLITS = ["test_all", "test"]

# Optional: Map dataset names to custom new repo_ids when pushing to HuggingFace.
# If a dataset name is not in this dict, it will default to {original_repo_id}_new
NEW_DATASET_NAMES = {
    "bugbench": "anonymous/bugbench_new",
    "bugbench_human": "anonymous/bugbench_human_new",
    "bugbench_qwen7b_sampled": "anonymous/bugbench_qwen7b_sampled_new",
    "bugbench_gpt-oss-20b_sampled": "anonymous/bugbench_gpt-oss-20b_sampled_new",
    "bugbench_adversarial": "anonymous/bugbench_adversarial_new",
}


def _get_task_id(example: dict) -> str:
    """Extract task_id from example, preferring 'task_id' then 'uid'."""
    task_id = example.get("task_id") or example.get("uid")
    if task_id is None:
        return ""
    task_id_str = str(task_id).strip()
    # Filter out "None" string and empty strings
    if task_id_str == "None" or not task_id_str:
        return ""
    return task_id_str


def _get_repo_id(dataset_name: str) -> str:
    """Convert dataset name to HuggingFace repo ID."""
    # Handle special case for adversarial_bugbench
    if dataset_name == "adversarial_bugbench":
        return "anonymous/adversarial_bugbench"
    return f"anonymous/{dataset_name}"


def _collect_task_ids_from_registry_dataset(ds: Any) -> Set[str]:
    """Extract task_ids from a DatasetRegistry dataset."""
    task_ids = set()
    try:
        # Try to get task_id column directly if it's a HF Dataset
        if hasattr(ds, "task_id"):
            task_ids = set(ds["task_id"])
        else:
            # Iterate through examples
            for ex in ds:
                task_id = _get_task_id(ex)
                if task_id and task_id != "None":  # Only add non-empty, non-None task_ids
                    task_ids.add(task_id)
    except Exception as e:
        raise Exception(f"Failed to collect task_ids from dataset: {e}") from e
    return task_ids


def load_all_splits(dataset_names: List[str], splits: List[str]) -> Dict[str, Dict[str, any]]:
    """
    Load all specified splits from all datasets.
    
    Returns:
        Dict mapping dataset_name -> Dict mapping split -> dataset
    """
    all_datasets = {}
    
    for dataset_name in dataset_names:
        repo_id = _get_repo_id(dataset_name)
        all_datasets[dataset_name] = {}
        
        print(f"\n[Loading] {repo_id}")
        for split in splits:
            try:
                ds = load_dataset(repo_id, split=split)
                print(f"  {split}: n={len(ds)}")
                all_datasets[dataset_name][split] = ds
            except Exception as e:
                print(f"  {split}: ERROR - {e}")
                # Continue with other splits even if one fails
                all_datasets[dataset_name][split] = None
    
    return all_datasets


def extract_task_ids(all_datasets: Dict[str, Dict[str, any]]) -> Dict[str, Dict[str, Set[str]]]:
    """
    Extract task_ids from all loaded datasets.
    
    Returns:
        Dict mapping dataset_name -> Dict mapping split -> set of task_ids
    """
    task_ids_by_dataset = {}
    
    for dataset_name, splits_dict in all_datasets.items():
        task_ids_by_dataset[dataset_name] = {}
        for split, ds in splits_dict.items():
            if ds is None:
                task_ids_by_dataset[dataset_name][split] = set()
                continue
            
            task_ids = set()
            for example in ds:
                task_id = _get_task_id(example)
                if task_id and task_id != "None":  # Only add non-empty, non-None task_ids
                    task_ids.add(task_id)
            
            task_ids_by_dataset[dataset_name][split] = task_ids
            print(f"[{dataset_name}/{split}] task_ids: n={len(task_ids)}")
    
    return task_ids_by_dataset


def find_intersection(task_ids_by_dataset: Dict[str, Dict[str, Set[str]]]) -> Dict[str, Set[str]]:
    """
    Find the intersection of task_ids separately for each split.
    
    For each split (e.g., 'test_all', 'test'), find the intersection across all datasets
    for that split only. This means test_all intersection only considers test_all splits,
    not test splits.
    
    Returns:
        Dict mapping split -> set of common task_ids for that split
    """
    # Group task_ids by split
    task_ids_by_split = {}
    for dataset_name, splits_dict in task_ids_by_dataset.items():
        for split, task_ids in splits_dict.items():
            if task_ids:  # Only include non-empty sets
                if split not in task_ids_by_split:
                    task_ids_by_split[split] = []
                task_ids_by_split[split].append((dataset_name, task_ids))
                print(f"[{dataset_name}/{split}] contributing {len(task_ids)} task_ids")
    
    if not task_ids_by_split:
        print("[WARNING] No task_id sets found!")
        return {}
    
    # Find intersection separately for each split
    intersection_by_split = {}
    
    for split, sets_list in task_ids_by_split.items():
        print(f"\n[Intersection/{split}] Finding intersection across {len(sets_list)} datasets")
        
        if not sets_list:
            intersection_by_split[split] = set()
            continue
        
        # Start with the first set
        first_dataset, common_task_ids = sets_list[0]
        print(f"  Starting with {first_dataset} ({len(common_task_ids)} task_ids)")
        
        for dataset_name, task_ids in sets_list[1:]:
            before = len(common_task_ids)
            common_task_ids = common_task_ids & task_ids
            after = len(common_task_ids)
            print(f"  Intersecting with {dataset_name}: {before} -> {after} task_ids")
        
        # Filter out "None" and empty strings from intersection
        common_task_ids = {tid for tid in common_task_ids if tid and tid != "None" and tid.strip()}
        
        print(f"  [Intersection/{split}] Final common task_ids: n={len(common_task_ids)}")
        if len(common_task_ids) > 0:
            print(f"    Sample task_ids: {sorted(list(common_task_ids))[:5]}")
        
        intersection_by_split[split] = common_task_ids
    
    return intersection_by_split


def load_original_dataset_all_splits(dataset_name: str) -> any:
    """
    Load the original dataset with all splits to get all task_ids.
    Tries to load as DatasetDict first, then falls back to individual splits.
    """
    repo_id = _get_repo_id(dataset_name)
    try:
        # Try loading as DatasetDict to get all splits
        dd = load_dataset(repo_id)
        if isinstance(dd, DatasetDict):
            # Combine all splits to get all task_ids
            all_splits = []
            for split_name, split_ds in dd.items():
                all_splits.append(split_ds)
            if all_splits:
                combined = concatenate_datasets(all_splits)
                return combined
            return None
        # If it's a single dataset, return it
        return dd
    except:
        try:
            # Try test_all first (most common, usually contains all examples)
            ds = load_dataset(repo_id, split="test_all")
            return ds
        except:
            try:
                # Try train
                ds = load_dataset(repo_id, split="train")
                return ds
            except Exception as e:
                print(f"  [WARN] Could not load original dataset for {dataset_name}: {e}")
                return None


def filter_and_push(
    all_datasets: Dict[str, Dict[str, any]],
    intersection_by_split: Dict[str, Set[str]],
    push_to_hub: bool = True,
    new_repo_ids: Dict[str, str] = None,
    bigcodebench_registry_name: str = "bigcodebench_new",
    bigcodebench_train_split: str = "train",
    bigcodebench_test_split: str = "test_all",
) -> None:
    """
    Filter each dataset's splits to only include common task_ids for that specific split,
    add a train split with remaining task_ids that are also in bigcodebench train, and push to the specified new repo_id.
    
    Each split uses its own intersection (e.g., test_all only uses test_all intersection).
    Train split contains all task_ids from the original dataset that are:
    - NOT in test or test_all
    - AND in bigcodebench train
    
    When pushing to HuggingFace, splits are renamed:
    - test_all -> test
    - test -> test_small
    - train -> train (unchanged)
    
    Args:
        all_datasets: Dict mapping dataset_name -> Dict mapping split -> dataset
        intersection_by_split: Dict mapping split -> set of common task_ids
        push_to_hub: Whether to push to HuggingFace
        new_repo_ids: Optional dict mapping dataset_name -> new repo_id. If not provided
                     or if a dataset_name is missing, defaults to {original_repo_id}_new
        bigcodebench_registry_name: Registry name for bigcodebench dataset
        bigcodebench_train_split: Split name for bigcodebench train
        bigcodebench_test_split: Split name for bigcodebench test
    """
    if new_repo_ids is None:
        new_repo_ids = {}
    
    # Load bigcodebench train and test to get their task_ids
    print(f"\n[Loading BigCodeBench] {bigcodebench_registry_name}/{bigcodebench_train_split}...")
    try:
        bcb_train = DatasetRegistry.load_dataset(bigcodebench_registry_name, bigcodebench_train_split)
        bcb_train_task_ids = _collect_task_ids_from_registry_dataset(bcb_train)
        print(f"  [BigCodeBench train] {len(bcb_train_task_ids)} task_ids")
    except Exception as e:
        print(f"  [WARN] Failed to load BigCodeBench train: {e}")
        print(f"  [WARN] Train split will not be filtered by BigCodeBench train task_ids")
        bcb_train_task_ids = None
    
    print(f"\n[Loading BigCodeBench] {bigcodebench_registry_name}/test...")
    try:
        bcb_test = DatasetRegistry.load_dataset(bigcodebench_registry_name, bigcodebench_test_split)
        bcb_test_task_ids = _collect_task_ids_from_registry_dataset(bcb_test)
        print(f"  [BigCodeBench test] {len(bcb_test_task_ids)} task_ids")
    except Exception as e:
        print(f"  [WARN] Failed to load BigCodeBench test: {e}")
        print(f"  [WARN] Cannot validate test split intersection")
        bcb_test_task_ids = None
    
    for dataset_name, splits_dict in all_datasets.items():
        repo_id = _get_repo_id(dataset_name)
        # Use custom repo_id if provided, otherwise default to _new suffix
        new_repo_id = new_repo_ids.get(dataset_name, f"{repo_id}_new")
        
        print(f"\n[Processing] {dataset_name} -> {new_repo_id}")
        
        filtered_splits = {}
        
        # First, filter test and test_all splits
        test_task_ids = set()
        test_all_task_ids = set()
        
        for split, ds in splits_dict.items():
            if ds is None:
                print(f"  {split}: SKIPPED (failed to load)")
                continue
            
            # Get intersection for this specific split
            common_task_ids = intersection_by_split.get(split, set())
            if not common_task_ids:
                print(f"  {split}: SKIPPED (no intersection found for this split)")
                continue
            
            # Filter to only include common task_ids for this split (and exclude "None")
            filtered_ds = ds.filter(
                lambda ex: _get_task_id(ex) in common_task_ids and _get_task_id(ex) != "None",
                num_proc=16
            )
            
            print(f"  {split}: {len(ds)} -> {len(filtered_ds)} (filtered using {split}-only intersection)")
            filtered_splits[split] = filtered_ds
            
            # Track task_ids for train split computation
            if split == "test":
                # Filter out "None" from test_task_ids
                test_task_ids = {tid for tid in common_task_ids if tid and tid != "None" and tid.strip()}
            elif split == "test_all":
                # Filter out "None" from test_all_task_ids
                test_all_task_ids = {tid for tid in common_task_ids if tid and tid != "None" and tid.strip()}
        
        # Validate test split: ensure test_task_ids don't intersect with bigcodebench train
        # and only match bigcodebench test
        if test_task_ids and (bcb_train_task_ids is not None or bcb_test_task_ids is not None):
            print(f"\n  [Validation] Checking test split task_ids...")
            
            # Check intersection with bigcodebench train (should be empty)
            if bcb_train_task_ids is not None:
                test_in_train = test_task_ids & bcb_train_task_ids
                if test_in_train:
                    print(f"  [ERROR] {len(test_in_train)} test task_ids intersect with BigCodeBench train!")
                    print(f"    Sample overlapping task_ids: {sorted(list(test_in_train))[:10]}")
                    raise ValueError(
                        f"BugBench test split contains {len(test_in_train)} task_ids that are in BigCodeBench train. "
                        f"This violates the requirement that test should only match BigCodeBench test."
                    )
                else:
                    print(f"  [OK] Test task_ids do not intersect with BigCodeBench train (0 overlaps)")
            
            # Check intersection with bigcodebench test (should match)
            if bcb_test_task_ids is not None:
                test_in_bcb_test = test_task_ids & bcb_test_task_ids
                test_not_in_bcb_test = test_task_ids - bcb_test_task_ids
                
                if test_not_in_bcb_test:
                    print(f"  [WARN] {len(test_not_in_bcb_test)} test task_ids are NOT in BigCodeBench test!")
                    print(f"    Sample non-matching task_ids: {sorted(list(test_not_in_bcb_test))[:10]}")
                    print(f"  [INFO] {len(test_in_bcb_test)} test task_ids ARE in BigCodeBench test")
                else:
                    print(f"  [OK] All {len(test_task_ids)} test task_ids are in BigCodeBench test")
        
        # Load original dataset to get all task_ids
        print(f"  [Loading original dataset] to compute train split...")
        original_ds = load_original_dataset_all_splits(dataset_name)
        
        if original_ds is None:
            print(f"  [WARN] Could not load original dataset for {dataset_name}, skipping train split")
        else:
            # Extract all task_ids from original dataset
            all_original_task_ids = set()
            for example in original_ds:
                task_id = _get_task_id(example)
                if task_id and task_id != "None":  # Only add non-empty, non-None task_ids
                    all_original_task_ids.add(task_id)
            
            print(f"  [Original dataset] Total task_ids: {len(all_original_task_ids)}")
            
            # Train split = (all original task_ids - test_all_task_ids - test_task_ids) ∩ bigcodebench_train_task_ids
            # (test_all typically contains test, so we subtract both to be safe)
            train_task_ids = all_original_task_ids - test_all_task_ids - test_task_ids
            
            # Intersect with bigcodebench train task_ids if available
            if bcb_train_task_ids is not None:
                before_intersect = len(train_task_ids)
                train_task_ids = train_task_ids & bcb_train_task_ids
                after_intersect = len(train_task_ids)
                print(f"  [Train split] After excluding test/test_all: {before_intersect} task_ids")
                print(f"  [Train split] After intersecting with BigCodeBench train: {after_intersect} task_ids")
            else:
                print(f"  [Train split] Task_ids (excluding test/test_all): {len(train_task_ids)}")
            
            if train_task_ids:
                print(f"  [Train split] Final task_ids: {len(train_task_ids)} (all: {len(all_original_task_ids)}, "
                      f"test_all: {len(test_all_task_ids)}, test: {len(test_task_ids)}, "
                      f"bcb_train: {len(bcb_train_task_ids) if bcb_train_task_ids is not None else 'N/A'})")
                
                # Filter original dataset to get train split
                train_ds = original_ds.filter(
                    lambda ex: _get_task_id(ex) in train_task_ids,
                    num_proc=16
                )
                print(f"  [Train split] Examples: {len(train_ds)}")
                filtered_splits["train"] = train_ds
                print(f"  [Train split] Added to filtered_splits - will be pushed to HF")
            else:
                print(f"  [Train split] No task_ids remaining for train split")
        
        if not filtered_splits:
            print(f"  [SKIP] No valid splits to push for {dataset_name}")
            continue
        
        # Rename splits for output: test_all -> test, test -> test_small
        renamed_splits = {}
        split_rename_map = {"test_all": "test", "test": "test_small"}
        for old_name, ds in filtered_splits.items():
            new_name = split_rename_map.get(old_name, old_name)
            renamed_splits[new_name] = ds
        
        # Print summary of all splits (using renamed names)
        print(f"\n  [Summary] {dataset_name} splits:")
        for split_name in ["train", "test_small", "test"]:
            if split_name in renamed_splits:
                print(f"    {split_name}: {len(renamed_splits[split_name])} examples")
        
        # Push to HuggingFace
        if push_to_hub:
            dd = DatasetDict(renamed_splits)
            splits_list = sorted(list(renamed_splits.keys()))
            print(f"  [Pushing] to {new_repo_id}")
            print(f"  [Pushing] Splits: {splits_list}")
            if "train" in splits_list:
                print(f"  [Pushing] Train split: {len(renamed_splits['train'])} examples")
            dd.push_to_hub(repo_id=new_repo_id)
            print(f"  [Done] Pushed {len(renamed_splits)} splits ({', '.join(splits_list)}) to {new_repo_id}")
            if "train" in splits_list:
                print(f"  [Done] Train split successfully pushed to {new_repo_id}")
        else:
            splits_list = sorted(list(renamed_splits.keys()))
            print(f"  [DRY RUN] Would push {len(renamed_splits)} splits ({', '.join(splits_list)}) to {new_repo_id}")


def main():
    parser = argparse.ArgumentParser(
        description="Find common task_ids across BugBench datasets and push filtered versions to _v2 repos."
    )
    parser.add_argument(
        "--datasets",
        nargs="+",
        default=DATASET_NAMES,
        help=f"Dataset names to process (default: {DATASET_NAMES})",
    )
    parser.add_argument(
        "--splits",
        nargs="+",
        default=SPLITS,
        help=f"Splits to load (default: {SPLITS})",
    )
    parser.add_argument(
        "--no-push",
        action="store_true",
        help="Don't push to HuggingFace (dry run)",
    )
    parser.add_argument(
        "--bigcodebench-registry-name",
        type=str,
        default="bigcodebench_new",
        help="Registry name for bigcodebench dataset (default: bigcodebench_new)",
    )
    parser.add_argument(
        "--bigcodebench-train-split",
        type=str,
        default="train",
        help="Split name for bigcodebench train (default: train)",
    )
    parser.add_argument(
        "--bigcodebench-test-split",
        type=str,
        default="test_all",
        help="Split name for bigcodebench test (default: test_all)",
    )
    
    args = parser.parse_args()
    
    print("=" * 80)
    print("Gathering Common Indices")
    print("=" * 80)
    print(f"Datasets: {args.datasets}")
    print(f"Splits: {args.splits}")
    print(f"Push to hub: {not args.no_push}")
    print("=" * 80)
    
    # Step 1: Load all splits
    print("\n[Step 1] Loading all splits...")
    all_datasets = load_all_splits(args.datasets, args.splits)
    
    # Validate that we have at least some valid datasets
    valid_datasets = {
        name: splits
        for name, splits in all_datasets.items()
        if any(ds is not None for ds in splits.values())
    }
    
    if not valid_datasets:
        print("\n[ERROR] No valid datasets loaded! Cannot proceed.")
        return
    
    print(f"\n[Valid datasets] {len(valid_datasets)}/{len(args.datasets)} datasets loaded successfully")
    
    # Step 2: Extract task_ids
    print("\n[Step 2] Extracting task_ids...")
    task_ids_by_dataset = extract_task_ids(all_datasets)
    
    # Step 3: Find intersection (separately for each split)
    print("\n[Step 3] Finding intersection for each split...")
    intersection_by_split = find_intersection(task_ids_by_dataset)
    
    if not intersection_by_split:
        print("\n[ERROR] No intersection found for any split! Cannot proceed.")
        return
    
    # Check if we have at least one non-empty intersection
    has_valid_intersection = any(
        len(task_ids) > 0 for task_ids in intersection_by_split.values()
    )
    if not has_valid_intersection:
        print("\n[ERROR] All intersections are empty! Cannot proceed.")
        return
    
    # Save common indices to JSON files (one file per split)
    print("\n[Saving] Common indices to JSON files...")
    for split, task_ids in intersection_by_split.items():
        task_ids_list = sorted(list(task_ids))  # Convert set to sorted list for JSON serialization
        output_file = f"common_indices_{split}.json"
        with open(output_file, "w") as f:
            json.dump(task_ids_list, f, indent=2)
        print(f"  [Saved] {split}: {len(task_ids_list)} task_ids -> {output_file}")
    
    # Step 4: Filter and push
    print("\n[Step 4] Filtering and pushing...")
    if NEW_DATASET_NAMES:
        print(f"[Using custom repo_ids] {NEW_DATASET_NAMES}")
    filter_and_push(
        all_datasets, 
        intersection_by_split, 
        push_to_hub=not args.no_push,
        new_repo_ids=NEW_DATASET_NAMES,
        bigcodebench_registry_name=args.bigcodebench_registry_name,
        bigcodebench_train_split=args.bigcodebench_train_split,
        bigcodebench_test_split=args.bigcodebench_test_split,
    )
    
    print("\n" + "=" * 80)
    print("Done!")
    print("=" * 80)


if __name__ == "__main__":
    main()
