
import json
import os
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import glob
import argparse


def scatter_plot(save_dir, result_file_name, metrics):
    print("== Visualizing the chair result using scatter plot===")

    with open(os.path.join(save_dir, result_file_name), 'r') as f:
        results = [json.loads(line) for line in f]
    chair_s = [result['overall_metrics'][metrics[0]] for result in results]
    chair_i = [result['overall_metrics'][metrics[1]] for result in results]
    label = [result['answer_file'] for result in results]

    plt.figure(figsize=(20, 10))
    sns.set_theme(style="whitegrid")
    # x axis is chair_s, y axis is chair_i
    df = pd.DataFrame({metrics[0]: chair_s, metrics[1]: chair_i, 'label': label})
    sns.scatterplot(data=df, x=metrics[0], y=metrics[1], hue='label')
    plt.savefig(os.path.join(save_dir, 'scatter_CHAIRs_CHAIRi.png'))
    print("Scatter plot result saved in ", os.path.join(
        save_dir, 'scatter_CHAIRs_CHAIRi.png'))


def load_eval_jsonl(save_dir, result_file_name, eval_metrics):
    # print(f"=== Loading {os.path.join(save_dir, result_file_name)} ===")
    metrics = get_metrics(eval_metrics)
    
    with open(os.path.join(save_dir, result_file_name), 'r') as f:
        results = [json.loads(line) for line in f]
    print(len(results))
    # chair_s_13b = []
    # chair_i_13b = []
    values_7b = [[] for i in range(len(metrics))]
    cfg_7b = []
    # cfg_13b = []
    label_7b = []
    # label_13b = []
    for result in results:
        if '7b' in result['answer_file'] or "mm" in result['answer_file']: #TODO: remove mm
            for i in range(len(metrics)):
                values_7b[i].append(result['overall_metrics'][metrics[i]])
            
            cfg_7b.append(float(result['answer_file'].split(
                '-cfg')[-1].split('.json')[0]))
            label_7b.append(result['answer_file'])
    #     elif '13b' in result['answer_file']:
    #         chair_s_13b.append(result['overall_metrics'][metrics[0]])
    #         chair_i_13b.append(result['overall_metrics'][metrics[1]])
    #         cfg_13b.append(
    #             float(result['answer_file'].split('-cfg')[-1].split('.json')[0]))
    #         label_13b.append(result['answer_file'])
    
    # df_7b = pd.DataFrame(
    #     {'cfg': cfg_7b, metrics[0]: chair_s_7b, metrics[1]: chair_i_7b, 'label': label_7b})
    df_7b = pd.DataFrame(
        {'cfg': cfg_7b, 'label': label_7b})
    for i in range(len(metrics)):
        df_7b[metrics[i]] = values_7b[i]
    # df_13b = pd.DataFrame({'cfg': cfg_13b, metrics[0]: chair_s_13b,
    #                       metrics[1]: chair_i_13b, 'label': label_13b})
    # # merge two dataframes
    df_7b['label'] = '7b-'+save_dir.split('/')[-1].split('_')[1]
    # df_13b['label'] = '13b-'+save_dir.split('/')[-1].split('_')[1]
    df_7b['name'] = '7b-'+save_dir.split('/')[-1].split('_')[-1]
    # df_13b['name'] = '13b-'+save_dir.split('/')[-1].split('_')[-1]
    # df = pd.concat([df_7b, df_13b])
    
    return df_7b


def line_plot(save_dir, df, eval_metrics):
    # print("=== Visualizing the chair result using line plot ===")
    metrics = get_metrics(eval_metrics)
    
    plt.figure(figsize=(6*len(metrics), 6))
    sns.set_theme(style="whitegrid")

    for i in range(len(metrics)):
        plt.subplot(1, len(metrics), 1+i)
        sns.lineplot(data=df, x='cfg', y=metrics[i], hue='name')
        if 0 in df['cfg'].values:
            plt.axhline(y=df[df['cfg'] == 0][metrics[i]].values[0],
                        color='grey', linestyle='--')

        # draw mean_df values on the plot
        for j in range(len(df)):
            plt.text(df['cfg'][j], df[metrics[i]][j], round(df[metrics[i]][j], 3), ha='center', va='bottom', fontsize=10)  
                
        plt.title(metrics[i])
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper right', borderaxespad=0.)

    plt.suptitle(f'{eval_metrics} with different cfg ({save_dir})')
    
    plt.savefig(os.path.join(save_dir, f'{eval_metrics}-cfg.png'))
    # print(f"Lineplot saved in {os.path.join(save_dir, 'lineplot.png')}\n")
    return df


