import os, sys, gc, pickle
import numpy as np
import torch
import torch.nn as nn
from natsort import natsorted
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer
from scipy.stats import pearsonr
import pandas as pd
from PIL import Image

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from brainmodels.utils import preprocess_images
from brainmodels.models import load_model

# -----------------------
# Args: source model only
# -----------------------
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, required=True)  # source model
args = parser.parse_args()
SRC_MODEL = args.model_name

# ---------------------------------
# Model sets (edit as needed)
# ---------------------------------
# Full list (order preserved; duplicates removed)
MODEL_LIST = [
    "CLIP-RN50",
    "RN50-0",
    "densenet201_imagenet",
    "mobilenet_v2",
    "CORnet_RT",
    "inception_v3",
    "squeezenet1_1",
    "dinov2",
    "dreamsim_vitb32",
    "google_vit",
    "blip2",
    "nomic",
    "vgg16",
    "alexnet",
    "RN50-robust-0.5",
    "RN50-robust-1",
    "RN50-robust-2",
    "RN50-robust-4",
    "RN50-robust-8",
    "L2-RN50-robust-0.1",
    "L2-RN50-robust-1",
    "L2-RN50-robust-3",
    "L2-RN50-robust-5",
]

def resolve_hook_name(encoder, requested):
    """
    Map your canonical best_layers string (e.g., 'blocks.8.norm2') to an actual
    key in encoder.named_modules(). Returns (resolved_name, status).
    """
    name2mod = dict(encoder.named_modules())
    if requested in name2mod:
        return requested, 'exact'

    # Try add/remove common wrappers
    variants = set()
    if not requested.startswith('model.'):
        variants.add('model.' + requested)
    else:
        variants.add(requested[len('model.'):])
    if not requested.startswith('module.'):
        variants.add('module.' + requested)
    else:
        variants.add(requested[len('module.'):])

    for v in variants:
        if v in name2mod:
            return v, 'prefixed'

    # CLIP ViT alias (OpenAI): blocks.i.* -> visual.transformer.resblocks.i.*
    if requested.startswith('blocks.'):
        cand = 'visual.transformer.resblocks.' + requested[len('blocks.'):]
        if cand in name2mod:
            return cand, 'alias'

        # HF ViT alias: blocks.i.norm{1,2} -> encoder.layer.i.layernorm_{before,after}
        parts = requested.split('.')
        if len(parts) >= 3 and parts[1].isdigit() and parts[2].startswith('norm'):
            i = int(parts[1])
            which = parts[2]
            hf = f'encoder.layer.{i}.layernorm_after' if which == 'norm2' \
                 else f'encoder.layer.{i}.layernorm_before'
            if hf in name2mod:
                return hf, 'alias'

    # Generic suffix fallback
    toks = requested.split('.')
    tail2 = '.'.join(toks[-2:]) if len(toks) >= 2 else requested
    tail1 = toks[-1]
    matches = [k for k in name2mod if k.endswith(tail2)]
    if len(matches) == 1:
        return matches[0], 'suffix2'
    if not matches:
        matches = [k for k in name2mod if k.endswith(tail1)]
        if len(matches) == 1:
            return matches[0], 'suffix1'
    if matches:
        return matches[0], 'ambiguous'  # still usable

    return None, 'unresolved'

# Helper: keep order, drop dups if you paste/extend
def dedup_keep_order(seq):
    seen = set(); out = []
    for x in seq:
        if x not in seen:
            out.append(x); seen.add(x)
    return out

# Baselines = models WITHOUT "robust" in the name
BASELINES = [m for m in MODEL_LIST if "robust" not in m]

# Robust L∞ = "robust" but NOT starting with "L2-"
ROBUST_LINF = [m for m in MODEL_LIST if ("robust" in m and not m.startswith("L2-"))]

# Robust L2 = names starting with "L2-"
ROBUST_L2 = [m for m in MODEL_LIST if m.startswith("L2-")]

