"""
Ultra-clean dataset setup module with zero dataset-specific logic.
ALL dataset-specific code is in the handlers. This file is pure delegation.
"""

from .dataset_handlers import DatasetFactory
from typing import Any, Dict, List, Tuple
from datasets import Dataset

def load_normalized_dataset_samples(dataset_name: str, multi_turn: bool = False, max_prompt_length: int = None, base_model_name: str = None, data_dir: str = None) -> List[Dict]:
    """Load and normalize dataset samples.

    Args:
        dataset_name: Name of the dataset (e.g., 'Anthropic/hh-rlhf')
        multi_turn: Whether to use multi-turn dialogue format
        max_prompt_length: Optional max token length for filtering prompts
        base_model_name: Model name for tokenizer (required if max_prompt_length is set)
        data_dir: Optional data directory for datasets that support it (e.g., 'helpful-base' for HH-RLHF)
    """
    handler = DatasetFactory.get_dataset_handler(dataset_name)
    return handler.load_normalized_samples(multi_turn, max_prompt_length=max_prompt_length, base_model_name=base_model_name, data_dir=data_dir)

def load_evaluation_dataset(dataset_name: str, dataset_split: str, max_prompt_length: int = None, base_model_name: str = None, multi_turn: bool = False) -> Dataset:
    """Load the evaluation dataset."""
    print(f"\nLoading dataset: {dataset_name}")
    handler = DatasetFactory.get_dataset_handler(dataset_name)
    dataset = handler.load_evaluation_dataset(dataset_split, max_prompt_length=max_prompt_length, base_model_name=base_model_name, multi_turn=multi_turn)
    print(f"Dataset size: {len(dataset)}")
    return dataset

def convert_sample_to_prompt(sample: Dict, dataset_name: str, multi_turn: bool) -> List[Dict[str, str]]:
    """Convert a sample to prompt format based on dataset name."""
    try:
        handler = DatasetFactory.get_dataset_handler(dataset_name)
        return handler.convert_sample_to_prompt(sample, multi_turn)
    except ValueError:
        # Fallback for unknown datasets
        if isinstance(sample, dict):
            for key in ['prompt', 'input', 'text']:
                if key in sample:
                    return [{"role": "user", "content": sample[key]}]
            return [{"role": "user", "content": str(sample)}]
        return [{"role": "user", "content": str(sample)}]

def load_datasets(args: Any, tokenizer: Any, val_dataset_size: int = 100) -> Tuple[Dataset, Dataset]:
    """
    Main function to load and preprocess datasets.

    This is now completely generic - it just finds the right handler and delegates.
    No dataset-specific if-statements needed!

    Args:
        args: Arguments object with:
            - dataset_name: Name/path of dataset
            - method: Training method ('sft', 'dpo', 'ppo', 'grpo')
            - dataset_num_proc: Number of processes for data loading
            - Other dataset-specific args
        tokenizer: Tokenizer to use
        val_dataset_size: Maximum validation dataset size

    Returns:
        Tuple of (train_dataset, val_dataset)
    """
    # Get the appropriate handler based on dataset name
    handler = DatasetFactory.get_dataset_handler(
        args.dataset_name,
        tokenizer=tokenizer,
        dataset_num_proc=getattr(args, 'dataset_num_proc', 1)
    )

    # Delegate to the handler's unified loading method
    return handler.load_for_training(
        method=args.method,
        args=args,
        tokenizer=tokenizer,
        val_dataset_size=val_dataset_size
    )

# def sample_to_input_dialogue_tldr(sample_info):
#     """Convert Reddit TLDR sample to input dialogue format."""
#     handler = DatasetFactory.get_dataset_handler('openai/summarize_from_feedback')
#     return handler.sample_to_input_dialogue(sample_info)

# def sample_to_input_dialogue_hh(sample_chosen, multi_turn=False):
#     """Convert HH-RLHF sample to input dialogue format."""
#     handler = DatasetFactory.get_dataset_handler('Anthropic/hh-rlhf')
#     return handler.sample_to_input_dialogue(sample_chosen, multi_turn=multi_turn)

# def prepare_dataset(dataset: Dataset, tokenizer: Any, dataset_name: str) -> Dataset:
#     """Pre-tokenize the dataset for PPO training."""
#     handler = DatasetFactory.get_dataset_handler(dataset_name, tokenizer=tokenizer)
#     return handler.prepare_ppo_dataset(dataset)

