"""
This script is used to evaluate MLLMs.
"""
import argparse
import logging
import random
import time
import wandb
import os
from einops import rearrange
import numpy as np

import torch
import json
from utils import *


PATH = {
    'llava_v1.5_7B': '/data/huggingface/llava-v1.5-7b-hf', 
    'llava_v1.5_7B_lht': '/data/huggingface/llava-v1.5-7b-liuhaotian',
    'mplug_owl2_7B': '/data/huggingface/mplug-owl2-llama2-7b',  
    'sharegpt4v_7B': '/data/huggingface/sharegpt4v-7b', 
    'instructblip_7B': '/data/huggingface/instructblip-vicuna-7b-old',
    'instructblip_7B_new': '/data/huggingface/instructblip-vicuna-7b',
    'qwen_vl_7B': '/data/huggingface/qwen-vl',
    'qwen2_vl_7B': '/data/huggingface/qwen2-vl',
    'qwen_vl_7B_chat': '/data/huggingface/qwen-vl-chat',
    'cogvlm_17B': '/data/huggingface/cogvlm-base-224-hf',
    'llava_v1.5_7B_hacl': '/data/multimodal_alignment/mPLUG-HalOwl-main/hacl/checkpoints/llava_sft',
    'minigpt4_vicuna_7B': '/data/multimodal_alignment/mm_iti/models/minigpt4/minigpt4_eval.yaml',
    'minigptv2_llama2_7B': '/data/multimodal_alignment/mm_iti/models/minigpt4/eval_configs/minigptv2_eval.yaml',
    'minigpt4_llama2_7B': '/data/multimodal_alignment/mm_iti/models/minigpt4/eval_configs/minigpt4_llama2_eval.yaml',
    'shikra_7B': '/data/multimodal_alignment/mm_iti/models/shikra_model/shikra_config.py'
}

def seed_all(seed = 8888):
    torch.manual_seed(seed)
    random.seed(seed)
    
def evaluate(func):
    def wrapper(*args, **kwargs):
        result = func(*args, **kwargs)
        return result
    return wrapper

@evaluate
def evaluate_model(model, args, data):
    model.batch_evaluate(args, data)  

@evaluate
def evaluate_model_with_intervention(model, args, data, interventions={}, intervention_fn=None):
    model.batch_evaluate_with_intervention(args, data, interventions, intervention_fn)  

def get_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('--model', type=str, default='llava_v1.5_7B', help="specifies the model to be evaluated.")
    parser.add_argument('--openai', type=str, default=None, help="specifies the api_key")
    parser.add_argument('--tokenizer', type=str, default=None, help='specifies the tokenizer to be used.')
    
    parser.add_argument('--probe_dataset', type=str, default='spa_vl', help='feature bank for training probes')
    parser.add_argument('--validate_dataset', type=str, default='MLLMGuard', help="specifies the path to the data")
    parser.add_argument('--log_file', type=str, default='logs/default.log', help='specifies the name of the log file')
    parser.add_argument('--save_path', type=str, default='results/toxicity_en.jsonl', help='specifies the path to save the results.')

    parser.add_argument('--verbose', type=bool, default=True, help='specifies whether to display verbose outputs.')
    
    parser.add_argument('--project_name', type=str, default='mllmguard', help='specifies the project name in wandb.')
    parser.add_argument('--entity_name', type=str, default='entity_name', help='specifies the entity name in wandb.')
    parser.add_argument('--seed', type=int, default=42, help='seed')
    
    parser.add_argument('--device', type=str, default='0')
    parser.add_argument('--categories', type=str, default='all')
    parser.add_argument('--subfix', type=str, default='')
    
    args = parser.parse_args()
    logging.basicConfig(
        filename=args.log_file,
        filemode="w+",
        format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
        datefmt="%Y-%m-%d-%H-%M",
        level=logging.INFO,
    )
    
    for arg in vars(args):
        logging.info(f"{arg}: {getattr(args, arg)}")
    return args

