import os
import numpy as np
import torch
import gc
import random
from natsort import natsorted
from PIL import Image
from sklearn.linear_model import Ridge
from scipy.stats import pearsonr
import pickle
import sys
import pandas as pd
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer


sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
import torch.nn as nn
from brainmodels.utils import preprocess_images
from brainmodels.models import load_model

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, required=True)
args = parser.parse_args()
MODEL_NAME = args.model_name

def pearson_r(y_true, y_pred):
    return pearsonr(y_true.squeeze(), y_pred.squeeze())[0]

scorer = make_scorer(pearson_r, greater_is_better=True)
alphas = np.logspace(-2, 6, 20)


n_images = 50
n_voxels = 50


RESULTS_TAG = 'epsilon_curve_new'
results_dir = os.path.join(os.getcwd(), f'corrs_{RESULTS_TAG}')
root_img_dir = os.path.join(results_dir, 'attack_results', MODEL_NAME)
results_folder = os.path.join(os.getcwd(), f'results_{RESULTS_TAG}')
os.makedirs(os.path.join(results_dir, 'voxel_corrs'), exist_ok=True)
os.makedirs(root_img_dir, exist_ok=True)
os.makedirs(results_folder, exist_ok=True)



q
subjects = ['s1', 's2', 's5', 's7']
regions = ['FFA', 'EBA', 'PPA']


with open('saved/subject_region_to_top50_global.pkl', 'rb') as f:
    subject_region_to_top50_global = pickle.load(f)
with open('saved/subject_region_to_global_indices.pkl', 'rb') as f:
    subject_region_to_global_indices = pickle.load(f)
with open('saved/best_layers_per_subj_region.pkl', 'rb') as f:
    best_layers = pickle.load(f)


brain_data_paths = [
    '../nsd_processed/s1_FFA_t7.pt', '../nsd_processed/s2_FFA_t7.pt', '../nsd_processed/s5_FFA_t7.pt', '../nsd_processed/s7_FFA_t7.pt',
    '../nsd_processed/s1_EBA_t7.pt', '../nsd_processed/s2_EBA_t7.pt', '../nsd_processed/s5_EBA_t7.pt', '../nsd_processed/s7_EBA_t7.pt',
    '../nsd_processed/s1_PPA_t7.pt', '../nsd_processed/s2_PPA_t7.pt', '../nsd_processed/s5_PPA_t7.pt', '../nsd_processed/s7_PPA_t7.pt'
]
brain_datas = [torch.load(p) for p in brain_data_paths]
brain_data = torch.cat(brain_datas, dim=1)  # shape [1000, 3514]

train_idx = torch.load('../nsd_processed/485_unique.pt')
test_idx = torch.load('../nsd_processed/515_shared.pt')

images = np.load('../nsd_processed/nsd_stimuli1000.npy')

def split_data(acts, brain):
    train_idx = torch.load('../nsd_processed/485_unique.pt')
    test_idx = torch.load('../nsd_processed/515_shared.pt')
    return (acts[train_idx], acts[test_idx], brain[train_idx], brain[test_idx])


encoder, (MEAN, STD) = load_model(MODEL_NAME)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CLAMP_MIN = (torch.tensor([0.0, 0.0, 0.0]).view(1,3,1,1).to(device) - torch.tensor(MEAN).view(1,3,1,1).to(device)) / torch.tensor(STD).view(1,3,1,1).to(device)
CLAMP_MAX = (torch.tensor([1.0, 1.0, 1.0]).view(1,3,1,1).to(device) - torch.tensor(MEAN).view(1,3,1,1).to(device)) / torch.tensor(STD).view(1,3,1,1).to(device)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
images_tensor = preprocess_images(images, MEAN, STD, device=device)
images_test_tensor = images_tensor[test_idx]
selected_image_indices = np.load(os.path.expanduser('fixed_test_indices.npy'))[:n_images]

def build_brainmodel_from_linear_model(lm, input_dim, output_dim, device):
    bm = torch.nn.Sequential(torch.nn.Linear(input_dim, output_dim))
    bm[0].weight = torch.nn.Parameter(torch.tensor(lm.coef_, dtype=torch.float32))
    bm[0].bias = torch.nn.Parameter(torch.tensor(lm.intercept_, dtype=torch.float32))
    bm.to(device)
    return bm

