
import torch
from brainmodels.utils import get_images, preprocess_images, get_activations
from brainmodels.models import EncoderBM, build_brainmodel
from brainmodels import models
from brainmodels.params import Params
import os
import numpy as np


device = 'cuda' if torch.cuda.is_available() else 'cpu'
from fgsm_small import (
    load_encoder_and_images,
    load_brain_data,
    prepare_activations,
    split_data,
    build_brainmodel_from_linear_model
)

layer_map = {
    'RN50-0': 'layer4.2.conv2',
    'densenet201_imagenet': 'denseblock3.denselayer40.conv1',
    'mobilenet_v2': '16.conv.3',
    'CORnet_RT': 'IT',
    'inception_v3': 'Mixed_7a',
    'vgg16': 'features.30',
    'alexnet': 'features.10',
    'squeezenet1_1': 'features.11.expand3x3_activation',
    'dinov2': 'blocks.11.mlp.fc1',
    'dreamsim_vitb32': 'model.blocks.8.attn.proj',
    'google_vit': 'encoder.layer.12.output.dense',
    'nomic': 'layers.6.mlp.fc12',
    'CLIP-RN50': 'layer4.0.conv1',
    'RN50-robust-8': 'layer4.2.conv2',
    'RN50-robust-4': 'layer4.2.conv2',
    'RN50-robust-2': 'layer4.2.conv2',
    'RN50-robust-1': 'layer4.2.conv2',
    'RN50-robust-0.5': 'layer4.2.conv2',
}





voxel_scores_final = []


modellist = ['RN50-0', 'densenet201_imagenet', 'mobilenet_v2', 'CORnet_RT', 'inception_v3', 'vgg16', 'alexnet', 'squeezenet1_1', 'dinov2', 'dreamsim_vitb32', 'google_vit', 'nomic', 'CLIP-RN50', 'RN50-robust-8', 'RN50-robust-4', 'RN50-robust-2', 'RN50-robust-1', 'RN50-robust-0.5']
for modelname in modellist:
    encoder_name = modelname
    layer_name = layer_map[encoder_name]
    image_path = None
    encoder, preprocess, preprocessed_images, params = load_encoder_and_images(encoder_name, layer_name, device, image_path)
    brain_data_path = 'nsd_processed/s1_FFA_t7.pt'
    brain_data = load_brain_data(brain_data_path)



    activations_dir = f'activations/{encoder_name}_{layer_name}'

    if os.path.exists(activations_dir):
        print(f'Activations already exist in {activations_dir}. Loading...')
        activations_flattened = torch.load(os.path.join(activations_dir, brain_data_path.split('/')[-1]), weights_only=False)
        print(f'Loaded activations from {activations_dir}')
    else:
        print(f'Activations do not exist in {activations_dir}. Computing...')
        activations, activations_flattened = prepare_activations(preprocessed_images, encoder, params)
        os.makedirs(activations_dir, exist_ok=True)
        activations_file = os.path.join(activations_dir, brain_data_path.split('/')[-1])
        torch.save(activations_flattened, activations_file)
        print(f'Saved activations to {activations_file}')



    from sklearn.linear_model import Ridge
    from sklearn.model_selection import KFold
    from scipy.stats import pearsonr
    import numpy as np
    from tqdm import tqdm

    brain_data.shape, activations_flattened.shape

    # choose 20 images for test, the rest will be train
    image_idcs = np.load('random_image_idcs.npy', allow_pickle=True)
    not_image_idcs = np.setdiff1d(np.arange(brain_data.shape[0]), image_idcs)

    brain_data_train = brain_data[not_image_idcs]
    brain_data_test = brain_data[image_idcs]

    activations_train = activations_flattened[not_image_idcs]
    activations_test = activations_flattened[image_idcs]

    num_voxels = brain_data_train.shape[1]
    num_features = activations_train.shape[1]

    num_folds = 5
    num_alphas = 20
    alphas = np.logspace(-1, 6, num_alphas)
    kf = KFold(n_splits=num_folds, shuffle=True, random_state=0)

    best_alpha_file = f'hyperparameters/{encoder_name}_{layer_name}_{brain_data_path.split("/")[-1]}_l2_best_alpha.pt'

    if os.path.exists(best_alpha_file):
        print(f'Loaded best alpha from {best_alpha_file}')
        best_alphas = torch.load(best_alpha_file, weights_only=False)
        voxel_scores_all = np.zeros((num_voxels, num_folds))
        
        for fold_idx, (train_idx, val_idx) in enumerate(kf.split(activations_train)):
            X_train, X_val = activations_train[train_idx], activations_train[val_idx]
            Y_train, Y_val = brain_data_train[train_idx], brain_data_train[val_idx]

            for v in range(num_voxels):
                model = Ridge(alpha=best_alphas[v], max_iter=5000)
                model.fit(X_train, Y_train[:, v])
                Y_pred = model.predict(X_val)
                r, _ = pearsonr(Y_val[:, v], Y_pred)
                voxel_scores_all[v, fold_idx] = r

        voxel_scores_all = voxel_scores_all.mean(axis=1)
        best_alpha_indices = np.array([np.argmin(np.abs(alphas - a)) for a in best_alphas])
    else:
        print(f'No best alpha file found! Performing cross-validation.')
        voxel_scores_all = np.zeros((num_voxels, num_alphas))

        for train_idx, val_idx in kf.split(activations_train):
            X_train, X_val = activations_train[train_idx], activations_train[val_idx]
            Y_train, Y_val = brain_data_train[train_idx], brain_data_train[val_idx]

            for a_idx, alpha in enumerate(alphas):
                model = Ridge(alpha=alpha, max_iter=5000)
                model.fit(X_train, Y_train)
                Y_pred = model.predict(X_val)
                for v in range(num_voxels):
                    r, _ = pearsonr(Y_val[:, v], Y_pred[:, v])
                    voxel_scores_all[v, a_idx] += r

        voxel_scores_all /= num_folds
        best_alpha_indices = np.argmax(voxel_scores_all, axis=1)
        best_alphas = alphas[best_alpha_indices]
        torch.save(best_alphas, best_alpha_file)
        print(f'Saved best alphas to {best_alpha_file}')

    final_weights = np.zeros((num_voxels, num_features))
    final_intercepts = np.zeros(num_voxels)

    for v in tqdm(range(num_voxels)):
        alpha = best_alphas[v]
        model = Ridge(alpha=alpha, max_iter=5000)
        model.fit(activations_train, brain_data_train[:, v])
        final_weights[v] = model.coef_
        final_intercepts[v] = model.intercept_
    if voxel_scores_all.ndim == 2:
        voxel_scores = voxel_scores_all[np.arange(num_voxels), best_alpha_indices]
    else:
        voxel_scores = voxel_scores_all  # already mean scores per voxel
    voxel_scores_final.append(voxel_scores)




