"""
Data Loading and Preprocessing

Handles loading and preprocessing of preference datasets
(UltraFeedback, HH-RLHF, etc.)
"""

import torch
from datasets import load_dataset, Dataset
from typing import Dict, List, Optional
import numpy as np


class PreferenceDatasetLoader:
    """
    Load and preprocess preference datasets.
    """

    def __init__(self, tokenizer, max_length: int = 512):
        """
        Args:
            tokenizer: Tokenizer for the model
            max_length: Maximum sequence length
        """
        self.tokenizer = tokenizer
        self.max_length = max_length

    def load_ultrafeedback(
        self,
        split: str = "train",
        num_samples: Optional[int] = None
    ) -> Dataset:
        """
        Load UltraFeedback dataset.

        Args:
            split: 'train' or 'test'
            num_samples: Number of samples to load (None for all)

        Returns:
            Dataset with columns: prompt, chosen, rejected
        """
        # Load dataset
        dataset = load_dataset("openbmb/UltraFeedback", split=split)

        if num_samples:
            dataset = dataset.select(range(min(num_samples, len(dataset))))

        # Process into binary preference pairs
        processed = []
        for sample in dataset:
            # Extract best and worst completions based on ratings
            completions = sample.get('completions', [])
            if len(completions) < 2:
                continue

            # Sort by score
            scored = [(c, c.get('rating', 0)) for c in completions]
            scored.sort(key=lambda x: x[1], reverse=True)

            processed.append({
                'prompt': sample['instruction'],
                'chosen': scored[0][0]['response'],
                'rejected': scored[-1][0]['response'],
            })

        return Dataset.from_list(processed)

    def load_hh_rlhf(
        self,
        split: str = "train",
        num_samples: Optional[int] = None
    ) -> Dataset:
        """
        Load HH-RLHF dataset.

        Args:
            split: 'train' or 'test'
            num_samples: Number of samples to load

        Returns:
            Dataset with columns: prompt, chosen, rejected
        """
        dataset = load_dataset("Anthropic/hh-rlhf", split=split)

        if num_samples:
            dataset = dataset.select(range(min(num_samples, len(dataset))))

        # HH-RLHF already has chosen/rejected format
        processed = []
        for sample in dataset:
            # Extract prompt and responses
            chosen_text = sample['chosen']
            rejected_text = sample['rejected']

            # Extract prompt (text before first assistant response)
            prompt = chosen_text.split('\n\nAssistant:')[0] + '\n\nAssistant: '

            # Extract responses
            chosen_response = chosen_text.split('\n\nAssistant:')[-1].strip()
            rejected_response = rejected_text.split('\n\nAssistant:')[-1].strip()

            processed.append({
                'prompt': prompt,
                'chosen': chosen_response,
                'rejected': rejected_response,
            })

        return Dataset.from_list(processed)

    def attach_difficulty_scores(
        self,
        dataset: Dataset,
        difficulty_scores: Dict[str, np.ndarray]
    ) -> Dataset:
        """
        Attach precomputed difficulty scores to dataset.

        Args:
            dataset: Dataset
            difficulty_scores: Dict with keys 'Csem', 'Upref', 'Rsem', 'Runc'

        Returns:
            Dataset with difficulty scores attached
        """
        # Convert to list of dicts
        data_list = []
        for i, sample in enumerate(dataset):
            sample_dict = dict(sample)
            sample_dict['Csem'] = float(difficulty_scores['Csem'][i])
            sample_dict['Upref'] = float(difficulty_scores['Upref'][i])
            sample_dict['Rsem'] = float(difficulty_scores['Rsem'][i])
            sample_dict['Runc'] = float(difficulty_scores['Runc'][i])
            data_list.append(sample_dict)

        return Dataset.from_list(data_list)

    def create_train_val_split(
        self,
        dataset: Dataset,
        val_ratio: float = 0.05,
        seed: int = 42
    ) -> tuple:
        """
        Split dataset into train and validation.

        Args:
            dataset: Full dataset
            val_ratio: Validation set ratio
            seed: Random seed

        Returns:
            (train_dataset, val_dataset)
        """
        split = dataset.train_test_split(test_size=val_ratio, seed=seed)
        return split['train'], split['test']

    def collate_fn(self, batch: List[Dict]) -> Dict[str, torch.Tensor]:
        """
        Collate function for DataLoader.

        Args:
            batch: List of samples

        Returns:
            Batched tensors
        """
        prompts = [sample['prompt'] for sample in batch]
        chosen = [sample['chosen'] for sample in batch]
        rejected = [sample['rejected'] for sample in batch]

        # Tokenize chosen
        chosen_texts = [p + c for p, c in zip(prompts, chosen)]
        chosen_encodings = self.tokenizer(
            chosen_texts,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        # Tokenize rejected
        rejected_texts = [p + r for p, r in zip(prompts, rejected)]
        rejected_encodings = self.tokenizer(
            rejected_texts,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids_chosen': chosen_encodings['input_ids'],
            'attention_mask_chosen': chosen_encodings['attention_mask'],
            'input_ids_rejected': rejected_encodings['input_ids'],
            'attention_mask_rejected': rejected_encodings['attention_mask'],
        }
