import json
from utils.eval_yesno import evaluate_yes_no
from utils.eval_multichoice import eval_mc
from utils.type1_utils import eval_single, avg_acc, avg_acc_all, eval_closed, eval_closed_with_false, eval_closed_only_yesno
import re
import numpy as np
import argparse
import pandas as pd

from utils.type1_utils import eval_all
import os

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--origin-data-dir", type=str, default=None)
    parser.add_argument("--result-dir", type=str, default=None)
    parser.add_argument("--model", type=str, default="")
    parser.add_argument("--subfix", type=str, default="")
    parser.add_argument("--answers-file", type=str, default="answer.jsonl")
    parser.add_argument("--datasets", type=str, default="")
    parser.add_argument("--hallu_type", type=str, default='')
    args = parser.parse_args()

    # args.datasets = 'rad_chest'
    # args.model = 'llava_med_v1.5'
    # args.hallu_type = 'visual_misinterpretation' # visual_misinterpretation knowledge_deficiency context_misalignment 
    # args.subfix =  "" #'_32_1_Mimic_Knowledge_I+Q;I+Q+K_len'
    
    args.datasets = args.datasets.split(' ')
    if args.datasets[0] == 'MM':
        args.datasets = ['slake', 'rad']
    elif args.datasets[0] == 'CXR':
        args.datasets = ['mimic_cxr', 'xray']
    elif args.datasets[0] == 'ALL':
        args.datasets = ['slake', 'rad', 'mimic_cxr', 'xray']

    summaries = []
    results = {}
    summary_dir = f'/root/project/summaries/{args.hallu_type}'
    if not os.path.exists(summary_dir):
        os.makedirs(summary_dir)
    summary_path = f'/root/project/summaries/{args.hallu_type}/{args.model}_close_{args.datasets[0]}{args.subfix}.csv'
    for dataset_name in args.datasets:
        ori_path = f'/root/project/benchmark_data/Visual_Misinterpretation_Hallucination/close-ended/{dataset_name}_close_pairs.json'
        infer_path = f'/root/project/results/{args.hallu_type}/{args.model}_close_{dataset_name}{args.subfix}.jsonl'
        false_path = f'/root/project/results_false/{args.hallu_type}/{args.model}_close_{dataset_name}{args.subfix}.jsonl'
        assert args.hallu_type == 'visual_misinterpretation'
        
        # if dataset_name == 'SLAKE':
        #     ori_path = f"/root/project/benchmark_data/Visual_Misinterpretation_Hallucination/close-ended/fine-grained/slake_qa_pairs.json"
        #     infer_path = f'/root/project/results/{args.hallu_type}/{args.model}_close_slake_answers.jsonl'
        #     false_path = f'/root/project/results_false/{args.hallu_type}/{args.model}_close_slake_answers.jsonl'
        # elif dataset_name == 'VQA_RAD':
        #     ori_path = f"/root/project/benchmark_data/Visual_Misinterpretation_Hallucination/close-ended/fine-grained/rad_vqa_pairs.json"
        #     infer_path = f'/root/project/results/{args.hallu_type}/{args.model}_close_vqarad_answers.jsonl'
        #     false_path = f'/root/project/results_false/{args.hallu_type}/{args.model}_close_vqarad_answers.jsonl'
        # elif dataset_name == 'MIMIC_CXR':
        #     ori_path = f"/root/project/benchmark_data/Visual_Misinterpretation_Hallucination/close-ended/fine-grained/mimic_cxr_closed_pairs.json"
        #     infer_path = f'/root/project/results/{args.hallu_type}/{args.model}_close_mimic_answers.jsonl'
        #     false_path = f'/root/project/results_false/{args.hallu_type}/{args.model}_close_mimic_answers.jsonl'
        # elif dataset_name == 'IU_Xray':
        #     ori_path = f"/root/project/benchmark_data/Visual_Misinterpretation_Hallucination/close-ended/fine-grained/xray_closed_pairs.json"
        #     infer_path = f'/root/project/results/{args.hallu_type}/{args.model}_close_xray_answers.jsonl'
        #     false_path = f'/root/project/results_false/{args.hallu_type}/{args.model}_close_xray_answers.jsonl'
        # else:
        #     print(f'Wrong dataset for {args.hallu_type}')
        #     continue
            
        with open(ori_path, 'r') as file:
            ori = json.load(file)
        id_to_ori = dict()    
        for i in ori:
            id_to_ori[i['qid']] = i
            
        with open(infer_path, 'r') as f:
            infer_results = [json.loads(line) for line in f]
            
        # f_accs, f_f1s, f_lens, o_accs, o_lens = eval_closed_only_yesno(ori, id_to_ori, infer_results)
        f_accs, f_f1s, f_lens, o_accs, o_lens = eval_closed(ori, id_to_ori, infer_results)
        # f_accs, f_lens, o_accs, o_lens = eval_closed_with_false(ori, id_to_ori, infer_results, false_path)
        results[dataset_name] = {'f_accs': f_accs, 
                         'f_f1s': f_f1s,
                         'f_lens': f_lens,
                         'o_accs': o_accs,
                         'o_lens': o_lens,
                         'f_avg_acc': np.sum(f_accs * f_lens) / np.sum(f_lens),
                         'f_avg_f1': np.sum(f_f1s * f_lens) / np.sum(f_lens)}
        summary = {
            'dataset':dataset_name, 
            'type1': float("{:.1f}".format(f_accs[0] * 100)),
            'type2': float("{:.1f}".format(f_accs[1] * 100)),
            'type3': float("{:.1f}".format(f_accs[2] * 100)),
            'type4': float("{:.1f}".format(f_accs[3] * 100)),
            'avg_acc': float("{:.1f}".format(results[dataset_name]['f_avg_acc'] * 100)),
            'avg_f1': float("{:.1f}".format(results[dataset_name]['f_avg_f1'] * 100)),
            }
        summaries.append(summary)
        print(dataset_name)
        print([float("{:.1f}".format(_ * 100)) for _ in f_accs], float("{:.1f}".format(results[dataset_name]['f_avg_acc'] * 100)))
        print([float("{:.1f}".format(_ * 100)) for _ in f_f1s], float("{:.1f}".format(results[dataset_name]['f_avg_f1'] * 100)))
    col1 = avg_acc([v['f_accs'] for v in results.values()], [v['f_lens'] for v in results.values()])
    col1_o = avg_acc([v['o_accs'] for v in results.values()], [v['o_lens'] for v in results.values()])
    col11 = avg_acc_all([v['f_accs'] for v in results.values()], [v['f_lens'] for v in results.values()])

    # summaries['overall'] = [float("{:.1f}".format(_ * 100)) for _ in col1.tolist()] + [float("{:.1f}".format(col11 * 100))]
    # print(col1, col1_o, col11)
    summary = pd.DataFrame(summaries)
    summary.to_csv(summary_path)
