import json
import matplotlib.pyplot as plt
from collections import defaultdict
import os
import numpy as np
import seaborn as sns
import re
from tqdm import tqdm
from transformers import AutoTokenizer

sns.set(style="whitegrid")

def extract_last_digit(text):
    chinese_digits = {
        "零": 0, "一": 1, "二": 2, "三": 3, "四": 4,
        "五": 5, "六": 6, "七": 7, "八": 8, "九": 9,
        "两": 2, "仨": 3,
    }

    arabic_numbers = [(int(match.group()), match.start()) for match in re.finditer(r'\d+', text)]
    chinese_numbers = [(chinese_digits[char], i) for i, char in enumerate(text) if char in chinese_digits]

    all_numbers_with_positions = arabic_numbers + chinese_numbers
    all_numbers_with_positions.sort(key=lambda x: x[1])
    result = [num for num, pos in all_numbers_with_positions]
    return result[-1] if result else None

def extract_first_digit(text):
    chinese_digits = {
        "零": 0, "一": 1, "二": 2, "三": 3, "四": 4,
        "五": 5, "六": 6, "七": 7, "八": 8, "九": 9,
        "两": 2, "仨": 3,
    }

    arabic_numbers = [(int(match.group()), match.start()) for match in re.finditer(r'\d+', text)]
    chinese_numbers = [(chinese_digits[char], i) for i, char in enumerate(text) if char in chinese_digits]

    all_numbers_with_positions = arabic_numbers + chinese_numbers
    all_numbers_with_positions.sort(key=lambda x: x[1])
    result = [num for num, pos in all_numbers_with_positions]
    return result[0] if result else None

def first_answer_length(text):
    model_answer = extract_last_digit(text)
    answer_map = {
        1: "一个",
        2: "两个",
        3: "三个",
        4: "四个",
        5: "五个",
        6: "六个",
    }
    if model_answer not in answer_map.keys():
        return None
    
    kw = answer_map.get(model_answer)
    
    summary_words = ["总", "共", "所以", "综上", "其实"]
    question_words = ["？", "吗"]

    summary_pattern = "|".join(map(re.escape, summary_words))
    question_pattern = "|".join(map(re.escape, question_words))

    start_index = 0
    while True:
        kw_index = text.find(kw, start_index)
        if kw_index == -1:
            return None
        
        if kw_index + len(kw) < len(text) and text[kw_index + len(kw)].isalpha():
            break
        
        if kw_index > 10:
            prev_text = text[kw_index - 10:kw_index]
        else:
            prev_text = text[:kw_index]
        if re.search(summary_pattern, prev_text):
            break
        
        if kw_index + len(kw) + 5 < len(text):
            next_text = text[kw_index + len(kw):kw_index + len(kw) + 5]
        else:
            next_text = text[kw_index + len(kw):]
        if re.search(question_pattern, next_text):
            break
        
        start_index = kw_index + len(kw)

    prefix_text = text[:kw_index]
    encoded_prefix = tokenizer.encode(prefix_text)
    return len(encoded_prefix)

def analyze_answers(folder_path):
    same_answer_accuracy_stats = defaultdict(list)
    same_answer_keyword_stats = defaultdict(list)
    same_answer_first_answer_length = defaultdict(list)
    same_answer_direct_answer_concentration = defaultdict(list)

    keywords = [
        "不过", "或者", "等等", "但是", "不对"
    ]

    for file in tqdm(os.listdir(folder_path)):
        if not file.endswith(".json"):
            continue
        file_path = os.path.join(folder_path, file)
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        for item in data:
            model_answer = item["model_answer"]
            model_answer = item.get("model_reasoning", model_answer)
            correct_answer = item["correct_answer"]
            model_answer_length = item["model_answer_length"]
            direct_answers = item["direct_answers"]

            # if extract_last_digit(model_answer) != correct_answer:
            #     continue
            # if model_answer_length > 4000 or not direct_answers:
            #     continue
            first_answer_len = first_answer_length(model_answer)
            if not first_answer_len:
                continue

            answers = defaultdict(int)
            for direct_answer in direct_answers:
                answer = extract_first_digit(direct_answer)
                if answer is not None:
                    answers[answer] += 1

            mean_direct = sum([k*v for k, v in answers.items()]) / sum(answers.values())
            gap = abs(mean_direct - extract_last_digit(model_answer))
            gap = abs(mean_direct - correct_answer)
            if gap <= 0.5:
                key = "0 ~ 0.5"
            elif gap <= 1:
                key = "0.5 ~ 1"
            elif gap <= 2:
                key = "1 ~ 2"
            else:
                key = ">2"

            same_answer_accuracy_stats[key].append(model_answer_length)
            same_answer_keyword_stats[key].append(
                sum(model_answer.count(keyword) for keyword in keywords)
            )
            same_answer_first_answer_length[key].append(first_answer_len)

    for k, v in same_answer_accuracy_stats.items():
        print(k, len(v))
    # exit(0)
    return (
        same_answer_accuracy_stats, 
        same_answer_keyword_stats, 
        same_answer_first_answer_length,
    )


