import os
import argparse
import torch
import math
from tqdm import tqdm
import numpy as np

from utils import model_utils, data_utils, evaluate_utils, seed_utils, extract_utils

if __name__ == "__main__":
    # # huggingface login
    # HF_TOKEN = os.getenv("HF_TOKEN")
    # login(token=HF_TOKEN)

    # set torch device
    torch.set_default_device('cuda' if torch.cuda.is_available() else 'cpu')

    parser = argparse.ArgumentParser()

    parser.add_argument('--model_name', type=str, required=True, help="Name of the model to load")
    parser.add_argument('--task_name', type=str, required=False, default='country-capital',
                        help="Name of the task to evaluate")
    parser.add_argument('--sub_FV_name', type=str, required=False, default='Ex1',
                        help='Name of the sub function vector')
    parser.add_argument('--token_types', type=int, required=True, default=6, help="Number of token types")
    parser.add_argument('--corrupted_dataset_name', type=str, required=False,
                        help="Name of the corrupted dataset for edge ablation")
    parser.add_argument('--train_dataset_name', type=str, required=False, help="Name of the training dataset")
    parser.add_argument('--inject_dataset_name', type=str, required=False,
                        help="Name of the dataset for FV intervention")
    parser.add_argument('--CIE_path', type=str, required=False, help="Path to the mean causal indirect effects file")
    parser.add_argument('--corruption_with_ablation', action='store_true',
                        help="Whether to apply corruption along with ablation")
    parser.add_argument('--device', type=str, required=False, default='cuda', help="Device to use for model inference")
    parser.add_argument('--seed', type=int, required=False, default=42, help="Random seed for reproducibility")

    args = parser.parse_args()
    model_name = args.model_name
    task_name = args.task_name
    sub_FV_name = args.sub_FV_name
    TOKEN_TYPES = args.token_types
    corrupted_dataset_name = args.corrupted_dataset_name
    train_dataset_name = args.train_dataset_name
    inject_dataset_name = args.inject_dataset_name
    CIE_path = args.CIE_path
    corruption_with_ablation = args.corruption_with_ablation
    device = args.device
    seed = int(args.seed)

    model_name_split = model_name.split('/')[1]

    # Set the device for PyTorch
    torch.set_default_device('cuda' if device == "cuda" else 'cpu')
    # Forbid gradient calculation
    torch.set_grad_enabled(False)

    # Set random seed
    seed_utils.set_seed(seed)

    # Load the datasets
    corrupted_dataset = data_utils.load_dataset(corrupted_dataset_name)
    train_dataset = data_utils.load_dataset(train_dataset_name)
    inject_dataset = data_utils.load_dataset(inject_dataset_name)

    # Make save directories
    DATA_SAVE_DIR = f"./logs/{model_name_split}-10shot/{task_name}"
    DATA_READ_DIR = f"./logs/{model_name_split}/{task_name}"
    if 'ambiguous-2' in task_name:
        DATA_READ_DIR = f"./logs/{model_name_split}/{task_name.replace('-2', '')}"
    os.makedirs(DATA_SAVE_DIR, exist_ok=True)
    os.makedirs(DATA_READ_DIR, exist_ok=True)

    # Load the model and tokenizer
    model_wrapper = model_utils.load_model_and_tokenizer(model_name, device=device)

    # Get mean causal indirect effects
    CIE_path = os.path.join(DATA_READ_DIR, CIE_path)
    if os.path.exists(CIE_path):
        # If file exists, load the mean causal indirect effects
        mean_causal_indirect_effects = torch.load(CIE_path, map_location=device)
        h_shape = mean_causal_indirect_effects.shape  # Shape: (Layers, Heads)
        topk_vals, topk_inds = torch.topk(mean_causal_indirect_effects.view(-1), k=20, largest=True)
        topk_inds = topk_inds.cpu().numpy()
        layer_idxs, head_idxs = np.unravel_index(topk_inds, h_shape)
        top_lh = list(zip(layer_idxs, head_idxs))
    else:
        raise FileNotFoundError(f"{CIE_path} does not exist. Please provide a valid path.")

    token_type_map = [0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
                      9, 9, 9, 9, 10, 10, 10, 10, 11, 12]
    ablate_edges_map = []
    corruption_ablate_edges_map = []
    kv_ablation_tp = None
    if sub_FV_name == 'fully_clean':
        ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero")]
    elif sub_FV_name == 'fully_context':
        ablate_edges_map = [
            (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"), (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"),
            (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"), (10, 11, "zero")
        ]
    elif sub_FV_name == 'full_model':
        ablate_edges_map = []
    elif sub_FV_name == 'full_model_q_only_context_examples':
        ablate_edges_map = [
            (None, None, "q_replace")]
        corruption_ablate_edges_map = [
            (11, 12, "zero")]
    elif sub_FV_name == 'full_model_q_only_context_xn':
        ablate_edges_map = [
            (None, None, "q_replace")]
        corruption_ablate_edges_map = [
            (1, 12, "zero"), (2, 12, "zero"), (3, 12, "zero"), (4, 12, "zero"), (5, 12, "zero"), (6, 12, "zero"),
            (7, 12, "zero"), (8, 12, "zero"), (9, 12, "zero"), (10, 12, "zero")]
    elif sub_FV_name == 'full_model_q_all':
        ablate_edges_map = [
            (None, None, "q_replace")]
    elif sub_FV_name == 'fully_clean_q_full_model':
        ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero"),
            (None, None, "q_replace")]
        corruption_ablate_edges_map = []
    elif sub_FV_name == 'full_model_q_fully_clean':
        ablate_edges_map = [
            (None, None, "q_replace")]
        corruption_ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero")]
    elif sub_FV_name == 'fully_clean_k_all_clean_Ex4710':
        ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero"),
            (None, None, "k_replace")]
        corruption_ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero")]
        kv_ablation_tp = [4, 7, 10]
    elif sub_FV_name == 'fully_clean_k_all_clean_am_Ex4710':
        ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero"),
            (None, None, "k_replace")]
        corruption_ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero")]
        kv_ablation_tp = [1, 2, 3, 5, 6, 8, 9]
    elif sub_FV_name == 'fully_clean_k_all_clean_Ex369':
        ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero"),
            (None, None, "k_replace")]
        corruption_ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero")]
        kv_ablation_tp = [3, 6, 9]
    elif sub_FV_name == 'fully_clean_k_all_clean_am_Ex369':
        ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero"),
            (None, None, "k_replace")]
        corruption_ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero")]
        kv_ablation_tp = [1, 2, 4, 5, 7, 8, 10]
    elif sub_FV_name == 'fully_clean_v_all_full_model_Ex4710':
        ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero"),
            (None, None, "v_replace")]
        corruption_ablate_edges_map = []
        kv_ablation_tp = [4, 7, 10]
    elif sub_FV_name == 'full_model_v_all_clean_Ex4710':
        ablate_edges_map = [
            (None, None, "v_replace")]
        corruption_ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero")]
        kv_ablation_tp = [4, 7, 10]
    elif sub_FV_name == 'fully_clean_v_all_full_model_Ex369':
        ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero"),
            (None, None, "v_replace")]
        corruption_ablate_edges_map = []
        kv_ablation_tp = [3, 6, 9]
    elif sub_FV_name == 'fully_clean_v_all_full_model':
        ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero"),
            (None, None, "v_replace")]
        corruption_ablate_edges_map = []
        kv_ablation_tp = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    elif sub_FV_name == 'full_model_v_all_clean_Ex369':
        ablate_edges_map = [
            (None, None, "v_replace")]
        corruption_ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero")]
        kv_ablation_tp = [3, 6, 9]
    elif sub_FV_name == 'full_model_v_all_clean':
        ablate_edges_map = [
            (None, None, "v_replace")]
        corruption_ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero")]
        kv_ablation_tp = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    elif sub_FV_name == 'fully_clean_extract_all':
        top_head = 1
        layer, head = top_lh[top_head - 1]
        ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero"),
            (layer, head, "q_extract"), (layer, head, "k_extract"), (layer, head, "v_extract")]
    elif sub_FV_name == 'full_model_extract_all':
        top_head = 1
        layer, head = top_lh[top_head - 1][0], top_lh[top_head - 1][1]
        ablate_edges_map = [
            (layer, head, "q_extract"), (layer, head, "k_extract"), (layer, head, "v_extract")]
    elif sub_FV_name == 'fully_clean_q_extract':
        top_head = 1
        layer, head = top_lh[top_head - 1]
        ablate_edges_map = [
            (1, 2, "zero"), (1, 3, "zero"), (1, 4, "zero"), (1, 5, "zero"), (1, 6, "zero"), (1, 7, "zero"),
            (1, 8, "zero"), (1, 9, "zero"), (1, 10, "zero"), (2, 3, "zero"), (2, 4, "zero"), (2, 5, "zero"),
            (2, 6, "zero"), (2, 7, "zero"), (2, 8, "zero"), (2, 9, "zero"), (2, 10, "zero"), (3, 4, "zero"),
            (3, 5, "zero"), (3, 6, "zero"), (3, 7, "zero"), (3, 8, "zero"), (3, 9, "zero"), (3, 10, "zero"),
            (4, 5, "zero"), (4, 6, "zero"), (4, 7, "zero"), (4, 8, "zero"), (4, 9, "zero"), (4, 10, "zero"),
            (5, 6, "zero"), (5, 7, "zero"), (5, 8, "zero"), (5, 9, "zero"), (5, 10, "zero"), (6, 7, "zero"),
            (6, 8, "zero"), (6, 9, "zero"), (6, 10, "zero"), (7, 8, "zero"), (7, 9, "zero"), (7, 10, "zero"),
            (8, 9, "zero"), (8, 10, "zero"), (9, 10, "zero"), (1, 11, "zero"), (2, 11, "zero"), (3, 11, "zero"),
            (4, 11, "zero"), (5, 11, "zero"), (6, 11, "zero"), (7, 11, "zero"), (8, 11, "zero"), (9, 11, "zero"),
            (10, 11, "zero"),
            (layer, head, "q_extract")]
    elif sub_FV_name == 'full_model_q_extract':
        top_head = 1
        layer, head = top_lh[top_head - 1][0], top_lh[top_head - 1][1]
        ablate_edges_map = [
            (layer, head, "q_extract")]

    model_wrapper.set_ablation_task_config(TOKEN_TYPES,
                                           ablate_edges_map=ablate_edges_map,
                                           token_type_map=token_type_map,
                                           kv_ablation_tp=kv_ablation_tp,
                                           corruption_with_ablation=corruption_with_ablation,
                                           corruption_ablate_edges_map=corruption_ablate_edges_map)
    model_wrapper.break_into()

    # Extract and save per prompt head activation
    # Shape: (batch_size, num_layers, num_heads, head_dim)
    _, per_prompt_head_activation, ablate_success_rate = extract_utils.get_ablated_mean_head_activation(model_wrapper,
                                                                                                        train_dataset,
                                                                                                        corrupted_dataset,
                                                                                                        sub_FV_name)
    if 'full' in sub_FV_name:
        attn_score_path = os.path.join(DATA_SAVE_DIR, f"attn_scores_{sub_FV_name}_{corrupted_dataset_name}.pt")
        torch.save(model_wrapper.attention_scores, attn_score_path)

    # Save QKV vectors if extracting
    if 'extract_all' in sub_FV_name:
        if model_wrapper.q_vectors is not None:
            q_vectors_path = os.path.join(DATA_SAVE_DIR,
                                          f"per_prompt_q_vectors_{sub_FV_name}_{corrupted_dataset_name}_{train_dataset_name}.pt")
            torch.save(model_wrapper.q_vectors, q_vectors_path)
        if model_wrapper.k_vectors is not None:
            k_vectors_path = os.path.join(DATA_SAVE_DIR,
                                          f"per_prompt_k_vectors_{sub_FV_name}_{corrupted_dataset_name}_{train_dataset_name}.pt")
            torch.save(model_wrapper.k_vectors, k_vectors_path)
        if model_wrapper.v_vectors is not None:
            v_vectors_path = os.path.join(DATA_SAVE_DIR,
                                          f"per_prompt_v_vectors_{sub_FV_name}_{corrupted_dataset_name}_{train_dataset_name}.pt")
            torch.save(model_wrapper.v_vectors, v_vectors_path)
        exit(0)
    elif 'extract_q' in sub_FV_name:
        if model_wrapper.q_vectors is not None:
            q_vectors_path = os.path.join(DATA_SAVE_DIR,
                                          f"per_prompt_q_vectors_{sub_FV_name}_{corrupted_dataset_name}_{train_dataset_name}.pt")
            torch.save(model_wrapper.q_vectors, q_vectors_path)
        exit(0)

    batch_size, layer_num, head_num, logic_seq_len, head_dim = per_prompt_head_activation.shape
    hidden_size = model_wrapper.model_config['hidden_size']  # residual stream width

    FV_path = os.path.join(DATA_SAVE_DIR,
                           f"per_prompt_function_vector_{train_dataset_name}_{corrupted_dataset_name}_{sub_FV_name}.pt")
    function_vectors_storage = torch.zeros((batch_size, 1, hidden_size)).to(device)

    for i in tqdm(range(batch_size), desc="Extracting per prompt function vector", leave=False):
        # Extract the function vector and top heads
        function_vector, top_heads = extract_utils.get_function_vector(model_wrapper, per_prompt_head_activation[i],
                                                                       mean_causal_indirect_effects, top_k=30)
        # Save the function vector for each prompt
        function_vectors_storage[i] = function_vector

    torch.save(function_vectors_storage, FV_path)

    if 'full' not in sub_FV_name:
        exit(0)

    # Evaluate the function vector intervention effects
    # scaled_factor = [1.0, 2.0, 4.0, 12.0, 24.0] # Llama 3
    scaled_factor = [10.0, 15.0, 25.0]  # Gemma2

    # Recover the original forward function
    model_wrapper.break_out()

    # Get the total number of layers from the model's configuration
    num_layers = model_wrapper.model_config['n_layers']

    # Calculate the intervention limit: ceiling of (total layers / 3)
    # This determines the number of initial layers to test the intervention on.
    # intervention_limit = math.ceil(2 * num_layers / 3) + 1
    intervention_limit = math.ceil(num_layers / 3)  # Gemma2

    # Create the list of layers to intervene on (from layer 0 up to the limit)
    intervened_layer_list = list(range(intervention_limit))

    intervened_df, clean_success_rate = evaluate_utils.evaluate_intervened_success_rate_on_per_prompt(model_wrapper,
                                                                                                      scaled_factor,
                                                                                                      function_vectors_storage,
                                                                                                      inject_dataset,
                                                                                                      intervened_layer_list=
                                                                                                      intervened_layer_list)

    intervened_df_test_dataset_save_path = os.path.join(DATA_SAVE_DIR,
                                                        f'per_prompt_intervened_success_rate_{train_dataset_name}_{corrupted_dataset_name}_{sub_FV_name}.csv')

    intervened_df.to_csv(
        intervened_df_test_dataset_save_path,
        index=False,
        encoding='utf-8')