# Simple wrapper functions for specific preprocessing methods
# def sft_preprocess_function(examples: Dict, tokenizer: Any, multi_turn: bool = True) -> Dict:
#     """Extract chosen responses from HH-RLHF dataset for SFT training."""
#     handler = DatasetFactory.get_dataset_handler('Anthropic/hh-rlhf', tokenizer=tokenizer)
#     return handler.preprocess_sft(examples, multi_turn=multi_turn)

# def sft_preprocess_function_tldr(examples: Dict, tokenizer: Any) -> Dict:
#     """Preprocess OpenAI summarize_from_feedback dataset for SFT training."""
#     handler = DatasetFactory.get_dataset_handler('openai/summarize_from_feedback', tokenizer=tokenizer)
#     return handler.preprocess_sft(examples, multi_turn=False)

# def grpo_preprocess_function_hh(examples: Dict, tokenizer: Any) -> Dict:
#     """Preprocess Anthropic HH-RLHF dataset for GRPO training."""
#     handler = DatasetFactory.get_dataset_handler('Anthropic/hh-rlhf', tokenizer=tokenizer)
#     return handler.preprocess_grpo(examples)

# def grpo_preprocess_function_tldr(examples: Dict, tokenizer: Any) -> Dict:
#     """Preprocess Reddit TLDR dataset for GRPO training."""
#     handler = DatasetFactory.get_dataset_handler('openai/summarize_from_feedback', tokenizer=tokenizer)
#     return handler.preprocess_grpo(examples)

# Legacy functions that might still be used somewhere
# def dpo_conversational_implicit_prompt_prefs(samples: Dict, tokenizer: Any) -> Dict:
#     """DPO preprocessing - legacy function."""
#     return {
#         'chosen': samples['chosen'],
#         'rejected': samples['rejected'],
#     }

# def sft_preprocess_function_conversation(examples: Dict, tokenizer: Any) -> Dict:
#     """Preprocessing for conversation format datasets - legacy function."""
#     return {"messages": examples['chosen']}

# def ppo_preprocess_function(examples: Dict, tokenizer: Any) -> Dict:
#     """Legacy PPO preprocessing function."""
#     prompts = []
#     for conversation in examples['chosen']:
#         last_assistant_turn_idx = conversation.rfind('\n\nAssistant:')
#         if last_assistant_turn_idx != -1:
#             prompt_end_idx = last_assistant_turn_idx + len('\n\nAssistant:')
#             prompts.append(conversation[:prompt_end_idx])
#         else:
#             prompts.append(conversation)
#     return tokenizer(prompts, truncation=True)

# def sft_preprocess_function_ultrafeedback(examples: Dict, tokenizer: Any) -> Dict:
#     """Preprocess UltraFeedback dataset for SFT training."""
#     handler = DatasetFactory.get_dataset_handler('openbmb/UltraFeedback', tokenizer=tokenizer)
#     return handler.preprocess_sft(examples, multi_turn=False)


"""
HOW TO ADD A NEW DATASET:

1. Create a new handler class in dataset_handlers.py:

```python
class YourDataset(BaseDataset):
    def get_dataset_pattern(self) -> str:
        return 'your_org/your_dataset'  # or a regex pattern

    def load_raw_dataset(self, split='train', **kwargs):
        return load_dataset(self.get_dataset_pattern(), split=split, **kwargs)

    def sample_to_input_dialogue(self, sample, multi_turn=False):
        # Convert to dialogue format
        return [{'role': 'user', 'content': sample['text']}]

    # Implement all required preprocessing methods...
    def preprocess_sft(self, examples, multi_turn=False): ...
    def preprocess_dpo(self, examples): ...
    def preprocess_ppo(self, element): ...
    def preprocess_grpo(self, examples): ...

    # Implement training loaders
    def _load_for_sft(self, args, tokenizer, val_dataset_size): ...
    def _load_for_dpo(self, args, tokenizer, val_dataset_size): ...
    def _load_for_ppo(self, args, tokenizer, val_dataset_size): ...
    def _load_for_grpo(self, args, tokenizer, val_dataset_size): ...
```

2. Register it in dataset_handlers.py:

```python
DatasetFactory.register_dataset('your_org/your_dataset', YourDataset, priority=10)
```

That's it! No changes needed in this file. The dataset will automatically be available
for all training methods through the generic load_datasets() function.
"""