import os, sys, gc, random, pickle
import numpy as np
import torch
import torch.nn as nn
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from natsort import natsorted
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer
from PIL import Image

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from brainmodels.models import load_model

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(1)
random.seed(1)
torch.manual_seed(1)

IMAGE_ROOT = '/activations/transfer_attacks/object_images'
MODELS = ['L2-RN50-robust-5']
SUBJECT = 's1'
REGIONS = ['FFA', 'EBA', 'PPA']
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MAX_STEPS = 200
THRESHOLD = 3.0
TARGET_SIS_ONE_SIDE = (0.2, 0.4, 0.6, 0.8, 1.0)
alphas = np.logspace(-2, 6, 20)


def ensure_dir(p):
    os.makedirs(p, exist_ok=True)


def sample_image_paths(root_dir, k=100, exts=('.jpg', '.jpeg', '.png', '.bmp', '.webp')):
    files = []
    for r, _, fs in os.walk(root_dir):
        for f in fs:
            if f.lower().endswith(exts):
                files.append(os.path.join(r, f))
    if not files:
        raise RuntimeError(f"No images under {root_dir}")
    return random.sample(files, min(k, len(files)))


def pix_to_norm(mean, std, img_t, device):
    mean_t = torch.tensor(mean, device=device).view(1, 3, 1, 1)
    std_t = torch.tensor(std, device=device).view(1, 3, 1, 1)
    return (img_t - mean_t) / std_t


def denormalize(img, mean, std, device):
    mean = torch.tensor(mean, device=device).view(1, 3, 1, 1)
    std = torch.tensor(std, device=device).view(1, 3, 1, 1)
    return img * std + mean


def load_png_image(img_path, mean, std, device):
    img = Image.open(img_path).convert("RGB").resize((224, 224), Image.BILINEAR)
    x = np.asarray(img, dtype=np.float32) / 255.0
    x = torch.tensor(x.transpose(2, 0, 1), device=device).unsqueeze(0)
    x_norm = pix_to_norm(mean, std, x, device)
    x_pix = denormalize(x_norm, mean, std, device).clamp(0, 1)
    return x_norm, x_pix


def resolve_hook_name(encoder, requested):
    name2mod = dict(encoder.named_modules())
    if requested in name2mod:
        return requested
    variants = set()
    variants.add('model.' + requested if not requested.startswith('model.') else requested[len('model.'):])
    variants.add('module.' + requested if not requested.startswith('module.') else requested[len('module.'):])
    for v in variants:
        if v in name2mod:
            return v
    toks = requested.split('.')
    tail2 = '.'.join(toks[-2:]) if len(toks) >= 2 else requested
    tail1 = toks[-1]
    m = [k for k in name2mod if k.endswith(tail2)]
    if len(m) == 1:
        return m[0]
    if not m:
        m = [k for k in name2mod if k.endswith(tail1)]
        if len(m) == 1:
            return m[0]
    if m:
        return m[0]
    raise KeyError(f"Cannot resolve hook '{requested}'.")


class EncoderBM(nn.Module):
    def __init__(self, encoder, brainmodel, layer_name):
        super().__init__()
        self.encoder = encoder
        self.brainmodel = brainmodel
        self.layer_name = resolve_hook_name(self.encoder, layer_name)
        self.activation = {}
        dict(self.encoder.named_modules())[self.layer_name].register_forward_hook(self._hook())

    def _hook(self):
        def h(_, __, out):
            self.activation[self.layer_name] = out[0] if isinstance(out, tuple) else out
        return h

    def forward(self, x):
        _ = self.encoder.encode_image(x) if hasattr(self.encoder, "encode_image") else self.encoder(x)
        x = self.activation[self.layer_name]
        if getattr(self, "mapping", None) == "cnn":
            x = x.to(torch.float32)
        elif x.ndim == 3:
            x = x.mean(dim=1).to(torch.float32)
        else:
            x = torch.flatten(x, start_dim=1).to(torch.float32)
        return self.brainmodel(x)


def build_brainmodel_from_linear_model(lm, input_dim, output_dim, device):
    bm = nn.Sequential(nn.Linear(input_dim, output_dim))
    bm[0].weight = nn.Parameter(torch.tensor(lm.coef_, dtype=torch.float32))
    bm[0].bias = nn.Parameter(torch.tensor(lm.intercept_, dtype=torch.float32))
    bm.to(device)
    return bm