class EncoderBM(torch.nn.Module):
    def __init__(self, encoder, brainmodel, layer_name):
        super().__init__()
        self.encoder = encoder
        self.brainmodel = brainmodel
        self.layer_name = layer_name
        self.activation = {}
        layer = dict([*self.encoder.named_modules()])[self.layer_name]
        layer.register_forward_hook(self.get_activation())
    def get_activation(self):
        def hook(model, input, output):
            # If output is a tuple, take the first element
            self.activation[self.layer_name] = output[0] if isinstance(output, tuple) else output
        return hook
    def forward(self, x):
        # Call the right method for each model
        if hasattr(self.encoder, "encode_image"):
            _ = self.encoder.encode_image(x)
        else:
            _ = self.encoder(x)
        x = self.activation[self.layer_name]
        if hasattr(self, "mapping") and self.mapping == "cnn":
            x = x.to(torch.float32)
        elif x.ndim == 3:
            x = x.mean(dim=1).to(torch.float32)  # for ViT token averaging
        else:
            x = torch.flatten(x, start_dim=1).to(torch.float32)
        x = self.brainmodel(x)
        return x

def denormalize(img, device):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(device)
    return img * std + mean

def fgsm_single_direction(model, image, device, epsilon, alpha, iterations, voxel_idx, direction):
    image = image.to(device)
    delta = torch.zeros_like(image, device=device)
    for _ in range(iterations):
        delta.requires_grad_()
        x_adv = (image + delta).clamp(CLAMP_MIN, CLAMP_MAX)   # clamp each step
        out = model(x_adv)
        val = out.squeeze()
        # clear any stale param grads
        for p in model.parameters():
            if p.grad is not None:
                p.grad = None
        loss = -val if direction == 'minimize' else val
        loss.backward()
        with torch.no_grad():
            delta += alpha * delta.grad.sign()
            delta.clamp_(-epsilon, epsilon)
        delta.grad = None
    pert_img = (image + delta).clamp(CLAMP_MIN, CLAMP_MAX)
    return pert_img.detach(), model(pert_img).squeeze().item(), model(image).squeeze().item()



def fgsm_params(epsilon):
    # epsilon argument is in pixel units (e.g. 1,2,...). returns normalized eps and a step = eps.
    eps = torch.tensor([epsilon/255, epsilon/255, epsilon/255]).view(1, 3, 1, 1).to(device)
    eps = eps / torch.tensor(STD).view(1, 3, 1, 1).to(device)   # normalized epsilon
    alpha = eps   # single-step FGSM: step == full epsilon (normalized)
    its = 1
    return eps, alpha, its, fgsm_single_direction


def l2_attack_single_direction(model, image, device, epsilon, alpha, iterations, voxel_idx, direction):
    """
    Pixel-space L2-PGD:
    - epsilon, alpha are scalars in pixel units ([0,1])
    - projection uses ||STD ⊙ delta||_2 ≤ epsilon
    """
    image = image.to(device)
    delta = torch.zeros_like(image, device=device)
    std = torch.tensor(STD, device=device, dtype=image.dtype).view(1,3,1,1)

    for _ in range(iterations):
        delta.requires_grad_(True)
        x_adv = (image + delta).clamp(CLAMP_MIN, CLAMP_MAX)
        out = model(x_adv)
        val = out.squeeze()
        for p in model.parameters():
            if p.grad is not None:
                p.grad = None
        loss = -val if direction == 'minimize' else val
        loss.backward()

        # pixel-space gradient step of size alpha, mapped back to norm space
        grad = delta.grad.detach()            # dL/dx_norm
        grad_pix = grad * std                 # dL/dx_pix
        gp = grad_pix.view(grad_pix.size(0), -1)
        gp_norm = gp.norm(p=2, dim=1, keepdim=True).clamp_min(1e-8).view(-1,1,1,1)

        with torch.no_grad():
            step_norm = (alpha * grad_pix / gp_norm) / std
            delta += step_norm

            # project to pixel-space L2 ball
            delta_pix = delta * std
            dp = delta_pix.view(delta_pix.size(0), -1)
            dp_norm = dp.norm(p=2, dim=1, keepdim=True).clamp_min(1e-8).view(-1,1,1,1)
            scale = (epsilon / dp_norm).clamp(max=1.0)
            delta *= scale

        delta.grad = None

    pert_img = (image + delta).clamp(CLAMP_MIN, CLAMP_MAX)
    return pert_img.detach(), model(pert_img).squeeze().item(), model(image).squeeze().item()