def line_plots(save_dir, df_ls, eval_metrics):
    print("=== Visualizing the chair result using line plot ===")

    metrics = get_metrics(eval_metrics)

    # # title of the figure
    # plt.suptitle(f'{metrics[0]} and {metrics[1]} with different cfg ({save_dir})')
    plt.figure(figsize=(6*len(metrics), 6))

    for i in range(len(metrics)):
        plt.subplot(1, len(metrics), 1+i)
        for df in df_ls:
            sns.lineplot(data=df, x='cfg', y=metrics[i], hue='name')
            if 0 in df['cfg'].values:
                plt.axhline(y=df[df['cfg'] == 0][metrics[i]].values[0],
                            color='grey', linestyle='--')

            # draw mean_df values on the plot
            for j in range(len(df)):
                plt.text(df['cfg'][j], df[metrics[i]][j], round(df[metrics[i]][j], 3), ha='center', va='bottom', fontsize=10)  
                    
        plt.title(metrics[i])
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper right', borderaxespad=0.)

    plt.suptitle(f'{eval_metrics} with different cfg ({save_dir})')
    
    plt.savefig(os.path.join(save_dir, f'{eval_metrics}-cfg.png'))
    print("lineplots saved in ", os.path.join(save_dir, f'{eval_metrics}-cfg.png'))
    return df


def line_plot_with_error_bars(save_dir, df_ls, eval_metrics):
    print("=== Visualizing the chair result of a group of exp using line plot with error bars===")
    metrics = get_metrics(eval_metrics)
    
    full_df = pd.concat(df_ls)
    full_df = full_df.drop(columns=['name'])

    mean_df = full_df.groupby(['cfg', 'label'], as_index=False).mean()
    std_df = full_df.groupby(['cfg', 'label'], as_index=False).std()

    print(f"mean_df: \n {mean_df}")
    plt.figure(figsize=(6*len(metrics), 6))

    for i in range(len(metrics)):
        plt.subplot(1, len(metrics), 1+i)
        labels = mean_df['label'].unique()
        for label in labels:
            subset = mean_df[mean_df['label'] == label]
            std_err = std_df[std_df['label'] == label][metrics[i]]
            plt.errorbar(subset['cfg'], subset[metrics[i]],
                        yerr=std_err, fmt='-', capsize=5, label=label)
        if 0 in subset['cfg'].values:
            baseline_value = subset[subset['cfg'] == 0][metrics[i]].values[0]
            plt.axhline(y=baseline_value, color='grey', linestyle='--')
        
        # draw mean_df values on the plot
        for j in range(len(mean_df)):
            plt.text(mean_df['cfg'][j], mean_df[metrics[i]][j], round(mean_df[metrics[i]][j], 3), ha='center', va='bottom', fontsize=10)  
                  
        plt.title(metrics[i])
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper right', borderaxespad=0.)


    plt.suptitle(f'{eval_metrics} with different cfg ({save_dir})')
    plt.savefig(os.path.join(save_dir, f'{eval_metrics}-cfg.png'))
    print(
        f"Lineplot saved in {os.path.join(save_dir, f'{eval_metrics}-cfg.png')} \n")

    return mean_df

