import os, sys, json, gc, pickle, random, warnings, time, logging, traceback
from typing import Dict, Tuple
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from natsort import natsorted

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from brainmodels.utils import preprocess_images
from brainmodels.models import load_model

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--subject", required=True)
parser.add_argument("--region", required=True)
args = parser.parse_args()
subject, region = args.subject, args.region

warnings.filterwarnings("ignore", category=UserWarning)
torch.backends.cudnn.benchmark = True
SEED = 123
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SAVE_DIR = os.path.join(os.getcwd(), f"fim_{subject}_{region}")
os.makedirs(SAVE_DIR, exist_ok=True)

ROOT = ''
GENERATED_ROOT = f'{ROOT}/generated_attacks_single'
TRANSFER_ROOT  = f'{ROOT}/transfer_results_single'
os.makedirs(GENERATED_ROOT, exist_ok=True)
os.makedirs(TRANSFER_ROOT, exist_ok=True)

images_np = np.load('../nsd_processed/nsd_stimuli1000.npy')
train_idx = torch.load('../nsd_processed/485_unique.pt')
test_idx  = torch.load('../nsd_processed/515_shared.pt')

with open('../transfer_attacks/saved/subject_region_to_top50_global.pkl', 'rb') as f:
    SUBJECT_REGION_TO_TOP50 = pickle.load(f)
with open('../transfer_attacks/saved/subject_region_to_global_indices.pkl', 'rb') as f:
    SUBJECT_REGION_TO_GLOBAL_IDX = pickle.load(f)
with open('../transfer_attacks/saved/best_layers_per_subj_region.pkl', 'rb') as f:
    BEST_LAYERS = pickle.load(f)

brain_paths = [
    '../nsd_processed/s1_FFA_t7.pt','../nsd_processed/s2_FFA_t7.pt','../nsd_processed/s5_FFA_t7.pt','../nsd_processed/s7_FFA_t7.pt',
    '../nsd_processed/s1_EBA_t7.pt','../nsd_processed/s2_EBA_t7.pt','../nsd_processed/s5_EBA_t7.pt','../nsd_processed/s7_EBA_t7.pt',
    '../nsd_processed/s1_PPA_t7.pt','../nsd_processed/s2_PPA_t7.pt','../nsd_processed/s5_PPA_t7.pt','../nsd_processed/s7_PPA_t7.pt'
]
BRAIN = torch.cat([torch.load(p) for p in brain_paths], dim=1).numpy()

BEST_LAYERS['VGG16-robust-l2-3'] = BEST_LAYERS['vgg16'].copy()
for k in list(BEST_LAYERS['VGG16-robust-l2-3'].keys()):
    BEST_LAYERS['VGG16-robust-l2-3'][k] = 'avgpool'

SELECTED_IMAGE_INDICES = np.load(os.path.expanduser('../transfer_attacks/fixed_test_indices.npy')).astype(int)

EPS_L2_PIX = 1000.0
STEP_PIX   = 1000.0
STEPS = 1

def split_data(acts_np: np.ndarray, brain_np: np.ndarray):
    return (acts_np[train_idx], acts_np[test_idx], brain_np[train_idx], brain_np[test_idx])

def denorm(img_norm: torch.Tensor, mean, std):
    mean = torch.tensor(mean, device=img_norm.device).view(1,3,1,1)
    std  = torch.tensor(std,  device=img_norm.device).view(1,3,1,1)
    return (img_norm * std + mean).clamp(0,1)

def renorm(img_pixel01: torch.Tensor, mean, std):
    mean = torch.tensor(mean, device=img_pixel01.device).view(1,3,1,1)
    std  = torch.tensor(std,  device=img_pixel01.device).view(1,3,1,1)
    return (img_pixel01 - mean) / std

def get_activation_dir(model_name: str, layer_name: str):
    candidates = [layer_name]
    candidates.append('model.'+layer_name if not layer_name.startswith('model.') else layer_name.split('model.',1)[1])
    for dn in candidates:
        cand = f"{ROOT}/activations/{model_name}/{dn}/"
        if os.path.isdir(cand): return cand
    raise FileNotFoundError(f"No activation dir for {model_name}:{layer_name} (tried {candidates})")

