"""
Enhanced dataset handlers with truly extensible OOP structure.
All dataset-specific logic is contained within handler classes.
"""

from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional, Union, Tuple, Callable
from datasets import load_dataset, Dataset
from datasets.formatting.formatting import LazyRow
import torch
import re

class BaseDataset(ABC):
    """
    Abstract base class for all dataset handlers.
    Contains ALL logic needed for dataset loading and preprocessing.
    """

    def __init__(self, tokenizer=None, dataset_num_proc: int = 1):
        """
        Initialize the dataset handler.

        Args:
            tokenizer: The tokenizer to use for preprocessing
            dataset_num_proc: Number of processes for dataset operations
        """
        self.tokenizer = tokenizer
        self.dataset_num_proc = dataset_num_proc

    @abstractmethod
    def get_dataset_pattern(self) -> str:
        """
        Return a pattern that matches this dataset.
        Can be exact name or regex pattern.
        """
        pass

    @abstractmethod
    def load_raw_dataset(self, split: str = 'train', **kwargs):
        """Load the raw dataset from HuggingFace or local path."""
        pass

    @abstractmethod
    def sample_to_input_dialogue(self, sample: Any, multi_turn: bool = False) -> List[Dict[str, str]]:
        """Convert a single sample to input dialogue format."""
        pass

    @abstractmethod
    def preprocess_sft(self, examples: Dict, multi_turn: bool = False) -> Dict:
        """Preprocess examples for SFT training."""
        pass

    @abstractmethod
    def preprocess_dpo(self, examples: Dict) -> Dict:
        """Preprocess examples for DPO training."""
        pass

    @abstractmethod
    def preprocess_ppo(self, element: Dict) -> Dict:
        """Preprocess a single element for PPO training."""
        pass

    @abstractmethod
    def preprocess_grpo(self, examples: Dict) -> Dict:
        """Preprocess examples for GRPO training."""
        pass

    def load_evaluation_dataset(self, dataset_split: str = 'validation', max_prompt_length: int = None, base_model_name: str = None, multi_turn: bool = False, **kwargs):
        """
        Load dataset for evaluation.
        Default implementation, can be overridden.

        Args:
            dataset_split: Dataset split to load
            max_prompt_length: Optional max token length for filtering prompts
            base_model_name: Model name for tokenizer (required if max_prompt_length is set)
            multi_turn: Whether to use multi-turn dialogue format
        """
        dataset = self.load_raw_dataset(split=dataset_split, **kwargs)

        # Filter by prompt length if specified
        if max_prompt_length is not None and base_model_name is not None:
            from transformers import AutoTokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
            original_size = len(dataset)

            # Add prompt field for length computation
            def add_prompt_field(sample):
                return {'prompt': self.sample_to_input_dialogue(sample, multi_turn=multi_turn)}

            dataset = dataset.map(add_prompt_field)
            dataset = dataset.map(self._compute_prompt_length)
            dataset = dataset.filter(lambda x: x['prompt_length'] <= max_prompt_length)
            dataset = dataset.remove_columns(['prompt', 'prompt_length'])
            print(f"Filtered eval dataset: {original_size - len(dataset)} samples removed (kept {len(dataset)}/{original_size})")

        return dataset

    def convert_sample_to_prompt(self, sample: Dict, multi_turn: bool) -> List[Dict[str, str]]:
        """
        Convert a sample to prompt format for evaluation.
        Default implementation calls sample_to_input_dialogue.
        """
        return self.sample_to_input_dialogue(sample, multi_turn=multi_turn)

    def load_normalized_samples(self, multi_turn: bool, max_prompt_length: int = None, base_model_name: str = None, data_dir: str = None) -> List[Dict]:
        """
        Load and normalize dataset samples for general use.
        Default implementation, can be overridden.

        Args:
            multi_turn: Whether to use multi-turn dialogue format
            max_prompt_length: Optional max token length for filtering prompts
            base_model_name: Model name for tokenizer (required if max_prompt_length is set)
            data_dir: Optional data directory for datasets that support it (e.g., HH-RLHF subsets)
        """
        # Pass data_dir to load_raw_dataset if provided
        if data_dir is not None:
            ds = self.load_raw_dataset(split='train', data_dir=data_dir)
        else:
            ds = self.load_raw_dataset(split='train')
        normalized = []
        for sample in ds:
            messages = self.sample_to_input_dialogue(sample, multi_turn=multi_turn)
            normalized.append({
                'input': messages,
                'raw_text': str(sample)
            })

        # Filter by prompt length if specified
        if max_prompt_length is not None and base_model_name is not None:
            from transformers import AutoTokenizer
            tokenizer = AutoTokenizer.from_pretrained(base_model_name)
            original_size = len(normalized)
            filtered = []
            for sample in normalized:
                try:
                    token_ids = tokenizer.apply_chat_template(
                        sample['input'],
                        tokenize=True,
                        padding=False,
                        add_generation_prompt=True,
                    )
                    if len(token_ids) <= max_prompt_length:
                        filtered.append(sample)
                except Exception:
                    pass  # Skip samples that fail tokenization
            normalized = filtered
            print(f"Filtered {original_size - len(normalized)} samples by prompt length (kept {len(normalized)}/{original_size})")

        return normalized

    def prepare_ppo_dataset(self, dataset, max_length: int = 512, multi_turn: bool = False):
        """
        Prepare dataset for PPO training.
        Common implementation for most datasets.
        """
        if self.tokenizer is None:
            raise ValueError("Tokenizer must be set to prepare PPO dataset")
        
        ds = dataset.map(
            self.preprocess_ppo,
            remove_columns=dataset.column_names,
            num_proc=self.dataset_num_proc,
            fn_kwargs={
                "multi_turn": multi_turn
                },
            load_from_cache_file=False,
        )
        
        ds = ds.filter(lambda x: x["lengths"] <= max_length)
        ds.set_format(type="torch")

        if len(ds) > 0 and ds[0]["input_ids"][-1] == self.tokenizer.eos_token_id:
            raise ValueError("PPO dataset should not end with EOS token")

        return ds

    def _compute_prompt_length(self, example: Dict) -> Dict:
        """
        Compute the tokenized length of a prompt for GRPO.
        Adds a 'prompt_length' field to the example.

        Args:
            example: Dictionary with 'prompt' field containing list of message dicts

        Returns:
            Dictionary with added 'prompt_length' field
        """
        if self.tokenizer is None:
            raise ValueError("Tokenizer must be set to compute prompt lengths")

        prompt_messages = example['prompt']

        # Handle empty prompts
        if not prompt_messages:
            example['prompt_length'] = 0
            return example

        # Tokenize the prompt using chat template
        try:
            token_ids = self.tokenizer.apply_chat_template(
                prompt_messages,
                tokenize=True,
                padding=False,
                add_generation_prompt=True,
            )
            example['prompt_length'] = len(token_ids)
        except Exception as e:
            # If tokenization fails, set a large length to filter it out
            print(f"Warning: Failed to tokenize prompt: {e}")
            example['prompt_length'] = 999999

        return example

    def load_for_training(self, method: str, args: Any, tokenizer: Any, val_dataset_size: int = 100) -> Tuple[Dataset, Dataset]:
        """
        Load and prepare datasets for a specific training method.
        This is the main entry point that handles all method-specific logic.

        Args:
            method: Training method ('sft', 'dpo', 'ppo', 'grpo')
            args: Arguments object with dataset configuration
            tokenizer: Tokenizer to use
            val_dataset_size: Maximum validation dataset size

        Returns:
            Tuple of (train_dataset, val_dataset)
        """
        self.tokenizer = tokenizer

        if method == 'sft':
            return self._load_for_sft(args, tokenizer, val_dataset_size)
        elif method == 'dpo':
            return self._load_for_dpo(args, tokenizer, val_dataset_size)
        elif method == 'ppo':
            return self._load_for_ppo(args, tokenizer, val_dataset_size)
        elif method == 'grpo':
            return self._load_for_grpo(args, tokenizer, val_dataset_size)
        else:
            raise NotImplementedError(f"Method {method} not implemented for {self.get_dataset_pattern()}")

    @abstractmethod
    def _load_for_sft(self, args, tokenizer, val_dataset_size):
        """Load and prepare dataset for SFT training."""
        pass

    @abstractmethod
    def _load_for_dpo(self, args, tokenizer, val_dataset_size):
        """Load and prepare dataset for DPO training."""
        pass

    @abstractmethod
    def _load_for_ppo(self, args, tokenizer, val_dataset_size):
        """Load and prepare dataset for PPO training."""
        pass

    @abstractmethod
    def _load_for_grpo(self, args, tokenizer, val_dataset_size):
        """Load and prepare dataset for GRPO training."""
        pass