# average over voxel_scores_final list
avg_models_voxel_scores_final = np.mean(voxel_scores_final, axis=0)


avg_models_voxel_scores_final

top20_voxel_indices = np.argsort(avg_models_voxel_scores_final)[-20:][::-1]


brain_data_train_new = brain_data_train[:, top20_voxel_indices.copy()].mean(axis=1).reshape(-1, 1)
activations_train_new = activations_train
brain_data_test_new = brain_data_test[:, top20_voxel_indices.copy()].mean(axis=1).reshape(-1, 1)
activations_test_new = activations_test

# now, train 


voxel_scores_final_2 = []
ridgemodels = []
modellist = ['RN50-0', 'densenet201_imagenet', 'mobilenet_v2', 'CORnet_RT', 'inception_v3', 'vgg16', 'alexnet', 'squeezenet1_1', 'dinov2', 'dreamsim_vitb32', 'google_vit', 'nomic', 'CLIP-RN50', 'RN50-robust-8', 'RN50-robust-4', 'RN50-robust-2', 'RN50-robust-1', 'RN50-robust-0.5']
for modelname in modellist:
    encoder_name = modelname
    layer_name = layer_map[encoder_name]
    image_path = None
    encoder, preprocess, preprocessed_images, params = load_encoder_and_images(encoder_name, layer_name, device, image_path)
    brain_data_path = 'nsd_processed/s1_FFA_t7.pt'
    brain_data = load_brain_data(brain_data_path)

    activations_dir = f'activations/{encoder_name}_{layer_name}'
    activations_file = os.path.join(activations_dir, brain_data_path.split('/')[-1])

    if os.path.exists(activations_file):
        activations_flattened = torch.load(activations_file, weights_only=False)
    else:
        activations, activations_flattened = prepare_activations(preprocessed_images, encoder, params)
        os.makedirs(activations_dir, exist_ok=True)
        torch.save(activations_flattened, activations_file)

    image_idcs = np.load('random_image_idcs.npy', allow_pickle=True)
    not_image_idcs = np.setdiff1d(np.arange(brain_data.shape[0]), image_idcs)

    brain_data_train = brain_data[not_image_idcs]
    activations_train = activations_flattened[not_image_idcs]

    # use same top20_voxel_indices across models
    brain_data_train_new = brain_data_train[:, top20_voxel_indices.copy()].mean(axis=1).reshape(-1, 1)

    num_folds = 5
    num_alphas = 20
    alphas = np.logspace(-1, 6, num_alphas)
    kf = KFold(n_splits=num_folds, shuffle=True, random_state=0)

    cv_scores = np.zeros(num_alphas)

    for a_idx, alpha in enumerate(alphas):
        scores = []
        for train_idx, val_idx in kf.split(activations_train):
            X_train, X_val = activations_train[train_idx], activations_train[val_idx]
            y_train, y_val = brain_data_train_new[train_idx], brain_data_train_new[val_idx]
            model = Ridge(alpha=alpha, max_iter=5000)
            model.fit(X_train, y_train)
            y_pred = model.predict(X_val)
            r, _ = pearsonr(y_val.flatten(), y_pred.flatten())
            scores.append(r)
        cv_scores[a_idx] = np.mean(scores)

    best_alpha = alphas[np.argmax(cv_scores)]
    best_score = np.max(cv_scores)

    final_model = Ridge(alpha=best_alpha, max_iter=5000)
    final_model.fit(activations_train, brain_data_train_new)
    ridgemodels.append(final_model)
    voxel_scores_final_2.append(best_score)
    print(f"{modelname}: best alpha = {best_alpha:.4f}, CV Pearson r = {best_score:.4f}")


