import os
import numpy as np
import torch
import gc
from natsort import natsorted
from sklearn.linear_model import Ridge
from scipy.stats import pearsonr
import pickle
import sys
import pandas as pd

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from brainmodels.utils import preprocess_images
from brainmodels.models import load_model

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, required=True)
args = parser.parse_args()
MODEL_NAME = args.model_name

n_images = 50
n_voxels = 50
subjects = ['s1', 's2', 's5', 's7']
regions = ['FFA', 'EBA', 'PPA']

with open('saved/subject_region_to_top50_global.pkl', 'rb') as f:
    subject_region_to_top50_global = pickle.load(f)
with open('saved/subject_region_to_global_indices.pkl', 'rb') as f:
    subject_region_to_global_indices = pickle.load(f)
with open('saved/best_layers_per_subj_region.pkl', 'rb') as f:
    best_layers = pickle.load(f)

brain_data_paths = ['../nsd_processed/s1_FFA_t7.pt']
brain_datas = [torch.load(p) for p in brain_data_paths]
brain_data = torch.cat(brain_datas, dim=1)

train_idx = torch.load('../nsd_processed/485_unique.pt')
test_idx = torch.load('../nsd_processed/515_shared.pt')
images = np.load('../nsd_processed/nsd_stimuli1000.npy')

def split_data(acts, brain):
    train_idx = torch.load('../nsd_processed/485_unique.pt')
    test_idx = torch.load('../nsd_processed/515_shared.pt')
    return (acts[train_idx], acts[test_idx], brain[train_idx], brain[test_idx])

encoder, (MEAN, STD) = load_model(MODEL_NAME)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CLAMP_MIN = (torch.tensor([0.0, 0.0, 0.0]).view(1,3,1,1).to(device) - torch.tensor(MEAN).view(1,3,1,1).to(device)) / torch.tensor(STD).view(1,3,1,1).to(device)
CLAMP_MAX = (torch.tensor([1.0, 1.0, 1.0]).view(1,3,1,1).to(device) - torch.tensor(MEAN).view(1,3,1,1).to(device)) / torch.tensor(STD).view(1,3,1,1).to(device)

images_tensor = preprocess_images(images, MEAN, STD, device=device)
images_test_tensor = images_tensor[test_idx]
selected_image_indices = np.load(os.path.expanduser('fixed_test_indices.npy'))[:n_images]

def build_brainmodel_from_linear_model(lm, input_dim, output_dim, device):
    bm = torch.nn.Sequential(torch.nn.Linear(input_dim, output_dim))
    bm[0].weight = torch.nn.Parameter(torch.tensor(lm.coef_, dtype=torch.float32))
    bm[0].bias = torch.nn.Parameter(torch.tensor(lm.intercept_, dtype=torch.float32))
    bm.to(device)
    return bm

class EncoderBM(torch.nn.Module):
    def __init__(self, encoder, brainmodel, layer_name):
        super().__init__()
        self.encoder = encoder
        self.brainmodel = brainmodel
        self.layer_name = layer_name
        self.activation = {}
        layer = dict([*self.encoder.named_modules()])[self.layer_name]
        layer.register_forward_hook(self.get_activation())
    def get_activation(self):
        def hook(model, input, output):
            self.activation[self.layer_name] = output[0] if isinstance(output, tuple) else output
        return hook
    def forward(self, x):
        _ = self.encoder.encode_image(x) if hasattr(self.encoder, "encode_image") else self.encoder(x)
        x = self.activation[self.layer_name]
        if hasattr(self, "mapping") and self.mapping == "cnn":
            x = x.to(torch.float32)
        elif x.ndim == 3:
            x = x.mean(dim=1).to(torch.float32)
        else:
            x = torch.flatten(x, start_dim=1).to(torch.float32)
        return self.brainmodel(x)

def fgsm_single_direction(model, image, device, epsilon, alpha, iterations, voxel_idx, direction):
    image = image.to(device)
    delta = torch.zeros_like(image, device=device)
    for _ in range(iterations):
        delta.requires_grad_()
        x_adv = (image + delta).clamp(CLAMP_MIN, CLAMP_MAX)
        out = model(x_adv)
        val = out.squeeze()[voxel_idx]
        loss = -val if direction == 'minimize' else val
        for p in model.parameters():
            if p.grad is not None:
                p.grad = None
        loss.backward()
        with torch.no_grad():
            delta += alpha * delta.grad.sign()
            delta.clamp_(-epsilon, epsilon)
        delta.grad = None
    pert_img = (image + delta).clamp(CLAMP_MIN, CLAMP_MAX)
    return pert_img.detach(), model(pert_img).squeeze()[voxel_idx].item(), model(image).squeeze()[voxel_idx].item()

