"""
This script is used to evaluate MLLMs.
"""
import argparse
import logging
import random
import time
import wandb
import os
from einops import rearrange
import numpy as np

import torch
import json
from utils import *


PATH = {
    'llava_v1.5_7B': '/data/huggingface/llava-v1.5-7b-hf', 
    'mplug_owl2_7B': '/data/huggingface/mplug-owl2-llama2-7b',  
    'sharegpt4v_7B': '/data/huggingface/sharegpt4v-7b', 
}

def seed_all(seed = 8888):
    torch.manual_seed(seed)
    random.seed(seed)
    
def evaluate(func):
    def wrapper(*args, **kwargs):
        result = func(*args, **kwargs)
        return result
    return wrapper

@evaluate
def evaluate_model(model, args, data):
    model.batch_evaluate(args, data)  

@evaluate
def evaluate_model_with_intervention(model, args, data, interventions={}, intervention_fn=None):
    model.batch_evaluate_with_intervention(args, data, interventions, intervention_fn)  

def get_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('--model', type=str, default='llava_v1.5_7B', help="specifies the model to be evaluated.")
    parser.add_argument('--openai', type=str, default=None, help="specifies the api_key")
    parser.add_argument('--tokenizer', type=str, default=None, help='specifies the tokenizer to be used.')
    
    parser.add_argument('--probe_dataset', type=str, default='spa_vl', help='feature bank for training probes')
    parser.add_argument('--validate_dataset', type=str, default='MLLMGuard', help="specifies the path to the data")
    parser.add_argument('--log_file', type=str, default='logs/default.log', help='specifies the name of the log file')
    parser.add_argument('--save_path', type=str, default='results/toxicity_en.jsonl', help='specifies the path to save the results.')

    parser.add_argument('--verbose', type=bool, default=True, help='specifies whether to display verbose outputs.')
    
    parser.add_argument('--probe_mode', type=str, default=None, help='specifies the project name in wandb.')
    parser.add_argument('--entity_name', type=str, default='entity_name', help='specifies the entity name in wandb.')
    
    parser.add_argument('--num_heads', type=int, default=48, help='K, number of top heads to intervene on')
    parser.add_argument('--alpha', type=int, default=15, help='alpha, intervention strength')
    parser.add_argument('--gamma', type=int, default=0.1, help='alpha, intervention strength')
    parser.add_argument('--val_ratio', type=float, help='ratio of validation set size to development set size', default=0.2)
    parser.add_argument('--use_center_of_mass', action='store_true', help='use center of mass direction', default=False)
    parser.add_argument('--use_random_dir', action='store_true', help='use random direction', default=False)
    parser.add_argument('--seed', type=int, default=42, help='seed')
    
    parser.add_argument('--device', type=str, default='0')
    parser.add_argument('--categories', type=str, default='all')
    parser.add_argument('--subfix', type=str, default='all')
    
    args = parser.parse_args()
    logging.basicConfig(
        filename=args.log_file,
        filemode="w+",
        format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
        datefmt="%Y-%m-%d-%H-%M",
        level=logging.INFO,
    )
    
    for arg in vars(args):
        logging.info(f"{arg}: {getattr(args, arg)}")
    return args

