#!/usr/bin/env python3
"""
vit with prompt: a clean version with the default settings of VPT
"""
import math
import numpy as np
import torch
import torch.nn as nn
import torch.fft
import torchvision as tv
import random
import math
from functools import reduce
from operator import mul
from torch.nn.modules.utils import _pair
from torch.nn import Conv2d, Dropout
from scipy import ndimage

from ..vit_backbones.vit import CONFIGS, Transformer, VisionTransformer, np2th
from ...utils import logging
from einops import rearrange
logger = logging.get_logger("visual_prompt")

# class ParameterWrapper(nn.Parameter):
#     def __init__(self, data):
#         super(ParameterWrapper, self).__init__(data)
        
#     def register_forward_hook(self, hook):
#         self._forward_hooks.clear()
#         handle = self._register_forward_hook(hook)
#         return handle

# import torch
# import torch.nn as nn
# import torchquantum as tq
# import torchquantum.functional as tqf

# class VQC(tq.QuantumModule):
#     def __init__(self, n_wires: int = 8, n_qlayers: int = 1):
#         super().__init__()
#         self.n_wires = n_wires
#         self.n_qlayers = n_qlayers
#         # Setting up tensor product encoder
#         enc_cnt = list()
#         for i in range(self.n_wires):
#             cnt = {'input_idx': [i], 'func': 'ry', 'wires': [i]}
#             enc_cnt.append(cnt)
#         self.encoder = tq.GeneralEncoder(enc_cnt)
        
#         # Trainable parameters for quantum rotation gates
#         self.params_rx_dct = tq.QuantumModuleDict()
#         self.params_rx_dct_1 = tq.QuantumModuleDict()
#         # Initialize rotation gate parameters for each layer
#         for k in range(self.n_qlayers):
#             for i in range(self.n_wires):
#                 self.params_rx_dct[str(i + k * self.n_wires)] = tq.RX(has_params=True, trainable=True)
#         for i in range(self.n_wires):
#                 self.params_rx_dct_1[str(i + k * self.n_wires)] = tq.RX(has_params=True, trainable=True)
#         # Observable: PauliZ
#         self.measure = tq.MeasureAll(tq.PauliZ)

#     @tq.static_support
#     def forward(self, x: torch.Tensor):
#         self.q_device = tq.QuantumDevice(n_wires=self.n_wires, bsz=x.shape[0], device=x.device)
#         self.encoder(self.q_device, x)

#         # Variational quantum layers (BasicEntanglerLayers style)
#         for k in range(self.n_qlayers):
#             for i in range(self.n_wires):
#                 self.params_rx_dct[str(i + k * self.n_wires)](self.q_device, wires=i)

#             # CNOT entanglement (ring structure)
#             for i in range(self.n_wires):
#                 tqf.cnot(self.q_device, wires=[i, (i + 1) % self.n_wires])
                
#         for i in range(self.n_wires):
#                 self.params_rx_dct_1[str(i + k * self.n_wires)](self.q_device, wires=i)
#         # Measurement
#         return self.measure(self.q_device)    

# class QNetBlock(nn.Module):
#     def __init__(self, num_qubits=8, num_qlayers=2):
#         super(QNetBlock, self).__init__()
#         self.avg_pool = nn.AdaptiveAvgPool2d(1)
#         self.num_qubits = num_qubits
#         self.fc = nn.Sequential(
#             nn.Linear(768, num_qubits, bias=False),
#             nn.ReLU(),
#             VQC(n_wires=num_qubits, n_qlayers=num_qlayers),
#             nn.Linear(num_qubits, 768, bias=False),
#             nn.Sigmoid()
#         )  
        
#     def forward(self, x):
#         b, t, d= x.size()
#         res = x
#         x = x.reshape(-1,d)
#         y = self.fc(x).reshape(b,t,d)
#         y += res
#         return y

# class LearnableUnitaryBlock(nn.Module):
#     def __init__(self, dim_size=768):
#         super().__init__()
#         # Parameterized unitary matrix: U = e^{iH} (H is Hermitian matrix)
#         self.H_real = nn.Parameter(torch.randn(dim_size, dim_size))
#         self.H_imag = nn.Parameter(torch.randn(dim_size, dim_size))
        
#     # def forward(self, x):
#     #     # Construct Hermitian matrix H = H_real + i H_imag
#     #     H = torch.complex(self.H_real, self.H_imag)
#     #     H = (H + H.conj().transpose(-1, -2)) / 2  # Force Hermitian
#     #     U = torch.matrix_exp(1j * H)  # Matrix exponential generates unitary matrix
        
#     #     # Apply transformation (process last dimension)
#     #     x_complex = torch.view_as_complex(x.unsqueeze(-1))
#     #     x_transformed = torch.matmul(x_complex, U)
#     #     return torch.view_as_real(x_transformed).real
#     def forward(self, x):
#         # 1. Construct Hermitian matrix H = H_real + i*H_imag
#         H = torch.complex(self.H_real, self.H_imag)
#         H = (H + H.conj().T) / 2  # Ensure H is Hermitian matrix
        
#         # 2. Generate unitary matrix through matrix exponential U = e^{iH}
#         U = torch.matrix_exp(1j * H)
        
#         # 3. Convert input x to complex tensor (imaginary part set to 0)
#         x_complex = torch.complex(x, torch.zeros_like(x))  # Directly generate complex, no dimension adjustment needed
        
#         # 4. Apply unitary transformation U, maintaining dimension compatibility
#         x_transformed = torch.matmul(x_complex, U)
        
#         # 5. Return real part, consistent with original FNetBlock output dimensions
#         return x_transformed.real

# import torch_dct as dct

# class DCTBlock(nn.Module):
#     def __init__(self):
#         super().__init__()

#     def forward(self, x):
#         # 2D DCT transformation (last and second-to-last dimensions)
#         x = dct.dct(x, norm='ortho')  # DCT-II with orthogonal normalization
#         x = dct.dct(x.transpose(-1, -2), norm='ortho').transpose(-1, -2)
#         return x







import torch
import torch.nn as nn
import math


