"""Dataset classes for EBM training.

This module provides dataset wrappers for CSV files and HuggingFace datasets,
along with a custom collate function for use with a PyTorch DataLoader. The
classes are designed to handle specific column structures for prompts, responses,
and various types of negative samples.
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Tuple, List, Union, Any

import pandas as pd
from torch.utils.data import Dataset
from .config import DataConfig

# Create a logger for this module
logger = logging.getLogger(__name__)


class PromptResponseDataset(Dataset):
    """Wraps a HuggingFace Dataset into a PyTorch-compatible format.

    This universal class handles any dataset loaded via the `datasets` library.
    It extracts required columns (prompt, response) and safely handles optional
    columns (human, gpt2) for use as negative samples, filling in blanks if
    they are missing.

    Attributes:
        prompts (List[str]): A list of all prompts.
        responses (List[str]): A list of all golden responses.
    """

    def __init__(self,
                 hf_ds: Any,
                 data_config: DataConfig) -> None:
        """Initializes the PromptResponseDataset.

        Args:
            hf_ds (datasets.Dataset): The HuggingFace Dataset object to wrap.
            data_config (DataConfig): Configuration specifying column names.

        Raises:
            ValueError: If the required prompt or response columns are not found.
        """
        self.config = data_config
        
        required_cols = {self.config.prompt_col, self.config.response_col}
        if not required_cols.issubset(hf_ds.column_names):
            missing = required_cols - set(hf_ds.column_names)
            msg = f"Dataset must contain the following columns: {missing}"
            raise ValueError(msg)

        self.prompts = hf_ds[self.config.prompt_col]

        self._LLM_answers = [
            (sample[0] if isinstance(sample, list) and len(sample) > 0
             else (sample if isinstance(sample, str) else ""))
            for sample in hf_ds[self.config.response_col]
        ]
        self.responses = self._LLM_answers

        # Internal implementation tools for an efficient 'get_field'
        if self.config.human_col in hf_ds.column_names:
            self._human_answers = [
                (sample[0] if isinstance(sample, list) and len(sample) > 0
                else (sample if isinstance(sample, str) else ""))
                for sample in hf_ds[self.config.human_col]
            ]            
        else:
            logger.warning(
                f"Human answers column ('{self.config.human_col}') not found. "
                "PromptResponseDataset.get_field() will return empty strings for this field."
            )
            self._human_answers = [""] * len(hf_ds)

        if self.config.gpt2_col in hf_ds.column_names:
            self._gpt2_answers = hf_ds[self.config.gpt2_col]
        else:
            logger.warning(
                f"GPT-2 column ('{self.config.gpt2_col}') not found. "
                "PromptResponseDataset.get_field() will return empty strings for this field."
            )
            self._gpt2_answers = [""] * len(hf_ds)

    def __len__(self) -> int:
        """Returns the total number of samples in the dataset."""
        return len(self.prompts)

    def __getitem__(self, idx: int) -> Tuple[int, str, str]:
        """Retrieves a sample from the dataset by index."""
        return idx, self.prompts[idx], self.responses[idx]

    def get_field(self, idx: int, name: str) -> str:
        """Retrieves the raw value from a supported field for a given index.

        Args:
            idx (int): The row index of the sample.
            name (str): The case-insensitive field name to retrieve.
            
        Returns:
            str: The value from the specified field and row as a string.
        """
        name_lower = name.lower()
        if name_lower == self.config.human_col:
            return self._human_answers[idx]
        if name_lower == self.config.gpt2_col:
            return self._gpt2_answers[idx]
        
        msg = f"Field '{name}' is not supported by PromptResponseDataset.get_field()."
        raise KeyError(msg)
    

class CSVDataset(Dataset):
    """Wraps a CSV file or pandas DataFrame into a PyTorch Dataset.

    This class provides a convenient interface for loading prompt-response data
    from a CSV file. It expects 'question' and 'answer' columns and provides
    methods to access other data fields by column name.

    Attributes:
        df (pd.DataFrame): The pandas DataFrame holding the dataset.
        prompts (List[str]): A list of all prompts (questions).
        responses (List[str]): A list of all golden responses (answers).
    """

    def __init__(self, 
                 data: Union[str, Path, pd.DataFrame],
                 data_config: DataConfig) -> None:
        """Initializes the CSVDataset.

        Args:
            data (Union[str, Path, pd.DataFrame]): A path to a CSV file or a
                pre-loaded pandas DataFrame.
            data_config (DataConfig): Configuration object specifying column names.

        Raises:
            TypeError: If the input data is not a path or DataFrame.
            ValueError: If the required 'question' and 'answer' columns are
                not present in the data.
        """
        if isinstance(data, (str, Path)):
            self.df = pd.read_csv(data)
        elif isinstance(data, pd.DataFrame):
            self.df = data
        else:
            msg = "Input must be a pandas DataFrame or a path to a CSV file."
            raise TypeError(msg)

        self.config = data_config

        required_cols = {self.config.prompt_col, self.config.response_col}
        if not required_cols.issubset(self.df.columns):
            msg = f"CSV must contain '{self.config.prompt_col}' and '{self.config.response_col}' columns."
            raise ValueError(msg)

        self.prompts = self.df[self.config.prompt_col].astype(str).tolist()
        self.responses = self.df[self.config.response_col].astype(str).tolist()

    def __len__(self) -> int:
        """Returns the total number of samples in the dataset."""
        return len(self.prompts)

    def __getitem__(self, idx: int) -> Tuple[int, str, str]:
        """Retrieves a sample from the dataset by index."""
        return idx, self.prompts[idx], self.responses[idx]

    def get_field(self, idx: int, name: str) -> str:
        """Retrieves the raw value from a specific column for a given index.

        This method allows access to any column in the CSV (e.g., 'GPT2',
        'human_answers') in a case-insensitive manner.

        Args:
            idx (int): The row index of the sample.
            name (str): The case-insensitive column name to retrieve.

        Returns:
            str: The value from the specified column and row as a string.
                 Returns an empty string for NaN values.
        """
        name_lower = name.lower()
        matching_columns = [col for col in self.df.columns if col.lower() == name_lower]

        if not matching_columns:
            msg = f"Column '{name}' not found in CSV."
            raise KeyError(msg)

        # Use the first match, handle potential NaN values
        value = self.df.iloc[idx][matching_columns[0]]
        # Convert to str so you never get numpy types
        return "" if pd.isna(value) else str(value)


class HFDataset(Dataset):
    """Wraps a HuggingFace Dataset into a PyTorch-compatible format.

    This class provides an interface for using a HuggingFace `datasets.Dataset`
    object. It specifically assumes columns like 'question', 'chatgpt_answers',
    'human_answers', and takes pre-computed 'GPT2' answers (if they exist), 
    making them accessible through a unified API.

    Attributes:
        prompts (List[str]): A list of all prompts (questions).
        responses (List[str]): A list of all golden responses, extracted from
            the 'chatgpt_answers' field.
    """

    def __init__(self, 
                 hf_ds: Any,
                 data_config: DataConfig) -> None:
        """Initializes the HFDataset.

        Args:
            hf_ds (datasets.Dataset): The HuggingFace Dataset object to wrap.
            data_config (DataConfig): Configuration object specifying column names
                (e.g., prompt_col, response_col).
        """
        self.config = data_config

        self.prompts = hf_ds[self.config.prompt_col]

        self._LLM_answers = [
            (a[0] if isinstance(a, list) and len(a) > 0
             else (a if isinstance(a, str) else ""))
            for a in hf_ds[self.config.response_col]
        ]
        self.responses = self._LLM_answers

        # Internal implementation tools for 'get_field'
        self._human_answers = [
            (sample[0] if isinstance(sample, list) and len(sample) > 0
             else (sample if isinstance(sample, str) else ""))
            for sample in hf_ds[self.config.human_col]
        ]
        # Get GPT-2 answers directly from the dataset's column
        if self.config.gpt2_col in hf_ds.column_names:
            self._gpt2_answers = hf_ds[self.config.gpt2_col]
        else:
            # Fallback in case the column is missing
            logger.warning(
                "GPT2 column not found in HuggingFace dataset. Filling with empty strings."
            )
            self._gpt2_answers = ["" for _ in self.prompts]

    def __len__(self) -> int:
        """Returns the total number of samples in the dataset."""
        return len(self.prompts)

    def __getitem__(self, idx: int) -> tuple[int, str, str]:
        """Retrieves a sample from the dataset by index."""
        return idx, self.prompts[idx], self.responses[idx]

    def get_field(self, idx: int, name: str) -> str:
        """Retrieves the raw value from a supported field for a given index.

        Args:
            idx (int): The row index of the sample.
            name (str): The case-insensitive field name to retrieve (e.g., "gpt2").
            
        Returns:
            str: The value from the specified field and row as a string.
        """
        name_lower = name.lower()
        if name_lower == self.config.gpt2_col:
            return self._gpt2_answers[idx]
        if name_lower == self.config.human_col:
            return self._human_answers[idx]
    
        msg = f"Field '{name}' is not supported by HFDataset.get_field()."
        raise KeyError(msg)


def collate(
    batch: List[Tuple[int, str, str]]
) -> Tuple[List[int], List[str], List[str]]:
    """Collates samples from the Dataset into a batch.

    This function is designed to be used with a PyTorch DataLoader. It takes a
    list of individual samples (each a tuple of index, prompt, response) and
    restructures it into three separate lists for indices, prompts, and responses.

    Args:
        batch (List[Tuple[int, str, str]]): A list of samples, where each
            sample is a tuple from the dataset's __getitem__ method.

    Returns:
        Tuple[List[int], List[str], List[str]]: A tuple containing a list of
        all indices, a list of all prompts, and a list of all responses.
    """
    indices, prompts, responses = zip(*batch)
    return list(indices), list(prompts), list(responses)