class EncoderBM(nn.Module):
    def __init__(self, encoder, brainmodel, hook_layer):
        super().__init__()
        self.encoder = encoder
        self.brainmodel = brainmodel
        self.hook_layer = hook_layer
        self.activation = {}
        layer = dict([*self.encoder.named_modules()])[self.hook_layer]
        layer.register_forward_hook(self._get_hook())
    def _get_hook(self):
        def hook(_, __, out):
            self.activation[self.hook_layer] = out[0] if isinstance(out, tuple) else out
        return hook
    def forward(self, x):
        _ = self.encoder.encode_image(x) if hasattr(self.encoder, "encode_image") else self.encoder(x)
        x = self.activation[self.hook_layer]
        x = x.mean(dim=1).to(torch.float32) if x.ndim == 3 else torch.flatten(x, start_dim=1).to(torch.float32)
        return self.brainmodel(x)

def build_brainmodel_from_linear_model(lm, input_dim, output_dim):
    bm = nn.Sequential(nn.Linear(input_dim, output_dim))
    bm[0].weight = nn.Parameter(torch.tensor(lm.coef_, dtype=torch.float32))
    bm[0].bias   = nn.Parameter(torch.tensor(lm.intercept_, dtype=torch.float32))
    return bm.to(device).eval()

ENCODERS: Dict[str, Tuple[torch.nn.Module, Tuple[list,list]]] = {}
READOUTS: Dict[Tuple[str,str,str], Tuple[torch.nn.Module, list, list]] = {}

def load_encoder_cached(model_name: str):
    if model_name in ENCODERS: return ENCODERS[model_name]
    enc, (MEAN, STD) = load_model(model_name)
    enc = enc.to(device).eval()
    ENCODERS[model_name] = (enc, (MEAN, STD))
    return ENCODERS[model_name]

from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV, KFold
from sklearn.metrics import make_scorer
from scipy.stats import pearsonr

ALPHAS = np.logspace(-2, 6, 10)
SCORER = make_scorer(lambda a,b: pearsonr(a.squeeze(), b.squeeze())[0], greater_is_better=True)

def _readout_cache_paths(model_name, subj, reg):
    root = os.path.join(ROOT, 'transfer_attacks', 'attack_results', model_name, f'{subj}_{reg}')
    os.makedirs(root, exist_ok=True)
    return dict(weights=os.path.join(root, 'weights.npy'),
                intercepts=os.path.join(root, 'intercepts.npy'),
                alpha=os.path.join(root, 'best_alpha.npy'))

def get_or_build_readout(model_name: str, subj: str, reg: str):
    key = (model_name, subj, reg)
    if key in READOUTS: return READOUTS[key]
    enc, (MEAN, STD) = load_encoder_cached(model_name)
    raw_layer = BEST_LAYERS[model_name][subj, reg]
    hook_layer = raw_layer if (model_name == 'dreamsim_vitb32' or not raw_layer.startswith('model.')) else raw_layer.split('model.', 1)[1]
    act_dir = get_activation_dir(model_name, raw_layer)
    acts = np.concatenate([np.load(os.path.join(act_dir, f)) for f in natsorted(os.listdir(act_dir)) if f.startswith('batch')], axis=0)
    acts = acts.reshape(acts.shape[0], -1).astype(np.float32)
    Xtr, Xte, Ytr, Yte = split_data(acts, BRAIN)
    top50 = SUBJECT_REGION_TO_TOP50[(subj, reg)][:50]
    ytr = Ytr[:, top50].mean(axis=1, keepdims=True)
    paths = _readout_cache_paths(model_name, subj, reg)
    have = all(os.path.exists(paths[k]) for k in ('weights','intercepts','alpha'))
    if have:
        ridge = Ridge(alpha=float(np.load(paths['alpha'])))
        ridge.coef_ = np.load(paths['weights'])
        ridge.intercept_ = np.load(paths['intercepts'])
    else:
        ridge = GridSearchCV(Ridge(), {'alpha': ALPHAS}, scoring=SCORER, cv=5).fit(Xtr, ytr).best_estimator_
        np.save(paths['weights'], ridge.coef_); np.save(paths['intercepts'], ridge.intercept_); np.save(paths['alpha'], ridge.alpha)
    ridge.coef_ = ridge.coef_.reshape(1, -1)
    brainmodel = build_brainmodel_from_linear_model(ridge, ridge.coef_.shape[1], 1)
    model = EncoderBM(enc, brainmodel, hook_layer).eval().to(device)
    READOUTS[key] = (model, MEAN, STD)
    return READOUTS[key]