def save_mean_df(mean_df, save_dir, eval_metrics, type):
    metrics = get_metrics(eval_metrics)
    # save mean_df to json
    mean_df_dict = {mean_df['label'][0]: {}}
    mean_df_dict[mean_df['label'][0]]['cfg'] = mean_df['cfg'].tolist()
    for i in range(len(metrics)):
        mean_df_dict[mean_df['label'][0]][metrics[i]] = mean_df[metrics[i]].tolist()
            
    save_file = os.path.join(save_dir, f'{eval_metrics}_{type}_results.json')
    with open(save_file, 'a+') as f:
        json.dump(mean_df_dict, f)
        f.write('\n')
    # print(f"mean_df saved in {save_file}!")

def compare_line_plots_with_error_bars(save_dir, df_ls_dict, eval_metrics, type):
    
    print("=== Visualizing the comparison of results using line plots with error bars ===")
    
    metrics = get_metrics(eval_metrics)

    plt.figure(figsize=(6*len(metrics), 6))

    for idx, (model, df_ls) in enumerate(df_ls_dict.items()):
        print(f"{model}")

        # Combine dataframes in the list, drop the 'name' column, and calculate mean and std
        full_df = pd.concat(df_ls).drop(columns=['name'])
        mean_df = full_df.groupby(['cfg', 'label'], as_index=False).mean()
        std_df = full_df.groupby(['cfg', 'label'], as_index=False).std()

        labels = mean_df['label'].unique()

        # # Plot for the first metric
        for i in range(len(metrics)):
            plt.subplot(1, len(metrics), i+1)
            for label in labels:
                subset = mean_df[mean_df['label'] == label]
                std_err = std_df[std_df['label'] == label][metrics[i]]
                type_idx = TYPES.index(type)
                try:
                    plt.errorbar(subset['cfg'], subset[metrics[i]],
                                yerr=std_err, fmt='-', capsize=5, label=f"{label}", color=COLOR_ls[eval_metrics][idx])
                except:
                    plt.errorbar(subset['cfg'], subset[metrics[i]],
                                yerr=std_err, fmt='-', capsize=5, label=f"{label}")
                
            # set x axis name 
            plt.xlabel('controllable strength', fontsize=16)
            # set y axis name
            plt.ylabel(metrics[i], fontsize=16)
            # plt.legend()
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper right', borderaxespad=0., fontsize=16)

        save_mean_df(mean_df, save_dir, eval_metrics, type)
    
    # plt.suptitle(f"{type} {eval_metrics} Evaluation")
    # plt.suptitle(f'Comparison of {eval_metrics} with different cfg ({save_dir})')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    os.makedirs(os.path.join(save_dir, 'figs'), exist_ok=True)
    plt.savefig(os.path.join(save_dir, 'figs', f'all-{type}-{eval_metrics}-cfg.png'))
    plt.savefig(os.path.join(save_dir, 'figs', f'all-{type}-{eval_metrics}-cfg.pdf'))
    print(f"Comparison lineplot with error bars saved in {os.path.join(save_dir, 'figs', f'all-{type}-{eval_metrics}-cfg.pdf')}\n")
    plt.close()
    
    return full_df

