
"""
This script is used to evaluate MLLMs.
"""
import argparse
import logging
import random
import time
import os
import sys
sys.path.append('/root/project/code/baselines/Med_LVLMs')
from einops import rearrange
import numpy as np

import torch
import json
from utils import *


PATH = {
    'llava_med_v1.5': '/root/project/huggingface/llava-med-v1.5-mistral-7b', 
    'exgra': "/root/project/huggingface/exgra-med-dci",
    'llava_v1.5_7B_lht': '/root/project/huggingface/llava-v1.5-7b-liuhaotian',
    'mplug_owl2_7B': '/root/project/huggingface/mplug-owl2-llama2-7b',  
    'sharegpt4v_7B': '/root/project/huggingface/sharegpt4v-7b', 
    'instructblip_7B': '/root/project/huggingface/instructblip-vicuna-7b-old',
    'instructblip_7B_new': '/root/project/huggingface/instructblip-vicuna-7b',
    'qwen_vl_7B': '/root/project/huggingface/qwen-vl',
    'qwen2_vl_7B': '/root/project/huggingface/qwen2-vl',
    'qwen_vl_7B_chat': '/root/project/huggingface/qwen-vl-chat',
    'cogvlm_17B': '/root/project/huggingface/cogvlm-base-224-hf',
    'llava_v1.5_7B_hacl': '/root/wtb/multimodal_alignment/mPLUG-HalOwl-main/hacl/checkpoints/llava_sft',
    'minigpt4_vicuna_7B': '/root/wtb/multimodal_alignment/mm_iti/models/minigpt4/minigpt4_eval.yaml',
    'minigptv2_llama2_7B': '/root/wtb/multimodal_alignment/mm_iti/models/minigpt4/eval_configs/minigptv2_eval.yaml',
    'minigpt4_llama2_7B': '/root/wtb/multimodal_alignment/mm_iti/models/minigpt4/eval_configs/minigpt4_llama2_eval.yaml',
    'shikra_7B': '/root/wtb/multimodal_alignment/mm_iti/models/shikra_model/shikra_config.py'
}

IMAGE_PATH = {
    'mimic_cxr': '/root/project/datasets/mimic_cxr_jpg/files',
    'xray': '/root/project/datasets/iu_xray/images',
    'rad': '/root/project/datasets/VQA_RAD',
    'rad_chest': '/root/project/datasets/VQA_RAD',
    'rad_other': '/root/project/datasets/VQA_RAD',
    'slake': '/root/project/datasets/Slake1.0/imgs',
    'slake_chest': '/root/project/datasets/Slake1.0/imgs',
    'slake_other': '/root/project/datasets/Slake1.0/imgs',
    'harvard': "/root/project/datasets/harvard/images",
    'pmc': "/root/project/datasets/pmc_oa/caption_T060_filtered_top4_sep_v0_subfigures"
}

def seed_all(seed = 8888):
    torch.manual_seed(seed)
    random.seed(seed)

def evaluate_batch(model, args, data):
    response_list = []
    for sample in tqdm(data):
        prompt = sample['prompt']
        image = sample['img_url']
        res = sample.copy()        
        response = model.evaluate(prompt, image)
        res['response'] = response
        print(res)
        response_list.append(res)
    
    with jsonlines.open(args.save_path, 'w') as writer:
        writer.write_all(response_list) 

def get_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('--model', type=str, default='llava_v1.5_7B', help="specifies the model to be evaluated.")
    parser.add_argument('--openai', type=str, default=None, help="specifies the api_key")
    parser.add_argument('--tokenizer', type=str, default=None, help='specifies the tokenizer to be used.')
    
    parser.add_argument('--probe_dataset', type=str, default='spa_vl', help='feature bank for training probes')
    parser.add_argument('--validate_datasets', type=str, default='MLLMGuard', help="specifies the path to the data")
    parser.add_argument('--log_file', type=str, default='logs/default.log', help='specifies the name of the log file')
    parser.add_argument('--save_path', type=str, default='results/toxicity_en.jsonl', help='specifies the path to save the results.')

    parser.add_argument('--verbose', type=bool, default=True, help='specifies whether to display verbose outputs.')
    
    parser.add_argument('--project_name', type=str, default='mllmguard', help='specifies the project name in wandb.')
    parser.add_argument('--entity_name', type=str, default='entity_name', help='specifies the entity name in wandb.')
    parser.add_argument('--seed', type=int, default=42, help='seed')
    
    parser.add_argument('--device', type=str, default='0')
    parser.add_argument('--categories', type=str, default='all')
    parser.add_argument('--subfix', type=str, default='')
    
    args = parser.parse_args()
    return args