class GivensRotation(nn.Module):
    def __init__(self, dim=None, num_rotations=None):
        super().__init__()
        self.init_dim = dim
        self.num_rotations = num_rotations
        self.initialized = False

    def _initialize(self, dim, device, dtype):
        if dim < 2:
            self.register_buffer('i_indices', torch.empty(0, dtype=torch.long, device=device))
            self.register_buffer('j_indices', torch.empty(0, dtype=torch.long, device=device))
            self.theta_param = nn.Parameter(torch.empty(0, dtype=dtype, device=device))
            self.initialized = True
            return

        self.dim = dim
        pairs = [(i, j) for i in range(dim) for j in range(i + 1, dim)]
        num_rotations = self.num_rotations or dim
        selected = random.sample(pairs, min(len(pairs), num_rotations))

        if not selected:
            raise ValueError("No valid Givens rotation pairs selected.")

        i_indices, j_indices = zip(*selected)
        self.register_buffer('i_indices', torch.tensor(i_indices, dtype=torch.long, device=device))
        self.register_buffer('j_indices', torch.tensor(j_indices, dtype=torch.long, device=device))
        self.theta_param = nn.Parameter(torch.randn(len(i_indices), dtype=dtype, device=device))
        self.initialized = True

    
    def forward(self, x):  # x: (B, N, D)
        B, N, D = x.shape
        x = x.reshape(-1, D)

        if not self.initialized:
            self._initialize(D, x.device, x.dtype)

        if len(self.theta_param) == 0:
            return x.reshape(B, N, D)
        
        theta = math.pi * torch.tanh(self.theta_param)  # strictly bounded [-π, π]
        x_out = x.clone()

        for idx, (i, j) in enumerate(zip(self.i_indices, self.j_indices)):
            c, s = torch.cos(theta[idx]), torch.sin(theta[idx])
            xi, xj = x[:, i], x[:, j]
            x_out[:, i] = c * xi - s * xj
            x_out[:, j] = s * xi + c * xj

        return x_out.reshape(B, N, D)


