#builds the embeddings for later use of a language model
from regression.lm_embeddings.embeddings_store import SessionStoryEmbeddingsFeatureLoader, MEGFeatureMapStore
from regression.lm_embeddings.story_to_embeddings import BuildStoryEmbeddings, n_previous_words_context
from regression.session_story_configs import subject_configs, subject_test_configs, subject_train_configs
from transformers import AutoConfig, LlamaForCausalLM, AutoTokenizer
from transformers import GPT2LMHeadModel, AutoConfig, AutoTokenizer

import os
from tqdm.auto import tqdm
import numpy as np

model = "llama2"
context_lens = list(range(1,20))
layers = [3]
dataset_loc = "./data"
embeddings_loc = "./embeddings"
model_loc = "./llama2-7b-hf"

stories = [x.split(".")[0] for x in os.listdir(f"{dataset_loc}/clean_words_and_times")]
lm_config = AutoConfig.from_pretrained(model_loc, local_files_only=True)
lm = LlamaForCausalLM.from_pretrained(model_loc, config = lm_config, local_files_only=True).to("cpu")
tokenizer = AutoTokenizer.from_pretrained(model_loc)

for layer in layers:
    for context_len in context_lens:
        context_fn = n_previous_words_context(context_len - 1)
        builder = BuildStoryEmbeddings(tokenizer, lm, context_fn,
                                       f"{embeddings_loc}/embeddings_sweep/{model}/layer_{layer}_context_{context_len}",
                                       layer = layer, batch_size=16, device="cuda")
        builder.make_all_embeddings(stories)