class SummarizeFromFeedbackDataset(BaseDataset):
    """Handler for OpenAI's Summarize from Feedback (Reddit TLDR) dataset."""

    def get_dataset_pattern(self) -> str:
        return 'openai/summarize_from_feedback'

    def load_raw_dataset(self, split: str = 'train', **kwargs):
        return load_dataset(self.get_dataset_pattern(), 'comparisons', split=split, **kwargs)

    # def load_evaluation_dataset(self, dataset_split: str = 'validation', **kwargs):
    #     # This dataset always uses 'validation' for evaluation
    #     return self.load_raw_dataset(split='validation', **kwargs)

    def sample_to_input_dialogue(self, sample: Any, multi_turn: bool = False) -> List[Dict[str, str]]:
        if (isinstance(sample, dict) or isinstance(sample, LazyRow)) and 'info' in sample:
            info = sample['info']
        else:
            info = sample

        subreddit = info.get('subreddit', '')
        title = info.get('title', '')
        post = info.get('post', '')
        site = info.get('site', '')

        if (subreddit is None) or (site in ['dailymail', 'cnn']):
            article = info.get('article', '')
            query_text = f"SUBREDDIT: r/{site}\n\nTITLE: {title}\n\nPOST: {article}\n\nTL;DR:"
        else:
            query_text = f"SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:"

        return [{'role': 'user', 'content': query_text}]

    def preprocess_sft(self, examples: Dict, multi_turn: bool = False) -> Dict:
        prompts = []
        completions = []

        for i in range(len(examples['info'])):
            info = examples['info'][i]
            summaries = examples['summaries'][i]
            choice = examples['choice'][i]

            prompt = self.sample_to_input_dialogue(info)
            chosen_summary = summaries[choice]['text']
            completion = [{'role': 'assistant', 'content': chosen_summary}]

            prompts.append(prompt)
            completions.append(completion)

        return {"prompt": prompts, "completion": completions}

    def preprocess_dpo(self, examples: Dict) -> Dict:
        # TLDR dataset doesn't have natural chosen/rejected pairs
        raise NotImplementedError("DPO preprocessing for TLDR dataset needs special handling")

    def preprocess_ppo(self, element: Dict, multi_turn: bool = False) -> Dict:
        if self.tokenizer is None:
            raise ValueError("Tokenizer must be set for PPO preprocessing")

        messages = self.sample_to_input_dialogue(element)
        input_ids = self.tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            padding=False,
            add_generation_prompt=True,
        )

        return {
            "input_ids": input_ids,
            "lengths": len(input_ids),
            "query": self.tokenizer.decode(input_ids)
        }

    def preprocess_grpo(self, examples: Dict) -> Dict:
        prompts = []
        for info in examples['info']:
            messages = self.sample_to_input_dialogue(info)
            prompts.append(messages)
        return {"prompt": prompts}

    def _load_for_sft(self, args, tokenizer, val_dataset_size):
        """Load Reddit TLDR for SFT training."""
        train_dataset = self.load_raw_dataset(split="train")
        val_dataset_full = self.load_raw_dataset(split="validation")

        val_dataset_full = val_dataset_full.shuffle(seed=0).select(
            range(min(val_dataset_size, len(val_dataset_full)))
        )

        train_dataset = train_dataset.map(
            self.preprocess_sft,
            batched=True,
            remove_columns=train_dataset.column_names,
            num_proc=self.dataset_num_proc,
            fn_kwargs={
                # "tokenizer": tokenizer
                "multi_turn": False
                }
        )
        train_dataset.set_format(type="torch", columns=["prompt", "completion"])

        val_dataset = val_dataset_full.map(
            self.preprocess_sft,
            batched=True,
            remove_columns=val_dataset_full.column_names,
            num_proc=self.dataset_num_proc,
            fn_kwargs={
                # "tokenizer": tokenizer
                "multi_turn": False
                }
        )
        val_dataset.set_format(type="torch", columns=["prompt", "completion"])

        print(f"\nSample TLDR SFT data:")
        print(f"Prompt: {train_dataset[0]['prompt']}")
        print(f"Completion: {train_dataset[0]['completion']}")
        print(f"Training samples: {len(train_dataset)}, Validation: {len(val_dataset)}")

        return train_dataset.shuffle(seed=0), val_dataset

    def _load_for_dpo(self, args, tokenizer, val_dataset_size):
        raise NotImplementedError("DPO not implemented for TLDR dataset")

    def _load_for_ppo(self, args, tokenizer, val_dataset_size):
        """Load Reddit TLDR for PPO training."""
        raw_datasets = load_dataset(self.get_dataset_pattern(), 'comparisons')
        train_dataset = raw_datasets["train"]
        eval_dataset_raw = raw_datasets["validation"].select(
            range(min(val_dataset_size * 2, len(raw_datasets["validation"])))
        )

        # Deduplicate validation samples
        seen_posts = set()
        unique_indices = []
        for i, sample in enumerate(eval_dataset_raw):
            post_content = sample["info"]["post"]
            if post_content not in seen_posts:
                seen_posts.add(post_content)
                unique_indices.append(i)
                if len(unique_indices) >= min(32, val_dataset_size):
                    break

        val_dataset = eval_dataset_raw.select(unique_indices)
        print(f"Selected {len(val_dataset)} unique evaluation samples from Reddit TLDR")

        train_dataset = self.prepare_ppo_dataset(train_dataset, max_length=getattr(args, 'dataset_max_length', 512))
        val_dataset = self.prepare_ppo_dataset(val_dataset, max_length=getattr(args, 'dataset_max_length', 512))

        return train_dataset, val_dataset

    def _load_for_grpo(self, args, tokenizer, val_dataset_size):
        """Load Reddit TLDR for GRPO training."""
        train_dataset = self.load_raw_dataset(split="train")
        val_dataset_full = self.load_raw_dataset(split="validation")

        val_dataset_full = val_dataset_full.shuffle(seed=0).select(
            range(min(val_dataset_size, len(val_dataset_full)))
        )

        # Get max_prompt_length from args (default 512 tokens)
        max_prompt_length = getattr(args, 'max_prompt_length', 512)

        train_dataset = train_dataset.map(
            self.preprocess_grpo,
            batched=True,
            remove_columns=train_dataset.column_names,
            num_proc=self.dataset_num_proc,
        )

        val_dataset = val_dataset_full.map(
            self.preprocess_grpo,
            batched=True,
            remove_columns=val_dataset_full.column_names,
            num_proc=self.dataset_num_proc,
        )

        # Filter out samples with prompts that are too long
        print(f"\nFiltering GRPO samples with max_prompt_length={max_prompt_length}")
        train_size_before = len(train_dataset)
        val_size_before = len(val_dataset)

        # Compute prompt lengths using the inherited method
        train_dataset = train_dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths for train set"
        )
        val_dataset = val_dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths for val set"
        )

        # Filter based on length
        train_dataset = train_dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts from train set"
        )
        val_dataset = val_dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts from val set"
        )

        print(f"Filtered train: {train_size_before - len(train_dataset)} samples (kept {len(train_dataset)}/{train_size_before})")
        print(f"Filtered val: {val_size_before - len(val_dataset)} samples (kept {len(val_dataset)}/{val_size_before})")

        # Remove prompt_length column (no longer needed)
        train_dataset = train_dataset.remove_columns(["prompt_length"])
        val_dataset = val_dataset.remove_columns(["prompt_length"])

        train_dataset.set_format(type="torch", columns=["prompt"])
        val_dataset.set_format(type="torch", columns=["prompt"])

        print(f"\nSample GRPO prompt: {train_dataset[0]['prompt'] if len(train_dataset) > 0 else 'N/A'}")
        print(f"Training samples: {len(train_dataset)}, Validation: {len(val_dataset)}")

        return train_dataset.shuffle(seed=0), val_dataset


