"""
This script is used to evaluate MLLMs.
"""
import argparse
import logging
import random
import time
import os
from einops import rearrange
import numpy as np
import sys
sys.path.append('/root/project/code/baselines/Med_LVLMs')
import torch
import json
from utils import *


PATH = {
    'llava_med_v1.5': '/root/project/huggingface/llava-med-v1.5-mistral-7b',  
    'llava_v1.5_7B_lht': '/root/project/huggingface/llava-v1.5-7b-liuhaotian',
    'mplug_owl2_7B': '/root/project/huggingface/mplug-owl2-llama2-7b',  
    'sharegpt4v_7B': '/root/project/huggingface/sharegpt4v-7b', 
    'instructblip_7B': '/root/project/huggingface/instructblip-vicuna-7b-old',
    'qwen_vl_7B': '/root/project/huggingface/qwen-vl',
    'minigpt4_vicuna_7B': '/root/project/models/minigpt4/minigpt4_eval.yaml',
    'minigptv2_llama2_7B': '/root/project/models/minigpt4/eval_configs/minigptv2_eval.yaml',
    'shikra_7B': '/root/project/models/shikra_model/shikra_config.py'
}

IMAGE_PATH = {
    'mimic_cxr': '/root/project/datasets/mimic_cxr_jpg/files',
    'xray': '/root/project/datasets/iu_xray/images',
    'rad': '/root/project/datasets/VQA_RAD',
    'rad_chest': '/root/project/datasets/VQA_RAD',
    'rad_other': '/root/project/datasets/VQA_RAD',
    'slake': '/root/project/datasets/Slake1.0/imgs',
    'slake_chest': '/root/project/datasets/Slake1.0/imgs',
    'slake_other': '/root/project/datasets/Slake1.0/imgs',
    'harvard': "/root/project/datasets/harvard/images",
    'pmc': "/root/project/datasets/pmc_oa/caption_T060_filtered_top4_sep_v0_subfigures"
}


def seed_all(seed = 8888):
    torch.manual_seed(seed)
    random.seed(seed)
    
def evaluate_batch_with_intervention(model, args, data, interventions, intervention_fn):
    response_list = []
    for sample in tqdm(data):
        prompt = sample['prompt']
        image = sample['img_url']
        res = sample.copy()        
        response = model.evaluate_with_intervention(prompt, image, interventions, intervention_fn)
        res['response'] = response
        print(res)
        response_list.append(res)
    
    with jsonlines.open(args.save_path, 'w') as writer:
        writer.write_all(response_list) 

def get_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('--model', type=str, default='llava_v1.5_7B', help="specifies the model to be evaluated.")
    parser.add_argument('--probe_dataset', type=str, default='spa_vl', help='feature bank for training probes')
    parser.add_argument('--validate_datasets', type=str, default='MLLMGuard', help="specifies the path to the data")
    parser.add_argument('--save_path', type=str, default='results/toxicity_en.jsonl', help='specifies the path to save the results.')

    parser.add_argument('--verbose', type=bool, default=True, help='specifies whether to display verbose outputs.')

    parser.add_argument('--neg', type=str, default=None, help='specifies the project name in wandb.')
    parser.add_argument('--pos', type=str, default=None, help='specifies the project name in wandb.')
    parser.add_argument('--hallu_type', type=str, default=None, help='specifies the project name in wandb.')
    parser.add_argument('--answer_type', type=str, default=None, help='specifies the project name in wandb.')
    parser.add_argument('--num_heads', type=int, default=48, help='K, number of top heads to intervene on')
    parser.add_argument('--alpha', type=int, default=15, help='alpha, intervention strength')
    parser.add_argument('--val_ratio', type=float, help='ratio of validation set size to development set size', default=0.2)
    parser.add_argument('--use_center_of_mass', action='store_true', help='use center of mass direction', default=False)
    parser.add_argument('--use_random_dir', action='store_true', help='use random direction', default=False)
    parser.add_argument('--seed', type=int, default=42, help='seed')
    parser.add_argument('--num_sample', type=int, default=0, help='seed')
    
    parser.add_argument('--device', type=str, default='0')
    parser.add_argument('--categories', type=str, default='all')
    parser.add_argument('--subfix', type=str, default='all')

    parser.add_argument('--start_layer', type=int, default=5, help='K, number of top heads to intervene on')
    parser.add_argument('--end_layer', type=int, default=31, help='K, number of top heads to intervene on')
    
    args = parser.parse_args()
    
    return args