def _l2_norm(x): return x.view(x.size(0), -1).norm(p=2, dim=1, keepdim=True)

def l2_attack_pixel(model, x_pix01, MEAN, STD, eps_l2_pix: float, step_pix: float, steps: int, direction: str = 'maximize'):
    x0 = x_pix01.detach()
    r_total = eps_l2_pix / 255.0
    alpha   = step_pix      / 255.0
    adv = x0.clone().detach()
    for _ in range(steps):
        adv.requires_grad_(True)
        out  = model(renorm(adv, MEAN, STD)).squeeze()
        loss = -out if direction == 'minimize' else out
        loss.backward()
        g = adv.grad.detach()
        gnorm = _l2_norm(g).clamp(min=1e-12)
        adv = (adv + alpha * (g / gnorm.view(-1,1,1,1))).detach()
        delta = adv - x0
        dnorm = _l2_norm(delta)
        scale = (r_total / dnorm).clamp(max=1.0)
        adv = (x0 + delta * scale.view(-1,1,1,1)).clamp(0,1).detach()
    clean_resp = float(model(renorm(x0,  MEAN, STD)).squeeze().item())
    adv_resp   = float(model(renorm(adv, MEAN, STD)).squeeze().item())
    achieved   = float(_l2_norm((adv - x0)) * 255.0)
    return adv, adv_resp, clean_resp, achieved

def attack_model_for_images(model_name, img_indices, eps_l2_pix, step_pix, steps, subj, reg, direction='minimize'):
    model, MEAN, STD = get_or_build_readout(model_name, subj, reg)
    imgs_norm = preprocess_images(images_np, MEAN, STD, device=device)
    out_dir = os.path.join(GENERATED_ROOT, f'{model_name}', f'eps_{int(eps_l2_pix)}')
    os.makedirs(out_dir, exist_ok=True)
    results = []
    for idx in map(int, img_indices):
        x_norm = imgs_norm[idx].unsqueeze(0)
        x_pix  = denorm(x_norm, MEAN, STD)
        adv_pix, adv_r, clean_r, eps_ach = l2_attack_pixel(model, x_pix, MEAN, STD, eps_l2_pix, step_pix, steps, direction)
        payload = dict(image_idx=idx, model_name=model_name, eps_l2_pix=float(eps_l2_pix), eps_achieved_pix=float(eps_ach),
                       clean=float(clean_r), adv=float(adv_r), x_pix=x_pix.detach().cpu(), adv_pix=adv_pix.detach().cpu())
        results.append(payload)
        torch.save(payload, os.path.join(out_dir, f'attack_image_{idx}.pt'))
        torch.cuda.empty_cache(); gc.collect()
    return results

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("attack-suite")

