# python get_activations.py Qwen1.5-14B-Chat DRC --model_dir "/data/CharacterAI/PretainedModels/Qwen1.5-14B-Chat"
import os
import torch
import numpy as np
import pickle
from utils import get_activations_bau, tokenized_tqa, tokenized_tqa_gen_DRC, tokenized_tqa_gen_Shakespeare,tokenized_tqa_gen,tokenized_tqa_gen_zh,tokenized_tqa_gen_zh_all,tokenized_tqa_gen_all
import llama
import qwen2
import argparse
import json
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
def main(): 
    """
    Specify dataset name as the first command line argument. Current options are 
    "tqa_mc2", "piqa", "rte", "boolq", "copa". Gets activations for all prompts in the 
    validation set for the specified dataset on the last token for llama-7B. 
    """

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='Qwen1.5-14B-Chat')
    parser.add_argument('--dataset_name', type=str, default='Daiyu')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument("--debug", type=int, default=1, help='if set, only use 100 samples for debugging')
    args = parser.parse_args()


    if args.model_name == "Qwen1.5-14B-Chat":
        MODEL = "/new_disk1/XXXX-3/projects/PretrainModels/Qwen1.5-14B-Chat"
    else:
        pass
    tokenizer = qwen2.Qwen2Tokenizer.from_pretrained(MODEL)
    model = qwen2.Qwen2ForCausalLM.from_pretrained(MODEL, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto")
    # tokenizer = AutoTokenizer.from_pretrained(MODEL)
    # model = AutoModelForCausalLM.from_pretrained(MODEL, low_cpu_mem_usage=True, torch_dtype=torch.float32, device_map="auto")
    device = "cuda"
    if args.dataset_name == "DRC": 
        with open("dataset/Train_DRC.json", 'r', encoding='utf-8') as file:
            dataset = json.load(file)
        formatter = tokenized_tqa_gen_DRC
    elif args.dataset_name == "DRC_merge": 
        with open("dataset/Train_DRC_merge.json", 'r', encoding='utf-8') as file:
            dataset = json.load(file)
        formatter = tokenized_tqa_gen_DRC
    elif args.dataset_name == "DRC_白话_question": 
        with open("dataset/Train_DRC_白话_question.json", 'r', encoding='utf-8') as file:
            dataset = json.load(file)
        formatter = tokenized_tqa_gen_DRC
    elif args.dataset_name == "DRC_tqa_question": 
        with open("dataset/Train_DRC_tqa_question.json", 'r', encoding='utf-8') as file:
            dataset = json.load(file)
        formatter = tokenized_tqa_gen_DRC
    elif args.dataset_name == "tqa_gen_zh": 
        with open("dataset/Train_tqa_gen_zh.json", 'r', encoding='utf-8') as file:
            dataset = json.load(file)
        formatter = tokenized_tqa_gen_zh
    elif args.dataset_name == "tqa_gen_zh_all": 
        with open("dataset/Train_tqa_gen_zh.json", 'r', encoding='utf-8') as file:
            dataset = json.load(file)
        formatter = tokenized_tqa_gen_zh_all
    elif args.dataset_name == "tqa_gen_zh_filter_acc": 
        with open("dataset/Train_tqa_gen_zh_filter_acc.json", 'r', encoding='utf-8') as file:
            dataset = json.load(file)
        formatter = tokenized_tqa_gen_zh
    elif args.dataset_name == "tqa_gen_zh_filter_acc_all": 
        with open("dataset/Train_tqa_gen_zh_filter_acc.json", 'r', encoding='utf-8') as file:
            dataset = json.load(file)
        formatter = tokenized_tqa_gen_zh_all
    elif args.dataset_name == "tqa_gen_zh_filter_score": 
        with open("dataset/Train_tqa_gen_zh_filter_score.json", 'r', encoding='utf-8') as file:
            dataset = json.load(file)
        formatter = tokenized_tqa_gen_zh
    elif args.dataset_name == "tqa_gen_zh_filter_score_all":
        with open("dataset/Train_tqa_gen_zh_filter_score.json", 'r', encoding='utf-8') as file:
            dataset = json.load(file)
        formatter = tokenized_tqa_gen_zh_all
    elif args.dataset_name == "Shakespeare":
        with open("dataset/Train_Shakespeare.json", 'r', encoding='utf-8') as file:
            dataset = json.load(file)
        formatter = tokenized_tqa_gen_Shakespeare
    elif args.dataset_name == "tqa_gen": 
        with open("dataset/Train_tqa_gen.json", 'r', encoding='utf-8') as file:
            dataset = json.load(file)
        formatter = tokenized_tqa_gen  
    elif args.dataset_name == "tqa_gen_all": 
        with open("dataset/Train_tqa_gen.json", 'r', encoding='utf-8') as file:
            dataset = json.load(file)
        formatter = tokenized_tqa_gen_all  
    else: 
        raise ValueError("Invalid dataset name")
    if args.debug==1:
        # 随机采样1000条数据
        import random
        random.seed(42)
        dataset = random.sample(dataset, 1000)
        # dataset = dataset[:200]
    print("Tokenizing prompts")
    print(len(dataset))
    prompts, labels = formatter(dataset, tokenizer)
    print(len(prompts), len(labels))
    all_layer_wise_activations = []
    all_head_wise_activations = []
    print("Getting activations")

    import gc
    for prompt in tqdm(prompts):
        layer_wise_activations, head_wise_activations, _ = get_activations_bau(model, prompt, device)
        layer_wise_activations_wanted = layer_wise_activations[:,-1,:].copy()
        head_wise_activations_wanted = head_wise_activations[:,-1,:].copy()  # 最后一个token  除去特殊token
        del layer_wise_activations, head_wise_activations, _
        all_layer_wise_activations.append(layer_wise_activations_wanted)
        all_head_wise_activations.append(head_wise_activations_wanted)
            # from einops import rearrange
            # rearrange(head_wise_activations_wanted, 'l (h d) -> l h d', h = 28)
        gc.collect()

    # Ensure directory exists
    os.makedirs('features', exist_ok=True)
    print("Saving labels")
    np.save(f'features/{args.dataset_name}/{args.model_name}_{args.dataset_name}_labels.npy', labels)

    print("Saving layer wise activations")
    np.save(f'features/{args.dataset_name}/{args.model_name}_{args.dataset_name}_layer_wise.npy', all_layer_wise_activations)
    
    print("Saving head wise activations")
    np.save(f'features/{args.dataset_name}/{args.model_name}_{args.dataset_name}_head_wise.npy', all_head_wise_activations)
    
if __name__ == '__main__':
    main()