from datasets import load_dataset
from helper.thirdparty.tofu.data_module import TextDatasetQA

def get_hf_dataset(dataset_name,
                   tokenizer,
                   max_seq_length=1024,
                   ):

    if dataset_name == "ag_news":
        dataset = load_dataset(dataset_name)
        train_dataset = dataset["train"]
        test_dataset = dataset["test"]
        def tokenize_function(examples):
            tokenizer.truncation_side = "left"
            tokenized_inputs = tokenizer(examples["text"], 
                                         return_tensors="np",
                                         truncation=True,
                                         max_length=max_seq_length)

            return tokenized_inputs

        train_dataset = train_dataset.map(tokenize_function, batched=True)
        test_dataset = test_dataset.map(tokenize_function, batched=True)
    elif dataset_name == "tofu":
        dataset = load_dataset("locuslab/TOFU", "full")
        train_dataset = dataset["train"]
        train_dataset = TextDatasetQA(train_dataset,
                                    tokenizer,
                                    model_family='llama2-7b',
                                    max_length=max_seq_length,
                                    question_key='question',
                                    answer_key='answer')
        test_dataset = train_dataset
    else:
        raise NotImplementedError

    return train_dataset, test_dataset