MARKER_ls = ['o', '^','o', '^','o', '^']
def compare_line_plots(save_dir, df_ls_dict, eval_metrics, type):
    
    print("=== Visualizing the comparison of results using line plots with error bars ===")
    
    metrics = get_metrics(eval_metrics)
    for i in range(len(metrics)):
        # plt.figure(figsize=(6, 6)) 
        plt.rcParams.update({'font.size': 12, 'figure.figsize': (6,6), 'axes.spines.right': False, 'axes.spines.top': False})
        plt.setp(plt.gca().lines, linewidth=2)
        # plt.figure(figsize=(5, 5))
        plt.grid(True)

        for idx, (model, df_ls) in enumerate(df_ls_dict.items()):
            print(f"model: {model}")
            # if "0.95" not in model:
            #     if "0.5" in model or "grey" == model:# or "0.9" in model:
            #         continue

            # Combine dataframes in the list, drop the 'name' column, and calculate mean and std
            full_df = pd.concat(df_ls).drop(columns=['name'])
            mean_df = full_df.groupby(['cfg', 'label'], as_index=False).mean()
            save_mean_df(mean_df, save_dir, eval_metrics, type)

            # std_df = full_df.groupby(['cfg', 'label'], as_index=False).std()
            labels = mean_df['label'].unique()

            for label in labels:
                subset = mean_df[mean_df['label'] == label]
                # std_err = std_df[std_df['label'] == label][metrics[i]]
                # std_err = std_err.to_list()
                

                # if cfg=0 is not in the subset, add it
                if 0 not in subset['cfg'].values:
                    new_row = {'cfg': 0, 'label': label}
                    new_row[metrics[0]] = get_0_value(metrics[0], type)
                    new_row[metrics[1]] = get_0_value(metrics[1], type)
                    new_row[metrics[2]] = get_0_value(metrics[2], type)
                    new_row_df = pd.DataFrame(new_row, index=[0])
                    subset = pd.concat([subset, new_row_df], ignore_index=True)
                
                type_idx = TYPES.index(model)
                # if "0.95" in model:
                #     label = "MARINE"+'-'+"DETR"
                if metrics[i] == 'Recall':
                    sns.lineplot(data=subset, x='cfg', y=metrics[i], color=COLOR_ls[eval_metrics][type_idx], linewidth=2, marker=MARKER_ls[type_idx], label=label, markersize=6)
                sns.lineplot(data=subset, x='cfg', y=metrics[i], color=COLOR_ls[eval_metrics][type_idx], linewidth=2, marker=MARKER_ls[type_idx], markersize=6)
                # plt.plot(sft_x, sft_y, marker='o', color='#2166ac', linewidth=2) 
    
                # plt.errorbar(subset['cfg'], subset[metrics[i]],
                #             fmt='-', capsize=5, color=COLOR_ls[eval_metrics][type_idx], elinewidth=2)

        # only show [0, 0.5, 1.0]
        plt.xticks([0, 0.5, 1.0], fontsize=12)
        # only keep 4 y ticks
        plt.yticks(fontsize=12) 
        
        plt.xlabel('')
        # do not show y axis name
        plt.ylabel('')
        
        
        
        # set x axis name 
        # plt.xlabel('controllable strength', fontsize=16)
        # # set y axis name
        # plt.ylabel(metrics[i], fontsize=16)
        if metrics[i] == 'Recall':
            plt.legend(bbox_to_anchor=(1.05, 1), loc='upper right', borderaxespad=1, fontsize=16)
            
        # plt.legend()
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        os.makedirs(os.path.join(save_dir, 'figs'), exist_ok=True)
        plt.savefig(os.path.join(save_dir, 'figs', f'ap-{type}-{metrics[i]}.pdf'))
        # only save subplots
        
        
        print(f"Saved in {os.path.join(save_dir, 'figs', f'ap-{type}-{metrics[i]}.pdf')}\n")
        plt.close()
    
    return full_df