def main(args):
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    
    model_name = args.model.lower()
    
    if 'llava_med_v1.5' in model_name:
        from baselines.Med_LVLMs.llava_med_v15_inference import Llava_med_v15
        sys.path.append('/root/project/code/baselines/Med_LVLMs/llava_med_v15')
        model = Llava_med_v15(PATH[args.model], args.device)
    else:
        raise NotImplementedError(
            f'Model {model_name} has not been implemented.'
        )
    
    # num_layers = model.model.language_model.config.num_hidden_layers
    # num_heads = model.model.language_model.config.num_attention_heads
    num_layers = 32
    num_heads = 32
        
    # load activations (正负样本合一起)
    
    template = '/root/project/features/{}_{}_{}_head_wise.npy'
    neg_path = template.format(args.model, args.probe_dataset, args.neg)
    pos_path = template.format(args.model, args.probe_dataset, args.pos)
    head_wise_activations, labels = load_and_conbine_activations(
        pos_path=pos_path,
        neg_path=neg_path,
        num=args.num_sample
        )
    args.probe_dataset = f'{args.probe_dataset}_{args.neg};{args.pos}'
    
    # head_wise_activations = np.load(f"/root/project/features/{args.model}_{args.probe_dataset}_head_wise.npy")
    # labels = np.load(f"/root/project/features/{args.model}_{args.probe_dataset}_labels.npy")
    
    head_wise_activations = rearrange(head_wise_activations, 'b l (h d) -> b l h d', h = num_heads)

    
    if 'GEMeX' in args.probe_dataset or 'SLAKE' in args.probe_dataset:
        split_range = 2
    elif "Mimic" in args.probe_dataset:
        split_range = 2
    # elif args.probe_dataset == 'POPE_test':

    separated_head_wise_activations, separated_labels, idxs_to_split_at = get_separated_activations(labels, head_wise_activations, split_range)
    
    train_idxs = np.arange(len(separated_head_wise_activations))
    train_set_idxs = np.random.choice(train_idxs, size=int(len(train_idxs)*(1-args.val_ratio)), replace=False)
    val_set_idxs = np.array([x for x in train_idxs if x not in train_set_idxs])
    
    # # get directions
    # if args.use_center_of_mass:
    #     com_directions = get_com_directions(num_layers, num_heads, train_set_idxs, val_set_idxs, separated_head_wise_activations, separated_labels)
    # else:
    #     com_directions = None
    # top_heads, probes = get_top_heads(train_set_idxs, val_set_idxs, separated_head_wise_activations, separated_labels, num_layers, num_heads, args.seed, args.num_heads, args.use_random_dir)
    
    # save_probes(probes, f'/root/project/probes/probes_{args.model}_{args.probe_dataset}_{args.num_sample}.pkl')
    # save_probes(com_directions, f'/root/project/probes/coms_{args.model}_{args.probe_dataset}_{args.num_sample}.pkl')
    # return
    
    if args.use_center_of_mass:
        com_directions = load_probes(f'/root/project/probes/coms_{args.model}_{args.probe_dataset}_{args.num_sample}.pkl')
        print(f'/root/project/probes/coms_{args.model}_{args.probe_dataset}_{args.num_sample}.pkl')
    else:
        com_directions = None
    probes = None
    
    # probes = load_probes(f'/root/project/probes/probes_{args.model}_{args.probe_dataset}.pkl')
    # val_separated_head_wise_activations = separated_head_wise_activations
    # val_separated_labels = separated_labels
    # top_accs, top_heads = val_probes(args.seed, val_set_idxs, val_separated_head_wise_activations, val_separated_labels, 
    #                     probes, num_layers, num_heads, args.num_heads)
    # sorted_idx = np.load('/root/project/features/idx_A_A-.npy')
    # top_heads = sorted_idx[:args.num_heads]
    # top_heads = [flattened_idx_to_layer_head(idx, num_heads) for idx in top_heads]
    
    # len
    top_heads = sort_direction_len(com_directions, num_layers, num_heads, args.num_heads, args.start_layer, args.end_layer)
    
    print("Heads intervened: ", sorted(top_heads))
    
    interventions = get_interventions_dict(model_name, top_heads, probes, 0, num_heads, args.use_center_of_mass, args.use_random_dir, com_directions)

    def lt_modulated_vector_add(head_output, layer_name, start_edit_location='lt'): 
        head_output = rearrange(head_output, 'b s (h d) -> b s h d', h=num_heads)
        for head, direction, proj_val_std in interventions[layer_name]:
            direction_to_add = torch.tensor(direction).to(head_output.device.index)
            if start_edit_location == 'lt': 
                head_output[:, -1, head, :] += args.alpha * proj_val_std * direction_to_add
            else: 
                head_output[:, start_edit_location:, head, :] += args.alpha * proj_val_std * direction_to_add
        head_output = rearrange(head_output, 'b s h d -> b s (h d)')
        return head_output
    
    if not os.path.exists(f"/root/project/results/{args.hallu_type}"):
        os.makedirs(f"/root/project/results/{args.hallu_type}")
    
    args.validate_datasets = args.validate_datasets.split(' ')
    if args.validate_datasets[0] == 'MM':
        args.validate_datasets = ['slake', 'rad']
    elif args.validate_datasets[0] == 'CXR':
        args.validate_datasets = ['mimic_cxr', 'xray']
    elif args.validate_datasets[0] == 'ALL':
        args.validate_datasets = ['slake', 'rad', 'mimic_cxr', 'xray']

    for dataset_name in args.validate_datasets:
        args.image_folder = IMAGE_PATH[dataset_name]
        args.save_path = f'/root/project/results/{args.hallu_type}/{args.model}_{args.answer_type}_{dataset_name}_{args.num_heads}_{args.alpha}_{args.probe_dataset}'
        args.save_path += args.subfix + '.jsonl'

        if dataset_name in ['slake_chest', 'slake_other', 'rad_chest', 'slake_chest', 'mimic_cxr', 'xray']:
            if args.hallu_type == 'visual_misinterpretation':
                args.question_file = f'/root/project/benchmark_data/Visual_Misinterpretation_Hallucination/{args.answer_type}-ended/{dataset_name}_{args.answer_type}_pairs.json'
                
            elif args.hallu_type == 'knowledge_deficiency':
                assert dataset_name == 'mimic_cxr'
                args.question_file = f'/root/project/benchmark_data/Knowledge_Deficiency_Hallucination/{args.answer_type}-ended/{dataset_name}_{args.answer_type}_pairs.json'
                
            elif args.hallu_type == 'context_misalignment':
                assert dataset_name == 'mimic_cxr'
                assert args.answer_type == 'close'
                args.question_file = f'/root/project/benchmark_data/Context_Misalignment_Hallucination/MIMIC-CXR_pairs.json'

            validate_data = process_data_medheval(args.question_file, args.image_folder, args.answer_type)
        
        elif dataset_name in ['harvard', 'pmc']:
            args.question_file = f"/root/project/benchmark_data/{dataset_name}/{dataset_name}_question_disease_{args.answer_type}.jsonl"
            validate_data = process_data_harvard_pmc(args.question_file, args.image_folder, args.answer_type)
        
        evaluate_batch_with_intervention(model, args, validate_data,
            interventions=interventions, 
            intervention_fn=lt_modulated_vector_add)
         
if __name__ == "__main__":
    args = get_args()
    # wandb.init(project=args.project_name, entity=args.entity_name, config=args)
    
    # args.validate_datasets = 'mimic_cxr'
    # args.hallu_type = 'visual_misinterpretation' # visual_misinterpretation  knowledge_deficiency context_misalignment
    # args.answer_type = 'close'
    # args.probe_dataset = 'Mimic_Knowledge'
    # args.pos = 'I+Q+RD'
    # args.neg = 'I+Q_onlyr'
    # args.model = 'llava_med_v1.5'
    # args.use_center_of_mass = True
    # args.num_heads = 8
    # args.alpha = 1
    # args.num_sample = 30
    # args.subfix = ''
    # print(args.save_path)
    # args.device = '6'
    
    # args.use_random_dir = True
    main(args)
    # seed_all(5555)
