import torch
import string
import re
import numpy as np
import torch.nn.functional as F
from tqdm.auto import tqdm
from .prompt_utils import *
from .model_utils import *
from .intervention_utils import *

def compute_top_k_accuracy(target_token_ranks, k=10) -> float:
    """
    Evaluation to compute topk accuracy.

    Parameters:
    target_token_ranks: the distribution of output token ranks
    k: how many tokens we're looking at (top K)

    Return:
    The accuracy of the token in the top k of tokens
    """

    target_token_ranks = np.array(target_token_ranks)
    return (target_token_ranks < k).sum(axis=0) / len(target_token_ranks) 

def compute_individual_token_rank(prob_dist, target_id) -> int:
    """
    Individual computation of token ranks across a single distribution.

    Parameters:
    prob_dist: the distribution of scores for a single output
    target_id: the target id we care about

    Return:
    A single value representing the token rank for that single token
    """
    if isinstance(target_id, list):
        target_id = target_id[0]

    return torch.where(torch.argsort(prob_dist.squeeze(), descending=True) == target_id)[0].item()


def compute_best_token_rank(prob_dist, target_ids) -> int:
    """
    Computes the best rank given a list of potential targets (target_ids) for a given probability distribution (prob_dist)
    """
    related_token_ranks = [compute_individual_token_rank(prob_dist, x) for x in target_ids]
    return min(related_token_ranks)

def compute_top_k_elements(x, K=10) -> list:
    """
    Computes the top k elements of a torch tensor (x), and returns them as a list of index tuples
    """
    h_shape = x.shape
    topk_vals, topk_inds  = torch.topk(x.view(-1), k=K, largest=True)
    top_lh = list(zip(*np.unravel_index(topk_inds, h_shape), [round(x.item(),4) for x in topk_vals]))
    top_elements = top_lh[:K]
    return top_elements

def decode_to_vocab(prob_dist, tokenizer, k=5) -> list:
    """
    Decodes and returns the top K words of a probability distribution

    Parameters:
    prob_dist: torch tensor of model logits (distribution over the vocabulary)
    tokenizer: huggingface model tokenizer
    k: number of vocabulary words to include

    Returns:
    list of top K decoded vocabulary words in the probability distribution as strings, along with their probabilities (float)
    """
    get_topk = lambda  x,K=1: torch.topk(torch.softmax(x, dim=-1), dim=-1, k=K)
    if not isinstance(prob_dist, torch.Tensor):
        prob_dist = torch.Tensor(prob_dist)

    return [(tokenizer.decode(x),round(y.item(), 5)) for x,y in zip(get_topk(prob_dist,k).indices[0],get_topk(prob_dist,k).values[0])]

def get_answer_id(query, answer, tokenizer):
    """
    Parameters:
    query (str): query as a string
    answer (str): expected answer as a string
    tokenizer: huggingface tokenizer
    
    Returns: 
    answer_ids (list): A list of the contextualized tokens of the answer
    """
    source = tokenizer(query, truncation=False, padding=False).input_ids
    target = tokenizer(query + answer, truncation=False, padding=False).input_ids
    assert len(source) < len(target) < tokenizer.model_max_length
    answer_ids = target[len(source): ]
    return answer_ids

def fv_to_vocab(function_vector, model, model_config, tokenizer, n_tokens=10):
    """
    Decodes a provided function vector into the model's vocabulary embedding space.

    Parameters:
    function_vector: torch vector extracted from ICL contexts that represents a particular function
    model: huggingface model
    model_config: dict with model information - n_layers, n_heads, etc.
    tokenizer: huggingface tokenizer
    n_tokens: number of top tokens to include in the decoding

    Returns:
    decoded_tokens: list of tuples of the form [(token, probability), ...]
    """

    if 'gpt' in model_config['name_or_path']:
        decoder = torch.nn.Sequential(model.transformer.ln_f, model.lm_head, torch.nn.Softmax(dim=-1))
    elif 'llama' in model_config['name_or_path'].lower():
        decoder = torch.nn.Sequential(model.model.norm, model.lm_head, torch.nn.Softmax(dim=-1))
    else:
        raise ValueError("Model not yet supported")
    
    d_out = decoder(function_vector.reshape(1,1,model_config['resid_dim']).to(model.device))

    vals, inds = torch.topk(d_out, k=n_tokens,largest=True)
    decoded_tokens = [(tokenizer.decode(x),round(y.item(), 4)) for x,y in zip(inds.squeeze(), vals.squeeze())]
    return decoded_tokens

def is_nontrivial_prefix(prediction: str, target: str) -> bool:
    """Return true if prediction is (case insensitive) prefix of the target."""
    target = target.lower().strip()
    prediction = prediction.lower().strip()
    return len(prediction) > 0 and target.startswith(prediction)

