#!/usr/bin/env python3
"""
CLI Tool for Preparing HuggingFace Datasets for VERL Training

Automates the process of:
1. Loading datasets from HuggingFace Hub
2. Applying custom prompt templates
3. Converting to VERL parquet format
4. Organizing train/eval splits
"""

import argparse
import os
from pathlib import Path
from typing import List, Optional
from hf_integration import HFIntegration, SYSTEM_PROMPT_SFT, SYSTEM_PROMPT_RSFT


def prepare_datasets(
    train_datasets: List[str],
    eval_datasets: List[str],
    output_dir: str,
    prompt_template: str = None,
    data_source: str = None,
    cache_dir: str = None,
    no_cache: bool = False,
    system_prompt: Optional[str] = None
):
    """
    Prepare HuggingFace datasets for VERL training

    Args:
        train_datasets: List of HF training dataset names
        eval_datasets: List of HF evaluation dataset names
        output_dir: Output directory for parquet files
        prompt_template: Path to prompt template file
        data_source: Task identifier (auto-detected if None)
        cache_dir: Cache directory for preprocessed data
        no_cache: Disable caching
        system_prompt: System prompt to include in messages (use "rsft" or "sft" for presets)
    """
    # Initialize HF integration
    integration = HFIntegration(cache_dir=cache_dir)

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    print("=" * 80)
    print("PREPARING HUGGINGFACE DATASETS FOR VERL")
    print("=" * 80)

    if system_prompt:
        print(f"System prompt: {system_prompt[:80]}...")

    # Process training datasets
    if train_datasets:
        print(f"\nProcessing {len(train_datasets)} training dataset(s)...")
        train_paths = []

        for idx, dataset_name in enumerate(train_datasets):
            print(f"\n[{idx+1}/{len(train_datasets)}] Loading {dataset_name}...")

            output_path = os.path.join(output_dir, f"train_{idx}.parquet")

            path = integration.load_hf_dataset_to_parquet(
                dataset_name=dataset_name,
                output_path=output_path,
                prompt_template_path=prompt_template,
                data_source=data_source,
                split="train",
                use_cache=not no_cache,
                system_prompt=system_prompt
            )

            train_paths.append(path)
            print(f"✓ Saved to {path}")

        # Create consolidated train file for easier config
        if len(train_paths) == 1:
            # Single dataset - create symlink
            consolidated_path = os.path.join(output_dir, "train_rl.parquet")
            if os.path.exists(consolidated_path):
                os.remove(consolidated_path)
            os.symlink(os.path.basename(train_paths[0]), consolidated_path)
            print(f"\n✓ Created train_rl.parquet → {os.path.basename(train_paths[0])}")
        else:
            # Multiple datasets - concatenate
            import pandas as pd
            dfs = [pd.read_parquet(p) for p in train_paths]
            combined_df = pd.concat(dfs, ignore_index=True)
            consolidated_path = os.path.join(output_dir, "train_rl.parquet")
            combined_df.to_parquet(consolidated_path, index=False)
            print(f"\n✓ Concatenated {len(train_datasets)} datasets → train_rl.parquet ({len(combined_df)} examples)")

    # Process evaluation datasets
    if eval_datasets:
        print(f"\nProcessing {len(eval_datasets)} evaluation dataset(s)...")
        eval_paths = []

        for idx, dataset_name in enumerate(eval_datasets):
            print(f"\n[{idx+1}/{len(eval_datasets)}] Loading {dataset_name}...")

            # Use dataset slug as unique data_source for per-dataset metrics
            dataset_slug = dataset_name.replace('/', '_').replace('-', '_')
            eval_data_source = dataset_slug  # e.g., "anon-neurips26_bridges_5x5de_intformat"
            output_path = os.path.join(output_dir, f"eval_{dataset_slug}.parquet")

            path = integration.load_hf_dataset_to_parquet(
                dataset_name=dataset_name,
                output_path=output_path,
                prompt_template_path=prompt_template,
                data_source=eval_data_source,  # Use unique slug for per-dataset validation metrics
                split="test",  # Use test split for eval
                use_cache=not no_cache,
                system_prompt=system_prompt
            )

            eval_paths.append(path)
            print(f"✓ Saved to {path}")

    print("\n" + "=" * 80)
    print("DATASET PREPARATION COMPLETE")
    print("=" * 80)
    print(f"\nOutput directory: {output_dir}")
    if train_datasets:
        print(f"Training data: train_rl.parquet")
    if eval_datasets:
        print(f"Evaluation data: {len(eval_datasets)} eval_*.parquet files")
    print("\nYou can now use these files in VERL training!")


