#!/usr/bin/env python3
"""
Dataset loaders for QR-Adaptor experiments.

Supports:
- Alpaca: General instruction tuning
- GSM8K: Math reasoning
"""

from typing import Dict, List, Optional
from datasets import load_dataset


# ========================================
# Alpaca Dataset
# ========================================

ALPACA_PROMPT_TEMPLATE = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Input:
{input}

### Response:
{output}"""

ALPACA_NO_INPUT_TEMPLATE = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:
{output}"""


def load_alpaca_dataset(
    split: str = "train",
    max_samples: Optional[int] = None,
    seed: int = 42
) -> List[Dict]:
    """Load Alpaca dataset.
    
    Args:
        split: Dataset split ("train")
        max_samples: Optional limit on number of samples
        seed: Random seed for shuffling
        
    Returns:
        List of formatted text examples
    """
    print(f"Loading Alpaca dataset (split={split})...")
    dataset = load_dataset("tatsu-lab/alpaca", split=split)
    
    if max_samples and len(dataset) > max_samples:
        dataset = dataset.shuffle(seed=seed).select(range(max_samples))
    
    formatted_data = []
    for example in dataset:
        if example.get("input", "").strip():
            text = ALPACA_PROMPT_TEMPLATE.format(
                instruction=example["instruction"],
                input=example["input"],
                output=example["output"]
            )
        else:
            text = ALPACA_NO_INPUT_TEMPLATE.format(
                instruction=example["instruction"],
                output=example["output"]
            )
        formatted_data.append({"text": text})
    
    print(f"Loaded {len(formatted_data)} Alpaca examples")
    return formatted_data


# ========================================
# GSM8K Dataset
# ========================================

GSM8K_PROMPT_TEMPLATE = """Question: {question}

Answer: {answer}"""


def load_gsm8k_dataset(
    split: str = "train",
    max_samples: Optional[int] = None,
    seed: int = 42
) -> List[Dict]:
    """Load GSM8K dataset for math reasoning.
    
    Args:
        split: Dataset split ("train" or "test")
        max_samples: Optional limit on number of samples
        seed: Random seed for shuffling
        
    Returns:
        List of formatted text examples
    """
    print(f"Loading GSM8K dataset (split={split})...")
    dataset = load_dataset("gsm8k", "main", split=split)
    
    if max_samples and len(dataset) > max_samples:
        dataset = dataset.shuffle(seed=seed).select(range(max_samples))
    
    formatted_data = []
    for example in dataset:
        text = GSM8K_PROMPT_TEMPLATE.format(
            question=example["question"],
            answer=example["answer"]
        )
        formatted_data.append({"text": text})
    
    print(f"Loaded {len(formatted_data)} GSM8K examples")
    return formatted_data


# ========================================
# Unified Loader
# ========================================

def load_dataset_for_training(
    dataset_name: str,
    split: str = "train",
    max_samples: Optional[int] = None,
    seed: int = 42
) -> List[Dict]:
    """Load dataset by name.
    
    Args:
        dataset_name: "alpaca" or "gsm8k"
        split: Dataset split
        max_samples: Optional limit
        seed: Random seed
        
    Returns:
        List of formatted examples
    """
    loaders = {
        "alpaca": load_alpaca_dataset,
        "gsm8k": load_gsm8k_dataset,
    }
    
    if dataset_name not in loaders:
        raise ValueError(f"Unknown dataset: {dataset_name}. Supported: {list(loaders.keys())}")
    
    return loaders[dataset_name](split, max_samples, seed)


if __name__ == "__main__":
    # Test loading
    print("Testing Alpaca dataset...")
    alpaca = load_alpaca_dataset(max_samples=3)
    print(f"Sample:\n{alpaca[0]['text'][:500]}...")
    
    print("\nTesting GSM8K dataset...")
    gsm8k = load_gsm8k_dataset(max_samples=3)
    print(f"Sample:\n{gsm8k[0]['text'][:500]}...")