# Evaluate a sentence
def sentence_eval(sentence, target, model, model_config, tokenizer, edit_layer=None,
                    compute_nll=True, generate_str=False, pred_file=None, metric_fn=None):
    """
    Evaluate a single sentence completion for a model, comparing to the given target.

    Parameters:
    sentence: sentence to have the model process and predict
    target: expected response of the model
    model: huggingface model
    tokenizer: huggingface tokenizer
    edit_layer: layer index (activation is extracted from layer following it)
    compute_nll: whether to compute the negative log likelihood of a teacher-forced answer prompt (used for computing PPL)
    generate_str: whether to generate a string of tokens or predict a single token
    pred_file: filepath to save intermediate generations for debugging
    metric: metric to use for longer generations (F1, exact match, etc.)

    Returns:
    model output on the provided sentence
    """
    # Clean Run, No Intervention:
    device = model.device
    inputs = tokenizer(sentence, return_tensors='pt').to(device)
    original_pred_idx = len(inputs.input_ids.squeeze()) - 1
    active = edit_layer is not None

    with TraceDict(model, layers=model_config['layer_hook_names'], retain_input=True, retain_output=True) as td:
        if compute_nll:
            target_completion = "".join(sentence + target)
            nll_inputs = tokenizer(target_completion, return_tensors='pt').to(device)
            nll_targets = nll_inputs.input_ids.clone()
            target_len = len(nll_targets.squeeze()) - len(inputs.input_ids.squeeze()) 
            nll_targets[:,:-target_len] = -100  # This is the accepted value to skip indices when computing loss in nn.CrossEntropyLoss

            output = model(**nll_inputs, labels=nll_targets)

            clean_nll = output.loss.item()
            clean_output = output.logits[:,original_pred_idx,:]
        elif generate_str:
            MAX_NEW_TOKENS = 16
            output = model.generate(inputs.input_ids, top_p=0.9, temperature=0.1,
                                    max_new_tokens=MAX_NEW_TOKENS,
                                    pad_token_id=tokenizer.eos_token_id)
            output_str = tokenizer.decode(output.squeeze()[-MAX_NEW_TOKENS:])
            parsed_str, score = parse_generation(output_str, target, metric_fn)
            if pred_file:
                pred_file.write(f"{parsed_str.strip()}\n")
        else:
            clean_output = model(**inputs).logits[:,-1,:]

    if active:
        layer_td = model_config['layer_hook_names'][edit_layer+1]
        input_act = td[layer_td].input if len(td[layer_td].input) > 0 else 0
        input_act = input_act[0][:,-1] if len(input_act) > 1 else input_act[:,-1]
        output_act = td[layer_td].output[0][:,-1] - input_act   

    clean_results = (score,) if generate_str else (clean_output,)
    if compute_nll:
        clean_results += (clean_nll,)
    if active:
        clean_results += (output_act,)
    return clean_results

