from transformers import AutoTokenizer, LlamaTokenizer
from datasets import Dataset
from ..configs.MMLU.general import load_data


if __name__ == "__main__":
    flan_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
    alpaca_tokenizer = LlamaTokenizer.from_pretrained("./models/alpaca-7b")
    
    flan_unobserved_dataset = load_data("./data/MMLU/", "MMLU/Unobserved/1", 0, False, flan_tokenizer)
    flan_observed_dataset = load_data("./data/MMLU/", "NIV2/QA/1", 0, False, flan_tokenizer)
    alpaca_unobserved_dataset = load_data("./data/MMLU/", "Alpaca/Unobserved/1", 0, False, alpaca_tokenizer)
    alpaca_observed_dataset = load_data("./data/MMLU/", "Alpaca/QA/1", 0, False, alpaca_tokenizer)
        
    flan_unobserved_dataset.save_to_disk("./catastrophic_forgetting/data/flan_mmlu_unobserved")
    flan_observed_dataset.save_to_disk("./catastrophic_forgetting/data/flan_mmlu_observed")
    alpaca_unobserved_dataset.save_to_disk("./catastrophic_forgetting/data/alpaca_mmlu_unobserved")
    alpaca_observed_dataset.save_to_disk("./catastrophic_forgetting/data/alpaca_mmlu_observed")
    
    