import os, re, json
from tqdm import tqdm

import torch, numpy as np
import torch.nn as nn
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer

import argparse

from nethook import TraceDict, get_module, get_parameter

# Include prompt creation helper functions
from utils import make_valid_path_name
from utils.prompt_utils import *
from utils.intervention_utils import *
from utils.model_utils import *
from utils.eval_utils import *
from compute_average_activations import get_mean_head_activations
from compute_indirect_effect import compute_indirect_effect


def compute_function_pointer(mean_activations, indirect_effect, model, model_config, n_top_heads = 10, token_class_idx=-1):
    """
        Computes a "function pointer" vector that communicates the task observed in ICL examples used for downstream intervention.
        
        Parameters:
        mean_activations: tensor of size (Layers, Heads, Tokens, head_dim) containing the average activation of each head for a particular task
        indirect_effect: tensor of size (N, Layers, Heads, class(optional)) containing the indirect_effect of each head across N trials
        model: huggingface model being used
        n_top_heads: The number of heads to use when computing the summed function pointer
        token_class_idx: int indicating which token class to use, -1 is default for last token computations

        Returns:
        function_pointer: vector representing the communication of a particular task
        top_heads: list of the top influential heads represented as tuples [(L,H,S), ...], (L=Layer, H=Head, S= Avg. indirect_effect Score)         
    """
    model_resid_dim = model_config['resid_dim']
    model_n_heads = model_config['n_heads']
    model_head_dim = model_resid_dim//model_n_heads
    device = model.device

    # Example:
    # GPT2-XL: resid_dim = 1600, n_heads = 25, => head_dim = 64
    # GPT-J: resid_dim = 4096, n_heads = 16, => head_dim = 256

    li_dims = len(indirect_effect.shape)
    
    if li_dims == 3 and token_class_idx == -1:
        mean_indirect_effect = indirect_effect.mean(dim=0)
    else:
        assert(li_dims == 4)
        mean_indirect_effect = indirect_effect[:,:,:,token_class_idx].mean(dim=0) # Subset to token class of interest

    # Compute Top Influential Heads (L,H)
    h_shape = mean_indirect_effect.shape 
    topk_vals, topk_inds  = torch.topk(mean_indirect_effect.view(-1), k=n_top_heads, largest=True)
    top_lh = list(zip(*np.unravel_index(topk_inds, h_shape), [round(x.item(),4) for x in topk_vals]))
    top_heads = top_lh[:n_top_heads]

    # Compute Function Pointer Vector as sum of influential heads
    function_pointer = torch.zeros((1,1,model_resid_dim)).to(device)
    T = -1 # Intervention & values taken from last token

    for L,H,_ in top_heads:
        if 'gpt2-xl' in model_config['name_or_path']:
            out_proj = model.transformer.h[L].attn.c_proj
        elif 'gpt-j' in model_config['name_or_path']:
            out_proj = model.transformer.h[L].attn.out_proj
        elif 'llama' in model_config['name_or_path']:
            out_proj = model.model.layers[L].self_attn.o_proj
        elif 'gpt-neox' in model_config['name_or_path']:
            out_proj = model.gpt_neox.layers[L].attention.dense

        x = torch.zeros(model_resid_dim)
        x[H*model_head_dim:(H+1)*model_head_dim] = mean_activations[L,H,T]
        d_out = out_proj(x.reshape(1,1,model_resid_dim).to(device).to(model.dtype))

        function_pointer += d_out
    
    function_pointer = function_pointer.to(model.dtype)

    return function_pointer, top_heads

