from datasets import load_dataset
import json


def merge_columns(example):
    example["text"] = (
        "TEXT:"
        + example["text"]
        + "\nEMOTION: "
        + ["sadness", "joy", "love", "anger", "fear", "surprise"][example["label"]]
    )
    return example


def load_and_preprocess_data(
    dataset_name, tokenizer, cache_dir: str, max_tokens_per_dataset_item: int
):
    data = load_dataset(dataset_name, cache_dir=cache_dir, split="train[:]")

    merged_data = data.map(merge_columns)
    tokenizer.pad_token = tokenizer.eos_token
    filtered_data = merged_data.filter(
        lambda samples: len(tokenizer(samples["text"])["input_ids"])
        <= max_tokens_per_dataset_item
    ).map(
        lambda samples: tokenizer(
            samples["text"], padding="max_length", max_length=1024
        ),
        batched=True,
        remove_columns=["label"],
    )

    return filtered_data
