# %%
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import yaml
import argparse
import json
import numpy as np
import pandas as pd
import scipy.io as sio
import random
from copy import deepcopy
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
from PIL import Image
import cortex

import torch
import utils
import nibabel as nib
import h5py
from scipy import stats

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.benchmark = True # fixes Conv3D if used
# torch.autograd.set_detect_anomaly(True)
# %%
# if running this interactively, can specify jupyter_args here for argparser to use
# if utils.is_interactive():
if True:
    # global_batch_size and batch_size should already be defined in the 2nd cell block
    jupyter_args = ""

    jupyter_args += "--name ddp2-0.01-qwen25vl-s1-CLIP-ViT-H-14-mask=brain-blocks=12-patch=conv_14-pt_dp=0-at_dp=0-bo_dp=0-txt_scl=1-m_pct=0-bs=128-lr=0.0003-bf=4096-norm_nii "
    jupyter_args += "--checkpoints_dir ../checkpoints/ "
    jupyter_args += "--seed 42 "

    jupyter_args += "--batch_size 32 "

    jupyter_args += "--subj_list 1 "
    jupyter_args += "--nsddir /opt/data/private/dataset/nsd "
    jupyter_args += "--space MNI_2mm "
    jupyter_args += "--func betas_fithrf_GLMdenoise_RR "
    jupyter_args += "--clip_model CLIP-ViT-H-14 "
    jupyter_args += "--norm_nii "

    jupyter_args += "--num_blocks 12 "
    jupyter_args += "--patch_size 14 "
    jupyter_args += "--patch_type conv "
    jupyter_args += "--patch_drop 0 "
    jupyter_args += "--attn_drop 0 "
    jupyter_args += "--block_drop 0 "
    jupyter_args += "--nii_mask brain "

    jupyter_args += "--subj 1 "
    jupyter_args += "--tag last "

    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    # from IPython.display import clear_output # function to clear print outputs in cell
    # %load_ext autoreload 
    # # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
# %%
parser = argparse.ArgumentParser(description="test model")

parser.add_argument("--name", type=str)
parser.add_argument("--checkpoints_dir", type=str, default="checkpoints/")
parser.add_argument("--seed", type=int, default=42)

parser.add_argument("--batch_size", type=int)

parser.add_argument("--subj_list", nargs='+', type=int)
parser.add_argument("--nsddir", type=str)
parser.add_argument("--space", type=str)
parser.add_argument("--func", type=str)

parser.add_argument("--clip_model", type=str)
parser.add_argument("--norm_nii", action='store_true')
# model
parser.add_argument("--num_blocks", type=int)
parser.add_argument("--patch_size", type=int)
parser.add_argument("--patch_type", type=str, default="conv")
parser.add_argument("--patch_drop", type=float, default=0.0)
parser.add_argument("--attn_drop", type=float, default=0.0)
parser.add_argument("--block_drop", type=float, default=0.0)
parser.add_argument("--nii_mask", type=str, default="brain")

parser.add_argument("--subj", type=int)
parser.add_argument("--tag", type=str)

# if utils.is_interactive():
if True:
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

utils.seed_everything(args.seed)
# %%
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("device:",device)
num_devices = 1
# %%
from typing import Any, Callable, Optional, Union
from transformers import CLIPConfig, CLIPModel, CLIPProcessor, CLIPImageProcessor
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
class BrainCLIPModel(CLIPModel):
    def __init__(self, config: CLIPConfig):
        super().__init__(config)
        # self.brain_model = clip_fmri.vision_model
        # self.brain_projection = clip_fmri.visual_projection
        self.brain_model = None
        self.brain_projection = None
    def get_brain_features(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: bool = False,
    ) -> torch.FloatTensor:
        # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        brain_outputs: BaseModelOutputWithPooling = self.brain_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )

        pooled_output = brain_outputs.pooler_output
        brain_features = self.brain_projection(pooled_output)

        return brain_features

clip_model = {'CLIP-ViT-H-14': "/opt/data/private/huggingface/models--laion--CLIP-ViT-H-14-laion2B-s32B-b79K", 'clip-vit-l-14': "/opt/data/private/huggingface/models--openai--clip-vit-large-patch14", "CLIP-ViT-bigG-14": "/opt/data/private/huggingface/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k"}

print("Loading CLIP model from ", clip_model.get(args.clip_model, clip_model['CLIP-ViT-H-14']))
processor = CLIPProcessor.from_pretrained(clip_model.get(args.clip_model, clip_model['CLIP-ViT-H-14']))
model = BrainCLIPModel.from_pretrained(clip_model.get(args.clip_model, clip_model['CLIP-ViT-H-14'])).to(device)
# %%
nii_path = os.path.join(args.nsddir,'nsddata_betas','ppdata_split_pth','{:s}','{:s}','subj{:02d}_betas_session{:02d}', '{:03d}.pth')

nsd_info = sio.loadmat(os.path.join(args.nsddir,'nsddata','experiments','nsd','nsd_expdesign.mat'))
stim_info = pd.read_csv(os.path.join(args.nsddir,'nsddata','experiments','nsd', 'nsd_stim_info_merged.csv'), index_col=0)
cocoId = stim_info['cocoId']
stim_sort = (nsd_info['subjectim'][:,nsd_info['masterordering']-1]-1).squeeze()
up_sess_bounds = [40, 40, 32, 30, 40, 32, 40, 30] 
with open(os.path.join(args.nsddir, 'nsddata_stimuli/stimuli/nsd/annotations', 'new_captions.json'), 'r') as f:
# with open(os.path.join(args.nsddir, 'nsddata_stimuli/stimuli/nsd/annotations', 'llava13b_captions_all.json'), 'r') as f:
# with open(os.path.join(args.nsddir, 'nsddata_stimuli/stimuli/nsd/annotations', 'qwen2_5_vl_captions_all.json'), 'r') as f:
    captions = json.load(f)