def l2_params(epsilon):
    # epsilon in pixel units
    eps = torch.tensor(float(epsilon), device=device)
    alpha = torch.tensor(float(epsilon), device=device)  # one-step: full epsilon in pixel units
    its = 1
    return eps, alpha, its, l2_attack_single_direction


def get_shuffled_control(image, delta_norm, clamp_min, clamp_max, device):
    """
    Shuffle the applied perturbation in pixel space and reapply.
    image, delta_norm: normalized tensors [B,3,H,W]
    """
    std = torch.tensor(STD, device=device, dtype=image.dtype).view(1,3,1,1)
    delta_pix = delta_norm * std
    shuffled_pix = torch.empty_like(delta_pix)
    B = delta_pix.shape[0]
    for b in range(B):
        flat = delta_pix[b].reshape(-1)
        idx = torch.randperm(flat.numel(), device=device)
        shuffled_pix[b] = flat[idx].view_as(delta_pix[b])
    delta_norm_shuf = shuffled_pix / std
    control_img = (image + delta_norm_shuf).clamp(clamp_min, clamp_max)
    return control_img


root_img_dir = os.path.join(results_dir, 'attack_results', MODEL_NAME)

fgsm_epsilons = [1,2,3,4,5,6,7,8]
l2_epsilons = [1,2,3,4,5,6,7,8,9,10]