class AnthropicHHRLHFDataset(BaseDataset):
    """Handler for Anthropic's HH-RLHF dataset."""

    def get_dataset_pattern(self) -> str:
        return 'Anthropic/hh-rlhf'

    def load_raw_dataset(self, split: str = 'train', **kwargs):
        # Handle data_dir if provided in kwargs (comes from args.dataset_dirs)
        data_dir = kwargs.pop('data_dir', None)
        if data_dir:
            return load_dataset(self.get_dataset_pattern(), data_dir=data_dir, split=split, **kwargs)
        return load_dataset(self.get_dataset_pattern(), split=split, **kwargs)

    def _parse_conversation(self, text: str) -> List[Dict[str, str]]:
        """Parse HH-RLHF conversation format into messages."""
        messages = []
        parts = text.split('\n\n')

        current_role = None
        current_content = []
        for part in parts:
            if part.startswith('Human:'):
                if current_role and current_content:
                    messages.append({
                        "role": "user" if current_role == "Human" else "assistant",
                        "content": ' '.join(current_content).strip()
                    })
                current_role = "Human"
                current_content = [part[6:].strip()]
            elif part.startswith('Assistant:'):
                if current_role and current_content:
                    messages.append({
                        "role": "user" if current_role == "Human" else "assistant",
                        "content": ' '.join(current_content).strip()
                    })
                current_role = "Assistant"
                current_content = [part[10:].strip()]
            elif current_content:
                current_content.append(part.strip())

        if current_role and current_content:
            messages.append({
                "role": "user" if current_role == "Human" else "assistant",
                "content": ' '.join(current_content).strip()
            })

        return messages

    def sample_to_input_dialogue(self, sample: Any, multi_turn: bool = False) -> List[Dict[str, str]]:
        if isinstance(sample, dict) or isinstance(sample, LazyRow):
            text = sample['chosen']
        else:
            text = sample

        messages = self._parse_conversation(text)

        if multi_turn:
            last_assistant_idx = -1
            for i in range(len(messages) - 1, -1, -1):
                if messages[i]["role"] == "assistant":
                    last_assistant_idx = i
                    break

            if last_assistant_idx > 0:
                return messages[:last_assistant_idx]
            else:
                return messages[:-1] if len(messages) > 1 else messages
        else:
            # Single-turn: return only first user message
            for msg in messages:
                if msg["role"] == "user":
                    return [msg]
            return messages[:1] if messages else []

    def preprocess_sft(self, examples: Dict, multi_turn: bool = False) -> Dict:
        prompts = []
        completions = []

        for chosen in examples['chosen']:
            messages = self._parse_conversation(chosen)

            if multi_turn:
                last_assistant_idx = -1
                for i in range(len(messages) - 1, -1, -1):
                    if messages[i]["role"] == "assistant":
                        last_assistant_idx = i
                        break

                if last_assistant_idx > 0:
                    prompt_messages = messages[:last_assistant_idx]
                    completion_messages = [messages[last_assistant_idx]]
                else:
                    prompt_messages = messages[:-1] if len(messages) > 1 else messages
                    completion_messages = messages[-1:] if messages else []
            else:
                # Single-turn logic
                prompt_messages = []
                completion_messages = []

                for msg in messages:
                    if msg["role"] == "user" and not prompt_messages:
                        prompt_messages.append(msg)
                        break

                found_user = False
                for msg in messages:
                    if msg["role"] == "user":
                        found_user = True
                    elif msg["role"] == "assistant" and found_user:
                        completion_messages.append(msg)
                        break

                if not prompt_messages and messages:
                    prompt_messages = [messages[0]] if messages[0]["role"] == "user" else []
                if not completion_messages and len(messages) > 1:
                    for msg in messages[1:]:
                        if msg["role"] == "assistant":
                            completion_messages = [msg]
                            break

            prompts.append(prompt_messages)
            completions.append(completion_messages)

        return {"prompt": prompts, "completion": completions}

    def preprocess_dpo(self, examples: Dict) -> Dict:
        raise NotImplementedError("DPO not implemented for HH-RLHF dataset")

    def preprocess_ppo(self, element: Dict, multi_turn: bool = False) -> Dict:
        if self.tokenizer is None:
            raise ValueError("Tokenizer must be set for PPO preprocessing")

        messages = self.sample_to_input_dialogue(element, multi_turn=multi_turn)
        
        if not messages:
            return {"input_ids": [], "lengths": 0, "query": ""}

        input_ids = self.tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            padding=False,
            add_generation_prompt=True,
        )

        return {
            "input_ids": input_ids,
            "lengths": len(input_ids),
            "query": self.tokenizer.decode(input_ids)
        }

    def preprocess_grpo(self, examples: Dict, multi_turn: bool = False) -> Dict:
        prompts = []
        for chosen_text in examples['chosen']:
            messages = self.sample_to_input_dialogue(chosen_text, multi_turn=multi_turn)
            prompts.append(messages)
        return {"prompt": prompts}

    def _load_for_sft(self, args, tokenizer, val_dataset_size):
        """Load HH-RLHF for SFT training."""
        train_dataset = self.load_raw_dataset(split="train")
        test_dataset = self.load_raw_dataset(split="test")

        test_dataset = test_dataset.shuffle(seed=0).select(
            range(min(val_dataset_size, len(test_dataset)))
        )

        multi_turn = getattr(args, 'multi_turn', False)

        train_dataset = train_dataset.map(
            self.preprocess_sft,
            batched=True,
            remove_columns=train_dataset.column_names,
            num_proc=self.dataset_num_proc,
            fn_kwargs={
                # "tokenizer": tokenizer, 
                "multi_turn": multi_turn
                }
        )
        train_dataset.set_format(type="torch", columns=["prompt", "completion"])

        val_dataset = test_dataset.map(
            self.preprocess_sft,
            batched=True,
            remove_columns=test_dataset.column_names,
            num_proc=self.dataset_num_proc,
            fn_kwargs={
                # "tokenizer": tokenizer, 
                "multi_turn": multi_turn
                }
        )
        val_dataset.set_format(type="torch", columns=["prompt", "completion"])

        print(f"\nSample HH-RLHF SFT data:")
        print(f"Prompt: {train_dataset[0]['prompt'][:200]}...")
        print(f"Completion: {train_dataset[0]['completion'][:200]}...")
        print(f"Training samples: {len(train_dataset)}, Validation: {len(val_dataset)}")
        return train_dataset.shuffle(seed=0), val_dataset

    def _load_for_dpo(self, args, tokenizer, val_dataset_size):
        """Load HH-RLHF for DPO training."""
        raise NotImplementedError("DPO not implemented for HH-RLHF dataset")

    def _load_for_ppo(self, args, tokenizer, val_dataset_size):
        """Load HH-RLHF for PPO training."""
        raw_datasets = load_dataset(self.get_dataset_pattern())
        train_dataset = raw_datasets["train"]
        val_dataset = raw_datasets["test"].select(
            range(min(val_dataset_size, len(raw_datasets["test"])))
        )

        train_dataset = self.prepare_ppo_dataset(train_dataset, max_length=getattr(args, 'dataset_max_length', 512), multi_turn=getattr(args, 'multi_turn', False))
        val_dataset = self.prepare_ppo_dataset(val_dataset, max_length=getattr(args, 'dataset_max_length', 512), multi_turn=getattr(args, 'multi_turn', False))

        return train_dataset, val_dataset

    def _load_for_grpo(self, args, tokenizer, val_dataset_size):
        """Load HH-RLHF for GRPO training."""
        dataset = self.load_raw_dataset(split="train", data_dir=getattr(args, 'dataset_dirs', None))

        # Get multi_turn setting
        multi_turn = getattr(args, 'multi_turn', False)

        # Preprocess to get prompts
        dataset = dataset.map(
            self.preprocess_grpo,
            batched=True,
            remove_columns=dataset.column_names,
            num_proc=self.dataset_num_proc,
            fn_kwargs={
                "multi_turn": multi_turn
                }
        )

        # Filter long prompts to prevent OOM, especially for multi-turn dialogues
        # Get max_prompt_length from args, with sensible defaults
        if multi_turn:
            # Multi-turn dialogues can be much longer, so use more conservative default
            max_prompt_length = getattr(args, 'max_prompt_length', 512)
        else:
            # Single-turn is typically shorter
            max_prompt_length = getattr(args, 'max_prompt_length', 512)

        print(f"\nFiltering GRPO prompts with max_prompt_length={max_prompt_length} (multi_turn={multi_turn})")
        original_size = len(dataset)

        # Compute prompt lengths
        dataset = dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths"
        )

        # Filter based on length
        dataset = dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts"
        )

        filtered_size = len(dataset)
        print(f"Filtered {original_size - filtered_size} samples (kept {filtered_size}/{original_size})")

        # Remove the prompt_length column (no longer needed)
        dataset = dataset.remove_columns(["prompt_length"])

        dataset.set_format(type="torch", columns=["prompt"])

        print(f"\nSample GRPO prompt: {dataset[0]['prompt'] if len(dataset) > 0 else 'N/A'}")
        print(f"Training samples: {len(dataset)}")

        dataset = dataset.shuffle(seed=0)
        dataset = dataset.train_test_split(
            test_size=min(val_dataset_size, int(0.01 * len(dataset))),
            seed=0
        )

        return dataset["train"], dataset["test"]

