"""

HuggingFaceDataset Class
=========================

TextAttack allows users to provide their own dataset or load from HuggingFace.


"""

import collections

import datasets

import textattack

from .dataset import Dataset


def _cb(s):
    """Colors some text blue for printing to the terminal."""
    return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")


def get_datasets_dataset_columns(dataset):
    """Common schemas for datasets found in dataset hub."""
    schema = set(dataset.column_names)
    if {"premise", "hypothesis", "label"} <= schema:
        input_columns = ("premise", "hypothesis")
        output_column = "label"
    elif {"question", "sentence", "label"} <= schema:
        input_columns = ("question", "sentence")
        output_column = "label"
    elif {"sentence1", "sentence2", "label"} <= schema:
        input_columns = ("sentence1", "sentence2")
        output_column = "label"
    elif {"question1", "question2", "label"} <= schema:
        input_columns = ("question1", "question2")
        output_column = "label"
    elif {"question", "sentence", "label"} <= schema:
        input_columns = ("question", "sentence")
        output_column = "label"
    elif {"context", "question", "title", "answers"} <= schema:
        # Common schema for SQUAD dataset
        input_columns = ("title", "context", "question")
        output_column = "answers"
    elif {"text", "label"} <= schema:
        input_columns = ("text",)
        output_column = "label"
    elif {"sentence", "label"} <= schema:
        input_columns = ("sentence",)
        output_column = "label"
    elif {"document", "summary"} <= schema:
        input_columns = ("document",)
        output_column = "summary"
    elif {"content", "summary"} <= schema:
        input_columns = ("content",)
        output_column = "summary"
    elif {"label", "review"} <= schema:
        input_columns = ("review",)
        output_column = "label"
    else:
        raise ValueError(
            f"Unsupported dataset schema {schema}. Try passing your own `dataset_columns` argument."
        )

    return input_columns, output_column


class HuggingFaceDataset(Dataset):
    """Loads a dataset from 🤗 Datasets and prepares it as a TextAttack dataset.

    Args:
        name_or_dataset (:obj:`Union[str, datasets.Dataset]`):
            The dataset name as :obj:`str` or actual :obj:`datasets.Dataset` object.
            If it's your custom :obj:`datasets.Dataset` object, please pass the input and output columns via :obj:`dataset_columns` argument.
        subset (:obj:`str`, `optional`, defaults to :obj:`None`):
            The subset of the main dataset. Dataset will be loaded as :obj:`datasets.load_dataset(name, subset)`.
        split (:obj:`str`, `optional`, defaults to :obj:`"train"`):
            The split of the dataset.
        dataset_columns (:obj:`tuple(list[str], str))`, `optional`, defaults to :obj:`None`):
            Pair of :obj:`list[str]` representing list of input column names (e.g. :obj:`["premise", "hypothesis"]`)
            and :obj:`str` representing the output column name (e.g. :obj:`label`). If not set, we will try to automatically determine column names from known designs.
        label_map (:obj:`dict[int, int]`, `optional`, defaults to :obj:`None`):
            Mapping if output labels of the dataset should be re-mapped. Useful if model was trained with a different label arrangement.
            For example, if dataset's arrangement is 0 for `Negative` and 1 for `Positive`, but model's label
            arrangement is 1 for `Negative` and 0 for `Positive`, passing :obj:`{0: 1, 1: 0}` will remap the dataset's label to match with model's arrangements.
            Could also be used to remap literal labels to numerical labels (e.g. :obj:`{"positive": 1, "negative": 0}`).
        label_names (:obj:`list[str]`, `optional`, defaults to :obj:`None`):
            List of label names in corresponding order (e.g. :obj:`["World", "Sports", "Business", "Sci/Tech"]` for AG-News dataset).
            If not set, labels will printed as is (e.g. "0", "1", ...). This should be set to :obj:`None` for non-classification datasets.
        output_scale_factor (:obj:`float`, `optional`, defaults to :obj:`None`):
            Factor to divide ground-truth outputs by. Generally, TextAttack goal functions require model outputs between 0 and 1.
            Some datasets are regression tasks, in which case this is necessary.
        shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to shuffle the underlying dataset.

            .. note::
                Generally not recommended to shuffle the underlying dataset. Shuffling can be performed using DataLoader or by shuffling the order of indices we attack.
    """

    def __init__(
        self,
        name_or_dataset,
        subset=None,
        split="train",
        dataset_columns=None,
        label_map=None,
        label_names=None,
        output_scale_factor=None,
        shuffle=False,
    ):
        if isinstance(name_or_dataset, datasets.Dataset):
            self._dataset = name_or_dataset
        else:
            self._name = name_or_dataset
            self._subset = subset
            self._dataset = datasets.load_dataset(self._name, subset)[split]
            subset_print_str = f", subset {_cb(subset)}" if subset else ""
            textattack.shared.logger.info(
                f"Loading {_cb('datasets')} dataset {_cb(self._name)}{subset_print_str}, split {_cb(split)}."
            )
        # Input/output column order, like (('premise', 'hypothesis'), 'label')
        (
            self.input_columns,
            self.output_column,
        ) = dataset_columns or get_datasets_dataset_columns(self._dataset)

        if not isinstance(self.input_columns, (list, tuple)):
            raise ValueError(
                "First element of `dataset_columns` must be a list or a tuple."
            )

        self.label_map = label_map
        self.output_scale_factor = output_scale_factor
        if label_names:
            self.label_names = label_names
        else:
            try:
                self.label_names = self._dataset.features[self.output_column].names
            except (KeyError, AttributeError):
                # This happens when the dataset doesn't have 'features' or a 'label' column.
                self.label_names = None

        # If labels are remapped, the label names have to be remapped as well.
        if self.label_names and label_map:
            self.label_names = [
                self.label_names[self.label_map[i]] for i in self.label_map
            ]

        self.shuffled = shuffle
        if shuffle:
            self._dataset.shuffle()

    def _format_as_dict(self, example):
        input_dict = collections.OrderedDict(
            [(c, example[c]) for c in self.input_columns]
        )

        output = example[self.output_column]
        if self.label_map:
            output = self.label_map[output]
        if self.output_scale_factor:
            output = output / self.output_scale_factor

        return (input_dict, output)

    def filter_by_labels_(self, labels_to_keep):
        """Filter items by their labels for classification datasets. Performs
        in-place filtering.

        Args:
            labels_to_keep (:obj:`Union[Set, Tuple, List, Iterable]`):
                Set, tuple, list, or iterable of integers representing labels.
        """
        if not isinstance(labels_to_keep, set):
            labels_to_keep = set(labels_to_keep)
        self._dataset = self._dataset.filter(
            lambda x: x[self.output_column] in labels_to_keep
        )

    def __getitem__(self, i):
        """Return i-th sample."""
        if isinstance(i, int):
            return self._format_as_dict(self._dataset[i])
        else:
            # `idx` could be a slice or an integer. if it's a slice,
            # return the formatted version of the proper slice of the list
            return [
                self._format_as_dict(self._dataset[j]) for j in range(i.start, i.stop)
            ]

    def shuffle(self):
        self._dataset.shuffle()
        self.shuffled = True