def plot_against_same_answer_accuracy(
        same_answer_accuracy_stats, 
        same_answer_keyword_stats, 
        same_answer_first_answer_length, 
    ):
    keys = [
        # "0",
        "0 ~ 0.5",
        # "0.25 ~ 0.5",
        # "0.5 ~ 0.75",
        "0.5 ~ 1",
        "1 ~ 2",
        # "2 ~ 3",
        ">2",
    ]
    
    avg_model_lengths = [np.mean(same_answer_accuracy_stats[key]) if key in same_answer_accuracy_stats else 0 for key in keys]
    avg_keyword_counts = [np.mean(same_answer_keyword_stats[key]) if key in same_answer_keyword_stats else 0 for key in keys]
    avg_first_answer_lengths = [np.mean(same_answer_first_answer_length[key]) if key in same_answer_first_answer_length else 0 for key in keys]

    bar_positions = np.arange(len(keys))
    bar_width = 0.3

    fig, ax1 = plt.subplots(figsize=(12, 8))

    bars1 = ax1.bar(bar_positions - bar_width/2,
                    avg_model_lengths,
                    width=bar_width,
                    color='navy',
                    alpha=0.7,
                    label='Avg. Model Answer Length')

    bars2 = ax1.bar(bar_positions + bar_width/2,
                    avg_first_answer_lengths,
                    width=bar_width,
                    color='darkblue',
                    alpha=0.9,
                    label='Avg. First Answer Length')

    ax1.set_xlabel('MAE', fontsize=18, fontweight='bold')
    ax1.set_ylabel('Average Lengths', fontsize=18, color='navy')
    ax1.tick_params(axis='x', labelsize=16)
    ax1.tick_params(axis='y', labelcolor='navy', labelsize=16)
    ax1.set_xticks(bar_positions)
    ax1.set_xticklabels(keys)
    ax1.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)

    ax2 = ax1.twinx()
    line1 = ax2.plot(bar_positions, avg_keyword_counts,
                     marker='^', linestyle='--',
                     color='darkorange',
                     label='Avg. Keyword Count',
                     markersize=8,
                     linewidth=2)
    ax2.set_ylabel('Average Keyword Count', fontsize=18, color='darkorange')
    ax2.tick_params(axis='y', labelcolor='darkorange', labelsize=16)
    ax2.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)

    lines_labels = [l for l in line1]
    labels = [l.get_label() for l in lines_labels]
    legend = ax1.legend([bars1, bars2] + lines_labels,
               ["Full Length", "First Reasoning", "Avg. Keyword Count"],
               loc='upper left', fontsize=14, frameon=True, shadow=True)
    legend.set_zorder(5)

    plt.title('CharCount (zh) - QwQ-32B', fontsize=22, fontweight='bold', pad=20)

    plt.tight_layout()
    plt.savefig('qwq_zh.png', dpi=300, bbox_inches='tight')
    plt.close()

model_path = "path/to/model/DeepSeek-R1-Distill-Qwen-14B"
model_path = "path/to/model/QwQ-32B"
# model_path = "path/to/model/DeepSeek-V3-tokenizer"
print(f"正在加载模型: {model_path} ...")

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    padding_side='left'
)

folder_path = 'path/to/LongCoT/CharCount/results/qwen_zh_results_with_direct'
folder_path = 'path/to/LongCoT/CharCount/results/QwQ_zh_results_with_direct'
# folder_path = "path/to/LongCoT/CharCount/results/dpsk_zh_results"

same_answer_accuracy_stats, same_answer_keyword_stats, same_answer_first_answer_length = analyze_answers(folder_path)

plot_against_same_answer_accuracy(
    same_answer_accuracy_stats, 
    same_answer_keyword_stats, 
    same_answer_first_answer_length
)