"""
This script is used to evaluate MLLMs.
"""
import argparse
import logging
import random
import time
import wandb
import os
from einops import rearrange
import numpy as np

import torch
import json
from utils import *


PATH = {
    'llava_v1.5_7B': '/data/huggingface/llava-v1.5-7b-hf', 
    'llava_v1.5_7B_lht': '/data/huggingface/llava-v1.5-7b-liuhaotian',
    'mplug_owl2_7B': '/data/huggingface/mplug-owl2-llama2-7b',  
    'sharegpt4v_7B': '/data/huggingface/sharegpt4v-7b', 
    'instructblip_7B': '/data/huggingface/instructblip-vicuna-7b-old',
    'qwen_vl_7B': '/data/huggingface/qwen-vl',
    'minigpt4_vicuna_7B': '/data/multimodal_alignment/mm_iti/models/minigpt4/minigpt4_eval.yaml',
    'minigptv2_llama2_7B': '/data/multimodal_alignment/mm_iti/models/minigpt4/eval_configs/minigptv2_eval.yaml',
    'shikra_7B': '/data/multimodal_alignment/mm_iti/models/shikra_model/shikra_config.py'
}

def seed_all(seed = 8888):
    torch.manual_seed(seed)
    random.seed(seed)
    
def evaluate(func):
    def wrapper(*args, **kwargs):
        result = func(*args, **kwargs)
        return result
    return wrapper

@evaluate
def evaluate_model(model, args, data):
    model.batch_evaluate(args, data)  

@evaluate
def evaluate_model_with_intervention(model, args, data, interventions={}, intervention_fn=None):
    model.batch_evaluate_with_intervention(args, data, interventions, intervention_fn)  

def get_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('--model', type=str, default='llava_v1.5_7B', help="specifies the model to be evaluated.")
    parser.add_argument('--openai', type=str, default=None, help="specifies the api_key")
    parser.add_argument('--tokenizer', type=str, default=None, help='specifies the tokenizer to be used.')
    
    parser.add_argument('--probe_dataset', type=str, default='spa_vl', help='feature bank for training probes')
    parser.add_argument('--validate_dataset', type=str, default='MLLMGuard', help="specifies the path to the data")
    parser.add_argument('--log_file', type=str, default='logs/default.log', help='specifies the name of the log file')
    parser.add_argument('--save_path', type=str, default='results/toxicity_en.jsonl', help='specifies the path to save the results.')

    parser.add_argument('--verbose', type=bool, default=True, help='specifies whether to display verbose outputs.')
    
    parser.add_argument('--neg_mode', type=str, default=None, help='specifies the project name in wandb.')
    parser.add_argument('--entity_name', type=str, default='entity_name', help='specifies the entity name in wandb.')
    
    parser.add_argument('--num_heads', type=int, default=48, help='K, number of top heads to intervene on')
    parser.add_argument('--alpha', type=int, default=15, help='alpha, intervention strength')
    parser.add_argument('--val_ratio', type=float, help='ratio of validation set size to development set size', default=0.2)
    parser.add_argument('--use_center_of_mass', action='store_true', help='use center of mass direction', default=False)
    parser.add_argument('--use_random_dir', action='store_true', help='use random direction', default=False)
    parser.add_argument('--seed', type=int, default=42, help='seed')
    
    parser.add_argument('--device', type=str, default='0')
    parser.add_argument('--categories', type=str, default='all')
    parser.add_argument('--subfix', type=str, default='all')

    parser.add_argument('--start_layer', type=int, default=5, help='K, number of top heads to intervene on')
    parser.add_argument('--end_layer', type=int, default=31, help='K, number of top heads to intervene on')
    
    args = parser.parse_args()
    logging.basicConfig(
        filename=args.log_file,
        filemode="w+",
        format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
        datefmt="%Y-%m-%d-%H-%M",
        level=logging.INFO,
    )
    
    for arg in vars(args):
        logging.info(f"{arg}: {getattr(args, arg)}")
    return args