def load_dir_for_layer(model_name, layer_name):
    dir_ = f'/activations/{model_name}/{layer_name}/' # or where activations are
    if not os.path.exists(os.path.join(dir_, 'best_alphas.npy')):
        alt = layer_name.split('model.', 1)[1] if layer_name.startswith('model.') else 'model.' + layer_name
        dir_ = f'/activations/{model_name}/{alt}/'
    return dir_


def pearson_r(y_true, y_pred):
    from scipy.stats import pearsonr as pr
    return pr(y_true.squeeze(), y_pred.squeeze())[0]


scorer = make_scorer(pearson_r, greater_is_better=True)


def l2_step_sensitivity(encbm, x_norm, mean, std, clamp_min, clamp_max, max_steps=30, threshold=1.0, direction='maximize', device=torch.device('cpu'), targets=(0.0, 0.5, 1.0, 1.5)):
    x0 = x_norm.detach()
    d = torch.zeros_like(x_norm, device=device)
    std_t = torch.tensor(std, device=device).view(1, 3, 1, 1)
    r_orig = encbm(x0).squeeze().detach().cpu().item()
    targets_sorted = sorted(list(targets))
    snapshots = {}
    if 0.0 in targets_sorted:
        snapshots[0.0] = (denormalize(x0, mean, std, device).clamp(0, 1).cpu().squeeze(0).clone(), 0.0, 0)
    for step in range(1, max_steps + 1):
        x_adv = (x0 + d).clamp(clamp_min, clamp_max).clone().detach().requires_grad_(True)
        out = encbm(x_adv).squeeze()
        loss = out if direction == 'maximize' else -out
        grad_x = torch.autograd.grad(loss, x_adv)[0]
        g_pix = grad_x * std_t
        gp = g_pix.reshape(g_pix.size(0), -1)
        gp_norm = gp.norm(p=2, dim=1, keepdim=True).clamp_min(1e-8).view(-1, 1, 1, 1)
        step_pix = g_pix / gp_norm
        step_norm = step_pix / std_t
        with torch.no_grad():
            d = (d + step_norm).detach()
            d = ((x0 + d).clamp(clamp_min, clamp_max) - x0).detach()
        x_now = (x0 + d).clamp(clamp_min, clamp_max).detach()
        r_now = encbm(x_now).squeeze().detach().cpu().item()
        si = abs(r_orig - r_now)
        for t in targets_sorted:
            if t not in snapshots and si >= t:
                snapshots[t] = (denormalize(x_now, mean, std, device).clamp(0, 1).cpu().squeeze(0).clone(), float(si), step)
        if si > threshold:
            break
    final_norm = (x0 + d).clamp(clamp_min, clamp_max)
    final_pix = denormalize(final_norm, mean, std, device).clamp(0, 1).cpu().squeeze(0)
    final_si = abs(r_orig - encbm(final_norm).squeeze().detach().cpu().item())
    for t in targets_sorted:
        if t not in snapshots:
            snapshots[t] = (final_pix.clone(), float(final_si), step)
    del x0, d, std_t, x_adv, grad_x, g_pix, gp, gp_norm, step_pix, step_norm
    torch.cuda.empty_cache(); gc.collect()
    return final_norm, snapshots


def load_image_for_model(mean, std, device, png_path):
    return load_png_image(png_path, mean, std, device)


