"""
This package includes functions for tokenizing all datasets available in our repo
"""

import functools
import os
from copy import deepcopy
from itertools import chain
from typing import Dict, Tuple, Union

import numpy as np

# from mistral_common.protocol.instruct.messages import UserMessage
# from mistral_common.protocol.instruct.request import ChatCompletionRequest
# from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
import torch
from torch import col_indices_copy
from transformers import (
    PretrainedConfig,
    PreTrainedModel,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
)

from datasets import Dataset, DatasetDict, load_dataset

from .constants import (
    BASE_DIR,
    GLUE_DATASETS,
    SUPERGLUE_DATASETS,
    TASK_TO_KEYS,
    TRAIN_SPLIT,
    VAL_SPLIT,
)


def preprocess_boolq(row):
    return ""


def preprocess_wic(row):
    return f"wic sentence1:{row['sentence1']} sentence2:{row['sentence2']} word:{row['word']}"


def preprocess_cb(row):
    return f"cb hypothesis:{row['hypothesis']} premise:{row['premise']}"


def preprocess_copa(row):
    question = row["question"]
    premise = row["premise"]
    choice1 = row["choice1"]
    choice2 = row["choice2"]
    arg = f"copa choice1:{choice1} choice2:{choice2} premise:{premise} question:{question}"
    return arg


# prompt templates taken from this paper / repo
# https://arxiv.org/pdf/1910.10683 (google-research/text-to-text-transfer-transformer)
SUPERGLUE_PROCESSORS = {
    "wic": preprocess_wic,
    "cb": preprocess_cb,
    "copa": preprocess_copa,
    "boolq": preprocess_boolq,
}


def tokenize_glue(
    tokenizer: PreTrainedTokenizer,
    task: str,
    model: PreTrainedModel,
    should_pad: bool,
    max_length: int,
    full_train=False,
):
    """
    Tokenizes a GLUE (or SuperGLUE) dataset for a given task
    """
    if full_train:
        split = [
            "train",
            "validation_matched" if task == "mnli" else "validation",
        ]
    else:
        split = [
            f"train[:{TRAIN_SPLIT[task]}%]",
            f"train[-{VAL_SPLIT[task]}%:]",
        ]

    if task == "boolq":
        train_dataset, eval_dataset = load_dataset(
            "google/boolq",
            split=split,
        )
    else:
        train_dataset, eval_dataset = load_dataset(
            "super_glue" if task in SUPERGLUE_DATASETS else "nyu-mll/glue",
            task,
            split=split,
        )

    assert isinstance(train_dataset, Dataset)
    assert isinstance(eval_dataset, Dataset)

    # this is from the LoftQ study for preprocessing the dataset, see here:
    # https://github.com/yxli2123/LoftQ/blob/ed5ba19e285b598109c7915586434d81e3d34748/glue/run_glue.py#L323
    is_regression = task == "stsb"
    label_key = "answer" if task == "boolq" else "label"
    if task == "boolq":
        num_labels = 2
        label_list = [0, 1]
    elif not is_regression:
        label_list = train_dataset.features[label_key].names
        num_labels = len(label_list)
    else:
        num_labels = 1
    label_to_id = None
    if (
        model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
        and not is_regression
    ):
        label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
        if sorted(label_name_to_id.keys()) == sorted(label_list):
            label_to_id = {
                i: label_name_to_id[label_list[i]] for i in range(num_labels)
            }
    elif not is_regression:
        label_to_id = {v: i for i, v in enumerate(label_list)}

    # TODO: analyze the coverage of these cases better
    if label_to_id is not None:
        model.config.label2id = label_to_id
        model.config.id2label = {
            id: label for label, id in model.config.label2id.items()
        }
    elif not is_regression:
        model.config.label2id = {l: i for i, l in enumerate(label_list)}
        model.config.id2label = {
            id: label for label, id in model.config.label2id.items()
        }

    sentence1, sentence2 = TASK_TO_KEYS.get(task, (None, None))

    # this is also from LoftQ, but is almost the exact same as the setup from
    # Amazon's design spaces paper
    def preprocess_function(examples):
        # Tokenize the texts
        if task in SUPERGLUE_PROCESSORS:
            args_ex = SUPERGLUE_PROCESSORS[task](examples)
            result = tokenizer(
                args_ex["inputs"] if task == "wsc" else args_ex,
                padding=should_pad,
                max_length=max_length,
                truncation=True,
            )
        else:
            args_ex = (
                (examples[sentence1],)
                if sentence2 is None
                else (examples[sentence1], examples[sentence2])
            )
            result = tokenizer(
                *args_ex, padding=should_pad, max_length=max_length, truncation=True
            )

        if label_key in examples:
            result["labels"] = examples[label_key]
            if task == "boolq":
                result["labels"] = np.array(result["labels"], dtype=int).tolist()
        else:
            raise RuntimeError("Labels not found.")
        return result

    processed_train = train_dataset.map(
        preprocess_function,
        batched=task in GLUE_DATASETS,
        remove_columns=train_dataset.column_names,
        desc="Running tokenizer on train",
    )
    processed_eval = eval_dataset.map(
        preprocess_function,
        batched=task in GLUE_DATASETS,
        remove_columns=train_dataset.column_names,
        desc="Running tokenizer on eval",
    )

    return processed_train, processed_eval