def main(args):
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    
    model_name = args.model.lower()
    
    if model_name == 'gpt4v':
        from apis.gpt4v import GPT4V
        if args.openai and args.organization:
            start_time = time.time()
            mllm = GPT4V(args.openai, args.organization)
            mllm.batch_evaluate(validate_data, args)
            end_time = time.time()
            logging.info(f'Result has been saved to {args.save_path}. Time used: {end_time - start_time}.')
        else:
            raise ValueError('OpenAI API Key or organization is not specified.')
        
    elif 'yi_plus' in model_name:
        from apis.yi_vl_plus import Yi_VL
        if args.model:
            yi_plus = Yi_VL(args.model)
            evaluate(yi_plus, args, validate_data)
        else:
            raise ValueError('API Key is not specified.')
        
    elif 'gemini' in model_name:
        from apis.geminipro import Gemini
        if args.model:
            gemini = Gemini(args.model)
            evaluate(gemini, args, validate_data)
            
    else:
        if 'llava' in model_name and 'lht' in model_name:
            from models.llava_inference_lht import Llava_lht
            model = Llava_lht(PATH[args.model])
        elif 'llava' in model_name:
            from models.llava_inference import Llava
            model = Llava(PATH[args.model])
        elif 'shikra' in model_name:
            from models.shikra_inference import Shikra
            model = Shikra(PATH[args.model])
        elif 'qwen_vl' in model_name:
            from models.qwen import QwenVL
            model = QwenVL(PATH[args.model])
        elif 'qwen' in model_name:
            from models.qwen import Qwen
            model = Qwen(PATH[args.model])
        elif 'cogvlm' in model_name:
            from models.cogvlm import CogVLM
            model = CogVLM(PATH[args.model], args.tokenizer)
        elif 'yi' in model_name:
            from models.yi import YIVL
            model = YIVL(PATH[args.model])
        elif 'deepseek' in model_name:
            from models.deepseek import DeepSeek
            model = DeepSeek(PATH[args.model])
        elif 'mplug_owl2' in model_name:
            from models.mplug import mPLUG_Owl2
            model = mPLUG_Owl2(PATH[args.model])
        elif 'mplug_owl' in model_name:
            from models.mplug import mPLUG_Owl
            model = mPLUG_Owl(PATH[args.model])
        elif 'seed_llama_14B' in model_name:
            from models.seed import SeedLLaMA14B
            model = SeedLLaMA14B(PATH[args.model])
        elif 'seed_llama_8B' in model_name:
            from models.seed import SeedLLaMA8B
            model = SeedLLaMA8B(PATH[args.model])
        elif 'minigptv2' in model_name:
            from models.minigptv2 import MiniGPTV2
            model = MiniGPTV2(PATH[args.model], args.device, model_name)
        elif 'minigpt' in model_name:
            from models.minigpt4_inference import MiniGPT4
            model = MiniGPT4(PATH[args.model], args.device, model_name)
        elif 'sharegpt' in model_name:
            from models.sharegpt4v import ShareGPT
            model = ShareGPT(PATH[args.model], validate_data)
        elif 'xcomposer' in model_name:
            from models.xcomposer import Xcomposer
            model = Xcomposer(PATH[args.model], validate_data)
        elif 'instructblip' in model_name:
            from models.instructblip_inference import InstructBlip
            model = InstructBlip(PATH[args.model])
        else:
            raise NotImplementedError(
                f'Model {model_name} has not been implemented.'
            )
    
    # num_layers = model.model.language_model.config.num_hidden_layers
    # num_heads = model.model.language_model.config.num_attention_heads
    num_layers = 32
    num_heads = 32
        
    # load activations (正负样本合一起)
    # template = '/data/multimodal_alignment/mm_iti/features/{}_POPE_train_img_p2_{}_head_wise.npy'
    # if args.neg_mode == 'neg':
    #     negs = ['miss', 'blur', 'error']
    #     neg_cap_paths = [template.format(args.model, neg) for neg in negs]
    # else:
    #     neg_cap_paths = [template.format(args.model, args.neg_mode)]
    # pos_cap_path = template.format(args.model, 'best')
    # head_wise_activations, labels = load_and_conbine_activations(
    #     pos_cap_path=pos_cap_path,
    #     neg_cap_paths=neg_cap_paths
    #     )
    
    head_wise_activations = np.load(f"/data/multimodal_alignment/mm_iti/features/{args.model}_{args.probe_dataset}_head_wise.npy")
    labels = np.load(f"/data/multimodal_alignment/mm_iti/features/{args.model}_{args.probe_dataset}_labels.npy")
    
    head_wise_activations = rearrange(head_wise_activations, 'b l (h d) -> b l h d', h = num_heads)

    # # tuning dataset: no labels used, just to get std of activations along the direction
    # activations_dataset = args.dataset_name if args.activations_dataset is None else args.activations_dataset
    # tuning_activations = np.load(f"../features/{args.model_name}_{activations_dataset}_head_wise.npy")
    # tuning_activations = rearrange(tuning_activations, 'b l (h d) -> b l h d', h = num_heads)
    # tuning_labels = np.load(f"../features/{args.model_name}_{activations_dataset}_labels.npy")
    
    if args.probe_dataset == "spa_vl": 
        # dataset = load_dataset("/data/multimodal_alignment/SPA-VL")['train']
        dataset = json.load(open('/data/multimodal_alignment/SPA-VL/train/meta.json', 'r'))
        dataset = dataset[:2000]
        split_range = 2
    elif 'POPE' in args.probe_dataset:
        split_range = 12
    # elif args.probe_dataset == 'POPE_test':

    separated_head_wise_activations, separated_labels, idxs_to_split_at = get_separated_activations(labels, head_wise_activations, split_range)
    
    train_idxs = np.arange(len(separated_head_wise_activations))
    train_set_idxs = np.random.choice(train_idxs, size=int(len(train_idxs)*(1-args.val_ratio)), replace=False)
    val_set_idxs = np.array([x for x in train_idxs if x not in train_set_idxs])
    
    # get directions
    if args.use_center_of_mass:
        com_directions = get_com_directions(num_layers, num_heads, train_set_idxs, val_set_idxs, separated_head_wise_activations, separated_labels)
    else:
        com_directions = None
    # top_heads, probes = get_top_heads(train_set_idxs, val_set_idxs, separated_head_wise_activations, separated_labels, num_layers, num_heads, args.seed, args.num_heads, args.use_random_dir)
    
    # save_probes(probes, f'/data/multimodal_alignment/mm_iti/probes/probes_{args.model}_{args.probe_dataset}.pkl')
    save_probes(com_directions, f'/data/multimodal_alignment/mm_iti/probes/coms_{args.model}_{args.probe_dataset}.pkl')
    
    if args.use_center_of_mass:
        com_directions = load_probes(f'/data/multimodal_alignment/mm_iti/probes/coms_{args.model}_{args.probe_dataset}.pkl')
    else:
        com_directions = None
    probes = None
    # probes = load_probes(f'/data/multimodal_alignment/mm_iti/probes/probes_{args.model}_{args.probe_dataset}.pkl')
    # val_separated_head_wise_activations = separated_head_wise_activations
    # val_separated_labels = separated_labels
    # top_accs, top_heads = val_probes(args.seed, val_set_idxs, val_separated_head_wise_activations, val_separated_labels, 
    #                     probes, num_layers, num_heads, args.num_heads)
    top_heads = sort_direction_len(com_directions, num_layers, num_heads, args.num_heads, args.start_layer, args.end_layer)
    
    print("Heads intervened: ", sorted(top_heads))
    
    interventions = get_interventions_dict(model_name, top_heads, probes, 0, num_heads, args.use_center_of_mass, args.use_random_dir, com_directions)

    def lt_modulated_vector_add(head_output, layer_name, start_edit_location='lt'): 
        head_output = rearrange(head_output, 'b s (h d) -> b s h d', h=num_heads)
        for head, direction, proj_val_std in interventions[layer_name]:
            direction_to_add = torch.tensor(direction).to(head_output.device.index)
            if start_edit_location == 'lt': 
                head_output[:, -1, head, :] += args.alpha * proj_val_std * direction_to_add
            else: 
                head_output[:, start_edit_location:, head, :] += args.alpha * proj_val_std * direction_to_add
        head_output = rearrange(head_output, 'b s h d -> b s (h d)')
        return head_output
    
    if args.validate_dataset == 'MLLMGuard':
        if args.categories == 'all':
            args.categories = ['privacy', 'bias', 'toxicity', 'hallucination', 'legality']
        else:
            args.categories = args.categories.split(' ')
        for c in args.categories:
            dataset_path = f"/data/multimodal_alignment/mm_iti/data/MLLMGuard/{c}"
            save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{c}_{args.num_heads}_{args.alpha}"
            if args.use_center_of_mass:
                save_name += '_com'
            if args.use_random_dir:
                save_name += '_rand'
            save_name += args.subfix
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_mllmguard(dataset_path)
            evaluate_model_with_intervention(model, args, validate_data, 
                interventions=interventions, 
                intervention_fn=lt_modulated_vector_add)
            
    elif args.validate_dataset == 'MM-HarmfulBench':
        dataset_path = f"/data/multimodal_alignment/mm_iti/data/MM-HarmfulBench"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{args.num_heads}_{args.alpha}"
        if args.use_center_of_mass:
            save_name += '_com'
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_mmharmfulbench(dataset_path)
        evaluate_model_with_intervention(model, args, validate_data, 
            interventions=interventions, 
            intervention_fn=lt_modulated_vector_add)
    
    elif args.validate_dataset == 'MM-SafetyBench':
        if args.categories == 'all':
            args.categories = ["01-Illegal_Activitiy",
                            "02-HateSpeech",
                            "03-Malware_Generation",
                            "04-Physical_Harm",
                            "05-EconomicHarm",
                            "06-Fraud",
                            "07-Sex",
                            "08-Political_Lobbying",
                            "09-Privacy_Violence",
                            "10-Legal_Opinion",
                            "11-Financial_Advice",
                            "12-Health_Consultation",
                            "13-Gov_Decision"]
        else:
            args.categories = args.categories.split(' ')
        
        for c in args.categories:
            dataset_path = f"/data/multimodal_alignment/mm_iti/data/MM-SafetyBench"
            save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{c.split('-')[-1]}_{args.num_heads}_{args.alpha}"
            if args.use_center_of_mass:
                save_name += '_com'
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_mmsafetybench(dataset_path, c)
            evaluate_model_with_intervention(model, args, validate_data, 
                interventions=interventions, 
                intervention_fn=lt_modulated_vector_add)
    
    elif args.validate_dataset == 'SafeBench':
        dataset_path = f"/data/multimodal_alignment/mm_iti/data/SafeBench"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{args.num_heads}_{args.alpha}"
        if args.use_center_of_mass:
            save_name += '_com'
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_safebench(dataset_path)
        evaluate_model_with_intervention(model, args, validate_data, 
            interventions=interventions, 
            intervention_fn=lt_modulated_vector_add)
    
    elif 'POPE' in args.validate_dataset:
        ## POPE_coco
        dataset = args.validate_dataset.split('_')[-1] 
        dataset_path = "/data/multimodal_alignment/mm_iti/data/POPE"
        if args.categories == 'all':
            args.categories = ['adversarial', 'popular', 'random']
        else:
            args.categories = args.categories.split(' ')
        
        for c in args.categories:
            save_name = f"/data/multimodal_alignment/mm_iti/results/POPE/{dataset}/{args.model}_{c}_{args.num_heads}_{args.alpha}"
            save_name += args.subfix
            print(save_name)
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_pope(dataset_path, dataset, c)

            evaluate_model_with_intervention(model, args, validate_data, 
                interventions=interventions, 
                intervention_fn=lt_modulated_vector_add)
            
    elif args.validate_dataset == 'CHAIR':
        dataset_path = "/data/multimodal_alignment/mm_iti/data/CHAIR"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{args.num_heads}_{args.alpha}"
        save_name += args.subfix
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_chair(dataset_path)

        evaluate_model_with_intervention(model, args, validate_data, 
            interventions=interventions, 
            intervention_fn=lt_modulated_vector_add)
    
    elif args.validate_dataset == 'MMHalBench':
        dataset_path = "/data/multimodal_alignment/mm_iti/data/MMHal-Bench"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{args.num_heads}_{args.alpha}"
        save_name += args.subfix
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_mmhalbench(dataset_path)
        evaluate_model_with_intervention(model, args, validate_data, 
            interventions=interventions, 
            intervention_fn=lt_modulated_vector_add)
         
if __name__ == "__main__":
    args = get_args()
    # wandb.init(project=args.project_name, entity=args.entity_name, config=args)
    
    # args.validate_dataset = 'POPE_coco'
    # args.categories = 'adversarial'
    # # args.neg_mode = 'miss'
    # args.probe_dataset = 'POPE_train_YR_I+Q;C_p2+Q_best'
    # args.model = 'shikra_7B'
    # # args.offset_name = f'{args.model}_offset_generator_q_10'
    # args.use_center_of_mass = True
    # args.num_heads = 16
    # args.alpha = 5
    # args.subfix = '_train_I+Q;C_p2+Q_end31'
    # print(args.save_path)
    # args.device = '0'
    
    # args.use_random_dir = True
    
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
    main(args)
    # seed_all(5555)