# ground truths
all_images = torch.load(f"../../Brain507/tools/all_images.pt")
stimulus = h5py.File(os.path.join(args.nsddir,'nsddata_stimuli','stimuli','nsd','coco_images_224_float16.hdf5'), 'r')['images']
# %%
tokenized_captions = None
try:
    tokenized_captions = torch.load('../tools/tokenized_newcaptions.pt')
    # tokenized_captions = torch.load('../tools/tokenized_captions_llava13b.pt')
    # tokenized_captions = torch.load('../tools/tokenized_captions_qwen25vl.pt')
except:
    tokenized_captions = {}
    for k, v in tqdm(captions.items()):
        tokenized_captions[k] =processor.tokenizer(captions[k], padding="max_length", truncation=True, max_length=77, return_tensors="pt")
    # torch.save(tokenized_captions, '../tools/tokenized_newcaptions.pt')
# %%
test_dict = {}
if args.norm_nii:
    norm_dict = {}
for subj in args.subj_list:
    test_list = []
    stim_unique = pd.Series(stim_sort[subj-1]).unique()
    for i in range(len(stim_unique)):
        nii_loc = np.where(stim_sort[subj-1]==stim_unique[i])[0]
        up_loc = np.where(nii_loc>=up_sess_bounds[subj-1]*750)[0]
        if len(up_loc)!=0:
            # print(f"delete nii: {list(up_loc)} in stim: {stim_unique[i]:05d}")
            nii_loc = np.delete(nii_loc, up_loc)
        if len(nii_loc)==0:
            continue
        if stim_unique[i] in nsd_info['sharedix']-1:
            # test
            test_list.append({
                'subj': subj,
                # 'stim': stim_path.format(stim_unique[i]),
                'stim_idx': stim_unique[i],
                'nii': [os.path.join(nii_path.format(args.space,args.func,subj,(nloc//750)+1,nloc%750)) for nloc in nii_loc],
                'text': captions[str(cocoId[stim_unique[i]])],
                'cocoid': cocoId[stim_unique[i]],
            })
    print(f"subj {subj} test set size: {len(test_list)}")
    test_dict['subj{:02d}'.format(subj)] = test_list
    if args.norm_nii:
        norm_dict.update({'subj{:02d}_mean'.format(subj): torch.load(f'{args.nsddir}/nsddata_betas/mean_std/subj{subj}_{args.space}_{args.func}_mean.pth')})
        norm_dict.update({'subj{:02d}_std'.format(subj): torch.load(f'{args.nsddir}/nsddata_betas/mean_std/subj{subj}_{args.space}_{args.func}_std.pth')})
# %%
def cal_padding_list(patch_size):
    base_size = [91, 109, 91]
    padding_list = []
    for dim in reversed(base_size):
        pad_size = (patch_size - dim % patch_size) % patch_size
        padding_list.extend([pad_size // 2, pad_size // 2 + (1 if pad_size % 2 else 0)])
    return padding_list
padding_list = cal_padding_list(args.patch_size)

if args.nii_mask=='brain':
    mask = torch.from_numpy(nib.load(os.path.join("{:s}", "fsl_tmp", "MNI152_T1_2mm_brain_mask.nii.gz").format(args.nsddir)).get_fdata().astype(np.uint8))
    mask = torch.nn.functional.pad(mask, pad=padding_list, mode='constant', value=0)
    token_mask = torch.nn.functional.conv3d(mask[None,None].float(), torch.ones((1, 1, args.patch_size, args.patch_size, args.patch_size)).float(), stride=(args.patch_size,args.patch_size,args.patch_size)).byte().flatten(2).squeeze()
    token_ids = torch.where(token_mask!=0)[0].tolist()
elif args.nii_mask=='roi':
    mask = torch.from_numpy(nib.load("{:s}/nsddata/ppdata/all/{:s}/roi/nsdgeneral.nii.gz".format(args.nsddir, args.space)).get_fdata().astype(np.uint8))
    mask = torch.nn.functional.pad(mask, pad=padding_list, mode='constant', value=0)
    token_mask = torch.nn.functional.conv3d(mask[None,None].float(), torch.ones((1, 1, args.patch_size, args.patch_size, args.patch_size)).float(), stride=(args.patch_size,args.patch_size,args.patch_size)).byte().flatten(2).squeeze()
    token_ids = torch.where(token_mask!=0)[0].tolist()
else:
    token_ids = None
# %%
PATCH_SIZE = args.patch_size 
NUM_BLOCKS = args.num_blocks 
PATCH_TYPE = args.patch_type 
ATTN_DROP = args.attn_drop
BLOCK_DROP = args.block_drop
PATCH_DROP = args.patch_drop
# %%
from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection
from vit3d import ViTEmbeddings3D, ViTConfig, calc_image_size

clip_cfg = CLIPVisionConfig(
    attention_dropout=ATTN_DROP,
    dropout=BLOCK_DROP,
    hidden_act=model.config.vision_config.hidden_act,
    hidden_size=model.config.vision_config.hidden_size,
    # image_size=224,
    initializer_factor=1.0,
    initializer_range=0.02,
    intermediate_size=model.config.vision_config.intermediate_size,
    layer_norm_eps=model.config.vision_config.layer_norm_eps,
    model_type="clip_vision_model",
    num_attention_heads=model.config.vision_config.num_attention_heads,
    # num_channels=1,
    num_hidden_layers=NUM_BLOCKS,
    projection_dim=model.config.vision_config.projection_dim,
    patch_size=14,
)

model_args = deepcopy(clip_cfg)
model_args.image_size = calc_image_size(PATCH_SIZE)
model_args.patch_size = (PATCH_SIZE, PATCH_SIZE, PATCH_SIZE)
model_args.hidden_dropout_prob = PATCH_DROP
model_args.num_channels=1
model_args.token_ids = token_ids
model_args.patch_type = PATCH_TYPE
# %%
clip_fmri = CLIPVisionModelWithProjection(clip_cfg)
clip_fmri.vision_model.embeddings = ViTEmbeddings3D(model_args)
utils.count_params(clip_fmri)
# %%
# ================= load ckpt ================ #
checkpoint_dir = os.path.join(args.checkpoints_dir, args.name)
print(f"\n---loading {checkpoint_dir}/{args.tag}.pth ckpt---\n")

state_dict = torch.load(checkpoint_dir+f'/{args.tag}.pth', map_location='cpu')
clip_fmri.load_state_dict(state_dict, strict=True)
json_save_path = os.path.join(checkpoint_dir, f'{args.tag}.json')
with open(json_save_path, 'r') as json_file:
    state = json.load(json_file)
    epoch_idx = state['epoch_idx']
print("Epoch", epoch_idx)
clip_fmri = clip_fmri.to(device)
# %%
model.brain_model = clip_fmri.vision_model
model.brain_projection = clip_fmri.visual_projection
# %%
def cal_norm_nii(nii, mean, std):
    mean = mean.to(nii.device)
    std = std.to(nii.device)
    nii = (nii - mean) / (std+1e-5)
    return nii
# %%
def get_batch(idx):
    batch = {}
    data = test_dict[f'subj{args.subj:02d}'][idx]
    text_ids = processor.tokenizer(captions[str(data['cocoid'])], return_tensors="pt")
    nii = [torch.load(n) for n in data['nii']]
    if len(nii)==1:
        nii = torch.stack(nii).repeat(3,1,1,1)
    elif len(nii)==2:
        nii = torch.stack([nii[0], nii[1], nii[0]])
    else:
        nii = torch.stack(nii)

    stim = torch.from_numpy(stimulus[data['stim_idx']])
    if processor is not None:
        stim = processor.image_processor(stim, return_tensors="pt", do_rescale=False).pixel_values
    
    nii = nii.float()/300.
    if args.norm_nii:
        nii = cal_norm_nii(nii, norm_dict[f'subj{subj:02d}_mean'].unsqueeze(0), norm_dict[f'subj{subj:02d}_std'].unsqueeze(0))
    nii = torch.nn.functional.pad(nii, pad=padding_list, mode='constant', value=0)

    batch.update({'text':{'input_ids':text_ids['input_ids'].to(device), 'attention_mask':text_ids['attention_mask'].to(device)}})
    batch.update({'vision': stim.to(device)})
    batch.update({'brain': nii.mean(0, keepdim=True).unsqueeze(0).to(device)})
    return batch
# %%
def get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
    square_tensor = torch.pow(tensor, 2)
    sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True)
    normed_tensor = torch.pow(sum_tensor, 0.5)
    return normed_tensor
# %%
from M2IB.scripts.utils import normalize 
# ==== 基础参数 ====
volume_shape = calc_image_size(PATCH_SIZE)   
patch_size = (PATCH_SIZE, PATCH_SIZE, PATCH_SIZE)    
grid_shape = tuple(d // p for d, p in zip(volume_shape, patch_size))  # (7, 8, 7)
num_tokens = np.prod(grid_shape)  # 392

def process_cams(cam):
    # cam = torch.nn.functional.relu(cam)
    # cam = (cam / cam.norm(dim=-1, keepdim=True))
    # cam = (cam-cam.mean(axis=-1, keepdims=True))/cam.std(axis=-1, keepdims=True)
    cam=cam/(cam.abs().max(dim=-1, keepdims=True).values)
    # cam = (cam-cam.mean())/cam.std()
    # cam=cam/(cam.abs().max())
    # (cam - cam.min(dim=-1, keepdims=True).values) / (cam.max(dim=-1, keepdims=True).values - cam.min(dim=-1, keepdims=True).values)
    return cam
def heatmap(saliency, target_model):
    if target_model=='text':
        # saliency = torch.nansum(saliency, -1).cpu().detach().numpy()
        saliency = normalize(saliency)
        return saliency
    elif target_model=='vision':
        # saliency = torch.nansum(saliency, -1)[1:] # Discard the first because it's the CLS token
        dim = int(saliency.numel() ** 0.5)
        saliency = saliency.reshape(1, 1, dim, dim)
        saliency = torch.nn.functional.interpolate(saliency, size=224, mode='bilinear')
        saliency = saliency.squeeze().cpu().detach() #.numpy()
        return normalize(saliency)
    elif target_model=='brain':
        global token_ids
        token_ids = torch.tensor(token_ids)
        saliency_grid = torch.full((num_tokens,), float(0)) 
        saliency_grid[token_ids] = saliency
        saliency_volume = saliency_grid.reshape(grid_shape)  
        saliency_volume = saliency_volume.unsqueeze(0).unsqueeze(0) 
        saliency_volume = torch.nn.functional.interpolate(
            saliency_volume, size=volume_shape, mode="trilinear", align_corners=False
        )[0, 0]
        saliency_volume *= mask
        saliency_volume = normalize(saliency_volume)
        saliency = saliency_volume[padding_list[0]:-padding_list[1], padding_list[2]:-padding_list[3], padding_list[4]:-padding_list[5]]
        saliency=saliency.permute(2,1,0)
        return saliency, saliency_volume
# %%

def chefer(model, batch, target_model='brain'):
    model.eval()
    out = {}
    for k in batch.keys():
        if k=='text':
            text_outputs = model.text_model(**batch[k], output_attentions=target_model=='text')
            text_embeds = model.text_projection(text_outputs.pooler_output)
            text_embeds = text_embeds / get_vector_norm(text_embeds)
            out.update({'text': {'outputs': text_outputs, 'embeds': text_embeds}})
        elif k=='vision':
            vision_outputs = model.vision_model(batch[k], output_attentions=target_model=='vision')
            vision_embeds = model.visual_projection(vision_outputs.pooler_output)
            vision_embeds = vision_embeds / get_vector_norm(vision_embeds)
            out.update({'vision': {'outputs': vision_outputs, 'embeds': vision_embeds}})
        elif k=='brain':
            brain_outputs = model.brain_model(batch[k], output_attentions=target_model=='brain')
            brain_embeds = model.brain_projection(brain_outputs.pooler_output)
            brain_embeds = brain_embeds / get_vector_norm(brain_embeds)
            out.update({'brain': {'outputs': brain_outputs, 'embeds': brain_embeds}})

    model.zero_grad()
    if target_model=='text':
        # logits_per_text_v = torch.matmul(out['text']['embeds'], out['vision']['embeds'].t()) * model.logit_scale.exp().cuda()
        logits_per_text_b = torch.matmul(out['text']['embeds'], out['brain']['embeds'].t()) * model.logit_scale.exp().cuda()
        # logits_per_text = (logits_per_text_v + logits_per_text_b) / 2
        logits_per_text = logits_per_text_b
        logit = logits_per_text[0,0]
        A = out['text']['outputs'].attentions[-1]
        grad = torch.autograd.grad(logit, A)[0].detach()
        grad = (grad[0, :, -1] * A.detach()[0, :, -1]).clamp(min=0).mean(dim=0)
    elif target_model=='vision':
        # logits_per_image_t = torch.matmul(out['vision']['embeds'], out['text']['embeds'].t()) * model.logit_scale.exp().cuda()
        logits_per_image_b = torch.matmul(out['vision']['embeds'], out['brain']['embeds'].t()) * model.logit_scale.exp().cuda()
        # logits_per_image = (logits_per_image_t + logits_per_image_b) / 2
        logits_per_image = logits_per_image_b
        logit = logits_per_image[0,0]
        A = out['vision']['outputs'].attentions[-1]
        grad = torch.autograd.grad(logit, A)[0].detach().abs()
        grad = (grad[0, :, 0, 1:] * A.detach()[0, :, 0, 1:]).clamp(min=0).mean(dim=0)
    elif target_model=='brain':
        logits_per_brain_t = torch.matmul(out['brain']['embeds'], out['text']['embeds'].t()) * model.logit_scale.exp().cuda()
        logits_per_brain_v = torch.matmul(out['brain']['embeds'], out['vision']['embeds'].t()) * model.logit_scale.exp().cuda()
        logits_per_brain = (logits_per_brain_t + logits_per_brain_v) / 2
        logit = logits_per_brain[0,0]
        A = out['brain']['outputs'].attentions[-1]
        grad = torch.autograd.grad(logit, A)[0].detach().abs()
        grad = (grad[0, :, 0, 1:] * A.detach()[0, :, 0, 1:]).clamp(min=0).mean(dim=0)

    return grad.cpu().detach()
# %%
saliency_maps = {'brain':[], 'vision':[], 'text':[]}
for i in trange(len(test_dict[f'subj{args.subj:02d}'])):
    batch = get_batch(i)
    target_model = 'brain'  # 'brain', 'vision', 'text'
    saliency_map_i = chefer(model, batch, target_model)
    _, saliency_map_i =heatmap(saliency_map_i, target_model)
    brain_feat = model.get_brain_features(batch['brain']).cpu().detach()
    saliency_maps['brain'].append({'saliency': saliency_map_i[None,None], 'feat': brain_feat, 'data': batch['brain'].cpu().detach()})

    target_model = 'vision'  # 'brain', 'vision', 'text'
    saliency_map_i = chefer(model, batch, target_model)
    saliency_map_i = heatmap(saliency_map_i, target_model)
    vision_feat = model.get_image_features(batch['vision']).cpu().detach()
    saliency_maps['vision'].append({'saliency': saliency_map_i[None,None], 'feat': vision_feat, 'data': batch['vision'].cpu().detach()})

    target_model = 'text'  # 'brain', 'vision', 'text'
    saliency_map_i = chefer(model, batch, target_model)
    saliency_map_i = heatmap(saliency_map_i, target_model)
    text_feat = model.get_text_features(**batch['text']).cpu().detach()
    saliency_maps['text'].append({'saliency': saliency_map_i[None], 'feat': text_feat, 'data': batch['text']['input_ids'].detach().cpu()})
# %%
class GradCAMBrainVLM(object):
    def __init__(self,model, target_model='brain'):
        self.gradients = dict()
        self.activations = dict()
        def backward_hook(module, grad_input, grad_output):
            self.gradients['value'] = grad_output[0].detach().cpu()
            # print(grad_output[0])
            return None
        def forward_hook(module, input, output):
            self.activations['value'] = output[0].detach().cpu()
            # print(output[0])
            return None

        if target_model=='text':
            target_layer = model.text_model.encoder.layers[21]
        elif target_model=='vision':
            target_layer = model.vision_model.encoder.layers[29]
        elif target_model=='brain':
            target_layer = model.brain_model.encoder.layers[9]
        # if target_model=='text':
        #     target_layer = model.text_model.encoder.layers[0]
        # elif target_model=='vision':
        #     target_layer = model.vision_model.encoder.layers[0]
        # elif target_model=='brain':
        #     target_layer = model.brain_model.encoder.layers[0]

        self.handles = []
        h = target_layer.register_forward_hook(forward_hook)
        self.handles.append(h)
        target_layer.register_backward_hook(backward_hook)
        self.handles.append(h)
        self.model = model
        self.target_model = target_model


    def forward(self, batch):
        # self.model.eval()
        out = {}
        for k in batch.keys():
            if k=='text':
                text_embeds = self.model.get_text_features(**batch[k])
                text_embeds = text_embeds / get_vector_norm(text_embeds)
                out.update({'text':text_embeds})
            elif k=='vision':
                vision_embeds = self.model.get_image_features(batch[k])
                vision_embeds = vision_embeds / get_vector_norm(vision_embeds)
                out.update({'vision':vision_embeds})
            elif k=='brain':
                brain_embeds = self.model.get_brain_features(batch[k])
                brain_embeds = brain_embeds / get_vector_norm(brain_embeds)
                out.update({'brain':brain_embeds})
        self.model.zero_grad()
        if self.target_model=='text':
            # logits_per_text_v = torch.matmul(out['text']['embeds'], out['vision']['embeds'].t()) * model.logit_scale.exp().cuda()
            logits_per_text_b = torch.matmul(out['text'], out['brain'].t()) * self.model.logit_scale.exp().cuda()
            # logits_per_text = (logits_per_text_v + logits_per_text_b) / 2
            logits_per_text = logits_per_text_b
            logit = logits_per_text[0,0]
        elif self.target_model=='vision':
            # logits_per_image_t = torch.matmul(out['vision']['embeds'], out['text']['embeds'].t()) * model.logit_scale.exp().cuda()
            logits_per_image_b = torch.matmul(out['vision'], out['brain'].t()) * self.model.logit_scale.exp().cuda()
            # logits_per_image = (logits_per_image_t + logits_per_image_b) / 2
            logits_per_image = logits_per_image_b
            logit = logits_per_image[0,0]
        elif self.target_model=='brain':
            logits_per_brain_t = torch.matmul(out['brain'], out['text'].t()) * self.model.logit_scale.exp().cuda()
            logits_per_brain_v = torch.matmul(out['brain'], out['vision'].t()) * self.model.logit_scale.exp().cuda()
            logits_per_brain = (logits_per_brain_t + logits_per_brain_v) / 2
            logit = logits_per_brain[0,0]
        logit.backward()
        gradients = self.gradients['value']
        activations = self.activations['value']
        # print(gradients.shape, activations.shape)
        saliency_map = (gradients * activations).sum(-1)
        # print(saliency_map)
        saliency_map = torch.nn.functional.relu(saliency_map)
        saliency_map = saliency_map[:,1:] if self.target_model!='text' else saliency_map
        return saliency_map

    def __call__(self, img,caption):
        return self.forward(img,caption)

    def remove_hooks(self):
        for handle in self.handles:
            handle.remove()

def grad_cam(model, batch, target_model='brain'):
    grad_cam = GradCAMBrainVLM(model, target_model=target_model)
    saliency = grad_cam.forward(batch)
    grad_cam.remove_hooks()
    return saliency.cpu().detach()
# %%
saliency_maps = {'brain':[], 'vision':[], 'text':[]}
for i in trange(len(test_dict[f'subj{args.subj:02d}'])):
    batch = get_batch(i)
    target_model = 'brain'  # 'brain', 'vision', 'text'
    saliency_map_i = grad_cam(model, batch, target_model)
    _, saliency_map_i =heatmap(saliency_map_i, target_model)
    brain_feat = model.get_brain_features(batch['brain']).cpu().detach()
    saliency_maps['brain'].append({'saliency': saliency_map_i[None,None], 'feat': brain_feat, 'data': batch['brain'].cpu().detach()})

    target_model = 'vision'  # 'brain', 'vision', 'text'
    saliency_map_i = grad_cam(model, batch, target_model)
    saliency_map_i = heatmap(saliency_map_i, target_model)
    vision_feat = model.get_image_features(batch['vision']).cpu().detach()
    saliency_maps['vision'].append({'saliency': saliency_map_i[None,None], 'feat': vision_feat, 'data': batch['vision'].cpu().detach()})

    target_model = 'text'  # 'brain', 'vision', 'text'
    saliency_map_i = grad_cam(model, batch, target_model)
    saliency_map_i = heatmap(saliency_map_i, target_model)
    text_feat = model.get_text_features(**batch['text']).cpu().detach()
    saliency_maps['text'].append({'saliency': saliency_map_i, 'feat': text_feat, 'data': batch['text']['input_ids'].detach().cpu()})

# %%
from skimage.transform import resize
import copy
from torch import nn
class RISEV(nn.Module):
    def __init__(self, model, input_size, gpu_batch=100):
        super(RISEV, self).__init__()
        self.model = model
        self.input_size = input_size
        self.gpu_batch = gpu_batch
        self.loss_fn = nn.CosineSimilarity(eps=1e-6)

    def generate_masks(self, N, s, p1, savepath='masks.npy'):
        cell_size = np.ceil(np.array(self.input_size) / s)
        up_size = (s + 1) * cell_size

        grid = np.random.rand(N, s, s) < p1
        grid = grid.astype('float32')

        self.masks = np.empty((N, *self.input_size))

        for i in tqdm(range(N), desc='Generating filters'):
            # Random shifts
            x = np.random.randint(0, cell_size[0])
            y = np.random.randint(0, cell_size[1])
            # Linear upsampling and cropping
            self.masks[i, :, :] = resize(grid[i], up_size, order=1, mode='reflect',
                                         anti_aliasing=False)[x:x + self.input_size[0], y:y + self.input_size[1]]
        self.masks = self.masks.reshape(-1, 1, *self.input_size)
        np.save(savepath, self.masks)
        self.masks = torch.from_numpy(self.masks).half()
        self.masks = self.masks.cuda()
        self.N = N
        self.p1 = p1

    def load_masks(self, filepath, p1=0.1):
        self.masks = np.load(filepath)
        self.masks = torch.from_numpy(self.masks).half().cuda()
        self.N = self.masks.shape[0]
        self.p1 = p1
        

    def forward(self, image, text_features):
        N = self.N
        _, _, H, W = image.size()
        stack = torch.mul(self.masks, image.data)
        p = []
        for i in range(0, N, self.gpu_batch):
            image_features = self.model.get_image_features(stack[i:min(i + self.gpu_batch, N)])
            p.append(self.loss_fn(image_features, text_features).detach().cpu().unsqueeze(-1))
        p = torch.cat(p)
        sal = torch.matmul(p.data.transpose(0, 1).half(), self.masks.cpu().view(N, H * W))
        sal = sal.view((1, H, W))
        sal = sal / N / self.p1
        sal = sal.mean(0, keepdim=True)
        sal = (sal - sal.min()) / (sal.max() - sal.min())
        return sal
    
def rise_v(model,image, brain_features, N=6000, s=14, p1=0.1, savepath='masks/masks_2d_{:d}.npy'):
    exp = RISEV(model, (224, 224), gpu_batch=20)
    if os.path.exists(savepath.format(s)):
        exp.load_masks(savepath.format(s))
    else:
        exp.generate_masks(N, s, p1, savepath=savepath.format(s))
    sal = exp(image, brain_features)
    return sal



class RISET(nn.Module):
    def __init__(self, model):
        super(RISET, self).__init__()
        self.ori_model = copy.deepcopy(model)
        self.model = copy.deepcopy(model)
        self.loss_fn = nn.CosineSimilarity(eps=1e-6)

    def generate_masks(self,input_size, N, s, p1):
        cell_size = np.ceil(np.array(input_size) / s)
        up_size = (s + 1) * cell_size
        
        np.random.seed(0)

        grid = np.random.rand(N, s) < p1
        grid = grid.astype('float32')

        masks = np.empty((N, *input_size))

        for i in range(N):
            x = np.random.randint(0, cell_size[0])
            y = np.random.randint(0, cell_size[1])
            masks[i, :, :] = resize(grid[i], up_size, order=1, mode='reflect',
                                         anti_aliasing=False)[x:x + input_size[0], y:y + input_size[1]]
        masks = masks.reshape(-1, 1, *input_size).squeeze()
        masks = torch.from_numpy(masks).float()
        masks = masks.cuda()
        N = N
        p1 = p1
        return masks, N, p1


    def forward(self, tid, brain_features, N=200, s=8, p1=0.1):
        masks, N, p1 = self.generate_masks((tid.shape[-1],brain_features.shape[-1]), N, s, p1)
        sal = list()
        for i in range(0, N):
            self.model.text_model.embeddings.token_embedding.weight.data[tid,:] = self.ori_model.text_model.embeddings.token_embedding.weight.data[tid,:] * masks[i]
            text_features = self.model.get_text_features(tid)
            sal.append(self.loss_fn(brain_features, text_features).detach().cpu().unsqueeze(-1) * masks[i].cpu())
        sal = torch.stack(sal)
        sal = sal.mean(0).sum(-1) / p1
        sal_min = sal.min()
        sal_max = sal.max()
        sal = (sal - sal_min) / (sal_max - sal_min)
        return sal



def rise_t(model,tid,brain_features):
    exp = RISET(model)
    sal = exp(tid, brain_features)
    return sal


class RISEB(nn.Module):
    def __init__(self, model, input_size, gpu_batch=100):
        super(RISEB, self).__init__()
        self.model = model
        self.input_size = input_size
        self.gpu_batch = gpu_batch
        self.loss_fn = nn.CosineSimilarity(eps=1e-6)

    def generate_masks(self, N, s, p1, savepath='masks.npy'):
        cell_size = np.ceil(np.array(self.input_size) / s)
        up_size = (s + 1) * cell_size

        grid = np.random.rand(N, s, s, s) < p1
        grid = grid.astype('float16')

        self.masks = np.empty((N, *self.input_size), dtype='float16')

        for i in tqdm(range(N), desc='Generating filters'):
            # Random shifts
            x = np.random.randint(0, cell_size[0])
            y = np.random.randint(0, cell_size[1])
            z = np.random.randint(0, cell_size[2])
            # Linear upsampling and cropping
            self.masks[i, :, :, :] = resize(grid[i], up_size, order=1, mode='reflect',
                                         anti_aliasing=False)[x:x + self.input_size[0], y:y + self.input_size[1], z:z + self.input_size[2]]
        self.masks = self.masks.reshape(-1, 1, *self.input_size).astype('float16')
        np.save(savepath, self.masks)
        self.masks = torch.from_numpy(self.masks).half()
        self.masks = self.masks.cuda()
        self.N = N
        self.p1 = p1

    def load_masks(self, filepath, p1=0.1):
        self.masks = np.load(filepath)
        self.masks = torch.from_numpy(self.masks).half().cuda()
        self.N = self.masks.shape[0]
        self.p1 = p1
        

    def forward(self, brain, text_features, image_features):
        N = self.N
        _, _, H, W, D = brain.size()
        stack = torch.mul(self.masks, brain.data)
        p = []
        for i in range(0, N, self.gpu_batch):
            brain_features = self.model.get_brain_features(stack[i:min(i + self.gpu_batch, N)])
            p.append((self.loss_fn(brain_features, text_features)/2+self.loss_fn(brain_features, image_features)/2).unsqueeze(-1))
        p = torch.cat(p)
        sal = torch.matmul(p.data.transpose(0, 1).half(), self.masks.view(N, H * W * D)).cpu()
        sal = sal.view((1, H, W, D))
        sal = sal / N / self.p1
        sal = sal.mean(0, keepdim=True)
        # sal = (sal - sal.min()) / (sal.max() - sal.min())
        return sal
    
def rise_b(model, brain, text_features, image_features, N=1000, s=14, p1=0.1, savepath='masks/masks_3d_{:d}.npy'):
    
    exp = RISEB(model, batch['brain'].shape[-3:], gpu_batch=10)
    if os.path.exists(savepath.format(s)):
        exp.load_masks(savepath.format(s))
    else:
        exp.generate_masks(N, s, p1, savepath=savepath.format(s))
    sal = exp(brain, text_features, image_features)
    return sal


def rise(model, batch, target_model, s=14):
    model.eval()
    out = {}
    with torch.no_grad(), torch.cuda.amp.autocast():
        for k in batch.keys():
            if k=='text':
                text_embeds = model.get_text_features(**batch[k])
                text_embeds = text_embeds / get_vector_norm(text_embeds)
                out.update({'text': text_embeds})
            elif k=='vision':
                vision_embeds = model.get_image_features(batch[k])
                vision_embeds = vision_embeds / get_vector_norm(vision_embeds)
                out.update({'vision': vision_embeds})
            elif k=='brain':
                brain_embeds = model.get_brain_features(batch[k])
                brain_embeds = brain_embeds / get_vector_norm(brain_embeds)
                out.update({'brain': brain_embeds})

        if target_model=='text':
            sal = rise_t(model, batch['text']['input_ids'][0][batch['text']['attention_mask'][0]==1].unsqueeze(0), out['brain'])
        elif target_model=='vision':
            sal = rise_v(model, batch['vision'], out['brain'], s=s)
        elif target_model=='brain':
            sal = rise_b(model, batch['brain'], out['text'], out['vision'], s=s)
    return sal.cpu().detach()

# %%
saliency_maps = {'brain':[], 'vision':[], 'text':[]}
for i in trange(1):
    batch = get_batch(i)
    target_model = 'brain'  # 'brain', 'vision', 'text'
    saliency_map_i = rise(model, batch, target_model)
    # _, saliency_map_i =heatmap(saliency_map_i, target_model)
    saliency_map_i *= mask
    saliency_map_i = normalize(saliency_map_i)
    brain_feat = model.get_brain_features(batch['brain']).cpu().detach()
    saliency_maps['brain'].append({'saliency': saliency_map_i[None], 'feat': brain_feat, 'data': batch['brain'].cpu().detach()})

    target_model = 'vision'  # 'brain', 'vision', 'text'
    saliency_map_i = rise(model, batch, target_model)
    saliency_map_i = heatmap(saliency_map_i, target_model)
    vision_feat = model.get_image_features(batch['vision']).cpu().detach()
    saliency_maps['vision'].append({'saliency': saliency_map_i[None,None], 'feat': vision_feat, 'data': batch['vision'].cpu().detach()})

    target_model = 'text'  # 'brain', 'vision', 'text'
    saliency_map_i = rise(model, batch, target_model)
    saliency_map_i = heatmap(saliency_map_i, target_model)
    text_feat = model.get_text_features(**batch['text']).cpu().detach()
    saliency_maps['text'].append({'saliency': saliency_map_i[None], 'feat': text_feat, 'data': batch['text']['input_ids'].detach().cpu()})

# %%
from M2IB.scripts.iba import *
class IBAInterpreter:
    def __init__(self, model, estim: Estimator, beta, steps=10, lr=1, batch_size=10, progbar=False):
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.model = model.to(self.device)
        self.original_layer = estim.get_layer()
        self.shape = estim.shape()
        self.beta = beta
        self.batch_size = batch_size
        self.fitting_estimator = torch.nn.CosineSimilarity(eps=1e-6)
        self.progbar = progbar
        self.lr = lr
        self.train_steps = steps
        self.bottleneck = InformationBottleneck(estim.mean(), estim.std(), device=self.device)
        # self.bottleneck = VarUB(estim.mean(), estim.std(), device=self.device)
        self.sequential = mySequential(self.original_layer, self.bottleneck)

    def text_heatmap(self, text_t, image_t):
        saliency, loss_c, loss_f, loss_t = self._run_text_training(text_t, image_t)
        saliency = torch.nansum(saliency, -1).cpu().detach().numpy()
        saliency = normalize(saliency)
        return normalize(saliency)
    
    def brain_heatmap(self, text_t, image_t):
        saliency, loss_c, loss_f, loss_t = self._run_vision_training(text_t, image_t)
        saliency = torch.nansum(saliency, -1)[1:] # Discard the first because it's the CLS token
        dim = int(saliency.numel() ** 0.5)
        saliency = saliency.reshape(1, 1, dim, dim)
        saliency = torch.nn.functional.interpolate(saliency, size=224, mode='bilinear')
        saliency = saliency.squeeze().cpu().detach().numpy()
        return normalize(saliency)

    def vision_heatmap(self, text_t, image_t):
        saliency, loss_c, loss_f, loss_t = self._run_vision_training(text_t, image_t)
        saliency = torch.nansum(saliency, -1)[1:] # Discard the first because it's the CLS token
        dim = int(saliency.numel() ** 0.5)
        saliency = saliency.reshape(1, 1, dim, dim)
        saliency = torch.nn.functional.interpolate(saliency, size=224, mode='bilinear')
        saliency = saliency.squeeze().cpu().detach().numpy()
        return normalize(saliency)

    def get_saliency(self, batch, target_model:str='brain'):
        saliency, loss_c, loss_f, loss_t = self._run_training(batch, target_model)
        return saliency, loss_c, loss_f, loss_t

    def _run_training(self, batch, target_model:str='brain'):
        assert batch[target_model] is not None, f"Please provide {target_model}"
        assert any(not v is None for k, v in batch.items() if k != target_model), f"must provide values for any other models except {target_model}"
        # assert all(batch[target_model].shape[0]==batch[other_model].shape[0] for other_model in batch if other_model!=target_model and batch[other_model] is not None), f"batch size must be the same for all models"

        if target_model=='text':
            replace_layer(self.model.text_model, self.original_layer, self.sequential)
        elif target_model=='vision':
            replace_layer(self.model.vision_model, self.original_layer, self.sequential)
        elif target_model=='brain':
            replace_layer(self.model.brain_model, self.original_layer, self.sequential)
        else:
            raise NotImplementedError(f"target_model {target_model} not implemented!")
        loss_c, loss_f, loss_t = self._train_bottleneck(batch, target_model)
        if target_model=='text':
            replace_layer(self.model.text_model, self.sequential, self.original_layer)
        elif target_model=='vision':
            replace_layer(self.model.vision_model, self.sequential, self.original_layer)
        elif target_model=='brain':
            replace_layer(self.model.brain_model, self.sequential, self.original_layer)
        return self.bottleneck.buffer_capacity.mean(axis=0), loss_c, loss_f, loss_t

    def _train_bottleneck(self, batch, target_model:str='brain'):
        optimizer = torch.optim.Adam(lr=self.lr, params=self.bottleneck.parameters())
        # Reset from previous run or modifications
        self.bottleneck.reset_alpha()
        # Train
        self.model.eval()
        out = {}
        for k,v in batch.items():
            if v is not None:
                if k == 'text':
                    # batch[k] = batch[k].expand(self.batch_size, -1)
                    batch[k]['input_ids'] = batch[k]['input_ids'].expand(self.batch_size, -1)
                    batch[k]['attention_mask'] = batch[k]['attention_mask'].expand(self.batch_size, -1)
                elif k == 'vision':
                    batch[k] = batch[k].expand(self.batch_size, -1, -1, -1)
                elif k == 'brain':
                    batch[k] = batch[k].expand(self.batch_size, -1, -1, -1, -1)
                else:
                    raise NotImplementedError(f"batch key {k} not implemented!")
        for _ in tqdm(range(self.train_steps), desc="Training Bottleneck",
                      disable=not self.progbar):
            optimizer.zero_grad()
            for k,v in batch.items():
                if v is not None:
                    if k == 'text':
                        out['text'] = self.model.get_text_features(**batch[k])
                    elif k == 'vision':
                        out['vision'] = self.model.get_image_features(batch[k])
                    elif k == 'brain':
                        out['brain'] = self.model.get_brain_features(batch[k])
            #
            # out = self.model.get_text_features(batch[0]), self.model.get_image_features(batch[1])
            # loss_c, loss_f, loss_t = self.calc_loss(outputs=out[0], labels=out[1])
            loss_c, loss_f, loss_t = 0, 0, 0
            for k in out:
                if k != target_model:
                    _loss_c, _loss_f, _loss_t = self.calc_loss(outputs=out[target_model], labels=out[k])
                    loss_c += _loss_c
                    loss_f += _loss_f
                    loss_t += _loss_t
            loss_t.backward()
            optimizer.step(closure=None)
        return loss_c, loss_f, loss_t 

    def calc_loss(self, outputs, labels):
        """ Calculate the combined loss expression for optimization of lambda """
        compression_term = self.bottleneck.buffer_capacity.mean()
        fitting_term = self.fitting_estimator(outputs, labels).mean()
        total =  self.beta * compression_term - fitting_term
        return compression_term, fitting_term, total
from M2IB.scripts.methods import extract_feature_map, extract_bert_layer, get_compression_estimator
def get_saliency_iba(batch, target_model, model, layer_idx, beta, var, lr=1, train_steps=10, progbar=True):
    features = extract_feature_map(model.get_submodule(f'{target_model}_model'), layer_idx, batch[target_model])
    layer = extract_bert_layer(model.get_submodule(f'{target_model}_model'), layer_idx)
    compression_estimator = get_compression_estimator(var, layer, features)
    reader = IBAInterpreter(model, compression_estimator, beta=beta, lr=lr, steps=train_steps, progbar=progbar)
    return reader.get_saliency(batch, target_model)
# %%
saliency_maps = {'brain':[], 'vision':[], 'text':[]}
for i in trange(1):
    batch = get_batch(i)
    target_model = 'brain'  # 'brain', 'vision', 'text'
    saliency_map_i, _, _, _ = get_saliency_iba(batch, target_model, model, layer_idx=0, beta=0.2, var=2, train_steps=20, progbar=False)
    saliency_map_i = torch.nansum(saliency_map_i, -1)[1:].cpu().detach()
    _, saliency_map_i =heatmap(saliency_map_i, target_model)
    batch = get_batch(i)
    brain_feat = model.get_brain_features(batch['brain']).cpu().detach()
    saliency_maps['brain'].append({'saliency': saliency_map_i[None,None], 'feat': brain_feat, 'data': batch['brain'].cpu().detach()})

    batch = get_batch(i)
    target_model = 'vision'  # 'brain', 'vision', 'text'
    batch['text']=None
    saliency_map_i, _, _, _ = get_saliency_iba(batch, target_model, model, layer_idx=0, beta=0.2, var=2, train_steps=20, progbar=False)
    saliency_map_i = torch.nansum(saliency_map_i, -1)[1:]
    saliency_map_i = heatmap(saliency_map_i, target_model)
    batch = get_batch(i)
    vision_feat = model.get_image_features(batch['vision']).cpu().detach()
    saliency_maps['vision'].append({'saliency': saliency_map_i[None,None], 'feat': vision_feat, 'data': batch['vision'].cpu().detach()})

    batch = get_batch(i)
    target_model = 'text'  # 'brain', 'vision', 'text'
    batch['vision']=None
    saliency_map_i, _, _, _ = get_saliency_iba(batch, target_model, model, layer_idx=0, beta=0.2, var=2, train_steps=20, progbar=False)
    saliency_map_i = torch.nansum(saliency_map_i, -1)
    saliency_map_i = heatmap(saliency_map_i, target_model)
    batch = get_batch(i)
    text_feat = model.get_text_features(**batch['text']).cpu().detach()
    saliency_maps['text'].append({'saliency': saliency_map_i[None], 'feat': text_feat, 'data': batch['text']['input_ids'].detach().cpu()})
# %%
import copy
from pytorch_grad_cam.metrics.cam_mult_image import DropInConfidence, IncreaseInConfidence
from eval_utils import CosSimilarity, ImageFeatureExtractor, TextFeatureExtractor, BrainFeatureExtractor, DropInConfidenceText, IncreaseInConfidenceText
def get_metrics_vt(brain_feat, brain_feature, image_feat,image_feature,text_id, text_feature, bmap, vmap, tmap, model):
    results = {}
    with torch.no_grad():
        btargets = [CosSimilarity(image_feature), CosSimilarity(text_feature)]
        results['bidrop'] = DropInConfidence()(brain_feat, bmap, btargets, BrainFeatureExtractor(model))[0][0]*100
        results['biincr'] = IncreaseInConfidence()(brain_feat, bmap, btargets, BrainFeatureExtractor(model))[0][0]*100
        results['btdrop'] = DropInConfidence()(brain_feat, bmap, btargets, BrainFeatureExtractor(model))[1][0]*100
        results['btincr'] = IncreaseInConfidence()(brain_feat, bmap, btargets, BrainFeatureExtractor(model))[1][0]*100

    return results

def metric_evaluation(model,saliency_maps):
    all_results = []
    for i in trange(len(saliency_maps['brain'])):
        brain_feat = saliency_maps['brain'][i]['data'].to(device)
        brain_feature = saliency_maps['brain'][i]['feat'].to(device)
        bmap = saliency_maps['brain'][i]['saliency'].to(device)
        image_feat = saliency_maps['vision'][i]['data'].to(device)
        image_feature = saliency_maps['vision'][i]['feat'].to(device)
        vmap = saliency_maps['vision'][i]['saliency'].to(device)
        text_id = saliency_maps['text'][i]['data'].to(device)
        text_feature = saliency_maps['text'][i]['feat'].to(device)
        tmap = saliency_maps['text'][i]['saliency'].to(device)
        results = get_metrics_vt(brain_feat, brain_feature, image_feat,image_feature,text_id,text_feature,bmap, vmap,tmap,model)
        all_results.append(results)
    return all_results


# %%
res = metric_evaluation(model,saliency_maps)

bidrop = sum(k['bidrop'] for k in res) / len(res)
biincr = sum(k['biincr'] for k in res) / len(res)
btdrop = sum(k['btdrop'] for k in res) / len(res)
btincr = sum(k['btincr'] for k in res) / len(res)
print("bidrop:", bidrop, "biincr:", biincr, "btdrop:", btdrop, "btincr:", btincr)