import torch
import torch.nn as nn
import json
import os
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import TensorDataset, DataLoader

class MLPProbe(nn.Module):
    def __init__(self, input_dim, hidden_dim=0):
        super(MLPProbe, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        if hidden_dim == 0:
            self.classifier = nn.Linear(input_dim, 1)
        else:
            self.mlp = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, 1)
            )
        
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        if self.hidden_dim == 0:
            logits = self.classifier(x)
        else:
            logits = self.mlp(x)
        return self.sigmoid(logits.squeeze(-1))

def get_hidden_state(model, tokenizer, text, layer_idx=-1):
    inputs = tokenizer(text, return_tensors="pt", padding=True).to(model.device)
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states[layer_idx]
        last_hidden = hidden_states[:, -1, :]
    return last_hidden.cpu()

mini_batch = 10

def evaluate_truncation(model, tokenizer, probe_model, data_path, layer_idx=48):
    with open(data_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    for i in range(0, len(data), mini_batch):
        start_pos = i
        end_pos = min(i + mini_batch, len(data))
        save_path = f"/path/to/probe_results/{start_pos}_{end_pos}.json"
        if os.path.exists(save_path):
            print(f"File {save_path} already exists, skipping...")
        print(f"Processing data {start_pos} to {end_pos}...")
        batch_data = data[start_pos:end_pos]
        new_data = []
        for item in tqdm(batch_data):
            problem = item["problem"]
            chunks = item["thinking_chunks"]
            current_text = problem + "<think>"
            truncated_length = 0
            medium_answer = -1
            if len(chunks) <= 1:
                current_text += chunks[0]["chunk"]
                medium_answer = chunks[0]["medium_answer"]
                new_item = {
                    "problem": problem,
                    "truncated_thinking": current_text,
                    "truncated_length": len(tokenizer(current_text)["input_ids"]),
                    "correct_answer": item["correct_answer"],
                    "medium_answer": medium_answer
                }
                new_data.append(new_item)
                continue
        
            for j, chunk in enumerate(chunks):
                if j == len(chunks) - 1:
                    current_text += chunk["chunk"]
                    medium_answer = chunk["medium_answer"]
                    break
                chunk_text = chunk["chunk"] + "\n\n"
                current_text += chunk_text
                medium_answer = chunk["medium_answer"]
                hidden_state = get_hidden_state(model, tokenizer, current_text, layer_idx)
                probe_input = hidden_state.float()
                pred = probe_model(probe_input).item()
                del hidden_state
            
                if pred >= 0.85:
                    break
        
            truncated_length = len(tokenizer(current_text)["input_ids"])

            new_item = {
                "problem": problem,
                "truncated_thinking": current_text,
                "truncated_length": truncated_length,
                "correct_answer": item["correct_answer"],
                "medium_answer": medium_answer
            }
            new_data.append(new_item)
    
        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(new_data, f, ensure_ascii=False, indent=4)


if __name__ == "__main__":
    model_path = "/path/to/model/DeepSeek-R1-Distill-Qwen-14B"
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
    
    probe_model_path = "/path/to/your/probe/model/model.pth"
    input_dim = 5120
    probe_model = MLPProbe(input_dim=input_dim, hidden_dim=0)
    probe_model.load_state_dict(torch.load(probe_model_path))
    probe_model.eval()

    data_path = "/path/to/probe_eval_data"
    evaluate_truncation(model, tokenizer, probe_model, data_path, layer_idx=48)