"""Test sequence length pajama"""

import os
import numpy as np
from functools import partial
from transformers import AutoTokenizer
from datasets import DatasetDict
from latte_trans.preproc.slim_pajama import SlimPajama


def tokenize(tokenizer, elem):
    elem = tokenizer(elem["text"], return_length=True, truncation=False, padding=False)
    return {"input_ids": elem["input_ids"], "length": elem["length"]}


def main():
    base_dir = "/data_user/data/"
    data_path = "/user_all_data/data/input/pajama_raw"
    cache_dir = os.path.join(base_dir, "input/test_pajama")
    raw_data = DatasetDict.load_from_disk(data_path)

    tokenizer = AutoTokenizer.from_pretrained(
        "google/gemma-2-2b",
        cache_dir=os.path.join(base_dir, "input/cache_hugg"),
        truncation_side="right",
        padding_side="right",
    )

    raw_data["train"] = raw_data["train"].select(np.arange(100))
    raw_data["validation"] = raw_data["validation"].select(np.arange(2000))
    raw_data["test"] = raw_data["test"].select(np.arange(1000))

    tokenized_data = raw_data["validation"].map(
        partial(tokenize, tokenizer),
        num_proc=4,
        remove_columns=["text", "meta"],
        cache_file_name=os.path.join(cache_dir, "test_len.bin"),
    )
    print(tokenized_data)
    print("Mean data: ", np.mean(tokenized_data["length"]))


if __name__ == "__main__":
    # pdm run python3 -m latte_trans.tests.ideas.pajama_seq
    main()