def n_shot_eval(dataset, fv_vector, edit_layer: int, n_shots: int, model, model_config, tokenizer, shuffle_labels:bool=False,
                filter_set=None, prefixes=None, separators=None, generate_str=False, pred_filepath=None, mean_act=False,
                metric="f1_score", test_split='test', context_split='train', dataname=""):
    """
    Evaluate a model and FV intervention on the model using the provided ICL dataset.

    Parameters:
    dataset: ICL dataset
    function_vector: torch vector that triggers execution of a task when added to a particular layer
    edit_layer: layer index 
    n_shots: the number of ICL examples in each in-context prompt
    model: huggingface model
    model_config: contains model config information (n layers, n heads, etc.)
    tokenizer: huggingface tokenizer
    shuffle_labels: Whether to shuffle the ICL labels or not
    filter_set: whether to only include samples the model gets correct via ICL
    prefixes: dict of ICL template prefixes for each ICL component (input, output, instructions)
    separators: dict of ICL template separators for each ICL component (input, output, instructions)
    generate_str: whether to generate a string of tokens or predict a single token
    pred_filepath: filepath to save intermediate generations for debugging
    mean_act: whether to extract the activations from the layer following FV intervention
    metric: metric to use for longer generations (F1, exact match, etc.)
    test_split: dataset split to use for query
    context_split: dataset split to use for context examples
    dataname: only for tdqm logging

    Returns:
    results: dict of topk accuracy on the test dataset, for both the model's n-shot, and n-shot + FV intervention, as well as the token rank of each prediction
    """
    assert(test_split != context_split)
    clean_rank_list = []
    intervention_rank_list = []
    mact_list = []

    if generate_str:
        clean_score_list = []
        intervention_score_list = []

    # If the model already prepends a bos token by default, we don't want to add one
    prepend_bos =  False if model_config['prepend_bos'] else True

    if filter_set is None:
        filter_set = np.arange(len(dataset[test_split]))

    if pred_filepath:
        pred_file = open(pred_filepath, 'w')
    else:
        pred_file = None        

    for j in tqdm(range(len(dataset[test_split])), total=len(dataset[test_split]), ncols=50, desc=dataname, leave=False):
        if j not in filter_set:
            continue
        if n_shots == 0:
            word_pairs = {'input':[], 'output':[]}
        else:
            word_pairs = dataset[context_split][np.random.choice(len(dataset[context_split]),
                                                min(n_shots,len(dataset[context_split])), replace=False)]
        word_pairs_test = dataset[test_split][j]

        if prefixes is not None and separators is not None:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, 
                                                    shuffle_labels=shuffle_labels, prefixes=prefixes, separators=separators)
        else:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, shuffle_labels=shuffle_labels)
            
        # Get relevant parts of the Prompt
        query, target = prompt_data['query_target']['input'], prompt_data['query_target']['output']
        query = query[0] if isinstance(query, list) else query

        if generate_str:
            target = [target] if not isinstance(target, list) else target
        else:
            target = target[0] if isinstance(target, list) else target
        
        sentence = [create_prompt(prompt_data)]
        
        # Figure out token of interest        
        target_token_id = get_answer_id(sentence[0], target, tokenizer)
        mact = None

        if generate_str:
            if metric == "f1_score":
                metric_fn = f1_score
            elif metric == "exact_match_score":
                metric_fn = exact_match_score
            elif metric == "first_word_score":
                metric_fn = first_word_score
            else:
                raise ValueError(f"Unknown metric: {metric}. Recognized metrics: [\"f1_score\", \"exact_match_score\"]")
            clean_output, intervention_output = function_vector_intervention(sentence, target = target, edit_layer = edit_layer, 
                                                                            function_vector = fv_vector,
                                                                            model=model, model_config=model_config, tokenizer=tokenizer, 
                                                                            compute_nll=False, generate_str=generate_str)
            clean_parsed_str, clean_score = parse_generation(clean_output, target, metric_fn)
            intervention_parsed_str, intervention_score = parse_generation(intervention_output, target, metric_fn)
            
            clean_score_list.append(clean_score)
            intervention_score_list.append(intervention_score)

            if pred_file:
                pred_file.write(f"{clean_parsed_str.strip()}\t|||\t{intervention_parsed_str}\n")

        else:
            all_outputs = function_vector_intervention(sentence, target = [target], edit_layer = edit_layer,
                                                                                         function_vector = fv_vector, model=model,
                                                                                         model_config=model_config, tokenizer=tokenizer,
                                                                                         compute_nll=False, active=mean_act)
            clean_output, intervention_output = all_outputs[:2]
            if mean_act:
                mact = all_outputs[-1]
                mact_list = mact if mact_list == [] else torch.cat((mact_list, mact), 0)
        
            clean_rank = compute_individual_token_rank(clean_output, target_token_id)
            intervention_rank = compute_individual_token_rank(intervention_output, target_token_id)
            
            clean_rank_list.append(clean_rank)
            intervention_rank_list.append(intervention_rank)

    if generate_str:
        results = {"clean_score": clean_score_list,
                   "intervention_score": intervention_score_list} 
    else:      
        results = {"clean_topk": [(K, compute_top_k_accuracy(clean_rank_list, K)) for K in range(1,6)],
                   "clean_rank_list": clean_rank_list,
                   
                   "intervention_topk": [(K, compute_top_k_accuracy(intervention_rank_list, K)) for K in range(1,6)],
                   "intervention_rank_list":intervention_rank_list}
    if mact_list != []: results["activation"] = mact_list
    
    if pred_filepath:
        pred_file.close()
    
    return results