def prepare_encoderbm(model_name, subject, region, device):
    encoder, (MEAN, STD) = load_model(model_name)
    encoder.to(device).eval()
    with open('saved/best_layers_per_subj_region.pkl', 'rb') as f:
        best_layers = pickle.load(f)
    best_layers['VGG16-robust-l2-3'] = best_layers.get('vgg16', best_layers.get(model_name, {}))
    layer_name = best_layers[model_name].get((subject, region))
    if 'robust' in model_name.lower():
        layer_name = 'avgpool'
    dir_ = load_dir_for_layer(model_name, layer_name)
    batch_files = natsorted([f for f in os.listdir(dir_) if f.startswith('batch')])
    acts = np.concatenate([np.load(os.path.join(dir_, f)) for f in batch_files], axis=0)
    acts_flat = acts.reshape(acts.shape[0], -1).astype(np.float32)
    del acts; gc.collect()
    brain_data_paths = [
        '../nsd_processed/s1_FFA_t7.pt', '../nsd_processed/s2_FFA_t7.pt', '../nsd_processed/s5_FFA_t7.pt', '../nsd_processed/s7_FFA_t7.pt',
        '../nsd_processed/s1_EBA_t7.pt', '../nsd_processed/s2_EBA_t7.pt', '../nsd_processed/s5_EBA_t7.pt', '../nsd_processed/s7_EBA_t7.pt',
        '../nsd_processed/s1_PPA_t7.pt', '../nsd_processed/s2_PPA_t7.pt', '../nsd_processed/s5_PPA_t7.pt', '../nsd_processed/s7_PPA_t7.pt'
    ]
    brain_datas = [torch.load(p) for p in brain_data_paths]
    brain_data = torch.cat(brain_datas, dim=1).numpy()
    for b in brain_datas: del b
    gc.collect()
    train_idx = torch.load('../nsd_processed/485_unique.pt').cpu().numpy()
    test_idx = torch.load('../nsd_processed/515_shared.pt').cpu().numpy()
    acts_train = acts_flat[train_idx]
    brain_train = brain_data[train_idx]
    with open('saved/subject_region_to_top50_global.pkl', 'rb') as f:
        subject_region_to_top50_global = pickle.load(f)
    top50 = subject_region_to_top50_global[(subject, region)][:50]
    brain_train_sub = brain_train[:, top50].mean(axis=1, keepdims=True)
    del brain_data, acts_flat, train_idx, test_idx
    gc.collect()
    grid = GridSearchCV(Ridge(), {'alpha': alphas}, scoring=scorer, cv=5)
    grid.fit(acts_train, brain_train_sub)
    ridgemodel = grid.best_estimator_
    ridgemodel.coef_ = ridgemodel.coef_.reshape(1, -1)
    brainmodel = build_brainmodel_from_linear_model(ridgemodel, ridgemodel.coef_.shape[1], brain_train_sub.shape[1], device)
    del acts_train, brain_train, brain_train_sub, grid, ridgemodel
    gc.collect()
    hook_name = layer_name.split('model.', 1)[1] if layer_name.startswith('model.') else layer_name
    encbm = EncoderBM(encoder, brainmodel, hook_name).to(device).eval()
    mean_t = torch.tensor(MEAN, device=device).view(1, 3, 1, 1)
    std_t = torch.tensor(STD, device=device).view(1, 3, 1, 1)
    clamp_min = (torch.tensor([0, 0, 0], device=device).view(1, 3, 1, 1) - mean_t) / std_t
    clamp_max = (torch.tensor([1, 1, 1], device=device).view(1, 3, 1, 1) - mean_t) / std_t
    del mean_t, std_t
    torch.cuda.empty_cache(); gc.collect()
    return encbm, MEAN, STD, clamp_min, clamp_max


