import torch
import json
import os
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

def get_all_hidden_states(model_path, data_path, save_dir):
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")

    with open(data_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    layer_num = model.config.num_hidden_layers + 1
    hidden_dict = [{} for _ in range(layer_num)]
    
    for k, item in enumerate(tqdm(data)):
        cor_pos, incor_pos, text = get_position(item, tokenizer)
        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            hidden_states = outputs.hidden_states
            hidden_states = [h.detach().cpu() for h in hidden_states]
        

        cor_pos_tensor = torch.LongTensor(cor_pos)
        incor_pos_tensor = torch.LongTensor(incor_pos)
        
        for i in range(len(hidden_states)):
            h = hidden_states[i][0]
            cor_h = h[cor_pos_tensor]
            incor_h = h[incor_pos_tensor]
            hidden_dict[i][k] = {"correct": cor_h, "incorrect": incor_h}
        del hidden_states
    
    os.makedirs(save_dir, exist_ok=True)
    torch.save(hidden_dict, f"{save_dir}/hidden.pt")

def get_position(item, tokenizer):
    thinking = item["problem"] + "<think>"
    cor_pos = []
    incor_pos = []
    chunks = item["thinking_chunks"]
    for i in range(len(chunks)):
        chunk_text = chunks[i]["chunk"] + "\n\n"
        thinking += chunk_text
        pos = len(tokenizer(thinking)["input_ids"]) - 1
        
        if chunks[i]["medium_answer"] == item["correct_answer"]:
            cor_pos.append(pos)
        else:
            incor_pos.append(pos)
    return cor_pos, incor_pos, thinking

if __name__ == "__main__":
    model_path = "/path/to/model/DeepSeek-R1-Distill-Qwen-14B"
    data_path = "/path/to/probe_train_data"
    save_dir = "/path/to/save/hidden_states"
    
    get_all_hidden_states(model_path, data_path, save_dir)