"""
This script is used to evaluate MLLMs.
"""
import argparse
import logging
import random
import time
import wandb
import os
from einops import rearrange
import numpy as np

import torch
import json
from utils import *


PATH = {
    'llava_v1.5_7B': '/data/huggingface/llava-v1.5-7b-hf', 
    'mplug_owl2_7B': '/data/huggingface/mplug-owl2-llama2-7b',  
    'sharegpt4v_7B': '/data/huggingface/sharegpt4v-7b', 
}

def seed_all(seed = 8888):
    torch.manual_seed(seed)
    random.seed(seed)
    
def evaluate(func):
    def wrapper(*args, **kwargs):
        result = func(*args, **kwargs)
        return result
    return wrapper

@evaluate
def evaluate_model(model, args, data):
    model.batch_evaluate(args, data)  

@evaluate
def evaluate_model_with_intervention(model, args, data, interventions={}, intervention_fn=None):
    model.batch_evaluate_with_intervention(args, data, interventions, intervention_fn)  

def get_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('--model', type=str, default='llava_v1.5_7B', help="specifies the model to be evaluated.")
    parser.add_argument('--openai', type=str, default=None, help="specifies the api_key")
    parser.add_argument('--tokenizer', type=str, default=None, help='specifies the tokenizer to be used.')
    
    parser.add_argument('--probe_dataset', type=str, default='spa_vl', help='feature bank for training probes')
    parser.add_argument('--validate_dataset', type=str, default='MLLMGuard', help="specifies the path to the data")
    parser.add_argument('--log_file', type=str, default='logs/default.log', help='specifies the name of the log file')
    parser.add_argument('--save_path', type=str, default='results/toxicity_en.jsonl', help='specifies the path to save the results.')

    parser.add_argument('--verbose', type=bool, default=True, help='specifies whether to display verbose outputs.')
    
    parser.add_argument('--project_name', type=str, default='mllmguard', help='specifies the project name in wandb.')
    parser.add_argument('--entity_name', type=str, default='entity_name', help='specifies the entity name in wandb.')
    
    parser.add_argument('--num_heads', type=int, default=48, help='K, number of top heads to intervene on')
    parser.add_argument('--alpha', type=int, default=15, help='alpha, intervention strength')
    parser.add_argument('--val_ratio', type=float, help='ratio of validation set size to development set size', default=0.2)
    parser.add_argument('--use_center_of_mass', action='store_true', help='use center of mass direction', default=False)
    parser.add_argument('--use_random_dir', action='store_true', help='use random direction', default=False)
    parser.add_argument('--use_all_captions', action='store_true', help='use center of mass direction', default=False)
    parser.add_argument('--seed', type=int, default=42, help='seed')
    
    parser.add_argument('--device', type=str, default='0')
    parser.add_argument('--categories', type=str, default='all')
    parser.add_argument('--subfix', type=str, default='all')
    
    args = parser.parse_args()
    logging.basicConfig(
        filename=args.log_file,
        filemode="w+",
        format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
        datefmt="%Y-%m-%d-%H-%M",
        level=logging.INFO,
    )
    
    for arg in vars(args):
        logging.info(f"{arg}: {getattr(args, arg)}")
    return args

