from enum import StrEnum, unique

from .qa_datasets import (
    MAQA,
    AbstractQADataset,
    AmazonReviewsDataset,
    AmbigQA,
    CnnDailyMailDataset,
    FolkTextsDataset,
    GeminiNumericQuestionsDataset,
    GSM8KDataset,
    KGSummariesDataset,
    MAQA_Simplex,
    SATABENCHDataset,
    SimpleQAVerifiedDataset,
    STSBDataset,
    ToxigenDataset,
    TriviaQADataset,
    WMT19Dataset,
    XSumDataset,
)


@unique
class DatasetName(StrEnum):
    TRIVIAQA = "TRIVIAQA"
    AMBIGQA = "AMBIGQA"
    MAQA = "MAQA"
    CNN_DAILYMAIL = "CNN_DAILYMAIL"
    XSUM = "XSUM"
    KG_SUMMARIES = "KG_SUMMARIES"
    SATA_BENCH = "SATA_BENCH"
    STSB = "STSB"
    AMAZON_REVIEWS = "AMAZON_REVIEWS"
    SIMPLE_QA_VERIFIED = "SIMPLE_QA_VERIFIED"
    WMT19_DEEN = "WMT19_DEEN"
    WMT19_FIEN = "WMT19_FIEN"
    GSM8K = "GSM8K"
    GEMINI_NUMERICAL_QA = "GEMINI_NUMERICAL_QA"
    FOLKTEXTS = "FOLKTEXTS"
    MAQA_SIMPLEX = "MAQA_SIMPLEX"
    TOXIGEN = "TOXIGEN"
    GEMINI_QUERY_30 = "GEMINI_QUERY_30"


def build_dataset(dataset_name: str, limit: int | None = None) -> AbstractQADataset:
    """Construct and optionally trim a QA dataset by name.

    Args:
        dataset_name (str): Identifier for which dataset implementation to load
            (``"TRIVIAQA"``, ``"AMBIGQA"``, or ``"MAQA"``).
        limit (int, optional): Maximum number of samples to retain. When provided,
            the dataset is truncated to the first ``limit`` entries.

    Returns:
        AbstractQADataset: Instantiated dataset aligned to the requested name.

    Raises:
        ValueError: If ``dataset_name`` is not one of the supported options.
    """
    match dataset_name:
        case DatasetName.TRIVIAQA:
            dataset = TriviaQADataset()
        case DatasetName.AMBIGQA:
            dataset = AmbigQA()
        case DatasetName.MAQA:
            dataset = MAQA()
        case DatasetName.CNN_DAILYMAIL:
            dataset = CnnDailyMailDataset()
        case DatasetName.XSUM:
            dataset = XSumDataset()
        case DatasetName.KG_SUMMARIES:
            dataset = KGSummariesDataset()
        case DatasetName.SATA_BENCH:
            dataset = SATABENCHDataset()
        case DatasetName.STSB:
            dataset = STSBDataset()
        case DatasetName.AMAZON_REVIEWS:
            dataset = AmazonReviewsDataset()
        case DatasetName.SIMPLE_QA_VERIFIED:
            dataset = SimpleQAVerifiedDataset()
        case DatasetName.WMT19_DEEN:
            dataset = WMT19Dataset(language_pair="de-en")
        case DatasetName.WMT19_FIEN:
            dataset = WMT19Dataset(language_pair="fi-en")
        case DatasetName.GSM8K:
            dataset = GSM8KDataset()
        case DatasetName.GEMINI_NUMERICAL_QA:
            dataset = GeminiNumericQuestionsDataset()
        case DatasetName.FOLKTEXTS:
            dataset = FolkTextsDataset()
        case DatasetName.MAQA_SIMPLEX:
            dataset = MAQA_Simplex()
        case DatasetName.TOXIGEN:
            dataset = ToxigenDataset()
        case DatasetName.GEMINI_QUERY_30:
            dataset = GeminiNumericQuestionsDataset(
                dataset_path="numerical_questions/hf_dataset_queries_30"
            )
        case _:
            raise ValueError(f"Unknown dataset: {dataset_name}")

    if limit is not None:
        dataset.dataset = dataset.dataset.select(range(limit))
    return dataset