for subject in subjects:
    torch.cuda.empty_cache()
    gc.collect()
    for region in regions:
        layer_name = best_layers[MODEL_NAME][subject, region]
        dir_ = f'CAMERA_READY/attacks/activations/{MODEL_NAME}/{layer_name}/'
        best_alphas = np.load(os.path.join(dir_, 'best_alphas.npy'))
        if MODEL_NAME != 'dreamsim_vitb32':
            if layer_name.startswith('model.'):
                layer_name = layer_name.split('model.')[1]

        batch_files = natsorted([f for f in os.listdir(dir_) if f.startswith('batch')])
        activations = np.concatenate([np.load(os.path.join(dir_, f)) for f in batch_files], axis=0)
        activations_flattened = activations.reshape(activations.shape[0], -1)
        acts_train, acts_test, brain_train, brain_test = split_data(activations_flattened, brain_data.numpy())



        csv_path = os.path.join(results_folder, f'{MODEL_NAME}_{subject}_{region}_attack_results.csv')
        if os.path.exists(csv_path):
            print(f"Skipping {subject} {region} — results already exist.")
            continue
        print(f'=== Running {subject} {region} ===')


        top50_voxels    = subject_region_to_top50_global[(subject, region)][:n_voxels]
        brain_train_sub = brain_train[:, top50_voxels].mean(axis=1, keepdims=True)
        brain_test_sub  = brain_test[:, top50_voxels].mean(axis=1, keepdims=True)

        print(f'Brain data shape: {brain_train_sub.shape}, {brain_test_sub.shape}')
        # ridgemodel = Ridge(alpha=best_alphas[indices], max_iter=5000)
        # ridgemodel.fit(acts_train, brain_train_sub)

        from sklearn.linear_model import Ridge

        ridge = Ridge()
        grid = GridSearchCV(ridge, {'alpha': alphas}, scoring=scorer, cv=5)
        grid.fit(acts_train, brain_train_sub)
        ridgemodel = grid.best_estimator_
        y_pred = ridgemodel.predict(acts_test)

        r, _ = pearsonr(brain_test_sub.squeeze(), y_pred.squeeze())
        voxel_corr = np.array([r])

        np.save(os.path.join(results_dir, f'{MODEL_NAME}_{subject}_{region}_best_alpha.npy'), grid.best_params_['alpha'])
        np.savez(os.path.join(results_dir, 'voxel_corrs', f'{MODEL_NAME}_{subject}_{region}_voxel_corr.npz'), voxel_corr=voxel_corr)

        top50_voxels = subject_region_to_top50_global[(subject, region)][:n_voxels]

        ridgemodel.coef_ = ridgemodel.coef_.reshape(1, -1)  # Ensure the shape is correct for the brain model
        brainmodel = build_brainmodel_from_linear_model(
            ridgemodel, ridgemodel.coef_.shape[1], brain_train_sub.shape[1], device
        )
        encoderbm = EncoderBM(encoder, brainmodel, layer_name).to(device)
        encoderbm.eval()

        results = []
        if True:

            for image_count, image_idx in enumerate(selected_image_indices):
                image = images_test_tensor[image_idx]

                if True:
                    voxel_idx = 0 
                    label = brain_test_sub[image_idx, voxel_idx]

                    # FGSM attacks
                    for epsilon in fgsm_epsilons:
                        eps, alpha, its, attack_fn = fgsm_params(epsilon)
                        fgsm_max_img, fgsm_max_resp, fgsm_orig_resp = attack_fn(
                            encoderbm, image.unsqueeze(0), device, eps, alpha, its, voxel_idx, direction='maximize'
                        )
                        delta_max = fgsm_max_img - image.unsqueeze(0)
                        fgsm_max_control_img = get_shuffled_control(image.unsqueeze(0), delta_max, CLAMP_MIN, CLAMP_MAX, device)
                        fgsm_max_control_resp = encoderbm(fgsm_max_control_img).squeeze().item()

                        fgsm_min_img, fgsm_min_resp, _ = attack_fn(
                            encoderbm, image.unsqueeze(0), device, eps, alpha, its, voxel_idx, direction='minimize'
                        )
                        delta_min = fgsm_min_img - image.unsqueeze(0)
                        fgsm_min_control_img = get_shuffled_control(image.unsqueeze(0), delta_min, CLAMP_MIN, CLAMP_MAX, device)
                        fgsm_min_control_resp = encoderbm(fgsm_min_control_img).squeeze().item()
                        
                        results.append((image_idx, voxel_idx, 'fgsm', label, float(fgsm_orig_resp), float(fgsm_max_resp), 'maximize', float(fgsm_max_control_resp), epsilon))
                        results.append((image_idx, voxel_idx, 'fgsm', label, float(fgsm_orig_resp), float(fgsm_min_resp), 'minimize', float(fgsm_min_control_resp), epsilon))

                    # L2 attacks
                    for epsilon in l2_epsilons:
                        eps, alpha, its, attack_fn = l2_params(epsilon)
                        l2_max_img, l2_max_resp, l2_orig_resp = attack_fn(
                            encoderbm, image.unsqueeze(0), device, eps, alpha, its, voxel_idx, direction='maximize'
                        )
            
                        l2_max_control_img = get_shuffled_control(image.unsqueeze(0), l2_max_img - image.unsqueeze(0), CLAMP_MIN, CLAMP_MAX, device)
                        l2_max_control_resp = encoderbm(l2_max_control_img).squeeze().item()

                        l2_min_img, l2_min_resp, _ = attack_fn(
                            encoderbm, image.unsqueeze(0), device, eps, alpha, its, voxel_idx, direction='minimize'
                        )
                        l2_min_control_img = get_shuffled_control(image.unsqueeze(0), l2_min_img - image.unsqueeze(0), CLAMP_MIN, CLAMP_MAX, device)
                        l2_min_control_resp = encoderbm(l2_min_control_img).squeeze().item()

                        results.append((image_idx, voxel_idx, 'l2', label, float(l2_orig_resp), float(l2_max_resp), 'maximize', float(l2_max_control_resp), epsilon))
                        results.append((image_idx, voxel_idx, 'l2', label, float(l2_orig_resp), float(l2_min_resp), 'minimize', float(l2_min_control_resp), epsilon))

            print(f"Done {image_count+1}/{n_images} images.")

        df_results = pd.DataFrame(
            results,
            columns=['image_idx', 'voxel_idx', 'attack_type', 'label', 'old_response', 'perturbed_response', 'direction', 'control_response', 'epsilon']
        )

        df_results.to_csv(os.path.join(results_folder, f'{MODEL_NAME}_{subject}_{region}_attack_results.csv'), index=False)
        torch.cuda.empty_cache()
        gc.collect()
        print(f'Done. Results saved for {subject}, {region}.')

print('All done.')