import torch.nn as nn
class EncoderBM_new(nn.Module):

    def __init__(self, encoder, brainmodel, params):
        super().__init__()
        self.encoder = encoder
        self.brainmodel = brainmodel
        self.layer_name = params.layer_name
        self.mapping = params.mapping
        self.activation = {}
        self.gradients = {}

        if params.encoder == "CLIP-RN50":
            layer = dict([*self.encoder.visual.named_modules()])[self.layer_name]
        elif params.encoder.startswith("RN50"):
            layer = dict([*self.encoder.model.named_modules()])[self.layer_name]
        elif params.encoder.startswith("L2-RN50"):
            layer = dict([*self.encoder.model.named_modules()])[self.layer_name]
        elif params.encoder == "densenet201_imagenet":
            layer = dict([*self.encoder.named_modules()])[self.layer_name]
        elif params.encoder == "mobilenet_v2":
            layer = dict([*self.encoder.named_modules()])[self.layer_name]
        elif params.encoder == "CORnet_RT":
            layer = dict([*self.encoder.named_modules()])[self.layer_name]
        elif params.encoder == "inception_v3":
            layer = dict([*self.encoder.named_modules()])[self.layer_name]
        elif params.encoder == "inception_v2":
            layer = dict([*self.encoder.named_modules()])[self.layer_name]
        elif params.encoder == "vgg16":
            layer = dict([*self.encoder.named_modules()])[self.layer_name]
        elif params.encoder == "alexnet":
            layer = dict([*self.encoder.named_modules()])[self.layer_name]
        elif params.encoder == "squeezenet1_1":
            layer = dict([*self.encoder.named_modules()])[self.layer_name]
        elif params.encoder == "bagnet33":
            layer = dict([*self.encoder.named_modules()])[self.layer_name]
        elif params.encoder == "google_vit":
            layer = dict(self.encoder.encoder.named_modules())[self.layer_name.replace("encoder.", "", 1)]
        elif params.encoder == "dinov2":
            layer = dict([*self.encoder.named_modules()])[self.layer_name]
        elif params.encoder == "dreamsim_vitb32":
            layer_name_fixed = self.layer_name.replace("model.", "", 1)
            layer = dict([*self.encoder.model.named_modules()])[layer_name_fixed]
        elif params.encoder == "blip2":
            layer = dict([*self.encoder.named_modules()])[self.layer_name]
        elif params.encoder == "nomic":
            layer = dict([*self.encoder.named_modules()])[self.layer_name]



        # print("LAYER NAME ", self.layer_name)
        # print("LAYER: ", layer)

        layer.register_forward_hook(self.get_activation(self.layer_name, self.activation))
        layer.register_backward_hook(self.get_gradient(self.layer_name, self.gradients))

    def get_activation(self, name, activations):
        def hook(model, input, output):
            if isinstance(output, tuple):
                activations[name] = output[0]
            else:
                activations[name] = output
        
        # print(activations[name].shape)
        return hook

    def get_gradient(self, name, gradients):
        def hook(model, grad_input, grad_output):
            gradients[name] = grad_output[0]
        return hook

    def forward(self, x):
        _ = self.encoder.encode_image(x)
        x = self.activation[self.layer_name]
        # print("X shape: ", x.shape)
        if(self.mapping == "cnn"):
            print('cnn')
            x = x.to(torch.float32)
        elif x.ndim == 3:
            print('3')
            x = x.mean(dim=1).to(torch.float32)  # for ViT token averaging
        else:
            # print('else')
            x = torch.flatten(x, start_dim=1).to(torch.float32)
            # print(x.shape)
            
        # print("X shape: ", x.shape)
        x = self.brainmodel(x)
        return x