class StanfordHumanPreferencesDataset(BaseDataset):
    """Handler for Stanford Human Preferences (SHP) dataset."""

    def get_dataset_pattern(self) -> str:
        return 'stanfordnlp/shp'

    def load_raw_dataset(self, split: str = 'train', **kwargs):
        """Load the raw dataset from HuggingFace."""
        # Handle dataset_dirs for specific subreddits or "all"
        data_dir = kwargs.pop('data_dir', None) or kwargs.pop('dataset_dirs', None)

        # The SHP dataset uses data_dir parameter for subreddit selection
        if data_dir and data_dir != 'all':
            return load_dataset(self.get_dataset_pattern(), data_dir=data_dir, split=split, **kwargs)
        else:
            # Load all subreddits if 'all' or no specific dir
            return load_dataset(self.get_dataset_pattern(), split=split, **kwargs)

    def _deduplicate_by_post_id(self, dataset, dataset_name: str = "dataset"):
        """
        Remove duplicate samples based on post_id.
        Keeps only the first occurrence of each unique post_id.

        Args:
            dataset: HuggingFace dataset to deduplicate
            dataset_name: Name for logging purposes (e.g., "train", "val")

        Returns:
            Deduplicated dataset
        """
        size_before = len(dataset)

        seen_post_ids = set()
        unique_indices = []

        for i, sample in enumerate(dataset):
            post_id = sample.get('post_id')
            if post_id not in seen_post_ids:
                seen_post_ids.add(post_id)
                unique_indices.append(i)

        dataset = dataset.select(unique_indices)

        duplicates_removed = size_before - len(dataset)
        if duplicates_removed > 0:
            print(f"Deduplicated {dataset_name}: removed {duplicates_removed} duplicates (kept {len(dataset)}/{size_before})")

        return dataset

    def sample_to_input_dialogue(self, sample: Any, multi_turn: bool = False) -> List[Dict[str, str]]:
        """Convert a single sample to input dialogue format."""
        # For preference datasets, we use the history/post as the user query
        history = sample.get('history', '')
        return [{'role': 'user', 'content': history}]

    def preprocess_sft(self, examples: Dict, multi_turn: bool = False) -> Dict:
        """Preprocess examples for SFT training - use the preferred response."""
        prompts = []
        completions = []

        for i in range(len(examples['history'])):
            history = examples['history'][i]
            human_ref_a = examples['human_ref_A'][i]
            human_ref_b = examples['human_ref_B'][i]
            labels = examples['labels'][i]

            # Select the preferred response based on label
            # labels=1 means A is preferred, labels=0 means B is preferred
            preferred_response = human_ref_a if labels == 1 else human_ref_b

            prompts.append([{'role': 'user', 'content': history}])
            completions.append([{'role': 'assistant', 'content': preferred_response}])

        return {"prompt": prompts, "completion": completions}

    def preprocess_dpo(self, examples: Dict) -> Dict:
        """Preprocess examples for DPO training."""
        prompts = []
        chosen = []
        rejected = []

        for i in range(len(examples['history'])):
            history = examples['history'][i]
            human_ref_a = examples['human_ref_A'][i]
            human_ref_b = examples['human_ref_B'][i]
            labels = examples['labels'][i]

            # Assign chosen/rejected based on label
            if labels == 1:
                chosen_text = human_ref_a
                rejected_text = human_ref_b
            else:
                chosen_text = human_ref_b
                rejected_text = human_ref_a

            # For DPO, we typically format as conversations
            prompt_messages = [{'role': 'user', 'content': history}]
            chosen_messages = prompt_messages + [{'role': 'assistant', 'content': chosen_text}]
            rejected_messages = prompt_messages + [{'role': 'assistant', 'content': rejected_text}]

            # Apply chat template if tokenizer is available
            if self.tokenizer:
                chosen_formatted = self.tokenizer.apply_chat_template(chosen_messages, tokenize=False)
                rejected_formatted = self.tokenizer.apply_chat_template(rejected_messages, tokenize=False)
            else:
                # Fallback formatting
                chosen_formatted = f"User: {history}\n\nAssistant: {chosen_text}"
                rejected_formatted = f"User: {history}\n\nAssistant: {rejected_text}"

            prompts.append(prompt_messages)
            chosen.append(chosen_formatted)
            rejected.append(rejected_formatted)

        return {"prompt": prompts, "chosen": chosen, "rejected": rejected}

    def preprocess_ppo(self, element: Dict, multi_turn: bool = False) -> Dict:
        """Preprocess a single element for PPO training."""
        if self.tokenizer is None:
            raise ValueError("Tokenizer must be set for PPO preprocessing")

        history = element.get('history', '')
        messages = [{'role': 'user', 'content': history}]

        input_ids = self.tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            padding=False,
            add_generation_prompt=True,
        )

        return {
            "input_ids": input_ids,
            "lengths": len(input_ids),
            "query": self.tokenizer.decode(input_ids)
        }

    def preprocess_grpo(self, examples: Dict) -> Dict:
        """Preprocess examples for GRPO training."""
        prompts = []

        for i in range(len(examples['history'])):
            history = examples['history'][i]
            messages = [{'role': 'user', 'content': history}]
            prompts.append(messages)

        return {"prompt": prompts}

    def _load_for_sft(self, args, tokenizer, val_dataset_size):
        """Load SHP dataset for SFT training."""
        # Get data_dir from args if specified
        data_dir = getattr(args, 'dataset_dirs', None)

        # Get max_prompt_length from args (default 512 tokens)
        max_prompt_length = getattr(args, 'max_prompt_length', 300)

        # Load train and validation splits
        train_dataset = self.load_raw_dataset(split="train", data_dir=data_dir)
        val_dataset = self.load_raw_dataset(split="validation", data_dir=data_dir)

        # Limit validation size
        val_dataset = val_dataset.shuffle(seed=0).select(
            range(min(val_dataset_size, len(val_dataset)))
        )

        # Preprocess datasets first
        train_dataset = train_dataset.map(
            self.preprocess_sft,
            batched=True,
            remove_columns=train_dataset.column_names,
            num_proc=self.dataset_num_proc,
            fn_kwargs={"multi_turn": False}
        )

        val_dataset = val_dataset.map(
            self.preprocess_sft,
            batched=True,
            remove_columns=val_dataset.column_names,
            num_proc=self.dataset_num_proc,
            fn_kwargs={"multi_turn": False}
        )

        # Filter out samples with prompts that are too long
        print(f"\nFiltering SFT samples with max_prompt_length={max_prompt_length}")
        train_size_before = len(train_dataset)
        val_size_before = len(val_dataset)

        # Compute prompt lengths using the inherited method
        train_dataset = train_dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths for train set"
        )
        val_dataset = val_dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths for val set"
        )

        # Filter based on length
        train_dataset = train_dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts from train set"
        )
        val_dataset = val_dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts from val set"
        )

        print(f"Filtered train: {train_size_before - len(train_dataset)} samples (kept {len(train_dataset)}/{train_size_before})")
        print(f"Filtered val: {val_size_before - len(val_dataset)} samples (kept {len(val_dataset)}/{val_size_before})")

        # Remove prompt_length column (no longer needed)
        train_dataset = train_dataset.remove_columns(["prompt_length"])
        val_dataset = val_dataset.remove_columns(["prompt_length"])

        # Set format
        train_dataset.set_format(type="torch", columns=["prompt", "completion"])
        val_dataset.set_format(type="torch", columns=["prompt", "completion"])

        print(f"\nSample SHP SFT data:")
        if len(train_dataset) > 0:
            print(f"Prompt: {train_dataset[0]['prompt'][:200]}...")
            print(f"Completion: {train_dataset[0]['completion'][:200]}...")
        print(f"Training samples: {len(train_dataset)}, Validation: {len(val_dataset)}")

        return train_dataset.shuffle(seed=0), val_dataset

    def _load_for_dpo(self, args, tokenizer, val_dataset_size):
        """Load SHP dataset for DPO training."""
        # Get data_dir from args if specified
        data_dir = getattr(args, 'dataset_dirs', None)

        # Get max_prompt_length from args (default 512 tokens)
        max_prompt_length = getattr(args, 'max_prompt_length', 512)

        # Load datasets
        train_dataset = self.load_raw_dataset(split="train", data_dir=data_dir)
        val_dataset = self.load_raw_dataset(split="validation", data_dir=data_dir)

        # Limit validation size
        val_dataset = val_dataset.shuffle(seed=0).select(
            range(min(val_dataset_size, len(val_dataset)))
        )

        # Preprocess datasets first
        train_dataset = train_dataset.map(
            self.preprocess_dpo,
            batched=True,
            remove_columns=train_dataset.column_names,
            num_proc=self.dataset_num_proc
        )

        val_dataset = val_dataset.map(
            self.preprocess_dpo,
            batched=True,
            remove_columns=val_dataset.column_names,
            num_proc=self.dataset_num_proc
        )

        # Filter out samples with prompts that are too long
        print(f"\nFiltering DPO samples with max_prompt_length={max_prompt_length}")
        train_size_before = len(train_dataset)
        val_size_before = len(val_dataset)

        # Compute prompt lengths using the inherited method
        train_dataset = train_dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths for train set"
        )
        val_dataset = val_dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths for val set"
        )

        # Filter based on length
        train_dataset = train_dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts from train set"
        )
        val_dataset = val_dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts from val set"
        )

        print(f"Filtered train: {train_size_before - len(train_dataset)} samples (kept {len(train_dataset)}/{train_size_before})")
        print(f"Filtered val: {val_size_before - len(val_dataset)} samples (kept {len(val_dataset)}/{val_size_before})")

        # Remove prompt_length column
        train_dataset = train_dataset.remove_columns(["prompt_length"])
        val_dataset = val_dataset.remove_columns(["prompt_length"])

        # Set format
        train_dataset.set_format(type="torch", columns=["prompt", "chosen", "rejected"])
        val_dataset.set_format(type="torch", columns=["prompt", "chosen", "rejected"])

        print(f"\nSample SHP DPO data:")
        if len(train_dataset) > 0:
            print(f"Prompt: {train_dataset[0]['prompt'][:200]}...")
            print(f"Chosen: {train_dataset[0]['chosen'][:200]}...")
            print(f"Rejected: {train_dataset[0]['rejected'][:200]}...")
        print(f"Training samples: {len(train_dataset)}, Validation: {len(val_dataset)}")

        return train_dataset.shuffle(seed=0), val_dataset

    def _load_for_ppo(self, args, tokenizer, val_dataset_size):
        """Load SHP dataset for PPO training."""
        # Get data_dir from args if specified
        data_dir = getattr(args, 'dataset_dirs', None)

        # Get max_prompt_length from args (default 512 tokens)
        max_prompt_length = getattr(args, 'max_prompt_length', 512)

        # Load raw datasets
        train_dataset = self.load_raw_dataset(split="train", data_dir=data_dir)
        val_dataset = self.load_raw_dataset(split="validation", data_dir=data_dir)

        # Deduplicate based on post_id to ensure unique prompts
        print(f"\nDeduplicating PPO samples based on post_id...")
        train_dataset = self._deduplicate_by_post_id(train_dataset, "train")
        val_dataset = self._deduplicate_by_post_id(val_dataset, "val")

        # Limit validation size after deduplication
        val_dataset = val_dataset.select(
            range(min(val_dataset_size, len(val_dataset)))
        )

        # Prepare datasets for PPO (includes filtering by length)
        train_dataset = self.prepare_ppo_dataset(train_dataset, max_length=max_prompt_length)
        val_dataset = self.prepare_ppo_dataset(val_dataset, max_length=max_prompt_length)

        print(f"\nSHP PPO dataset loaded with max_prompt_length={max_prompt_length}")
        print(f"Training samples: {len(train_dataset)}, Validation: {len(val_dataset)}")

        return train_dataset, val_dataset

    def _load_for_grpo(self, args, tokenizer, val_dataset_size):
        """Load SHP dataset for GRPO training."""
        # Get data_dir from args if specified
        data_dir = getattr(args, 'dataset_dirs', None)

        # Get max_prompt_length from args (default 512 tokens)
        max_prompt_length = getattr(args, 'max_prompt_length', 300)

        # Load datasets
        train_dataset = self.load_raw_dataset(split="train", data_dir=data_dir)
        val_dataset = self.load_raw_dataset(split="validation", data_dir=data_dir)

        # Deduplicate based on post_id to ensure unique prompts
        print(f"\nDeduplicating GRPO samples based on post_id...")
        train_dataset = self._deduplicate_by_post_id(train_dataset, "train")
        val_dataset = self._deduplicate_by_post_id(val_dataset, "val")

        # Limit validation size after deduplication
        val_dataset = val_dataset.shuffle(seed=0).select(
            range(min(val_dataset_size, len(val_dataset)))
        )

        # Preprocess datasets first
        train_dataset = train_dataset.map(
            self.preprocess_grpo,
            batched=True,
            remove_columns=train_dataset.column_names,
            num_proc=self.dataset_num_proc
        )

        val_dataset = val_dataset.map(
            self.preprocess_grpo,
            batched=True,
            remove_columns=val_dataset.column_names,
            num_proc=self.dataset_num_proc
        )

        # Filter out samples with prompts that are too long
        print(f"\nFiltering GRPO samples with max_prompt_length={max_prompt_length}")
        train_size_before = len(train_dataset)
        val_size_before = len(val_dataset)

        # Compute prompt lengths using the inherited method
        train_dataset = train_dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths for train set"
        )
        val_dataset = val_dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths for val set"
        )

        # Filter based on length
        train_dataset = train_dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts from train set"
        )
        val_dataset = val_dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts from val set"
        )

        print(f"Filtered train: {train_size_before - len(train_dataset)} samples (kept {len(train_dataset)}/{train_size_before})")
        print(f"Filtered val: {val_size_before - len(val_dataset)} samples (kept {len(val_dataset)}/{val_size_before})")

        # Remove prompt_length column
        train_dataset = train_dataset.remove_columns(["prompt_length"])
        val_dataset = val_dataset.remove_columns(["prompt_length"])

        # Set format
        train_dataset.set_format(type="torch", columns=["prompt"])
        val_dataset.set_format(type="torch", columns=["prompt"])

        print(f"\nSample GRPO prompt: {train_dataset[0]['prompt'] if len(train_dataset) > 0 else 'N/A'}")
        print(f"Training samples: {len(train_dataset)}, Validation: {len(val_dataset)}")

        return train_dataset.shuffle(seed=0), val_dataset

