import torch
import json
import os
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

def get_all_hidden_states(model_path, data_path, save_dir):
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
    
    with open(data_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    layer_num = model.config.num_hidden_layers + 1
    hidden_dict = [{} for _ in range(layer_num)]

    for k, item in enumerate(tqdm(data)):
        exec_pos, ref_pos, tran_pos, text = get_position(item, tokenizer)
        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            hidden_states = outputs.hidden_states
            hidden_states = [h.detach().cpu() for h in hidden_states]
        

        exec_pos_tensor = torch.LongTensor(exec_pos)
        ref_pos_tensor = torch.LongTensor(ref_pos)
        tran_index_tensor = torch.LongTensor(tran_pos)

        for i in range(len(hidden_states)):
            h = hidden_states[i][0]
            exec_h = h[exec_pos_tensor]
            ref_h = h[ref_pos_tensor]
            tran_h = h[tran_index_tensor]
            hidden_dict[i][k] = {"exec": exec_h, "ref": ref_h, "tran": tran_h}
        del hidden_states
    
    os.makedirs(save_dir, exist_ok=True)
    torch.save(hidden_dict, f"{save_dir}/hidden.pt")

def get_position(item, tokenizer):
    thinking = item["problem"] + "<think>"
    exec_pos = []
    ref_pos = []
    tran_pos = []
    chunks = item["thinking_chunks"]
    for i in range(len(chunks)):
        chunk_text = chunks[i]["chunk"] + "\n\n"
        thinking += chunk_text
        pos = len(tokenizer(thinking)["input_ids"]) - 1
        if i < len(chunks) - 1:
            if chunks[i + 1]["type"] == "execution":
                exec_pos.append(pos)
            elif chunks[i + 1]["type"] == "reflection":
                ref_pos.append(pos)
            elif chunks[i + 1]["type"] == "transition":
                tran_pos.append(pos)
    return exec_pos, ref_pos, tran_pos, thinking

if __name__ == "__main__":
    model_path = "/path/to/model/DeepSeek-R1-Distill-Qwen-14B"
    data_path = "/path/to/seal_train_data"
    save_dir = "/path/to/save/hidden_states"
    
    get_all_hidden_states(model_path, data_path, save_dir)