import torch
from einops import rearrange
import numpy as np
import pickle
import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import argparse
from datasets import load_dataset
import wandb
import pickle

import sys
sys.path.append('../')
from utils import get_interventions_dict, get_top_heads, get_jb_results, get_com_directions
import llama

N_TEST_SAMPLES = 20

HF_NAMES = {
    'llama_7B': 'baffo32/decapoda-research-llama-7B-hf', 
    'honest_llama_7B': 'results_dump/llama_7B_seed_42_top_48_heads_alpha_15', 
    'alpaca_7B': 'circulus/alpaca-7b', 
    'honest_alpaca_7B': 'results_dump/alpaca_7B_seed_42_top_48_heads_alpha_15', 
    'vicuna_7B': 'AlekseyKorshuk/vicuna-7b', 
    'honest_vicuna_7B': 'results_dump/vicuna_7B_seed_42_top_48_heads_alpha_15', 
    'llama2_chat_7B': 'meta-llama/Llama-2-7b-chat-hf', 
    'honest_llama2_chat_7B': 'results_dump/llama2_chat_7B_seed_42_top_48_heads_alpha_15', 
    'llama2_chat_70B': 'meta-llama/Llama-2-70b-chat-hf', 
    'honest_llama2_chat_70B': 'results_dump/llama2_chat_70B_seed_42_top_48_heads_alpha_15', 
}

def main(): 
    parser = argparse.ArgumentParser()
    parser.add_argument("model_name", type=str, default='llama_7B', choices=HF_NAMES.keys(), help='model name')
    parser.add_argument('--dataset_name', type=str, default='tqa_mc2', help='feature bank for training probes')
    parser.add_argument('--alpha', type=float, default=15, help='alpha, intervention strength')
    parser.add_argument('--val_ratio', type=float, help='ratio of validation set size to development set size', default=0.2)
    parser.add_argument('--use_center_of_mass', action='store_true', help='use center of mass direction', default=False)
    parser.add_argument('--use_random_dir', action='store_true', help='use random direction', default=False)
    parser.add_argument('--seed', type=int, default=42, help='seed')
    parser.add_argument('--is_test', type=int, required=False, default=0)
    parser.add_argument('--no_wandb', type=int, required=False, default=0)
    args = parser.parse_args()

    if not args.no_wandb:
        wandb.init(project="adv-llama", 
                entity="llmattacks",
                config=args)
    
    # set seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # raw_prompts = pickle.load(open(f'../features/{args.model_name}_{args.dataset_name}_raw_prompts.pkl', 'rb'))
    tokenized_prompts = pickle.load(open(f'../features/{args.model_name}_{args.dataset_name}_tokenized_prompts.pkl', 'rb'))

    labels = np.load(f'../features/{args.model_name}_{args.dataset_name}_labels.npy')
    head_wise_activations = np.load(f'../features/{args.model_name}_{args.dataset_name}_layer_wise.npy')
    # layer_wise_activations = np.load(f'../features/{args.model_name}_{args.dataset_name}_head_wise.npy')

    if args.is_test:
        labels = labels[:N_TEST_SAMPLES]
        head_wise_activations = head_wise_activations[:N_TEST_SAMPLES]

    
    # define number of layers and heads
    num_layers = 32 # model.config.num_hidden_layers
    num_heads = 32 # model.config.num_attention_heads

    # TODO understand if this is actually correct
    head_wise_activations = rearrange(head_wise_activations, 'b l (h d) -> b l h d', h = num_heads) 

    train_idxs = np.arange(len(labels)) # np.concatenate([fold_idxs[j] for j in range(args.num_fold) if j != i])

    train_set_idxs = np.random.choice(train_idxs, size=int(len(train_idxs)*(1-args.val_ratio)), replace=False)
    val_set_idxs = np.array([x for x in train_idxs if x not in train_set_idxs])

    # get directions
    if args.use_center_of_mass:
        com_directions = get_com_directions(num_layers, num_heads, train_set_idxs, val_set_idxs, head_wise_activations, labels)
    else:
        com_directions = None

    # comptue the top heads
    top_heads, probes = get_top_heads(train_set_idxs, val_set_idxs, head_wise_activations, labels, num_layers, num_heads, args.seed, args.use_random_dir)
    print("Heads intervened: ", len(top_heads))

    # assemble interventions
    interventions = get_interventions_dict(top_heads, probes, head_wise_activations, num_heads, args.use_center_of_mass, args.use_random_dir, com_directions)

    # breakpoint()
    # start_edit_token = 0
    # while tokenized_prompts[0][0][start_edit_token] == tokenized_prompts[1][0][start_edit_token] == tokenized_prompts[2][0][start_edit_token] == tokenized_prompts[3][0][start_edit_token]:
    #     start_edit_token += 1

    start_edit_token = 134

    print("start_edit_token: ", start_edit_token)

    def lt_modulated_vector_add(head_output, layer_name, start_edit_location="lt"): 
        head_output = rearrange(head_output, 'b s (h d) -> b s h d', h=num_heads)
        for head, direction, proj_val_std in interventions[layer_name]:
            direction_to_add = torch.tensor(direction).to(head_output.device.index)
            if start_edit_location == 'lt': 
                head_output[:, -1, head, :] += args.alpha * proj_val_std * direction_to_add
            else: 
                head_output[:, start_edit_location:, head, :] += args.alpha * proj_val_std * direction_to_add
        head_output = rearrange(head_output, 'b s h d -> b s (h d)')
        return head_output

    evals = {"unattacked_prompts" : [tokenized_prompts[i] for i in train_idxs if labels[i]==0],
             "attacked_prompts" : [tokenized_prompts[i] for i in train_idxs if labels[i]==1]
    }
    
    # create model
    model_name = HF_NAMES[args.model_name]
    model = llama.LlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage = True, torch_dtype=torch.float16, device_map="auto")
    
    # run evals
    for type, prompt_list in evals.items():
        jb_results = get_jb_results(prompt_list, model, interventions=interventions, intervention_fn=lt_modulated_vector_add, start_edit_location=start_edit_token)
        print(f"jailbroken for {type}: {np.mean(jb_results)}")
        if wandb.run:
            wandb.log({f"jailbroken_{type}": np.mean(jb_results)})

if __name__ == "__main__":
    main()
