import os
import torch
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
import pickle
from utils import *
import argparse
import json
import sys
sys.path.append('/root/project/code/baselines/Med_LVLMs')

HF_NAMES = {
    'llava_med_v1.5': '/root/project/huggingface/llava-med-v1.5-mistral-7b', 
    'alpaca_7B': 'circulus/alpaca-7b', 
    'vicuna_7B': 'AlekseyKorshuk/vicuna-7b', 
    'llama2_chat_7B': 'meta-llama/Llama-2-7b-chat-hf', 
    'llama2_chat_13B': 'meta-llama/Llama-2-13b-chat-hf', 
    'llama2_chat_70B': 'meta-llama/Llama-2-70b-chat-hf', 
    'llava_v1.5_7B': '/root/project/huggingface/llava-v1.5-7b-hf', 
    'llava_v1.5_7B_lht': '/root/project/huggingface/llava-v1.5-7b-liuhaotian', 
    'mplug_owl2_7B': '/root/project/huggingface/mplug-owl2-llama2-7b',  
    'sharegpt4v_7B': '/root/project/huggingface/sharegpt4v-7b', 
    'instructblip_7B': '/root/project/huggingface/instructblip-vicuna-7b-old',
    'qwen_vl_7B': '/root/project/huggingface/qwen-vl',
    'shikra_7B': '/root/wtb/multimodal_alignment/mm_iti/models/shikra_model/shikra_config.py',
    'minigpt4_vicuna_7B': '/root/wtb/multimodal_alignment/mm_iti/models/minigpt4/minigpt4_eval.yaml',
    'minigptv2_llama2_7B': '/root/wtb/multimodal_alignment/mm_iti/models/minigpt4/eval_configs/minigptv2_eval.yaml'
}

def main(): 
    """
    Specify dataset name as the first command line argument. Current options are 
    "tqa_mc2", "piqa", "rte", "boolq", "copa". Gets activations for all prompts in the 
    validation set for the specified dataset on the last token for llama-7B. 
    """

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='llama_7B')
    parser.add_argument('--dataset_name', type=str, default='tqa_mc2')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument("--model_dir", type=str, default=None, help='local directory with model data')
    parser.add_argument('--mode', type=str, default='answer')
    args = parser.parse_args()
    
    args.model_name = 'llava_med_v1.5'
    args.dataset_name = 'Pmc_oa_D+Q'
    args.device = '7'
    args.param = 25
    # os.environ['CUDA_VISIBLE_DEVICES'] = args.device

    MODEL = HF_NAMES[args.model_name] if not args.model_dir else args.model_dir

    if 'GEMeX' in args.dataset_name:
        datafile_path = '/root/project/datasets/mimic_cxr_jpg/GEMeX/processed_gemex_qab_resize2000.jsonl'
        ## 
        prompts, regions, filepaths = process_data_gemex_activation(datafile_path, args.model_name, args.dataset_name)
    elif 'SLAKE' in args.dataset_name:
        datafile_path = '/root/project/datasets/Slake1.0/processed_slake_qab.jsonl'
        ## 
        prompts, regions, filepaths = process_data_slake_activation(datafile_path, args.model_name, args.dataset_name)
    elif 'Mimic_Knowledge' in args.dataset_name:
        knowledge_path = '/root/project/disease_knowledge_dataset/disease_dataset_final.json'
        datafile_path= '/root/project/disease_knowledge_dataset/mimic_type1_dataset_polished.json'
        
        # get_unpolished_reasoning_chain(knowledge_path, datafile_path)
        prompts, filepaths = process_data_mimick_activation(knowledge_path, datafile_path, args.model_name, args.dataset_name)
        regions = ["None" for i in range(len(prompts))]
    elif 'Harvard' in args.dataset_name:
        datafile_path = '/root/project/benchmark_data/harvard/harvard_question_disease_train.jsonl'
        
        prompts, filepaths = process_data_harvard_pmc_activation(datafile_path, args.model_name, args.dataset_name, "/root/project/datasets/harvard/images")
        regions = ["None" for i in range(len(prompts))]
    elif 'Pmc_oa' in args.dataset_name:
        datafile_path = '/root/project/benchmark_data/pmc/pmc_question_disease_train.jsonl'
        
        prompts, filepaths = process_data_harvard_pmc_activation(datafile_path, args.model_name, args.dataset_name, "/root/project/datasets/pmc_oa/caption_T060_filtered_top4_sep_v0_subfigures")
        regions = ["None" for i in range(len(prompts))]
    else: 
        raise ValueError("Invalid dataset name")
    
    pos_type = None
    if "addnoise" in args.dataset_name:
        pos_type = "addnoise"
        args.dataset_name = f'{args.dataset_name}_{args.param}'
    elif "attn_position" in args.dataset_name:
        pos_type = "attn_position"
    elif "attn_qk" in args.dataset_name:
        pos_type = "attn_qk"
        args.dataset_name = f'{args.dataset_name}_{args.param}'
    elif "I+R+Q" in args.dataset_name:
        pos_type = "add_region_text"
        
    if 'llava_med_v1.5' in args.model_name:
        from baselines.Med_LVLMs.llava_med_v15_inference import Llava_med_v15
        model = Llava_med_v15(MODEL, args.device)
            
    # print("Tokenizing prompts")
    # if args.dataset_name == "tqa_gen" or args.dataset_name == "tqa_gen_end_q": 
    #     prompts, labels, categories = formatter(dataset, tokenizer)
    #     with open(f'features/{args.model_name}_{args.dataset_name}_categories.pkl', 'wb') as f:
    #         pickle.dump(categories, f)
    # else: 

    print("Getting activations")
    # ori_path = '/root/wtb/multimodal_alignment/mm_iti/features/minigpt4_vicuna_7B_POPE_train_I+Q_head_wise.npy'
    # ori = np.load(ori_path)
    all_layer_wise_activations = []
    all_head_wise_activations = []
    all_mlp_activations = []
    for (prompt, region, filepath) in tqdm(zip(prompts, regions, filepaths)):
        # if not type(filepath) == list and not os.path.exists(filepath):
        #     layer_wise_activations, head_wise_activations, mlp_activations = model.get_activations_only_text(prompt)
        # else:
        layer_wise_activations, head_wise_activations, mlp_activations, img_idx = model.get_activations(prompt, filepath, region, pos_type, args.param)
        all_layer_wise_activations.append(layer_wise_activations[:,-1,:].copy())
        all_head_wise_activations.append(head_wise_activations[:,-1,:].copy())
        # all_mlp_activations.append(mlp_activations[:, -1, :].copy())
    # print("Saving labels")
    # np.save(f'features/{args.model_name}_{args.dataset_name}_labels.npy', labels)
    
    # print("Saving mlp activations")
    # np.save(f'features/{args.model_name}_{args.dataset_name}_mlp.npy', all_mlp_activations)
    
    # print("Saving layer wise activations")
    # np.save(f'features/{args.model_name}_{args.dataset_name}_layer_wise.npy', all_layer_wise_activations)
    
    print("Saving head wise activations")
    np.save(f'features/{args.model_name}_{args.dataset_name}_head_wise.npy', all_head_wise_activations)
    

    
if __name__ == '__main__':
    main()