class UltraFeedbackDataset(BaseDataset):
    """Handler for OpenBMB UltraFeedback dataset."""

    def get_dataset_pattern(self) -> str:
        return 'openbmb/UltraFeedback'

    def load_raw_dataset(self, split: str = 'train', **kwargs):
        # UltraFeedback only has train split
        return load_dataset(self.get_dataset_pattern(), split='train', **kwargs)

    def sample_to_input_dialogue(self, sample: Any, multi_turn: bool = False) -> List[Dict[str, str]]:
        instruction = sample.get('instruction', '')
        return [{'role': 'user', 'content': instruction}]

    def preprocess_sft(self, examples: Dict, multi_turn: bool = False) -> Dict:
        prompts = []
        completions = []

        for i in range(len(examples['instruction'])):
            instruction = examples['instruction'][i]
            completions_list = examples['completions'][i]

            best_completion = None
            best_score = -float('inf')

            for comp in completions_list:
                score = comp.get('fine-grained_score', -float('inf'))
                if score > best_score:
                    best_score = score
                    best_completion = comp.get('response', '')

            if best_completion is None and completions_list:
                best_completion = completions_list[0].get('response', '')

            prompts.append(instruction)
            completions.append(best_completion)

        return {"prompt": prompts, "completion": completions}

    def preprocess_dpo(self, examples: Dict) -> Dict:
        # Could create chosen/rejected from best/worst completions
        raise NotImplementedError("DPO preprocessing for UltraFeedback not implemented")

    def preprocess_ppo(self, element: Dict, multi_turn: bool = False) -> Dict:
        if self.tokenizer is None:
            raise ValueError("Tokenizer must be set for PPO preprocessing")

        messages = self.sample_to_input_dialogue(element)
        input_ids = self.tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            padding=False,
            add_generation_prompt=True,
        )

        return {
            "input_ids": input_ids,
            "lengths": len(input_ids),
            "query": self.tokenizer.decode(input_ids)
        }

    def preprocess_grpo(self, examples: Dict) -> Dict:
        prompts = []
        for instruction in examples['instruction']:
            prompts.append([{'role': 'user', 'content': instruction}])
        return {"prompt": prompts}

    def _load_for_sft(self, args, tokenizer, val_dataset_size):
        """Load UltraFeedback for SFT training."""
        full_dataset = self.load_raw_dataset()

        full_dataset = full_dataset.shuffle(seed=0)
        split_dataset = full_dataset.train_test_split(
            test_size=min(val_dataset_size, int(0.01 * len(full_dataset))),
            seed=0
        )

        train_dataset = split_dataset["train"].map(
            self.preprocess_sft,
            batched=True,
            remove_columns=split_dataset["train"].column_names,
            num_proc=self.dataset_num_proc,
            fn_kwargs={"tokenizer": tokenizer}
        )
        train_dataset.set_format(type="torch", columns=["prompt", "completion"])

        val_dataset = split_dataset["test"].map(
            self.preprocess_sft,
            batched=True,
            remove_columns=split_dataset["test"].column_names,
            num_proc=self.dataset_num_proc,
            fn_kwargs={"tokenizer": tokenizer}
        )
        val_dataset.set_format(type="torch", columns=["prompt", "completion"])

        print(f"\nSample UltraFeedback SFT data:")
        print(f"Prompt: {train_dataset[0]['prompt'][:300]}...")
        print(f"Completion: {train_dataset[0]['completion'][:200]}...")
        print(f"Training samples: {len(train_dataset)}, Validation: {len(val_dataset)}")

        return train_dataset.shuffle(seed=0), val_dataset

    def _load_for_dpo(self, args, tokenizer, val_dataset_size):
        raise NotImplementedError("DPO not implemented for UltraFeedback")

    def _load_for_ppo(self, args, tokenizer, val_dataset_size):
        raise NotImplementedError("PPO not implemented for UltraFeedback")

    def _load_for_grpo(self, args, tokenizer, val_dataset_size):
        raise NotImplementedError("GRPO not implemented for UltraFeedback")


