import transformers
from datasets import (
    get_dataset_split_names,
    load_dataset,
    concatenate_datasets
)
from loguru import logger

DATASET_REGISTRY = {
    "rotten_tomatoes": "rotten_tomatoes",
    "sst2": "stanfordnlp/sst2",
    "yelp_review_full": "yelp_review_full",
    "imdb": "imdb",
    "imdb-large": "imdb",
    "wiki_toxic": "OxAISH-AL-LLM/wiki_toxic",
    "toxigen": "toxigen/toxigen-data",
    "bias_in_bios": "LabHC/bias_in_bios",
    "polarity": "fancyzhx/amazon_polarity",
    "emotion": "dair-ai/emotion",
    "snli": "stanfordnlp/snli",
    "medical": "medical_questions_pairs",
}

def from_name(name: str):
    assert name in DATASET_REGISTRY
    hf_name = DATASET_REGISTRY[name]
    splits = get_dataset_split_names(hf_name)
    logger.info(f"Loading {hf_name} — splits {splits}")

    data = dict()
    if name == "imdb-large":
        imdb_train = load_dataset("imdb", split="train").remove_columns(["label"])
        imdb_unsupervised = load_dataset("imdb", split="unsupervised")
        concatenated_dataset = concatenate_datasets([imdb_train, imdb_unsupervised])

        data["train"] = concatenated_dataset.select_columns(["text"])
        data["test"] = load_dataset("imdb", split="test").select_columns(["text"])
    else:
        for split in splits:
            dataset = load_dataset(hf_name, split=split)

            if hf_name == "OxAISH-AL-LLM/wiki_toxic":
                dataset = dataset.rename_column("comment_text", "text")
            elif hf_name == "fancyzhx/amazon_polarity":
                dataset = dataset.rename_column("content", "text")
            elif hf_name == "stanfordnlp/sst2":
                dataset = dataset.rename_column("sentence", "text")
            elif hf_name == "stanfordnlp/snli":

                def preprocess(example):
                    for i, v in enumerate(example["hypothesis"]):
                        example["premise"][i] += " " + v
                    return example

                dataset = dataset.map(preprocess, batched=True)
                dataset = dataset.rename_column("premise", "text")
            elif hf_name == "medical_questions_pairs":

                def preprocess(example):
                    for i, v in enumerate(example["question_2"]):
                        example["question_1"][i] += " " + v
                    return example

                dataset = dataset.map(preprocess, batched=True)
                dataset = dataset.rename_column("question_1", "text")
            elif hf_name == "LabHC/bias_in_bios":
                dataset = dataset.rename_column("hard_text", "text")

            data[split] = dataset.select_columns(["text"])
            logger.info(
                f"Extract split: {split} | {data[split].shape} | {data[split].column_names}"
            )

    return data
