import json
import numpy as np
from matplotlib import pyplot as plt, font_manager
from scipy.stats import entropy
import seaborn as sns
def js_divergence(p, q):
    p = np.asarray(p) / np.sum(p)
    q = np.asarray(q) / np.sum(q)
    m = 0.5 * (p + q)

    kl_pm = entropy(p, m + 1e-10, base=2)
    kl_qm = entropy(q, m + 1e-10, base=2)

    return 0.5 * (kl_pm + kl_qm)

permutation_list = [
    [0, 1, 2, 3],
    [0, 2, 1, 3],
    [1, 0, 2, 3],
    [1, 2, 0, 3],
    [2, 0, 1, 3],
    [2, 1, 0, 3]
]
permutation_list2 = [
    ['A', 'B', 'C', 'D'],
    ['A', 'C', 'B', 'D'],
    ['B', 'A', 'C', 'D'],
    ['C', 'A', 'B', 'D'],
    ['B', 'C', 'A', 'D'],
    ['C', 'B', 'A', 'D']
]

def compare(
        original,
        counterfactual
):
    statistics = {
        'all':{'js_divergence':[], 'ground_truth_delta':[], 'len':0},
        'tp':{'js_divergence':[], 'ground_truth_delta':[], 'len':0},
        'tn':{'js_divergence':[], 'ground_truth_delta':[], 'len':0, 'b_delta':[],'c_delta':[], 'd_delta':[]},
        'fn':{'js_divergence':[], 'ground_truth_delta':[], 'len':0,},
        'fp':{'js_divergence':[], 'ground_truth_delta':[], 'len':0},
        'acc':[0,0],
        'hallu':[0,0]
    }
    count = 0
    with open(original,'r',encoding='utf-8') as f:
        original_data = json.load(f)
    with open(counterfactual,'r',encoding='utf-8') as f:
        counterfactual_data = json.load(f)
    for i in original_data:
        statistics['acc'][0] += sum(i['corrects']) / (len(original_data) * 6)
        for j in range(6):
            if i['choices'][j] == permutation_list2[j][1] or i['choices'][j] == permutation_list2[j][2]:
                statistics['hallu'][0] += 1 / (len(original_data) * 6)
        # statistics['hallu'][0] += sum(1 if k == -1 else 0 for k in i['scores']) / (len(original_data) * 6)
    for i in counterfactual_data:
        statistics['acc'][1] += sum(i['corrects']) / (len(counterfactual_data) * 6)
        statistics['hallu'][1] += sum(1 if k == -1 else 0 for k in i['scores']) / (len(counterfactual_data) * 6)
    for i,j in zip(original_data,counterfactual_data):
        for k in range(6):
            if (i['choices'][k] == permutation_list2[k][1] or i['choices'][k] == permutation_list2[k][2]) and j['corrects'][k]:
                statistics['tp']['js_divergence'].append(js_divergence(list(i['options_probs'][k].values()),list(j['options_probs'][k].values())))
                statistics['tp']['ground_truth_delta'].append(i['options_probs'][k][permutation_list2[k][0]] - j['options_probs'][k][permutation_list2[k][0]])
                statistics['tp']['len'] += 1
            elif (i['choices'][k] == permutation_list2[k][1] or i['choices'][k] == permutation_list2[k][2]) and not j['corrects'][k]:
                # statistics['tn']['js_divergence'].append(js_divergence(list(i['options_probs'][k].values()),list(j['options_probs'][k].values())))
                # statistics['tn']['ground_truth_delta'].append(i['options_probs'][k][permutation_list2[k][0]] - j['options_probs'][k][permutation_list2[k][0]])
                # statistics['tn']['len'] += 1
                # statistics['tn']['b_delta'].append(i['options_probs'][k][permutation_list2[k][1]] - j['options_probs'][k][permutation_list2[k][1]])
                # statistics['tn']['c_delta'].append(i['options_probs'][k][permutation_list2[k][2]] - j['options_probs'][k][permutation_list2[k][2]])
                # statistics['tn']['d_delta'].append(i['options_probs'][k][permutation_list2[k][3]] - j['options_probs'][k][permutation_list2[k][3]])
                # if i['choices'][k] == j['choices'][k]:
                count += 1
                statistics['tn']['js_divergence'].append(
                    js_divergence(list(i['options_probs'][k].values()), list(j['options_probs'][k].values())))
                statistics['tn']['ground_truth_delta'].append(
                    i['options_probs'][k][permutation_list2[k][0]] - j['options_probs'][k][permutation_list2[k][0]])
                statistics['tn']['len'] += 1
                statistics['tn']['b_delta'].append(
                    i['options_probs'][k][permutation_list2[k][1]] - j['options_probs'][k][permutation_list2[k][1]])
                statistics['tn']['c_delta'].append(
                    i['options_probs'][k][permutation_list2[k][2]] - j['options_probs'][k][permutation_list2[k][2]])
                statistics['tn']['d_delta'].append(
                    i['options_probs'][k][permutation_list2[k][3]] - j['options_probs'][k][permutation_list2[k][3]])
                    # print(f"{i['id']}")
            elif not (i['choices'][k] == permutation_list2[k][1] or i['choices'][k] == permutation_list2[k][2]) and j['corrects'][k]:
                # print(f"{i['id']} : {i['choices'][k]}  | ground_truth: {permutation_list2[k][0]},  distractor: {permutation_list2[k][1]}")

                statistics['fp']['js_divergence'].append(js_divergence(list(i['options_probs'][k].values()),list(j['options_probs'][k].values())))
                statistics['fp']['ground_truth_delta'].append(i['options_probs'][k][permutation_list2[k][0]] - j['options_probs'][k][permutation_list2[k][0]])
                statistics['fp']['len'] += 1
                # statistics['fp']['B-delta'].append

            else:
                # print(f"{i['id']} : {i['choices'][k]}  | ground_truth: {permutation_list2[k][0]},  distractor: {permutation_list2[k][1]}")
                statistics['fn']['js_divergence'].append(js_divergence(list(i['options_probs'][k].values()),list(j['options_probs'][k].values())))
                statistics['fn']['ground_truth_delta'].append(i['options_probs'][k][permutation_list2[k][0]] - j['options_probs'][k][permutation_list2[k][0]])
                statistics['fn']['len'] += 1


    statistics['all']['js_divergence'] = statistics['tp']['js_divergence'] + statistics['fp']['js_divergence'] + statistics['tn']['js_divergence'] + statistics['fn']['js_divergence']
    statistics['all']['ground_truth_delta'] = statistics['tp']['ground_truth_delta'] + statistics['fp']['ground_truth_delta'] + statistics['tn']['ground_truth_delta'] + statistics['fn']['ground_truth_delta']
    statistics['all']['len'] = statistics['tp']['len'] + statistics['fp']['len'] + statistics['tn']['len'] + statistics['fn']['len']
    statistics['tn']['b_delta_mean'] = np.mean(statistics['tn']['b_delta'])
    statistics['tn']['c_delta_mean'] = np.mean(statistics['tn']['c_delta'])
    statistics['tn']['d_delta_mean'] = np.mean(statistics['tn']['d_delta'])
    print('COUNT:',count)
    for key,value in statistics.items():
        if type(value) == dict:
        # print(value['js_divergence'])
            value['js_divergence_mean'] = np.mean(value['js_divergence'])
            value['ground_truth_delta_mean'] = np.mean(value['ground_truth_delta'])
    return statistics