def main():
    parser = argparse.ArgumentParser(
        description="Prepare HuggingFace datasets for VERL training",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Single training dataset
  python prepare_hf_datasets.py \\
      --train_datasets anon-neurips26/bridges_5x5dm_grpo_train_5k_intformat \\
      --output_dir data/bridges/hf_preprocessed

  # Multiple training datasets
  python prepare_hf_datasets.py \\
      --train_datasets anon-neurips26/bridges_5x5de anon-neurips26/bridges_7x7dm \\
      --output_dir data/bridges/hf_preprocessed

  # With evaluation datasets and RSFT system prompt (for models trained with reasoning)
  python prepare_hf_datasets.py \\
      --train_datasets anon-neurips26/bridges_5x5dm_grpo_train_5k_intformat \\
      --eval_datasets anon-neurips26/bridges_5x5de_intformat anon-neurips26/bridges_5x5dm_test200_intformat \\
      --prompt_template prompts/bridges_intformat.txt \\
      --system_prompt rsft \\
      --output_dir data/bridges/hf_preprocessed

  # With custom prompt template
  python prepare_hf_datasets.py \\
      --train_datasets anon-neurips26/bridges_5x5de_intformat \\
      --prompt_template prompts/bridges/intformat.txt \\
      --output_dir data/bridges/hf_preprocessed
        """
    )

    parser.add_argument(
        "--train_datasets",
        nargs="+",
        help="HuggingFace training dataset name(s)"
    )

    parser.add_argument(
        "--eval_datasets",
        nargs="+",
        help="HuggingFace evaluation dataset name(s)"
    )

    parser.add_argument(
        "--output_dir",
        required=True,
        help="Output directory for parquet files"
    )

    parser.add_argument(
        "--prompt_template",
        help="Path to prompt template file (optional)"
    )

    parser.add_argument(
        "--data_source",
        help="Task identifier (auto-detected if not specified)"
    )

    parser.add_argument(
        "--cache_dir",
        default=None,
        help="Cache directory for preprocessed data (default: ~/.cache/hf_verl)"
    )

    parser.add_argument(
        "--no_cache",
        action="store_true",
        help="Disable caching (force re-download)"
    )

    parser.add_argument(
        "--system_prompt",
        default=None,
        help="System prompt to include. Use 'rsft' for RSFT format (reasoning+answer tags), "
             "'sft' for SFT format (answer only), or provide a custom string. "
             "IMPORTANT: This must match how the model was trained!"
    )

    args = parser.parse_args()

    if not args.train_datasets and not args.eval_datasets:
        parser.error("Must specify at least one of --train_datasets or --eval_datasets")

    # Resolve system prompt shorthand
    system_prompt = args.system_prompt
    if system_prompt == "rsft":
        system_prompt = SYSTEM_PROMPT_RSFT
        print("Using RSFT system prompt (reasoning + answer tags)")
    elif system_prompt == "sft":
        system_prompt = SYSTEM_PROMPT_SFT
        print("Using SFT system prompt (answer only)")

    prepare_datasets(
        train_datasets=args.train_datasets or [],
        eval_datasets=args.eval_datasets or [],
        output_dir=args.output_dir,
        prompt_template=args.prompt_template,
        data_source=args.data_source,
        cache_dir=args.cache_dir,
        no_cache=args.no_cache,
        system_prompt=system_prompt
    )


if __name__ == "__main__":
    main()
