import json
import matplotlib.pyplot as plt
from collections import defaultdict
import os
import numpy as np
import seaborn as sns 
import re
from transformers import AutoTokenizer

sns.set(style="whitegrid")

def extract_abcd_options(s):
    valid_options = {'A', 'B', 'C', 'D'}
    result = [ch for ch in s if ch in valid_options]
    return ''.join(result)

def extract_answer(text):
    match = re.search(r'【([^】]*)】', text)
    if not match:
        return None
    
    content = match.group(1)
    ans = ""
    for c in "ABCD":
        if c in content:
            ans += c

    return ans

def first_answer_length(text):
    splitted = text.split("\n\n")
    res = []
    for i in range(len(splitted)):
        sp = splitted[i]
        res.append(splitted[i])
        if ("所以" in sp or "因此" in sp) \
            and ("答案" in sp or "选项" in sp) and i+1 < len(splitted) \
            and ("或者" in splitted[i+1] or "不过" in splitted[i+1] or "但是" in splitted[i+1] or "可是" in splitted[i+1]) \
            :
            return "\n\n".join(res)

    return ""

def analyze_answers(folder_path):
    same_answer_error_stats = defaultdict(list)
    same_answer_keyword_stats = defaultdict(list)
    same_answer_first_answer_length = defaultdict(list)

    keywords = [
        "不过",
        "等等", 
        "但是", 
        "不对"
    ]
    # keywords = [
    #     "\n\n不过", "\n\n或者", "\n\n等等", "\n\n但是", 
    #     "\n\n等一下", "\n\n不", "\n\n不对",
    # ]

    correct_cnt = incorrect_cnt = 0
    direct_corr_cnt = direct_incorr_cnt = 0

    for file in os.listdir(folder_path):
        if not file.endswith("json"):
            continue
        print(file)
        with open(os.path.join(folder_path, file), "r", encoding='utf-8') as f:
            data = json.load(f)

        for item in data:
            directs = item['direct_answers']
            # correct = item['correct_answer']
            # correct = item.get('answer', item['correct_answer'])
            if item.get("answer", None):
                correct = item['answer']
            else:
                correct = item['correct_answer']
            model_answer = item['model_answer']
            model_answer = item.get("reasoning", model_answer)
            model_answer_length = item['model_answer_length']
            correct_answer = "".join(correct)
            model_short_answer = extract_answer(item['model_answer'])
            ttl_num = 0
            correct_num = 0
            answers = defaultdict(int)
            for da in directs:
                ans = extract_abcd_options(da.split("】")[0])
                ttl_num += 1
                # if ans == correct_answer:
                if ans == model_short_answer:
                    correct_num += 1
                answers[ans] += 1
            most = max(answers, key=answers.get)
            if most == correct_answer:
                direct_corr_cnt += 1
            else:
                direct_incorr_cnt += 1

            # if extract_answer(item['model_answer']) != correct_answer:
            #     continue
                

            first_answer = first_answer_length(model_answer)
            if first_answer == "":
                first_answer = model_answer
            first_answer_len = len(tokenizer.encode(first_answer, add_special_tokens=False))
            

            # 计算错误率
            same_answer_error_rate = (ttl_num - correct_num) / ttl_num

            # 存储数据
            same_answer_error_stats[same_answer_error_rate].append(model_answer_length)
            same_answer_keyword_stats[same_answer_error_rate].append(
                sum(model_answer.count(keyword) for keyword in keywords)
            )
            same_answer_first_answer_length[same_answer_error_rate].append(first_answer_len)

    # print("Model Answer Accuracy:", correct_cnt / (correct_cnt + incorrect_cnt))
    # print("Direct Answers Accuracy:", direct_corr_cnt / (direct_corr_cnt + direct_incorr_cnt))
    for k, v in same_answer_error_stats.items():
        print(k, sum(v)/len(v), len(v))
    return same_answer_error_stats, same_answer_keyword_stats, same_answer_first_answer_length


def group_into_bins(stats, num_bins=10):
    acc_to_values = defaultdict(list)
    for err_rate, values in stats.items():
        for v in values:
            acc_to_values[err_rate].append(v)

    bin_edges = np.linspace(0, 1, num_bins + 1)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    grouped_stats = []

    for i in range(num_bins):
        left = bin_edges[i]
        right = bin_edges[i + 1]
        bin_values = []

        for err_rate in acc_to_values:
            if left <= err_rate < right:
                bin_values.extend(acc_to_values[err_rate])
            if right == 1.0 and err_rate == 1.0:
                bin_values.extend(acc_to_values[err_rate])
        mean_key = bin_centers[i]
        grouped_stats.append((mean_key, bin_values))

    return grouped_stats