class LocalAlpacaGPT4Dataset(BaseDataset):
    """Handler for local Alpaca GPT4 10k instruction-following dataset."""

    def get_dataset_pattern(self) -> str:
        return 'local_alpaca_gpt4_10k'

    def load_raw_dataset(self, split: str = 'train', **kwargs):
        """Load the dataset from local JSON file."""
        import json

        file_path = ""

        with open(file_path, 'r') as f:
            data = json.load(f)

        # Convert to HuggingFace Dataset format
        from datasets import Dataset
        dataset = Dataset.from_list(data)

        # Split dataset if needed (80/20 split)
        if split == 'train':
            # Return first 80%
            train_size = int(0.8 * len(dataset))
            return dataset.select(range(train_size))
        elif split in ['validation', 'test']:
            # Return last 20%
            train_size = int(0.8 * len(dataset))
            return dataset.select(range(train_size, len(dataset)))
        else:
            return dataset

    def sample_to_input_dialogue(self, sample: Any, multi_turn: bool = False) -> List[Dict[str, str]]:
        """Convert a single sample to input dialogue format."""
        instruction = sample.get('instruction', '')
        input_text = sample.get('input', '')

        # Combine instruction and input if input is not empty
        if input_text and input_text.strip():
            query_text = f"{instruction}\n\nInput: {input_text}"
        else:
            query_text = instruction

        return [{'role': 'user', 'content': query_text}]

    def preprocess_sft(self, examples: Dict, multi_turn: bool = False) -> Dict:
        """Preprocess examples for SFT training."""
        prompts = []
        completions = []

        for i in range(len(examples['instruction'])):
            instruction = examples['instruction'][i]
            input_text = examples['input'][i]
            output_text = examples['output'][i]

            # Combine instruction and input if input is not empty
            if input_text and input_text.strip():
                query_text = f"{instruction}\n\nInput: {input_text}"
            else:
                query_text = instruction

            prompts.append([{'role': 'user', 'content': query_text}])
            completions.append([{'role': 'assistant', 'content': output_text}])

        return {"prompt": prompts, "completion": completions}

    def preprocess_dpo(self, examples: Dict) -> Dict:
        """DPO not applicable for this dataset (no preference pairs)."""
        raise NotImplementedError("DPO not applicable for Alpaca GPT4 dataset - no preference pairs")

    def preprocess_ppo(self, element: Dict, multi_turn: bool = False) -> Dict:
        """Preprocess a single element for PPO training."""
        if self.tokenizer is None:
            raise ValueError("Tokenizer must be set for PPO preprocessing")

        messages = self.sample_to_input_dialogue(element)

        input_ids = self.tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            padding=False,
            add_generation_prompt=True,
        )

        return {
            "input_ids": input_ids,
            "lengths": len(input_ids),
            "query": self.tokenizer.decode(input_ids)
        }

    def preprocess_grpo(self, examples: Dict) -> Dict:
        """Preprocess examples for GRPO training."""
        prompts = []

        for i in range(len(examples['instruction'])):
            instruction = examples['instruction'][i]
            input_text = examples['input'][i]

            # Combine instruction and input if input is not empty
            if input_text and input_text.strip():
                query_text = f"{instruction}\n\nInput: {input_text}"
            else:
                query_text = instruction

            messages = [{'role': 'user', 'content': query_text}]
            prompts.append(messages)

        return {"prompt": prompts}

    def _load_for_sft(self, args, tokenizer, val_dataset_size):
        """Load Alpaca GPT4 for SFT training."""
        train_dataset = self.load_raw_dataset(split="train")
        val_dataset = self.load_raw_dataset(split="validation")

        # Limit validation size
        val_dataset = val_dataset.shuffle(seed=0).select(
            range(min(val_dataset_size, len(val_dataset)))
        )

        # Preprocess datasets
        train_dataset = train_dataset.map(
            self.preprocess_sft,
            batched=True,
            remove_columns=train_dataset.column_names,
            num_proc=self.dataset_num_proc,
            fn_kwargs={"multi_turn": False}
        )
        train_dataset.set_format(type="torch", columns=["prompt", "completion"])

        val_dataset = val_dataset.map(
            self.preprocess_sft,
            batched=True,
            remove_columns=val_dataset.column_names,
            num_proc=self.dataset_num_proc,
            fn_kwargs={"multi_turn": False}
        )
        val_dataset.set_format(type="torch", columns=["prompt", "completion"])

        print(f"\nSample Alpaca GPT4 SFT data:")
        if len(train_dataset) > 0:
            print(f"Prompt: {train_dataset[0]['prompt'][:200]}...")
            print(f"Completion: {train_dataset[0]['completion'][:200]}...")
        print(f"Training samples: {len(train_dataset)}, Validation: {len(val_dataset)}")

        return train_dataset.shuffle(seed=0), val_dataset

    def _load_for_dpo(self, args, tokenizer, val_dataset_size):
        """DPO not applicable for this dataset."""
        raise NotImplementedError("DPO not applicable for Alpaca GPT4 dataset - no preference pairs")

    def _load_for_ppo(self, args, tokenizer, val_dataset_size):
        """Load Alpaca GPT4 for PPO training."""
        train_dataset = self.load_raw_dataset(split="train")
        val_dataset = self.load_raw_dataset(split="validation")

        # Limit validation size
        val_dataset = val_dataset.select(
            range(min(val_dataset_size, len(val_dataset)))
        )

        # Prepare datasets for PPO
        train_dataset = self.prepare_ppo_dataset(train_dataset, max_length=512)
        val_dataset = self.prepare_ppo_dataset(val_dataset, max_length=512)

        print(f"\nAlpaca GPT4 PPO dataset loaded")
        print(f"Training samples: {len(train_dataset)}, Validation: {len(val_dataset)}")

        return train_dataset, val_dataset

    def _load_for_grpo(self, args, tokenizer, val_dataset_size):
        """Load Alpaca GPT4 for GRPO training."""
        train_dataset = self.load_raw_dataset(split="train")
        val_dataset = self.load_raw_dataset(split="validation")

        # Limit validation size
        val_dataset = val_dataset.shuffle(seed=0).select(
            range(min(val_dataset_size, len(val_dataset)))
        )

        # Get max_prompt_length from args (default 512 tokens)
        max_prompt_length = getattr(args, 'max_prompt_length', 512)

        # Preprocess datasets
        train_dataset = train_dataset.map(
            self.preprocess_grpo,
            batched=True,
            remove_columns=train_dataset.column_names,
            num_proc=self.dataset_num_proc
        )

        val_dataset = val_dataset.map(
            self.preprocess_grpo,
            batched=True,
            remove_columns=val_dataset.column_names,
            num_proc=self.dataset_num_proc
        )

        # Filter out samples with prompts that are too long
        print(f"\nFiltering GRPO samples with max_prompt_length={max_prompt_length}")
        train_size_before = len(train_dataset)
        val_size_before = len(val_dataset)

        # Compute prompt lengths using the inherited method
        train_dataset = train_dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths for train set"
        )
        val_dataset = val_dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths for val set"
        )

        # Filter based on length
        train_dataset = train_dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts from train set"
        )
        val_dataset = val_dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts from val set"
        )

        print(f"Filtered train: {train_size_before - len(train_dataset)} samples (kept {len(train_dataset)}/{train_size_before})")
        print(f"Filtered val: {val_size_before - len(val_dataset)} samples (kept {len(val_dataset)}/{val_size_before})")

        # Remove prompt_length column (no longer needed)
        train_dataset = train_dataset.remove_columns(["prompt_length"])
        val_dataset = val_dataset.remove_columns(["prompt_length"])

        train_dataset.set_format(type="torch", columns=["prompt"])
        val_dataset.set_format(type="torch", columns=["prompt"])

        print(f"\nSample GRPO prompt: {train_dataset[0]['prompt'] if len(train_dataset) > 0 else 'N/A'}")
        print(f"Training samples: {len(train_dataset)}, Validation: {len(val_dataset)}")

        return train_dataset.shuffle(seed=0), val_dataset