def main(args):
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    
    model_name = args.model.lower()
    
    if model_name == 'gpt4v':
        from apis.gpt4v import GPT4V
        if args.openai and args.organization:
            start_time = time.time()
            mllm = GPT4V(args.openai, args.organization)
            mllm.batch_evaluate(validate_data, args)
            end_time = time.time()
            logging.info(f'Result has been saved to {args.save_path}. Time used: {end_time - start_time}.')
        else:
            raise ValueError('OpenAI API Key or organization is not specified.')
        
    elif 'yi_plus' in model_name:
        from apis.yi_vl_plus import Yi_VL
        if args.model:
            yi_plus = Yi_VL(args.model)
            evaluate(yi_plus, args, validate_data)
        else:
            raise ValueError('API Key is not specified.')
        
    elif 'gemini' in model_name:
        from apis.geminipro import Gemini
        if args.model:
            gemini = Gemini(args.model)
            evaluate(gemini, args, validate_data)
            
    else:
        if 'llava' in model_name:
            from models.llava_inference import Llava
            model = Llava(PATH[args.model])
        elif 'qwen_vl' in model_name:
            from models.qwen import QwenVL
            model = QwenVL(PATH[args.model])
        elif 'qwen' in model_name:
            from models.qwen import Qwen
            model = Qwen(PATH[args.model])
        elif 'cogvlm' in model_name:
            from models.cogvlm import CogVLM
            model = CogVLM(PATH[args.model], args.tokenizer)
        elif 'yi' in model_name:
            from models.yi import YIVL
            model = YIVL(PATH[args.model])
        elif 'deepseek' in model_name:
            from models.deepseek import DeepSeek
            model = DeepSeek(PATH[args.model])
        elif 'mplug_owl2' in model_name:
            from models.mplug import mPLUG_Owl2
            model = mPLUG_Owl2(PATH[args.model])
        elif 'mplug_owl' in model_name:
            from models.mplug import mPLUG_Owl
            model = mPLUG_Owl(PATH[args.model])
        elif 'seed_llama_14B' in model_name:
            from models.seed import SeedLLaMA14B
            model = SeedLLaMA14B(PATH[args.model])
        elif 'seed_llama_8B' in model_name:
            from models.seed import SeedLLaMA8B
            model = SeedLLaMA8B(PATH[args.model])
        elif 'minigptv2' in model_name:
            from models.minigptv2 import MiniGPTV2
            model = MiniGPTV2(PATH[args.model], validate_data)
        elif 'sharegpt' in model_name:
            from models.sharegpt4v import ShareGPT
            model = ShareGPT(PATH[args.model], validate_data)
        elif 'xcomposer' in model_name:
            from models.xcomposer import Xcomposer
            model = Xcomposer(PATH[args.model], validate_data)
        else:
            raise NotImplementedError(
                f'Model {model_name} has not been implemented.'
            )
    
    # num_layers = model.model.language_model.config.num_hidden_layers
    # num_heads = model.model.language_model.config.num_attention_heads
    num_layers = 32
    num_heads = 32
        
    # load activations (正负样本合一起)
    probe_mode = args.probe_mode.split(';')
    if len(probe_mode) == 2: ## 只有neg, pos
        neg, pos = probe_mode
        template = '/data/multimodal_alignment/mm_iti/features/{}_POPE_train_{}_head_wise.npy'
        if 'neg' in neg:
            negs = ['miss', 'blur', 'error']
            neg_paths = [template.format(args.model, neg.replace('neg', n)) for n in negs]
        else:
            neg_paths = template.format(args.model, neg)
        pos_path = template.format(args.model, pos)
        head_wise_activations, labels = load_and_conbine_activations(
            pos_path=pos_path,
            neg_paths=neg_paths,
            )
    elif len(probe_mode) == 3: ## neg,pos,sup
        neg, pos, sup = probe_mode
        template = '/data/multimodal_alignment/mm_iti/features/{}_POPE_train_{}_head_wise.npy'
        pos_path = template.format(args.model, pos)
        neg_paths = template.format(args.model, neg)
        if 'neg' in sup:
            sups = ['miss', 'blur', 'error']
            sup_paths = [template.format(args.model, sup.replace('neg', s)) for s in sups]
        else:
            sup_paths = [template.format(args.model, sup)]
        head_wise_activations, labels = load_and_conbine_activations(
            pos_path=pos_path,
            neg_paths=neg_paths,
            sup_paths=sup_paths,
            gamma=args.gamma
            )
        
    # if args.probe_mode == 'neg': ## I+C- +Q -> I+C+ +Q
    #     template = '/data/multimodal_alignment/mm_iti/features/{}_POPE_train_img_p2_{}_head_wise.npy'
    #     negs = ['miss', 'blur', 'error']
    #     neg_paths = [template.format(args.model, neg) for neg in negs]
    #     pos_path = template.format(args.model, 'best')
    #     sup_paths = None
    # elif args.probe_mode == 'img': ## I+Q -> I+C+ +Q
    #     template = '/data/multimodal_alignment/mm_iti/features/{}_POPE_train_img_p2_{}_head_wise.npy'
    #     neg_paths = [f'/data/multimodal_alignment/mm_iti/features/{args.model}_POPE_train_img_head_wise.npy']
    #     pos_path = template.format(args.model, 'best')
    #     sup_paths = None
    # elif args.probe_mode == 'img_sup':
    #     template = '/data/multimodal_alignment/mm_iti/features/{}_POPE_train_img_p2_{}_head_wise.npy'
    #     neg_paths = [f'/data/multimodal_alignment/mm_iti/features/{args.model}_POPE_train_img_head_wise.npy']
    #     pos_path = template.format(args.model, 'best')
    #     sup_paths = None
    # else:
    #     template = '/data/multimodal_alignment/mm_iti/features/{}_POPE_train_p2_{}_head_wise.npy'
    #     neg_paths = [template.format(args.model, args.probe_mode)]
    #     pos_path = '/data/multimodal_alignment/mm_iti/features/{}_POPE_train_p2_{}_q_head_wise.npy'
    #     pos_path = pos_path.format(args.model, args.probe_mode)
    #     # pos_path = template.format(args.model, 'best')
    #     sup_paths = None
    
    
    
    # head_wise_activations = np.load(f"/data/multimodal_alignment/mm_iti/features/{args.model}_{args.probe_dataset}_head_wise.npy")
    # labels = np.load(f"/data/multimodal_alignment/mm_iti/features/{args.model}_{args.probe_dataset}_labels.npy")
    
    head_wise_activations = rearrange(head_wise_activations, 'b l (h d) -> b l h d', h = num_heads)

    # # tuning dataset: no labels used, just to get std of activations along the direction
    # activations_dataset = args.dataset_name if args.activations_dataset is None else args.activations_dataset
    # tuning_activations = np.load(f"../features/{args.model_name}_{activations_dataset}_head_wise.npy")
    # tuning_activations = rearrange(tuning_activations, 'b l (h d) -> b l h d', h = num_heads)
    # tuning_labels = np.load(f"../features/{args.model_name}_{activations_dataset}_labels.npy")
    
    if args.probe_dataset == "spa_vl": 
        # dataset = load_dataset("/data/multimodal_alignment/SPA-VL")['train']
        dataset = json.load(open('/data/multimodal_alignment/SPA-VL/train/meta.json', 'r'))
        dataset = dataset[:2000]
        split_range = 2
    elif 'POPE' in args.probe_dataset:
        split_range = 12
    # elif args.probe_dataset == 'POPE_test':

    separated_head_wise_activations, separated_labels, idxs_to_split_at = get_separated_activations(labels, head_wise_activations, split_range)
    
    train_idxs = np.arange(len(separated_head_wise_activations))
    train_set_idxs = np.random.choice(train_idxs, size=int(len(train_idxs)*(1-args.val_ratio)), replace=False)
    val_set_idxs = np.array([x for x in train_idxs if x not in train_set_idxs])
    
    # get directions
    if args.use_center_of_mass:
        com_directions = get_com_directions(num_layers, num_heads, train_set_idxs, val_set_idxs, separated_head_wise_activations, separated_labels)
    else:
        com_directions = None
    top_heads, probes = get_top_heads(train_set_idxs, val_set_idxs, separated_head_wise_activations, separated_labels, num_layers, num_heads, args.seed, args.num_heads, args.use_random_dir)
    
    save_probes(probes, f'/data/multimodal_alignment/mm_iti/probes/probes_{args.model}_{args.probe_dataset}_{args.probe_mode}.pkl')
    save_probes(com_directions, f'/data/multimodal_alignment/mm_iti/probes/coms_{args.model}_{args.probe_dataset}_{args.probe_mode}.pkl')
    
    if args.use_center_of_mass:
        com_directions = load_probes(f'/data/multimodal_alignment/mm_iti/probes/coms_{args.model}_{args.probe_dataset}_{args.probe_mode}.pkl')
    else:
        com_directions = None
    
    probes = load_probes(f'/data/multimodal_alignment/mm_iti/probes/probes_{args.model}_{args.probe_dataset}_{args.probe_mode}.pkl')
    val_separated_head_wise_activations = separated_head_wise_activations
    val_separated_labels = separated_labels
    top_accs, top_heads = val_probes_2(args.seed, val_set_idxs, val_separated_head_wise_activations, val_separated_labels, 
                        probes, num_layers, num_heads, args.num_heads)
    
    # top_heads2 = sort_direction_len(com_directions, num_layers, num_heads, args.num_heads)
    
    print("Heads intervened: ", sorted(top_heads))
    
    interventions = get_interventions_dict(model_name, top_heads, probes, 0, num_heads, args.use_center_of_mass, args.use_random_dir, com_directions)

    def lt_modulated_vector_add(head_output, layer_name, start_edit_location='lt'): 
        head_output = rearrange(head_output, 'b s (h d) -> b s h d', h=num_heads)
        for head, direction, proj_val_std in interventions[layer_name]:
            direction_to_add = torch.tensor(direction).to(head_output.device.index)
            if start_edit_location == 'lt': 
                head_output[:, -1, head, :] += args.alpha * proj_val_std * direction_to_add
            else: 
                head_output[:, start_edit_location:, head, :] += args.alpha * proj_val_std * direction_to_add
        head_output = rearrange(head_output, 'b s h d -> b s (h d)')
        return head_output
    
    if args.validate_dataset == 'MLLMGuard':
        if args.categories == 'all':
            args.categories = ['privacy', 'bias', 'toxicity', 'hallucination', 'legality']
        else:
            args.categories = args.categories.split(' ')
        for c in args.categories:
            dataset_path = f"/data/multimodal_alignment/mm_iti/data/MLLMGuard/{c}"
            save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{c}_{args.num_heads}_{args.alpha}"
            if args.use_center_of_mass:
                save_name += '_com'
            if args.use_random_dir:
                save_name += '_rand'
            save_name += args.subfix
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_mllmguard(dataset_path)
            evaluate_model_with_intervention(model, args, validate_data, 
                interventions=interventions, 
                intervention_fn=lt_modulated_vector_add)
            
    elif args.validate_dataset == 'MM-HarmfulBench':
        dataset_path = f"/data/multimodal_alignment/mm_iti/data/MM-HarmfulBench"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{args.num_heads}_{args.alpha}"
        if args.use_center_of_mass:
            save_name += '_com'
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_mmharmfulbench(dataset_path)
        evaluate_model_with_intervention(model, args, validate_data, 
            interventions=interventions, 
            intervention_fn=lt_modulated_vector_add)
    
    elif args.validate_dataset == 'MM-SafetyBench':
        if args.categories == 'all':
            args.categories = ["01-Illegal_Activitiy",
                            "02-HateSpeech",
                            "03-Malware_Generation",
                            "04-Physical_Harm",
                            "05-EconomicHarm",
                            "06-Fraud",
                            "07-Sex",
                            "08-Political_Lobbying",
                            "09-Privacy_Violence",
                            "10-Legal_Opinion",
                            "11-Financial_Advice",
                            "12-Health_Consultation",
                            "13-Gov_Decision"]
        else:
            args.categories = args.categories.split(' ')
        
        for c in args.categories:
            dataset_path = f"/data/multimodal_alignment/mm_iti/data/MM-SafetyBench"
            save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{c.split('-')[-1]}_{args.num_heads}_{args.alpha}"
            if args.use_center_of_mass:
                save_name += '_com'
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_mmsafetybench(dataset_path, c)
            evaluate_model_with_intervention(model, args, validate_data, 
                interventions=interventions, 
                intervention_fn=lt_modulated_vector_add)
    
    elif args.validate_dataset == 'SafeBench':
        dataset_path = f"/data/multimodal_alignment/mm_iti/data/SafeBench"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{args.num_heads}_{args.alpha}"
        if args.use_center_of_mass:
            save_name += '_com'
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_safebench(dataset_path)
        evaluate_model_with_intervention(model, args, validate_data, 
            interventions=interventions, 
            intervention_fn=lt_modulated_vector_add)
    
    elif args.validate_dataset == 'POPE':
        dataset_path = "/data/multimodal_alignment/mm_iti/data/POPE"
        if args.categories == 'all':
            args.categories = ['adversarial', 'popular', 'random']
        else:
            args.categories = args.categories.split(' ')
        
        for c in args.categories:
            save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{c}_{args.num_heads}_{args.alpha}"
            save_name += args.subfix
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_pope(dataset_path, c)

            evaluate_model_with_intervention(model, args, validate_data, 
                interventions=interventions, 
                intervention_fn=lt_modulated_vector_add)
            
    elif args.validate_dataset == 'CHAIR':
        dataset_path = "/data/multimodal_alignment/mm_iti/data/CHAIR"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{args.num_heads}_{args.alpha}"
        save_name += args.subfix
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_chair(dataset_path)

        evaluate_model_with_intervention(model, args, validate_data, 
            interventions=interventions, 
            intervention_fn=lt_modulated_vector_add)
         
if __name__ == "__main__":
    args = get_args()
    # wandb.init(project=args.project_name, entity=args.entity_name, config=args)
    
    # args.validate_dataset = 'POPE'
    # args.categories = 'adversarial'
    # args.probe_mode = 'I+Q;C_p2+Q_best'
    # args.probe_dataset = 'POPE_train'
    # args.model = 'llava_v1.5_7B'
    # args.use_center_of_mass = True
    # args.num_heads = 32
    # args.alpha = 3
    # args.subfix = 'tmp'
    # print(args.save_path)
    # args.device = '3'
    
    # args.use_random_dir = True
    
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
    main(args)
    # seed_all(5555)