import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Input data
data1_path = "./POPE/chair/overall_group_ab1_0130/chair_instructblip_results.json"
import json
with open(data1_path, "r") as f:
    data1 = [json.loads(line) for line in f]
data2_path = "./POPE/chair/overall_group_ab1_0130/chair_llava_results.json"
with open(data2_path, "r") as f:
    data2 = [json.loads(line) for line in f]
    
data3_path = "./POPE/chair/overall_group_ab1_0130/chair_mplug-owl2_results.json"
with open(data3_path, "r") as f:
    data3 = [json.loads(line) for line in f]
    
# data1 = {
#     "th": 1, 
#     "cfg": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
#     "Recall": [0.31237661351556567, 0.3342444950645406, 0.3366742596810934, 0.3406226271829917, 0.34441913439635535, 0.34882308276385726, 0.35611237661351564, 0.36795747911921034, 0.37934700075930144, 0.3901290812452543, 0.40592255125284743]
# }

# data2 = {
#     "th": 0.95, 
#     "cfg": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
#     "Recall": [0.3350037965072134, 0.3362186788154897, 0.3412300683371298, 0.34365983295368263, 0.3460895975702354, 0.35186028853454826, 0.3582384206529992, 0.36795747911921034, 0.3743356112376614, 0.37835990888382687]
# }

# Creating DataFrames
df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2)
df3 = pd.DataFrame(data3)
df_ls = [df3]
model_name_ls = ["mplug_owl2"]
# model_name_ls = ["instructblip", "llava"]
metrics_ls = ["CHAIRs", "CHAIRi", "Recall"]
COLOR_ls = ['#810f7c', '#8c96c6', '#f4a582', '#67001f', '#f7f7f7']
marker_ls = ["o", "^", "D", "v", "x"]
th_ls = ["0.5", "0.7", "0.9", "0.95", "Truth"]
cfg_ls = [0.5, 1.0]
cfg_idx_ls = [-6, -1]
for idx, df in enumerate(df_ls):
    for metrics in metrics_ls:
        # Extracting Recall values for cfg=0.5 and cfg=1.0
        dict1 = {"0.5":[], "1.0":[]}
        for i, th in enumerate(th_ls):
            for j, cfg in enumerate(cfg_ls):
                dict1[str(cfg)].append(df[metrics][i][cfg_idx_ls[j]])
        

        # Plotting
        plot_data = df[metrics]
        plt.figure(figsize=(4,4))
        plt.rcParams.update({'font.size': 14, 'figure.figsize': (4,4), 'axes.spines.right': False, 'axes.spines.top': False})
        plt.setp(plt.gca().lines, linewidth=2)
        plt.grid(True)
        for j, cfg in enumerate(cfg_ls):
            # form a dataframe that th_ls, dict1[str(cfg)] as columns
            df_plot = pd.DataFrame({"th": th_ls, f"{cfg}": dict1[str(cfg)]})
            if metrics == "Recall":
                sns.lineplot(data=df_plot, x="th", y=f"{cfg}", label=r'$\gamma$'+f"={cfg}", marker=marker_ls[j], linewidth=3, markersize=10, color=COLOR_ls[j])
            else:
                sns.lineplot(data=df_plot, x="th", y=f"{cfg}", marker=marker_ls[j], linewidth=3, markersize=10, color=COLOR_ls[j])
        
        if metrics == "Recall":
            plt.legend(bbox_to_anchor=(1.05, 1), loc='upper right', borderaxespad=0.)
        plt.grid(True)
        plt.xticks(fontsize=12, rotation=45)
        plt.yticks(fontsize=12) 
        
        plt.xlabel('')
        # do not show y axis name
        plt.ylabel('')
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        plt.savefig(f"./POPE/chair/overall_group_ab3/ab2_{model_name_ls[idx]}_{metrics}.pdf")
        print(f"Saved ./POPE/chair/overall_group_ab3/ab2_{model_name_ls[idx]}_{metrics}.pdf")