def attack_suite_and_visualize_safe(model_list, img_indices=SELECTED_IMAGE_INDICES, eps_l2_pix=EPS_L2_PIX, step_pix=STEP_PIX, steps=STEPS, subj=subject, reg=region, direction='minimize', save_partial=True):
    suite_results, skipped = {}, {}
    for m in model_list:
        start_t = time.time()
        try:
            logger.info(f"Starting model: {m}")
            res = attack_model_for_images(m, img_indices, eps_l2_pix, step_pix, steps, subj, reg, direction)
            suite_results[m] = res
            if save_partial:
                out_fp = os.path.join(GENERATED_ROOT, f"{m}", f"eps_{int(eps_l2_pix)}", "suite_results.pt")
                os.makedirs(os.path.dirname(out_fp), exist_ok=True)
                torch.save(res, out_fp)
                summary = [{"image_idx": r["image_idx"], "clean": r["clean"], "adv": r["adv"], "eps_achieved": r["eps_achieved_pix"]} for r in res]
                with open(out_fp.replace(".pt", ".json"), "w") as f: json.dump(summary, f)
            logger.info(f"Finished model: {m} in {time.time()-start_t:.1f}s")
        except Exception as e:
            tb = traceback.format_exc()
            warnings.warn(f"[{m}] skipped: {e}")
            logger.exception(f"Error on model {m}: {e}")
            skipped[m] = dict(err=str(e), traceback=tb)
            err_fp = os.path.join(GENERATED_ROOT, f"{m}", f"eps_{int(eps_l2_pix)}", "error.txt")
            os.makedirs(os.path.dirname(err_fp), exist_ok=True)
            with open(err_fp, "w") as f: f.write(tb)
        finally:
            try: del res
            except: pass
            torch.cuda.empty_cache(); gc.collect()
    return dict(results=suite_results, skipped=skipped)

MODEL_LIST = [
    "CLIP-RN50","RN50-0","L2-RN50-robust-0.1","L2-RN50-robust-1","L2-RN50-robust-3","L2-RN50-robust-5",
    "densenet201_imagenet","mobilenet_v2","CORnet_RT","inception_v3","squeezenet1_1",
    "dinov2","dreamsim_vitb32","google_vit","blip2","nomic","vgg16","alexnet"
]

ALPHAS_50 = np.logspace(-2, 6, 10)
KFOLDS = 5
TOPK = 50

def _mt_save_dir(model_name, subj, reg):
    d = os.path.join(os.getcwd(), f'multi50_{model_name}_{subj}_{reg}')
    os.makedirs(d, exist_ok=True)
    return d

def _mt_paths(model_name, subj, reg):
    d = _mt_save_dir(model_name, subj, reg)
    return dict(cv_scores=os.path.join(d, 'cv_scores.npy'),
                alphas=os.path.join(d, 'alphas.npy'),
                best_alpha_idx=os.path.join(d, 'best_alpha_idx_per_voxel.npy'),
                weights=os.path.join(d, 'weights.npy'),
                intercepts=os.path.join(d, 'intercepts.npy'))

def _flatten_acts(acts: np.ndarray): return acts.reshape(acts.shape[0], -1)

def _build_linear_from_weights(W: np.ndarray, b: np.ndarray) -> nn.Module:
    lin = nn.Linear(W.shape[1], W.shape[0], bias=True)
    with torch.no_grad():
        lin.weight.copy_(torch.tensor(W, dtype=torch.float32))
        lin.bias.copy_(torch.tensor(b, dtype=torch.float32))
    return lin.to(device).eval()

