import os, re, json
import math

import torch, numpy as np
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer

import argparse

from nethook import TraceDict, get_module, get_parameter

# Include prompt creation helper functions
from utils.prompt_utils import *
from utils.intervention_utils import *
from utils.model_utils import *

def gather_activations(prompt_data, layers, dummy_labels, model, tokenizer):
    """
    Collects activations for an ICL prompt 

    Parameters:
    prompt_data: dict containing
    eval_word_pair: dict with a query input/output pair
    layers: layer names to get activatons from
    dummy_labels: labels and indices for a baseline prompt with the same number of example pairs
    model: huggingface model
    tokenizer: huggingface tokenizer

    
    Returns:
    td: tracedict with stored activations
    idx_map: map of token indices to respective averaged token indices
    idx_avg: dict containing token indices of multi-token words
    """   
    
    # Get sentence and token labels
    query = prompt_data['query_target']['input']
    token_labels, prompt_string = get_token_meta_labels(prompt_data, tokenizer, query)
    sentence = [prompt_string]

    inputs = tokenizer(sentence, return_tensors='pt').to(model.device)
    idx_map, idx_avg = compute_duplicated_labels(token_labels, dummy_labels)

    # Access Activations 
    with TraceDict(model, layers=layers, retain_input=True, retain_output=False) as td:                
        model(**inputs) # batch_size x n_tokens x vocab_size, only want last token prediction

    return td, idx_map, idx_avg

def get_mean_head_activations(dataset, model, model_config, tokenizer, n_icl_examples = 10, N_TRIALS = 100, shuffle_labels=False, prefixes=None, separators=None, filter_set=None):
    """
    Computes the average activations for each attention head in the model, where multi-token phrases are condensed into a single slot through averaging.

    Parameters: 
    dataset:
    model:
    model_config:
    tokenizer:
    n_icl_examples: Number of shots in each in-context prompt
    N_TRIALS: Number of in-context prompts to average over

    Returns:
    mean_activations:
    """
    def split_activations_by_head(activations, model_config):
        new_shape = activations.size()[:-1] + (model_config['n_heads'], model_config['resid_dim']//model_config['n_heads']) # split by head: + (n_attn_heads, hidden_size/n_attn_heads)
        activations = activations.view(*new_shape)  # (batch_size, n_tokens, n_heads, head_hidden_dim)
        return activations
    
    n_test_examples = 1
    if prefixes is not None and separators is not None:
        dummy_labels = get_dummy_token_labels(n_icl_examples, tokenizer=tokenizer, prefixes=prefixes, separators=separators)
    else:
        dummy_labels = get_dummy_token_labels(n_icl_examples, tokenizer=tokenizer)
    activation_storage = torch.zeros(N_TRIALS, model_config['n_layers'], model_config['n_heads'], len(dummy_labels), model_config['resid_dim']//model_config['n_heads'])

    if filter_set is None:
        filter_set = np.arange(len(dataset['valid']))
    # else:
        # assert isinstance(filter_set, list) or isinstance(filter_set, int) or isinstance(filter_set, np.ndarray), "filter_set is not [int, list, ndarray]"

    is_llama = 'llama' in model_config['name_or_path']
    prepend_bos = not is_llama

    for n in range(N_TRIALS):
        word_pairs = dataset['train'][np.random.choice(len(dataset['train']),n_icl_examples, replace=False)]
        word_pairs_test = dataset['valid'][np.random.choice(filter_set,n_test_examples, replace=False)]
        if prefixes is not None and separators is not None:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=prepend_bos, 
                                                    shuffle_labels=shuffle_labels, prefixes=prefixes, separators=separators)
        else:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=prepend_bos, shuffle_labels=shuffle_labels)
        activations_td,idx_map,idx_avg = gather_activations(prompt_data=prompt_data, 
                                                            layers = model_config['attn_hook_names'], 
                                                            dummy_labels=dummy_labels, 
                                                            model=model, 
                                                            tokenizer=tokenizer)
        
        stack_initial = torch.vstack([split_activations_by_head(activations_td[layer].input, model_config) for layer in model_config['attn_hook_names']]).permute(0,2,1,3)
        stack_filtered = stack_initial[:,:,list(idx_map.keys())]
        for (i,j) in idx_avg.values():
            stack_filtered[:,:,idx_map[i]] = stack_initial[:,:,i:j+1].mean(axis=2) # Average activations of multi-token prompt parts across all its tokens
        
        activation_storage[n] = stack_filtered

    mean_activations = activation_storage.mean(dim=0)
    return mean_activations


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset_name', help='Name of the dataset to be loaded', type=str, required=True)
    parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b')
    parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='dataset_files')
    parser.add_argument('--save_path_root', help='File path to save mean activations to', type=str, required=False, default='results')
    parser.add_argument('--seed', help='Randomized seed', type=int, required=False, default=42)
    parser.add_argument('--n_shots', help="Number of shots in each in-context prompt", required=False, default=10)
    parser.add_argument('--n_trials', help="Number of in-context prompts to average over", required=False, default=100)
    parser.add_argument('--test_split', help="Percentage corresponding to test set split size", required=False, default=0.3)
    parser.add_argument('--device', help='Device to run on', required=False, default='cuda' if torch.cuda.is_available() else 'cpu')
    parser.add_argument('--prefixes', help='Prompt template prefixes to be used', type=json.loads, required=False, default={"input":"Q:", "output":"A:", "instructions":""})
    parser.add_argument('--separators', help='Prompt template separators to be used', type=json.loads, required=False, default={"input":"\n", "output":"\n\n", "instructions":""})    
        
    
    args = parser.parse_args()
    
    dataset_name = args.dataset_name
    model_name = args.model_name
    root_data_dir = args.root_data_dir
    save_path_root = f"{args.save_path_root}/{dataset_name}"
    seed = args.seed
    n_shots = args.n_shots
    n_trials = args.n_trials
    test_split = args.test_split
    device = args.device
    prefixes = args.prefixes
    separators = args.separators
    

    # Load Model & Tokenizer
    torch.set_grad_enabled(False)
    print("Loading Model")
    model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device)

    set_seed(seed)

    # Load the dataset
    print("Loading Dataset")
    dataset = load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed)

    print("Computing Mean Activations")
    mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer, 
                                                 n_icl_examples=n_shots, N_TRIALS=n_trials, prefixes=prefixes, separators=separators)

    if not os.path.exists(save_path_root):
        os.makedirs(save_path_root)

    # Write args to file
    args.save_path_root = save_path_root # update for logging
    with open(f'{save_path_root}/mean_head_activation_args.txt', 'w') as arg_file:
        json.dump(args.__dict__, arg_file, indent=2)
    
    torch.save(mean_activations, f'{save_path_root}/{dataset_name}_mean_head_activations.pt')
    