# Final targets
FGSM_TARGETS = dedup_keep_order(BASELINES + ROBUST_LINF)
L2_TARGETS   = dedup_keep_order(BASELINES + ROBUST_L2)

print("FGSM_TARGETS =", FGSM_TARGETS)
print("L2_TARGETS   =", L2_TARGETS)

# -----------------------
# Config / outputs
# -----------------------
RESULTS_TAG = 'sensitivity_transfer'
results_dir    = os.path.join(os.getcwd(), f'corrs_{RESULTS_TAG}')
results_folder = os.path.join(os.getcwd(), f'results_{RESULTS_TAG}')

SAVE_ROOT = f'perturbations/{RESULTS_TAG}'
DIR_CLEAN = os.path.join(SAVE_ROOT, 'clean')
DIR_ADV   = os.path.join(SAVE_ROOT, 'adv')
DIR_CTRL  = os.path.join(SAVE_ROOT, 'ctrl')
os.makedirs(DIR_CLEAN, exist_ok=True)
os.makedirs(DIR_ADV,   exist_ok=True)
os.makedirs(DIR_CTRL,  exist_ok=True)

os.makedirs(os.path.join(results_dir, 'voxel_corrs'), exist_ok=True)
os.makedirs(results_folder, exist_ok=True)

# -----------------------
# Data & shared resources
# -----------------------
subjects = ['s1', 's2', 's5', 's7']
regions  = ['FFA', 'EBA', 'PPA']
n_images = 50
threshold = 1.5  # target |r' - r| on source

with open('saved/subject_region_to_top50_global.pkl','rb') as f:
    subject_region_to_top50_global = pickle.load(f)
with open('saved/subject_region_to_global_indices.pkl','rb') as f:
    subject_region_to_global_indices = pickle.load(f)
with open('saved/best_layers_per_subj_region.pkl','rb') as f:
    best_layers = pickle.load(f)

brain_data_paths = [
    '../nsd_processed/s1_FFA_t7.pt','../nsd_processed/s2_FFA_t7.pt','../nsd_processed/s5_FFA_t7.pt','../nsd_processed/s7_FFA_t7.pt',
    '../nsd_processed/s1_EBA_t7.pt','../nsd_processed/s2_EBA_t7.pt','../nsd_processed/s5_EBA_t7.pt','../nsd_processed/s7_EBA_t7.pt',
    '../nsd_processed/s1_PPA_t7.pt','../nsd_processed/s2_PPA_t7.pt','../nsd_processed/s5_PPA_t7.pt','../nsd_processed/s7_PPA_t7.pt'
]
brain_datas = [torch.load(p) for p in brain_data_paths]
brain_data  = torch.cat(brain_datas, dim=1)  # [1000, 3514]
train_idx = torch.load('../nsd_processed/485_unique.pt')
test_idx  = torch.load('../nsd_processed/515_shared.pt')
images_np = np.load('../nsd_processed/nsd_stimuli1000.npy')
selected_image_indices = np.load(os.path.expanduser('fixed_test_indices.npy'))[:n_images]

# -----------------------
# Utilities
# -----------------------
def split_data(acts, brain):
    train = torch.load('../nsd_processed/485_unique.pt').cpu().numpy()
    test  = torch.load('../nsd_processed/515_shared.pt').cpu().numpy()
    return (acts[train], acts[test], brain[train], brain[test])

def pearson_r(y_true, y_pred):
    return pearsonr(y_true.squeeze(), y_pred.squeeze())[0]
