"""

Ted Multi TranslationDataset Class
------------------------------------
"""


import collections

import datasets
import numpy as np

from textattack.datasets import HuggingFaceDataset


class TedMultiTranslationDataset(HuggingFaceDataset):
    """Loads examples from the Ted Talk translation dataset using the
    `datasets` package.

    dataset source: http://www.cs.jhu.edu/~kevinduh/a/multitarget-tedtalks/
    """

    def __init__(self, source_lang="en", target_lang="de", split="test"):
        self._dataset = datasets.load_dataset("ted_multi")[split]
        self.examples = self._dataset["translations"]
        language_options = set(self.examples[0]["language"])
        if source_lang not in language_options:
            raise ValueError(
                f"Source language {source_lang} invalid. Choices: {sorted(language_options)}"
            )
        if target_lang not in language_options:
            raise ValueError(
                f"Target language {target_lang} invalid. Choices: {sorted(language_options)}"
            )
        self.source_lang = source_lang
        self.target_lang = target_lang

    def _format_raw_example(self, raw_example):
        translations = np.array(raw_example["translation"])
        languages = np.array(raw_example["language"])
        source = translations[languages == self.source_lang][0]
        target = translations[languages == self.target_lang][0]
        source_dict = collections.OrderedDict([("Source", source)])
        return (source_dict, target)