def fix(file):
    with open(file,'r',encoding='utf-8') as f:
        data = json.load(f)
    for i in data:
        i['options_probs'] = i.pop('options_probs')
    with open("_"+file,'w',encoding='utf-8') as f:
        json.dump(data,f,ensure_ascii=False,indent=2)


def format_output(model_name, statistics, ax):
    print(model_name)
    labels = []
    sizes = []
    colors = ['#f37576', '#fe9550', '#c697e7', '#8eb8f2']
    sns.set_style("whitegrid")

    for key, value in statistics.items():
        if isinstance(value, dict) and key != 'all':
            labels.append(key)
            sizes.append(value['len'] / statistics['all']['len'])
            print(f"{key:<10}: {value['js_divergence_mean']:>10.2%} {value['ground_truth_delta_mean']:>10.2%}")

    if 'tn' in statistics:
        tn = statistics['tn']
        print(
            f"TN: {tn['ground_truth_delta_mean']} | {tn['b_delta_mean']} | {tn['c_delta_mean']} | {tn['d_delta_mean']}")

    if 'acc' in statistics:
        print(f"Accuracy: {statistics['acc'][0]} | {statistics['acc'][1]}")

    if 'hallu' in statistics:
        print(f"Hallucination: {statistics['hallu'][0]} | {statistics['hallu'][1]}")

    print("\n")
    research_font = {
        'family': 'Times New Roman',
        'size': 18,
        'weight': 'bold'
    }
    combined_labels = [f"{label}\n{size * 100:.1f}%" for label, size in zip(labels, sizes)]
    wedges, texts = ax.pie(
        sizes,
        labels=combined_labels,
        colors=colors,
        startangle=90,
        textprops={
            'fontproperties': research_font,
            'color': 'black'
        },
        wedgeprops={'linewidth': 1, 'edgecolor': 'white'}
    )
    autotexts = []

    threshold = 0.05
    for i, (wedge, text) in enumerate(zip(wedges, texts)):
        # 计算扇区中心角度（弧度）
        angle = (wedge.theta2 + wedge.theta1) / 2
        angle_rad = np.deg2rad(angle)

        if sizes[i] < threshold:

            distance = 1.3
            new_x = distance * np.cos(angle_rad)
            new_y = distance * np.sin(angle_rad)

            text.set_position((new_x, new_y))

            ax.annotate(
                "",
                xy=(np.cos(angle_rad), np.sin(angle_rad)),
                xytext=(new_x, new_y),
                arrowprops=dict(
                    arrowstyle="-",
                    color="gray",
                    alpha=0.7,
                    connectionstyle="arc3,rad=0.1"
                )
            )
            if 90 < angle < 270:
                text.set_horizontalalignment('right')
            else:
                text.set_horizontalalignment('left')
    for i, wedge in enumerate(wedges):
        if sizes[i] >= threshold:
            center_x, center_y = wedge.center
            r = wedge.r * 0.7
            angle_rad = np.deg2rad((wedge.theta2 + wedge.theta1) / 2)
            x = r * np.cos(angle_rad)
            y = r * np.sin(angle_rad)


            autotext = ax.text(
                x, y,
                f"{sizes[i] * 100:.1f}%",
                ha='center', va='center',
                fontproperties=research_font,
                color='white',
                fontsize=0,
                fontweight='bold'
            )
            autotexts.append(autotext)

    ax.set_title(f"{model_name}", fontproperties=research_font, fontsize=14)

    ax.axis('equal')

if __name__ == '__main__':
    fig, axes = plt.subplots(1, 4, figsize=(14, 4))
    axes = axes.flatten()
    statistics = compare("EXP2_internvl3-14b_False_shuffle_20250830_062048.json","EXP2_internvl3-14b_True_shuffle_20250830_060222.json")
    format_output('InternVL3-14B:',statistics,axes[0])
    statistics = compare("EXP2_qwen2.5-vl-7b_False_shuffle_20250830_050022.json","EXP2_qwen2.5-vl-7b_True_shuffle_20250830_045101.json")
    format_output('Qwen2.5-VL-7B:',statistics,axes[1])

    statistics = compare("EXP2_internvl3-38b_False_shuffle_20250830_071101.json","EXP2_internvl3-38b_True_shuffle_20250830_064329.json")
    format_output('InternVL3-38B:',statistics,axes[2])
    statistics = compare("EXP2_qwen2.5-vl-32b_False_shuffle_20250830_080812.json","EXP2_qwen2.5-vl-32b_True_shuffle_20250830_074202.json")
    format_output('Qwen2.5-VL-32B:',statistics,axes[3])
    plt.tight_layout()
    plt.show()