"""
Load models and datasets for testing.
Used by test_llm.py
"""

from transformers import (
    BertForSequenceClassification,
    BertTokenizerFast,
    GPT2LMHeadModel,
    GPT2TokenizerFast,
)
from datasets import load_dataset
from pyrootutils.pyrootutils import setup_root
import torch

root = setup_root(__file__, ".root", pythonpath=True)

TRAIN_LEN = 25000
EVAL_LEN = 500


def load_bert():
    tokenizer = BertTokenizerFast.from_pretrained(f"{root}/models/bert")
    dataset = load_dataset(
        "parquet", data_files=[f"{root}/datasets/imdb/train.parquet"]
    )["train"]
    train_dataset = dataset.map(
        lambda examples: tokenizer(
            examples["text"], max_length=512, truncation=True, padding=False
        ),
        batched=True,
        remove_columns=["text"],
    )

    train_dataset = train_dataset.shuffle(seed=42).select(range(TRAIN_LEN))

    model = BertForSequenceClassification.from_pretrained(
        f"{root}/models/bert", num_labels=2
    )
    model = model.to(torch.bfloat16)

    return model, tokenizer, train_dataset


def load_gpt2():
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
    tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token  # GPT-2 has no pad token

    def tok(batch):
        tokens = tokenizer(
            batch["text"], padding="max_length", truncation=True, max_length=64
        )
        tokens["labels"] = tokens["input_ids"].copy()
        return tokens

    dataset = dataset.map(tok, batched=True)
    dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
    train_dataset = dataset["train"].select(range(TRAIN_LEN))
    eval_dataset = dataset["validation"].select(range(EVAL_LEN))

    model = GPT2LMHeadModel.from_pretrained("gpt2")

    return model, train_dataset, eval_dataset


def load_model(type):
    if type == "bert":
        return load_bert()
    elif type == "gpt2":
        return load_gpt2()
    else:
        raise ValueError(f"Unknown model type: {type}")
