import json
import re
from tqdm import tqdm
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
import numpy as np
import matplotlib.pyplot as plt


def extract_last_digit(text):
    chinese_digits = {
        "零": 0, "一": 1, "二": 2, "三": 3, "四": 4,
        "五": 5, "六": 6, "七": 7, "八": 8, "九": 9,
        "两": 2, "仨": 3,
    }

    arabic_numbers = [(int(match.group()), match.start()) for match in re.finditer(r'\d+', text)]
    chinese_numbers = [(chinese_digits[char], i) for i, char in enumerate(text) if char in chinese_digits]

    all_numbers_with_positions = arabic_numbers + chinese_numbers
    all_numbers_with_positions.sort(key=lambda x: x[1])
    result = [num for num, pos in all_numbers_with_positions]
    if len(result) != 0:
        return result[-1]
    return None

def simple_split_generated_answers(model_answer):
    answers = model_answer.split("。")
    temp_splitted_answers = [ans + "。" for ans in answers[:-1]] + [answers[-1]]
    splitted_answers = []
    for temp in temp_splitted_answers:
        tt = temp.split("？")
        if len(tt) > 1:
            tt = [t+"？" for t in tt[:-1]] + [tt[-1]]
        splitted_answers.extend(tt)
    return splitted_answers

def extract_spelling(text, word):
    text = text.replace(word, "")
    text = text.replace(" ", "")
    text = text.replace(",", "")
    text = text.replace("、", "")
    text = text.replace("-", "")
    text = text.replace(".", "")
    text = text.replace("\n", "")
    
    text = re.sub(r'（[^）]*）', '', text)
    text = re.sub(r'\([^\)]*\)', '', text)
    text = re.sub(r'[0-9]', '', text)
    
    extracted_words = re.findall(r'[a-zA-Z]+', text)
    return [word for word in extracted_words if len(word)>=3]


def analyze_last_logits(outputs, tokenizer, direct_answer, correct_answer):
    last_token_logits = outputs.logits[:, -1, :]
    probs = torch.softmax(last_token_logits, dim=-1).squeeze()

    direct_answer_ids = tokenizer.encode(str(direct_answer), add_special_tokens=False)
    correct_answer_ids = tokenizer.encode(str(correct_answer), add_special_tokens=False)

    direct_prob = 0
    for token_id in direct_answer_ids:
        direct_prob += probs[token_id].item()

    correct_prob = 0
    for token_id in correct_answer_ids:
        correct_prob += probs[token_id].item()

    return direct_prob, correct_prob    

def mask_input_word(inputs, tokenizer, input_word):
    input_ids = inputs["input_ids"][0]
    attention_mask = inputs["attention_mask"][0]

    input_word_tokens = tokenizer.tokenize(input_word)
    input_word_ids = tokenizer.convert_tokens_to_ids(input_word_tokens)

    start_index = None
    for i in range(len(input_ids) - len(input_word_ids) + 1):
        if input_ids[i:i + len(input_word_ids)].tolist() == input_word_ids:
            start_index = i
            end_index = start_index + len(input_word_ids)
            attention_mask[start_index:end_index] = 0

    inputs["attention_mask"] = attention_mask.unsqueeze(0)
    return inputs


model_path = "path/to/model/DeepSeek-R1-Distill-Qwen-14B"
model_idx = "qwen14"
print(f"正在加载模型: {model_path} ...")

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    padding_side='left'
)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    low_cpu_mem_usage=True,
    local_files_only=True
)
model.eval()

input_cases_path = "path/to/LongCoT/CharCount/hiddenStates/greedyAnswers/high_bias.json"
output_image_path = "path/to/LongCoT/CharCount/hiddenStates/high_bias.png"

# input_cases_path = "path/to/LongCoT/CharCount/hiddenStates/greedyAnswers/2-3-3_greedy.json"
# output_image_path = "path/to/LongCoT/CharCount/hiddenStates/low_bias.png"
with open(input_cases_path, "r", encoding='utf-8') as f:
    selected_cases = json.load(f)