def fgsm_params():
    eps = torch.tensor([3/255, 3/255, 3/255]).view(1,3,1,1).to(device)
    alpha = torch.tensor([1/255, 1/255, 1/255]).view(1,3,1,1).to(device)
    eps = eps / torch.tensor(STD).view(1,3,1,1).to(device)
    alpha = alpha / torch.tensor(STD).view(1,3,1,1).to(device)
    return eps, alpha, 3, fgsm_single_direction

def l2_attack_single_direction(model, image, device, epsilon, alpha, iterations, voxel_idx, direction):
    image = image.to(device)
    delta = torch.zeros_like(image, device=device)
    std = torch.tensor(STD, device=device, dtype=image.dtype).view(1,3,1,1)
    for _ in range(iterations):
        delta.requires_grad_(True)
        x_adv = (image + delta).clamp(CLAMP_MIN, CLAMP_MAX)
        out = model(x_adv)
        val = out.squeeze()[voxel_idx]
        loss = -val if direction == 'minimize' else val
        for p in model.parameters():
            if p.grad is not None:
                p.grad = None
        loss.backward()
        grad = delta.grad.detach()
        grad_pix = grad * std
        gp = grad_pix.view(grad_pix.size(0), -1)
        gp_norm = gp.norm(p=2, dim=1, keepdim=True).clamp_min(1e-8).view(-1,1,1,1)
        with torch.no_grad():
            step_norm = (alpha * grad_pix / gp_norm) / std
            delta += step_norm
            delta_pix = delta * std
            dp = delta_pix.view(delta_pix.size(0), -1)
            dp_norm = dp.norm(p=2, dim=1, keepdim=True).clamp_min(1e-8).view(-1,1,1,1)
            scale = (epsilon / dp_norm).clamp(max=1.0)
            delta *= scale
        delta.grad = None
    pert_img = (image + delta).clamp(CLAMP_MIN, CLAMP_MAX)
    return pert_img.detach(), model(pert_img).squeeze()[voxel_idx].item(), model(image).squeeze()[voxel_idx].item()

def l2_params():
    eps = torch.tensor(5.0, device=device)
    alpha = torch.tensor(1.0, device=device)
    return eps, alpha, 5, l2_attack_single_direction

def get_shuffled_control(image, delta_norm, clamp_min, clamp_max, device):
    std = torch.tensor(STD, device=device, dtype=image.dtype).view(1,3,1,1)
    delta_pix = delta_norm * std
    shuffled_pix = torch.empty_like(delta_pix)
    B = delta_pix.shape[0]
    for b in range(B):
        flat = delta_pix[b].reshape(-1)
        idx = torch.randperm(flat.numel(), device=device)
        shuffled_pix[b] = flat[idx].view_as(delta_pix[b])
    delta_norm_shuf = shuffled_pix / std
    return (image + delta_norm_shuf).clamp(clamp_min, clamp_max)