def main(args):
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    
    if args.probe_dataset == "spa_vl": 
        # dataset = load_dataset("/data/multimodal_alignment/SPA-VL")['train']
        dataset = json.load(open('/data/multimodal_alignment/SPA-VL/train/meta.json', 'r'))
        dataset = dataset[:2000]
        formatter = process_data_spa_vl
    
    model_name = args.model.lower()
    
    if model_name == 'gpt4v':
        from apis.gpt4v import GPT4V
        if args.openai and args.organization:
            start_time = time.time()
            mllm = GPT4V(args.openai, args.organization)
            mllm.batch_evaluate(validate_data, args)
            end_time = time.time()
            logging.info(f'Result has been saved to {args.save_path}. Time used: {end_time - start_time}.')
        else:
            raise ValueError('OpenAI API Key or organization is not specified.')
        
    elif 'yi_plus' in model_name:
        from apis.yi_vl_plus import Yi_VL
        if args.model:
            yi_plus = Yi_VL(args.model)
            evaluate(yi_plus, args, validate_data)
        else:
            raise ValueError('API Key is not specified.')
        
    elif 'gemini' in model_name:
        from apis.geminipro import Gemini
        if args.model:
            gemini = Gemini(args.model)
            evaluate(gemini, args, validate_data)
            
    else:
        if 'llava' in model_name and 'lht' in model_name or 'llava' in model_name and 'hacl' in model_name:
            from models.llava_inference_lht import Llava_lht
            model = Llava_lht(PATH[args.model])
        elif 'llava' in model_name:
            from models.llava_inference import Llava
            model = Llava(PATH[args.model])
        elif 'shikra' in model_name:
            from models.shikra_inference import Shikra
            model = Shikra(PATH[args.model])
        elif 'qwen2_vl' in model_name:
            from models.qwen2 import Qwen2VL
            model = Qwen2VL(PATH[args.model])
        elif 'qwen_vl' in model_name and 'chat' in model_name:
            from models.qwen import QwenVL_Chat
            model = QwenVL_Chat(PATH[args.model])
        elif 'qwen_vl' in model_name:
            from models.qwen import QwenVL
            model = QwenVL(PATH[args.model])
        elif 'qwen' in model_name:
            from models.qwen import Qwen
            model = Qwen(PATH[args.model])
        elif 'cogvlm' in model_name:
            from models.cogvlm import CogVLM
            model = CogVLM(PATH[args.model], args.tokenizer)
        elif 'yi' in model_name:
            from models.yi import YIVL
            model = YIVL(PATH[args.model])
        elif 'deepseek' in model_name:
            from models.deepseek import DeepSeek
            model = DeepSeek(PATH[args.model])
        elif 'mplug_owl2' in model_name:
            from models.mplug import mPLUG_Owl2
            model = mPLUG_Owl2(PATH[args.model])
        elif 'mplug_owl' in model_name:
            from models.mplug import mPLUG_Owl
            model = mPLUG_Owl(PATH[args.model])
        elif 'seed_llama_14B' in model_name:
            from models.seed import SeedLLaMA14B
            model = SeedLLaMA14B(PATH[args.model])
        elif 'seed_llama_8B' in model_name:
            from models.seed import SeedLLaMA8B
            model = SeedLLaMA8B(PATH[args.model])
        elif 'minigptv2' in model_name:
            from models.minigptv2 import MiniGPTV2
            model = MiniGPTV2(PATH[args.model], args.device, model_name)
        elif 'minigpt' in model_name:
            from models.minigpt4_inference import MiniGPT4
            model = MiniGPT4(PATH[args.model], args.device, model_name)
        elif 'sharegpt' in model_name:
            from models.sharegpt4v import ShareGPT
            model = ShareGPT(PATH[args.model], validate_data)
        elif 'xcomposer' in model_name:
            from models.xcomposer import Xcomposer
            model = Xcomposer(PATH[args.model], validate_data)
        elif 'instructblip' in model_name:
            from models.instructblip_inference import InstructBlip
            model = InstructBlip(PATH[args.model])
            
        else:
            raise NotImplementedError(
                f'Model {model_name} has not been implemented.'
            )
    
    if not os.path.exists(f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}"):
        os.makedirs(f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}")
    
    if args.validate_dataset == 'MLLMGuard':
        if args.categories == 'all':
            args.categories = ['privacy', 'bias', 'toxicity', 'hallucination', 'legality']
        else:
            args.categories = args.categories.split(' ')
        for c in args.categories:
            dataset_path = f"/data/multimodal_alignment/mm_iti/data/MLLMGuard/{c}"
            save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{c}"
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_mllmguard(dataset_path)
            
            # ## 将有害图片换成无害图片
            # safe_data = process_data_seed_bench(json.load(open('/data/multimodal_alignment/SEED-Bench/SEED-Bench.json', 'r'))['questions'][:2000])
            # for i in range(len(validate_data)):
            #     validate_data[i]['img_url'] = safe_data[2][i % len(safe_data)]

            evaluate_model(model, args, validate_data)
            
    elif args.validate_dataset == 'MM-HarmfulBench':
        dataset_path = f"/data/multimodal_alignment/mm_iti/data/MM-HarmfulBench"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}"
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_mmharmfulbench(dataset_path)
        evaluate_model(model, args, validate_data)
    
    elif args.validate_dataset == 'MM-SafetyBench':
        if args.categories == 'all':
            args.categories = ["01-Illegal_Activitiy",
                            "02-HateSpeech",
                            "03-Malware_Generation",
                            "04-Physical_Harm",
                            "05-EconomicHarm",
                            "06-Fraud",
                            "07-Sex",
                            "08-Political_Lobbying",
                            "09-Privacy_Violence",
                            "10-Legal_Opinion",
                            "11-Financial_Advice",
                            "12-Health_Consultation",
                            "13-Gov_Decision"]
        else:
            args.categories = args.categories.split(' ')
        
        for c in args.categories:
            dataset_path = f"/data/multimodal_alignment/mm_iti/data/MM-SafetyBench"
            save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{c.split('-')[-1]}"
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_mmsafetybench(dataset_path, c)
            evaluate_model(model, args, validate_data)
    
    elif args.validate_dataset == 'SafeBench':
        dataset_path = f"/data/multimodal_alignment/mm_iti/data/SafeBench"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}"
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_safebench(dataset_path)
        evaluate_model(model, args, validate_data)
    
    elif 'POPE' in args.validate_dataset:
        dataset_path = "/data/multimodal_alignment/mm_iti/data/POPE"
        dataset = args.validate_dataset.split('_')[-1] 
        if args.categories == 'all':
            args.categories = ['adversarial', 'popular', 'random']
        else:
            args.categories = args.categories.split(' ')
        
        for c in args.categories:
            save_name = f"/data/multimodal_alignment/mm_iti/results/POPE/{dataset}/{args.model}_{c}"
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_pope(dataset_path, dataset, c)
            evaluate_model(model, args, validate_data)
            
    elif args.validate_dataset == 'CHAIR':
        dataset_path = "/data/multimodal_alignment/mm_iti/data/CHAIR"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}"
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_chair(dataset_path)
        evaluate_model(model, args, validate_data)
    
    elif args.validate_dataset == 'SEED-Bench':
        dataset_path = "/data/multimodal_alignment/mm_iti/data/SEED-Bench"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}"
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_seed_bench(dataset_path)
        evaluate_model(model, args, validate_data)
    
    elif args.validate_dataset == 'MMHalBench':
        dataset_path = "/data/multimodal_alignment/mm_iti/data/MMHal-Bench"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}"
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_mmhalbench(dataset_path)
        evaluate_model(model, args, validate_data)
        
    elif args.validate_dataset == 'AMBER':
        dataset_path = "/data/multimodal_alignment/mm_iti/data/AMBER"
        if args.categories == 'all':
            args.categories = ['gen', 'dis-attribute_sub', 'dis-existence_sub', 'dis-relation_sub']
        else:
            args.categories = args.categories.split(' ')
        
        for c in args.categories:
            save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{c}"
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_amber(dataset_path, c)
            evaluate_model(model, args, validate_data)
            
    elif args.validate_dataset == 'MME':
        dataset_path = "/data/multimodal_alignment/mm_iti/data/MME"
        if args.categories == 'all':
            args.categories = ['color', 'count', 'existence', 'position', 'commonsense_reasoning']
        else:
            args.categories = args.categories.split(' ')
        
        for c in args.categories:
            save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{c}"
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_mme(dataset_path, c)
            evaluate_model(model, args, validate_data)
            
    elif args.validate_dataset == 'MME_general':
        dataset_path = "/data/multimodal_alignment/mm_iti/data/MME"
        if args.categories == 'all':
            args.categories = ['artwork', 'celebrity', 'code_reasoning', 'commonsense_reasoning', 
            'landmark', 'numerical_calculation', 'OCR', 'posters', 'scene', 'text_translation']
        else:
            args.categories = args.categories.split(' ')
        
        for c in args.categories:
            save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{c}"
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_mme(dataset_path, c)
            evaluate_model(model, args, validate_data)
         
if __name__ == "__main__":
    args = get_args()
    # wandb.init(project=args.project_name, entity=args.entity_name, config=args)
    
    args.validate_dataset = 'MME'
    args.categories = 'all'
    args.model = 'llava_v1.5_7B_lht'
    print(args.save_path)
    args.device = '1'
    
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
    main(args)
    # seed_all(5555)