all_question_trends = {
    "direct": [],
    "correct": [],
    "masked_direct": [],
    "masked_correct": []
}
for idx, item in enumerate(tqdm(selected_cases)):
    question = item['question']
    correct_answer = item['correct_answer']
    most_direct_answer = item['most_direct_answer']

    input_word = question.split("这个单词里面有几个字母")[0]
    input_text = f"<｜User｜>{question}<｜Assistant｜><think>"        

    splitted_answer = item['splitted_answer']
    spelling_right_idx = item['spelling_right_idx']

    thinking = ""
    for ans_idx, ans in enumerate(splitted_answer):
        thinking += ans
        if ans_idx + 2 == spelling_right_idx[0]:
            break

    direct_trend = []
    correct_trend = []
    masked_direct_trend = []
    masked_correct_trend = []
    for i in range(spelling_right_idx[0] - 1, spelling_right_idx[0] + 10):
        thinking += splitted_answer[i]
        
        input_textt = f"{input_text}{thinking}</think>\nThe answer is: "
        inputs = tokenizer(input_textt, return_tensors="pt")
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        # print(inputs['attention_mask'])
        # exit(0)
        inputs_masked = tokenizer(input_textt, return_tensors="pt")
        inputs_masked = mask_input_word(inputs_masked, tokenizer, input_word)
        # for kw in ["不过", "或者", "但是", "等", "不", "对", "吗", "？"]:
        #     inputs_masked = mask_input_word(inputs_masked, tokenizer, kw)
        inputs_masked = {k: v.to(model.device) for k, v in inputs_masked.items()}

        with torch.no_grad():
            outputs = model(
                **inputs,
                output_attentions=True,
            )
        with torch.no_grad():
            outputs_masked = model(
                **inputs_masked,
                output_attentions=True,
            )            

        direct_prob, correct_prob = analyze_last_logits(outputs, tokenizer, most_direct_answer, correct_answer)
        masked_direct_prob, masked_correct_prob = analyze_last_logits(outputs_masked, tokenizer, most_direct_answer, correct_answer)

        direct_trend.append(direct_prob)
        correct_trend.append(correct_prob)
        masked_direct_trend.append(masked_direct_prob)
        masked_correct_trend.append(masked_correct_prob)

    all_question_trends["direct"].append(direct_trend)
    all_question_trends["correct"].append(correct_trend)
    all_question_trends["masked_direct"].append(masked_direct_trend)
    all_question_trends["masked_correct"].append(masked_correct_trend)


def calculate_averages(trends):
    avg_trend = np.mean(trends, axis=0)
    return avg_trend

def plot_results(dashed_green, dashed_red, solid_green, solid_red, output_image_path):
    max_len = 10 
    
    x = range(1, max_len + 1)
    dashed_green = dashed_green[:max_len]
    dashed_red = dashed_red[:max_len]
    solid_green = solid_green[:max_len]
    solid_red = solid_red[:max_len]

    plt.figure(figsize=(10, 6))

    plt.plot(x, dashed_green, linestyle='--', color='green', marker='o', markevery=1, label='Correct_Original')
    plt.plot(x, solid_green, linestyle='-', color='green', marker='o', markevery=1, label='Correct_MASK')
    plt.plot(x, dashed_red, linestyle='--', color='red', marker='^', markevery=1, label='Direct_Original')
    plt.plot(x, solid_red, linestyle='-', color='red', marker='^', markevery=1, label='Direct_MASK')

    plt.xlabel('Index of Sentence', fontsize=20)
    plt.ylabel('Probability', fontsize=20)

    plt.legend(loc='center right', fontsize=14)
    plt.xlim(0.5, 10.5)
    plt.xticks(ticks=range(1, 11), labels=[str(i) for i in range(1, 11)], fontsize=12)

    plt.grid(True, axis='y', linestyle='--', alpha=0.7)

    plt.tight_layout()
    
    plt.savefig(output_image_path, dpi=300, bbox_inches='tight')
    print(f"图片已保存到: {output_image_path}")

avg_direct = calculate_averages(all_question_trends["direct"])
avg_correct = calculate_averages(all_question_trends["correct"])
avg_masked_direct = calculate_averages(all_question_trends["masked_direct"])
avg_masked_correct = calculate_averages(all_question_trends["masked_correct"])

plot_results(
    dashed_green=avg_correct,
    dashed_red=avg_direct, 
    solid_green=avg_masked_correct, 
    solid_red=avg_masked_direct, 
    output_image_path=output_image_path
)