import numpy as np
import torch
from pathlib import Path
import os, json
import einops
import matplotlib.pyplot as plt
from io import BytesIO
import seaborn as sns


def _list_init(length: int):
    return [[] for _ in range(length)]

class PRSLogger(object):
    def __init__(self, model, device, layer_info, cfg, textual_len=None):
        self.device = device
        self.model_name = cfg.TRAINER.NAME
        self.dataset_name = cfg.DATASET.NAME
        self.seed = cfg.SEED
        self.logs_root = f"/Anonymous/Anonymous/logs/{self.model_name}/{self.dataset_name}/seed{self.seed}"
        self.npz_root = f"/Anonymous/Anonymous/npz/{self.model_name}/{self.dataset_name}/seed{self.seed}"
        os.makedirs(self.logs_root, exist_ok=True)
        self.visual_layers, self.textual_layers = layer_info
        self.layers_length = len(cfg.TRAINER.MYMODEL.REP_LAYERS) + 1
        self.first_layer = cfg.TRAINER.MYMODEL.REP_LAYERS[0]
        self.repr_token_len = cfg.TRAINER.MYMODEL.N_REP_TOKENS
        self.eot_pos = textual_len + self.repr_token_len
        # # visual attentions visualization
        # self.visual_cls_attentions = _list_init(self.visual_layers)
        # self.visual_prompt_attentions = _list_init(self.visual_layers)
        # visual tokens
        self.visual_cls_token_after_attn = _list_init(self.visual_layers)
        self.visual_prompt_token_after_attn = _list_init(self.visual_layers)
        self.visual_content_token_after_attn = _list_init(self.visual_layers)
        self.visual_cls_token_post = _list_init(self.visual_layers)
        self.visual_prompt_token_post = _list_init(self.visual_layers)
        self.visual_content_token_post = _list_init(self.visual_layers)
        # textual attentions visualization
        # self.textual_eot_attentions = _list_init(self.textual_layers)
        # textual tokens
        self.textual_eot_token_after_attn = _list_init(self.textual_layers)
        self.textual_content_token_after_attn = _list_init(self.textual_layers)
        self.textual_eot_token_post = _list_init(self.textual_layers)
        self.textual_content_token_post = _list_init(self.textual_layers)
        self.visual_attention = _list_init(self.visual_layers)
        
        self.visual_proj_post = []
        self.visual_proj_pre = []
        self.visual_proj_rep_post = []
        self.visual_proj_rep_pre = []
        self.textual_post_ln = []
        self.textual_post_proj = []
        
        self.sub_cls = cfg.DATASET.SUBSAMPLE_CLASSES
        self.cur_num = 0
        self.len_limit = 1e9 if self.sub_cls == "base" else 4096
        self.img = []
        self.img0 = []
        self.img_path = []
        self.name_to_idx = None
        self.model = model

    @torch.no_grad()
    def compute_visual_attentions_matrix(self, ret, layer):
        assert len(ret.shape) == 3, "Verify that you catch the attention weights correctly" # [b, n, n]
        ret_tmp = ret.detach()
        self.visual_attention[layer - 1].append(ret_tmp.cpu().numpy())  # [b, n]
        return ret

    @torch.no_grad()
    def compute_visual_sequence_after_attn(self, ret, layer):
        assert len(ret.shape) == 3, "Verify that you catch the attention weights correctly" # [b, n, d]
        ret_tmp = ret.detach().clone().permute(1, 0, 2)
        if self.cur_num <= self.len_limit:
            if layer >= self.first_layer:
                self.visual_cls_token_after_attn[layer - 1].append(ret_tmp[:, 0, :].norm(dim=-1).cpu().numpy())  # [b]
                self.visual_prompt_token_after_attn[layer - 1].append(ret_tmp[:, 1:1+self.repr_token_len, :].norm(dim=-1).cpu().numpy())  # [b, repr_n]
                self.visual_content_token_after_attn[layer - 1].append(ret_tmp[:, 1+self.repr_token_len:, :].norm(dim=-1).cpu().numpy())  # [b, n - repr_n]
            else:
                b, _, _ = ret_tmp.shape
                self.visual_cls_token_after_attn[layer - 1].append(ret_tmp[:, 0, :].norm(dim=-1).cpu().numpy())  # [b]
                self.visual_prompt_token_after_attn[layer - 1].append(np.zeros((b, self.repr_token_len), dtype=ret_tmp.cpu().numpy().dtype))  # [b, repr_n]
                self.visual_content_token_after_attn[layer - 1].append(ret_tmp[:, 1:, :].norm(dim=-1).cpu().numpy())  # [b, n - 1]
        return ret
    
    @torch.no_grad()
    def compute_visual_sequence_post(self, ret, layer):
        assert len(ret.shape) == 3, "Verify that you catch the attention weights correctly" # [b, n, d]
        ret_tmp = ret.detach().clone().permute(1, 0, 2)
        if self.cur_num <= self.len_limit:
            if layer >= self.first_layer:
                self.visual_cls_token_post[layer - 1].append(ret_tmp[:, 0, :].norm(dim=-1).cpu().numpy())  # [b]
                self.visual_prompt_token_post[layer - 1].append(ret_tmp[:, 1:1+self.repr_token_len, :].norm(dim=-1).cpu().numpy())  # [b, repr_n]
                self.visual_content_token_post[layer - 1].append(ret_tmp[:, 1+self.repr_token_len:, :].norm(dim=-1).cpu().numpy())  # [b, n - repr_n]
            else:
                b, _, _ = ret_tmp.shape
                self.visual_cls_token_post[layer - 1].append(ret_tmp[:, 0, :].norm(dim=-1).cpu().numpy())  # [b]
                self.visual_prompt_token_post[layer - 1].append(np.zeros((b, self.repr_token_len), dtype=ret_tmp.cpu().numpy().dtype))  # [b, repr_n]
                self.visual_content_token_post[layer - 1].append(ret_tmp[:, 1:, :].norm(dim=-1).cpu().numpy())  # [b, n - 1]
        return ret
    
    # @torch.no_grad()
    # def compute_textual_attentions_matrix(self, ret, layer):
    #     assert len(ret.shape) == 3, "Verify that you catch the attention weights correctly" # [n, l, l]
    #     ret_tmp = ret.detach()
    #     self.textual_eot_attentions[layer - 1].append(ret_tmp[torch.arange(ret.shape[0]), self.eot_pos].cpu().numpy())  # [n, l]
    #     return ret

    @torch.no_grad()
    def compute_textual_sequence_after_attn(self, ret, layer):
        assert len(ret.shape) == 3, "Verify that you catch the attention weights correctly" # [n, l, d]
        ret_tmp = ret.detach().clone().permute(1, 0, 2)
        if self.cur_num <= self.len_limit:
            self.textual_eot_token_after_attn[layer - 1].append(ret_tmp[torch.arange(ret_tmp.shape[0]), self.eot_pos].norm(dim=-1).cpu().numpy())  # [n]
            self.textual_content_token_after_attn[layer - 1].append(ret_tmp[:, :, :].norm(dim=-1).cpu().numpy())  # [n, l]
        return ret
    
    @torch.no_grad()
    def compute_textual_sequence_post(self, ret, layer):
        assert len(ret.shape) == 3, "Verify that you catch the attention weights correctly" # [n, l, d]
        ret_tmp = ret.detach().clone().permute(1, 0, 2)
        if self.cur_num <= self.len_limit:
            self.textual_eot_token_post[layer - 1].append(ret_tmp[torch.arange(ret_tmp.shape[0]), self.eot_pos].norm(dim=-1).cpu().numpy())  # [n]
            self.textual_content_token_post[layer - 1].append(ret_tmp[:, :, :].norm(dim=-1).cpu().numpy())  # [n, l]
        return ret
    
    def finalize(self, epoch):
        if self.name_to_idx is None or epoch == -1:
            img_path_list = []
            for paths in self.img_path:
                img_path_list.extend(paths)
            self.name_to_idx = {path: idx for idx, path in enumerate(img_path_list)}
        textual_content_token_post_norm = self.textual_content_token_post
        textual_eot_token_post_norm = self.textual_eot_token_post
        textual_content_token_after_attn_norm = self.textual_content_token_after_attn
        textual_eot_token_after_attn_norm = self.textual_eot_token_after_attn
        visual_cls_token_post_norm = self.visual_cls_token_post
        visual_content_token_post_norm = self.visual_content_token_post
        visual_cls_token_after_attn_norm = self.visual_cls_token_after_attn
        visual_content_token_after_attn_norm = self.visual_content_token_after_attn
        visual_prompt_token_post_norm = self.visual_prompt_token_post
        visual_prompt_token_after_attn_norm = self.visual_prompt_token_after_attn
        visual_attention = self.visual_attention
        # print('image list length:', len(self.img))
        # save results
        # np.savez_compressed(os.path.join(os.path.join(self.npz_root, f"epoch_{epoch}"), "textual_content_token_post_norm.npz"), arr=textual_content_token_post_norm)
        # np.savez_compressed(os.path.join(os.path.join(self.npz_root, f"epoch_{epoch}"), "textual_eot_token_post_norm.npz"), arr=textual_eot_token_post_norm)
        # np.savez_compressed(os.path.join(os.path.join(self.npz_root, f"epoch_{epoch}"), "textual_content_token_after_attn_norm.npz"), arr=textual_content_token_after_attn_norm)
        # np.savez_compressed(os.path.join(os.path.join(self.npz_root, f"epoch_{epoch}"), "textual_eot_token_after_attn_norm.npz"), arr=textual_eot_token_after_attn_norm)
        # np.savez_compressed(os.path.join(os.path.join(self.npz_root, f"epoch_{epoch}"), "visual_cls_token_post_norm.npz"), arr=visual_cls_token_post_norm)
        # np.savez_compressed(os.path.join(os.path.join(self.npz_root, f"epoch_{epoch}"), "visual_content_token_post_norm.npz"), arr=visual_content_token_post_norm)
        # np.savez_compressed(os.path.join(os.path.join(self.npz_root, f"epoch_{epoch}"), "visual_cls_token_after_attn_norm.npz"), arr=visual_cls_token_after_attn_norm)
        # np.savez_compressed(os.path.join(os.path.join(self.npz_root, f"epoch_{epoch}"), "visual_content_token_after_attn_norm.npz"), arr=visual_content_token_after_attn_norm)
        # np.savez_compressed(os.path.join(os.path.join(self.npz_root, f"epoch_{epoch}"), "visual_prompt_token_post_norm.npz"), arr=visual_prompt_token_post_norm)
        # np.savez_compressed(os.path.join(os.path.join(self.npz_root, f"epoch_{epoch}"), "visual_prompt_token_after_attn_norm.npz"), arr=visual_prompt_token_after_attn_norm)
        # # save images
        # np.savez_compressed(os.path.join(os.path.join(self.npz_root, f"epoch_{epoch}"), "images.npz"), arr=np.array(self.img))
        # np.savez_compressed(os.path.join(os.path.join(self.npz_root, f"epoch_{epoch}"), "images0.npz"), arr=np.array(self.img0))
        for layer in range(self.visual_layers):
            visual_attn_l = visual_attention[layer]
            merged_layer_data = np.concatenate(visual_attn_l, axis=0)
            avg_matrix = np.mean(merged_layer_data, axis=0)
            if epoch == -1:
                work_dir = os.path.join(os.path.join(self.logs_root, f"test_{self.sub_cls}"), f"layer_{layer+1}")
            else:
                work_dir = os.path.join(os.path.join(self.logs_root, f"epoch_{epoch}"), f"layer_{layer+1}")
            os.makedirs(work_dir, exist_ok=True)
            self.show_attn(avg_matrix, work_dir, is_prompted=(layer>=self.first_layer))
        return
        for s in range(len(self.img)):
            image, image_randop, img_path = self.img[s], self.img0[s], self.img_path[s]
            for img_idx in range(image.shape[0]):
                img, img_randop, img_p = image[img_idx], image_randop[img_idx], img_path[img_idx]
                img_p_idx = self.name_to_idx[img_p]
                if epoch == -1 and img_p_idx % 10 != 0: continue
                if epoch != -1 and img_p_idx > 3: continue
                for layer in range(self.visual_layers):
                    if epoch == -1:
                        work_dir = os.path.join(os.path.join(self.logs_root, f"test_{self.sub_cls}"), f"layer_{layer+1}")
                    else:
                        work_dir = os.path.join(os.path.join(self.logs_root, f"epoch_{epoch}"), f"layer_{layer+1}")
                    os.makedirs(work_dir, exist_ok=True)
                    if epoch == -1:
                        textual_content_token_post_norm_l = textual_content_token_post_norm[layer][0] # [n_cls, l]
                        textual_content_token_after_attn_norm_l = textual_content_token_after_attn_norm[layer][0] # [n_cls, l]
                        
                        textual_eot_token_post_norm_l = textual_eot_token_post_norm[layer][0] # [n_cls]
                        textual_eot_token_after_attn_norm_l = textual_eot_token_after_attn_norm[layer][0] # [n_cls]
                        
                        if layer + 1 == self.visual_layers:
                            textual_post_ln_ = self.textual_post_ln[0]
                            textual_post_proj_ = self.textual_post_proj[0]
                    else:
                        textual_content_token_post_norm_l = textual_content_token_post_norm[layer][s] # [n_cls, l]
                        textual_content_token_after_attn_norm_l = textual_content_token_after_attn_norm[layer][s] # [n_cls, l]
                        
                        textual_eot_token_post_norm_l = textual_eot_token_post_norm[layer][s] # [n_cls]
                        textual_eot_token_after_attn_norm_l = textual_eot_token_after_attn_norm[layer][s] # [n_cls]
                        if layer + 1 == self.visual_layers:
                            textual_post_ln_ = self.textual_post_ln[s]
                            textual_post_proj_ = self.textual_post_proj[s]
                    visual_cls_token_post_norm_l = visual_cls_token_post_norm[layer][s][img_idx] #  [1]
                    visual_content_token_post_norm_l = visual_content_token_post_norm[layer][s][img_idx] # [n]
                    # visual_attention_l = visual_attention[layer][s][img_idx]
                    
                    visual_cls_token_after_attn_norm_l = visual_cls_token_after_attn_norm[layer][s][img_idx] # [1]
                    visual_content_token_after_attn_norm_l = visual_content_token_after_attn_norm[layer][s][img_idx] # [n]
                    if layer + 1 >= self.first_layer:
                        visual_prompt_token_post_norm_l = visual_prompt_token_post_norm[layer][s][img_idx] # [repr_n]
                        visual_prompt_token_after_attn_norm_l = visual_prompt_token_after_attn_norm[layer][s][img_idx] # [repr_n]
                    
                    visual_post_norm = einops.rearrange(visual_content_token_post_norm_l, '(N M) -> N M', N=14, M=14)
                    visual_after_attn_norm = einops.rearrange(visual_content_token_after_attn_norm_l, '(N M) -> N M', N=14, M=14)
                    if layer + 1 == self.visual_layers:
                        fig, axes = plt.subplots(3, 5, figsize=(18, 12))
                    else:
                        fig, axes = plt.subplots(2, 5, figsize=(18, 8))
                    self.show_img(axes[0,0], img, "Original Image")
                    self.show_img(axes[0,1], img_randop, "Randomly Perturbed Image")
                    self.show_jet(axes[0,2], visual_after_attn_norm, "Visual Token After Attention Norm")
                    self.show_jet(axes[0,3], visual_post_norm, "Visual Token Post Norm")
                    
                    self.show_bar(axes[0,4], visual_content_token_post_norm_l, "Visual Token Post Norm")
                    self.bar_mark(axes[0,4], visual_cls_token_post_norm_l, 'red', True)
                    
                    self.show_bar(axes[1,0], visual_content_token_after_attn_norm_l, "Visual Token After Attention Norm")
                    self.bar_mark(axes[1,0], visual_cls_token_after_attn_norm_l, 'red', True)
                    if layer + 1 >= self.first_layer:
                        self.bar_mark(axes[0,4], visual_prompt_token_post_norm_l, 'blue')
                        self.bar_mark(axes[0,4], visual_prompt_token_post_norm_l.mean(), 'green', True)
                        self.bar_mark(axes[1,0], visual_prompt_token_after_attn_norm_l, 'blue')
                        self.bar_mark(axes[1,0], visual_prompt_token_after_attn_norm_l.mean(), 'green', True)
                    
                    self.show_bar(axes[1,1], textual_eot_token_post_norm_l, "Textual EOT Token Post Norm")
                    self.show_bar(axes[1,2], textual_eot_token_after_attn_norm_l, "Textual EOT Token After Attention Norm")
                    self.show_bar(axes[1,3], textual_content_token_post_norm_l[-1][1:], "Textual Token Post Norm")
                    self.set_bar_color(axes[1,3], self.eot_pos[-1])
                    self.show_bar(axes[1,4], textual_content_token_after_attn_norm_l[-1][1:], "Textual Token After Attention Norm")
                    self.set_bar_color(axes[1,4], self.eot_pos[-1])
                    fig.suptitle(f"{img_p}", fontsize=16, y=0.98)
                    if layer + 1 == self.visual_layers:
                        visual_proj_pre_ = self.visual_proj_pre[s][img_idx]
                        visual_proj_post_ = self.visual_proj_post[s][img_idx]
                        visual_proj_rep_pre = self.visual_proj_rep_pre[s][img_idx]
                        visual_proj_rep_post_ = self.visual_proj_rep_post[s][img_idx]
                        visual_data_ = np.concatenate([visual_proj_pre_, visual_proj_post_, visual_proj_rep_pre, visual_proj_rep_post_], axis=0)
                        labels = ["Pre", "Post", "Pre[R]", "Post[R]"]
                        self.show_bar_labels(axes[2,0], visual_data_, labels, "Visual Proj and Proj Rep Norm",)
                        self.show_bar(axes[2,1], textual_post_ln_, "Textual EOT Token Post LN Norm")
                        self.show_bar(axes[2,2], textual_post_proj_, "Textual EOT Token Post Proj Norm")
                    # buf = BytesIO()
                    # plt.savefig(buf, format='png', bbox_inches='tight')
                    # buf.seek(0)
                    # cache.append(buf)
                    # img_p_idx_list.append(img_p_idx)
                    fig.savefig(os.path.join(work_dir, f"{img_p_idx}.png"), dpi=300, bbox_inches='tight')
                    plt.close(fig)
        print(f'finish finalize {epoch}_{self.dataset_name}!')
                    
    def show_attn(self, data, load_path, idx=None, n=18, is_prompted=False):
        if not is_prompted: return
        total_tokens = data.shape[0]
        fixed_n = 6
        fixed_indices = list(range(1, min(fixed_n, total_tokens)))
        
        if total_tokens > fixed_n:
            # n_to_select = n - len(fixed_indices)
            # num_per_part = n_to_select // 3
            # rem = n_to_select % 3
            # parts_sizes = [num_per_part, num_per_part, num_per_part + rem] 
            # remaining_indices = np.arange(fixed_n, total_tokens)
            # weights = data[:, remaining_indices].sum(axis=0)
            # sorted_args = np.argsort(weights)
            # sorted_indices_by_weight = remaining_indices[sorted_args]
            # k_low = parts_sizes[0]
            # idx_low = sorted_indices_by_weight[:k_low]
            # k_high = parts_sizes[2]
            # idx_high = sorted_indices_by_weight[-k_high:]
            # k_mid = parts_sizes[1]
            # if k_mid > 0:
            #     mid_point = len(sorted_indices_by_weight) // 2
            #     start_mid = max(0, mid_point - k_mid // 2)
            #     end_mid = start_mid + k_mid
            #     idx_mid = sorted_indices_by_weight[start_mid:end_mid]
            # else:
            #     idx_mid = []
            # selected_dynamic_indices = np.concatenate([idx_low, idx_mid, idx_high])
            # selected_indices = np.sort(np.concatenate([fixed_indices, selected_dynamic_indices]))
            n_to_select = n - len(fixed_indices)
            remaining_indices = np.arange(fixed_n, total_tokens)
            weights = data[:, remaining_indices].sum(axis=0)
            sorted_args = np.argsort(weights)
            k = min(n_to_select, len(remaining_indices))
            top_k_indices = remaining_indices[sorted_args][-k:]
            selected_indices = np.sort(np.concatenate([fixed_indices, top_k_indices]))
        else:
            selected_indices = np.arange(total_tokens)

        # 4. 提取子矩阵
        sub_matrix = data[np.ix_(selected_indices, selected_indices)]
        eps = 1e-8
        processed_data = np.log(sub_matrix + eps)
        
        npy_filename = f"attn_data.npy"
        npy_path = os.path.join(load_path, npy_filename)
        np.save(npy_path, processed_data)
        
        processed_data = processed_data - np.median(processed_data) # 零点对齐
        
        # 5. 可视化与保存
        sns.set_theme(style="white")
        plt.figure(figsize=(8, 8))
        
        ax = sns.heatmap(
            processed_data,
            cmap="coolwarm",
            center=0,
            square=True,
            cbar=True,
        )
        ax.axis('off')
        
        # 自动处理路径
        save_path = os.path.join(load_path, f"attn.png")
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
        plt.close()
                    
    def bar_mark(self, ax, data, color, single=False):
        if single:
            ax.axhline(data, color=color, ls='--', lw=1)
        else:
            for iter in range(data.shape[0]):
                ax.axhline(data[iter], color=color, ls='--', lw=1)
                
    def set_bar_color(self, ax, pos):
        ax.patches[pos].set_facecolor('red')
        # ax.patches[pos].set_edgecolor('red')
    
    def show_bar_labels(self, ax, data, label, title):
        ax.set_title(title)
        ax.bar(np.arange(data.shape[0]), data, tick_label=label)
        
    def show_bar(self, ax, data, title):
        ax.set_title(title)
        ax.bar(np.arange(data.shape[0]), data)
    
    def show_img(self, ax, img, title):
        ax.imshow(img, aspect='auto')
        ax.set_title(title)
        ax.axis('off')
    
    def show_jet(self, ax, jet_map, title):
        ax.imshow(jet_map, cmap='viridis', aspect='auto')
        ax.set_title(title)
        ax.axis('off')
    
    def add(self, img0, img, img_path):
        self.cur_num += img.shape[0]
        if self.cur_num <= self.len_limit:
            self.img.append(img.detach().permute(0, 2, 3, 1).numpy())
            self.img0.append(img0.detach().permute(0, 2, 3, 1).numpy())
        self.img_path.append(img_path)
    
    def test(self):
        self.len_limit = 50
        
    def compute_visual_proj_rep_post(self, ret):
        ret_tmp = ret.detach().clone() # [b, d]
        self.visual_proj_rep_post.append(ret_tmp.norm(dim=-1, keepdim=True).cpu().numpy())
        return ret
    
    def compute_visual_proj_rep_pre(self, ret):
        ret_tmp = ret.detach().clone() # [b, d]
        self.visual_proj_rep_pre.append(ret_tmp.norm(dim=-1, keepdim=True).cpu().numpy())
        return ret
    
    def compute_visual_proj_post(self, ret):
        ret_tmp = ret.detach().clone() # [b, d]
        self.visual_proj_post.append(ret_tmp.norm(dim=-1, keepdim=True).cpu().numpy())
        return ret
    
    def compute_visual_proj_pre(self, ret):
        ret_tmp = ret.detach().clone() # [b, d]
        self.visual_proj_pre.append(ret_tmp.norm(dim=-1, keepdim=True).cpu().numpy())
        return ret
    
    def compute_textual_post_ln(self, ret):
        ret_tmp = ret.detach().clone() # [n_cls, d]
        self.textual_post_ln.append(ret_tmp.norm(dim=-1).cpu().numpy())
        return ret
    
    def compute_textual_post_proj(self, ret):
        ret_tmp = ret.detach().clone() # [n_cls, d]
        self.textual_post_proj.append(ret_tmp.norm(dim=-1).cpu().numpy())
        return ret

    def reinit(self):
        self.visual_cls_token_after_attn = _list_init(self.visual_layers)
        self.visual_prompt_token_after_attn = _list_init(self.visual_layers)
        self.visual_content_token_after_attn = _list_init(self.visual_layers)
        self.visual_cls_token_post = _list_init(self.visual_layers)
        self.visual_prompt_token_post = _list_init(self.visual_layers)
        self.visual_content_token_post = _list_init(self.visual_layers)
        self.textual_eot_token_after_attn = _list_init(self.textual_layers)
        self.textual_content_token_after_attn = _list_init(self.textual_layers)
        self.textual_eot_token_post = _list_init(self.textual_layers)
        self.textual_content_token_post = _list_init(self.textual_layers)
        self.visual_proj_post = []
        self.visual_proj_pre = []
        self.visual_proj_rep_post = []
        self.visual_proj_rep_pre = []
        self.textual_post_ln = []
        self.textual_post_proj = []
        self.cur_num = 0
        self.img = []
        self.img0 = []
        self.img_path = []
        torch.cuda.empty_cache()