class SkyworkRewardPreferenceDataset(BaseDataset):
    """Handler for Skywork Reward Preference 80K dataset."""

    def get_dataset_pattern(self) -> str:
        return 'Skywork/Skywork-Reward-Preference-80K-v0.2'

    def load_raw_dataset(self, split: str = 'train', **kwargs):
        """Load the raw dataset from HuggingFace."""
        # Skywork dataset only has train split
        return load_dataset(self.get_dataset_pattern(), split='train', **kwargs)

    def _parse_conversation(self, conversation: Union[str, List]) -> List[Dict[str, str]]:
        """Parse conversation format into messages."""
        import json

        # If it's a string, try to parse as JSON
        if isinstance(conversation, str):
            try:
                conversation = json.loads(conversation)
            except json.JSONDecodeError:
                # If JSON parsing fails, treat as a single message
                return [{'role': 'user', 'content': conversation}]

        # If it's already a list of messages, return as is
        if isinstance(conversation, list):
            return conversation

        # Fallback for unexpected formats
        return [{'role': 'user', 'content': str(conversation)}]

    def sample_to_input_dialogue(self, sample: Any, multi_turn: bool = False) -> List[Dict[str, str]]:
        """Convert a single sample to input dialogue format."""
        # For Skywork, we use the chosen conversation
        chosen = sample.get('chosen', [])
        messages = self._parse_conversation(chosen)

        if multi_turn:
            # For multi-turn, return all messages except the last assistant message
            last_assistant_idx = -1
            for i in range(len(messages) - 1, -1, -1):
                if messages[i].get("role") == "assistant":
                    last_assistant_idx = i
                    break

            if last_assistant_idx > 0:
                return messages[:last_assistant_idx]
            else:
                return messages[:-1] if len(messages) > 1 else messages
        else:
            # For single-turn, return only the first user message
            for msg in messages:
                if msg.get("role") == "user":
                    return [msg]
            return messages[:1] if messages else []

    def preprocess_sft(self, examples: Dict, multi_turn: bool = False) -> Dict:
        """Preprocess examples for SFT training - use the chosen response."""
        prompts = []
        completions = []

        # for i in range(len(examples['chosen'])):
        for i in range(len(examples['rejected'])):
            # chosen = examples['chosen'][i]
            chosen = examples['rejected'][i]
            messages = self._parse_conversation(chosen)

            if multi_turn:
                # Multi-turn: find last assistant message
                last_assistant_idx = -1
                for j in range(len(messages) - 1, -1, -1):
                    if messages[j].get("role") == "assistant":
                        last_assistant_idx = j
                        break

                if last_assistant_idx > 0:
                    prompt_messages = messages[:last_assistant_idx]
                    completion_messages = [messages[last_assistant_idx]]
                else:
                    prompt_messages = messages[:-1] if len(messages) > 1 else messages
                    completion_messages = messages[-1:] if messages else []
            else:
                # Single-turn: first user message as prompt, first assistant as completion
                prompt_messages = []
                completion_messages = []

                for msg in messages:
                    if msg.get("role") == "user" and not prompt_messages:
                        prompt_messages.append(msg)
                    elif msg.get("role") == "assistant" and prompt_messages and not completion_messages:
                        completion_messages.append(msg)
                        break

            prompts.append(prompt_messages)
            completions.append(completion_messages)

        return {"prompt": prompts, "completion": completions}

    def preprocess_dpo(self, examples: Dict) -> Dict:
        """Preprocess examples for DPO training."""
        prompts = []
        chosen_responses = []
        rejected_responses = []

        for i in range(len(examples['chosen'])):
            chosen_conv = self._parse_conversation(examples['chosen'][i])
            rejected_conv = self._parse_conversation(examples['rejected'][i])

            # Extract the user query (first user message)
            prompt_messages = []
            for msg in chosen_conv:
                if msg.get("role") == "user":
                    prompt_messages.append(msg)
                    break

            # If no user message in chosen, check rejected
            if not prompt_messages:
                for msg in rejected_conv:
                    if msg.get("role") == "user":
                        prompt_messages.append(msg)
                        break

            # Extract chosen and rejected assistant responses
            chosen_text = ""
            rejected_text = ""

            for msg in chosen_conv:
                if msg.get("role") == "assistant":
                    chosen_text = msg.get("content", "")
                    break

            for msg in rejected_conv:
                if msg.get("role") == "assistant":
                    rejected_text = msg.get("content", "")
                    break

            # Format for DPO
            if self.tokenizer:
                chosen_formatted = self.tokenizer.apply_chat_template(
                    prompt_messages + [{"role": "assistant", "content": chosen_text}],
                    tokenize=False
                )
                rejected_formatted = self.tokenizer.apply_chat_template(
                    prompt_messages + [{"role": "assistant", "content": rejected_text}],
                    tokenize=False
                )
            else:
                # Fallback formatting
                prompt_text = prompt_messages[0]["content"] if prompt_messages else ""
                chosen_formatted = f"User: {prompt_text}\n\nAssistant: {chosen_text}"
                rejected_formatted = f"User: {prompt_text}\n\nAssistant: {rejected_text}"

            prompts.append(prompt_messages)
            chosen_responses.append(chosen_formatted)
            rejected_responses.append(rejected_formatted)

        return {"prompt": prompts, "chosen": chosen_responses, "rejected": rejected_responses}

    def preprocess_ppo(self, element: Dict, multi_turn: bool = False) -> Dict:
        """Preprocess a single element for PPO training."""
        if self.tokenizer is None:
            raise ValueError("Tokenizer must be set for PPO preprocessing")

        messages = self.sample_to_input_dialogue(element, multi_turn=multi_turn)

        if not messages:
            return {"input_ids": [], "lengths": 0, "query": ""}

        input_ids = self.tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            padding=False,
            add_generation_prompt=True,
        )

        return {
            "input_ids": input_ids,
            "lengths": len(input_ids),
            "query": self.tokenizer.decode(input_ids)
        }

    def preprocess_grpo(self, examples: Dict) -> Dict:
        """Preprocess examples for GRPO training."""
        prompts = []

        for i in range(len(examples['chosen'])):
            chosen = examples['chosen'][i]
            messages = self._parse_conversation(chosen)

            # Extract just the user query for GRPO
            prompt_messages = []
            for msg in messages:
                if msg.get("role") == "user":
                    prompt_messages.append(msg)
                    break

            prompts.append(prompt_messages)

        return {"prompt": prompts}

    def _load_for_sft(self, args, tokenizer, val_dataset_size):
        """Load Skywork dataset for SFT training."""
        # Get max_prompt_length from args (default 512 tokens for coding tasks)
        max_prompt_length = getattr(args, 'max_prompt_length', 512)

        # Load full dataset
        full_dataset = self.load_raw_dataset()

        # Split into train and validation
        full_dataset = full_dataset.shuffle(seed=0)
        split_dataset = full_dataset.train_test_split(
            test_size=min(val_dataset_size, int(0.01 * len(full_dataset))),
            seed=0
        )

        multi_turn = getattr(args, 'multi_turn', False)

        # Preprocess train dataset
        train_dataset = split_dataset["train"].map(
            self.preprocess_sft,
            batched=True,
            remove_columns=split_dataset["train"].column_names,
            num_proc=self.dataset_num_proc,
            fn_kwargs={"multi_turn": multi_turn}
        )

        # Preprocess validation dataset
        val_dataset = split_dataset["test"].map(
            self.preprocess_sft,
            batched=True,
            remove_columns=split_dataset["test"].column_names,
            num_proc=self.dataset_num_proc,
            fn_kwargs={"multi_turn": multi_turn}
        )

        # Filter out samples with prompts that are too long
        print(f"\nFiltering SFT samples with max_prompt_length={max_prompt_length}")
        train_size_before = len(train_dataset)
        val_size_before = len(val_dataset)

        # Compute prompt lengths using the inherited method
        train_dataset = train_dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths for train set"
        )
        val_dataset = val_dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths for val set"
        )

        # Filter based on length
        train_dataset = train_dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts from train set"
        )
        val_dataset = val_dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts from val set"
        )

        print(f"Filtered train: {train_size_before - len(train_dataset)} samples (kept {len(train_dataset)}/{train_size_before})")
        print(f"Filtered val: {val_size_before - len(val_dataset)} samples (kept {len(val_dataset)}/{val_size_before})")

        # Remove the prompt_length column (no longer needed)
        train_dataset = train_dataset.remove_columns(["prompt_length"])
        val_dataset = val_dataset.remove_columns(["prompt_length"])

        # Set format
        train_dataset.set_format(type="torch", columns=["prompt", "completion"])
        val_dataset.set_format(type="torch", columns=["prompt", "completion"])

        print(f"\nSample Skywork SFT data:")
        if len(train_dataset) > 0:
            print(f"Prompt: {train_dataset[0]['prompt'][:200]}...")
            print(f"Completion: {train_dataset[0]['completion'][:200]}...")
        print(f"Training samples: {len(train_dataset)}, Validation: {len(val_dataset)}")

        return train_dataset.shuffle(seed=0), val_dataset

    def _load_for_dpo(self, args, tokenizer, val_dataset_size):
        """Load Skywork dataset for DPO training."""
        # Get max_prompt_length from args (default 512 tokens)
        max_prompt_length = getattr(args, 'max_prompt_length', 512)

        # Load full dataset
        full_dataset = self.load_raw_dataset()

        # Split into train and validation
        full_dataset = full_dataset.shuffle(seed=0)
        split_dataset = full_dataset.train_test_split(
            test_size=min(val_dataset_size, int(0.01 * len(full_dataset))),
            seed=0
        )

        # Preprocess datasets
        train_dataset = split_dataset["train"].map(
            self.preprocess_dpo,
            batched=True,
            remove_columns=split_dataset["train"].column_names,
            num_proc=self.dataset_num_proc
        )

        val_dataset = split_dataset["test"].map(
            self.preprocess_dpo,
            batched=True,
            remove_columns=split_dataset["test"].column_names,
            num_proc=self.dataset_num_proc
        )

        # Filter out samples with prompts that are too long
        print(f"\nFiltering DPO samples with max_prompt_length={max_prompt_length}")
        train_size_before = len(train_dataset)
        val_size_before = len(val_dataset)

        # Compute prompt lengths
        train_dataset = train_dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths for train set"
        )
        val_dataset = val_dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths for val set"
        )

        # Filter based on length
        train_dataset = train_dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts from train set"
        )
        val_dataset = val_dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts from val set"
        )

        print(f"Filtered train: {train_size_before - len(train_dataset)} samples (kept {len(train_dataset)}/{train_size_before})")
        print(f"Filtered val: {val_size_before - len(val_dataset)} samples (kept {len(val_dataset)}/{val_size_before})")

        # Remove prompt_length column
        train_dataset = train_dataset.remove_columns(["prompt_length"])
        val_dataset = val_dataset.remove_columns(["prompt_length"])

        # Set format
        train_dataset.set_format(type="torch", columns=["prompt", "chosen", "rejected"])
        val_dataset.set_format(type="torch", columns=["prompt", "chosen", "rejected"])

        print(f"\nSample Skywork DPO data:")
        if len(train_dataset) > 0:
            print(f"Prompt: {train_dataset[0]['prompt'][:200]}...")
            print(f"Chosen: {train_dataset[0]['chosen'][:200]}...")
            print(f"Rejected: {train_dataset[0]['rejected'][:200]}...")
        print(f"Training samples: {len(train_dataset)}, Validation: {len(val_dataset)}")

        return train_dataset.shuffle(seed=0), val_dataset

    def _load_for_ppo(self, args, tokenizer, val_dataset_size):
        """Load Skywork dataset for PPO training."""
        # Get max_prompt_length from args
        # max_prompt_length = getattr(args, 'max_prompt_length', 200)
        max_prompt_length = getattr(args, 'dataset_max_length', 512)
        multi_turn = getattr(args, 'multi_turn', False)

        # Load full dataset
        full_dataset = self.load_raw_dataset()

        # Split into train and validation
        full_dataset = full_dataset.shuffle(seed=0)
        split_dataset = full_dataset.train_test_split(
            test_size=min(val_dataset_size, int(0.01 * len(full_dataset))),
            seed=0
        )

        # Prepare datasets for PPO
        train_dataset = self.prepare_ppo_dataset(
            split_dataset["train"],
            max_length=max_prompt_length,
            multi_turn=multi_turn
        )
        val_dataset = self.prepare_ppo_dataset(
            split_dataset["test"],
            max_length=max_prompt_length,
            multi_turn=multi_turn
        )

        print(f"\nSkywork PPO dataset loaded with max_prompt_length={max_prompt_length}")
        print(f"Training samples: {len(train_dataset)}, Validation: {len(val_dataset)}")

        return train_dataset, val_dataset

    def _load_for_grpo(self, args, tokenizer, val_dataset_size):
        """Load Skywork dataset for GRPO training."""
        # Get max_prompt_length from args
        max_prompt_length = getattr(args, 'max_prompt_length', 512)

        # Load full dataset
        full_dataset = self.load_raw_dataset()

        # Split into train and validation
        full_dataset = full_dataset.shuffle(seed=0)
        split_dataset = full_dataset.train_test_split(
            test_size=min(val_dataset_size, int(0.01 * len(full_dataset))),
            seed=0
        )

        # Preprocess datasets
        train_dataset = split_dataset["train"].map(
            self.preprocess_grpo,
            batched=True,
            remove_columns=split_dataset["train"].column_names,
            num_proc=self.dataset_num_proc
        )

        val_dataset = split_dataset["test"].map(
            self.preprocess_grpo,
            batched=True,
            remove_columns=split_dataset["test"].column_names,
            num_proc=self.dataset_num_proc
        )

        # Filter out samples with prompts that are too long
        print(f"\nFiltering GRPO samples with max_prompt_length={max_prompt_length}")
        train_size_before = len(train_dataset)
        val_size_before = len(val_dataset)

        # Compute prompt lengths
        train_dataset = train_dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths for train set"
        )
        val_dataset = val_dataset.map(
            self._compute_prompt_length,
            num_proc=self.dataset_num_proc,
            desc="Computing prompt lengths for val set"
        )

        # Filter based on length
        train_dataset = train_dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts from train set"
        )
        val_dataset = val_dataset.filter(
            lambda x: x["prompt_length"] <= max_prompt_length,
            desc="Filtering long prompts from val set"
        )

        print(f"Filtered train: {train_size_before - len(train_dataset)} samples (kept {len(train_dataset)}/{train_size_before})")
        print(f"Filtered val: {val_size_before - len(val_dataset)} samples (kept {len(val_dataset)}/{val_size_before})")

        # Remove prompt_length column
        train_dataset = train_dataset.remove_columns(["prompt_length"])
        val_dataset = val_dataset.remove_columns(["prompt_length"])

        # Set format
        train_dataset.set_format(type="torch", columns=["prompt"])
        val_dataset.set_format(type="torch", columns=["prompt"])

        print(f"\nSample GRPO prompt: {train_dataset[0]['prompt'] if len(train_dataset) > 0 else 'N/A'}")
        print(f"Training samples: {len(train_dataset)}, Validation: {len(val_dataset)}")

        return train_dataset.shuffle(seed=0), val_dataset