def compare_line_plots_ab2(save_dir, df_ls_dict, eval_metrics, type):
    
    print("=== Visualizing the comparison of results using line plots with error bars ===")
    df_ls_dict = read_mean_df()
    # only need cfg =0, 0.5, 1.0
    import pdb; pdb.set_trace()
    
    metrics = get_metrics(eval_metrics)
    for i in range(len(metrics)):
        # plt.figure(figsize=(6, 6)) 
        plt.rcParams.update({'font.size': 14, 'figure.figsize': (4,4), 'axes.spines.right': False, 'axes.spines.top': False})
        plt.setp(plt.gca().lines, linewidth=3)
        # plt.figure(figsize=(5, 5))
        plt.grid(True)

        for idx, (model, mean_df) in enumerate(df_ls_dict.items()):
            # type
            print(f"model: {model}")
            # if "0.95" not in model:
            #     if "0.7" in model or "grey" == model or "0.9" in model:
            #         continue

            # Combine dataframes in the list, drop the 'name' column, and calculate mean and std
            # full_df = pd.concat(df_ls).drop(columns=['name'])
            # mean_df = full_df.groupby(['cfg', 'label'], as_index=False).mean()
            # save_mean_df(mean_df, save_dir, eval_metrics, type)

            # std_df = full_df.groupby(['cfg', 'label'], as_index=False).std()
            labels = mean_df['label'].unique()

            for label in labels:
                subset = mean_df[mean_df['label'] == label]
                # std_err = std_df[std_df['label'] == label][metrics[i]]
                # std_err = std_err.to_list()
                

                # if cfg=0 is not in the subset, add it
                if 0 not in subset['cfg'].values:
                    new_row = {'cfg': 0, 'label': label}
                    new_row[metrics[0]] = get_0_value(metrics[0], type)
                    new_row[metrics[1]] = get_0_value(metrics[1], type)
                    new_row[metrics[2]] = get_0_value(metrics[2], type)
                    new_row_df = pd.DataFrame(new_row, index=[0])
                    subset = pd.concat([subset, new_row_df], ignore_index=True)
                
                type_idx = TYPES.index(model)
                if "0.95" in model:
                    label = "MARINE"+'-'+"DETR"
                if metrics[i] == 'Recall':
                    sns.lineplot(data=subset, x='cfg', y=metrics[i], color=COLOR_ls[eval_metrics][type_idx], linewidth=3, marker=MARKER_ls[type_idx], label=label, markersize=10)
                sns.lineplot(data=subset, x='cfg', y=metrics[i], color=COLOR_ls[eval_metrics][type_idx], linewidth=3, marker=MARKER_ls[type_idx], markersize=10)
                # plt.plot(sft_x, sft_y, marker='o', color='#2166ac', linewidth=2) 
    
                # plt.errorbar(subset['cfg'], subset[metrics[i]],
                #             fmt='-', capsize=5, color=COLOR_ls[eval_metrics][type_idx], elinewidth=2)

        # only show [0, 0.5, 1.0]
        plt.xticks([0, 0.5, 1.0], fontsize=14)
        # only keep 4 y ticks
        plt.yticks(fontsize=14) 
        
        plt.xlabel('')
        # do not show y axis name
        plt.ylabel('')
        
        
        
        # set x axis name 
        # plt.xlabel('controllable strength', fontsize=16)
        # # set y axis name
        # plt.ylabel(metrics[i], fontsize=16)
        if metrics[i] == 'Recall':
            plt.legend(bbox_to_anchor=(1.05, 1), loc='upper right', borderaxespad=1, fontsize=16)
            
        # plt.legend()
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        os.makedirs(os.path.join(save_dir, 'figs'), exist_ok=True)
        plt.savefig(os.path.join(save_dir, 'figs', f'ab-{type}-{metrics[i]}.pdf'))
        # only save subplots
        
        
        print(f"Saved in {os.path.join(save_dir, 'figs', f'ab-{type}-{metrics[i]}.pdf')}\n")
        plt.close()
    
    # plt.suptitle(f"{type} {eval_metrics} Evaluation")
    # plt.suptitle(f'Comparison of {eval_metrics} with different cfg ({save_dir})')

    
    return full_df

# read mean_df from json
def read_mean_df():
    mean_df_path = "./POPE/chair/overall_group_ab1_0130/chair_instructblip_results.json"
    with open(mean_df_path, 'r') as f:
        mean_df_dict = [json.loads(line) for line in f]

    data5 = {"CHAIRs":[], "CHAIRi":[], "Recall":[]}
    data1 = {"CHAIRs":[], "CHAIRi":[], "Recall":[]}
    for mean_df in mean_df_dict:
        if v['cfg'] == 0.5:
            data5["CHAIRs"].append(v["CHAIRs"])
            data5["CHAIRi"].append(v["CHAIRi"])
            data5["Recall"].append(v["Recall"])
        elif v['cfg'] == 1.0:
            data1["CHAIRs"].append(v["CHAIRs"])
            data1["CHAIRi"].append(v["CHAIRi"])
            data1["Recall"].append(v["Recall"])
    data_df_ls = [pd.DataFrame(data5), pd.DataFrame(data1)]
    print(data_df_ls)
    return data_df_ls