def get_or_build_readout_multi50(model_name: str, subj: str, reg: str, alphas=ALPHAS_50, kfolds: int = KFOLDS):
    paths = _mt_paths(model_name, subj, reg)
    cache_ok = all(os.path.exists(paths[p]) for p in ['weights','intercepts','cv_scores','alphas','best_alpha_idx'])
    enc, (MEAN, STD) = load_encoder_cached(model_name)
    raw_layer = BEST_LAYERS[model_name][subj, reg]
    hook_layer = raw_layer if (model_name == 'dreamsim_vitb32' or not raw_layer.startswith('model.')) else raw_layer.split('model.', 1)[1]
    act_dir = get_activation_dir(model_name, raw_layer)
    acts = np.concatenate([np.load(os.path.join(act_dir, f)) for f in natsorted(os.listdir(act_dir)) if f.startswith('batch')], axis=0)
    X = _flatten_acts(acts)
    top_idx = SUBJECT_REGION_TO_TOP50[(subj, reg)][:TOPK]
    Y = BRAIN[:, top_idx]
    Xtr, Xte, Ytr, Yte = split_data(X, Y)
    if cache_ok:
        W = np.load(paths['weights']); b = np.load(paths['intercepts'])
        brainmodel = _build_linear_from_weights(W, b)
        return (EncoderBM(enc, brainmodel, hook_layer).eval().to(device), MEAN, STD)
    A, V = len(alphas), Ytr.shape[1]
    scores = np.zeros((A, kfolds, V), dtype=np.float32)
    kf = KFold(n_splits=kfolds, shuffle=True, random_state=SEED)
    tr_indices = np.arange(Xtr.shape[0])
    for ai, alpha in enumerate(alphas):
        for fi, (tr_i, va_i) in enumerate(kf.split(tr_indices)):
            tr, va = tr_indices[tr_i], tr_indices[va_i]
            mdl = Ridge(alpha=alpha).fit(Xtr[tr], Ytr[tr])
            Yhat = mdl.predict(Xtr[va])
            for v in range(V):
                r = pearsonr(Ytr[va, v].squeeze(), Yhat[:, v].squeeze())[0]
                scores[ai, fi, v] = 0.0 if np.isnan(r) else r
    np.save(paths['cv_scores'], scores); np.save(paths['alphas'], alphas)
    best_alpha_idx = scores.mean(axis=1).argmax(axis=0)
    np.save(paths['best_alpha_idx'], best_alpha_idx)
    D = Xtr.shape[1]
    W = np.zeros((V, D), dtype=np.float32); b = np.zeros((V,), dtype=np.float32)
    for v in range(V):
        a = float(alphas[best_alpha_idx[v]])
        mdl_v = Ridge(alpha=a).fit(Xtr, Ytr[:, v])
        W[v, :] = mdl_v.coef_.astype(np.float32, copy=False); b[v] = np.float32(mdl_v.intercept_)
    np.save(paths['weights'], W); np.save(paths['intercepts'], b)
    brainmodel = _build_linear_from_weights(W, b)
    return (EncoderBM(enc, brainmodel, hook_layer).eval().to(device), MEAN, STD)

@torch.no_grad()
def _flatten_img(x): return x.view(-1)

def jacobian_50_wrt_pixels(model, x_pix01, MEAN, STD):
    x = x_pix01.clone().detach().requires_grad_(True)
    y = model(renorm(x, MEAN, STD)).squeeze(0)
    grads = []
    for i in range(y.shape[0]):
        if x.grad is not None: x.grad.zero_()
        y[i].backward(retain_graph=True)
        grads.append(_flatten_img(x.grad.detach().clone()))
    return torch.stack(grads, dim=0)

def fim_spectrum_for_image(model, MEAN, STD, x_pix01):
    J = jacobian_50_wrt_pixels(model, x_pix01, MEAN, STD)
    G = J @ J.t()
    evals, evecs = torch.linalg.eigh(G)
    idx = torch.arange(evals.numel()-1, -1, -1, device=evals.device)
    evals = evals[idx].clamp(min=0)
    evecs = evecs[:, idx]
    v_top = (J.t() @ evecs[:, 0]).contiguous()
    v_top = v_top / (v_top.norm(p=2) + 1e-12)
    return evals, evecs, v_top

def fim_effective_dim(evals: torch.Tensor):
    lam = evals.detach().cpu().double().numpy()
    s1 = lam.sum(); s2 = (lam**2).sum() + 1e-16
    return float((s1*s1)/s2)

def build_multi50_models(model_list, subj=subject, reg=region):
    out = {}
    for m in model_list:
        try: out[m] = get_or_build_readout_multi50(m, subj, reg)
        except Exception as e: warnings.warn(f"[{m}] multi50 build failed: {e}")
    return out

def random_subspace(P, k, device='cpu'):
    G = torch.randn(P, k, device=device)
    Q, _ = torch.linalg.qr(G, mode='reduced')
    return Q

