#!/usr/bin/env python3
"""
Join HF datasets against `anonymous/bigcodebench` by `task_id` and add:
  - instruct_prompt
  - complete_prompt

This is useful for datasets that only store `task_id` (plus other columns), but are missing
the prompt fields needed by downstream codegen / bug-fixing workflows.

Example:
  python3 -m examples.bugs.hf_dataset_formatting.add_prompts_from_bigcodebench \
    --source_repo anonymous/bugbench_qwen7b_sampled \
    --repo_id anonymous/bugbench_qwen7b_sampled_with_prompts \
    --splits test

Auth:
- If you are not logged in via `huggingface-cli login`, pass `--token` or set
  `HF_TOKEN` / `HUGGINGFACEHUB_API_TOKEN`.
"""

from __future__ import annotations

import argparse
import os
from typing import Any, Dict, Iterable, Optional, Tuple

from datasets import Dataset, DatasetDict, load_dataset
from huggingface_hub import HfApi


def _resolve_token(explicit: Optional[str]) -> Optional[str]:
    if explicit:
        return explicit
    return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or None


def _build_prompt_index(
    ref: Dataset,
    *,
    ref_id_col: str = "task_id",
    instruct_col: str = "instruct_prompt",
    complete_col: str = "complete_prompt",
) -> Dict[str, Tuple[str, str]]:
    """
    Build ref_id -> (instruct_prompt, complete_prompt) mapping.
    """
    idx: Dict[str, Tuple[str, str]] = {}
    for ex in ref:
        rid = ex.get(ref_id_col)
        if rid is None:
            continue
        rid = str(rid)
        instr = ex.get(instruct_col, "") or ""
        comp = ex.get(complete_col, "") or ""
        # Keep first occurrence if duplicates exist.
        if rid not in idx:
            idx[rid] = (str(instr), str(comp))
    return idx


