"""
HuggingFace Integration for VERL Training

Provides utilities to:
- Load datasets from HuggingFace Hub
- Convert to VERL parquet format
- Load pretrained LoRA checkpoints
- Cache preprocessed data
"""

import os
import json
import hashlib
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd

# System prompts matching the Unsloth SFT/RSFT training scripts
# These MUST match what the model was trained with for proper output format
SYSTEM_PROMPT_SFT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it."
    "Place the final solution within <answer> answer here </answer>"
)
SYSTEM_PROMPT_RSFT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it step by step by reasoning."
    "Provide the reasoning in <reasoning> reasoning here </reasoning> and the final solution within <answer> answer here </answer>"
)


class HFIntegration:
    """Handles HuggingFace dataset and model integration for VERL"""

    def __init__(self, cache_dir: Optional[str] = None):
        """
        Initialize HF integration

        Args:
            cache_dir: Directory to cache preprocessed datasets
        """
        self.cache_dir = cache_dir or os.path.expanduser("~/.cache/hf_verl")
        Path(self.cache_dir).mkdir(parents=True, exist_ok=True)

    def _get_cache_key(self, dataset_name: str, prompt_template: Optional[str] = None) -> str:
        """Generate cache key for dataset + prompt combination"""
        key_str = f"{dataset_name}_{prompt_template or 'default'}"
        return hashlib.md5(key_str.encode()).hexdigest()

    def _load_prompt_template(self, prompt_template_path: Optional[str]) -> Optional[str]:
        """Load prompt template from file"""
        if not prompt_template_path:
            return None

        if not os.path.exists(prompt_template_path):
            print(f"Warning: Prompt template not found at {prompt_template_path}")
            return None

        with open(prompt_template_path, 'r') as f:
            return f.read().strip()

    def _convert_to_verl_format(
        self,
        dataset: Dataset,
        prompt_template: Optional[str] = None,
        data_source: str = "bridges",
        ability: str = "puzzle",
        system_prompt: Optional[str] = None
    ) -> pd.DataFrame:
        """
        Convert HuggingFace dataset to VERL parquet format

        Expected HF dataset columns:
        - prompt or question: The input
        - solution or answer: The expected output
        - Other columns become extra_info

        VERL format:
        - data_source: Task identifier
        - prompt: List of chat messages (optionally with system prompt)
        - ability: Task category
        - reward_model: Dict with style and ground_truth
        - extra_info: Dict with metadata

        Args:
            dataset: HuggingFace dataset
            prompt_template: Optional custom prompt template with {} placeholder
            data_source: Task identifier for reward routing
            ability: Task category
            system_prompt: Optional system prompt to prepend to messages.
                          Use SYSTEM_PROMPT_RSFT for models trained with reasoning format.
        """
        records = []

        for idx, example in enumerate(dataset):
            # Extract prompt (try different column names)
            prompt_text = None
            for key in ['prompt', 'question', 'input', 'initial_state', 'problem']:
                if key in example:
                    prompt_text = example[key]
                    break

            if prompt_text is None:
                print(f"Warning: No prompt found in example {idx}, skipping")
                continue

            # Apply custom prompt template if provided
            if prompt_template:
                # Support multiple placeholder formats: {}, {input}, {puzzle}
                puzzle_str = str(prompt_text)
                if "{}" in prompt_template:
                    prompt_text = prompt_template.replace("{}", puzzle_str)
                elif "{input}" in prompt_template:
                    prompt_text = prompt_template.replace("{input}", puzzle_str)
                elif "{puzzle}" in prompt_template:
                    prompt_text = prompt_template.replace("{puzzle}", puzzle_str)
                else:
                    # No placeholder found, append puzzle to template
                    prompt_text = prompt_template + "\n" + puzzle_str

            # Extract ground truth (try different column names)
            ground_truth = None
            for key in ['solution', 'answer', 'target', 'expected_solution']:
                if key in example:
                    ground_truth = example[key]
                    break

            if ground_truth is None:
                print(f"Warning: No ground truth found in example {idx}, skipping")
                continue

            # Build extra_info from remaining columns
            extra_info = {}
            excluded_keys = {'prompt', 'question', 'input', 'solution', 'answer', 'target', 'initial_state', 'expected_solution', 'problem'}
            for key, value in example.items():
                if key not in excluded_keys:
                    extra_info[key] = value

            # Preserve initial state for reward functions (partial scoring needs it)
            for key in ['initial_state', 'problem']:
                if key in example:
                    extra_info['initial_state'] = str(example[key])
                    break

            # Build message list with optional system prompt
            messages = []
            if system_prompt:
                messages.append({'role': 'system', 'content': system_prompt})
            messages.append({'role': 'user', 'content': str(prompt_text)})

            # Create VERL record
            record = {
                'data_source': data_source,
                'prompt': messages,
                'ability': ability,
                'reward_model': {
                    'style': 'rule',
                    'ground_truth': str(ground_truth)
                },
                'extra_info': extra_info
            }

            records.append(record)

        return pd.DataFrame(records)

    def load_hf_dataset_to_parquet(
        self,
        dataset_name: str,
        output_path: str,
        prompt_template_path: Optional[str] = None,
        data_source: Optional[str] = None,
        ability: str = "puzzle",
        split: str = "train",
        use_cache: bool = True,
        system_prompt: Optional[str] = None
    ) -> str:
        """
        Load HuggingFace dataset and convert to VERL parquet format

        Args:
            dataset_name: HF dataset name (e.g., "anon-neurips26/bridges_5x5dm_grpo_train_5k_intformat")
            output_path: Path to save parquet file
            prompt_template_path: Optional path to prompt template file
            data_source: Task identifier (defaults to dataset name slug)
            ability: Task category
            split: Dataset split to load
            use_cache: Whether to use cached preprocessed data
            system_prompt: Optional system prompt to include. Use SYSTEM_PROMPT_RSFT
                          for models trained with the RSFT format (reasoning + answer tags).

        Returns:
            Path to created parquet file
        """
        # Auto-detect data_source from dataset name if not provided
        if data_source is None:
            # Extract puzzle type from dataset name
            name_lower = dataset_name.lower()
            for puzzle in ['bridges', 'cryptarithm', 'nonogram', 'sudoku', 'galaxies', 'graph']:
                if puzzle in name_lower:
                    data_source = puzzle
                    break
            if data_source is None:
                data_source = dataset_name.split('/')[-1].split('_')[0]

        # Check cache - include system_prompt in cache key
        prompt_template_str = None
        if prompt_template_path:
            prompt_template_str = self._load_prompt_template(prompt_template_path)

        # Include system_prompt in cache key to avoid serving wrong cached data
        cache_key_str = f"{dataset_name}_{prompt_template_path or 'default'}_{system_prompt or 'nosys'}"
        cache_key = hashlib.md5(cache_key_str.encode()).hexdigest()
        cached_path = os.path.join(self.cache_dir, f"{cache_key}_{split}.parquet")

        if use_cache and os.path.exists(cached_path):
            print(f"Loading cached dataset from {cached_path}")
            # Copy to output path
            import shutil
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            shutil.copy(cached_path, output_path)
            return output_path

        # Local-first override: if the dataset slug has a sibling under
        # data/sft/<basename> or data/eval/<basename> (relative to repo root),
        # load that instead of going to HF. Lets the supplementary ship
        # everything in-zip.
        dataset = None
        basename = dataset_name.split('/')[-1] if '/' in dataset_name else dataset_name
        # Walk up from this file to find a likely repo root with a `data/` dir
        candidate_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
        for subdir in ('sft', 'eval'):
            local_dir = os.path.join(candidate_root, 'data', subdir, basename)
            if os.path.isdir(local_dir):
                # Try the typical HF-style layout: <local_dir>/data/*.parquet
                from datasets import load_dataset as _ld
                try:
                    dataset = _ld(local_dir, split=split)
                    print(f"Loading dataset {dataset_name} (split: {split}) from local: {local_dir}")
                    break
                except Exception:
                    # Fall back to direct parquet glob
                    import glob as _glob
                    pq_files = sorted(_glob.glob(os.path.join(local_dir, '**', '*.parquet'), recursive=True))
                    if pq_files:
                        try:
                            dataset = _ld('parquet', data_files=pq_files, split='train')
                            print(f"Loading dataset {dataset_name} (split: {split}) from local parquet: {pq_files}")
                            break
                        except Exception:
                            pass

        if dataset is None:
            print(f"Loading dataset {dataset_name} (split: {split}) from HuggingFace...")
            # Load dataset from HF
            try:
                dataset = load_dataset(dataset_name, split=split)
            except Exception as e:
                print(f"Error loading dataset {dataset_name}: {e}")
                raise

        print(f"Loaded {len(dataset)} examples")

        # Convert to VERL format
        print("Converting to VERL parquet format...")
        if system_prompt:
            print(f"Using system prompt: {system_prompt[:80]}...")
        df = self._convert_to_verl_format(
            dataset,
            prompt_template=prompt_template_str,
            data_source=data_source,
            ability=ability,
            system_prompt=system_prompt
        )

        print(f"Converted {len(df)} examples")

        # Save to parquet
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        df.to_parquet(output_path, index=False)
        print(f"Saved to {output_path}")

        # Cache the result
        if use_cache:
            df.to_parquet(cached_path, index=False)
            print(f"Cached to {cached_path}")

        return output_path

    def load_multiple_datasets(
        self,
        dataset_names: List[str],
        output_dir: str,
        prompt_template_path: Optional[str] = None,
        data_source: Optional[str] = None,
        ability: str = "puzzle",
        split: str = "train",
        use_cache: bool = True,
        system_prompt: Optional[str] = None
    ) -> List[str]:
        """
        Load multiple HuggingFace datasets and convert to VERL parquet format

        Args:
            dataset_names: List of HF dataset names
            output_dir: Directory to save parquet files
            prompt_template_path: Optional path to prompt template file
            data_source: Task identifier (defaults to auto-detect)
            ability: Task category
            split: Dataset split to load
            use_cache: Whether to use cached preprocessed data
            system_prompt: Optional system prompt to include

        Returns:
            List of paths to created parquet files
        """
        parquet_paths = []

        for idx, dataset_name in enumerate(dataset_names):
            output_path = os.path.join(output_dir, f"{split}_{idx}.parquet")

            path = self.load_hf_dataset_to_parquet(
                dataset_name=dataset_name,
                output_path=output_path,
                prompt_template_path=prompt_template_path,
                data_source=data_source,
                ability=ability,
                split=split,
                use_cache=use_cache,
                system_prompt=system_prompt
            )

            parquet_paths.append(path)

        return parquet_paths

    def load_pretrained_lora_for_verl(
        self,
        hf_checkpoint_path: str,
        base_model_path: str = "Qwen/Qwen2.5-7B-Instruct"
    ) -> Tuple[str, str]:
        """
        Prepare pretrained LoRA checkpoint for VERL training

        Args:
            hf_checkpoint_path: HF Hub path to LoRA checkpoint
                              Format: "username/repo/subfolder" or "username/repo"
            base_model_path: Base model path

        Returns:
            Tuple of (base_model_path, lora_checkpoint_path)
        """
        # Download from HF Hub if needed
        from huggingface_hub import snapshot_download

        # Parse HF path
        parts = hf_checkpoint_path.split('/')
        if len(parts) >= 2:
            repo_id = f"{parts[0]}/{parts[1]}"
            subfolder = '/'.join(parts[2:]) if len(parts) > 2 else None
        else:
            raise ValueError(f"Invalid HF checkpoint path: {hf_checkpoint_path}")

        import sys
        print(f"Downloading LoRA checkpoint from {repo_id} (subfolder: {subfolder})...", file=sys.stderr)

        # Download checkpoint
        local_path = snapshot_download(
            repo_id=repo_id,
            allow_patterns=f"{subfolder}/*" if subfolder else "*",
            cache_dir=self.cache_dir
        )

        if subfolder:
            local_path = os.path.join(local_path, subfolder)

        print(f"Downloaded to {local_path}", file=sys.stderr)

        return base_model_path, local_path


def main():
    """CLI for testing HF integration"""
    import argparse

    parser = argparse.ArgumentParser(description="HuggingFace Integration for VERL")
    parser.add_argument("--dataset", type=str, required=True, help="HF dataset name")
    parser.add_argument("--output", type=str, required=True, help="Output parquet path")
    parser.add_argument("--prompt_template", type=str, help="Path to prompt template")
    parser.add_argument("--data_source", type=str, help="Task identifier")
    parser.add_argument("--split", type=str, default="train", help="Dataset split")
    parser.add_argument("--no_cache", action="store_true", help="Disable caching")

    args = parser.parse_args()

    integration = HFIntegration()
    integration.load_hf_dataset_to_parquet(
        dataset_name=args.dataset,
        output_path=args.output,
        prompt_template_path=args.prompt_template,
        data_source=args.data_source,
        split=args.split,
        use_cache=not args.no_cache
    )

    print("Done!")


if __name__ == "__main__":
    main()