def cosine_matrix(v_list):
    M = len(v_list)
    C = torch.zeros((M, M), dtype=torch.float32, device=v_list[0].device)
    for i in range(M):
        for j in range(M):
            C[i, j] = torch.abs(torch.dot(v_list[i], v_list[j]))
    return C

def energy_overlap_matrix(v_list, Q_list):
    M = len(v_list)
    E = torch.zeros((M, M), dtype=torch.float32, device=v_list[0].device)
    for i in range(M):
        for j in range(M):
            coeffs = Q_list[j].T @ v_list[i]
            E[i, j] = torch.sum(coeffs * coeffs)
    return E

def subspace_similarity_matrix(Q_list):
    M = len(Q_list)
    S = torch.zeros((M, M), dtype=torch.float32, device=Q_list[0].device)
    for i in range(M):
        for j in range(M):
            s = torch.linalg.svdvals(Q_list[i].T @ Q_list[j]).mean()
            S[i, j] = s
    return S

def fim_null_across_models(M, P, k, device='cpu'):
    v_list = [torch.randn(P, device=device) for _ in range(M)]
    v_list = [v / (v.norm() + 1e-12) for v in v_list]
    Q_list = [random_subspace(P, k, device=device) for _ in range(M)]
    return dict(cosine=cosine_matrix(v_list), energy=energy_overlap_matrix(v_list, Q_list), subspace=subspace_similarity_matrix(Q_list))

def reorder_by_diag_and_robust(mat, labels, robust_keywords=("robust", "adv")):
    diag_vals = mat.diag().cpu().numpy()
    order = np.argsort(-diag_vals)
    non_robust = [i for i in order if not any(k.lower() in labels[i].lower() for k in robust_keywords)]
    robust = [i for i in order if any(k.lower() in labels[i].lower() for k in robust_keywords)]
    new_order = non_robust + sorted(robust, key=lambda i: labels[i])
    return mat[new_order][:, new_order], [labels[i] for i in new_order]

def plot_heatmap(mat, labels, title, vmin=None, vmax=None):
    m = mat.detach().cpu().numpy()
    plt.figure(figsize=(max(8, 0.45*len(labels)+3), max(6, 0.45*len(labels)+3)))
    plt.gca().set_aspect('equal', adjustable='box')
    im = plt.imshow(m.T, aspect='auto', vmin=vmin, vmax=vmax, interpolation='nearest', cmap="gist_heat_r")
    plt.colorbar(im, fraction=0.046, pad=0.04)
    plt.xticks(ticks=np.arange(len(labels)), labels=labels, rotation=60, ha='right')
    plt.yticks(ticks=np.arange(len(labels)), labels=labels)
    plt.title(title)
    plt.tight_layout()

def fim_topk_pixel_basis(model, MEAN, STD, x_pix01, k=50, skip_first=False):
    J = jacobian_50_wrt_pixels(model, x_pix01, MEAN, STD)
    G = J @ J.t()
    evals, U = torch.linalg.eigh(G)
    idx = torch.arange(evals.numel()-1, -1, -1, device=evals.device)
    evals = evals[idx].clamp(min=0.0); U = U[:, idx]
    if skip_first:
        v_top = (J.t() @ U[:, 1])
        Uk = U[:, 1:k]
    else:
        v_top = (J.t() @ U[:, 0])
        Uk = U[:, :k]
    v_top = v_top / (v_top.norm(p=2) + 1e-12)
    Vpix = (J.t() @ Uk)
    Vpix = Vpix / (Vpix.norm(dim=0, keepdim=True) + 1e-12)
    Q, _ = torch.linalg.qr(Vpix, mode='reduced')
    return v_top.detach(), Q.detach(), evals.detach()

