import os
import torch
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
import pickle
from utils import *
import argparse
import json

HF_NAMES = {
    'llama_7B': 'baffo32/decapoda-research-llama-7B-hf',
    'alpaca_7B': 'circulus/alpaca-7b', 
    'vicuna_7B': 'AlekseyKorshuk/vicuna-7b', 
    'llama2_chat_7B': 'meta-llama/Llama-2-7b-chat-hf', 
    'llama2_chat_13B': 'meta-llama/Llama-2-13b-chat-hf', 
    'llama2_chat_70B': 'meta-llama/Llama-2-70b-chat-hf', 
    'llava_v1.5_7B': '/data/huggingface/llava-v1.5-7b-hf', 
    'llava_v1.5_7B_lht': '/data/huggingface/llava-v1.5-7b-liuhaotian', 
    'mplug_owl2_7B': '/data/huggingface/mplug-owl2-llama2-7b',  
    'sharegpt4v_7B': '/data/huggingface/sharegpt4v-7b', 
    'instructblip_7B': '/data/huggingface/instructblip-vicuna-7b-old',
    'qwen_vl_7B': '/data/huggingface/qwen-vl',
    'shikra_7B': '/data/multimodal_alignment/mm_iti/models/shikra_model/shikra_config.py',
    'minigpt4_vicuna_7B': '/data/multimodal_alignment/mm_iti/models/minigpt4/minigpt4_eval.yaml',
    'minigptv2_llama2_7B': '/data/multimodal_alignment/mm_iti/models/minigpt4/eval_configs/minigptv2_eval.yaml'
}