def analogy_eval(dataset, fv_vector, edit_layer: int, model, model_config, tokenizer, shuffle_labels:bool=False,
                filter_set=None, prefixes=None, separators=None, generate_str=False, pred_filepath=None, mean_act=False,
                metric="f1_score", test_split='test', context_split='gold', dataname=""):
    """
    Evaluate a model and FV intervention on the model's analogies.

    Parameters:
    dataset: ICL dataset
    function_vector: torch vector that triggers execution of a task when added to a particular layer
    edit_layer: layer index
    model: huggingface model
    model_config: contains model config information (n layers, n heads, etc.)
    tokenizer: huggingface tokenizer
    shuffle_labels: Whether to shuffle the ICL labels or not
    filter_set: whether to only include samples the model gets correct via ICL
    prefixes: dict of ICL template prefixes for each ICL component (input, output, instructions)
    separators: dict of ICL template separators for each ICL component (input, output, instructions)
    generate_str: whether to generate a string of tokens or predict a single token
    mean_act: whether to extract the activations from the layer following FV intervention
    metric: metric to use for longer generations (F1, exact match, etc.)
    test_split: dataset split to use for query
    context_split: dataset split to use for context examples
    dataname: only for tdqm logging

    Returns:
    results: dict of topk accuracy on the test dataset, for both the model's n-shot, and n-shot + FV intervention, as well as the token rank of each prediction
    """
    assert(test_split != context_split)
    clean_rank_list = []
    intervention_rank_list = []
    mact_list = []

    if generate_str:
        clean_score_list = []
        intervention_score_list = []

    # If the model already prepends a bos token by default, we don't want to add one
    prepend_bos =  False if model_config['prepend_bos'] else True

    if filter_set is None:
        filter_set = np.arange(len(dataset[test_split]))

    if pred_filepath:
        pred_file = open(pred_filepath, 'w')
    else:
        pred_file = None        

    for j in tqdm(range(len(dataset[test_split])), total=len(dataset[test_split]), ncols=50, desc=dataname, leave=False):
        if j not in filter_set: continue
        word_pairs_test = dataset[test_split][j]
        for k in range(len(dataset[context_split])):
            word_pairs = dataset[context_split][k:k+1]

            if prefixes is not None and separators is not None:
                prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, 
                                                        shuffle_labels=shuffle_labels, prefixes=prefixes, separators=separators)
            else:
                prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, shuffle_labels=shuffle_labels)
                
            # Get relevant parts of the Prompt
            query, target = prompt_data['query_target']['input'], prompt_data['query_target']['output']
            query = query[0] if isinstance(query, list) else query

            if generate_str:
                target = [target] if not isinstance(target, list) else target
            else:
                target = target[0] if isinstance(target, list) else target
            
            sentence = [create_prompt(prompt_data)]
            
            # Figure out token of interest        
            target_token_id = get_answer_id(sentence[0], target, tokenizer)
            mact = None

            if generate_str:
                if metric == "f1_score":
                    metric_fn = f1_score
                elif metric == "exact_match_score":
                    metric_fn = exact_match_score
                elif metric == "first_word_score":
                    metric_fn = first_word_score
                else:
                    raise ValueError(f"Unknown metric: {metric}. Recognized metrics: [\"f1_score\", \"exact_match_score\"]")
                clean_output, intervention_output = function_vector_intervention(sentence, target = target, edit_layer = edit_layer, 
                                                                                function_vector = fv_vector,
                                                                                model=model, model_config=model_config, tokenizer=tokenizer, 
                                                                                compute_nll=False, generate_str=generate_str)
                clean_parsed_str, clean_score = parse_generation(clean_output, target, metric_fn)
                intervention_parsed_str, intervention_score = parse_generation(intervention_output, target, metric_fn)
                
                clean_score_list.append(clean_score)
                intervention_score_list.append(intervention_score)

                if pred_file:
                    pred_file.write(f"{clean_parsed_str.strip()}\t|||\t{intervention_parsed_str}\n")

            else:
                fv_outputs = function_vector_intervention(sentence, target = [target], edit_layer = edit_layer,
                                                          function_vector = fv_vector, model=model,
                                                          model_config=model_config, tokenizer=tokenizer,
                                                          compute_nll=False, active=mean_act)
                clean_output, intervention_output = fv_outputs[:2]
                if mean_act:
                    mact = fv_outputs[-1]
                    mact_list = mact if mact_list == [] else torch.cat((mact_list, mact), 0)
            
                clean_rank = compute_individual_token_rank(clean_output, target_token_id)
                intervention_rank = compute_individual_token_rank(intervention_output, target_token_id)
                
                clean_rank_list.append(clean_rank)
                intervention_rank_list.append(intervention_rank)

    if generate_str:
        results = {"clean_score": clean_score_list,
                   "intervention_score": intervention_score_list} 
    else:      
        results = {"clean_topk": [(K, compute_top_k_accuracy(clean_rank_list, K)) for K in range(1,6)],
                   "clean_rank_list": clean_rank_list,
                   
                   "intervention_topk": [(K, compute_top_k_accuracy(intervention_rank_list, K)) for K in range(1,6)],
                   "intervention_rank_list":intervention_rank_list}
    if mact_list != []: results["activation"] = mact_list
    
    if pred_filepath:
        pred_file.close()
    
    return results