def main(args):
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    
    if args.probe_dataset == "spa_vl": 
        # dataset = load_dataset("/data/multimodal_alignment/SPA-VL")['train']
        dataset = json.load(open('/data/multimodal_alignment/SPA-VL/train/meta.json', 'r'))
        dataset = dataset[:2000]
    
    model_name = args.model.lower()
    
    if model_name == 'gpt4v':
        from apis.gpt4v import GPT4V
        if args.openai and args.organization:
            start_time = time.time()
            mllm = GPT4V(args.openai, args.organization)
            mllm.batch_evaluate(validate_data, args)
            end_time = time.time()
            logging.info(f'Result has been saved to {args.save_path}. Time used: {end_time - start_time}.')
        else:
            raise ValueError('OpenAI API Key or organization is not specified.')
        
    elif 'yi_plus' in model_name:
        from apis.yi_vl_plus import Yi_VL
        if args.model:
            yi_plus = Yi_VL(args.model)
            evaluate(yi_plus, args, validate_data)
        else:
            raise ValueError('API Key is not specified.')
        
    elif 'gemini' in model_name:
        from apis.geminipro import Gemini
        if args.model:
            gemini = Gemini(args.model)
            evaluate(gemini, args, validate_data)
            
    else:
        if 'llava' in model_name:
            from models.llava_inference import Llava
            model = Llava(PATH[args.model])
        elif 'qwen_vl' in model_name:
            from models.qwen import QwenVL
            model = QwenVL(PATH[args.model])
        elif 'qwen' in model_name:
            from models.qwen import Qwen
            model = Qwen(PATH[args.model])
        elif 'cogvlm' in model_name:
            from models.cogvlm import CogVLM
            model = CogVLM(PATH[args.model], args.tokenizer)
        elif 'yi' in model_name:
            from models.yi import YIVL
            model = YIVL(PATH[args.model])
        elif 'deepseek' in model_name:
            from models.deepseek import DeepSeek
            model = DeepSeek(PATH[args.model])
        elif 'mplug_owl2' in model_name:
            from models.mplug import mPLUG_Owl2
            model = mPLUG_Owl2(PATH[args.model])
        elif 'mplug_owl' in model_name:
            from models.mplug import mPLUG_Owl
            model = mPLUG_Owl(PATH[args.model])
        elif 'seed_llama_14B' in model_name:
            from models.seed import SeedLLaMA14B
            model = SeedLLaMA14B(PATH[args.model])
        elif 'seed_llama_8B' in model_name:
            from models.seed import SeedLLaMA8B
            model = SeedLLaMA8B(PATH[args.model])
        elif 'minigptv2' in model_name:
            from models.minigptv2 import MiniGPTV2
            model = MiniGPTV2(PATH[args.model], validate_data)
        elif 'sharegpt' in model_name:
            from models.sharegpt4v import ShareGPT
            model = ShareGPT(PATH[args.model], validate_data)
        elif 'xcomposer' in model_name:
            from models.xcomposer import Xcomposer
            model = Xcomposer(PATH[args.model], validate_data)
        else:
            raise NotImplementedError(
                f'Model {model_name} has not been implemented.'
            )
    
    # num_layers = model.model.language_model.config.num_hidden_layers
    # num_heads = model.model.language_model.config.num_attention_heads
    # num_layers = 32
    # num_heads = 32

    # load activations 
    whole_image_activations = np.load(f"/data/multimodal_alignment/mm_iti/features/{args.model}_{args.probe_dataset}_whole_image_activations.npy")
    whole_text_activations = np.load(f"/data/multimodal_alignment/mm_iti/features/{args.model}_{args.probe_dataset}_whole_text_activations.npy", allow_pickle=True)
    
    # get directions
    com_directions = get_com_directions_i2t(whole_image_activations, whole_text_activations, args.use_all_captions)
    
    # save_probes(com_directions, f'/data/multimodal_alignment/mm_iti/probes/coms_{args.model}_{args.probe_dataset}_i2t.pkl')
    
    # com_directions = load_probes(f'/data/multimodal_alignment/mm_iti/probes/coms_{args.model}_{args.probe_dataset}_i2t.pkl')
    
    if args.use_random_dir:
        com_directions = np.random.normal(size=(4096,))
    # com_directions = com_directions / np.linalg.norm(com_directions)
    
    if args.validate_dataset == 'MLLMGuard':
        if args.categories == 'all':
            args.categories = ['privacy', 'bias', 'toxicity', 'hallucination', 'legality']
        else:
            args.categories = args.categories.split(' ')
        for c in args.categories:
            dataset_path = f"/data/multimodal_alignment/mm_iti/data/MLLMGuard/{c}"
            save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{c}_{args.alpha}"
            if args.use_center_of_mass:
                save_name += '_com'
            if args.use_random_dir:
                save_name += '_rand'
            save_name += args.subfix
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_mllmguard(dataset_path)
            evaluate_model_with_intervention(model, args, validate_data, 
                interventions=interventions, 
                intervention_fn=lt_modulated_vector_add)
            
    elif args.validate_dataset == 'MM-HarmfulBench':
        dataset_path = f"/data/multimodal_alignment/mm_iti/data/MM-HarmfulBench"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{args.alpha}"
        if args.use_center_of_mass:
            save_name += '_com'
        save_name += args.subfix
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_mmharmfulbench(dataset_path)
        evaluate_model_with_intervention(model, args, validate_data, 
            interventions=interventions, 
            intervention_fn=lt_modulated_vector_add)
    
    elif args.validate_dataset == 'MM-SafetyBench':
        if args.categories == 'all':
            args.categories = ["01-Illegal_Activitiy",
                            "02-HateSpeech",
                            "03-Malware_Generation",
                            "04-Physical_Harm",
                            "05-EconomicHarm",
                            "06-Fraud",
                            "07-Sex",
                            "08-Political_Lobbying",
                            "09-Privacy_Violence",
                            "10-Legal_Opinion",
                            "11-Financial_Advice",
                            "12-Health_Consultation",
                            "13-Gov_Decision"]
        else:
            args.categories = args.categories.split(' ')
        
        for c in args.categories:
            dataset_path = f"/data/multimodal_alignment/mm_iti/data/MM-SafetyBench"
            save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{c.split('-')[-1]}_{args.alpha}"
            if args.use_center_of_mass:
                save_name += '_com'
            save_name += args.subfix
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_mmsafetybench(dataset_path, c)
            evaluate_model_with_intervention(model, args, validate_data, 
                interventions=interventions, 
                intervention_fn=lt_modulated_vector_add)
    
    elif args.validate_dataset == 'SafeBench':
        dataset_path = f"/data/multimodal_alignment/mm_iti/data/SafeBench"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{args.alpha}"
        if args.use_center_of_mass:
            save_name += '_com'
        save_name += args.subfix
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_safebench(dataset_path)
        evaluate_model_with_intervention(model, args, validate_data, 
            interventions=interventions, 
            intervention_fn=lt_modulated_vector_add)
    
    elif args.validate_dataset == 'POPE':
        dataset_path = "/data/multimodal_alignment/mm_iti/data/POPE"
        if args.categories == 'all':
            args.categories = ['adversarial', 'popular', 'random']
        else:
            args.categories = args.categories.split(' ')
        
        for c in args.categories:
            save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{c}_{args.alpha}"
            save_name += args.subfix
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_pope(dataset_path, c)
            intervention = {'alpha': args.alpha, 'direction': com_directions}
            model.batch_evaluate_with_i2t(args, validate_data, intervention) 
            
    elif args.validate_dataset == 'CHAIR':
        dataset_path = "/data/multimodal_alignment/mm_iti/data/CHAIR"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{args.alpha}"
        save_name += args.subfix
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_chair(dataset_path)
        
        intervention = {'alpha': args.alpha, 'direction': com_directions}
        model.batch_evaluate_with_i2t(args, validate_data, intervention) 
    
         
if __name__ == "__main__":
    args = get_args()
    # wandb.init(project=args.project_name, entity=args.entity_name, config=args)
    
    # args.validate_dataset = 'POPE'
    # args.probe_dataset = 'flickr30k'
    # args.categories = 'all'
    # args.model = 'llava_v1.5_7B'
    # args.alpha = 5
    # print(args.save_path)
    # args.device = '0'
    # args.use_all_captions = True
    # # args.use_random_dir = True
    # args.subfix = ''
    
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
    print(args)
    main(args)
    # seed_all(5555)