def tokenize_arc(
    *,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    version: str = "ARC-Easy",
    validation_set: str = "validation",
    **kwargs,
) -> Tuple[Dataset, Dataset]:
    assert version in [
        "ARC-Easy",
        "ARC-Challenge",
    ], f"Invalid ARC version type {version}"
    assert validation_set in [
        "validation",
        "test",
    ], f"Invalid validation set {validation_set}"

    raw_dataset = load_dataset(
        "allenai/ai2_arc",
        name=version,
        trust_remote_code=True,
    )

    keys_to_labels = {"1": 0, "A": 0, "2": 1, "B": 1, "3": 2, "C": 2, "4": 3, "D": 3}

    def preprocess_function(examples):
        questions = examples["question"]
        options = examples["choices"]
        answers = examples["answerKey"]

        inputs = []
        labels = []

        for question, options, answer in zip(questions, options, answers):
            input_text = f"Question: {question}\n\nChoices:\n"

            # skip examples with more than 5 options
            if "E" in options["label"]:
                continue
            for key, value in zip(options["label"], options["text"]):
                key = keys_to_labels[key]
                input_text += f"{key}: {value}\n"
            input_text += "\nAnswer: "

            inputs.append(input_text)
            labels.append(keys_to_labels[answer])

        return {"input_text": inputs, "label": labels}

    processed_dataset = raw_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=raw_dataset["train"].column_names,
    )

    def tokenize_function(examples):
        result = tokenizer(
            examples["input_text"],
            truncation=True,
            **kwargs,
        )

        result["labels"] = examples["label"]

        return result

    tokenized_dataset = processed_dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=processed_dataset["train"].column_names,
    )

    return tokenized_dataset["train"], tokenized_dataset[validation_set]


def tokenize_hellaswag(
    *,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    validation_set: str = "validation",
    **kwargs,
) -> Tuple[Dataset, Dataset]:
    assert validation_set in [
        "validation",
        "test",
    ], f"Invalid validation set {validation_set}"

    raw_train, raw_eval = load_dataset(
        "Rowan/hellaswag",
        split=[
            "train",
            "validation",
        ],
    )

    def preprocess_function(examples):
        if len(examples["label"]) == 0:
            return {}

        input_text = f"{examples['activity_label']}\n\n" f"{examples['ctx']}\n\n"

        for i, choice in enumerate(examples["endings"]):
            input_text += f"{i}: {choice}\n"
        input_text += "\nAnswer: "

        return {"input_text": input_text, "label": int(examples["label"])}

    processed_train = raw_train.map(
        preprocess_function,
        remove_columns=raw_train.column_names,
    )
    processed_eval = raw_eval.map(
        preprocess_function,
        remove_columns=raw_train.column_names,
    )

    if "should_pad" in kwargs:
        should_pad = kwargs.pop("should_pad")

    def tokenize_function(examples):
        result = tokenizer(
            examples["input_text"],
            truncation=True,
            **kwargs,
        )
        result["labels"] = examples["label"]
        return result

    tokenized_train = processed_train.map(
        tokenize_function,
        batched=True,
        remove_columns=processed_train.column_names,
    )
    tokenized_eval = processed_eval.map(
        tokenize_function,
        batched=True,
        remove_columns=processed_train.column_names,
    )

    if validation_set == "validation":
        # pass in seed to make sure we get the same split every time
        split = tokenized_train.train_test_split(train_size=0.15, test_size=0.3, seed=0)
        tokenized_train, tokenized_eval = split["train"], split["test"]

    return tokenized_train, tokenized_eval