def composite_eval(dataset, fv_vectors, edit_layer: int, model, model_config, tokenizer,
                   intervention_prob_dict, shuffle_labels:bool=False,
                   filter_set=None, prefixes=None, separators=None, generate_str=False,
                   test_split='test', soft=False, dataname=""):
    """
    Evaluate a model and FV intervention on the model using the provided ICL dataset.

    Parameters:
    dataset: ICL dataset
    function_vector: torch vector that triggers execution of a task when added to a particular layer
    edit_layer: layer index 
    model: huggingface model
    model_config: contains model config information (n layers, n heads, etc.)
    tokenizer: huggingface tokenizer
    intervention_prob_dict: dictionary to store the posterior distribution of FV intervention
    shuffle_labels: Whether to shuffle the ICL labels or not
    filter_set: whether to only include samples the model gets correct via ICL
    prefixes: dict of ICL template prefixes for each ICL component (input, output, instructions)
    separators: dict of ICL template separators for each ICL component (input, output, instructions)
    generate_str: whether to generate a string of tokens or predict a single token
    test_split: dataset split to use for query
    soft: whether to use softmax, otherwise use log_softmax
    dataname: only for tdqm logging

    Returns:
    results: dict of topk accuracy on the test dataset, for both the model's n-shot, and n-shot + FV intervention, as well as the token rank of each prediction
    """

    # If the model already prepends a bos token by default, we don't want to add one
    prepend_bos =  False if model_config['prepend_bos'] else True
    
    if filter_set is None:
        filter_set = np.arange(len(dataset[test_split])) 

    word_pairs = {'input':[], 'output':[]}
    for j in tqdm(range(len(dataset[test_split])), total=len(dataset[test_split]), ncols=50, desc=dataname, leave=False):
    # for j in range(len(dataset[test_split])):
        if j not in filter_set: continue
        word_pairs_test = dataset[test_split][j]

        if prefixes is not None and separators is not None:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, 
                                                    shuffle_labels=shuffle_labels, prefixes=prefixes, separators=separators)
        else:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, shuffle_labels=shuffle_labels)
            
        # Get relevant parts of the Prompt
        query, target = prompt_data['query_target']['input'], prompt_data['query_target']['output']
        query = query[0] if isinstance(query, list) else query

        if generate_str:
            target = [target] if not isinstance(target, list) else target
        else:
            target = target[0] if isinstance(target, list) else target

        pair = f'{query.strip()}:{target.strip()}'
        if pair not in intervention_prob_dict: intervention_prob_dict[pair] = []
        
        sentence = [create_prompt(prompt_data)]
        
        # Figure out token of interest        
        target_token_id = get_answer_id(sentence[0], target, tokenizer)
        mact = None
        inter_probs = []

        for fv_vector in fv_vectors:
            _, intervention_output = function_vector_intervention(sentence, target = [target], edit_layer = edit_layer,
                                                      function_vector = fv_vector, model=model,
                                                      model_config=model_config, tokenizer=tokenizer,
                                                      compute_nll=False)
            
            # clean_prob = clean_output[target_token_id]
            fsoft = F.softmax if soft else F.log_softmax
            intervention_prob = torch.mean(fsoft(intervention_output.squeeze(), dim=0)[target_token_id])
            
            inter_probs.append(intervention_prob)
        intervention_prob_dict[pair].append(torch.stack(inter_probs))
    
    return intervention_prob_dict