def compute_universal_function_pointer(mean_activations, model, model_config, n_top_heads=10):
    """
        Computes a "function pointer" vector that communicates the task observed in ICL examples used for downstream intervention.
        
        Parameters:
        mean_activations: tensor of size (Layers, Heads, Tokens, head_dim) containing the average activation of each head for a particular task
        model: huggingface model being used
        token_class_idx: int indicating which token class to use, -1 is default for last token computations

        Returns:
        function_pointer: vector representing the communication of a particular task
        top_heads: list of the top influential heads represented as tuples [(L,H,S), ...], (L=Layer, H=Head, S= Avg. indirect_effect Score)         
    """
    model_resid_dim = model_config['resid_dim']
    model_n_heads = model_config['n_heads']
    model_head_dim = model_resid_dim//model_n_heads
    device = model.device

    # Universal Set of Heads
    
    if 'gpt-j' in model_config['name_or_path']:
        # AIE = torch.load('results/gptj/gptj_AIE.pt')
        # top_heads = compute_top_k_elements(AIE, n_heads)
        top_heads = [(15, 5, 0.0587), (9, 14, 0.0584), (12, 10, 0.0526), (8, 1, 0.0445), (11, 0, 0.0445), (13, 13, 0.019), (8, 0, 0.0184), (14, 9, 0.016), (9, 2, 0.0127), (24, 6, 0.0113), (15, 11, 0.0092),
                     (6, 6, 0.0069), (14, 0, 0.0068), (17, 8, 0.0068), (21, 2, 0.0067), (10, 11, 0.0066), (11, 2, 0.0057), (17, 0, 0.0054), (20, 11, 0.0051), (23, 0, 0.0047), (20, 0, 0.0046), (15, 7, 0.0045),
                     (27, 2, 0.0045), (21, 15, 0.0044), (11, 4, 0.0044), (18, 6, 0.0043), (9, 6, 0.0042), (4, 12, 0.004), (11, 15, 0.004), (20, 2, 0.0036), (10, 0, 0.0035), (16, 9, 0.0031), (11, 14, 0.0031),
                     (12, 4, 0.003), (9, 7, 0.003), (18, 3, 0.003), (19, 5, 0.003), (22, 5, 0.0027), (25, 3, 0.0026), (18, 9, 0.0025)]
        top_heads = top_heads[:n_top_heads]
    elif 'Llama-2-7b' in model_config['name_or_path']:
        top_heads = [(14, 1, 0.0391), (11, 2, 0.0225), (9, 25, 0.02), (12, 15, 0.0196), (12, 28, 0.0191), (13, 7, 0.0171), (11, 18, 0.0152), (12, 18, 0.0113), (16, 10, 0.007), (14, 16, 0.007),
                     (14, 14, 0.0048), (16, 1, 0.0042), (18, 1, 0.0042), (19, 16, 0.0041), (13, 30, 0.0034), (18, 26, 0.0032), (14, 7, 0.0032), (16, 0, 0.0031), (16, 29, 0.003), (29, 30, 0.003),
                     (16, 6, 0.0029), (15, 11, 0.0027), (12, 11, 0.0026), (11, 22, 0.0023), (16, 19, 0.0021), (15, 23, 0.002), (16, 20, 0.0019), (15, 9, 0.0019), (17, 28, 0.0019), (14, 18, 0.0018),
                     (8, 26, 0.0018), (29, 26, 0.0018), (15, 8, 0.0018), (13, 13, 0.0017), (30, 9, 0.0017), (13, 23, 0.0017), (13, 10, 0.0016), (11, 30, 0.0016), (12, 26, 0.0015), (19, 27, 0.0015),
                     (14, 9, 0.0014), (14, 10, 0.0013), (31, 17, 0.0013), (31, 4, 0.0013), (15, 17, 0.0013), (10, 5, 0.0012), (14, 11, 0.0012), (19, 12, 0.0012), (16, 7, 0.0012), (15, 24, 0.0011),
                     (26, 28, 0.0011), (11, 15, 0.0011), (15, 25, 0.0011), (17, 12, 0.0011), (13, 2, 0.0011), (14, 5, 0.0011), (14, 3, 0.001), (26, 30, 0.001), (27, 29, 0.001), (25, 12, 0.0009),
                     (15, 13, 0.0009), (10, 14, 0.0009), (28, 13, 0.0009), (17, 19, 0.0008), (19, 2, 0.0008), (12, 23, 0.0008), (15, 26, 0.0008), (28, 21, 0.0008), (15, 10, 0.0008), (12, 0, 0.0007),
                     (6, 16, 0.0007), (7, 28, 0.0007), (27, 7, 0.0007), (11, 28, 0.0007), (29, 15, 0.0006), (13, 8, 0.0006), (13, 17, 0.0006), (8, 0, 0.0006), (22, 17, 0.0006), (22, 20, 0.0006), 
                     (12, 2, 0.0006), (26, 9, 0.0006), (31, 26, 0.0006), (22, 27, 0.0005), (16, 26, 0.0005), (13, 1, 0.0005), (26, 2, 0.0005), (30, 10, 0.0005), (11, 25, 0.0005), (29, 20, 0.0005),
                     (19, 15, 0.0005), (12, 10, 0.0005), (12, 3, 0.0005), (30, 5, 0.0004), (6, 9, 0.0004), (15, 16, 0.0004), (23, 28, 0.0004), (22, 5, 0.0004), (31, 19, 0.0004), (26, 14, 0.0004)]
    elif 'Llama-2-13b' in model_config['name_or_path']:
        top_heads = [(13, 13, 0.0402), (12, 17, 0.0332), (15, 38, 0.0269), (14, 34, 0.0209), (19, 2, 0.0116), (19, 36, 0.0106), (13, 4, 0.0106), (18, 11, 0.01), (10, 15, 0.0087), (13, 23, 0.0077),
                     (14, 7, 0.0074), (15, 36, 0.0046), (12, 8, 0.0046), (17, 7, 0.0044), (38, 29, 0.0043), (15, 32, 0.0037), (17, 18, 0.0034), (16, 9, 0.0033), (14, 23, 0.0032), (39, 13, 0.0029),
                     (39, 14, 0.0027), (18, 22, 0.0026), (21, 32, 0.0026), (15, 18, 0.0026), (13, 14, 0.0026), (11, 31, 0.0025), (14, 39, 0.0024), (19, 14, 0.0023), (36, 23, 0.0021), (21, 7, 0.0021),
                     (8, 23, 0.002), (18, 18, 0.002), (17, 28, 0.002), (17, 9, 0.0019), (13, 27, 0.0017), (13, 34, 0.0017), (13, 12, 0.0016), (21, 2, 0.0016), (16, 16, 0.0015), (15, 31, 0.0015),
                     (26, 35, 0.0015), (10, 18, 0.0014), (11, 27, 0.0014), (13, 25, 0.0014), (15, 26, 0.0013), (5, 32, 0.0013), (20, 12, 0.0013), (18, 15, 0.0013), (16, 23, 0.0013), (25, 5, 0.0013),
                     (34, 6, 0.0012), (15, 2, 0.0012), (15, 27, 0.0012), (18, 20, 0.0012), (16, 19, 0.0011), (37, 4, 0.001), (19, 7, 0.001), (19, 3, 0.0009), (38, 14, 0.0009), (20, 21, 0.0009),
                     (21, 30, 0.0009), (16, 11, 0.0009), (13, 24, 0.0009), (9, 31, 0.0008), (14, 13, 0.0008), (16, 29, 0.0008), (15, 17, 0.0008), (19, 6, 0.0008), (23, 36, 0.0008), (18, 17, 0.0007),
                     (15, 34, 0.0007), (14, 29, 0.0007), (15, 7, 0.0007), (13, 17, 0.0007), (20, 11, 0.0007), (35, 16, 0.0007), (39, 27, 0.0007), (29, 27, 0.0006), (30, 24, 0.0006), (19, 37, 0.0006),
                     (39, 21, 0.0006), (13, 36, 0.0006), (37, 30, 0.0006), (16, 36, 0.0006), (15, 3, 0.0006), (19, 13, 0.0006), (13, 10, 0.0006), (14, 19, 0.0005), (36, 3, 0.0005), (15, 25, 0.0005),
                     (16, 0, 0.0005), (16, 10, 0.0005), (20, 29, 0.0005), (25, 13, 0.0005), (14, 36, 0.0005), (36, 7, 0.0005), (17, 0, 0.0005), (11, 37, 0.0005), (23, 18, 0.0005), (35, 10, 0.0005)]
    elif 'gpt-neox' in model_config['name_or_path']:
        top_heads = [(9, 42, 0.0293), (12, 4, 0.0224), (9, 28, 0.019), (11, 57, 0.0079), (10, 43, 0.0073), (12, 14, 0.0069), (14, 31, 0.0065), (9, 23, 0.0057), (11, 21, 0.0054), (11, 4, 0.0052),
                     (9, 21, 0.0052), (18, 23, 0.005), (13, 9, 0.0048), (14, 49, 0.0048), (12, 20, 0.0047), (8, 30, 0.0045), (12, 59, 0.0043), (16, 42, 0.0039), (11, 34, 0.0038), (9, 33, 0.0038),
                     (9, 3, 0.0036), (11, 48, 0.0035), (14, 63, 0.0032), (18, 11, 0.0032), (13, 7, 0.003), (9, 27, 0.0029), (11, 23, 0.0029), (16, 30, 0.0027), (10, 17, 0.0026), (9, 55, 0.0024),
                     (11, 38, 0.0024), (11, 59, 0.0024), (20, 8, 0.0024), (15, 42, 0.0023), (11, 47, 0.0023), (9, 15, 0.0023), (8, 47, 0.0023), (10, 40, 0.0023), (18, 18, 0.0022), (9, 1, 0.0021),
                     (13, 12, 0.0021), (14, 5, 0.002), (16, 18, 0.0019), (13, 63, 0.0019), (9, 20, 0.0018), (26, 38, 0.0018), (21, 60, 0.0017), (17, 55, 0.0016), (17, 30, 0.0016), (10, 56, 0.0015),
                     (12, 3, 0.0015), (10, 16, 0.0014), (10, 0, 0.0013), (15, 62, 0.0013), (12, 15, 0.0013), (9, 34, 0.0013), (12, 18, 0.0013), (23, 46, 0.0012), (16, 53, 0.0012), (11, 1, 0.0011),
                     (9, 2, 0.0011), (10, 27, 0.0011), (23, 54, 0.0011), (16, 54, 0.0011), (12, 30, 0.0011), (11, 14, 0.0011), (16, 44, 0.001), (14, 27, 0.001), (26, 31, 0.001), (15, 0, 0.001),
                     (13, 46, 0.001), (15, 57, 0.001), (15, 17, 0.001), (19, 12, 0.0009), (9, 49, 0.0009), (10, 7, 0.0009), (19, 46, 0.0009), (8, 21, 0.0009), (25, 24, 0.0008), (19, 29, 0.0008),
                     (12, 21, 0.0008), (8, 18, 0.0008), (12, 35, 0.0008), (9, 10, 0.0008), (19, 40, 0.0008), (38, 5, 0.0008), (13, 31, 0.0007), (10, 38, 0.0007), (10, 12, 0.0007), (11, 31, 0.0007),
                     (10, 1, 0.0007), (23, 15, 0.0007), (13, 40, 0.0007), (9, 5, 0.0007), (22, 33, 0.0007), (13, 36, 0.0006), (8, 32, 0.0006), (16, 21, 0.0006), (14, 11, 0.0006), (13, 61, 0.0006)]
        top_heads = top_heads[:n_top_heads]

    # Compute Function Pointer Vector as sum of influential heads
    function_pointer = torch.zeros((1,1,model_resid_dim)).to(device)
    T = -1 # Intervention & values taken from last token

    for L,H,_ in top_heads:
        if 'gpt2-xl' in model_config['name_or_path']:
            out_proj = model.transformer.h[L].attn.c_proj
        elif 'gpt-j' in model_config['name_or_path']:
            out_proj = model.transformer.h[L].attn.out_proj
        elif 'llama' in model_config['name_or_path']:
            out_proj = model.model.layers[L].self_attn.o_proj
        elif 'gpt-neox' in model_config['name_or_path']:
            out_proj = model.gpt_neox.layers[L].attention.dense

        x = torch.zeros(model_resid_dim)
        x[H*model_head_dim:(H+1)*model_head_dim] = mean_activations[L,H,T]
        d_out = out_proj(x.reshape(1,1,model_resid_dim).to(device).to(model.dtype))

        function_pointer += d_out
        function_pointer = function_pointer.to(model.dtype)
    function_pointer = function_pointer.reshape(1, model_resid_dim)

    return function_pointer, top_heads


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset_name', help='Name of the dataset to be loaded', type=str, required=True)
    parser.add_argument('--n_top_heads', help='Number of attenion head outputs used to compute function pointer', required=False, type=int, default=10)
    parser.add_argument('--edit_layer', help='Layer for intervention. If -1, sweep over all layers', type=int, required=False, default=-1) # 
    parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b')
    parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='dataset_files')
    parser.add_argument('--save_path_root', help='File path to save to', type=str, required=False, default='results')
    parser.add_argument('--seed', help='Randomized seed', type=int, required=False, default=42)
    parser.add_argument('--device', help='Device to run on',type=str, required=False, default='cuda' if torch.cuda.is_available() else 'cpu')
    parser.add_argument('--mean_activations_path', help='Path to file containing mean_head_activations for the specified task', required=False, type=str, default=None)
    parser.add_argument('--indirect_effect_path', help='Path to file containing indirect_effect scores for the specified task', required=False, type=str, default=None)    
    parser.add_argument('--test_split', help="Percentage corresponding to test set split size", required=False, default=0.3)    
    parser.add_argument('--n_shots', help="Number of shots in each in-context prompt", type=int, required=False, default=10)
    parser.add_argument('--n_trials', help="Number of in-context prompts to average over for indirect_effect", type=int, required=False, default=25)
    parser.add_argument('--prefixes', help='Prompt template prefixes to be used', type=json.loads, required=False, default={"input":"Q:", "output":"A:", "instructions":""})
    parser.add_argument('--separators', help='Prompt template separators to be used', type=json.loads, required=False, default={"input":"\n", "output":"\n\n", "instructions":""})    
    parser.add_argument('--compute_baseline', help='Whether to compute the model baseline 0-shot -> n-shot performance', type=bool, required=False, default=True)
    parser.add_argument('--generate_str', help='Whether to generate long-form completions for the task', action='store_true', required=False)
    parser.add_argument("--metric", help="Metric to use when evaluating generated strings", type=str, required=False, default="f1_score")
    parser.add_argument("--universal_set", help="Flag for whether to evaluate using the univeral set of heads", action="store_true", required=False)
        
    args = parser.parse_args()  

    dataset_name = args.dataset_name
    model_name = args.model_name
    root_data_dir = args.root_data_dir
    save_path_root = f"{args.save_path_root}/{dataset_name}"
    seed = args.seed
    device = args.device
    mean_activations_path = args.mean_activations_path
    indirect_effect_path = args.indirect_effect_path
    n_top_heads = args.n_top_heads
    eval_edit_layer = args.edit_layer

    test_split = float(args.test_split)
    n_shots = args.n_shots
    n_trials = args.n_trials

    prefixes = args.prefixes 
    separators = args.separators
    compute_baseline = args.compute_baseline

    generate_str = args.generate_str
    metric = args.metric
    universal_set = args.universal_set

    print(args)

    # Load Model & Tokenizer
    torch.set_grad_enabled(False)
    print("Loading Model")
    model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device)

    if args.edit_layer == -1: # sweep over all layers if edit_layer=-1
        eval_edit_layer = [0, model_config['n_layers']]

    # Load the dataset
    print("Loading Dataset")
    set_seed(seed)
    dataset = load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed)

    if not os.path.exists(save_path_root):
        os.makedirs(save_path_root)

    print(f"Filtering Dataset via {n_shots}-shot Eval")
    # 1. Compute Model 10-shot Baseline & 2. Filter test set to cases where model gets it correct

    fs_results_file_name = f'{save_path_root}/fs_results_layer_sweep.json'
    print(fs_results_file_name)
    if os.path.exists(fs_results_file_name):
        with open(fs_results_file_name, 'r') as indata:
            fs_results = json.load(indata)
        key = 'score' if generate_str else 'clean_rank_list'
        target_val = 1 if generate_str else 0
        filter_set = np.where(np.array(fs_results[key]) == target_val)[0]
        filter_set_validation = None
    elif generate_str:
        set_seed(seed+42)
        fs_results_validation = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=False,
                                                 generate_str=True, metric=metric, test_split='valid')
        filter_set_validation = np.where(np.array(fs_results_validation['score']) == 1)[0]
        set_seed(seed)
        fs_results = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=False,
                                                 generate_str=True, metric=metric)
        filter_set = np.where(np.array(fs_results['score']) == 1)[0]
    else:
        set_seed(seed+42)
        fs_results_validation = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=True, test_split='valid')
        filter_set_validation = np.where(np.array(fs_results_validation['clean_rank_list']) == 0)[0]
        set_seed(seed)
        fs_results = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=True)
        filter_set = np.where(np.array(fs_results['clean_rank_list']) == 0)[0]
    
    args.fs_results_file_name = fs_results_file_name
    with open(fs_results_file_name, 'w') as results_file:
        json.dump(fs_results, results_file, indent=2)

    set_seed(seed)
    # Load or Re-Compute mean_head_activations
    if mean_activations_path is not None and os.path.exists(mean_activations_path):
        mean_activations = torch.load(mean_activations_path)
    elif mean_activations_path is None and os.path.exists(f'{save_path_root}/{dataset_name}_mean_head_activations.pt'):
        mean_activations_path = f'{save_path_root}/{dataset_name}_mean_head_activations.pt'
        mean_activations = torch.load(mean_activations_path)        
    else:
        print("Computing Mean Activations")
        set_seed(seed)
        mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer, n_icl_examples=n_shots,
                                                     N_TRIALS=n_trials, prefixes=prefixes, separators=separators, filter_set=filter_set_validation)
        args.mean_activations_path = f'{save_path_root}/{dataset_name}_mean_head_activations.pt'
        torch.save(mean_activations, args.mean_activations_path)

    # Load or Re-Compute indirect_effect values
    if indirect_effect_path is not None and os.path.exists(indirect_effect_path):
        indirect_effect = torch.load(indirect_effect_path)
    elif indirect_effect_path is None and os.path.exists(f'{save_path_root}/{dataset_name}_indirect_effect.pt'):
        indirect_effect_path = f'{save_path_root}/{dataset_name}_indirect_effect.pt'
        indirect_effect = torch.load(indirect_effect_path) 
    elif not universal_set:     # Only compute indirect effects if we need to
        print("Computing Indirect Effects")
        set_seed(seed)
        indirect_effect = compute_indirect_effect(dataset, mean_activations, model=model, model_config=model_config, tokenizer=tokenizer, n_shots=n_shots,
                                                  n_trials=n_trials, last_token_only=True, prefixes=prefixes, separators=separators, filter_set=filter_set_validation)
        args.indirect_effect_path = f'{save_path_root}/{dataset_name}_indirect_effect.pt'
        torch.save(indirect_effect, args.indirect_effect_path)
        
    # Compute Function Pointer
    if universal_set:
        fp, top_heads = compute_universal_function_pointer(mean_activations, model, model_config=model_config, n_top_heads=n_top_heads)   
        # Save to different folder:
        save_path_root = f"{args.save_path_root}_universal/{dataset_name}"
        if not os.path.exists(save_path_root):
            os.makedirs(save_path_root)

    else:
        fp, top_heads = compute_function_pointer(mean_activations, indirect_effect, model, model_config=model_config, n_top_heads=n_top_heads)   
    
    if isinstance(eval_edit_layer, int):
        print(f"Running ZS Eval with edit_layer={eval_edit_layer}")
        set_seed(seed)
        if generate_str:
            pred_filepath = f"results/preds/{model_config['name_or_path'].replace('/', '_')}_ZS_intervention_layer{eval_edit_layer}.txt"
            zs_results = n_shot_eval(dataset=dataset, fp_vector=fp, edit_layer=eval_edit_layer, n_shots=0,
                                     model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set,
                                     generate_str=generate_str, metric=metric, pred_filepath=pred_filepath)
        else:
            zs_results = n_shot_eval(dataset=dataset, fp_vector=fp, edit_layer=eval_edit_layer, n_shots=0,
                                    model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set)
        zs_results_file_suffix = f'_editlayer_{eval_edit_layer}.json'   


        print(f"Running {n_shots}-Shot Shuffled Eval")
        set_seed(seed)
        if generate_str:
            pred_filepath = f"results/preds/{model_config['name_or_path'].replace('/', '_')}_{n_shots}shots_shuffled_intervention_layer{eval_edit_layer}.txt"
            fs_shuffled_results = n_shot_eval(dataset=dataset, fp_vector=fp, edit_layer=eval_edit_layer, n_shots=n_shots, 
                                              model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set, shuffle_labels=True,
                                              generate_str=generate_str, metric=metric, pred_filepath=pred_filepath)
        else:
            fs_shuffled_results = n_shot_eval(dataset=dataset, fp_vector=fp, edit_layer=eval_edit_layer, n_shots=n_shots, 
                                              model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set, shuffle_labels=True)
        fs_shuffled_results_file_suffix = f'_editlayer_{eval_edit_layer}.json'   
        
    else:
        print(f"Running sweep over layers {eval_edit_layer}")
        zs_results = {}
        fs_shuffled_results = {}
        for edit_layer in range(eval_edit_layer[0], eval_edit_layer[1]):
            set_seed(seed)
            if generate_str:
                zs_results[edit_layer] = n_shot_eval(dataset=dataset, fp_vector=fp, edit_layer=edit_layer, n_shots=0, 
                                                    model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set,
                                                    generate_str=generate_str, metric=metric)
            else:
                zs_results[edit_layer] = n_shot_eval(dataset=dataset, fp_vector=fp, edit_layer=edit_layer, n_shots=0, 
                                                    model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set)
            set_seed(seed)
            if generate_str:
                fs_shuffled_results[edit_layer] = n_shot_eval(dataset=dataset, fp_vector=fp, edit_layer=edit_layer, n_shots=n_shots, 
                                                    model=model, model_config=model_config, tokenizer=tokenizer, filter_set = filter_set,
                                                    generate_str=generate_str, metric=metric, shuffle_labels=True)
            else:
                fs_shuffled_results[edit_layer] = n_shot_eval(dataset=dataset, fp_vector=fp, edit_layer=edit_layer, n_shots=n_shots, 
                                                    model=model, model_config=model_config, tokenizer=tokenizer, filter_set = filter_set, shuffle_labels=True)
        zs_results_file_suffix = '_layer_sweep.json'
        fs_shuffled_results_file_suffix = '_layer_sweep.json'


    # Save results to files
    zs_results_file_name = make_valid_path_name(f'{save_path_root}/zs_results' + zs_results_file_suffix)
    args.zs_results_file_name = zs_results_file_name
    with open(zs_results_file_name, 'w') as results_file:
        json.dump(zs_results, results_file, indent=2)
    
    fs_shuffled_results_file_name = make_valid_path_name(f'{save_path_root}/fs_shuffled_results' + fs_shuffled_results_file_suffix)
    args.fs_shuffled_results_file_name = fs_shuffled_results_file_name
    with open(fs_shuffled_results_file_name, 'w') as results_file:
        json.dump(fs_shuffled_results, results_file, indent=2)

    if compute_baseline:
        print(f"Computing model baseline results for {n_shots}-shots")
        baseline_results = compute_dataset_baseline(dataset, model, model_config, tokenizer, n_shots=n_shots, seed=seed)        
    
        baseline_file_name = make_valid_path_name(f'{save_path_root}/model_baseline.json')
        args.baseline_file_name = baseline_file_name
        with open(baseline_file_name, 'w') as results_file:
            json.dump(baseline_results, results_file, indent=2)

    # Write args to file
    args_file_name = make_valid_path_name(f'{save_path_root}/fp_eval_args.txt')
    with open(args_file_name, 'w') as arg_file:
        json.dump(args.__dict__, arg_file, indent=2)