MMLU_FEWSHOT = """
Question: What is the capital of France?
0: Berlin
1: Madrid
2: Paris
3: Rome
Answer: 2

Question: What is the largest planet in our solar system?
0: Earth
1: Mars
2: Jupiter
3: Saturn
Answer: 2

Question: What is the powerhouse of the cell?
0: Nucleus
1: Mitochondria
2: Ribosome
3: Endoplasmic Reticulum
Answer: 1
"""


def tokenize_mmlu(
    *,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    validation_set: str = "validation",
    **kwargs,
) -> Tuple[Dataset, Dataset]:
    assert validation_set in [
        "validation",
        "test",
    ], f"Invalid validation set {validation_set}"

    raw_dataset = load_dataset("cais/mmlu", "all")
    assert isinstance(raw_dataset, DatasetDict)

    def preprocess_function(examples):
        questions = examples["question"]
        choices = examples["choices"]
        answers = examples["answer"]

        inputs = []
        labels = []

        for question, choice, answer in zip(questions, choices, answers):
            prompt = f"Question: {question}\n\nChoices:\n"
            for label, option in enumerate(choice):
                prompt += f"{label}: {option}\n"
            prompt += "\nAnswer: "
            prompt = MMLU_FEWSHOT + prompt
            inputs.append(prompt)
            labels.append(answer)

        return {"input_text": inputs, "label": labels}

    processed_dataset = raw_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=raw_dataset["auxiliary_train"].column_names,
    )

    def tokenize_function(examples):
        result = tokenizer(
            examples["input_text"],
            truncation=True,
            **kwargs,
        )
        result["labels"] = examples["label"]
        return result

    tokenized_dataset = processed_dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=processed_dataset["auxiliary_train"].column_names,
    )

    tokenized_dataset["train"] = tokenized_dataset["auxiliary_train"]

    if validation_set == "validation":
        tokenized_dataset = (
            tokenized_dataset["train"]
            .shuffle(seed=0)
            .train_test_split(
                train_size=0.10,
                test_size=0.10,
                seed=0,
            )
        )
        validation_set = "test"

    return tokenized_dataset["train"], tokenized_dataset[validation_set]


def tokenize_primevul(
    *,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    validation_set: str = "validation",
    **kwargs,
) -> Tuple[Dataset, Dataset]:
    assert validation_set in [
        "validation",
        "test",
    ], f'Invalid key: "{validation_set}".'

    data_dir = os.path.join(BASE_DIR, "datasets", "PrimeVul_Data/cleaned")
    raw_dataset = load_dataset(
        "json",
        data_dir=data_dir,
    )
    assert isinstance(raw_dataset, DatasetDict)

    max_tokens = 2048

    def tokenize_example(examples):
        result = {}

        # Codestral is weird so need to preprocess data differently
        if isinstance(tokenizer, MistralTokenizer):
            input_ids = []
            for example in examples["code"]:
                completion_request = ChatCompletionRequest(
                    messages=[UserMessage(content=example)]
                )
                tokens = tokenizer.encode_chat_completion(completion_request).tokens
                input_ids.append(tokens)

            result["input_ids"] = input_ids
        else:
            result = tokenizer(
                examples["code"],
                truncation=True,
                **kwargs,
            )

        result["labels"] = examples["label"]

        return result

    tokenized_dataset = raw_dataset.map(
        tokenize_example,
        batched=True,
        remove_columns=raw_dataset["train"].column_names,
        desc="Tokenizing dataset",
    )

    tokenized_dataset = tokenized_dataset.filter(
        lambda x: len(x["input_ids"]) < max_tokens,
        batched=False,
        desc="Filtering dataset",
    )

    if validation_set == "validation":
        # pass in seed to make sure we get the same split every time
        split = tokenized_dataset["train"].train_test_split(test_size=0.20, seed=0)
        tokenized_train, tokenized_test = split["train"], split["test"]
        return tokenized_train, tokenized_test
    else:
        return tokenized_dataset["train"], tokenized_dataset["test"]


