import json
from utils.eval_yesno import evaluate_yes_no
from utils.type1_utils import eval_single, avg_acc, avg_acc_all
import argparse
import numpy as np
import os

root_path = r"/data/aofei/hallucination"

def eval_yes_no(results):
    if "response" in results[0].keys():
        answers = [{"text":line['response']} for line in results]
    else:
        answers = [{"text":line['model_answer']} for line in results]

    if "ground_truth" in results[0]:
        labels = [line['ground_truth'] for line in results]
    else:
        labels = [line['gt_ans'] for line in results]
    return evaluate_yes_no(answers, labels)

def eval_closed_single(ori_data, id_to_ori, inference_res):
    yn_ids, mc_ids = [], []
    for i in ori_data:
        if "ground_truth_type" in i.keys():
            if i['ground_truth_type'] == "binary":
                yn_ids.append(i['qid'])
            elif i['question_type'] == "multi-choice":
                mc_ids.append(i['qid'])
        else:
            yn_ids.append(i['qid'])
    yn_results, mc_results= [], []
    
    for i in inference_res:
        if i['question_id'] in yn_ids:
            yn_results.append(i)
            
        elif i['question_id'] in mc_ids:
            ori_i = id_to_ori[i['question_id']]
            i['choices'] = ori_i['choices']
            i['question_type'] = ori_i['question_type']
            mc_results.append(i)

    yn_acc_type, yn_f1_type = eval_yes_no(yn_results)
    
    return yn_acc_type, yn_f1_type


mimic_ori2_path = "/root/project/benchmark_data/Knowledge_Deficiency_Hallucination/close-ended/mimic_cxr_close_pairs.json"
with open(mimic_ori2_path, 'r') as file:
    mimic_ori2 = json.load(file)
mimic_id_to_ori2 = dict()    
for i in mimic_ori2:
    mimic_id_to_ori2[i['qid']] = i
    
mimic_ori3_path = '/root/project/benchmark_data/Context_Misalignment_Hallucination/MIMIC-CXR_pairs.json'
with open(mimic_ori3_path, 'r') as file:
    mimic_ori3 = json.load(file)
mimic_id_to_ori3 = dict()    
for i in mimic_ori3:
    mimic_id_to_ori3[i['qid']] = i

def eval_mimic_single(type2_infer_path, type3_infer_path):
    with open(type2_infer_path, 'r') as f:
        mimic_results = []
        for line in f:
            mimic_results.append(json.loads(line))
    
    slake_f_accs = eval_closed_single(mimic_ori2, mimic_id_to_ori2, mimic_results)
    
    with open(type3_infer_path, 'r') as f:
        mimic_results = []
        for line in f:
            mimic_results.append(json.loads(line))
    
    slake_f_accs3 = eval_closed_single(mimic_ori3, mimic_id_to_ori3, mimic_results)
    
    return slake_f_accs, slake_f_accs3


import pandas as pd
if __name__ == "__main__":
    # models = ['original', 'DoLa', 'PAI', 'm3id', 'VCD', 'damro']
    # # models = ["original", 'DoLa', 'PAI']
    # for model in models:
    #     infer_path_type2 = f"{root_path}/MedHEval/type2/baselines_llava_v1.6_13b/mimic_cxr_closed/pred_{model}.jsonl"
    #     infer_path_type3 = f"{root_path}/MedHEval/type3/baselines_llava_v1.6_13b/pred_{model}.jsonl"
    #     slake_f_accs, slake_f_accs3 = eval_mimic_single(infer_path_type2, infer_path_type3)
        
    #     res_path = f"{root_path}/MedHEval/type2/baselines_llava_v1.6_13b/{model}_type2&3.txt"

    #     with open(res_path, 'w') as file:
    #         file.write("type2: " + str(slake_f_accs) + '\n')
    #         file.write("type3: " + str(slake_f_accs3) + '\n')
            
    parser = argparse.ArgumentParser()
    parser.add_argument("--origin-data-dir", type=str, default=None)
    parser.add_argument("--result-dir", type=str, default=None)
    parser.add_argument("--model", type=str, default="")
    parser.add_argument("--subfix", type=str, default="")
    parser.add_argument("--answers-file", type=str, default="answer.jsonl")
    parser.add_argument("--datasets", type=str, default="")
    parser.add_argument("--hallu_type", type=str, default='')
    args = parser.parse_args()

    # args.datasets = 'mimic_cxr'
    # args.model = 'llava_med_v1.5'
    # args.hallu_type = 'context_misalignment' # visual_misinterpretation knowledge_deficiency context_misalignment 
    # args.subfix = "" # '_24_7_Mimic_Knowledge_I+Q;I+Q+K_I+Q_onlyr;I+Q+RD_len'

    # args.model = 'llava_med_v1.5'
    # args.hallu_type = 'context_misalignment'
    dataset_name = args.datasets
    if dataset_name in ["harvard", "pmc"]:
        ori_path = f'/root/project/benchmark_data/{dataset_name}/{dataset_name}_question_disease_close.json'
        infer_path = f'/root/project/results/{args.hallu_type}/{args.model}_close_{dataset_name}{args.subfix}.jsonl'
    else:
        if args.hallu_type == 'knowledge_deficiency':
            ori_path = '/root/project/benchmark_data/Knowledge_Deficiency_Hallucination/close-ended/mimic_cxr_close_pairs.json'
            infer_path = f'/root/project/results/{args.hallu_type}/{args.model}_close_{dataset_name}{args.subfix}.jsonl'

        elif args.hallu_type == 'context_misalignment':
            ori_path = '/root/project/benchmark_data/Context_Misalignment_Hallucination/MIMIC-CXR_pairs.json'
            infer_path = f'/root/project/results/{args.hallu_type}/{args.model}_close_{dataset_name}{args.subfix}.jsonl'

    with open(ori_path, 'r') as file:
        ori = json.load(file)
    id_to_ori = dict()    
    for i in ori:
        id_to_ori[i['qid']] = i
        
    with open(infer_path, 'r') as f:
        mimic_results = []
        for line in f:
            mimic_results.append(json.loads(line))

    summaries = []   
    summary_dir = f'/root/project/summaries/{args.hallu_type}'
    if not os.path.exists(summary_dir):
        os.makedirs(summary_dir)
    summary_path = f'/root/project/summaries/{args.hallu_type}/{args.model}_close_{args.datasets}{args.subfix}.csv'

    f_accs, f_f1s = eval_closed_single(ori, id_to_ori, mimic_results)
    summary = {
            'dataset':dataset_name, 
            'avg_accs': float("{:.1f}".format(f_accs * 100)),
            'avg_f1s': float("{:.1f}".format(f_f1s * 100)),
            }
    summaries.append(summary)
    print(summary)

    summary = pd.DataFrame(summaries)
    summary.to_csv(summary_path)