import pandas as pd
import torch 
from transformers import AutoTokenizer, AutoModel
import numpy as np
import os

layer_th = 14
input_file = "train_4096.csv"
output_path = f"C4/{layer_th}"

def load_embeddings(input_file_path = output_path):
    embeddings = []
    for i in range(8):
        file_name = os.path.join(input_file_path, f"sentence_{i}_embedding.npy")
        with open(file_name, 'rb') as f:
            embedding = np.load(f)
            embeddings.append(embedding)
    return embeddings


if __name__ == "__main__":
    # read c4.csv
    df = pd.read_csv(input_file)

    model_path = "./llama3.2"
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModel.from_pretrained(model_path)
    model.eval()

    print(model)

    hidden_states = []
    def forward_hook(module, input, output):
        hidden_states.append(output)

    # embed_layer = model.get_input_embeddings() # change the hook layer here can get different hidden states
    embed_layer = model.layers[layer_th]
    hook_handle = embed_layer.register_forward_hook(forward_hook)

    for i, row in df.iterrows():
        input_text = row["text"]
        input_ids = tokenizer(input_text, return_tensors="pt").input_ids

        # get the hidden states after the first mlp layer, before the attention layer
        _ = model(input_ids)
        embedding_output = hidden_states[0][0]
        embedding_output = embedding_output.detach().numpy()

        file_name = os.path.join(output_path, f"sentence_{i}_embedding.npy")
        print(f"process {i} sentence")
        with open(file_name, 'wb') as f:
            np.save(f, embedding_output)
            print(f"save {file_name}")
        print(f"process finish {i} sentence")
        hidden_states.clear()
        
        
        

        

    