def one_pair_active(data, fv_vector, edit_layer: int, model, model_config, tokenizer, shuffle_labels:bool=False,
                prefixes=None, separators=None, generate_str=False, pred_filepath=None, context=None):
    """
    Evaluate a model and FV intervention on the word pair or analogy.

    Parameters:
    dataset: ICL dataset
    function_vector: torch vector that triggers execution of a task when added to a particular layer
    edit_layer: layer index 
    n_shots: the number of ICL examples in each in-context prompt
    model: huggingface model
    model_config: contains model config information (n layers, n heads, etc.)
    tokenizer: huggingface tokenizer
    shuffle_labels: Whether to shuffle the ICL labels or not
    prefixes: dict of ICL template prefixes for each ICL component (input, output, instructions)
    separators: dict of ICL template separators for each ICL component (input, output, instructions)
    generate_str: whether to generate a string of tokens or predict a single token
    pred_filepath: filepath to save intermediate generations for debugging
    context: word pair to use for context

    Returns:
    results: dict of topk accuracy on the word pair, for both the model and FV intervention, as well as the token rank of each prediction
    """
    clean_rank_list = []
    intervention_rank_list = []
    intervention_words_list = []
    mact_list = []

    # If the model already prepends a bos token by default, we don't want to add one
    prepend_bos =  False if model_config['prepend_bos'] else True

    if pred_filepath:
        pred_file = open(pred_filepath, 'w')
    else:
        pred_file = None        

    # for j in tqdm(range(len(data)), total=len(data), ncols=40):
    for j in range(len(data)):
        word_pairs = {'input':[], 'output':[]} if context is None \
            else {c:[context[c]] for c in context}
        word_pairs_test = data[j]

        if prefixes is not None and separators is not None:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, 
                                                    shuffle_labels=shuffle_labels, prefixes=prefixes, separators=separators)
        else:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, shuffle_labels=shuffle_labels)
            
        # Get relevant parts of the Prompt
        query, target = prompt_data['query_target']['input'], prompt_data['query_target']['output']
        query = query[0] if isinstance(query, list) else query

        if generate_str:
            target = [target] if not isinstance(target, list) else target
        else:
            target = target[0] if isinstance(target, list) else target
        
        sentence = [create_prompt(prompt_data)]
        
        # Figure out token of interest        
        target_token_id = get_answer_id(sentence[0], target, tokenizer)
        clean_output, intervention_output, mact = function_vector_intervention(sentence, target = [target],
                                                      edit_layer = edit_layer, function_vector = fv_vector,
                                                      model=model, model_config=model_config, tokenizer=tokenizer,
                                                      compute_nll=False, active=True) 
        mact_list = mact if mact_list == [] else torch.cat((mact_list, mact), 0)

        clean_rank = compute_individual_token_rank(clean_output, target_token_id)
        intervention_rank = compute_individual_token_rank(intervention_output, target_token_id)
        
        clean_rank_list.append(clean_rank)
        intervention_rank_list.append(intervention_rank)
    clean_words = [de[0] for de in decode_to_vocab(clean_output, tokenizer, 5)]
    intervention_words = [de[0] for de in decode_to_vocab(intervention_output, tokenizer, 5)]

    results = {"clean_topk": [(K, float(compute_top_k_accuracy(clean_rank_list, K) > 0.0)) for K in range(1,6)],
               "clean_rank_list": clean_rank_list,
               "clean_words":clean_words,
               
               "intervention_topk": [(K, float(compute_top_k_accuracy(intervention_rank_list, K) > 0.0)) for K in range(1,6)],
               "intervention_rank_list":intervention_rank_list,
               "intervention_words":intervention_words,
               "activation": mact_list}
    
    if pred_filepath: pred_file.close()
    
    return results

def one_pair_no_intervention(data, model, model_config, tokenizer, active_layer=12, compute_ppl=True, shuffle_labels=False,
                             prefixes=None, separators=None, pred_filepath=None, context=None):
    """
    Evaluate a model (without any interventions) on the provided word pair.

    Parameters:
    dataset: ICL dataset
    n_shots: the number of ICL examples in each in-context prompt
    active_layer: layer index (takes activation from following layer)
    model: huggingface model
    model_config: contains model config information (n layers, n heads, etc.)
    tokenizer: huggingface tokenizer
    compute_ppl: whether to compute perplexity of teacher-forced correct completion for base model & intervened model
    generate_str: whether to generate a string of tokens or predict a single token
    shuffle_labels: Whether to shuffle the ICL labels or not
    prefixes: dict of ICL template prefixes for each ICL component (input, output, instructions)
    separators: dict of ICL template separators for each ICL component (input, output, instructions)
    pred_filepath: filepath to save intermediate generations for debugging
    context: word pair to use for context

    Returns:
    results: dict of topk (k=1,2,3) accuracy on the test_split dataset, for both the model's n-shot
    """
    clean_rank_list = []
    clean_words_list = []
    active = active_layer is not None

    clean_nll_list = []

    mact_list = []

    # If the model already prepends a bos token by default, we don't want to add one
    prepend_bos =  False if model_config['prepend_bos'] else True

    if pred_filepath:
        pred_file = open(pred_filepath, 'w')
    else:
        pred_file = None        

    for j in range(len(data)):
        word_pairs = {'input':[], 'output':[]} if context is None \
            else {c:[context[c]] for c in context}
        word_pairs_test = data[j]

        if prefixes is not None and separators is not None:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, 
                                                    shuffle_labels=shuffle_labels, prefixes=prefixes, separators=separators)
        else:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, shuffle_labels=shuffle_labels)
            
        # Get relevant parts of the Prompt
        query, target = prompt_data['query_target']['input'], prompt_data['query_target']['output']
        query = query[0] if isinstance(query, list) else query

        target = target[0] if isinstance(target, list) else target
        
        sentence = [create_prompt(prompt_data)]

        # Figure out tokens of interest
        target_token_id = get_answer_id(sentence[0], target, tokenizer)
        
        clean_outputs = sentence_eval(sentence, target = [target], edit_layer=active_layer,
                                      model=model, tokenizer=tokenizer, model_config=model_config,
                                      compute_nll=compute_ppl)
        clean_output = clean_outputs[0]
        if compute_ppl: clean_nll_list.append(clean_outputs[1])
        if active:
            mact = clean_outputs[-1]
            mact_list = mact if mact_list == [] else torch.cat((mact_list, mact), 0)
            
        clean_rank = compute_individual_token_rank(clean_output, target_token_id)
        clean_rank_list.append(clean_rank)
    clean_words = [de[0] for de in decode_to_vocab(clean_output, tokenizer, 5)]

    results = {"clean_topk": [(K, float(compute_top_k_accuracy(clean_rank_list, K) > 0)) \
                              for K in range(1,6)],
               "clean_rank_list": clean_rank_list,
               "clean_words": clean_words}
    if compute_ppl:
        results['clean_ppl'] = np.exp(clean_nll_list).mean()
    if mact_list != []: results["activation"] = mact_list

    if pred_filepath:
        pred_file.close()
    
    return results