def tokenize_cve(
    *,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    validation_set: str = "validation",
    **kwargs,
) -> Tuple[Dataset, Dataset]:
    raw_data = load_dataset(
        "json",
        data_dir=os.path.join(BASE_DIR, "datasets/cpp_vulnerabilities"),
    )
    assert isinstance(raw_data, DatasetDict)
    train_data = raw_data["train"]
    test_data = raw_data["test"]

    def preprocess(examples):
        result = {}
        input_ids = []

        if isinstance(tokenizer, MistralTokenizer):
            for example in examples["code"]:
                completion_request = ChatCompletionRequest(
                    messages=[UserMessage(content=example)]
                )
                tokens = tokenizer.encode_chat_completion(completion_request).tokens
                input_ids.append(tokens)

            result["input_ids"] = input_ids
        else:
            for example in examples["code"]:
                result = tokenizer(
                    example,
                    padding=False,
                    truncation=True,
                    **kwargs,
                )
                input_ids.append(result["input_ids"])

        result["labels"] = 1 if examples["label"] else 0
        return result

    tokenized_train = train_data.map(
        preprocess,
        batched=True,
        remove_columns=train_data.column_names,
        desc="Running tokenizer on train",
    )
    tokenized_test = test_data.map(
        preprocess,
        batched=True,
        remove_columns=train_data.column_names,
        desc="Running tokenizer on eval",
    )

    if validation_set == "validation":
        # pass in seed to make sure we get the same split every time
        split = tokenized_train.train_test_split(test_size=0.20, seed=0)
        tokenized_train, tokenized_test = split["train"], split["test"]

    return tokenized_train, tokenized_test


def tokenize_gsm8k(
    *,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    validation_set: str = "validation",
    **kwargs,
):
    assert validation_set in [
        "validation",
        "test",
    ], "please enter a valid `validation_set`"

    QUESTION_PROMPT = "\nAnswer the above question. First think step by step and then answer the final number.\n"

    raw_dataset = load_dataset(
        "openai/gsm8k",
        "main",
    )
    assert isinstance(raw_dataset, DatasetDict), "Error loading dataset"

    def tokenize_function(data_item):
        result = {}
        if (
            "Meta-Llama-3-8B-Instruct" in tokenizer.name_or_path
        ):  # pretty bad workaround for llama-3, forgive me
            system_prompt = "You are a helpful assistant."
            # we remove the BOS, otherwise there will be redundant BOS tokens.
            base_prompt = tokenizer.apply_chat_template(
                [
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": data_item["question"]},
                ],
                tokenize=False,
            )[len("<|begin_of_text|>") :]
            base_input = (
                tokenizer.apply_chat_template(
                    [
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": data_item["question"]},
                        {"role": "assistant", "content": data_item["answer"]},
                    ],
                    tokenize=False,
                )[len("<|begin_of_text|>") :]
                + tokenizer.eos_token
            )
        else:  # setup is from https://github.com/yxli2123/LoftQ/
            base_prompt = f"{data_item['question']}{QUESTION_PROMPT}"
            # note: we remove the extra space here to keep the format clean.
            base_input = (
                base_prompt
                + f"{data_item['answer']}{tokenizer.eos_token}".replace(
                    "####", "The final answer is: "
                )
            )

        # tokenize
        base_prompt_ids = tokenizer(
            base_prompt,
            max_length=kwargs.get("max_length", 1024),
            truncation=True,
            return_tensors="pt",
        )["input_ids"][0]
        base_prompt_length = len(base_prompt_ids)
        base_input_ids = tokenizer(
            base_input,
            max_length=kwargs.get("max_length", 1024),
            truncation=True,
            return_tensors="pt",
        )["input_ids"][0]

        output_ids = deepcopy(base_input_ids)
        output_ids[:base_prompt_length] = -100

        result["input_ids"] = base_input_ids
        result["labels"] = output_ids
        return result

    tokenized_dataset = raw_dataset.map(
        tokenize_function,
        batched=False,
        remove_columns=raw_dataset["train"].column_names,
    )

    if validation_set == "validation":
        percent_train = 0.10
        train_size = int(percent_train * len(tokenized_dataset["train"]))
        train_set = tokenized_dataset["train"].select(range(0, train_size))
        eval_set = tokenized_dataset["train"].select(
            range(train_size, len(tokenized_dataset["train"]))
        )
        return train_set, eval_set
    else:
        return tokenized_dataset["train"], tokenized_dataset[validation_set]