def main(): 
    """
    Specify dataset name as the first command line argument. Current options are 
    "tqa_mc2", "piqa", "rte", "boolq", "copa". Gets activations for all prompts in the 
    validation set for the specified dataset on the last token for llama-7B. 
    """

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='llama_7B')
    parser.add_argument('--dataset_name', type=str, default='tqa_mc2')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument("--model_dir", type=str, default=None, help='local directory with model data')
    parser.add_argument('--mode', type=str, default='answer')
    args = parser.parse_args()
    
    args.model_name = 'llava_v1.5_7B_lht'
    args.dataset_name = 'POPE_sample2_YR_C_p2+Q'
    args.mode = 'YR_C_p2+Q'
    args.device = '6'
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device

    MODEL = HF_NAMES[args.model_name] if not args.model_dir else args.model_dir

    if args.dataset_name == "spa_vl": 
        dataset = json.load(open('/data/multimodal_alignment/SPA-VL/train/meta.json', 'r'))
        dataset = dataset[:2000]
        args.mode = 'answer'
        prompts, labels, filepaths = process_data_spa_vl(dataset)
    elif args.dataset_name == "seed_bench": 
        dataset = json.load(open('/data/multimodal_alignment/SEED-Bench/SEED-Bench.json', 'r'))['questions']
        dataset = dataset[:2000]
        args.mode = 'image'
        prompts, labels, filepaths = process_data_seed_bench(dataset)
    elif args.dataset_name == 'safebench_zh': 
        data_path = '/data/multimodal_alignment/safebench_zh/final_bench'
        args.mode = 'image'
        prompts, labels, filepaths = process_data_safebench_zh(data_path)
    elif args.dataset_name == 'flickr30k':
        data_path = '/data/multimodal_alignment/mm_iti/data/Flickr30k'
        args.mode = 'image_to_text'
        filepaths, annotations, sentences = process_data_flickr30k(data_path)
    elif 'POPE' in args.dataset_name:
        data_path = '/data/multimodal_alignment/mm_iti/data/POPE'
        if 'train' in args.dataset_name:
            datafile_path = os.path.join(data_path, 'train_3k_complete_qs_with_simplecap.json')
        elif 'val' in args.dataset_name:
            datafile_path = os.path.join(data_path, 'val_500.json')
        elif 'sample' in args.dataset_name:
            datafile_path = os.path.join(data_path, 'captions_selected_2.json')
        ## 
        prompts, labels, filepaths = process_data_pope_activation(datafile_path, args.model_name, args.mode)
        
    else: 
        raise ValueError("Invalid dataset name")
    
    

    if 'llava' in args.model_name and 'lht' in args.model_name:
        from models.llava_inference_lht import Llava_lht
        model = Llava_lht(MODEL)
        
    elif 'llava' in args.model_name:
        from models.llava_inference import Llava
        model = Llava(MODEL)

    elif 'shikra' in args.model_name:
        from models.shikra_inference import Shikra
        model = Shikra(MODEL)
        
    elif 'qwen_vl' in args.model_name:
        from models.qwen import QwenVL
        model = QwenVL(MODEL)
        
    elif 'qwen' in args.model_name:
        from models.qwen import Qwen
        model = Qwen(MODEL)
        
    elif 'cogvlm' in args.model_name:
        from models.cogvlm import CogVLM
        model = CogVLM(MODEL, args.tokenizer)
        
    elif 'yi' in args.model_name:
        from models.yi import YIVL
        model = YIVL(MODEL)
        
    elif 'deepseek' in args.model_name:
        from models.deepseek import DeepSeek
        model = DeepSeek(MODEL)
        
    elif 'mplug_owl2' in args.model_name:
        from models.mplug import mPLUG_Owl2
        model = mPLUG_Owl2(MODEL)
        
    elif 'mplug_owl' in args.model_name:
        from models.mplug import mPLUG_Owl
        model = mPLUG_Owl(MODEL)
        
    elif 'seed_llama_14B' in args.model_name:
        from models.seed import SeedLLaMA14B
        model = SeedLLaMA14B(MODEL)
        
    elif 'seed_llama_8B' in args.model_name:
        from models.seed import SeedLLaMA8B
        model = SeedLLaMA8B(MODEL)
        
    elif 'minigptv2' in args.model_name:
        from models.minigptv2 import MiniGPTV2
        model = MiniGPTV2(MODEL, args.device, args.model_name)
    elif 'minigpt' in args.model_name:
        from models.minigpt4_inference import MiniGPT4
        model = MiniGPT4(MODEL, args.device, args.model_name)
    elif 'sharegpt' in args.model_name:
        from models.sharegpt4v import ShareGPT
        model = ShareGPT(MODEL, data)
        
    elif 'xcomposer' in args.model_name:
        from models.xcomposer import Xcomposer
        model = Xcomposer(MODEL, data)
        
    elif 'instructblip' in args.model_name:
        from models.instructblip_inference import InstructBlip
        model = InstructBlip(MODEL)
            
    # print("Tokenizing prompts")
    # if args.dataset_name == "tqa_gen" or args.dataset_name == "tqa_gen_end_q": 
    #     prompts, labels, categories = formatter(dataset, tokenizer)
    #     with open(f'features/{args.model_name}_{args.dataset_name}_categories.pkl', 'wb') as f:
    #         pickle.dump(categories, f)
    # else: 

    print("Getting activations")
    if args.mode == 'answer' or args.mode == 'query' or 'query' in args.mode or '+' in args.mode:
        # ori_path = '/data/multimodal_alignment/mm_iti/features/minigpt4_vicuna_7B_POPE_train_I+Q_head_wise.npy'
        # ori = np.load(ori_path)
        all_layer_wise_activations = []
        all_head_wise_activations = []
        all_mlp_activations = []
        for (prompt, filepath) in tqdm(zip(prompts, filepaths)):
            if not os.path.exists(filepath):
                layer_wise_activations, head_wise_activations, mlp_activations = model.get_activations_only_text(prompt)
            else:
                layer_wise_activations, head_wise_activations, mlp_activations, img_idx = model.get_activations(prompt, filepath)
            all_layer_wise_activations.append(layer_wise_activations[:,-1,:].copy())
            all_head_wise_activations.append(head_wise_activations[:,-1,:].copy())
            # all_mlp_activations.append(mlp_activations[:, -1, :].copy())
        # print("Saving labels")
        # np.save(f'features/{args.model_name}_{args.dataset_name}_labels.npy', labels)
        
        # print("Saving mlp activations")
        # np.save(f'features/{args.model_name}_{args.dataset_name}_mlp.npy', all_mlp_activations)
        
        print("Saving layer wise activations")
        np.save(f'features/{args.model_name}_{args.dataset_name}_layer_wise.npy', all_layer_wise_activations)
        
        print("Saving head wise activations")
        np.save(f'features/{args.model_name}_{args.dataset_name}_head_wise.npy', all_head_wise_activations)
    
    elif args.mode == 'image':
        all_layer_wise_activations = []
        all_head_wise_activations = []
        for (prompt, filepath) in tqdm(zip(prompts, filepaths)):
            layer_wise_activations, head_wise_activations, _, img_idx = model.get_activations(prompt, filepath)
            all_layer_wise_activations.append(layer_wise_activations[:,img_idx[0],:].copy())
            all_head_wise_activations.append(head_wise_activations[:,img_idx[0],:].copy())

        print("Saving labels")
        np.save(f'features/{args.model_name}_{args.dataset_name}_labels.npy', labels)

        print("Saving layer wise activations")
        np.save(f'features/{args.model_name}_{args.dataset_name}_layer_wise.npy', all_layer_wise_activations)
        
        print("Saving head wise activations")
        np.save(f'features/{args.model_name}_{args.dataset_name}_head_wise.npy', all_head_wise_activations)

    elif args.mode == 'image_to_text':
        whole_image_activations = []
        whole_text_activations = []
        instance_image_activations = []
        instance_text_activations = []
        for (filepath, annotation, sentence) in tqdm(zip(filepaths, annotations, sentences)):
            text_activations = []
            for s in sentence:
                prompt = s['sentence']
                image_activation, text_activation, text_ids = model.get_projected_activations(prompt, filepath)
                if len(text_activations) == 0:
                    whole_image_activations.append(image_activation)
                text_activations.append(text_activation)
            whole_text_activations.append(text_activations)
        
        print("Saving whole_image_activations")
        np.save(f'features/{args.model_name}_{args.dataset_name}_whole_image_activations.npy', whole_image_activations)

        print("Saving whole_text_activations")
        # np.save(f'features/{args.model_name}_{args.dataset_name}_whole_text_activations.npy', whole_text_activations)
        with open(f'features/{args.model_name}_{args.dataset_name}_whole_text_activations.npy', 'wb') as f:
            pickle.dump(whole_text_activations, f)

    elif args.mode == 'pope' or 'pope' in args.mode:
        all_layer_wise_activations = []
        all_head_wise_activations = []
        all_mlp_activations = []
        for (prompt, filepath) in tqdm(zip(prompts, filepaths)):
            if not os.path.exists(filepath):
                _, head_wise_activations, mlp_activations = model.get_activations_only_text(prompt)
            else:
                _, head_wise_activations, mlp_activations, img_idx = model.get_activations(prompt, filepath)
            # all_layer_wise_activations.append(layer_wise_activations[:,-1,:].copy())
            all_head_wise_activations.append(head_wise_activations[:,-1,:].copy())
            all_mlp_activations.append(mlp_activations[:, -1, :].copy())
        print("Saving labels")
        np.save(f'features/{args.model_name}_{args.dataset_name}_labels.npy', labels)

        # print("Saving layer wise activations")
        # np.save(f'features/{args.model_name}_{args.dataset_name}_layer_wise_p2.npy', all_layer_wise_activations)
        
        print("Saving mlp activations")
        np.save(f'features/{args.model_name}_{args.dataset_name}_mlp.npy', all_mlp_activations)
        
        print("Saving head wise activations")
        np.save(f'features/{args.model_name}_{args.dataset_name}_head_wise.npy', all_head_wise_activations)

        
if __name__ == '__main__':
    main()