def fim_similarity_across_models(model_objs, img_idx, topk=50, labels=None, do_plot=True, do_reorder=True):
    names = list(model_objs.keys())
    labels = labels or names
    v_list, Q_list = [], []
    for name in names:
        model, MEAN, STD = model_objs[name]
        x_norm = preprocess_images(images_np, MEAN, STD, device=device)[int(img_idx)].unsqueeze(0)
        x_pix  = denorm(x_norm, MEAN, STD)
        v_top, Qk, _ = fim_topk_pixel_basis(model, MEAN, STD, x_pix, k=topk)
        v_list.append(v_top); Q_list.append(Qk)
    C = cosine_matrix(v_list); E = energy_overlap_matrix(v_list, Q_list); S = subspace_similarity_matrix(Q_list)
    out = dict(cosine_raw=C.clone(), energy_raw=E.clone(), subspace_raw=S.clone(), names=names)
    if do_reorder:
        C, labelsC = reorder_by_diag_and_robust(C, labels)
        E, labelsE = reorder_by_diag_and_robust(E, labels)
        S, labelsS = reorder_by_diag_and_robust(S, labels)
    else:
        labelsC = labelsE = labelsS = labels
    if do_plot:
        plot_heatmap(C, labelsC, f"Top-direction cosine |v_i·v_j| (img {int(img_idx)})", 0, 1)
        plt.savefig(os.path.join(SAVE_DIR, f"top_direction_cosine_img_{int(img_idx)}.pdf"))
        plot_heatmap(E, labelsE, f"Energy overlap (top-{topk}) (img {int(img_idx)})", 0, 1)
        plt.savefig(os.path.join(SAVE_DIR, f"energy_overlap_img_{int(img_idx)}.pdf"))
        plot_heatmap(S, labelsS, f"Subspace similarity (k={topk}) (img {int(img_idx)})", 0, 1)
        plt.savefig(os.path.join(SAVE_DIR, f"subspace_similarity_img_{int(img_idx)}.pdf"))
        M, P, k = len(Q_list), Q_list[0].shape[0], Q_list[0].shape[1]
        nulls = fim_null_across_models(M, P, k, device=Q_list[0].device)
        plot_heatmap(nulls['cosine'], labels, f"Null: cosine |v_i·v_j| (img {int(img_idx)})", 0, 1)
        plt.savefig(os.path.join(SAVE_DIR, f"null_top_direction_cosine_img_{int(img_idx)}.pdf"))
        plot_heatmap(nulls['energy'], labels, f"Null: energy overlap (img {int(img_idx)})", 0, 1)
        plt.savefig(os.path.join(SAVE_DIR, f"null_energy_overlap_img_{int(img_idx)}.pdf"))
        plot_heatmap(nulls['subspace'], labels, f"Null: subspace similarity (img {int(img_idx)})", 0, 1)
        plt.savefig(os.path.join(SAVE_DIR, f"null_subspace_similarity_img_{int(img_idx)}.pdf"))
    np.save(os.path.join(SAVE_DIR, f"C_img{int(img_idx)}_raw.npy"), out['cosine_raw'].detach().cpu().numpy())
    np.save(os.path.join(SAVE_DIR, f"E_img{int(img_idx)}_raw.npy"), out['energy_raw'].detach().cpu().numpy())
    np.save(os.path.join(SAVE_DIR, f"S_img{int(img_idx)}_raw.npy"), out['subspace_raw'].detach().cpu().numpy())
    np.save(os.path.join(SAVE_DIR, f"C_img{int(img_idx)}.npy"), C.detach().cpu().numpy())
    np.save(os.path.join(SAVE_DIR, f"E_img{int(img_idx)}.npy"), E.detach().cpu().numpy())
    np.save(os.path.join(SAVE_DIR, f"S_img{int(img_idx)}.npy"), S.detach().cpu().numpy())
    return out