def run_for_one_model(encbm_src, MEAN_src, STD_src, CLAMP_MIN_src, CLAMP_MAX_src, model_name, device, png_path):
    x_norm, x_pix_clean = load_image_for_model(MEAN_src, STD_src, device, png_path)
    TARGET_SIS_MIN = (0.0,) + TARGET_SIS_ONE_SIDE
    TARGET_SIS_MAX = TARGET_SIS_ONE_SIDE
    _, snaps_min = l2_step_sensitivity(encbm_src, x_norm, MEAN_src, STD_src, CLAMP_MIN_src, CLAMP_MAX_src, max_steps=MAX_STEPS, threshold=THRESHOLD, direction='minimize', device=device, targets=TARGET_SIS_MIN)
    _, snaps_max = l2_step_sensitivity(encbm_src, x_norm, MEAN_src, STD_src, CLAMP_MIN_src, CLAMP_MAX_src, max_steps=MAX_STEPS, threshold=THRESHOLD, direction='maximize', device=device, targets=TARGET_SIS_MAX)
    x_pix_clean_img = snaps_min[0.0][0]
    cols_min_desc = list(reversed(TARGET_SIS_ONE_SIDE))
    cols_max_asc = list(TARGET_SIS_ONE_SIDE)
    imgs, diffs, heatmaps, titles = [], [], [], []
    for t in cols_min_desc:
        img_t, si_t, step_idx = snaps_min[t]
        imgs.append(img_t.permute(1, 2, 0).numpy())
        titles.append(f"min: si={t} (ach {si_t:.3f}, step={step_idx})")
        diff = (img_t - x_pix_clean_img).permute(1, 2, 0).numpy()
        dmin, dmax = diff.min(), diff.max()
        diffs.append(np.zeros_like(diff) if abs(dmax - dmin) < 1e-8 else (diff - dmin) / (dmax - dmin))
        mag = np.linalg.norm(diff, axis=2)
        mag = (mag - mag.min()) / (mag.max() - mag.min() + 1e-8)
        heatmaps.append(mag)
    imgs.append(x_pix_clean_img.permute(1, 2, 0).numpy())
    titles.append("clean (si=0.0)")
    diff_zero = np.zeros_like(imgs[-1])
    diffs.append(diff_zero)
    heatmaps.append(np.zeros_like(heatmaps[-1]))
    for t in cols_max_asc:
        img_t, si_t, step_idx = snaps_max[t]
        imgs.append(img_t.permute(1, 2, 0).numpy())
        titles.append(f"max: si={t} (ach {si_t:.3f}, step={step_idx})")
        diff = (img_t - x_pix_clean_img).permute(1, 2, 0).numpy()
        dmin, dmax = diff.min(), diff.max()
        diffs.append(np.zeros_like(diff) if abs(dmax - dmin) < 1e-8 else (diff - dmin) / (dmax - dmin))
        mag = np.linalg.norm(diff, axis=2)
        mag = (mag - mag.min()) / (mag.max() - mag.min() + 1e-8)
        heatmaps.append(mag)
    del snaps_min, snaps_max, x_pix_clean, x_norm, x_pix_clean_img
    torch.cuda.empty_cache(); gc.collect()
    return {'model_name': model_name, 'imgs': imgs, 'diffs': diffs, 'heatmaps': heatmaps, 'titles': titles}


def main():
    sampled_images = sample_image_paths(IMAGE_ROOT, k=100)
    for region in REGIONS:
        out_dir = f"newest_{region}"
        ensure_dir(out_dir)
        cache = {}
        try:
            for m in MODELS:
                cache[m] = prepare_encoderbm(m, SUBJECT, region, DEVICE)
            ncols = len(TARGET_SIS_ONE_SIDE) * 2 + 1
            nrows = 2 * len(MODELS)
            for idx, img_path in enumerate(sampled_images, 1):
                results, fig = [], None
                try:
                    for m in MODELS:
                        encbm_src, MEAN_src, STD_src, CLAMP_MIN_src, CLAMP_MAX_src = cache[m]
                        results.append(run_for_one_model(encbm_src, MEAN_src, STD_src, CLAMP_MIN_src, CLAMP_MAX_src, m, DEVICE, png_path=img_path))
                        torch.cuda.empty_cache(); gc.collect()
                    fig, axes = plt.subplots(nrows, ncols, figsize=(4.8 * ncols, 3.8 * nrows))
                    if nrows == 2:
                        axes = np.array(axes).reshape(2, ncols)
                    for mi, res in enumerate(results):
                        row_base = mi * 2
                        for ci in range(ncols):
                            axes[row_base, ci].imshow(res['imgs'][ci])
                            axes[row_base, ci].set_title(f"{res['model_name']} | {res['titles'][ci]}", fontsize=12)
                            axes[row_base, ci].axis('off')
                            axes[row_base + 1, ci].imshow(res['diffs'][ci], cmap='bwr')
                            axes[row_base + 1, ci].axis('off')
                    plt.tight_layout()
                    base = os.path.splitext(os.path.basename(img_path))[0]
                    fig.savefig(os.path.join(out_dir, f"{base}_{region}.pdf"), bbox_inches='tight', dpi=300)
                finally:
                    if fig is not None:
                        plt.close(fig)
                    for res in results:
                        res['imgs'].clear(); res['diffs'].clear(); res['heatmaps'].clear(); res['titles'].clear()
                    results.clear()
                    torch.cuda.empty_cache(); gc.collect()
        finally:
            for m in list(cache.keys()):
                del cache[m]
            cache.clear()
            torch.cuda.empty_cache(); gc.collect()


if __name__ == "__main__":
    main()