for subject in subjects:
    torch.cuda.empty_cache()
    gc.collect()
    for region in regions:
        RESULTS_TAG = '_'
        results_dir = os.path.join(os.getcwd(), f'corrs_{RESULTS_TAG}')
        results_folder = os.path.join(os.getcwd(), f'results_{RESULTS_TAG}')
        os.makedirs(os.path.join(results_dir, 'voxel_corrs'), exist_ok=True)
        os.makedirs(results_folder, exist_ok=True)
        csv_path = os.path.join(results_folder, f'{MODEL_NAME}_{subject}_{region}_attack_results.csv')
        if os.path.exists(csv_path):
            print(f"Skipping {subject} {region} — results already exist.")
            continue
        print(f'=== Running {subject} {region} ===')

        layer_name = best_layers[MODEL_NAME][subject, region]
        act_root = 'CAMERA_READY/attacks/activations'
        dir_ = os.path.join(act_root, MODEL_NAME, layer_name)
        best_alphas = np.load(os.path.join(dir_, 'best_alphas.npy'))
        if MODEL_NAME != 'dreamsim_vitb32' and layer_name.startswith('model.'):
            layer_name = layer_name.split('model.')[1]

        batch_files = natsorted([f for f in os.listdir(dir_) if f.startswith('batch')])
        activations = np.concatenate([np.load(os.path.join(dir_, f)) for f in batch_files], axis=0)
        activations_flattened = activations.reshape(activations.shape[0], -1)
        acts_train, acts_test, brain_train, brain_test = split_data(activations_flattened, brain_data.numpy())

        indices = subject_region_to_global_indices[(subject, region)]
        brain_train_sub = brain_train[:, indices]
        brain_test_sub = brain_test[:, indices]
        ridgemodel = Ridge(alpha=best_alphas[indices], max_iter=5000)
        ridgemodel.fit(acts_train, brain_train_sub)
        y_pred = ridgemodel.predict(acts_test)
        voxel_corr = np.array([pearsonr(brain_test_sub[:, v], y_pred[:, v])[0] for v in range(y_pred.shape[1])])
        np.savez(os.path.join(results_dir, 'voxel_corrs', f'{MODEL_NAME}_{subject}_{region}_voxel_corr.npz'), voxel_corr=voxel_corr)

        top50_voxels = subject_region_to_top50_global[(subject, region)][:n_voxels]
        brainmodel = build_brainmodel_from_linear_model(ridgemodel, ridgemodel.coef_.shape[1], brain_train_sub.shape[1], device)
        encoderbm = EncoderBM(encoder, brainmodel, layer_name).to(device).eval()

        results = []
        for image_count, image_idx in enumerate(selected_image_indices):
            image = images_test_tensor[image_idx]
            for v_count, voxel_idx in enumerate(top50_voxels):
                voxel_local_idx = indices.index(voxel_idx)
                label = brain_test_sub[image_idx, voxel_local_idx]
                eps, alpha, its, attack_fn = fgsm_params()
                fgsm_max_img, fgsm_max_resp, fgsm_orig_resp = attack_fn(encoderbm, image.unsqueeze(0), device, eps, alpha, its, voxel_local_idx, 'maximize')
                delta_max = fgsm_max_img - image.unsqueeze(0)
                fgsm_max_control_img = get_shuffled_control(image.unsqueeze(0), delta_max, CLAMP_MIN, CLAMP_MAX, device)
                fgsm_max_control_resp = encoderbm(fgsm_max_control_img).squeeze()[voxel_local_idx].item()
                fgsm_min_img, fgsm_min_resp, _ = attack_fn(encoderbm, image.unsqueeze(0), device, eps, alpha, its, voxel_local_idx, 'minimize')
                delta_min = fgsm_min_img - image.unsqueeze(0)
                fgsm_min_control_img = get_shuffled_control(image.unsqueeze(0), delta_min, CLAMP_MIN, CLAMP_MAX, device)
                fgsm_min_control_resp = encoderbm(fgsm_min_control_img).squeeze()[voxel_local_idx].item()
                eps, alpha, its, attack_fn = l2_params()
                l2_max_img, l2_max_resp, l2_orig_resp = attack_fn(encoderbm, image.unsqueeze(0), device, eps, alpha, its, voxel_local_idx, 'maximize')
                l2_max_control_img = get_shuffled_control(image.unsqueeze(0), l2_max_img - image.unsqueeze(0), CLAMP_MIN, CLAMP_MAX, device)
                l2_max_control_resp = encoderbm(l2_max_control_img).squeeze()[voxel_local_idx].item()
                l2_min_img, l2_min_resp, _ = attack_fn(encoderbm, image.unsqueeze(0), device, eps, alpha, its, voxel_local_idx, 'minimize')
                l2_min_control_img = get_shuffled_control(image.unsqueeze(0), l2_min_img - image.unsqueeze(0), CLAMP_MIN, CLAMP_MAX, device)
                l2_min_control_resp = encoderbm(l2_min_control_img).squeeze()[voxel_local_idx].item()
                results.append((image_idx, voxel_idx, 'fgsm', label, float(fgsm_orig_resp), float(fgsm_max_resp), 'maximize', float(fgsm_max_control_resp)))
                results.append((image_idx, voxel_idx, 'fgsm', label, float(fgsm_orig_resp), float(fgsm_min_resp), 'minimize', float(fgsm_min_control_resp)))
                results.append((image_idx, voxel_idx, 'l2', label, float(l2_orig_resp), float(l2_max_resp), 'maximize', float(l2_max_control_resp)))
                results.append((image_idx, voxel_idx, 'l2', label, float(l2_orig_resp), float(l2_min_resp), 'minimize', float(l2_min_control_resp)))
            print(f"Done {image_count+1}/{n_images} images.")

        df_results = pd.DataFrame(results, columns=['image_idx', 'voxel_idx', 'attack_type', 'label', 'old_response', 'perturbed_response', 'direction', 'control_response'])
        df_results.to_csv(os.path.join(results_folder, f'{MODEL_NAME}_{subject}_{region}_attack_results.csv'), index=False)
        torch.cuda.empty_cache()
        gc.collect()
        print(f'Done. Results saved for {subject}, {region}.')

print('All done.')
