import torch
import re
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import json


model_path = "path/to/model/DeepSeek-R1-Distill-Qwen-14B"
input_data_paths = [
    "path/to/LongCoT/CharCount/hiddenStates/greedyAnswers/qwen14_zh_low_same_ratio_2-3-3_greedy.json",
    "path/to/LongCoT/CharCount/hiddenStates/greedyAnswers/qwen14_zh_low_same_ratio_2-4-4_greedy.json",
    "path/to/LongCoT/CharCount/hiddenStates/greedyAnswers/qwen14_zh_low_same_ratio_3-4-4_greedy.json",
]

tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    low_cpu_mem_usage=True,
    local_files_only=True,
    output_attentions=True
)
model.eval()

def get_category_attention(token_idx, tokens, attention_scores, user_assistant_content):
    end_question = False
    category_scores = {
        'Question': [],
        'Mid Results': [],
        'Others': []
    }

    for i, token in enumerate(tokens[:token_idx+1]):
        if i < 3:
            continue
        if token == '<｜Assistant｜>':
            end_question = True

        if not end_question and token in user_assistant_content:
            category_scores['Question'].append(attention_scores[i])
        elif end_question and re.search(r'\d|[一二三四五六七八九十]个', token):
            category_scores['Mid Results'].append(attention_scores[i])
        else:
            category_scores['Others'].append(attention_scores[i])

    mean_scores = {
        clss: np.mean(scores) if len(scores) > 0 else 0
        for clss, scores in category_scores.items()
    }
    return mean_scores

def split_first_reflection(text):
    keywords = [
        "对吗", "吗",
    ]
    for kw in keywords:
        if kw in text:
            prefix = text.split(kw)[0]
            # return prefix
            if "所以，总共" in prefix[-20:]:
                return prefix
            else:
                return None
    return None

all_last_scores = []
all_avg_scores = []

for input_data_path in input_data_paths:
    with open(input_data_path, "r", encoding='utf-8') as f:
        data = json.load(f)

    cnt = 0
    for idx, item in enumerate(data):
        text = f"<｜User｜>{item['question']}<｜Assistant｜><think>{item['model_answer']}"
        text = split_first_reflection(text)
        if not text:
            continue
        cnt += 1
        print(f"Processing sample {cnt}")
        # if cnt > 100:
        #     break

        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        input_ids = inputs["input_ids"][0].tolist()
        tokens = [tokenizer.decode([token_id]) for token_id in input_ids]
    
        with torch.no_grad():
            outputs = model(**inputs)

        attentions = outputs.attentions  # Tuple of attention tensors per layer

        # print(len(attentions))
        # exit(0)
        layer_attentions = attentions[40:]
        averaged_attention = torch.stack(layer_attentions).mean(dim=(0, 2)).squeeze()

        user_assistant_content = ""
        user_assistant_match = re.search(
            r"<｜User｜>(.*?)<｜Assistant｜>",
            "".join(tokens),
            re.DOTALL
        )
        if user_assistant_match:
            user_assistant_content = user_assistant_match.group(1)
        
        think_token = "<think>"
        think_pos = None
        for i, token in enumerate(tokens):
            if token == think_token:
                think_pos = i
                break

        all_mean_scores = []
        for token_idx in range(think_pos+1, len(tokens)):

            attention_scores = averaged_attention[token_idx, :token_idx+1].cpu().float().numpy()
            mean_scores = get_category_attention(token_idx, tokens, attention_scores, user_assistant_content)
            context_len = token_idx + 1
            normalized_scores = {
                key: value / (1.0 / context_len) if value > 0 else 0
                for key, value in mean_scores.items()
            }

            all_mean_scores.append(normalized_scores)

        last_token_pos = len(input_ids) - 1

        sum_scores = {'Question': 0, 'Mid Results': 0, 'Others': 0}
        count_scores = {'Question': 0, 'Mid Results': 0, 'Others': 0}

        for score_dict in all_mean_scores[:-1]:
            for key in sum_scores:
                sum_scores[key] += score_dict[key]
                count_scores[key] += 1

        avg_all_tokens = {
            key: (sum_scores[key] / count_scores[key]) if count_scores[key] > 0 else 0
            for key in sum_scores
        }

        last_token_scores = all_mean_scores[-1]

        all_last_scores.append([last_token_scores[c] for c in ['Question', 'Mid Results', 'Others']])
        all_avg_scores.append([avg_all_tokens[c] for c in ['Question', 'Mid Results', 'Others']])

final_last_scores = np.mean(all_last_scores, axis=0).tolist()
final_avg_scores = np.mean(all_avg_scores, axis=0).tolist()

categories = ['Question', 'Mid Results', 'Others']
x = np.arange(len(categories))
width = 0.35

fig, ax = plt.subplots(figsize=(8, 5))

last_scores_by_category = [[s[i] for s in all_last_scores] for i in range(3)]
avg_scores_by_category = [[s[i] for s in all_avg_scores] for i in range(3)]

last_means = final_last_scores
avg_means = final_avg_scores

last_stds = [np.std(scores) for scores in last_scores_by_category]
avg_stds = [np.std(scores) for scores in avg_scores_by_category]

color_reflection = ['navy', 'navy', 'navy']
color_others = ['lightblue', 'lightblue', 'lightblue']

rects1 = ax.bar(x - width/2, last_means, width,
                yerr=last_stds, capsize=5,
                label='Reflection Tokens', color=color_reflection, edgecolor='black')

rects2 = ax.bar(x + width/2, avg_means, width,
                yerr=avg_stds, capsize=5,
                label='Other Tokens', color=color_others, edgecolor='black')

ax.set_ylabel('Normalized Attention Score', fontsize=14)
# ax.set_title('Comparison of Normalized Attention Scores\n(Across All Samples)')
ax.set_xticks(x)
ax.set_xticklabels(categories)
ax.tick_params(axis='x', labelsize=12)
ax.tick_params(axis='y', labelsize=12)


ax.legend(loc='upper left', fontsize=14)

def autolabel(rects):
    for rect in rects:
        height = rect.get_height()
        ax.annotate(f'{height:.2f}×',
                    xy=(rect.get_x(), height), 
                    xytext=(1, 3),      
                    textcoords="offset points",
                    ha='left',              
                    va='bottom',
                    fontsize=9,
                    color='black',
        )

autolabel(rects1)
autolabel(rects2)

plt.tight_layout()
output_file = "path/to/LongCoT/CharCount/hiddenStates/comparison_normalized_category_attention_bars.png"
plt.savefig(output_file, dpi=300)
print(f"Saved comparison plot with multi-color legend to {output_file}")