scorer = make_scorer(pearson_r, greater_is_better=True)
alphas = np.logspace(-2, 6, 20)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class EncoderBM(nn.Module):
    def __init__(self, encoder, brainmodel, layer_name):
        super().__init__()
        self.encoder = encoder
        self.brainmodel = brainmodel

        resolved, how = resolve_hook_name(self.encoder, layer_name)
        if resolved is None:
            raise KeyError(f"Could not resolve hook '{layer_name}' on this encoder.")
        self.layer_name = resolved

        self.activation = {}
        layer = dict(self.encoder.named_modules())[self.layer_name]
        layer.register_forward_hook(self.get_activation())

    def get_activation(self):
        def hook(m, inp, out):
            self.activation[self.layer_name] = out[0] if isinstance(out, tuple) else out
        return hook

    def forward(self, x):
        if hasattr(self.encoder, "encode_image"):
            _ = self.encoder.encode_image(x)
        else:
            _ = self.encoder(x)
        x = self.activation[self.layer_name]
        if hasattr(self, "mapping") and self.mapping == "cnn":
            x = x.to(torch.float32)
        elif x.ndim == 3:
            x = x.mean(dim=1).to(torch.float32)  # ViT tokens avg
        else:
            x = torch.flatten(x, start_dim=1).to(torch.float32)
        x = self.brainmodel(x)
        return x

def denormalize(img, mean, std, device):
    mean = torch.tensor(mean, device=device).view(1,3,1,1)
    std  = torch.tensor(std,  device=device).view(1,3,1,1)
    return img * std + mean

def shuffle_delta_spatial(delta_pix: torch.Tensor) -> torch.Tensor:
    B, C, H, W = delta_pix.shape
    perm = torch.randperm(H * W, device=delta_pix.device)
    d = delta_pix.reshape(B, C, H * W)
    d = d[:, :, perm].reshape(B, C, H, W)
    return d