class DatasetFactory:
    """
    Factory class for creating dataset handlers.
    Uses a simple dictionary for exact matches.
    """

    # Main registry - simple dictionary for exact dataset names
    _datasets = {
        'openai/summarize_from_feedback': SummarizeFromFeedbackDataset,
        'Anthropic/hh-rlhf': AnthropicHHRLHFDataset,
        'openbmb/UltraFeedback': UltraFeedbackDataset,
        'stanfordnlp/shp': StanfordHumanPreferencesDataset,
        'local_alpaca_gpt4_10k': LocalAlpacaGPT4Dataset,
        'Skywork/Skywork-Reward-Preference-80K-v0.2': SkyworkRewardPreferenceDataset,
    }

    @classmethod
    def register_dataset(cls, dataset_name: str, handler_class: type):
        """
        Register a dataset handler.

        Args:
            dataset_name: Exact name of the dataset
            handler_class: Class that handles this dataset (must inherit from BaseDataset)
        """
        if not issubclass(handler_class, BaseDataset):
            raise ValueError(f"Handler class must inherit from BaseDataset")
        cls._datasets[dataset_name] = handler_class

    @classmethod
    def get_dataset_handler(cls, dataset_name: str, tokenizer=None, **kwargs) -> BaseDataset:
        """
        Get a dataset handler instance.

        Args:
            dataset_name: Name or path of the dataset
            tokenizer: Tokenizer to use
            **kwargs: Additional arguments for handler

        Returns:
            Dataset handler instance

        Raises:
            ValueError: If dataset is not supported
        """
        # Check exact matches in dictionary
        if dataset_name in cls._datasets:
            return cls._datasets[dataset_name](tokenizer=tokenizer, **kwargs)

        # No match found
        available = list(cls._datasets.keys())
        raise ValueError(
            f"No handler for dataset '{dataset_name}'.\n"
            f"Available datasets: {available}"
        )

    @classmethod
    def list_available_datasets(cls) -> List[str]:
        """Get list of available dataset names."""
        return list(cls._datasets.keys())