def get_0_value(metric, model):
    import pandas as pd
    data = {
        'CHAIRs': [0.261, 0.075, 0.116, 0.059, 0.053, 0.270667],
        'CHAIRi': [0.108, 0.042, 0.068, 0.035, 0.036, 0.104874],
        'Recall': [0.461, 0.4087, 0.365, 0.361, 0.312, 0.50858],
        'Accuracy': [0.53513333, 0.7897    , 0.802     , 0.715     , 0.71646667, 0.55966667],
        'Yes_ratio':[0.95703333, 0.6123    , 0.4449    , 0.71923333, 0.6142    , 0.91386667],
        'F1':[0.68096667, 0.8109    , 0.79016667, 0.76626667, 0.7467    , 0.68856667]
    }

    # Row keys
    row_keys = ['llava', 'llava2', 'minigptv2', 'mplug-owl2', 'instructblip', 'llama-adapter-v2']

    # Creating the DataFrame
    df = pd.DataFrame(data, index=row_keys)
    return df[metric][model]


def get_metrics(eval_metrics):
    if eval_metrics.lower() == 'chair':
        metrics = ['CHAIRs', 'CHAIRi', 'Recall']
    elif eval_metrics.lower() == 'pope':
        metrics = ['Accuracy', 'Yes_ratio', 'F1']
    else:
        raise NotImplementedError
    return metrics
        
def load_eval_results(save_dir, result_file_name, eval_metrics, yaml_file):
    import yaml
    with open(yaml_file, 'r') as stream:
        try:
            data = yaml.safe_load(stream)
        except yaml.YAMLError as exc:
            print(exc)
    
    data = data['results']
    for type, results in data.items():
        results_dict = {str(type): {}}
        for model, result in results.items():
            print(type, model)
            # print(result['save_dir'])
            root_save_dir = result['save_dir']
            seed_ls = result['seed_ls']
            eval_dirs = glob.glob(root_save_dir + '/*')
                        
            eval_dirs = [
                eval_dir for eval_dir in eval_dirs if not os.path.isfile(eval_dir)]
            # any seed in seed_ls is in eval_dir
            eval_dirs0 = [eval_dir for eval_dir in eval_dirs if any(
                str(seed) in eval_dir for seed in seed_ls)]
            df_ls = []
            if eval_dirs0 == []:
                print(f"No eval dirs found in {eval_dirs}!")
                continue
                # raise ValueError(f"No eval dirs found in {eval_dirs}! {root_save_dir} {seed_ls}")
            for eval_dir in eval_dirs0:
                # eval_dir = os.path.join(eval_dir, 'eval')
                df = load_eval_jsonl(eval_dir, result_file_name, eval_metrics)
                df['name'] = type
                df['label'] = df['label'] + '-' + type+'-'+model
                df_ls.append(df)
            results_dict[str(type)][model] = df_ls
        full_df = compare_line_plots_with_error_bars(save_dir, results_dict[type], eval_metrics, type)
        