def main(args):
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    
    model_name = args.model.lower()
    
    if 'llava_med_v1.5' in model_name:
        from baselines.Med_LVLMs.llava_med_v15_inference import Llava_med_v15
        sys.path.append('/root/project/code/baselines/Med_LVLMs/llava_med_v15')
        model = Llava_med_v15(PATH[args.model], args.device)
    elif 'exgra' in model_name:
        from baselines.Med_LVLMs.llava_med_inference import Llava_med_v1
        sys.path.append('/root/project/code/baselines/Med_LVLMs/llava_med_v1')
        model = Llava_med_v1(PATH[args.model], args.device)
    else:
        raise NotImplementedError(
            f'Model {model_name} has not been implemented.'
        )
    
    if not os.path.exists(f"/root/project/results/{args.hallu_type}"):
        os.makedirs(f"/root/project/results/{args.hallu_type}")
    
    args.validate_datasets = args.validate_datasets.split(' ')
    if args.validate_datasets[0] == 'MM':
        args.validate_datasets = ['slake', 'rad']
    elif args.validate_datasets[0] == 'CXR':
        args.validate_datasets = ['mimic_cxr', 'xray']
    elif args.validate_datasets[0] == 'ALL':
        args.validate_datasets = ['slake', 'rad', 'mimic_cxr', 'xray']

    for dataset_name in args.validate_datasets:
        args.image_folder = IMAGE_PATH[dataset_name]
        args.save_path = f'/root/project/results/{args.hallu_type}/{args.model}_{args.answer_type}_{dataset_name}_wo.jsonl'
        
        if dataset_name in ['slake_chest', 'slake_other', 'rad_chest', 'slake_chest', 'mimic_cxr', 'xray']:
            if args.hallu_type == 'visual_misinterpretation':
                args.question_file = f'/root/project/benchmark_data/Visual_Misinterpretation_Hallucination/{args.answer_type}-ended/{dataset_name}_{args.answer_type}_pairs.json'
                
            elif args.hallu_type == 'knowledge_deficiency':
                assert dataset_name == 'mimic_cxr'
                args.question_file = f'/root/project/benchmark_data/Knowledge_Deficiency_Hallucination/{args.answer_type}-ended/{dataset_name}_{args.answer_type}_pairs.json'
                
            elif args.hallu_type == 'context_misalignment':
                assert dataset_name == 'mimic_cxr'
                assert args.answer_type == 'close'
                args.question_file = f'/root/project/benchmark_data/Context_Misalignment_Hallucination/MIMIC-CXR_pairs.json'

            validate_data = process_data_medheval(args.question_file, args.image_folder, args.answer_type)
        
        elif dataset_name in ['harvard', 'pmc']:
            args.question_file = f"/root/project/benchmark_data/{dataset_name}/{dataset_name}_question_disease_{args.answer_type}.jsonl"
            validate_data = process_data_harvard_pmc(args.question_file, args.image_folder, args.answer_type)
        evaluate_batch(model, args, validate_data)
        
   
if __name__ == "__main__":
    args = get_args()
    # wandb.init(project=args.project_name, entity=args.entity_name, config=args)
    
    args.validate_datasets = 'harvard' # harvard pmc
    args.model = 'llava_med_v1.5'
    args.device = '0'
    args.hallu_type = 'knowledge_deficiency' # visual_misinterpretation  knowledge_deficiency context_misalignment
    args.answer_type = 'open'

    main(args)
    # seed_all(5555)