import ujson as json

import numpy as np
from transformers import AutoTokenizer
from collections import defaultdict
# import matplotlib.pyplot as plt
from tqdm import tqdm


if __name__ == "__main__":
    MODEL_PATH = "Qwen/QwQ-32B"
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

    reflection_pattern_words = [
        r"wait",
        r"Wait",
        r"retry",
        r"Retry",
        r"recheck",
        r"Recheck",
        r"however",
        r"However",
        r"alternatively",
        r"Alternatively",
        r"therefore",
        r"Therefore",
        r"hmm",
        r"Hmm"
    ]
    reflection_token_ids = []
    for word in reflection_pattern_words:
        token_ids = tokenizer.encode(word, add_special_tokens=False)
        reflection_token_ids.append(token_ids)

    all_logprobs = []
    # 0, 5, 7, 8, 9, 11, 12, 15, 16
    for q_idx in tqdm([0, 7, 8, 9, 11, 12, 15, 16]):
        with open(f"qwq32b_aime24_logprobs_{q_idx}.json", "r") as rfile:
            data = json.load(rfile)

        outputs = {
            "question": data["question"],
            "answer": data["answer"],
            "logprobs": []
        }
        answer_tokens = tokenizer.encode(data["answer"], add_special_tokens=False)
        extended_answer_tokens = tokenizer.encode("\n\n**Final Answer**\n\\boxed{" + data["answer"], add_special_tokens=False)
        for rollout in data["logprobs"]:
            logprobs = []
            for item in rollout:
                if item["prompt_ids"][:-len(extended_answer_tokens)][-1] == tokenizer.encode("</think>", add_special_tokens=False)[0]:
                    break
                logprobs.append((item["prompt_ids"][:-len(extended_answer_tokens)][-1], -np.mean(item["logprobs"][-len(answer_tokens):])))
            logprobs = logprobs[:-len(extended_answer_tokens)]
            all_logprobs.append(logprobs)
    
    # 统计反思标记前后的logprob变化
    reflection_changes = defaultdict(list)
    all_token_changes = []
    
    # 跟踪每个token的变化量
    token_changes = defaultdict(list)
    
    # 遍历所有logprobs序列
    for logprob_seq in tqdm(all_logprobs):
        if len(logprob_seq) < 2:
            continue
            
        # 计算所有相邻token的logprob相对变化量
        for i in range(1, len(logprob_seq)):
            prev_token_id, prev_logprob = logprob_seq[i-1]
            curr_token_id, curr_logprob = logprob_seq[i]
            
            if prev_logprob == 0:  # 避免除以零
                continue
                
            relative_change = abs(curr_logprob - prev_logprob) / prev_logprob
            all_token_changes.append(relative_change)
            
            # 记录当前token的变化量
            token_changes[curr_token_id].append(relative_change)
            
            # 检查当前token是否是反思标记的开始
            for idx, reflection_ids in enumerate(reflection_token_ids):
                if len(reflection_ids) > 0 and curr_token_id == reflection_ids[0]:
                    # 找到了反思标记的开始
                    ii, jj = i, 0
                    while jj < len(reflection_ids) and ii < len(logprob_seq):
                        if logprob_seq[ii][0] != reflection_ids[jj]:
                            break
                        ii += 1
                        jj += 1

                    if jj == len(reflection_ids):
                        word = reflection_pattern_words[idx]
                        reflection_changes[word].append(abs(logprob_seq[ii][1] - prev_logprob) / prev_logprob)
    
    # 计算整体平均相对变化量
    overall_avg_change = np.mean(all_token_changes)
    overall_std_change = np.std(all_token_changes)
    
    print(f"所有token的平均相对变化量: {overall_avg_change:.4f} ± {overall_std_change:.4f}")
    print("\n反思标记的相对变化量:")
    
    # 计算并打印每个反思标记的平均相对变化量
    for word, changes in reflection_changes.items():
        if len(changes) > 0:
            avg_change = np.mean(changes)
            std_change = np.std(changes)
            diff_from_overall = avg_change - overall_avg_change
            print(f"'{word}': {avg_change:.4f} ± {std_change:.4f} (与整体差异: {diff_from_overall:.4f})")
            print(f"  出现次数: {len(changes)}")
        else:
            print(f"'{word}': 未在序列中找到")
    
    # 计算每个token的平均变化量
    token_avg_changes = {}
    for token_id, changes in token_changes.items():
        if len(changes) >= 5:  # 只考虑出现次数足够多的token
            token_avg_changes[token_id] = np.mean(changes)
    
    # 对token按平均变化量排序
    sorted_tokens = sorted(token_avg_changes.items(), key=lambda x: x[1], reverse=True)
    
    # 输出top-10的token
    print("\n平均相对变化量 Top-100 的 Token:")
    for i, (token_id, avg_change) in enumerate(sorted_tokens[:100]):
        token_text = tokenizer.decode([token_id])
        print(f"{i+1}. Token ID: {token_id}, 文本: '{token_text}', 平均变化量: {avg_change:.4f}, 出现次数: {len(token_changes[token_id])}")