DATALOADER_MAP = {
    "arc-e": functools.partial(tokenize_arc, version="ARC-Easy"),
    "arc-c": functools.partial(tokenize_arc, version="ARC-Challenge"),
    "mmlu": tokenize_mmlu,
    "hellaswag": tokenize_hellaswag,
    "primevul": tokenize_primevul,
    "cve": tokenize_cve,
}


def calibrate_arc(
    *,
    version: str = "ARC-Easy",
    validation_set: str = "validation",
) -> Tuple[Dataset, Dataset]:
    assert version in [
        "ARC-Easy",
        "ARC-Challenge",
    ], f"Invalid ARC version type {version}"
    assert validation_set in [
        "validation",
        "test",
    ], f"Invalid validation set {validation_set}"

    raw_dataset = load_dataset(
        "allenai/ai2_arc",
        name=version,
        trust_remote_code=True,
    )

    keys_to_labels = {"1": 0, "A": 0, "2": 1, "B": 1, "3": 2, "C": 2, "4": 3, "D": 3}

    def preprocess_function(examples):
        questions = examples["question"]
        options = examples["choices"]
        answers = examples["answerKey"]

        inputs = []
        labels = []

        for question, options, answer in zip(questions, options, answers):
            input_text = f"Question: {question}\n\nChoices:\n"

            # skip examples with more than 5 options
            if "E" in options["label"]:
                continue
            for key, value in zip(options["label"], options["text"]):
                key = keys_to_labels[key]
                input_text += f"{key}: {value}\n"
            input_text += "\nAnswer: "

            inputs.append(input_text)
            labels.append(keys_to_labels[answer])

        return {"text": inputs}

    processed_dataset = raw_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=raw_dataset["train"].column_names,
    )

    return processed_dataset["train"]


def calibrate_mmlu(
    *,
    validation_set: str = "validation",
) -> Tuple[Dataset, Dataset]:
    assert validation_set in [
        "validation",
        "test",
    ], f"Invalid validation set {validation_set}"

    raw_dataset = load_dataset("cais/mmlu", "all")
    assert isinstance(raw_dataset, DatasetDict)

    def preprocess_function(examples):
        questions = examples["question"]
        choices = examples["choices"]
        answers = examples["answer"]

        inputs = []
        labels = []

        for question, choice, answer in zip(questions, choices, answers):
            prompt = f"Question: {question}\n\nChoices:\n"
            for label, option in enumerate(choice):
                prompt += f"{label}: {option}\n"
            prompt += "\nAnswer: "
            prompt = MMLU_FEWSHOT + prompt
            inputs.append(prompt)
            labels.append(answer)

        return {"text": inputs}

    processed_dataset = raw_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=raw_dataset["auxiliary_train"].column_names,
    )

    return processed_dataset["train"]