def hook_prs_logger(model, device, layer_info=None, cfg=None, textual_len=None):
    """Hooks a projected residual stream logger to the model."""
    prs = PRSLogger(model, device, layer_info, cfg, textual_len)
    model.hook_manager.register(
        "visual.transformer.resblocks.*.after_attn", prs.compute_visual_sequence_after_attn
    )
    model.hook_manager.register(
        "visual.transformer.resblocks.*.post", prs.compute_visual_sequence_post
    )
    model.hook_manager.register(
        "textual.resblocks.*.after_attn", prs.compute_textual_sequence_after_attn
    )
    model.hook_manager.register(
        "textual.resblocks.*.post", prs.compute_textual_sequence_post
    )
    model.hook_manager.register(
        "visual.proj_rep.post", prs.compute_visual_proj_rep_post
    )
    model.hook_manager.register(
        "visual.proj_rep.pre", prs.compute_visual_proj_rep_pre
    )
    model.hook_manager.register(
        "visual.proj.post", prs.compute_visual_proj_post
    )
    model.hook_manager.register(
        "visual.proj.pre", prs.compute_visual_proj_pre
    )
    model.hook_manager.register(
        "textual_post.post_ln", prs.compute_textual_post_ln
    )
    model.hook_manager.register(
        "textual_post.post_proj", prs.compute_textual_post_proj
    )
    model.hook_manager.register(
        "visual.transformer.resblocks.*.attn_weight", prs.compute_visual_attentions_matrix
    )
    return prs