# assign color for each type
# COLOR_ls={}  contains 5 colors from light to drak in sns color cmap "Blue" 
TYPES = ['black', 'grey_0.95', 'grey_0.9', 'grey_0.7', 'grey']
# color='#2166ac'  color='#b2182b'
# COLOR_ls = {'chair': ['#2166ac', '#b2182b', '#2166ac', '#b2182b', '#4d004b'], "pope": ['#ccece6', '#99d8c9', '#66c2a4', '#2ca25f', '#006d2c']}
# COLOR_ls = {'chair': ['#810f7c', '#8c96c6', '#8856a7', '#810f7c', '#4d004b'], "pope": ['#b3cde3', '#8c96c6', '#8856a7', '#810f7c', '#4d004b']}
# COLOR_ls = {'chair': ['#b3cde3', '#8c96c6', '#810f7c', '#4d004b', '#ccece6'], "pope": ['#ccece6', '#99d8c9', '#66c2a4', '#2ca25f', '#006d2c']}
# list the color from dark to light
COLOR_ls = {'chair': ['#4d004b', '#810f7c', '#8c96c6', '#b3cde3', '#ccece6']}
def load_group_eval_results(save_dir, result_file_name, eval_metrics, yaml_file):
    import yaml
    with open(yaml_file, 'r') as stream:
        try:
            data = yaml.safe_load(stream)
        except yaml.YAMLError as exc:
            print(exc)
    
    data = data['results']['model']
    for model, results in data.items():
        results_dict = {str(model): {}}
        for type, result in results.items():
            print(type, model)
            root_save_dir = result['save_dir']
            seed_ls = result['seed_ls']
            eval_dirs = glob.glob(root_save_dir + '/*')
                        
            eval_dirs = [
                eval_dir for eval_dir in eval_dirs if not os.path.isfile(eval_dir)]
            # any seed in seed_ls is in eval_dir
            eval_dirs0 = [eval_dir for eval_dir in eval_dirs if any(
                str(seed) in eval_dir for seed in seed_ls)]
            df_ls = []
            if eval_dirs0 == []:
                # raise ValueError(f"No eval dirs found in {eval_dirs}!")
                print(f"No eval dirs found in {eval_dirs}!")
            for eval_dir in eval_dirs0:
                # eval_dir = os.path.join(eval_dir, 'eval')
                df = load_eval_jsonl(eval_dir, result_file_name, eval_metrics)
                df['name'] = type
                th_ls = [0.7, 0.95, 0.9, 0.5]
                # rename df['label'] to df['cfg']
                # df['model'] = df['cfg']
                if 'black' in type:
                    df['label'] = "MARINE"+'-'+"Truth"
                    # df['cfg'] = 1.0
                else:
                    for th in th_ls:
                        if str(th) in type:
                            break
                    df['label'] = "MARINE"+'-'+"DETR"+'-'+str(th)
                    # df['cfg'] = th
                # df['label'] = df['label'] + '-' + type+'-'+model
                
                df_ls.append(df)
            results_dict[model][str(type)] = df_ls
        full_df = compare_line_plots(save_dir, results_dict[model], eval_metrics, model)
        
        
            
if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument("--duplicate", action='store_true')
    parser.add_argument("--save_dir", type=str,
                        default='../POPE/llava/black/answer_1219/eval')
    parser.add_argument("--result_file_name", type=str, default='eval.json')
    parser.add_argument("--metrics", type=str, default='CHAIR')
    args = parser.parse_args()

    # if args.metrics == 'CHAIR':
    #     metrics = ['CHAIRs', 'CHAIRi']#, 'Recall', 'len']
    # elif args.metrics == 'POPE':
    #     metrics = ['Accuracy', 'Yes_ratio']#, 'F1']
    if args.duplicate:
        # root_save_dir = os.path.dirname(args.save_dir)
        root_save_dir = args.save_dir
        eval_dirs = glob.glob(root_save_dir + '/*')
        eval_dirs = [os.path.join(root_save_dir, eval_dir)
                     for eval_dir in eval_dirs]
        eval_dirs = [
            eval_dir for eval_dir in eval_dirs if not os.path.isfile(eval_dir)]

        df_ls = []
        if len(eval_dirs) == 0:
            raise ValueError("No eval dirs found!")
        for eval_dir in eval_dirs:
            # if os.path.join(save_dir, result_file_name exists)
            if not os.path.isfile(os.path.join(eval_dir, args.result_file_name)):
                continue
            print(f"=== Evaluating {eval_dir} ===")
            
            df = load_eval_jsonl(eval_dir, args.result_file_name, args.metrics)
            if df is None:
                print(f"{eval_dir} is empty!")
            print(df)
            line_plot(eval_dir, df, args.metrics)
            df_ls.append(df)  
        if df_ls == []:
            raise ValueError("No eval dirs found!")     
        line_plot_with_error_bars(root_save_dir, df_ls, args.metrics)
    
    else:
        df = load_eval_jsonl(args.save_dir, args.result_file_name, args.metrics)
        line_plot(args.save_dir, df,args.metrics)
    # scatter_plot(args.save_dir, args.result_file_name)
