from re import split
from datasets import load_dataset, load_metric

from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM


def prep_dataset(
    dataset_name, dataset_config, dataset_split, size, tokenizer, switch_lang=False
):

    _dataset = []
    if dataset_name == "daily_dialog":
        dataset = load_dataset(dataset_name)
        dataset = dataset[dataset_split]
        dataset = dataset[:size]
        for element in dataset["dialog"]:
            acc = ""
            for sentence in element:
                acc += sentence + tokenizer.eos_token
                _dataset.append((acc, sentence))

        return _dataset[:size]

    if dataset_name == "multi_woz_v22":
        dataset = load_dataset(dataset_name, dataset_config, ignore_verifications=True)
        dataset = dataset[dataset_split]

        for e in dataset:
            for k, sentence in enumerate(e["turns"]["utterance"][:-1]):
                _dataset.append(
                    (sentence + tokenizer.eos_token, e["turns"]["utterance"][k + 1])
                )

        _dataset = _dataset[:size]

    elif dataset_name == "silicone":
        dataset = load_dataset(dataset_name, dataset_config, ignore_verifications=True)
        dataset = dataset[dataset_split]
        _dataset = [(e["Utterance"], "a") for e in dataset][:size]

    elif dataset_name == "movieqa":
        dataset = load_dataset("wiki_movies", ignore_verifications=True)
        dataset = dataset[dataset_split]["text"]
        dataset = dataset[:size]
        for element in dataset:
            q, a = element.split("\t")
            q = q[1:]
            _dataset.append((q, a))

    elif (
        dataset_name == "wmt16"
        or dataset_name == "news_commentary"
        or dataset_name == "qanastek/EMEA-V3"
    ):
        dataset = load_dataset(dataset_name, dataset_config)
        src, tgt = dataset_config.split("-")
        dataset = dataset[dataset_split]["translation"][:size]

        if switch_lang:
            _dataset = [(element[tgt], element[src]) for element in dataset]
        else:
            _dataset = [(element[src], element[tgt]) for element in dataset]

    elif dataset_name == "Helsinki-NLP/tatoeba_mt":
        dataset = load_dataset(dataset_name, dataset_config, ignore_verifications=True)
        src, tgt = dataset_config.split("-")
        dataset = dataset[dataset_split][:size]

        if switch_lang:
            _dataset = [
                (x, y) for x, y in zip(dataset["targetString"], dataset["sourceString"])
            ]
        else:
            _dataset = [
                (x, y) for x, y in zip(dataset["sourceString"], dataset["targetString"])
            ]

    elif dataset_name == "europarl_bilingual":
        lang1, lang2 = dataset_config.split("-")
        dataset = load_dataset(
            dataset_name, lang1=lang1, lang2=lang2, ignore_verifications=True
        )

        dataset = dataset[dataset_split]

        if switch_lang:
            _dataset = [
                (d["translation"][lang2], d["translation"][lang1]) for d in dataset
            ]
        else:
            _dataset = [
                (d["translation"][lang1], d["translation"][lang2]) for d in dataset
            ]

        _dataset = _dataset[:size]
    elif dataset_name == "amazon_reviews_multi":
        dataset = load_dataset(dataset_name, dataset_config, ignore_verifications=True)

        dataset = dataset[dataset_split]
        _dataset = [(d["review_title"], d["review_title"]) for d in dataset][:size]

    elif dataset_name == "multi_eurlex":
        dataset = load_dataset(dataset_name, "all_languages", ignore_verifications=True)

        dataset = dataset[dataset_split]
        lang1, lang2 = dataset_config.split("-")

        _dataset = [
            (d["text"][lang1], d["text"][lang2])
            for d in dataset
            if d["text"][lang1] and d["text"][lang2]
        ][:size]

    return _dataset


def prep_model(model_name):
    if model_name == "microsoft/DialoGPT-medium" or model_name == "tosin/dialogpt_mwoz":
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        tokenizer.truncation_side = "left"
        tokenizer.model_max_length = 50
        model = AutoModelForCausalLM.from_pretrained(model_name)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    return model, tokenizer


def prep_inputs(x, tokenizer, dataset_name):
    if dataset_name == "daily_dialog" or dataset_name == "movieqa":
        inputs = tokenizer(x, return_tensors="pt", truncation=True)
    else:
        inputs = tokenizer(x, return_tensors="pt", truncation=True)

    return inputs