def tokenize_wikitext(
    *,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    validation_set: str = "validation",
    block_size: int = 1024,
    **kwargs,
) -> Tuple[Dataset, Dataset]:
    raw_dataset = load_dataset(
        "Salesforce/wikitext",
        "wikitext-2-raw-v1",
    )
    assert isinstance(raw_dataset, DatasetDict)

    def tokenize_function(examples):
        model_inputs = tokenizer(examples["text"])
        model_inputs["labels"] = model_inputs["input_ids"].copy()
        return model_inputs

    tokenized_datasets = raw_dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=list(raw_dataset["train"].features),
        desc="Running tokenizer on dataset",
    )

    def group_texts(examples):
        # concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        # total_length = len(concatenated_examples[list(examples.keys())[0]])
        # total_length = (total_length // block_size) * block_size

        # result = {
        #     k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        #     for k, t in concatenated_examples.items()
        # }
        # result["labels"] = result["input_ids"].copy()

        examples["labels"] = examples["input_ids"].copy()

        return examples

    # grouped_datasets = tokenized_datasets.map(
    #     group_texts,
    #     batched=True,
    #     desc=f"Grouping texts into chunks of {block_size}",
    # )

    grouped_datasets = tokenized_datasets

    if validation_set == "validation":
        train_percent = 0.20
        eval_percent = 0.20
        # probably a better way to do this
        return (
            grouped_datasets["train"].select(
                range(
                    0,
                    int(len(grouped_datasets["train"]["input_ids"]) * train_percent),
                )
            ),
            grouped_datasets["validation"].select(
                range(
                    0,
                    int(
                        len(grouped_datasets["validation"]["input_ids"]) * eval_percent
                    ),
                )
            ),
        )
    else:
        return (grouped_datasets["train"], grouped_datasets[validation_set])


summarization_datasets = {
    "xsum": ("EdinburghNLP/xsum", None),
    "dailymail": ("abisee/cnn_dailymail", "3.0.0"),
}

summarization_name_mapping = {
    "dailymail": ("article", "highlights"),
    "xsum": ("document", "summary"),
}


def tokenize_summarization(
    *,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    validation_set: str = "validation",
    dataset: str = "xsum",
    **kwargs,
) -> Tuple[Dataset, Dataset]:
    assert dataset in summarization_datasets
    task, name = summarization_datasets[dataset]

    if validation_set == "validation":
        split = [
            "train[:100]",
            "train[:100]",
        ]
    else:
        split = [
            "train",
            "test",
        ]

    train_dataset, test_dataset = load_dataset(
        task,
        name=name,
        split=split,
    )
    assert isinstance(train_dataset, Dataset)
    assert isinstance(test_dataset, Dataset)

    text_column, summary_column = summarization_name_mapping[dataset]

    # padding = "max_length" if padding else padding

    def tokenize_function(example):
        """Tokenize and prepare inputs and labels for training."""
        IGNORE_INDEX = -100  # Define the ignore index for labels

        sources = example[text_column]
        targets = example[summary_column]

        example = sources + " Summary: " + targets
        modified_source = sources + " Summary: "

        # Tokenize concatenated examples
        examples_tokenized = tokenizer(
            example,
            return_tensors="pt",
            # truncation=True,
            **kwargs,
        )

        # Tokenize sources to obtain source lengths
        sources_tokenized = tokenizer(
            modified_source,
            return_tensors="pt",
            # truncation=True,
            **kwargs,
        )

        input_ids = examples_tokenized["input_ids"][-1]
        labels = input_ids.clone()

        # Compute lengths of source sequences (excluding padding tokens)
        source_len = sources_tokenized["input_ids"][-1].ne(tokenizer.pad_token_id).sum()

        # Set labels corresponding to source tokens to IGNORE_INDEX
        labels[:source_len] = IGNORE_INDEX

        return {"input_ids": input_ids, "labels": labels}

    column_names = train_dataset.column_names
    train_dataset = train_dataset.map(
        tokenize_function,
        batched=False,
        remove_columns=column_names,
        desc="Running tokenizer on train dataset",
    )
    test_dataset = test_dataset.map(
        tokenize_function,
        batched=False,
        remove_columns=column_names,
        desc="Running tokenizer on test dataset",
    )

    return train_dataset, test_dataset