# Evaluate few-shot dataset w/o intervention
def n_shot_eval_no_intervention(dataset, n_shots, model, model_config, tokenizer, compute_ppl=True, generate_str=False,
                                filter_set=None, shuffle_labels=False, prefixes=None, separators=None, active_layer=None,
                                test_split='test', context_split='train'):
    """
    Evaluate a model (without any interventions) on the provided ICL dataset.

    Parameters:
    dataset: ICL dataset
    n_shots: the number of ICL examples in each in-context prompt
    model: huggingface model
    model_config: contains model config information (n layers, n heads, etc.)
    tokenizer: huggingface tokenizer
    compute_ppl: whether to compute perplexity of teacher-forced correct completion for base model & intervened model
    generate_str: whether to generate a string of tokens or predict a single token
    shuffle_labels: Whether to shuffle the ICL labels or not
    prefixes: dict of ICL template prefixes for each ICL component (input, output, instructions)
    separators: dict of ICL template separators for each ICL component (input, output, instructions)
    pred_filepath: filepath to save intermediate generations for debugging
    metric: metric to use for longer generations (F1, exact match, etc.)
    test_split: the dataset test split to use as the "test" dataset, typically set to 'test' or 'valid'

    Returns:
    results: dict of topk (k=1,2,3) accuracy on the test_split dataset, for both the model's n-shot
    """
    clean_rank_list = []
    mact_list = []
    active = active_layer is not None

    if compute_ppl: clean_nll_list = []

    # If the model already prepends a bos token by default, we don't want to add one
    prepend_bos =  False if model_config['prepend_bos'] else True

    if filter_set is None:
        filter_set = np.arange(len(dataset[test_split]))

    for j in range(len(dataset[test_split])):
        if j not in filter_set: continue
        if n_shots == 0:
            word_pairs = {'input':[], 'output':[]}
        else:
            word_pairs = dataset[context_split][np.random.choice(len(dataset[context_split]), n_shots, replace=False)]
        word_pairs_test = dataset[test_split][j]
        if prefixes is not None and separators is not None:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, 
                                                    shuffle_labels=shuffle_labels, prefixes=prefixes, separators=separators)
        else:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, shuffle_labels=shuffle_labels)
            
        # Get relevant parts of the Prompt
        query, target = prompt_data['query_target']['input'], prompt_data['query_target']['output']
        query = query[0] if isinstance(query, list) else query
        if generate_str:
            target = [target] if not isinstance(target, list) else target
        else:
            target = target[0] if isinstance(target, list) else target
        
        sentence = [create_prompt(prompt_data)]
        
        # Figure out tokens of interest
        target_token_id = get_answer_id(sentence[0], target, tokenizer)
        
        clean_outputs = sentence_eval(sentence, target = [target], edit_layer=active_layer,
                                      model=model, tokenizer=tokenizer, model_config=model_config,
                                      compute_nll=compute_ppl)
        clean_output = clean_outputs[0]
        if compute_ppl: clean_nll_list.append(clean_outputs[1])
        if active:
            mact = clean_outputs[-1]
            mact_list = mact if mact_list == [] else torch.cat((mact_list, mact), 0)
            
        clean_rank = compute_individual_token_rank(clean_output, target_token_id)
        clean_rank_list.append(clean_rank)

    results = {"clean_topk": [(K, compute_top_k_accuracy(clean_rank_list, K)) for K in range(1,6)],
               "clean_rank_list": clean_rank_list,
               }
    if mact_list != []: results["activation"] = mact_list
    
    return results