modellist = ['RN50-0', 'densenet201_imagenet', 'mobilenet_v2', 'CORnet_RT', 'inception_v3', 'vgg16', 'alexnet', 'squeezenet1_1', 'dinov2', 'dreamsim_vitb32', 'google_vit', 'nomic', 'CLIP-RN50', 'RN50-robust-8', 'RN50-robust-4', 'RN50-robust-2', 'RN50-robust-1', 'RN50-robust-0.5']


for idx, model_ in enumerate(modellist):
    ridgemodels[idx].coef_ = ridgemodels[idx].coef_.reshape(1, -1)


encoderbms = {}

for idx, encoder_name in enumerate(modellist):
    ridgemodels[idx].coef_ = ridgemodels[idx].coef_.reshape(1, -1)
    brainmodel = build_brainmodel_from_linear_model(
        ridgemodels[idx],
        ridgemodels[idx].coef_.shape[0],
        brain_data_train_new.shape[1],
        device
    )
    encoder, preprocess, preprocessed_images, params = load_encoder_and_images(
        encoder_name, layer_name, device, image_path
    )
    params.layer_name = layer_map[encoder_name]
    encoderbms[encoder_name] = EncoderBM_new(encoder, brainmodel, params)


def fgsm_single_direction(model, image, device, epsilon, al, iterations, voxel_idx, direction):
    image = image.to(device)
    model = model.to(device)
    test_image_pp = image.unsqueeze(0)  # Shape: (1, 3, H, W)
    delta = torch.zeros_like(test_image_pp, requires_grad=True)

    CLAMP_MIN = (torch.tensor([0.0, 0.0, 0.0]) - torch.tensor([0.485, 0.456, 0.406])) / torch.tensor([0.229, 0.224, 0.225])
    CLAMP_MAX = (torch.tensor([1.0, 1.0, 1.0]) - torch.tensor([0.485, 0.456, 0.406])) / torch.tensor([0.229, 0.224, 0.225])
    CLAMP_MIN = CLAMP_MIN.view(1, 3, 1, 1).to(device)
    CLAMP_MAX = CLAMP_MAX.view(1, 3, 1, 1).to(device)

    original_image = test_image_pp.clone().detach()
    old_response = model(test_image_pp).squeeze()

    for _ in range(iterations):
        input_ = torch.clamp(test_image_pp + delta, CLAMP_MIN, CLAMP_MAX)
        output = model(input_).squeeze()
        loss = -output if direction == 'minimize' else output
        loss.backward()
        delta.data = delta + al * delta.grad.detach().sign()
        delta.data = torch.clamp(delta, -epsilon, epsilon)
        delta.grad.zero_()

    perturbed_image = torch.clamp(test_image_pp + delta, CLAMP_MIN, CLAMP_MAX)
    perturbed_response = model(perturbed_image).squeeze()

    return perturbed_image, perturbed_response, old_response, delta


