import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import random

color_maps = {
    'r': "#edc1a8",
    'b': "#a5b8d7",
    'g': "#a9d3b3"
}

def draw_dist_plot(labels, score_pairs, file_name, legend=False, density=False):
    line_styles = ['-', '-', '--', '--'][:len(labels)]
    plt.rcParams['figure.dpi'] = 300
    plt.rcParams['savefig.dpi'] = 300

    colors = ['r', 'b', 'g', 'o'][:len(labels)]
    for i, label, line_style, color in zip(range(len(labels)), labels, line_styles, colors):
        one_dist = [res[i] for res in score_pairs]
        gfg = sns.distplot(one_dist, hist = False, kde = True,
                         kde_kws = {'shade': True, 
                                    'linewidth': 1, 
#                                     'axes.titlesize':8,
                                    'linestyle': line_style, 
                                    'color': color_maps[color]}, 
                          label = label)

        if legend:
            gfg.legend(fontsize=15, loc='upper right')
    
#     xticks, xlabels = plt.xticks()
#     print("xlabels", xlabels)
#     print("xticks", _)
#     plt.set_xticks(xticks, fontsize=15)
#     print("gfg.get_xticks()")
#     print(gfg.get_xticks())
#     gfg.set_xticklabels(gfg.get_xticks(), size = 15)
#     gfg.set_xticklabels(gfg.get_xticks(), fontsize = 15)
    if density:
        gfg.set_ylabel("Density", fontsize=30)
    else:
        gfg.set_ylabel("")
        
        
    gfg.set_xticklabels([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], fontsize=20)
    gfg.set(yticklabels=[])

    plt.xlim(0, 1.0)
    plt.ylim(0, 10)
    plt.savefig(f"{len(labels)}_dist_{file_name}.pdf", format='pdf')
    
def draw_hist_plot(axis, labels, score_pairs, file_name, legend=False, density=False):
    # plt.rcParams['figure.dpi'] = 300
    # plt.rcParams['savefig.dpi'] = 300
    
    median_labels = ["All", "Consistent", "Conflict"]

    colors = ['b', 'r', 'g'][:len(labels)]  # 修正颜色代码
    for i, label, color in zip(range(len(labels)), labels, colors):
        one_dist = score_pairs[i]
        print(len(one_dist))
        axis.hist(one_dist, bins=np.arange(0, 3, 1/20), alpha=0.9, label=label, color=color_maps[color], linewidth=1.0, histtype="bar") # bins=np.arange(0, 1, 1/50), 
        np_one_dist = np.array(one_dist)
        axis.axvline(np.median(np_one_dist), color=color, linestyle='dashed', linewidth=1)
        min_ylim, max_ylim = axis.get_ylim()
        if i == 0:
            axis.text(np.median(np_one_dist)*1.05, max_ylim*(0.4 - 0.1*(i + 1)), median_labels[i]+': {:.2f}'.format(np.median(np_one_dist)), fontsize=18)
        elif i == 1:
            axis.text(np.median(np_one_dist)*1.05, max_ylim*(0.4 - 0.1*(i - 2)), median_labels[i]+': {:.2f}'.format(np.median(np_one_dist)), fontsize=18)
        else:
            axis.text(np.median(np_one_dist)*1.05, max_ylim*(0.4 - 0.1*(i)), median_labels[i]+': {:.2f}'.format(np.median(np_one_dist)), fontsize=18)

        if legend:
            axis.legend(fontsize=18, loc='upper right')
    
    if density:
        axis.set_ylabel("Density", fontsize=30)
    else:
        axis.set_ylabel("")
        
    axis.tick_params(labelsize=15)
    axis.set_xticks(np.arange(0, 3.1, 3/10))
    axis.set_yticks([])

    # plt.savefig(f"{len(labels)}_hist_{file_name}.pdf", format='pdf')

import json


def read_file(file_path):
    alls = []
    wo_conflicts = []
    w_conflicts = []
    with open(file_path, "r") as fin:
        for line, _ in zip(fin.readlines(), [alls, wo_conflicts, w_conflicts]):
            _.extend(json.loads(line))

    scores = [alls, wo_conflicts, w_conflicts]
    return scores


# all_scores = {}
dataset_nicknames = ["viquae", "infoseek"]
model_nicknames = ["llava-v1.6-vicuna-7b-hf", "llava-v1.6-vicuna-13b-hf", "llava-v1.6-34b-hf"]
labels = ['All','Consistent', 'Conflict']
legend = True
density = False
figure, axis = plt.subplots(len(dataset_nicknames), 1, figsize=(18, 9), sharey=True)

for d_nickname, big_ax in zip(dataset_nicknames, axis):
    if d_nickname == "viquae":
        d_nickname = "ViQuAE"
    else:
        d_nickname = "InfoSeek"
    big_ax.set_title(f"{d_nickname}\n", fontsize=25) # , fontweight='bold'
    big_ax.tick_params(labelcolor=(1.,1.,1., 0.0), top='off', bottom='off', left='off', right='off')
    big_ax._frameon = False
    big_ax.set_xticks([])
    big_ax.set_yticks([])

index = 1
for row, d_nickname in enumerate(dataset_nicknames):
    for column, m_nickname in enumerate(model_nicknames):
        a = figure.add_subplot(len(dataset_nicknames), len(model_nicknames), index)
        nickname = f"{d_nickname}_{m_nickname}"
        file_path = f"outputs/draw/{d_nickname}_{m_nickname}_scores.txt"
        # all_scores.update({nickname: read_file(file_path)})
        a.set_title(f"{m_nickname}", fontsize=20) # , fontweight='bold'
        draw_hist_plot(a, labels, read_file(file_path), "file_name", legend, density)
        index += 1
        
figure.tight_layout()

        



# # labels = ['Counter-answer', 'Memory-answer', 'Uncertain']
# # labels = ['Counter-answer','Memory-answer']
# # labels = ['LM','Correct', 'Wrong']
# labels = ['All','Consistent', 'Conflict']
# # labels = ['Verifier-score', 'correct-score', 'wrong-score']
# # labels = ['correct-score', 'wrong-score']
# # labels = ['LM-score']
# # score_pairs = [[x[1]] for x in report_values if x[1]>0.5]
# # score_pairs = [[x[0]] for x in report_values2 if x[0]>0.5]
# # file_name = 'lm-scores'
# file_name = '7b'



# draw_hist_plot(labels, scores, file_name, legend, density)


plt.savefig(f"scores.pdf", format='pdf')