"""
This script is used to evaluate MLLMs.
"""
import argparse
import logging
import random
import time
import wandb
import os
from einops import rearrange
import numpy as np
import copy

import torch
import json
from utils import *


PATH = {
    'llava_v1.5_7B': '/data/huggingface/llava-v1.5-7b-hf', 
    'mplug_owl2_7B': '/data/huggingface/mplug-owl2-llama2-7b',  
    'sharegpt4v_7B': '/data/huggingface/sharegpt4v-7b', 
}

EDIT_LOC_DICT = {
    'spa_vl': 'lt',
    'seed_bench&safebench_zh': 'img'
}

def seed_all(seed = 8888):
    torch.manual_seed(seed)
    random.seed(seed)
    
def evaluate(func):
    def wrapper(*args, **kwargs):
        result = func(*args, **kwargs)
        return result
    return wrapper

@evaluate
def evaluate_model(model, args, data):
    model.batch_evaluate(args, data)  

@evaluate
def evaluate_model_with_intervention(model, args, data, interventions={}, intervention_fn=None, multiple=False):
    model.batch_evaluate_with_intervention(args, data, interventions, intervention_fn, multiple)  

def get_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('--model', type=str, default='llava_v1.5_7B', help="specifies the model to be evaluated.")
    parser.add_argument('--openai', type=str, default=None, help="specifies the api_key")
    parser.add_argument('--tokenizer', type=str, default=None, help='specifies the tokenizer to be used.')
    
    parser.add_argument('--probe_datasets', type=str, default='spa_vl', help='feature bank for training probes')
    parser.add_argument('--validate_dataset', type=str, default='MLLMGuard', help="specifies the path to the data")
    parser.add_argument('--log_file', type=str, default='logs/default.log', help='specifies the name of the log file')
    parser.add_argument('--save_path', type=str, default='results/toxicity_en.jsonl', help='specifies the path to save the results.')

    parser.add_argument('--verbose', type=bool, default=True, help='specifies whether to display verbose outputs.')
    
    parser.add_argument('--project_name', type=str, default='mllmguard', help='specifies the project name in wandb.')
    parser.add_argument('--entity_name', type=str, default='entity_name', help='specifies the entity name in wandb.')
    
    parser.add_argument('--num_heads', type=int, default=48, help='K, number of top heads to intervene on')
    parser.add_argument('--alpha', type=int, default=15, help='alpha, intervention strength')
    parser.add_argument('--val_ratio', type=float, help='ratio of validation set size to development set size', default=0.2)
    parser.add_argument('--use_center_of_mass', action='store_true', help='use center of mass direction', default=False)
    parser.add_argument('--use_random_dir', action='store_true', help='use random direction', default=False)
    parser.add_argument('--seed', type=int, default=42, help='seed')
    
    parser.add_argument('--device', type=str, default='0')
    parser.add_argument('--categories', type=str, default='all')
    parser.add_argument('--subfix', type=str, default='all')
    
    args = parser.parse_args()
    logging.basicConfig(
        filename=args.log_file,
        filemode="w+",
        format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
        datefmt="%Y-%m-%d-%H-%M",
        level=logging.INFO,
    )
    
    for arg in vars(args):
        logging.info(f"{arg}: {getattr(args, arg)}")
    return args