def fgsm(model, preprocess, image, label, device, epsilon, al, iterations, voxel_idx):
    perturbed_image_min, perturbed_response_min, old_response_min, delta_min = fgsm_single_direction(
        model, image, device, epsilon, al, iterations, voxel_idx, direction='minimize'
    )
    perturbed_image_max, perturbed_response_max, old_response_max, delta_max = fgsm_single_direction(
        model, image, device, epsilon, al, iterations, voxel_idx, direction='maximize'
    )

    diff_min = abs(perturbed_response_min.item() - old_response_min.item())
    diff_max = abs(perturbed_response_max.item() - old_response_max.item())

    if diff_min >= diff_max:
        perturbed_image = perturbed_image_min
        perturbed_response = perturbed_response_min
        delta = delta_min
        old_response = old_response_min
    else:
        perturbed_image = perturbed_image_max
        perturbed_response = perturbed_response_max
        delta = delta_max
        old_response = old_response_max

    shuffled_delta = delta.squeeze().detach().cpu().numpy()
    shuffled_delta = shuffled_delta.flatten()
    np.random.shuffle(shuffled_delta)
    shuffled_delta = shuffled_delta.reshape(3, 224, 224)
    shuffled_delta = torch.tensor(shuffled_delta).unsqueeze(0).to(device)

    CLAMP_MIN = (torch.tensor([0.0, 0.0, 0.0]) - torch.tensor([0.485, 0.456, 0.406])) / torch.tensor([0.229, 0.224, 0.225])
    CLAMP_MAX = (torch.tensor([1.0, 1.0, 1.0]) - torch.tensor([0.485, 0.456, 0.406])) / torch.tensor([0.229, 0.224, 0.225])
    CLAMP_MIN = CLAMP_MIN.view(1, 3, 1, 1).to(device)
    CLAMP_MAX = CLAMP_MAX.view(1, 3, 1, 1).to(device)

    control_perturbed_image = torch.clamp(image.unsqueeze(0) + shuffled_delta, CLAMP_MIN, CLAMP_MAX)
    control_perturbed_response = model(control_perturbed_image).squeeze().item()

    return label.item(), old_response.item(), perturbed_response.item(), image.unsqueeze(0), perturbed_image



CLAMP_MIN = (torch.tensor([0.0, 0.0, 0.0]) - torch.tensor([0.485, 0.456, 0.406])) / torch.tensor([0.229, 0.224, 0.225])
CLAMP_MAX = (torch.tensor([1.0, 1.0, 1.0]) - torch.tensor([0.485, 0.456, 0.406])) / torch.tensor([0.229, 0.224, 0.225])
CLAMP_MIN = CLAMP_MIN.view(1, 3, 1, 1).to(device)
CLAMP_MAX = CLAMP_MAX.view(1, 3, 1, 1).to(device)

encoderbms_all = encoderbms.copy()

model_names = list(encoderbms_all.keys())
n = len(model_names)
num_images = 20

epsilon = 3 / 255
al = 3 / 255
iterations = 1

transfer_matrices = []

for idx, image_idx in enumerate(image_idcs):
    image = preprocessed_images[image_idx].to(device)
    voxel_idx=0
    transfer_matrix = np.zeros((n, n))
    deltas = []
    original_responses = []

    # Step 1: generate delta_i for each model i
    for name in model_names:
        model = encoderbms_all[name]
        _, old_r, pert_r, orig_img, pert_img = fgsm(
            model, None, image, brain_data_test_new[idx, voxel_idx], device, epsilon, al, iterations, voxel_idx
        )
        delta = pert_img - orig_img
        deltas.append(delta.detach())
        original_responses.append(orig_img)

    # Step 2: apply each delta_i to each model j
    for i, delta in enumerate(deltas):
        for j, name in enumerate(model_names):
            model_j = encoderbms_all[name].to(device)
            with torch.no_grad():
                base = original_responses[i].to(device)
                perturbed = torch.clamp(base + delta.to(device), CLAMP_MIN, CLAMP_MAX)
                old = model_j(base)[0, voxel_idx].item()
                new = model_j(perturbed)[0, voxel_idx].item()
                transfer_matrix[i, j] = abs(new - old)

    transfer_matrices.append(transfer_matrix)

transfer_matrices_array = np.stack(transfer_matrices)  # shape: (20*20, 4, 4)
average_transfer_matrix = transfer_matrices_array.mean(axis=0)


import seaborn as sns
from scipy.cluster.hierarchy import linkage, leaves_list
import matplotlib.pyplot as plt


# Compute linkage and order of rows/columns
link = linkage(average_transfer_matrix, method='average', metric='euclidean')
order = leaves_list(link)

# Reorder matrix and labels
sorted_matrix = average_transfer_matrix[order][:, order]
sorted_names = [model_names[i] for i in order]

plt.figure(figsize=(12, 10))
sns.heatmap(
    sorted_matrix.T,
    xticklabels=sorted_names,
    yticklabels=sorted_names,
    annot=True,
    fmt=".2f",
    annot_kws={"size": 10},
    cbar_kws={"label": "|Δ response|"},
    cmap='bone_r',
)
plt.xticks(rotation=45, ha="right", fontsize=10)
plt.yticks(fontsize=10)
plt.xlabel("Source model (i)", fontsize=12)
plt.ylabel("Target model (j)", fontsize=12)
plt.title("Transferability matrix of FGSM attacks (|Δ response|)", fontsize=14)
plt.tight_layout()
plt.show()