def plot_against_direct_answer_error_rate(
        same_answer_error_stats, 
        same_answer_keyword_stats, 
        same_answer_first_answer_length, 
        num_bins=5
    ):

    same_answer_avg_lengths = group_into_bins(same_answer_error_stats, num_bins)
    same_answer_avg_keyword_counts = group_into_bins(same_answer_keyword_stats, num_bins)
    same_answer_avg_first_answer_length = group_into_bins(same_answer_first_answer_length, num_bins)
    valid_indices = [i for i, (key, vals) in enumerate(same_answer_avg_lengths) if len(vals) > 0]

    avg_model_lengths = [np.mean(item[1]) if len(item[1]) > 0 else 0 for item in same_answer_avg_lengths if len(item[1]) > 0]
    avg_keyword_counts = [np.mean(item[1]) if len(item[1]) > 0 else 0 for item in same_answer_avg_keyword_counts if len(item[1]) > 0]
    avg_first_answer_lengths = [np.mean(item[1]) if len(item[1]) > 0 else 0 for item in same_answer_avg_first_answer_length if len(item[1]) > 0]

    bin_edges = np.linspace(0, 1, num_bins + 1)
    bin_labels = [f"{int(bin_edges[i]*100)}~{int(bin_edges[i+1]*100)}%" for i in range(num_bins)]
    bin_labels = [bin_labels[i] for i in valid_indices]

    bar_positions = np.arange(len(avg_model_lengths)) 
    bar_width = 0.3

    fig, ax1 = plt.subplots(figsize=(12, 8))

    bars1 = ax1.bar(bar_positions - bar_width/2,
                    avg_model_lengths,
                    width=bar_width,
                    color='navy',
                    alpha=0.7,
                    label='Avg. Model Answer Length')

    bars2 = ax1.bar(bar_positions + bar_width/2,
                    avg_first_answer_lengths,
                    width=bar_width,
                    color='darkblue',
                    alpha=0.9,
                    label='Avg. First Answer Length')

    ax1.set_xlabel('Inconsistency Rate', fontsize=18, fontweight='bold')
    ax1.set_ylabel('Average Lengths', fontsize=18, color='navy')
    ax1.tick_params(axis='x', labelsize=16)
    ax1.tick_params(axis='y', labelcolor='navy', labelsize=16)
    ax1.set_xticks(bar_positions)
    ax1.set_xticklabels(bin_labels)
    ax1.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)

    ax2 = ax1.twinx()
    line1 = ax2.plot(bar_positions, avg_keyword_counts,
                     marker='^', linestyle='--',
                     color='darkorange',
                     label='Avg. Keyword Count',
                     markersize=8,
                     linewidth=2)
    ax2.set_ylabel('Average Keyword Count', fontsize=18, color='darkorange')
    ax2.tick_params(axis='y', labelcolor='darkorange', labelsize=16)
    ax2.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)

    lines_labels = [l for l in line1]
    labels = [l.get_label() for l in lines_labels]
    legend = ax1.legend([bars1, bars2] + lines_labels,
               ["Full Length", "First Reasoning", "Avg. Keyword Count"],
               loc='upper left', fontsize=14, frameon=True, shadow=True)
    legend.set_zorder(100)
    
    plt.title('Knowlogic - QwQ-32B', fontsize=22, fontweight='bold', pad=20)

    plt.tight_layout()
    plt.savefig('Knowlogic_qwq.png', dpi=300, bbox_inches='tight')
    plt.close()


# 主程序入口
model_path = "path/to/model/DeepSeek-R1-Distill-Qwen-14B"
model_path = "path/to/model/QwQ-32B"
# model_path = "path/to/model/DeepSeek-V3-tokenizer"
print(f"正在加载模型: {model_path} ...")
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left')

folder_path = "path/to/LongCoT/Knowlogic/direct_results"
folder_path = "path/to/LongCoT/Knowlogic/qwq_results"
# folder_path = "path/to/LongCoT/Knowlogic/dpsk_results"
threshold = 4

same_answer_error_stats, same_answer_keyword_stats, same_answer_first_answer_length = analyze_answers(folder_path)
plot_against_direct_answer_error_rate(same_answer_error_stats, same_answer_keyword_stats, same_answer_first_answer_length, threshold)