def average_fim_similarities_across_images(model_objs, img_indices, topk=50, labels=None):
    img_indices = [int(i) for i in img_indices]
    accC = accE = accS = None; n = 0
    for idx in img_indices:
        out = fim_similarity_across_models(model_objs, img_idx=idx, topk=topk, labels=labels, do_plot=False, do_reorder=False)
        C, E, S = out['cosine_raw'].cpu(), out['energy_raw'].cpu(), out['subspace_raw'].cpu()
        if accC is None: accC, accE, accS = C.clone(), E.clone(), S.clone()
        else: accC += C; accE += E; accS += S
        n += 1
    C_avg, E_avg, S_avg = accC / n, accE / n, accS / n
    np.save(os.path.join(SAVE_DIR, f"C_avg_{n}imgs.npy"), C_avg.cpu().numpy().astype(np.float32))
    np.save(os.path.join(SAVE_DIR, f"E_avg_{n}imgs.npy"), E_avg.cpu().numpy().astype(np.float32))
    np.save(os.path.join(SAVE_DIR, f"S_avg_{n}imgs.npy"), S_avg.cpu().numpy().astype(np.float32))
    if labels is None: labels = list(model_objs.keys())
    np.save(os.path.join(SAVE_DIR, "model_labels.npy"), np.array(labels, dtype=object))
    C_plot, labelsC = reorder_by_diag_and_robust(C_avg, labels)
    E_plot, labelsE = reorder_by_diag_and_robust(E_avg, labels)
    S_plot, labelsS = reorder_by_diag_and_robust(S_avg, labels)
    plot_heatmap(C_plot, labelsC, f"AVG Top-direction cosine over {n} imgs", 0, 1)
    plt.savefig(os.path.join(SAVE_DIR, f"avg_top_direction_cosine_over_{n}_imgs.pdf"))
    plot_heatmap(E_plot, labelsE, f"AVG Energy overlap (k={topk}) over {n} imgs", 0, 1)
    plt.savefig(os.path.join(SAVE_DIR, f"avg_energy_overlap_k{topk}_over_{n}_imgs.pdf"))
    plot_heatmap(S_plot, labelsS, f"AVG Subspace similarity (k={topk}) over {n} imgs", 0, 1)
    plt.savefig(os.path.join(SAVE_DIR, f"avg_subspace_similarity_k{topk}_over_{n}_imgs.pdf"))
    names = list(model_objs.keys())
    any_model, MEAN, STD = next(iter(model_objs.values()))
    x_norm = preprocess_images(images_np, MEAN, STD, device=device)[int(img_indices[0])].unsqueeze(0)
    x_pix  = denorm(x_norm, MEAN, STD)
    v_top, Qk, _ = fim_topk_pixel_basis(any_model, MEAN, STD, x_pix, k=topk)
    nulls = fim_null_across_models(len(names), v_top.numel(), Qk.shape[1], device=Qk.device)
    plot_heatmap(nulls['cosine'], names, "AVG-Null: cosine |v_i·v_j|", 0, 1)
    plt.savefig(os.path.join(SAVE_DIR, f"avg_null_top_direction_cosine.pdf"))
    plot_heatmap(nulls['energy'], names, f"AVG-Null: energy overlap (k={topk})", 0, 1)
    plt.savefig(os.path.join(SAVE_DIR, f"avg_null_energy_overlap_k{topk}.pdf"))
    plot_heatmap(nulls['subspace'], names, f"AVG-Null: subspace similarity (k={topk})", 0, 1)
    plt.savefig(os.path.join(SAVE_DIR, f"avg_null_subspace_similarity_k{topk}.pdf"))
    return dict(C_avg=C_avg, E_avg=E_avg, S_avg=S_avg, labels=labels)

suite_results = attack_suite_and_visualize_safe(
    MODEL_LIST,
    img_indices=SELECTED_IMAGE_INDICES,
    eps_l2_pix=EPS_L2_PIX,
    step_pix=STEP_PIX,
    steps=STEPS,
    subj=subject,
    reg=region,
    direction='minimize',
    save_partial=True
)

multi50_models = build_multi50_models(MODEL_LIST, subj=subject, reg=region)

avg_sim = average_fim_similarities_across_images(
    multi50_models,
    img_indices=SELECTED_IMAGE_INDICES,
    topk=50,
    labels=MODEL_LIST
)