# Save the transfer matrix
np.save('average_transfer_matrix.npy', average_transfer_matrix)
np.save('average_transfer_matrix_DATA.npy', transfer_matrices_array)


# # # what are the image indices?
# # print("Image indices used for FGSM:", image_idcs)

# # np.load('random_image_idcs.npy')

# # what are the voxel indices?
# print("Top 20 voxel indices used for FGSM:", top20_voxel_indices)

# np.load('voxel_indices.npy')


import torch
import numpy as np
from tqdm import tqdm

CLAMP_MIN = (torch.tensor([0.0, 0.0, 0.0]) - torch.tensor([0.485, 0.456, 0.406])) / torch.tensor([0.229, 0.224, 0.225])
CLAMP_MAX = (torch.tensor([1.0, 1.0, 1.0]) - torch.tensor([0.485, 0.456, 0.406])) / torch.tensor([0.229, 0.224, 0.225])
CLAMP_MIN = CLAMP_MIN.view(1, 3, 1, 1).to(device)
CLAMP_MAX = CLAMP_MAX.view(1, 3, 1, 1).to(device)

model_names = list(encoderbms_all.keys())
n = len(model_names)
num_images = 3
voxel_idx = 0
target_change = 0.75
transfer_matrices = []

for idx, image_idx in enumerate(tqdm(image_idcs[:num_images])):
    image = preprocessed_images[image_idx].to(device)
    transfer_matrix = np.zeros((n, n))
    deltas = []
    original_responses = []

    # Step 1: find the delta_i for each model that causes ~0.77 change
    for name in model_names:
        model = encoderbms_all[name]
        for steps in range(1, 256):
            epsilon = steps / 255
            al = steps / 255
            _, old_r, pert_r, orig_img, pert_img = fgsm(
                model, None, image, brain_data_test_new[idx, voxel_idx], device, epsilon, al, steps, voxel_idx
            )
            change = abs(pert_r - old_r)
            if change >= target_change:
                delta = pert_img - orig_img
                deltas.append(delta.detach())
                original_responses.append(orig_img)
                break
        else:
            deltas.append(delta.detach())  # use last delta if 0.77 not reached
            original_responses.append(orig_img)
            print('not reached for model ', name, 'with steps', steps)

    # Step 2: apply each delta_i to each model_j
    for i, delta in enumerate(deltas):
        for j, name in enumerate(model_names):
            model_j = encoderbms_all[name].to(device)
            with torch.no_grad():
                base = original_responses[i].to(device)
                perturbed = torch.clamp(base + delta.to(device), CLAMP_MIN, CLAMP_MAX)
                old = model_j(base)[0, voxel_idx].item()
                new = model_j(perturbed)[0, voxel_idx].item()
                transfer_matrix[i, j] = abs(new - old)

    transfer_matrices.append(transfer_matrix)

transfer_matrices_array = np.stack(transfer_matrices)
average_transfer_matrix = transfer_matrices_array.mean(axis=0)


# custom_order = ['squeezenet1_1', 'alexnet', 'RN50-robust-2', 'RN50-robust-8']
# idx_order = [model_names.index(name) for name in custom_order]
# sorted_matrix = average_transfer_matrix[idx_order][:, idx_order]

plt.figure(figsize=(12, 10))
sns.heatmap(
    average_transfer_matrix.T,
    annot=True,
    fmt=".2f",
    annot_kws={"size": 14},
    cbar_kws={"label": "|Δ response|"},
    cmap='bone_r',
)
plt.xticks(rotation=45, ha="right", fontsize=10)
plt.yticks(fontsize=14)
plt.xlabel("Source model (i)", fontsize=14)
plt.ylabel("Target model (j)", fontsize=14)
plt.title("Transferability matrix of FGSM attacks (|Δ response|)", fontsize=14)
plt.tight_layout()
plt.show()

# save average_transfer_matrix
np.save('average_transfer_matrix_epsilon.npy', average_transfer_matrix)
np.save('average_transfer_matrix_epsilon_DATA.npy', transfer_matrices_array)


# save brain_data_train_new and brain_data_test_new
np.save('brain_data_train_new.npy', brain_data_train_new)
np.save('brain_data_test_new.npy', brain_data_test_new)
# save voxel indices
np.save('voxel_indices.npy', top20_voxel_indices)
# save accuracies of the voxels
np.save('voxel_scores_final.npy', avg_models_voxel_scores_final)
