import inspect
import math
import logging
from typing import Optional
import omegaconf
from omegaconf import OmegaConf

import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention import Attention

from qdiff.base.base_quantizer import DynamicQuantizer
from qdiff.base.quant_attn import QuantizedAttentionMap

from models.sparse_attn import SparseAttentionMap, EmptyHeadAttentionMap, PAROAttentionMap
from models.attn_eval_utils import evaluate_attention_maps

import time
logger = logging.getLogger(__name__)  # pylint: disable=invalid-name

class SaveActivationHook:

    def __init__(self, type=None, original_shape=None, quant_config=None):
        self.hook_handle = None
        self.type = type
        self.original_shape = original_shape
        self.quant_config = quant_config
        self.outputs = []
        self.attn_ds_rate = self.quant_config.calib_data.get('attn_ds_rate', None)
        self.qkv_ds_rate = self.quant_config.calib_data.get('qkv_ds_rate', None)
        
    def attn_map_downsample(self, data):
        '''
        down_sample in the N_token dimension, handle the indivisible situation. 
        '''
        assert self.type == 'attn'
        BS, head_per_split_num, N_token, N_token = self.original_shape
            
        if isinstance(self.attn_ds_rate, omegaconf.listconfig.ListConfig):
            # IMPORTANT: too aggresive downsample under the last dimension will cause permute visualization false
            # need to downsample from multiple dimensions
            F = 13
            H = 30
            W = 45
            N_text_token = self.quant_config.model.n_text_tokens
            N_image_token = N_token - N_text_token
            data = data[:,:,N_text_token:,N_text_token:]
            assert data.shape[2] == data.shape[3] == N_image_token
            assert N_image_token == F*W*H
            attn_ds_rate_all = self.attn_ds_rate[0]*self.attn_ds_rate[1]*self.attn_ds_rate[2]
            
            if N_image_token % attn_ds_rate_all == 0:
                data_expanded = data
                H_ceil = H
                W_ceil = W
                N_image_token_ceil = N_image_token
            else:
                raise AssertionError("We only use attn_map to determine permute, so [1,5,5] is more suitable. ")
                # support arbitray ([1,8,8]) but not divisble ([1,5,5]) downsample 
                h_ds_rate = self.attn_ds_rate[1]
                w_ds_rate = self.attn_ds_rate[2]
                H_ceil = ((H + (h_ds_rate-1))//h_ds_rate)*h_ds_rate
                W_ceil = ((W + (w_ds_rate-1))//w_ds_rate)*w_ds_rate
                N_image_token_ceil = F*H_ceil*W_ceil

                data_expanded = torch.zeros([
                    BS,
                    head_per_split_num,
                    F,
                    H_ceil,  # 32
                    W_ceil,  # 48
                    F,
                    H_ceil,
                    W_ceil,
                ], device=data.device)
                
                data_expanded[:,:,:,:H,:W,:,:H,:W] = data.reshape([
                    BS,
                    head_per_split_num,
                    F, 
                    H,  # 30
                    W,  # 45
                    F,
                    H,
                    W,
                ])
                
            self.attn_ds_rate[0] == 1

            data_expanded = data_expanded.reshape([
                BS,
                head_per_split_num,
                F,
                H_ceil // self.attn_ds_rate[1],
                self.attn_ds_rate[1],    # max in this dim
                W_ceil // self.attn_ds_rate[2],
                self.attn_ds_rate[2],    # max in this dim
                F,
                H_ceil // self.attn_ds_rate[1],
                self.attn_ds_rate[1],    # max in this dim
                W_ceil // self.attn_ds_rate[2],
                self.attn_ds_rate[2],    # max in this dim
            ])
            
            # return data_expanded.max(dim=4)[0].max(dim=5)[0].max(dim=7)[0].max(dim=8)[0].reshape([
            return data_expanded.sum(dim=4).sum(dim=5).sum(dim=7).sum(dim=8).reshape([
                BS,
                head_per_split_num,
                N_image_token_ceil//attn_ds_rate_all,
                N_image_token_ceil//attn_ds_rate_all,
            ])
            
        else: 
            N_text_token = self.quant_config.model.n_text_tokens
            N_image_token = N_token - N_text_token
            data = data[:,:,N_text_token:,N_text_token:]
            
            attn_ds_type = self.quant_config.calib_data.get("attn_ds_type", "reduce_sum")  # default is reduce_sum
            if attn_ds_type == "reduce_sum":
                # INFO: single value of self.attn_ds_rate, just downsample the last dim (used for the sparse plan), maybe it causes the last_dim to vanish, when ds_size is larger than F,H,W. But we assume after permute the data is already locally aggregated, so use plain downsample, since it is more suitable for efficienct kernel processing.

                N_remainder = N_image_token % self.attn_ds_rate
                if N_remainder != 0:
                    data = data[:,:,:-N_remainder,:-N_remainder]
                data = data.reshape([
                    BS,head_per_split_num,N_image_token//self.attn_ds_rate,self.attn_ds_rate,N_image_token//self.attn_ds_rate,self.attn_ds_rate
                    ])
                return data.sum(dim=3).sum(dim=4)
                # return data.max(dim=3)[0].max(dim=4)[0]
            elif attn_ds_type == "part":
                N_part = N_image_token // self.attn_ds_rate
                return data[:,:,:N_part,:N_part]
            else:
                raise NotImplementedError
            

    def qkv_downsample(self, data):
        assert self.type in ['qk','v']
        
        ds_type = self.quant_config.calib_data.get("qkv_ds_type", "reduce_max")  # default reduce_mas
        
        BS, head_per_split_num, N_token, N_dim = self.original_shape
        N_text_token = self.quant_config.model.n_text_tokens
        N_image_token = N_token - N_text_token
        if ds_type == "reduce_max":
            N_remainder = N_image_token % self.qkv_ds_rate
            if N_remainder != 0:
                data = data[:,:,N_text_token:-N_remainder,:]  # indexing [:-0:] will have 0. 
            data = data.reshape([
                BS,head_per_split_num,N_image_token//self.qkv_ds_rate,self.qkv_ds_rate,N_dim
                ])
            return data.max(dim=3)[0]
        elif ds_type == "part":
            N_part = N_token // self.qkv_ds_rate  # how many to fetch
            data = data[:,:,:N_part,:]
            return data
        else:
            raise NotImplementedError
        
    def __call__(self, module, module_in, module_out):
        '''
        the input shape could be [BS, N_group];
        reduce along the head dimension. 
        '''
        if self.type == 'qk':
            BS, head_per_split_num, N_token, N_dim = self.original_shape
            data = module_in[0].reshape(self.original_shape)
            # data = module_in[0].reshape(self.original_shape).abs().max(dim=-1)[0].to('cpu') # avoid taking up too much GPU memory
            if self.qkv_ds_rate is not None:
                data = self.qkv_downsample(data)
                
        elif self.type == 'v':
            BS, head_per_split_num, N_token, N_dim = self.original_shape
            data = module_in[0].reshape([BS, head_per_split_num, N_dim, N_token]).permute(0,1,3,2)
            if self.qkv_ds_rate is not None:
                data = self.qkv_downsample(data)        
                        
        elif self.type == 'attn':
            BS, head_per_split_num, N_token, N_token = self.original_shape
            data = module_in[0].reshape(self.original_shape)
            if self.attn_ds_rate is not None:
                data = self.attn_map_downsample(data)
        else:
            raise NotImplementedError
        
        # TODO: add processing. 
        self.outputs.append(data.to('cpu'))

    def clear(self):
        self.outputs = []

def add_hook_to_module_(module, hook_cls, **kwargs):
    hook = hook_cls(**kwargs)
    hook.hook_handle = module.register_forward_hook(hook)
    return hook

def exp_of_two_softmax(x, dim=-1):
    exp_x = torch.pow(2, x)
    sum_exp_x = torch.sum(exp_x, dim=dim, keepdim=True)
    softmax_output = exp_x / sum_exp_x
    return softmax_output

class CustomizeCogVideoXAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
    query and key vectors, but does not include spatial normalization.
    """

    def __init__(self):
        self.customize_attention = True
        # logger.info("using customized attention block-wise split to avoid large memory & to explicitly save attention map.")
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
             
    def convert_quant(self, quant_config):
        
        self.quant_config = quant_config
        
        if self.quant_config.get('export_calib_data', False):
            self.apply_hooks = True
            self.hooks = {}  # hook_handle to store activations
        else:
            self.apply_hooks = False
        
        # default set as Identity
        self.q_quantizer = nn.Identity()
        self.k_quantizer = nn.Identity()
        self.v_quantizer = nn.Identity()
        self.pre_softmax_attn_map_quantizer = nn.Identity()
        self.attn_map_quantizer = nn.Identity()
        self.attn_map_sparse_processor = nn.Identity()

        if self.quant_config.attn.get('qk', None) is not None:
            self.q_quantizer = DynamicQuantizer(self.quant_config.attn.qk)
            self.q_quantizer.module_name = f"transformer_blocks.{self.i_block}.attn_map.q"
            self.k_quantizer = DynamicQuantizer(self.quant_config.attn.qk)
            self.k_quantizer.module_name = f"transformer_blocks.{self.i_block}.attn_map.k"

            if self.quant_config.attn.qk.get('mixed_precision_cfg_path', None) is not None:
                raise AssertionError("QK Matmul Mixed Precision is only supported for hardware accelerator. ")
                self.pre_softmax_attn_map_quantizer = QuantizedAttentionMap(self.quant_config)
                # process from list into whole tensor
                self.pre_softmax_attn_map_quantizer.mixed_precision_cfg = self.pre_softmax_attn_map_mixed_precision_cfg

        if self.quant_config.attn.get('v', None) is not None:
            self.v_quantizer = DynamicQuantizer(self.quant_config.attn.v)
            # need to be different with v_mapping.
            self.v_quantizer.module_name = f"transformer_blocks.{self.i_block}.attn_map.v"
                
        if self.quant_config.attn.get('attn_map', None) is not None:
            self.attn_map_quantizer = QuantizedAttentionMap(self.quant_config)
            if self.quant_config.attn.attn_map.get('mixed_precision_cfg_path', None) is not None:
                raise AssertionError("PV Matmul Mixed Precision is only supported for hardware accelerator. ")
                self.attn_map_quantizer.mixed_precision_cfg = self.post_softmax_attn_map_mixed_precision_cfg

        # INFO: sparse support for baseline methods
        if self.quant_config.attn.get('sparse', None) is not None:
            # self.attn_map_sparse_processor = SparseAttentionMap(self.quant_config)  # the deprecated sparse scheme. 
            
            # INFO: apply empty head sparse processing.
            if self.quant_config.attn.sparse.get('empty_head', None):
                if self.quant_config.attn.sparse.get('block_sparse', None):
                    # PARO attn scheme: empty_head + block_sparse
                    self.attn_map_sparse_processor = PAROAttentionMap(self.quant_config, sparse_plan=self.sparse_plan, permute_plan=self.permute_plan)
                else:
                    self.attn_map_sparse_processor = EmptyHeadAttentionMap(self.quant_config, sparse_plan=self.sparse_plan)
        else:
            self.attn_map_sparse_processor = nn.Identity()
            
        if self.apply_hooks:
            if self.quant_config.calib_data.qkv:
                ##记得改
                #self.hooks['q'] = add_hook_to_module_(self.q_quantizer, SaveActivationHook, type='qk', quant_config=self.quant_config)
                #self.hooks['k'] = add_hook_to_module_(self.k_quantizer, SaveActivationHook, type='qk', quant_config=self.quant_config)
                self.hooks['v'] = add_hook_to_module_(self.v_quantizer, SaveActivationHook, type='v', quant_config=self.quant_config)
            if self.quant_config.calib_data.attn_map:
                # INFO: change it to pre_softmax attention_map for sparse.
                self.hooks['attn_map'] = add_hook_to_module_(self.attn_map_sparse_processor, SaveActivationHook, type='attn', quant_config=self.quant_config)

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # logger.info(f'processing {self.i_block}-th block')
        text_seq_length = encoder_hidden_states.size(1)
        
        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        '''
        INFO: the QK quantization
        both are per-token quantization, reshape into [-1,dim]
        '''
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        if attn.norm_q is not None:
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        # Apply RoPE if needed
        if image_rotary_emb is not None:
            from diffusers.models.embeddings import apply_rotary_emb

            query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
            if not attn.is_cross_attention:
                key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
        
        '''
        INFO: the attention map quantizer
        '''
        if self.customize_attention:
            hidden_states = self.customize_scaled_dot_product_attention(
                query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, \
                head_split_num=self.quant_config.attn.head_split_num, use_exp_of_two_softmax=self.quant_config.attn.get("exp_of_two_softmax",False), 
            )
        else:
            hidden_states = F.scaled_dot_product_attention(
                query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
            )
        # if self.i_block == 0:
        #     torch.save(hidden_states, "./tensor_save/video_nosmooth.pth")
        #     print("hidden states saved\n")
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        encoder_hidden_states, hidden_states = hidden_states.split(
            [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
        )
        return hidden_states, encoder_hidden_states
    
    def customize_scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False, head_split_num=16, use_exp_of_two_softmax=False,\
        ) -> torch.Tensor:
        
        self.device = query.device
            
        L, S = query.size(-2), key.size(-2)
        scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
        attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
        if is_causal:
            assert attn_mask is None
            temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
            attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
            attn_bias.to(query.dtype)

        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
            else:
                attn_bias += attn_mask

        if enable_gqa:
            key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
            value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

        assert L == S

        head_num = query.shape[1]
        self.head_split_num = head_split_num
        head_per_split_num = head_num // head_split_num
        assert head_num % head_split_num == 0
        attn_output = torch.zeros_like(query)
        softmax_func = exp_of_two_softmax if use_exp_of_two_softmax else torch.softmax
        
        '''
        INFO: the qk permute (token-level) based on existing permuteing order.
        '''
        if self.quant_config.attn.get('sparse', None) is not None:
            if self.quant_config.attn.sparse.get('permute', False):
                query, key, value = self.permute_qk(query, key, value)
        
        for i in range(head_split_num):
            slice_range = (slice(None), slice(i * head_per_split_num, (i + 1) * head_per_split_num), slice(None), slice(None))
            query_part = query[slice_range]
            key_part = key[slice_range]
            value_part = value[slice_range]
        
            '''Quantization: qkv'''   
            # INFO: if config.attn.skip_text_quant = True, quantize image_part only.
            skip_text_quant = self.quant_config.attn.get("skip_text_quant", False)
            n_text_tokens = self.quant_config.model.n_text_tokens
            n_token = query.shape[2]
            n_image_tokens = n_token - n_text_tokens
            
            if skip_text_quant:  # quantize image_part_only
                query_part_ = query_part[:,:,n_text_tokens:,:]
                key_part_ = key_part[:,:,n_text_tokens:,:]
                value_part_ = value_part[:,:,n_text_tokens:,:]
            else:
                query_part_ = query_part
                key_part_ = key_part
                value_part_ = value_part
               
            BS, head_per_split_num, N_token, N_dim = query_part_.shape
            # if self.apply_hooks:
            #     if self.quant_config.calib_data.qkv:
            #         self.hooks['q'].original_shape = [BS, head_per_split_num, N_token, N_dim]
            if self.quant_config.attn.get('qk', None) is not None:
                if self.quant_config.attn.qk.get('smooth', False):
                    query_mean = torch.mean(query_part_, dim=-1, keepdim=True).expand(-1, -1, -1, N_dim)
                    key_mean = torch.mean(key_part_, dim=-1, keepdim=True).expand(-1, -1, -1, N_dim)
                    query_part_ = query_part_ - query_mean
                    key_part_ = key_part_ - key_mean
                    deltaS_part = query_mean @ key_part_.transpose(-2, -1)
                elif self.quant_config.attn.qk.get('onlyKsmooth', False):
                    key_mean = torch.mean(key_part_, dim=-1, keepdim=True).expand(-1, -1, -1, N_dim)
                    key_part_ = key_part_ - key_mean
            query_part[:,:,n_text_tokens:,:] = self.q_quantizer(query_part_.reshape([-1,N_dim])).reshape([BS, head_per_split_num, N_token, N_dim])

            # if self.apply_hooks:
            #     if self.quant_config.calib_data.qkv:
            #         self.hooks['k'].original_shape = [BS, head_per_split_num, N_token, N_dim]
            key_part[:,:,n_text_tokens:,:] = self.k_quantizer(key_part_.reshape([-1,N_dim])).reshape([BS, head_per_split_num, N_token, N_dim])

            if self.apply_hooks:
                if self.quant_config.calib_data.qkv:
                    self.hooks['v'].original_shape = [BS, head_per_split_num, N_token, N_dim]
            
            n_group = 1  # default
            per_group = 4096
            if self.quant_config.attn.get("v",None) is not None:
                if self.quant_config.attn.v.get("per_group",None) is not None:
                    per_group = self.quant_config.attn.v.per_group
            if N_token % per_group == 0:
                n_group = N_token // per_group
            else:
                n_group = N_token // per_group + 1
                    #assert N_token % n_group == 0
            N_quant = n_group * per_group
            padder_value_part = torch.zeros(BS, head_per_split_num, N_quant, N_dim,device=value_part_.device)
            padder_value_part[:, :, :N_token, :] = value_part_
            value_part_ = padder_value_part.reshape([BS,head_per_split_num,n_group,per_group,N_dim]).permute([0,1,4,2,3]).reshape([-1,per_group])
            value_part[:,:,n_text_tokens:,:] = self.v_quantizer(
                    value_part_
                ).reshape([BS, head_per_split_num, N_dim, N_quant]).permute([0,1,3,2])[:, :, :N_token, :]
            if self.quant_config.attn.get('qk', None) is not None and self.quant_config.attn.qk.get('smooth', None):
                attn_map_pre_softmax_part = query_part @ key_part.transpose(-2, -1)
                attn_map_pre_softmax_part[:,:,n_text_tokens:,n_text_tokens:] +=  deltaS_part
                attn_map_pre_softmax_part *= scale_factor
            else:
                attn_map_pre_softmax_part = query_part @ key_part.transpose(-2, -1) * scale_factor
            attn_map_pre_softmax_part += attn_bias
            
            self.pre_softmax_attn_map_quantizer.i_block = self.i_block
            self.pre_softmax_attn_map_quantizer.split_range = [i * head_per_split_num, (i + 1) * head_per_split_num]
            attn_map_pre_softmax_part = self.pre_softmax_attn_map_quantizer(attn_map_pre_softmax_part)
            
            # --------------- The Sparse Processor ----------------
            # INFO: sparse, whether skip according to timestep.
            SKIP_SPARSE = False   # assigned later.
            if self.quant_config.attn.get("sparse", False):
                if self.quant_config.attn.sparse.get("skip_timestep_percentage", None) is not None:
                    self.attn_map_sparse_processor.i_timestep = self.i_timestep
                    if (self.i_timestep) < self.quant_config.attn.sparse.skip_timestep_percentage:
                        SKIP_SPARSE = True
            self.attn_map_sparse_processor.i_block = self.i_block
            self.attn_map_sparse_processor.split_range = [i * head_per_split_num, (i + 1) * head_per_split_num]
            
            BS, head_per_split_num, N_token, N_token = attn_map_pre_softmax_part.shape
            if self.apply_hooks:
                if self.quant_config.calib_data.attn_map:
                    self.hooks['attn_map'].original_shape = [BS, head_per_split_num, N_token, N_token]

            # INFO: the pre_softmax attn sparse processor.
            if self.quant_config.attn.get("sparse", False):
                if self.quant_config.attn.sparse.get("pre_softmax", False):
                    # assert self.quant_config.attn.sparse.online == True, "the pre_softmax only support online mask.", maybe also offline presoftmax to try it out?
                    if not SKIP_SPARSE:
                        attn_map_pre_softmax_part = self.attn_map_sparse_processor(attn_map_pre_softmax_part)
            
            attn_map_post_softmax_part = softmax_func(attn_map_pre_softmax_part, dim=-1)
            attn_map_post_softmax_part = torch.dropout(attn_map_post_softmax_part, dropout_p, train=True)
            # INFO: the post_softmax attn sparse processor.
            if self.quant_config.attn.get("sparse", False):
                if not self.quant_config.attn.sparse.get("pre_softmax", False):
                    # the normal post softmax sparse process.
                    if not SKIP_SPARSE:
                        attn_map_post_softmax_part = self.attn_map_sparse_processor(attn_map_post_softmax_part)
            else:  # when no sparse is configured, the self.attn_map_sparse_processor is nn.Identity(), need to be infered for calib_data, only applied for post-softmax. 
                attn_map_post_softmax_part = self.attn_map_sparse_processor(attn_map_post_softmax_part)
                i
                        
            # INFO: (optional) rescale the embeds (calib for post-softmax, also apply after softmax only.)
            if self.quant_config.attn.get("sparse", False):
                if self.quant_config.attn.sparse.get("rescale_text_embeds", False):
                    assert "rescale_rows" in self.sparse_plan.keys()
                    N_timestep_in_calib_data = self.sparse_plan["rescale_rows"].shape[0]
                    i_timestep_in_calib_data = int(self.i_timestep // (1/N_timestep_in_calib_data))
                    split_range = [i * head_per_split_num, (i + 1) * head_per_split_num]
                    rescale_rows_ = self.sparse_plan["rescale_rows"][i_timestep_in_calib_data, self.i_block,split_range[0]:split_range[1]]
                    rescale_cols_ = self.sparse_plan["rescale_cols"][i_timestep_in_calib_data, self.i_block,split_range[0]:split_range[1]]
                    
                    BS, head_per_split_num, N_token, N_token = attn_map_post_softmax_part.shape
                    N_text_token = self.quant_config.model.n_text_tokens
                    N_image_token = N_token - self.quant_config.model.n_text_tokens
                    block_sparse_size = self.quant_config.calib_data.attn_ds_rate
                    N_block_sparse = N_image_token // block_sparse_size
                    N_masked_token = block_sparse_size*N_block_sparse  # when not divisble, smaller than N_image_token.

                    # apply cols rescale
                    attn_map_post_softmax_part[:,:,:N_text_token,N_text_token:N_masked_token+N_text_token] = (attn_map_post_softmax_part[:,:,:N_text_token,N_text_token:N_masked_token+N_text_token].reshape(
                        [BS, head_per_split_num, N_text_token, N_block_sparse, block_sparse_size]
                    )*rescale_cols_.reshape([1,head_per_split_num,1,N_block_sparse,1])).reshape([
                        BS, head_per_split_num, N_text_token, N_masked_token
                    ])
                    
                    # apply rows rescale
                    attn_map_post_softmax_part[:,:,N_text_token:N_masked_token+N_text_token,:N_text_token] = (attn_map_post_softmax_part[:,:,:N_text_token,N_text_token:N_masked_token+N_text_token].reshape(
                        [BS, head_per_split_num, N_block_sparse, block_sparse_size, N_text_token]
                    )*rescale_rows_.reshape([1,head_per_split_num,N_block_sparse,1,1])).reshape([
                        BS, head_per_split_num, N_masked_token, N_text_token
                    ])

                           
            '''Quantization: apply attention map quantizer before multiplying V'''
            # INFO: the attention map quantization, reshaped within the attn_map_quantizer
            # self.attn_map_quantizer.quant_mode = False
            self.attn_map_quantizer.i_block = self.i_block
            self.attn_map_quantizer.split_range = [i * head_per_split_num, (i + 1) * head_per_split_num]
            #print(f'attn_map_post_softmax_part={attn_map_post_softmax_part.shape}')[2, 4, 17776, 17776]
            # if i == 1 and self.i_block == 0:
            #     torch.save(attn_map_post_softmax_part[0,0,:,:], f"./visualization/calib_data/head4_permute/{time.time()}.pth")
            attn_map_post_softmax_part = self.attn_map_quantizer(attn_map_post_softmax_part)
            if self.quant_config.attn.get("FP8", False): 
                attn_map_fp8 = (attn_map_post_softmax_part[:,:,n_text_tokens:,n_text_tokens:]*448).to(torch.float8_e4m3fn)
                value_fp8 = value_part[:,:,n_text_tokens:,:].to(torch.float8_e4m3fn)
                attn_map_post_softmax_part[:,:,n_text_tokens:,n_text_tokens:] = attn_map_fp8.to(torch.bfloat16)/448
                value_part[:,:,n_text_tokens:,:] = value_fp8.to(torch.bfloat16)
            attn_output[slice_range] = attn_map_post_softmax_part @ value_part
        
        # INFO: unpack the hooks (deprecated, move this logic to quant_inference.py)
        # if self.apply_hooks:
        #     for k_ in self.hooks:
        #         save_data = torch.cat(self.hooks[k_].outputs, dim=1)
        #         torch.save(save_data, f'./visualization/savedz_{k_}s.pth')
        #     import ipdb; ipdb.set_trace()
        
        # INFO: permute back the output
        attn_output_ = attn_output
        if self.quant_config.attn.get('sparse', False):
            if self.quant_config.attn.sparse.get('permute', False):
                attn_output = self.permute_attn_out(attn_output)
            
        """ the same as the following code, only split to lower the memory footprint
        attn_weight = query @ key.transpose(-2, -1) * scale_factor
        attn_weight += attn_bias
        attn_weight = torch.softmax(attn_weight, dim=-1)
        attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
        attn_score = attn_weight @ value
        """
        assert attn_output.dtype == torch.bfloat16
        return attn_output
    
    def permute_qk(self, query, key, value):
        # (F,W,H) -> (Frame, With, Height) 
        # (17776-226) == 13*30*45
        BS, N_head, N_token, N_dim = query.shape
        query_image_part = query[:,:,self.quant_config.model.n_text_tokens:,:]
        key_image_part = key[:,:,self.quant_config.model.n_text_tokens:,:]
        value_image_part = value[:,:,self.quant_config.model.n_text_tokens:,:]
        
        N_image_token = N_token - self.quant_config.model.n_text_tokens
        F = 13
        H = 30
        W = 45
        assert N_image_token == F*W*H
        
        permutations = torch.tensor([
                [0, 1, 2],  # 0: FHW
                [0, 2, 1],  # 1: FWH
                [1, 2, 0],  # 2: HWF
                [1, 0, 2],  # 3: HFW
                [2, 1, 0],  # 4: WHF
                [2, 0, 1],  # 5: WFH
        ])
        permute_order_index = self.permute_plan['permute'][self.i_block]  # i_block is initialized during creating block in `transformer_3d.py`
        permute_orders = torch.stack([permutations[i.item()] for i in permute_order_index], dim=0)  # [N_head,3]
        
        for i_head in range(N_head):
            permute_dims_head = permute_orders[i_head]
            permute_dims_head_extend = tuple([0]+(permute_dims_head+1).tolist()+[4])
            
            query_image_part[:,i_head,:,:] = query_image_part[:,i_head,:,:].reshape([BS,F,H,W,N_dim]).permute(*permute_dims_head_extend).reshape([BS,N_image_token,N_dim])
            key_image_part[:,i_head,:,:] = key_image_part[:,i_head,:,:].reshape([BS,F,H,W,N_dim]).permute(*permute_dims_head_extend).reshape([BS,N_image_token,N_dim])
            value_image_part[:,i_head,:,:] = value_image_part[:,i_head,:,:].reshape([BS,F,H,W,N_dim]).permute(*permute_dims_head_extend).reshape([BS,N_image_token,N_dim])
        
        query[:,:,self.quant_config.model.n_text_tokens:,:] = query_image_part
        key[:,:,self.quant_config.model.n_text_tokens:,:] = key_image_part
        value[:,:,self.quant_config.model.n_text_tokens:,:] = value_image_part
        
        return query, key, value
        
    def permute_attn_out(self, attn_out):
        
        BS, N_head, N_token, N_dim = attn_out.shape
        attn_out_image_part = attn_out[:,:,self.quant_config.model.n_text_tokens:,:]
        
        N_image_token = N_token - self.quant_config.model.n_text_tokens
        F = 13
        H = 30
        W = 45
        assert N_image_token == F*W*H
        
        permute_order_index = self.permute_plan['permute'][self.i_block]  # i_block is initialized during creating block in `transformer_3d.py`
        permutations = torch.tensor([
            [0, 1, 2],  # 0: FHW
            [0, 2, 1],  # 1: FWH
            [1, 2, 0],  # 2: HWF
            [1, 0, 2],  # 3: HFW
            [2, 1, 0],  # 4: WHF
            [2, 0, 1],  # 5: WFH
        ])
        permutations_inv = torch.tensor([
            [0, 1, 2],  # 0: FHW
            [0, 2, 1],  # 1: FWH
            [2, 0, 1],  # 2: HWF
            [1, 0, 2],  # 3: HFW
            [2, 1, 0],  # 4: WHF
            [1, 2, 0],  # 5: WFH
        ])
        
        permute_orders = torch.stack([permutations[i.item()] for i in permute_order_index], dim=0)  # [N_head,3]
        permute_orders_inv = torch.stack([permutations_inv[i.item()] for i in permute_order_index], dim=0)  # [N_head,3]
        
        # indices = torch.zeros([N_head, N_image_token], device=self.device).long()
        for i_head in range(N_head):
            permute_dims_head = permute_orders[i_head]
            permute_dims_head_extend = tuple([0]+(permute_dims_head+1).tolist()+[4])
            permute_dims_head_inv = permute_orders_inv[i_head]
            permute_dims_head_inv_extend = tuple([0]+(permute_dims_head_inv+1).tolist()+[4])
                        
            permuted_shape = torch.tensor([BS,F,H,W,N_dim], device=self.device)[list(permute_dims_head_extend)]
            
            attn_out_image_part[:,i_head,:,:] = attn_out_image_part[:,i_head,:,:].reshape(*permuted_shape).permute(*permute_dims_head_inv_extend).reshape([BS,N_image_token,N_dim])
        
        attn_out[:,:,self.quant_config.model.n_text_tokens:,:] = attn_out_image_part
                
        return attn_out
    
    # def permute_attn(self, attn_map, split_id):
    #     # convert back the attention map
    #     # the input attention map are splitted
    #     # for debug_only, DONOT Need during inference        
    #     # (F,W,H) -> (Frame, With, Height) 
    #     # (17776-226) == 13*30*45
    #     permutations_inv = torch.tensor([
    #             [0, 1, 2],  # 0: FHW
    #             [0, 2, 1],  # 1: FWH
    #             [2, 0, 1],  # 2: HWF
    #             [1, 0, 2],  # 3: HFW
    #             [2, 1, 0],  # 4: WHF
    #             [1, 2, 0],  # 5: WFH
    #     ])
    #     N_text_token = self.quant_config.model.n_text_tokens
    #     N_image_token = N_token - self.quant_config.model.n_text_tokens
    #     F = 13
    #     H = 30
    #     W = 45
    #     assert N_image_token == F*W*H
    
    #     permute_order_index = self.optimal_permute['permute_order_index'][self.i_block]  
    #     BS, head_per_split_num, N_token, N_dim = attn_map.shape
    #     N_head = head_per_split_num * self.head_split_num
    #     permute_order_index = permute_order_index[self.head_split_num*split_id : self.head_split_num*(split_id+1)]
    #     permute_orders = torch.stack([permutations_inv[i.item()] for i in permute_order_index], dim=0)  # [N_head,3]
        
    #     indices = torch.zeros([head_per_split_num, N_image_token], device=self.device)
        
    #     # TODO: finish the permute for test 
    #     for i_head in range(head_per_split_num):
    #         permute_dims_head = permute_orders[i_head]
    #         indices[i_head] = torch.arange(F*H*W, device=self.device).reshape([F,H,W]).permute(*permute_dims_head).reshape([N_image_token])
    #     indices_expanded = indices.reshape([1,N_head,N_image_token,1]).expand([BS,-1,-1,N_dim])
        
        