class MultiBlockGivensRotation(nn.Module):
    def __init__(self, dim=None, num_blocks=None, num_rotations_per_block=None):
        super().__init__()
        self.init_dim = dim
        self.num_blocks = num_blocks
        self.num_rotations_per_block = num_rotations_per_block
        self.initialized = False

    def _initialize(self, dim, device, dtype):
        assert dim % self.num_blocks == 0, f"dim {dim} must be divisible by num_blocks {self.num_blocks}"
        block_dim = dim // self.num_blocks
        self.block_dim = block_dim

        self.blocks = nn.ModuleList([
            GivensRotation(dim=block_dim, num_rotations=self.num_rotations_per_block or block_dim // 8)
            for _ in range(self.num_blocks)
        ])
        self.initialized = True

    def forward(self, x):  # (B, N, D)
        B, N, D = x.shape
        if not self.initialized:
            self._initialize(D, x.device, x.dtype)

        x_split = torch.split(x, self.block_dim, dim=2)  # split on dim D
        x_out = [block(xi) for block, xi in zip(self.blocks, x_split)]
        return torch.cat(x_out, dim=2)
    
   
class GeneralOrthogonalBlock(nn.Module):
    def __init__(self, token_dim=None, embed_dim=768,
                 embed_block_size=4, embed_num_rotations_per_block=2):
        super().__init__()

        # Embed blocking (e.g., 768 -> 96 blocks of size 8)
        if embed_dim is not None and embed_dim % embed_block_size == 0:
            self.ortho_embed = MultiBlockGivensRotation(
                dim=embed_dim,
                num_blocks=embed_dim // embed_block_size,
                num_rotations_per_block=embed_num_rotations_per_block
            )
        else:
            self.ortho_embed = GivensRotation(embed_dim)

        # Token dimension not blocked, directly use GivensRotation
        self.ortho_token = GivensRotation(token_dim)

    def forward(self, x):  # x: (B, N, D)
        # print(x.shape)
        x = self.ortho_embed(x)         # (B, N, D)
        x = self.ortho_token(x.transpose(1, 2)).transpose(1, 2)
        return x

def orthogonality_loss_from_model(model: GeneralOrthogonalBlock):
    loss = 0.
    if isinstance(model.ortho_embed, MultiBlockGivensRotation):
        loss += model.ortho_embed.orthogonality_loss()
    else:
        loss += model.ortho_embed.orthogonality_loss()

    if hasattr(model, 'ortho_token') and model.ortho_token.initialized:
        loss += model.ortho_token.orthogonality_loss()

    return loss

class StackedOrtho(nn.Module):
    def __init__(self, num_layers=3):
        super().__init__()
        self.layers = nn.Sequential(*[
            GeneralOrthogonalBlock()
            for _ in range(num_layers)
        ])

    def forward(self, x):  # x: (B, N, D)
        return self.layers(x)
    
class NonLinearOrthoBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.ortho1 = GeneralOrthogonalBlock()
        self.act = nn.GELU()
        self.ortho2 = GeneralOrthogonalBlock()

    def forward(self, x):
        return self.ortho2(self.act(self.ortho1(x)))
# class WaveletBlock(nn.Module):
#     def __init__(self, wavelet='db1', level=3):
#         super().__init__()
#         self.wavelet = wavelet
#         self.level = level

#     def forward(self, x):  # x: (B, N, D)
#         B, N, D = x.shape
#         x = x.to(torch.float32)  # Ensure input is float32 type
#         x_out = []

#         for b in range(B):
#             channels_out = []
#             for n in range(N):
#                 signal = x[b, n]
#                 coeffs = ptwt.wavedec(signal, wavelet=self.wavelet, level=self.level, mode='zero')
#                 approx = coeffs[0]
#                 # Pad with zeros if dimensions don't match
#                 if approx.numel() < D:
#                     pad = D - approx.numel()
#                     approx = torch.nn.functional.pad(approx, (0, pad))
#                 else:
#                     approx = approx[:D]
#                 channels_out.append(approx)
#             x_out.append(torch.stack(channels_out, dim=0))
#         return torch.stack(x_out, dim=0)  # Output shape is (B, N, D)
import torch
import torch.nn as nn

class ExponentialMapOrthogonal(nn.Module):
    def __init__(self, dim, taylor_order=5):
        """
        Args:
            dim (int): Input dimension D
            taylor_order (int): Order of Taylor expansion used
        """
        super().__init__()
        self.dim = dim
        self.taylor_order = taylor_order

        # Learnable antisymmetric matrix parameters (constructed using upper triangle)
        self.A_param = nn.Parameter(torch.randn(dim, dim) * 0.01)

    def forward(self, x):
        """
        Args:
            x: Tensor of shape (B, N, D)
        Returns:
            Transformed x with shape (B, N, D)
        """
        B, N, D = x.shape
        assert D == self.dim

        # Construct antisymmetric matrix A
        A = self.A_param - self.A_param.T  # A^T = -A

        # Taylor expansion approximation of exp(A)
        I = torch.eye(self.dim, device=x.device, dtype=x.dtype)
        A_power = I.clone()
        expA = I.clone()
        for i in range(1, self.taylor_order + 1):
            A_power = torch.matmul(A_power, A) / i
            expA = expA + A_power

        # Apply orthogonal transformation (right multiplication)
        x_out = torch.matmul(x, expA.T)  # (B, N, D) x (D, D)

        return x_out


class ExponentialOrthogonalBlock(nn.Module):
    def __init__(self, embed_dim=768, token_dim=7, taylor_order=6, block_size=16):
        """
        Args:
            embed_dim (int): Dimension of each token
            token_dim (int): Number of tokens (can be used for cross-token orthogonal transformation)
            taylor_order (int): Taylor order for exponential mapping
            block_size (int): Only block embed_dim (no blocking when token_dim is small)
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.token_dim = token_dim
        self.block_size = block_size

        # Block orthogonal mapping by embed_dim
        assert embed_dim % block_size == 0, f"embed_dim {embed_dim} must be divisible by block_size {block_size}"
        self.blocks = nn.ModuleList([
            ExponentialMapOrthogonal(dim=block_size, taylor_order=taylor_order)
            for _ in range(embed_dim // block_size)
        ])

        # Token dimension not blocked (skip token direction transformation if token_dim is None)
        self.ortho_token = None
        if token_dim is not None and token_dim >= 2:
            self.ortho_token = ExponentialMapOrthogonal(dim=token_dim, taylor_order=taylor_order)

    def forward(self, x):  # x: (B, N, D)
        B, N, D = x.shape
        assert D == self.embed_dim

        # (1) Block embed dimension by blocks, each block orthogonal rotation
        x_blocks = torch.chunk(x, len(self.blocks), dim=2)  # list of (B, N, block_size)
        x_rotated = [blk(xb) for blk, xb in zip(self.blocks, x_blocks)]
        x = torch.cat(x_rotated, dim=2)  # (B, N, D)

        # (2) Token dimension orthogonal rotation (optional)
        if self.ortho_token is not None:
            x = x.transpose(1, 2)          # (B, D, N)
            x = self.ortho_token(x)        # (B, D, N)
            x = x.transpose(1, 2)          # (B, N, D)

        return x     


import torch
import torch.nn as nn
import math

# ----------- Exponential Map Orthogonal Transform (Taylor Approximation) ------------
# class ExponentialMapOrthogonal(nn.Module):
#     def __init__(self, dim, taylor_order=5):
#         super().__init__()
#         self.dim = dim
#         self.taylor_order = taylor_order
#         self.skew_param = nn.Parameter(torch.zeros(dim, dim))
#         nn.init.kaiming_uniform_(self.skew_param, a=math.sqrt(5))  # for stability

#     def forward(self, x):  # x: (B, N, D)
#         B, N, D = x.shape
#         assert D == self.dim

#         # Make matrix skew-symmetric: A = (A - A^T)/2
#         skew = self.skew_param - self.skew_param.T
#         skew = 0.5 * (skew - skew.T)

#         # Taylor approximation of exp(skew)
#         exp_skew = torch.eye(self.dim, device=x.device, dtype=x.dtype)
#         mat_power = torch.eye(self.dim, device=x.device, dtype=x.dtype)
#         factorial = 1.0

#         for k in range(1, self.taylor_order + 1):
#             mat_power = mat_power @ skew
#             factorial *= k
#             exp_skew = exp_skew + mat_power / factorial

#         # Linear transform: (B*N, D) x (D, D)
#         x_flat = x.reshape(-1, D)
#         out = x_flat @ exp_skew.T
#         return out.view(B, N, D)

# # ----------- Hadamard Block Mixing (cross-block entanglement) -----------------------
# class HadamardMixing(nn.Module):
#     def __init__(self, block_num, block_size):
#         super().__init__()
#         self.block_num = block_num
#         self.block_size = block_size
#         self.mix_vectors = nn.Parameter(torch.randn(block_num, block_size))  # learnable Hadamard vector

#     def forward(self, x):  # x: (B, N, D)
#         B, N, D = x.shape
#         x = x.view(B, N, self.block_num, self.block_size)  # (B, N, Bn, Bs)
#         x = x * self.mix_vectors.unsqueeze(0).unsqueeze(0)  # (1, 1, Bn, Bs)
#         return x.view(B, N, D)

# # ----------- Main Block: Exponential Orthogonal + Hadamard Mixing -------------------
# class ExponentialOrthogonalBlock(nn.Module):
#     def __init__(self, embed_dim=768, token_dim=7, taylor_order=5, block_size=8):
#         super().__init__()
#         self.embed_dim = embed_dim
#         self.token_dim = token_dim
#         self.block_size = block_size
#         assert embed_dim % block_size == 0, "embed_dim must be divisible by block_size"
#         self.block_num = embed_dim // block_size

#         # Create block-wise exponential orthogonal transforms
#         self.blocks = nn.ModuleList([
#             ExponentialMapOrthogonal(dim=block_size, taylor_order=taylor_order)
#             for _ in range(self.block_num)
#         ])

#         # Hadamard-based cross-block mixing
#         self.block_mixing = HadamardMixing(self.block_num, block_size)

#         # Optional token-dim orthogonal rotation
#         if token_dim is not None and token_dim >= 2:
#             self.ortho_token = ExponentialMapOrthogonal(dim=token_dim, taylor_order=taylor_order)
#         else:
#             self.ortho_token = None

#     def forward(self, x):  # x: (B, N, D)
#         B, N, D = x.shape
#         x_blocks = torch.chunk(x, self.block_num, dim=2)  # list of (B, N, block_size)
#         x_rotated = [blk(xb) for blk, xb in zip(self.blocks, x_blocks)]
#         x = torch.cat(x_rotated, dim=2)  # (B, N, D)

#         # Hadamard entanglement across blocks
#         x = self.block_mixing(x)

#         # Optional token-axis orthogonal rotation
#         if self.ortho_token is not None:
#             x = x.transpose(1, 2)  # (B, D, N)
#             x = self.ortho_token(x)
#             x = x.transpose(1, 2)

#         return x

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHouseholder(nn.Module):
    def __init__(self, dim, num_reflections=4):
        super().__init__()
        self.dim = dim
        self.num_reflections = num_reflections

        # Each reflection is a vector of dimension dim
        self.v_list = nn.ParameterList([
            nn.Parameter(torch.randn(dim)) for _ in range(num_reflections)
        ])

    def forward(self, x):
        """
        x: (..., dim)
        """
        for v in self.v_list:
            v = F.normalize(v, dim=0)  # Ensure unit vector
            proj = torch.einsum('...d,d->...', x, v)  # Projection coefficient
            x = x - 2 * proj.unsqueeze(-1) * v  # Householder reflection
        return x

class HouseholderFNetBlock(nn.Module):
    def __init__(self, embed_dim=768, token_dim=7, num_reflections=4):
        super().__init__()
        self.embed_transform = MultiHouseholder(dim=embed_dim, num_reflections=num_reflections)
        self.token_transform = MultiHouseholder(dim=token_dim, num_reflections=num_reflections)

    def forward(self, x):
        """
        x: (B, N, D) = (batch, token, embed)
        """
        x = self.embed_transform(x)       # Apply reflection to embedding dim
        x = x.transpose(1, 2)             # (B, D, N)
        x = self.token_transform(x)       # Apply reflection to token dim
        x = x.transpose(1, 2)             # (B, N, D)
        return x


import torch
import torch.nn as nn
import math

class CayleyOrthogonalTransform(nn.Module):
    def __init__(self, dim, rank=2, eps=1e-6):
        super().__init__()
        self.dim = dim
        self.rank = rank
        self.eps = eps
        
        # Low-rank parameterization of skew-symmetric matrix
        self.U = nn.Parameter(torch.Tensor(dim, rank))
        self.V = nn.Parameter(torch.Tensor(dim, rank))
        
        self.reset_parameters()
    
    def reset_parameters(self):
        # Orthogonal initialization
        nn.init.orthogonal_(self.U)
        nn.init.orthogonal_(self.V)
        with torch.no_grad():
            self.U.mul_(math.sqrt(1/self.rank))
            self.V.mul_(math.sqrt(1/self.rank))
    
    def skew_symmetric(self):
        """Construct skew-symmetric matrix of rank 2r"""
        return self.U @ self.V.t() - self.V @ self.U.t()
    
    def cayley_transform(self, A):
        # A: (..., dim, dim)
        I = torch.eye(self.dim, dtype=A.dtype, device=A.device)
        A_skew = 0.5 * (A - A.transpose(-1, -2))

        # Solve (I - A_skew) X = (I + A_skew) without explicit inverse
        M = I - A_skew
        B = I + A_skew

        # Slight regularization for enhanced numerical stability (optional)
        eps = getattr(self, "eps", 1e-6)
        M = M + eps * I

        # Compatible with different PyTorch versions
        if hasattr(torch.linalg, "solve"):
            X = torch.linalg.solve(M, B)          # New version priority
        else:
            X, _ = torch.solve(B, M)              # Old version interface

        return X
    
    def forward(self, x):
        # Generate skew-symmetric matrix
        A = self.skew_symmetric()
        
        # Compute orthogonal matrix
        Q = self.cayley_transform(A)
        
        # Apply orthogonal transformation
        return x @ Q

class CayleyFNetBlock(nn.Module):
    def __init__(self, d_model=768, seq_len=7, rank=2):
        super().__init__()
        # Sequence dimension Cayley transform
        self.seq_transform = CayleyOrthogonalTransform(seq_len, rank)
        # Feature dimension Cayley transform
        self.feat_transform = CayleyOrthogonalTransform(d_model, rank)
        
    def forward(self, x):
        # Input shape: [batch, seq_len, d_model]
        
        # Sequence dimension transformation (maintain feature dimension)
        x = self.seq_transform(x.transpose(-1, -2)).transpose(-1, -2)
        
        # Feature dimension transformation
        x = self.feat_transform(x)
        
        return x
    
class FNetBlock(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, x):
    # print('print in',x.shape) #torch.Size([64, 10, 768])
    x = torch.fft.fft(torch.fft.fft(x, dim=-1), dim=-2).real
   
    return x

import torch
import torch.nn as nn
import torch.fft

class SparseFFTBlock(nn.Module):
    def __init__(self, 
                 keep_ratio=0.3, 
                 grad_scale=1e3,
                 learnable_threshold=True):
        super().__init__()
        self.keep_ratio = nn.Parameter(torch.tensor(keep_ratio), 
                                      requires_grad=learnable_threshold)
        self.grad_scale = grad_scale
        
        # Frequency position encoder
        self.freq_encoder = nn.Sequential(
            nn.Conv2d(2, 8, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(8, 1, kernel_size=3, padding=1)
        )
    def dynamic_threshold(self, x_fft):
        """Compatible coordinate generation method for older PyTorch versions"""
        magnitude = torch.abs(x_fft)
        B, H, W = magnitude.shape
        
        # Alternative to meshgrid approach
        h_coord = torch.linspace(-1, 1, H, device=x_fft.device)
        w_coord = torch.linspace(-1, 1, W, device=x_fft.device)
        
        # Manually generate grid coordinates
        h_grid = h_coord.unsqueeze(1).repeat(1, W)
        w_grid = w_coord.unsqueeze(0).repeat(H, 1)
        
        coord_map = torch.stack([h_grid, w_grid], dim=0)  # [2, H, W]
        coord_map = coord_map.unsqueeze(0).expand(B, 2, H, W)
        
        energy_weight = torch.sigmoid(self.freq_encoder(coord_map))
        weighted_magnitude = magnitude * energy_weight.squeeze(1)
        
        k = int(H * W * torch.sigmoid(self.keep_ratio))
        values, _ = torch.topk(weighted_magnitude.view(B, -1), k, dim=-1)
        threshold = values[:, -1].view(B, 1, 1)
    
        return threshold
    def sparse_masking(self, x_fft, threshold):
        """Differentiable frequency masking"""
        magnitude = torch.abs(x_fft)
        phase = torch.angle(x_fft)
        
        # Soft masking mechanism
        mask = torch.sigmoid(self.grad_scale * (magnitude - threshold))
        sparse_fft = torch.polar(magnitude * mask, phase)
        
        return sparse_fft

    def forward(self, x):
        # Input shape: [B, H, W]
        
        # 2D FFT
        x_fft = torch.fft.fft2(x, dim=(-2, -1))
        
        # Dynamic threshold calculation
        threshold = self.dynamic_threshold(x_fft)
        
        # Sparsification processing
        sparse_fft = self.sparse_masking(x_fft, threshold)
        
        # Inverse transform
        x_sparse = torch.fft.ifft2(sparse_fft).real
        
        return x_sparse


import torch
import torch.nn as nn
import torch.nn.functional as F

class FNetPlusBlock(nn.Module):
    def __init__(self, dim=768, ffn_hidden_dim=None, dropout=0.1, freq_dropout=0.2):
        super().__init__()
        self.dim = dim
        self.freq_dropout = freq_dropout

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        # self.ffn_hidden_dim = ffn_hidden_dim or dim * 4
        # self.ffn = nn.Sequential(
        #     nn.Linear(dim, self.ffn_hidden_dim),
        #     nn.GELU(),
        #     nn.Dropout(dropout),
        #     nn.Linear(self.ffn_hidden_dim, dim),
        #     nn.Dropout(dropout)
        # )

    def forward(self, x):
        # x: (B, N, D)
        residual = x
        x = self.norm1(x)

        # Frequency domain transformation
        x_freq = torch.fft.fft(torch.fft.fft(x, dim=-1), dim=-2)

        if self.training and self.freq_dropout > 0:
            drop_mask = (torch.rand_like(x_freq.real) > self.freq_dropout).float()
            x_freq = x_freq * drop_mask  # Apply same mask to both real/imag parts

        # Back to time domain
        x_time = torch.fft.ifft(torch.fft.ifft(x_freq, dim=-2), dim=-1).real

        # Residual connection
        x = residual + x_time

        # # FeedForward block
        # x = x + self.ffn(self.norm2(x))

        return x
class FNetBlock_hidden(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, x):
    x = torch.fft.fft(x, dim=-1).real
    return x

class FNetBlock_sequence(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, x):
    x = torch.fft.fft(x, dim=-2).real
    return x

class VPTBlock(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, x):
    return x
class PromptedTransformer_Fourier(Transformer):
    def __init__(self, prompt_config, config, img_size, vis):
        assert prompt_config.LOCATION == "prepend"
        assert prompt_config.INITIATION == "random"
        assert prompt_config.NUM_DEEP_LAYERS is None
        assert not prompt_config.DEEP_SHARED
        super(PromptedTransformer_Fourier, self).__init__(
            config, img_size, vis)
        
        self.prompt_config = prompt_config
        self.vit_config = config
        
        img_size = _pair(img_size)
        patch_size = _pair(config.patches["size"])

        num_tokens = self.prompt_config.NUM_TOKENS
        self.num_tokens = num_tokens  # number of prompted tokens

        if self.prompt_config.FOURIER_TYPE == "fixed_linear" :
            nn.init.kaiming_normal_(self.FT.weight, a=0, mode='fan_out')
        elif self.prompt_config.FOURIER_DIMENSION == "vpt":
            self.FT = VPTBlock()
        elif self.prompt_config.FOURIER_DIMENSION == "sequence":
            self.FT = FNetBlock_sequence()
        elif self.prompt_config.FOURIER_DIMENSION == "hidden":
            self.FT = FNetBlock_hidden()
        elif self.prompt_config.FOURIER_DIMENSION == "all":
            self.FT = FNetBlock()
        elif self.prompt_config.FOURIER_DIMENSION == "quantum":
            self.FT = QNetBlock()
        elif self.prompt_config.FOURIER_DIMENSION == "Unitary":
            self.FT = LearnableUnitaryBlock()
        elif self.prompt_config.FOURIER_DIMENSION == "DCT":
            self.FT = DCTBlock()
        elif self.prompt_config.FOURIER_DIMENSION == "givens":
            self.FT = GeneralOrthogonalBlock()
            # self.FT = NonLinearOrthoBlock()
            # self.FT = StackedOrtho()
        elif self.prompt_config.FOURIER_DIMENSION == "exp":
            self.FT = ExponentialOrthogonalBlock()
        elif self.prompt_config.FOURIER_DIMENSION == "house":
            self.FT = HouseholderFNetBlock()
        elif self.prompt_config.FOURIER_DIMENSION == "cayley":
            self.FT = CayleyFNetBlock()
        elif self.prompt_config.FOURIER_DIMENSION == "fnetplus":
            self.FT = FNetPlusBlock()
        elif self.prompt_config.FOURIER_DIMENSION == "sparsefft":
            self.FT = SparseFFTBlock()
        if self.prompt_config.FOURIER_PERCENTAGE != 0.0:
            self.givens_num_tokens = math.floor(self.prompt_config.NUM_TOKENS * self.prompt_config.FOURIER_PERCENTAGE)
        else:
            self.givens_num_tokens = 0
        
        self.givens_addition_num_tokens = self.prompt_config.FOURIER_ADDITION_NUM
        
        self.prompt_dropout = Dropout(self.prompt_config.DROPOUT)

        # if project the prompt embeddings
        if self.prompt_config.PROJECT > -1:
            # only for prepend / add
            prompt_dim = self.prompt_config.PROJECT
            self.prompt_proj = nn.Linear(
                prompt_dim, config.hidden_size)
            nn.init.kaiming_normal_(
                self.prompt_proj.weight, a=0, mode='fan_out')
        else:
            prompt_dim = config.hidden_size
            self.prompt_proj = nn.Identity()

        # initiate prompt:
        if self.prompt_config.INITIATION == "random":
            val = math.sqrt(6. / float(3 * reduce(mul, patch_size, 1) + prompt_dim))  # noqa

            self.prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))
            # xavier_uniform initialization
            nn.init.uniform_(self.prompt_embeddings.data, -val, val)

            if self.prompt_config.DEEP:  # noqa

                total_d_layer = config.transformer["num_layers"]-1

                self.deep_prompt_embeddings = nn.Parameter(torch.zeros(
                    total_d_layer, num_tokens, prompt_dim))
                # xavier_uniform initialization
                nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)
                
                if self.prompt_config.FOURIER_ADDITION:
                    self.deep_prompt_givens_embeddings = nn.Parameter(torch.zeros(
                    total_d_layer, self.givens_addition_num_tokens, prompt_dim))
                    # xavier_uniform initialization
                    nn.init.uniform_(self.deep_prompt_givens_embeddings.data, -val, val)

        else:
            raise ValueError("Other initiation scheme is not supported")
        
        # Wrap self.prompt_embeddings in ParameterWrapper to be able to register hooks
        # self.prompt_embeddings = ParameterWrapper(self.prompt_embeddings.weight)

    def incorporate_prompt(self, x):
        # combine prompt embeddings with image-patch embeddings
        B = x.shape[0]
        # after CLS token, all before image patches
        x = self.embeddings(x)  # (batch_size, 1 + n_patches, hidden_dim)
        
        if self.prompt_config.FOURIER_FIRST_LAYER and self.prompt_config.FOURIER_PERCENTAGE != 0.0 and self.prompt_config.FOURIER_HALF != "later":
            deep_prompt_givens_emb = self.prompt_dropout(self.prompt_proj(
                            self.prompt_embeddings[0][:self.givens_num_tokens]).expand(B, -1, -1))

            if self.givens_num_tokens ==  self.num_tokens:
                x = torch.cat((
                    x[:, :1, :],
                    self.FT(deep_prompt_givens_emb),
                    x[:, 1:, :]
                ), dim=1)
            else:
                if self.prompt_config.FOURIER_LOCATION == 'prepend':
                    x = torch.cat((
                        x[:, :1, :],
                        self.FT(deep_prompt_givens_emb),
                        self.prompt_dropout(self.prompt_proj(self.prompt_embeddings[0][self.givens_num_tokens:]).expand(B, -1, -1)),
                        x[:, 1:, :]
                    ), dim=1)
                elif self.prompt_config.FOURIER_LOCATION == 'random':
                    selected_indices = random.sample(range(self.num_tokens), self.givens_num_tokens)
                    tmp = self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)).clone()
                    for index in selected_indices:
                        tmp[:, index:index+1, :] = self.FT(self.prompt_dropout(self.prompt_proj(self.prompt_embeddings[0][index:index+1]).expand(B, -1, -1)))
                    x = torch.cat((
                        x[:, :1, :],
                        tmp,
                        x[:, 1:, :]
                    ), dim=1)
                else:
                    x = torch.cat((
                        x[:, :1, :],
                        self.prompt_dropout(self.prompt_proj(self.prompt_embeddings[0][self.givens_num_tokens:]).expand(B, -1, -1)),
                        self.FT(deep_prompt_givens_emb),
                        x[:, 1:, :]
                    ), dim=1)
        else:
            x = torch.cat((
                    x[:, :1, :],
                    self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)),
                    x[:, 1:, :]
                ), dim=1)
        # (batch_size, cls_token + n_prompt + n_patches, hidden_dim)

        return x

    def train(self, mode=True):
        # set train status for this class: disable all but the prompt-related modules
        if mode:
            # training:
            self.encoder.eval()
            self.embeddings.eval()
            self.prompt_proj.train()
            self.prompt_dropout.train()
        else:
            # eval:
            for module in self.children():
                module.train(mode)

    def forward_deep_prompt(self, embedding_output):
        attn_weights = []
        hidden_states = None
        weights = None
        B = embedding_output.shape[0]
        num_layers = self.vit_config.transformer["num_layers"]

        for i in range(num_layers):
            if i == 0:
                hidden_states, weights = self.encoder.layer[i](embedding_output)
            else:
                if i <= self.deep_prompt_embeddings.shape[0]:
                    if self.prompt_config.FOURIER_PERCENTAGE == 0.0:
                        deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
                            self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))

                        hidden_states = torch.cat((
                            hidden_states[:, :1, :],
                            deep_prompt_emb,
                            hidden_states[:, (1+self.num_tokens):, :]
                        ), dim=1)
                    elif self.prompt_config.FOURIER_HALF == "former" and i>=6:
                        deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
                        self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))

                        hidden_states = torch.cat((
                            hidden_states[:, :1, :],
                            deep_prompt_emb,
                            hidden_states[:, (1+self.num_tokens):, :]
                        ), dim=1)
                    elif self.prompt_config.FOURIER_HALF == "later" and i<=5:
                        deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
                        self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))

                        hidden_states = torch.cat((
                            hidden_states[:, :1, :],
                            deep_prompt_emb,
                            hidden_states[:, (1+self.num_tokens):, :]
                        ), dim=1)
                    elif self.prompt_config.MIXED == True and i%2 != 0:
                        # print('here')
                        deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
                            self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))

                        hidden_states = torch.cat((
                            hidden_states[:, :1, :],
                            deep_prompt_emb,
                            hidden_states[:, (1+self.num_tokens):, :]
                        ), dim=1)
                    elif not self.prompt_config.FOURIER_ADDITION:
                        if self.givens_num_tokens !=  self.num_tokens:
                            deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
                                self.deep_prompt_embeddings[i-1][self.givens_num_tokens:]).expand(B, -1, -1))
                        
                        deep_prompt_givens_emb = self.prompt_dropout(self.prompt_proj(
                            self.deep_prompt_embeddings[i-1][:self.givens_num_tokens]).expand(B, -1, -1))

                        if self.givens_num_tokens ==  self.num_tokens:
                            hidden_states = torch.cat((
                                hidden_states[:, :1, :],
                                self.FT(deep_prompt_givens_emb),
                                hidden_states[:, (1+self.num_tokens):, :]
                            ), dim=1)
                        else:
                            if self.prompt_config.FOURIER_LOCATION == 'prepend':
                                hidden_states = torch.cat((
                                    hidden_states[:, :1, :],
                                    self.FT(deep_prompt_givens_emb),
                                    deep_prompt_emb,
                                    hidden_states[:, (1+self.num_tokens):, :]
                                ), dim=1)
                            elif self.prompt_config.FOURIER_LOCATION == 'random':
                                selected_indices = random.sample(range(self.num_tokens), self.givens_num_tokens)
                                tmp = self.prompt_dropout(self.prompt_proj(self.deep_prompt_embeddings[i-1:i]).expand(B, -1, -1)).clone()
                                for index in selected_indices:
                                    tmp[:, index:index+1, :] = self.FT(self.prompt_dropout(self.prompt_proj(self.deep_prompt_embeddings[i-1][index:index+1]).expand(B, -1, -1)))
                                hidden_states = torch.cat((
                                    hidden_states[:, :1, :],
                                    tmp,
                                    hidden_states[:, (1+self.num_tokens):, :]
                                ), dim=1)
                            else:
                                hidden_states = torch.cat((
                                    hidden_states[:, :1, :],
                                    deep_prompt_emb,
                                    self.FT(deep_prompt_givens_emb),
                                    hidden_states[:, (1+self.num_tokens):, :]
                                ), dim=1)
                    else:
                        deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
                            self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))
                        
                        deep_prompt_givens_emb = self.prompt_dropout(self.prompt_proj(
                            self.deep_prompt_givens_embeddings[i-1]).expand(B, -1, -1))

                        hidden_states = torch.cat((
                            hidden_states[:, :1, :],
                            self.FT(deep_prompt_givens_emb),
                            deep_prompt_emb,
                            hidden_states[:, (1+self.num_tokens):, :]
                        ), dim=1)


                hidden_states, weights = self.encoder.layer[i](hidden_states)

            if self.encoder.vis:
                attn_weights.append(weights)

        encoded = self.encoder.encoder_norm(hidden_states)
        return encoded, attn_weights

    def forward(self, x):
        # this is the default version:
        embedding_output = self.incorporate_prompt(x)

        if self.prompt_config.DEEP:
            encoded, attn_weights = self.forward_deep_prompt(
                embedding_output)
        else:
            encoded, attn_weights = self.encoder(embedding_output)

        return encoded, attn_weights


class PromptedVisionTransformer_Fourier(VisionTransformer):
    def __init__(
        self, prompt_cfg, model_type,
        img_size=224, num_classes=21843, vis=False
    ):
        assert prompt_cfg.VIT_POOL_TYPE == "original"
        super(PromptedVisionTransformer_Fourier, self).__init__(
            model_type, img_size, num_classes, vis)
        if prompt_cfg is None:
            raise ValueError("prompt_cfg cannot be None if using PromptedVisionTransformer")
        self.prompt_cfg = prompt_cfg
        vit_cfg = CONFIGS[model_type]
        self.transformer = PromptedTransformer_Fourier(
            prompt_cfg, vit_cfg, img_size, vis)

    def forward(self, x, vis=False):
        x_copy = x.clone()
        x, attn_weights = self.transformer(x)
        
        if self.prompt_cfg.VIS == True:
            aw = attn_weights[11].cpu().detach().numpy().mean(axis=1)
            
            np.savez(f'./{self.prompt_cfg.VIS_JSON_FOURIER}.npz', aw)
            
        x = x[:, 0]

        logits = self.head(x)

        if not vis:
            return logits
        return logits, attn_weights
    
import numpy as np
import math
from scipy.linalg import dft
import matplotlib.pyplot as plt

def calc_spectrum(A, F):
    return F @ A @ F.T

def plot_spectrum(a, len_tokens):
    F = dft(len_tokens, scale='sqrtn')
    s = calc_spectrum(a, F)
    s = np.linalg.norm(s, ord=2, axis=1)
    s = np.concatenate([s[-math.floor(len_tokens/2):], s[0:1], s[1:math.floor(len_tokens/2)]], axis=0)
    return s

def plot_6x12(attn_weights, path):
    for r in range(1, 13):
        for c in range(1, 7):
            i = (r-1)*6 + c
            plt.subplot(12, 6, i)
            s = plot_spectrum(attn_weights[r-1][0,c-1,:,:].cpu().detach().numpy())
            plt.plot(s)
            if r != 12:
                plt.xticks([])
            plt.tick_params(labelsize=3)
            
    plt.savefig(path,dpi=500)

import torch
import torchvision.transforms as T
from timm.models.vision_transformer import vit_small_patch16_224
import json
from PIL import Image, ImageDraw
import numpy as np
import matplotlib.pyplot as plt

def grid_show(to_shows, cols, pth):
    rows = 1
    cols = 2
    it = iter(to_shows)
    image, title = next(it)
    plt.imshow(image, cmap='YlGnBu')
    plt.title(title)
    plt.yticks([])
    plt.xticks([])
    plt.colorbar()
    plt.savefig(pth)

def visualize_head(att_map):
    ax = plt.gca()
    # Plot the heatmap
    im = ax.imshow(att_map)
    # Create colorbar
    cbar = ax.figure.colorbar(im, ax=ax)
    plt.show()
    
def visualize_heads(att_map, cols, pth):
    to_shows = []
    att_map = att_map.squeeze()
    # for i in range(att_map.shape[0]):
    #     to_shows.append((att_map[i], f'Head {i}'))
    average_att_map = att_map.mean(axis=0)
    to_shows.append((average_att_map, 'Head Average'))
    grid_show(to_shows, cols=cols, pth=pth)

def gray2rgb(image):
    return np.repeat(image[...,np.newaxis],3,2)
    
def cls_padding(image, mask, cls_weight, grid_size):
    if not isinstance(grid_size, tuple):
        grid_size = (grid_size, grid_size)
        
    image = np.array(image)

    H, W = image.shape[:2]
    delta_H = int(H/grid_size[0])
    delta_W = int(W/grid_size[1])
    
    padding_w = delta_W
    padding_h = H
    padding = np.ones_like(image) * 255
    padding = padding[:padding_h, :padding_w]
    
    padded_image = np.hstack((padding,image))
    padded_image = Image.fromarray(padded_image)
    draw = ImageDraw.Draw(padded_image)
    draw.text((int(delta_W/4),int(delta_H/4)),'CLS', fill=(0,0,0)) # PIL.Image.size = (W,H) not (H,W)

    mask = mask / max(np.max(mask),cls_weight)
    cls_weight = cls_weight / max(np.max(mask),cls_weight)
    
    if len(padding.shape) == 3:
        padding = padding[:,:,0]
        padding[:,:] = np.min(mask)
    mask_to_pad = np.ones((1,1)) * cls_weight
    mask_to_pad = Image.fromarray(mask_to_pad)
    mask_to_pad = mask_to_pad.resize((delta_W, delta_H))
    mask_to_pad = np.array(mask_to_pad)

    padding[:delta_H,  :delta_W] = mask_to_pad
    padded_mask = np.hstack((padding, mask))
    padded_mask = padded_mask
    
    meta_mask = np.zeros((padded_mask.shape[0], padded_mask.shape[1],4))
    meta_mask[delta_H:,0: delta_W, :] = 1 
    
    return padded_image, padded_mask, meta_mask
    

def visualize_grid_to_grid_with_cls(att_map, grid_index, image, grid_size=14, alpha=0.6, pth=None):
    if not isinstance(grid_size, tuple):
        grid_size = (grid_size, grid_size)
    
    attention_map = att_map[grid_index]
    cls_weight = attention_map[0]
    
    mask = attention_map[1:].reshape(grid_size[0], grid_size[1])
    mask = Image.fromarray(mask).resize((image.size))
    
    padded_image ,padded_mask, meta_mask = cls_padding(image, mask, cls_weight, grid_size)
    
    if grid_index != 0: # adjust grid_index since we pad our image
        grid_index = grid_index + (grid_index-1) // grid_size[1]
        
    grid_image = highlight_grid(padded_image, [grid_index], (grid_size[0], grid_size[1]+1))
    
    fig, ax = plt.subplots(1, 2, figsize=(10,7))
    fig.tight_layout()
    
    ax[0].imshow(grid_image)
    ax[0].axis('off')
    
    ax[1].imshow(grid_image)
    ax[1].imshow(padded_mask, alpha=alpha, cmap='rainbow')
    ax[1].imshow(meta_mask)
    ax[1].axis('off')
    plt.savefig(pth)
    

def visualize_grid_to_grid(att_map, grid_index, image, grid_size=14, alpha=0.6):
    if not isinstance(grid_size, tuple):
        grid_size = (grid_size, grid_size)
    
    H,W = att_map.shape
    with_cls_token = False
      
    grid_image = highlight_grid(image, [grid_index], grid_size)
    
    mask = att_map[grid_index].reshape(grid_size[0], grid_size[1])
    mask = Image.fromarray(mask).resize((image.size))
    
    fig, ax = plt.subplots(1, 2, figsize=(10,7))
    fig.tight_layout()
    
    ax[0].imshow(grid_image)
    ax[0].axis('off')
    
    ax[1].imshow(grid_image)
    ax[1].imshow(mask/np.max(mask), alpha=alpha, cmap='rainbow')
    ax[1].axis('off')
    plt.show()
    
def highlight_grid(image, grid_indexes, grid_size=14):
    if not isinstance(grid_size, tuple):
        grid_size = (grid_size, grid_size)
    
    W, H = image.size
    h = H / grid_size[0]
    w = W / grid_size[1]
    image = image.copy()
    for grid_index in grid_indexes:
        x, y = np.unravel_index(grid_index, (grid_size[0], grid_size[1]))
        a= ImageDraw.ImageDraw(image)
        a.rectangle([(y*w,x*h),(y*w+w,x*h+h)],fill =None,outline ='red',width =2)
    return image

def vis_attn(x, attention_maps, pn, name):
    # image = Image.open('./assets/dogcat.jpg')
    
    
    total = attention_maps[11].size()[2]
    
    visualize_heads(attention_maps[11][9,:,:,:].cpu().detach().numpy(), cols=4, pth=f"./layer_ave12")
    fig = plt.figure(figsize=(8,6))
    fig.tight_layout() # Or equivalently,  "plt.tight_layout()"
    plt.subplots_adjust(wspace=0.)

    ax = fig.add_subplot(1,1, 1, projection='3d')
    
    xbase_index = [0,8,9,10,11,12,13,14,15,16,17, 23,30,37,44,51]
    num_tokens = len(xbase_index)
    xdata = np.array([xbase_index for i in range(total)])
    ydata = np.array([np.ones(num_tokens) * i for i in range(total)])
    
    # print(attention_maps[11].size())
    tmp = attention_maps[11][9,:,0,:].unsqueeze(1)
    # print(tmp.size())
    for i in [1,2,3,4,5,6,7,8,9,10,11,60,109,158,206]:
        tmp = torch.cat((tmp,
                        attention_maps[11][9,:,i,:].unsqueeze(1)
                            ), dim=1)
        # print(tmp.size())
    # print(tmp.size())
    zdata = tmp.cpu().detach().numpy().mean(axis=0).T
    ax.plot_wireframe(xdata, ydata, zdata, rstride=0, color="royalblue", linewidth=1)

    ax.set_title(name, fontsize=20, fontweight="bold", y=1.015)
    plt.savefig(f"./3d_attention")
    
    for i in range(10):
        tensor = x[i,:,:,:]
        tensor = tensor.cpu().clone()
        tensor = tensor.squeeze(0)
        tensor = tensor.permute(1, 2, 0)
        image = tensor.numpy()
        image = (image * 255).astype(np.uint8)
        image = Image.fromarray(image)
        
        attention_map = torch.cat((
                                attention_maps[11][i,:,:1,pn+1:],
                                attention_maps[11][i,:,pn+1:,pn+1:]
                            ), dim=1)
        
        tmp = torch.cat((attention_maps[11][i,:,:1,:1],
                        attention_maps[11][i,:,pn+1:,:1]
                            ), dim=1)
        
        attention_map = torch.cat((
                                tmp,
                                attention_map
                            ), dim=2)
        
        visualize_grid_to_grid_with_cls(attention_map.cpu().detach().numpy().mean(axis=0), 0, image, pth=f"./img_{str(i)}")
    


    
    