def analogy_eval_no_intervention(dataset, model, model_config, tokenizer, compute_ppl=True, generate_str=False,
                                filter_set=None, shuffle_labels=False, prefixes=None, separators=None, pred_filepath=None,
                                active_layer=12, test_split='test', context_split='gold'):
    """
    Evaluate a model (without any interventions) on the provided ICL dataset.

    Parameters:
    dataset: ICL dataset
    n_shots: the number of ICL examples in each in-context prompt
    model: huggingface model
    model_config: contains model config information (n layers, n heads, etc.)
    tokenizer: huggingface tokenizer
    compute_ppl: whether to compute perplexity of teacher-forced correct completion for base model & intervened model
    generate_str: whether to generate a string of tokens or predict a single token
    shuffle_labels: Whether to shuffle the ICL labels or not
    prefixes: dict of ICL template prefixes for each ICL component (input, output, instructions)
    separators: dict of ICL template separators for each ICL component (input, output, instructions)
    pred_filepath: filepath to save intermediate generations for debugging
    metric: metric to use for longer generations (F1, exact match, etc.)
    test_split: the dataset test split to use as the "test" dataset, typically set to 'test' or 'valid'

    Returns:
    results: dict of topk (k=1,2,3) accuracy on the test_split dataset, for both the model's n-shot
    """
    clean_rank_list = []
    mact_list = []
    active = active_layer is not None

    clean_nll_list = []

    if generate_str:
        score_list = []

    # If the model already prepends a bos token by default, we don't want to add one
    prepend_bos =  False if model_config['prepend_bos'] else True

    if filter_set is None:
        filter_set = np.arange(len(dataset[test_split]))

    if pred_filepath:
        pred_file = open(pred_filepath, 'w')
    else:
        pred_file = None

    # for j in tqdm(range(len(dataset[test_split])), total=len(dataset[test_split]), ncols=40):
    for j in range(len(dataset[test_split])):
        if j not in filter_set: continue
        word_pairs_test = dataset[test_split][j]
        for k in range(len(dataset[context_split])):
            word_pairs = dataset[context_split][k:k+1]
            if prefixes is not None and separators is not None:
                prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, 
                                                        shuffle_labels=shuffle_labels, prefixes=prefixes, separators=separators)
            else:
                prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, shuffle_labels=shuffle_labels)
                
            # Get relevant parts of the Prompt
            query, target = prompt_data['query_target']['input'], prompt_data['query_target']['output']
            query = query[0] if isinstance(query, list) else query
            if generate_str:
                target = [target] if not isinstance(target, list) else target
            else:
                target = target[0] if isinstance(target, list) else target
            
            sentence = [create_prompt(prompt_data)]
            
            # Figure out tokens of interest
            target_token_id = get_answer_id(sentence[0], target, tokenizer)
            
            clean_outputs = sentence_eval(sentence, target = [target], edit_layer=active_layer,
                                                    model=model, tokenizer=tokenizer, model_config=model_config, 
                                                    compute_nll=compute_ppl)
            clean_output = clean_outputs[0]
            if compute_ppl: clean_nll_list.append(clean_outputs[1])
            if active:
                mact = clean_outputs[-1]
                mact_list = mact if mact_list == [] else torch.cat((mact_list, mact), 0)
                
            clean_rank = compute_individual_token_rank(clean_output, target_token_id)
            clean_rank_list.append(clean_rank)

    if generate_str:
        results = {"score": score_list}
    else:
        results = {"clean_topk": [(K, compute_top_k_accuracy(clean_rank_list, K)) for K in range(1,6)],
                   "clean_rank_list": clean_rank_list,
                #    "clean_output_list": clean_output_list
                   }
    if compute_ppl:
        results['clean_ppl'] = np.exp(clean_nll_list).mean()
    if mact_list != []: results["activation"] = mact_list

    if pred_filepath:
        pred_file.close()
    
    return results

# Logic from huggingface `evaluate` library
def normalize_answer(s):
    """Lowercase text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def f1_score(prediction, ground_truth):
    """Harmonic mean of pred overlap with gold and gold overlap with pred."""
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def exact_match_score(prediction, ground_truth):
    """Only correct if the prediction matches the entire answer."""
    return normalize_answer(prediction) == normalize_answer(ground_truth)

def first_word_score(prediction, ground_truth):
    """Only correct if the predicted first word matches the answer's first word."""
    prediction = normalize_answer(prediction).split()
    ground_truth = normalize_answer(ground_truth).split()
    if len(prediction) > 0 and len(ground_truth) > 0:
        return prediction[0] == ground_truth[0]
    else:
        return len(prediction) == len(ground_truth)

def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    """Pick maximum score across possible answers."""
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)

def parse_generation(output_str, target, metric_fn):
    """Parse a generated string for the target, and score using the specified metric"""
    ans_regex = re.compile("([\w. ]+)[\nQ]*")
    parsed_str = ans_regex.findall(output_str)
    if len(parsed_str) > 0:
        parsed_str = parsed_str[0]
        score = metric_max_over_ground_truths(metric_fn, parsed_str, target)
    else:
        score = 0.0
    
    return parsed_str, score

def make_valid_path_name(path: str):
    """
    Returns an updated path name if given name already exists
    """
    file_name, extension = os.path.splitext(path)
    counter = 1

    while os.path.exists(path):
        path = file_name + "_(" + str(counter) + ")" + extension
        counter += 1

    return path