def save_tensor_image(x_pix: torch.Tensor, path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    arr = (x_pix.squeeze(0).detach().clamp(0,1).permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
    Image.fromarray(arr).save(path)

def build_brainmodel_from_linear_model(lm, input_dim, output_dim, device):
    bm = nn.Sequential(nn.Linear(input_dim, output_dim))
    bm[0].weight = nn.Parameter(torch.tensor(lm.coef_, dtype=torch.float32))
    bm[0].bias   = nn.Parameter(torch.tensor(lm.intercept_, dtype=torch.float32))
    bm.to(device)
    return bm

def load_dir_for_layer(model_name, layer_name):
    # try original, then swap "model." prefix
    dir_ = f'activations/{model_name}/{layer_name}/'
    path = os.path.join(dir_, 'best_alphas.npy')
    if not os.path.exists(path):
        if layer_name.startswith('model.'):
            alt = layer_name.split('model.',1)[1]
        else:
            alt = 'model.' + layer_name
        dir_ = f'activations/{model_name}/{alt}/'
    return dir_

def prepare_encoderbm(model_name, subject, region):
    encoder, (MEAN, STD) = load_model(model_name)
    encoder.to(device).eval()
    layer_name = best_layers[model_name][subject, region]
    dir_ = load_dir_for_layer(model_name, layer_name)
    batch_files = natsorted([f for f in os.listdir(dir_) if f.startswith('batch')])
    acts = np.concatenate([np.load(os.path.join(dir_, f)) for f in batch_files], axis=0)
    acts_flat = acts.reshape(acts.shape[0], -1)

    acts_train, acts_test, brain_train, brain_test = split_data(acts_flat, brain_data.numpy())
    # mean-of-top-50 target
    top50 = subject_region_to_top50_global[(subject, region)][:50]
    brain_train_sub = brain_train[:, top50].mean(axis=1, keepdims=True)
    brain_test_sub  = brain_test[:,  top50].mean(axis=1, keepdims=True)

    ridge = Ridge()
    grid = GridSearchCV(ridge, {'alpha': alphas}, scoring=scorer, cv=5)
    grid.fit(acts_train, brain_train_sub)
    ridgemodel = grid.best_estimator_
    ridgemodel.coef_ = ridgemodel.coef_.reshape(1, -1)

    brainmodel = build_brainmodel_from_linear_model(ridgemodel, ridgemodel.coef_.shape[1], brain_train_sub.shape[1], device)
    hook_name = layer_name.split('model.',1)[1] if layer_name.startswith('model.') else layer_name
    encbm = EncoderBM(encoder, brainmodel, hook_name).to(device).eval()

    mean_t = torch.tensor(MEAN, device=device).view(1,3,1,1)
    std_t  = torch.tensor(STD,  device=device).view(1,3,1,1)
    clamp_min = (torch.tensor([0,0,0], device=device).view(1,3,1,1) - mean_t) / std_t
    clamp_max = (torch.tensor([1,1,1], device=device).view(1,3,1,1) - mean_t) / std_t
    return encbm, MEAN, STD, clamp_min, clamp_max

# -----------------------
# Threshold-search attacks (source model)
# -----------------------
def fgsm_step(encbm, x_norm, std, clamp_min, clamp_max, eps_steps):
    """
    FGSM L∞ PGD: epsilon = eps_steps/255 (pixel), alpha = 1/255 (pixel), k = eps_steps.
    Try both directions; return the better (larger |Δ| on source).
    """
    eps = torch.tensor([eps_steps/255]*3, device=device).view(1,3,1,1) / std
    alpha = torch.tensor([1/255]*3, device=device).view(1,3,1,1) / std

    def run(minimize=False):
        d = torch.zeros_like(x_norm, device=device)
        for _ in range(eps_steps):
            d.requires_grad_()
            x_adv = (x_norm + d).clamp(clamp_min, clamp_max)
            out = encbm(x_adv).squeeze()
            for p in encbm.parameters():
                if p.grad is not None: p.grad = None
            loss = -out if minimize else out
            loss.backward()
            with torch.no_grad():
                d += alpha * d.grad.sign()
                d.clamp_(-eps, eps)
            d = d.detach()
        x_adv = (x_norm + d).clamp(clamp_min, clamp_max)
        return x_adv

    r0 = encbm(x_norm).squeeze().item()
    xa = run(False); rA = encbm(xa).squeeze().item()
    xb = run(True);  rB = encbm(xb).squeeze().item()
    if abs(rA - r0) >= abs(rB - r0): return xa, r0, rA, 'maximize'
    else:                            return xb, r0, rB, 'minimize'

def l2_step(encbm, x_norm, std, clamp_min, clamp_max, eps_steps):
    """
    L2 PGD in pixel space: epsilon = eps_steps (pixel), alpha = 1 (pixel), k = eps_steps.
    """
    eps = torch.tensor(float(eps_steps), device=device)
    alpha = torch.tensor(1.0, device=device)

    def run(maximize=True):
        d = torch.zeros_like(x_norm, device=device)
        for _ in range(eps_steps):
            d.requires_grad_(True)
            x_adv = (x_norm + d).clamp(clamp_min, clamp_max)
            out = encbm(x_adv).squeeze()
            for p in encbm.parameters():
                if p.grad is not None: p.grad = None
            loss = out if maximize else -out
            loss.backward()
            g = d.grad.detach()
            g_pix = g * std
            gp = g_pix.view(g_pix.size(0), -1)
            gp_norm = gp.norm(p=2, dim=1, keepdim=True).clamp_min(1e-8).view(-1,1,1,1)
            with torch.no_grad():
                step_norm = (alpha * g_pix / gp_norm) / std
                d += step_norm
                d_pix = d * std
                dp = d_pix.view(d_pix.size(0), -1)
                dp_norm = dp.norm(p=2, dim=1, keepdim=True).clamp_min(1e-8).view(-1,1,1,1)
                scale = (eps / dp_norm).clamp(max=1.0)
                d *= scale
            d = d.detach()
        x_adv = (x_norm + d).clamp(clamp_min, clamp_max)
        return x_adv

    r0 = encbm(x_norm).squeeze().item()
    xa = run(True);  rA = encbm(xa).squeeze().item()
    xb = run(False); rB = encbm(xb).squeeze().item()
    if abs(rA - r0) >= abs(rB - r0): return xa, r0, rA, 'maximize'
    else:                            return xb, r0, rB, 'minimize'

# -----------------------
# MAIN (Option B structure)
# -----------------------
for subject in subjects:
    torch.cuda.empty_cache(); gc.collect()
    for region in regions:
        # SOURCE model
        encbm_src, MEAN_src, STD_src, CLAMP_MIN_src, CLAMP_MAX_src = prepare_encoderbm(SRC_MODEL, subject, region)
        std_src = torch.tensor(STD_src, device=device).view(1,3,1,1)

        # Build TARGET sets (ensure SRC on diagonal)
        fgsm_targets = FGSM_TARGETS[:]
        l2_targets   = L2_TARGETS[:]
        if SRC_MODEL not in fgsm_targets: fgsm_targets = [SRC_MODEL] + fgsm_targets
        if SRC_MODEL not in l2_targets:   l2_targets   = [SRC_MODEL] + l2_targets
        all_targets = dedup_keep_order(fgsm_targets + l2_targets)

        # Helper pix->norm
        def pix_to_norm(mean, std, x_pix, dev):
            mean_t = torch.tensor(mean, device=dev).view(1,3,1,1)
            std_t  = torch.tensor(std,  device=dev).view(1,3,1,1)
            return (x_pix.to(dev) - mean_t) / std_t

        # Precompute per-image FGSM/L2 at threshold on source; store CPU tensors + meta
        test_idx_np = test_idx.cpu().numpy()
        precomp = []
        eps_used_fgsm, eps_used_l2 = [], []

        for image_idx in selected_image_indices:
            idx = int(image_idx)
            global_idx = int(test_idx_np[idx])
            x_norm = preprocess_images(images_np[global_idx:global_idx+1], MEAN_src, STD_src, device=device)
            r_src_clean = encbm_src(x_norm).squeeze().item()

            # ---- FGSM escalate eps_steps until |Δ|>=threshold or max 40
            hit = False
            for eps_steps in range(1, 31):
                x_adv, r0, r1, dirn = fgsm_step(encbm_src, x_norm, std_src, CLAMP_MIN_src, CLAMP_MAX_src, eps_steps)
                if abs(r1 - r0) >= threshold:
                    hit = True
                    break
            eps_used_fgsm.append((idx, eps_steps, dirn, r0, r1))
            x_pix_clean = denormalize(x_norm, MEAN_src, STD_src, device).clamp(0,1)
            x_pix_adv_f = denormalize(x_adv, MEAN_src, STD_src, device).clamp(0,1)
            delta_pix = (x_pix_adv_f - x_pix_clean)
            delta_ctrl = shuffle_delta_spatial(delta_pix)
            x_pix_ctrl_f = (x_pix_clean + delta_ctrl).clamp(0,1)

            save_tensor_image(x_pix_adv_f,  os.path.join(DIR_ADV,  f'FGSM_{SRC_MODEL}_{subject}_{region}_img{idx}_eps{eps_steps}_{dirn}.png'))
            save_tensor_image(x_pix_ctrl_f, os.path.join(DIR_CTRL, f'FGSM_{SRC_MODEL}_{subject}_{region}_img{idx}_eps{eps_steps}_{dirn}.png'))
            clean_path = os.path.join(DIR_CLEAN, f'{SRC_MODEL}_{subject}_{region}_img{idx}.png')
            if not os.path.exists(clean_path):
                save_tensor_image(x_pix_clean, clean_path)

            # ---- L2 escalate eps_steps until |Δ|>=threshold or max 10
            hit = False
            for eps_steps2 in range(1, 31):
                x_adv2, r0b, r1b, dirn2 = l2_step(encbm_src, x_norm, std_src, CLAMP_MIN_src, CLAMP_MAX_src, eps_steps2)
                if abs(r1b - r0b) >= threshold:
                    hit = True
                    break
            eps_used_l2.append((idx, eps_steps2, dirn2, r0b, r1b))
            x_pix_adv_l = denormalize(x_adv2, MEAN_src, STD_src, device).clamp(0,1)
            delta_pix2 = (x_pix_adv_l - x_pix_clean)
            delta_ctrl2 = shuffle_delta_spatial(delta_pix2)
            x_pix_ctrl_l = (x_pix_clean + delta_ctrl2).clamp(0,1)

            save_tensor_image(x_pix_adv_l,  os.path.join(DIR_ADV,  f'L2_{SRC_MODEL}_{subject}_{region}_img{idx}_eps{eps_steps2}_{dirn2}.png'))
            save_tensor_image(x_pix_ctrl_l, os.path.join(DIR_CTRL, f'L2_{SRC_MODEL}_{subject}_{region}_img{idx}_eps{eps_steps2}_{dirn2}.png'))

            precomp.append({
                'idx': idx,
                'x_pix_clean': x_pix_clean.cpu(),
                'fgsm': {'x_pix_adv': x_pix_adv_f.cpu(), 'x_pix_ctrl': x_pix_ctrl_f.cpu(), 'dir': dirn,  'eps': eps_steps},
                'l2':   {'x_pix_adv': x_pix_adv_l.cpu(), 'x_pix_ctrl': x_pix_ctrl_l.cpu(), 'dir': dirn2, 'eps': eps_steps2},
            })

            # free per-image GPU tensors
            del x_norm, x_pix_clean, x_pix_adv_f, x_pix_ctrl_f, x_pix_adv_l, x_pix_ctrl_l
            torch.cuda.empty_cache()

        # collectors
        rows_fgsm = {m: [] for m in fgsm_targets}
        rows_fgsm_ctrl = {m: [] for m in fgsm_targets}
        rows_l2   = {m: [] for m in l2_targets}
        rows_l2_ctrl   = {m: [] for m in l2_targets}
        preds_fgsm, preds_fgsm_ctrl = [], []
        preds_l2,   preds_l2_ctrl   = [], []

        # Evaluate targets one at a time
        for m in dedup_keep_order(fgsm_targets + l2_targets):
            encbm_t, MEAN_t, STD_t, CL_MIN_t, CL_MAX_t = prepare_encoderbm(m, subject, region)
            try:
                for item in precomp:
                    idx = item['idx']

                    if m in fgsm_targets:
                        x_clean_t = pix_to_norm(MEAN_t, STD_t, item['x_pix_clean'], device)
                        x_adv_t   = pix_to_norm(MEAN_t, STD_t, item['fgsm']['x_pix_adv'], device)
                        x_ctrl_t  = pix_to_norm(MEAN_t, STD_t, item['fgsm']['x_pix_ctrl'], device)
                        with torch.no_grad():
                            r_clean_t = encbm_t(x_clean_t).squeeze().item()
                            r_adv_t   = encbm_t(x_adv_t).squeeze().item()
                            r_ctrl_t  = encbm_t(x_ctrl_t).squeeze().item()
                        rows_fgsm[m].append(abs(r_adv_t - r_clean_t))
                        rows_fgsm_ctrl[m].append(abs(r_ctrl_t - r_clean_t))
                        preds_fgsm.append((idx, m, r_clean_t, r_adv_t, item['fgsm']['dir'], item['fgsm']['eps']))
                        preds_fgsm_ctrl.append((idx, m, r_clean_t, r_ctrl_t, item['fgsm']['dir'], item['fgsm']['eps']))

                    if m in l2_targets:
                        x_clean_t = pix_to_norm(MEAN_t, STD_t, item['x_pix_clean'], device)
                        x_adv_t   = pix_to_norm(MEAN_t, STD_t, item['l2']['x_pix_adv'], device)
                        x_ctrl_t  = pix_to_norm(MEAN_t, STD_t, item['l2']['x_pix_ctrl'], device)
                        with torch.no_grad():
                            r_clean_t = encbm_t(x_clean_t).squeeze().item()
                            r_adv_t   = encbm_t(x_adv_t).squeeze().item()
                            r_ctrl_t  = encbm_t(x_ctrl_t).squeeze().item()
                        rows_l2[m].append(abs(r_adv_t - r_clean_t))
                        rows_l2_ctrl[m].append(abs(r_ctrl_t - r_clean_t))
                        preds_l2.append((idx, m, r_clean_t, r_adv_t, item['l2']['dir'], item['l2']['eps']))
                        preds_l2_ctrl.append((idx, m, r_clean_t, r_ctrl_t, item['l2']['dir'], item['l2']['eps']))
            finally:
                encbm_t.activation.clear()
                del encbm_t
                torch.cuda.empty_cache(); gc.collect()

        # -----------------------
        # Save transfer rows + per-image predictions + eps used (ADV and CONTROL)
        # -----------------------
        if len(rows_fgsm):
            cols = sorted(rows_fgsm.keys())
            vals_adv  = [float(np.mean(rows_fgsm[c])) if len(rows_fgsm[c]) else np.nan for c in cols]
            vals_ctrl = [float(np.mean(rows_fgsm_ctrl[c])) if len(rows_fgsm_ctrl[c]) else np.nan for c in cols]
            pd.DataFrame([vals_adv],  columns=cols, index=[SRC_MODEL]).to_csv(
                os.path.join(results_folder, f'transferThresh_FGSM_{SRC_MODEL}_{subject}_{region}.csv'))
            pd.DataFrame([vals_ctrl], columns=cols, index=[SRC_MODEL]).to_csv(
                os.path.join(results_folder, f'transferThresh_FGSM_ctrl_{SRC_MODEL}_{subject}_{region}.csv'))
            pd.DataFrame(preds_fgsm, columns=['image_idx','model','r_clean','r_adv','direction','eps_steps']).to_csv(
                os.path.join(results_folder, f'predsThresh_FGSM_{SRC_MODEL}_{subject}_{region}.csv'), index=False)
            pd.DataFrame(preds_fgsm_ctrl, columns=['image_idx','model','r_clean','r_ctrl','direction','eps_steps']).to_csv(
                os.path.join(results_folder, f'predsThresh_FGSM_ctrl_{SRC_MODEL}_{subject}_{region}.csv'), index=False)
            pd.DataFrame(eps_used_fgsm, columns=['image_idx','eps_steps','direction','r_clean_src','r_adv_src']).to_csv(
                os.path.join(results_folder, f'epsUsed_FGSM_{SRC_MODEL}_{subject}_{region}.csv'), index=False)

        if len(rows_l2):
            cols = sorted(rows_l2.keys())
            vals_adv  = [float(np.mean(rows_l2[c])) if len(rows_l2[c]) else np.nan for c in cols]
            vals_ctrl = [float(np.mean(rows_l2_ctrl[c])) if len(rows_l2_ctrl[c]) else np.nan for c in cols]
            pd.DataFrame([vals_adv],  columns=cols, index=[SRC_MODEL]).to_csv(
                os.path.join(results_folder, f'transferThresh_L2_{SRC_MODEL}_{subject}_{region}.csv'))
            pd.DataFrame([vals_ctrl], columns=cols, index=[SRC_MODEL]).to_csv(
                os.path.join(results_folder, f'transferThresh_L2_ctrl_{SRC_MODEL}_{subject}_{region}.csv'))
            pd.DataFrame(preds_l2, columns=['image_idx','model','r_clean','r_adv','direction','eps_steps']).to_csv(
                os.path.join(results_folder, f'predsThresh_L2_{SRC_MODEL}_{subject}_{region}.csv'), index=False)
            pd.DataFrame(preds_l2_ctrl, columns=['image_idx','model','r_clean','r_ctrl','direction','eps_steps']).to_csv(
                os.path.join(results_folder, f'predsThresh_L2_ctrl_{SRC_MODEL}_{subject}_{region}.csv'), index=False)
            pd.DataFrame(eps_used_l2, columns=['image_idx','eps_steps','direction','r_clean_src','r_adv_src']).to_csv(
                os.path.join(results_folder, f'epsUsed_L2_{SRC_MODEL}_{subject}_{region}.csv'), index=False)

        # free source
        encbm_src.activation.clear()
        del encbm_src
        torch.cuda.empty_cache(); gc.collect()

print("All done.")