def _add_columns_from_index(
    ds: Dataset,
    *,
    prompt_index: Dict[str, Tuple[str, str]],
    source_id_col: str = "task_id",
    instruct_col: str = "instruct_prompt",
    complete_col: str = "complete_prompt",
    overwrite: bool = True,
) -> Dataset:
    missing = 0

    def fn(batch: Dict[str, Any]) -> Dict[str, Any]:
        nonlocal missing
        ids: Iterable[Any] = batch.get(source_id_col, [])
        out_instruct = []
        out_complete = []

        existing_instruct = batch.get(instruct_col, None)
        existing_complete = batch.get(complete_col, None)

        for i, _id in enumerate(ids):
            key = str(_id)
            pair = prompt_index.get(key)
            if pair is None:
                missing += 1
                instr, comp = "", ""
            else:
                instr, comp = pair

            if not overwrite:
                if isinstance(existing_instruct, list) and i < len(existing_instruct) and str(existing_instruct[i]).strip():
                    instr = str(existing_instruct[i])
                if isinstance(existing_complete, list) and i < len(existing_complete) and str(existing_complete[i]).strip():
                    comp = str(existing_complete[i])

            out_instruct.append(instr)
            out_complete.append(comp)

        return {instruct_col: out_instruct, complete_col: out_complete}

    ds2 = ds.map(fn, batched=True, desc="Adding prompts from anonymous/bigcodebench")
    total = len(ds2)
    print(f"[add_prompts] missing ids: {missing}/{total} ({(100.0*missing/total) if total else 0.0:.2f}%)")
    return ds2


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--source_repo", required=True, help="HF repo to augment (must have task_id column)")
    ap.add_argument("--repo_id", required=True, help="Output HF repo_id to push to")
    ap.add_argument("--splits", nargs="*", default=None, help="Splits to process (default: all splits in source)")

    ap.add_argument("--ref_repo", default="anonymous/bigcodebench", help="Reference repo (default: anonymous/bigcodebench)")
    ap.add_argument("--ref_split", default="v0.1.0_hf", help="Reference split to index (default: v0.1.0_hf)")

    ap.add_argument("--source_id_col", default="uid", help="ID column in source dataset (e.g., uid)")
    ap.add_argument("--ref_id_col", default="task_id", help="ID column in reference dataset (usually task_id)")
    ap.add_argument("--instruct_col", default="instruct_prompt")
    ap.add_argument("--complete_col", default="complete_prompt")
    ap.add_argument("--overwrite", action="store_true", help="Overwrite existing prompt columns if present (default: False)")
    ap.add_argument("--token", default=None, help="HF token (or set HF_TOKEN / HUGGINGFACEHUB_API_TOKEN)")
    ap.add_argument("--private", action="store_true", help="Push dataset as private")
    ap.add_argument(
        "--create_repo",
        action="store_true",
        help="Create the destination dataset repo_id on the hub if it doesn't exist",
    )

    args = ap.parse_args()

    token = _resolve_token(args.token)
    if not token:
        raise ValueError(
            "No Hugging Face token found. Please run `huggingface-cli login` OR pass --token OR set HF_TOKEN / HUGGINGFACEHUB_API_TOKEN."
        )

    print(f"[add_prompts] loading reference: {args.ref_repo} split={args.ref_split}")
    ref = load_dataset(args.ref_repo, split=str(args.ref_split))
    # Reduce memory: keep only needed columns if present
    keep_cols = [args.ref_id_col, args.instruct_col, args.complete_col]
    present = [c for c in keep_cols if c in ref.column_names]
    if len(present) == len(keep_cols):
        ref = ref.select_columns(keep_cols)
    else:
        missing_cols = [c for c in keep_cols if c not in ref.column_names]
        raise ValueError(f"Reference dataset missing columns: {missing_cols}. Available: {ref.column_names}")

    prompt_index = _build_prompt_index(
        ref,
        ref_id_col=args.ref_id_col,
        instruct_col=args.instruct_col,
        complete_col=args.complete_col,
    )
    print(f"[add_prompts] built prompt index: {len(prompt_index)} task_ids")

    print(f"[add_prompts] loading source: {args.source_repo}")
    src_any = load_dataset(args.source_repo)
    if isinstance(src_any, DatasetDict):
        src_dd = src_any
    else:
        # Some repos may only have a single split returned as Dataset; normalize.
        src_dd = DatasetDict({"default": src_any})

    splits = list(src_dd.keys()) if args.splits is None else [str(s) for s in args.splits]
    for s in splits:
        if s not in src_dd:
            raise ValueError(f"Split {s!r} not found in {args.source_repo}. Available: {list(src_dd.keys())}")

    out: Dict[str, Dataset] = {}
    for split in splits:
        print(f"[add_prompts] processing split={split} (n={len(src_dd[split])})")
        if args.source_id_col not in src_dd[split].column_names:
            raise ValueError(
                f"Source split {split!r} missing {args.source_id_col!r}. Available: {src_dd[split].column_names}"
            )
        out[split] = _add_columns_from_index(
            src_dd[split],
            prompt_index=prompt_index,
            source_id_col=args.source_id_col,
            instruct_col=args.instruct_col,
            complete_col=args.complete_col,
            overwrite=bool(args.overwrite),
        )

    out_dd = DatasetDict(out)

    print(f"[add_prompts] pushing to hub: {args.repo_id} (private={bool(args.private)})")
    if args.create_repo:
        print(f"[add_prompts] ensuring destination repo exists: {args.repo_id}")
        HfApi().create_repo(
            repo_id=str(args.repo_id),
            repo_type="dataset",
            token=token,
            private=bool(args.private),
            exist_ok=True,
        )
    out_dd.push_to_hub(
        repo_id=str(args.repo_id),
        token=token,
        private=bool(args.private),
    )
    print("[add_prompts] done")


if __name__ == "__main__":
    main()