def main(args):
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    
    args.probe_datasets = args.probe_datasets.split('@')
           
    model_name = args.model.lower()
    
    if model_name == 'gpt4v':
        from apis.gpt4v import GPT4V
        if args.openai and args.organization:
            start_time = time.time()
            mllm = GPT4V(args.openai, args.organization)
            mllm.batch_evaluate(validate_data, args)
            end_time = time.time()
            logging.info(f'Result has been saved to {args.save_path}. Time used: {end_time - start_time}.')
        else:
            raise ValueError('OpenAI API Key or organization is not specified.')
        
    elif 'yi_plus' in model_name:
        from apis.yi_vl_plus import Yi_VL
        if args.model:
            yi_plus = Yi_VL(args.model)
            evaluate(yi_plus, args, validate_data)
        else:
            raise ValueError('API Key is not specified.')
        
    elif 'gemini' in model_name:
        from apis.geminipro import Gemini
        if args.model:
            gemini = Gemini(args.model)
            evaluate(gemini, args, validate_data)
            
    else:
        if 'llava' in model_name:
            from models.llava_inference import Llava
            model = Llava(PATH[args.model])
        elif 'qwen_vl' in model_name:
            from models.qwen import QwenVL
            model = QwenVL(PATH[args.model])
        elif 'qwen' in model_name:
            from models.qwen import Qwen
            model = Qwen(PATH[args.model])
        elif 'cogvlm' in model_name:
            from models.cogvlm import CogVLM
            model = CogVLM(PATH[args.model], args.tokenizer)
        elif 'yi' in model_name:
            from models.yi import YIVL
            model = YIVL(PATH[args.model])
        elif 'deepseek' in model_name:
            from models.deepseek import DeepSeek
            model = DeepSeek(PATH[args.model])
        elif 'mplug_owl2' in model_name:
            from models.mplug import mPLUG_Owl2
            model = mPLUG_Owl2(PATH[args.model])
        elif 'mplug_owl' in model_name:
            from models.mplug import mPLUG_Owl
            model = mPLUG_Owl(PATH[args.model])
        elif 'seed_llama_14B' in model_name:
            from models.seed import SeedLLaMA14B
            model = SeedLLaMA14B(PATH[args.model])
        elif 'seed_llama_8B' in model_name:
            from models.seed import SeedLLaMA8B
            model = SeedLLaMA8B(PATH[args.model])
        elif 'minigptv2' in model_name:
            from models.minigptv2 import MiniGPTV2
            model = MiniGPTV2(PATH[args.model], validate_data)
        elif 'sharegpt' in model_name:
            from models.sharegpt4v import ShareGPT
            model = ShareGPT(PATH[args.model], validate_data)
        elif 'xcomposer' in model_name:
            from models.xcomposer import Xcomposer
            model = Xcomposer(PATH[args.model], validate_data)
        else:
            raise NotImplementedError(
                f'Model {model_name} has not been implemented.'
            )
    
    num_layers = model.model.language_model.config.num_hidden_layers
    num_heads = model.model.language_model.config.num_attention_heads
    # num_layers = 32
    # num_heads = 32

    interventions = []
    edit_locations = []
    train_idxs = np.arange(2000)
    train_set_idxs = np.random.choice(train_idxs, size=int(len(train_idxs)*(1-args.val_ratio)), replace=False)
    val_set_idxs = np.array([x for x in train_idxs if x not in train_set_idxs])
    for probe_dataset in args.probe_datasets:
        edit_locations.append(EDIT_LOC_DICT[probe_dataset])
        ## not a merged dataset
        if not '&' in probe_dataset:
            # load activations 
            head_wise_activations = np.load(f"/data/multimodal_alignment/mm_iti/features/{args.model}_{probe_dataset}_head_wise.npy")
            labels = np.load(f"/data/multimodal_alignment/mm_iti/features/{args.model}_{probe_dataset}_labels.npy")
            head_wise_activations = rearrange(head_wise_activations, 'b l (h d) -> b l h d', h = num_heads)
        else:
            [d1, d2] = probe_dataset.split('&')
            head_wise_activations_1 = np.load(f"/data/multimodal_alignment/mm_iti/features/{args.model}_{d1}_head_wise.npy")
            labels_1 = np.load(f"/data/multimodal_alignment/mm_iti/features/{args.model}_{d1}_labels.npy")
            
            head_wise_activations_2 = np.load(f"/data/multimodal_alignment/mm_iti/features/{args.model}_{d2}_head_wise.npy")
            labels_2 = np.load(f"/data/multimodal_alignment/mm_iti/features/{args.model}_{d2}_labels.npy")
            
            minsize = min(labels_1.size, labels_2.size)
            head_wise_activations = np.empty((2 * minsize, head_wise_activations_1.shape[1], head_wise_activations_1.shape[2]), dtype=head_wise_activations_1.dtype)

            # Alternate fill using slices
            head_wise_activations[0::2, :, :] = head_wise_activations_1[:minsize]
            head_wise_activations[1::2, :, :] = head_wise_activations_2[:minsize]
            labels = np.vstack((labels_1[:minsize], labels_2[:minsize])).reshape((-1,), order='F')
            head_wise_activations = rearrange(head_wise_activations, 'b l (h d) -> b l h d', h = num_heads)
            
        separated_head_wise_activations, separated_labels, idxs_to_split_at = get_separated_activations(labels, head_wise_activations)
        
        # # get directions
        # if args.use_center_of_mass:
        #     com_directions = get_com_directions(num_layers, num_heads, train_set_idxs, val_set_idxs, separated_head_wise_activations, separated_labels)
        # else:
        #     com_directions = None
        # top_heads, probes = get_top_heads(train_set_idxs, val_set_idxs, separated_head_wise_activations, separated_labels, num_layers, num_heads, args.seed, args.num_heads, args.use_random_dir)
        
        # save_probes(probes, f'/data/multimodal_alignment/mm_iti/probes/probes_{args.model}_{probe_dataset}.pkl')
        # save_probes(com_directions, f'/data/multimodal_alignment/mm_iti/probes/coms_{args.model}_{probe_dataset}.pkl')
        
        if args.use_center_of_mass:
            com_directions = load_probes(f'/data/multimodal_alignment/mm_iti/probes/coms_{args.model}_{probe_dataset}.pkl')
        else:
            com_directions = None
        probes = load_probes(f'/data/multimodal_alignment/mm_iti/probes/probes_{args.model}_{probe_dataset}.pkl')
        val_separated_head_wise_activations = separated_head_wise_activations
        val_separated_labels = separated_labels
        top_accs, top_heads = val_probes(args.seed, val_set_idxs, val_separated_head_wise_activations, val_separated_labels, 
                            probes, num_layers, num_heads, args.num_heads)

        print("Heads intervened: ", sorted(top_heads))
        
        # intervention = get_interventions_dict(top_heads, probes, 0, num_heads, args.use_center_of_mass, args.use_random_dir, com_directions)
        intervention = get_interventions_dict_withprobe(top_heads, probes, 0, num_heads, args.use_center_of_mass, args.use_random_dir, com_directions)
        interventions.append(intervention)
    
    interventions = merge_interventions(interventions, edit_locations)
    
    def lt_modulated_vector_add_multiple(head_output, layer_name, special_tokens_location={}): 
        head_output = rearrange(head_output, 'b s (h d) -> b s h d', h=num_heads)
        
        ### 全部干预
               
        # for head, direction, proj_val_std, loc in interventions[layer_name]:
        #     direction_to_add = torch.tensor(direction).to(head_output.device.index)
        #     if type(special_tokens_location[loc]) == slice:
        #         max_idx = special_tokens_location[loc].stop
        #     else:
        #         max_idx = special_tokens_location[loc]
        #     if abs(max_idx) <= head_output.shape[1]:
        #         head_output[:, special_tokens_location[loc], head, :] += args.alpha * proj_val_std * direction_to_add
        
        ### 带probe
        for head, direction, proj_val_std, probe, loc in interventions[layer_name]:
            direction_to_add = torch.tensor(direction).to(head_output.device.index)
            if type(special_tokens_location[loc]) == slice:
                max_idx = special_tokens_location[loc].stop
            else:
                max_idx = special_tokens_location[loc]
            
            if abs(max_idx) <= head_output.shape[1]:
                if type(special_tokens_location[loc]) == slice:
                    img_features = head_output[:, special_tokens_location[loc], head, :].squeeze()
                    nontoxic_score = probe(img_features.to(dtype=torch.double))
                    nontoxic_score = torch.sigmoid(nontoxic_score).squeeze()
                    toxic_idx = (nontoxic_score < 0.5).nonzero() + 1
                    head_output[:, toxic_idx, head, :] += args.alpha * proj_val_std * direction_to_add
                else:
                    head_output[:, special_tokens_location[loc], head, :] += args.alpha * proj_val_std * direction_to_add
        head_output = rearrange(head_output, 'b s h d -> b s (h d)')
        return head_output
    
    if args.validate_dataset == 'MLLMGuard':
        if args.categories == 'all':
            args.categories = ['privacy', 'bias', 'toxicity', 'hallucination', 'legality']
        else:
            args.categories = args.categories.split(' ')
        for c in args.categories:
            dataset_path = f"/data/multimodal_alignment/mm_iti/data/MLLMGuard/{c}"
            save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{c}_{args.num_heads}_{args.alpha}"
            if args.use_center_of_mass:
                save_name += '_com'
            if args.use_random_dir:
                save_name += '_rand'
            save_name += args.subfix
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_mllmguard(dataset_path)
            evaluate_model_with_intervention(model, args, validate_data, 
                interventions=interventions, 
                intervention_fn=lt_modulated_vector_add_multiple, multiple=True)
            
    elif args.validate_dataset == 'MM-HarmfulBench':
        dataset_path = f"/data/multimodal_alignment/mm_iti/data/MM-HarmfulBench"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{args.num_heads}_{args.alpha}"
        if args.use_center_of_mass:
            save_name += '_com'
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_mmharmfulbench(dataset_path)
        evaluate_model_with_intervention(model, args, validate_data, 
            interventions=interventions, 
            intervention_fn=lt_modulated_vector_add_multiple, multiple=True)
    
    elif args.validate_dataset == 'MM-SafetyBench':
        if args.categories == 'all':
            args.categories = ["01-Illegal_Activitiy",
                            "02-HateSpeech",
                            "03-Malware_Generation",
                            "04-Physical_Harm",
                            "05-EconomicHarm",
                            "06-Fraud",
                            "07-Sex",
                            "08-Political_Lobbying",
                            "09-Privacy_Violence",
                            "10-Legal_Opinion",
                            "11-Financial_Advice",
                            "12-Health_Consultation",
                            "13-Gov_Decision"]
        else:
            args.categories = args.categories.split(' ')
        
        for c in args.categories:
            dataset_path = f"/data/multimodal_alignment/mm_iti/data/MM-SafetyBench"
            save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{c.split('-')[-1]}_{args.num_heads}_{args.alpha}"
            if args.use_center_of_mass:
                save_name += '_com'
            args.save_path = f'{save_name}.jsonl'
            validate_data = process_data_mmsafetybench(dataset_path, c)
            evaluate_model_with_intervention(model, args, validate_data, 
                interventions=interventions, 
                intervention_fn=lt_modulated_vector_add_multiple, multiple=True)
    
    elif args.validate_dataset == 'SafeBench':
        dataset_path = f"/data/multimodal_alignment/mm_iti/data/SafeBench"
        save_name = f"/data/multimodal_alignment/mm_iti/results/{args.validate_dataset}/{args.model}_{args.num_heads}_{args.alpha}"
        if args.use_center_of_mass:
            save_name += '_com'
        args.save_path = f'{save_name}.jsonl'
        validate_data = process_data_safebench(dataset_path)
        evaluate_model_with_intervention(model, args, validate_data, 
            interventions=interventions, 
            intervention_fn=lt_modulated_vector_add_multiple, multiple=True)
    
         
if __name__ == "__main__":
    args = get_args()
    # wandb.init(project=args.project_name, entity=args.entity_name, config=args)
    
    # args.probe_datasets = 'seed_bench&safebench_zh'
    # args.validate_dataset = 'MLLMGuard'
    # args.categories = 'bias'
    # args.model = 'llava_v1.5_7B'
    # args.use_center_of_mass = True
    # args.num_heads = 8
    # args.alpha = 1
    # print(args.save_path)
    # args.device = '4'
    # args.use_random_dir = True
    # args.subfix = 'tmp'
    print('args.device', args.device)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
    main(args)
    # seed_all(5555)