import inspect
import math
import logging
from typing import Optional
import omegaconf
from omegaconf import OmegaConf

import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention import Attention

from qdiff.base.base_quantizer import DynamicQuantizer
from qdiff.base.quant_attn import QuantizedAttentionMap

from models.sparse_attn import SparseAttentionMap, EmptyHeadAttentionMap

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name

class CogVideoXSageAttnProcessor:
    r"""
    Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
    query and key vectors, but does not include spatial normalization.
    """

    def __init__(self):
        self.customize_attention = True
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def convert_quant(self, quant_config):
        
        self.quant_config = quant_config
        
        if self.quant_config.get('export_calib_data', False):
            self.apply_hooks = True
            self.hooks = {}  # hook_handle to store activations
        else:
            self.apply_hooks = False
            
        # INFO: the sparse cfgs
        if self.quant_config.attn.get('sparse', None) is not None:
            if self.quant_config.attn.sparse.get('plan', None) is not None:
                self.sparse_plan = torch.load(self.quant_config.attn.sparse.plan, weights_only=True, map_location='cuda')
        
        # INFO: reorder could be used without quant
        # also, block-wise quant without reorder also requires this file
        # DIRTY: fix this file path 
        if self.quant_config.attn.get('qk', None) is not None:
            if self.quant_config.attn.qk.get('reorder_file_path', None) is not None:
                reorder_file = self.quant_config.attn.qk.reorder_file_path
                self.optimal_reorder = torch.load(reorder_file, weights_only=True, map_location='cuda')
        
        # default set as None
        self.q_quantizer = nn.Identity()
        self.k_quantizer = nn.Identity()
        self.v_quantizer = nn.Identity()
        self.pre_softmax_attn_map_quantizer = nn.Identity()
        self.attn_map_quantizer = nn.Identity()
        self.attn_map_sparse_processor = nn.Identity()

        if self.quant_config.attn.get('qk', None) is not None:
            self.q_quantizer = DynamicQuantizer(self.quant_config.attn.qk)
            self.k_quantizer = DynamicQuantizer(self.quant_config.attn.qk)

            if self.quant_config.attn.qk.get('mixed_precision_cfg_path', None) is not None:
                self.pre_softmax_attn_map_quantizer = QuantizedAttentionMap(self.quant_config)
                # process from list into whole tensor
                self.pre_softmax_attn_map_quantizer.mixed_precision_cfg = self.pre_softmax_attn_map_mixed_precision_cfg

        if self.quant_config.attn.get('v', None) is not None:
            self.v_quantizer = DynamicQuantizer(self.quant_config.attn.v)
                
        if self.quant_config.attn.get('attn_map', None) is not None:
            self.attn_map_quantizer = QuantizedAttentionMap(self.quant_config)
            if self.quant_config.attn.attn_map.get('mixed_precision_cfg_path', None) is not None:
                self.attn_map_quantizer.mixed_precision_cfg = self.post_softmax_attn_map_mixed_precision_cfg

        # INFO: sparse support for baseline methods
        if self.quant_config.attn.get('sparse', None) is not None:
            # self.attn_map_sparse_processor = SparseAttentionMap(self.quant_config)  # the deprecated sparse scheme. 
            
            # INFO: apply empty head sparse processing.
            if self.quant_config.attn.sparse.get('empty_head', None) is not None:
                self.attn_map_sparse_processor = EmptyHeadAttentionMap(self.quant_config, sparse_plan=self.sparse_plan)
        else:
            self.attn_map_sparse_processor = nn.Identity()
            
        if self.apply_hooks:
            if self.quant_config.calib_data.qkv:
                self.hooks['q'] = add_hook_to_module_(self.q_quantizer, SaveActivationHook, type='qk', quant_config=self.quant_config)
                self.hooks['k'] = add_hook_to_module_(self.k_quantizer, SaveActivationHook, type='qk', quant_config=self.quant_config)
                self.hooks['v'] = add_hook_to_module_(self.v_quantizer, SaveActivationHook, type='v', quant_config=self.quant_config)
            if self.quant_config.calib_data.attn_map:
                self.hooks['attn_map'] = add_hook_to_module_(self.attn_map_quantizer, SaveActivationHook, type='attn', quant_config=self.quant_config)

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        text_seq_length = encoder_hidden_states.size(1)

        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)
        
        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        
        #print(query.shape)
        #print(key.shape)
        #print(value.shape)
        if attn.norm_q is not None:
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        # Apply RoPE if needed
        if image_rotary_emb is not None:
            from diffusers.models.embeddings import apply_rotary_emb

            query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
            if not attn.is_cross_attention:
                key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)

        '''
        INFO: the attention map quantizer
        '''
        if self.customize_attention:
            hidden_states = self.customize_scaled_dot_product_attention(
                query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, \
                head_split_num=self.quant_config.attn.head_split_num, use_exp_of_two_softmax=self.quant_config.attn.exp_of_two_softmax, 
            )
        else:
            hidden_states = F.scaled_dot_product_attention(
                query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
            )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        encoder_hidden_states, hidden_states = hidden_states.split(
            [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
        )
        return hidden_states, encoder_hidden_states
    
    def customize_scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False, head_split_num=16, use_exp_of_two_softmax=False,\
        ) -> torch.Tensor:
        
        self.device = query.device
            
        L, S = query.size(-2), key.size(-2)
        scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
        attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
        if is_causal:
            assert attn_mask is None
            temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
            attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
            attn_bias.to(query.dtype)

        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
            else:
                attn_bias += attn_mask

        if enable_gqa:
            key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
            value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

        assert L == S

        head_num = query.shape[1]
        self.head_split_num = head_split_num
        head_per_split_num = head_num // head_split_num
        assert head_num % head_split_num == 0
        attn_output = torch.zeros_like(query)
        softmax_func = exp_of_two_softmax if use_exp_of_two_softmax else torch.softmax
        
        '''
        INFO: the qk reorder (token-level) based on existing reordering order.
        '''
        if self.quant_config.attn.get('qk', None) is not None:
            if self.quant_config.attn.qk.get('reorder', False):
                query, key, value = self.reorder_qk(query, key, value)
        
        for i in range(head_split_num):
            slice_range = (slice(None), slice(i * head_per_split_num, (i + 1) * head_per_split_num), slice(None), slice(None))
            query_part = query[slice_range]
            key_part = key[slice_range]
            value_part = value[slice_range]
        
            '''Quantization: qkv'''            
            BS, head_per_split_num, N_token, N_dim = query_part.shape
            if self.apply_hooks:
                if self.quant_config.calib_data.qkv:
                    self.hooks['q'].original_shape = [BS, head_per_split_num, N_token, N_dim]
            #print(f"N_token:{N_token}")
            query_part = query_part.view(BS, head_per_split_num, N_token//16, 16, N_dim)
            query_part = self.q_quantizer(query_part.reshape([-1, 16*N_dim])).reshape([BS, head_per_split_num, N_token, N_dim])

            if self.apply_hooks:
                if self.quant_config.calib_data.qkv:
                    self.hooks['k'].original_shape = [BS, head_per_split_num, N_token, N_dim]
            key_part = key_part.view(BS, head_per_split_num, N_token//16, 16, N_dim)
            key_part = self.k_quantizer(key_part.reshape([-1,N_dim])).reshape([BS, head_per_split_num, N_token, N_dim])

            if self.apply_hooks:
                if self.quant_config.calib_data.qkv:
                    self.hooks['v'].original_shape = [BS, head_per_split_num, N_token, N_dim]
            value_part = self.v_quantizer(
                value_part.permute([0,1,3,2]).reshape([-1, N_token])  # all tokens share the same quant_params.
                ).reshape([BS, head_per_split_num, N_dim, N_token]).permute([0,1,3,2])

            attn_map_pre_softmax_part = query_part @ key_part.transpose(-2, -1) * scale_factor
            attn_map_pre_softmax_part += attn_bias
            
            self.pre_softmax_attn_map_quantizer.i_block = self.i_block
            self.pre_softmax_attn_map_quantizer.split_range = [i * head_per_split_num, (i + 1) * head_per_split_num]
            attn_map_pre_softmax_part = self.pre_softmax_attn_map_quantizer(attn_map_pre_softmax_part)
                
            attn_map_post_softmax_part = softmax_func(attn_map_pre_softmax_part, dim=-1)
            attn_map_post_softmax_part = torch.dropout(attn_map_post_softmax_part, dropout_p, train=True)

            '''Quantization: apply attention map quantizer before multiplying V'''
            BS, head_per_split_num, N_token, N_token = attn_map_post_softmax_part.shape
            if self.apply_hooks:
                if self.quant_config.calib_data.attn_map:
                    self.hooks['attn_map'].original_shape = [BS, head_per_split_num, N_token, N_token]
                
            # INFO: the attention map quantization, reshaped within the attn_map_quantizer
            # self.attn_map_quantizer.quant_mode = False
            self.attn_map_quantizer.i_block = self.i_block
            self.attn_map_quantizer.split_range = [i * head_per_split_num, (i + 1) * head_per_split_num]
            attn_map_post_softmax_part = self.attn_map_quantizer(attn_map_post_softmax_part)
                     
            self.attn_map_sparse_processor.i_block = self.i_block
            self.attn_map_sparse_processor.split_range = [i * head_per_split_num, (i + 1) * head_per_split_num]
            attn_map_post_softmax_part = self.attn_map_sparse_processor(attn_map_post_softmax_part)
            
            attn_output[slice_range] = attn_map_post_softmax_part @ value_part
        
        # INFO: unpack the hooks (deprecated, move this logic to quant_inference.py)
        # if self.apply_hooks:
        #     for k_ in self.hooks:
        #         save_data = torch.cat(self.hooks[k_].outputs, dim=1)
        #         torch.save(save_data, f'./visualization/savedz_{k_}s.pth')
        #     import ipdb; ipdb.set_trace()
        
        # INFO: reorder back the output
        attn_output_ = attn_output
        if self.quant_config.attn.get('qk', None) is not None:
            if self.quant_config.attn.qk.get('reorder', False):
                attn_output = self.reorder_attn_out(attn_output)
            
        """ the same as the following code, only split to lower the memory footprint
        attn_weight = query @ key.transpose(-2, -1) * scale_factor
        attn_weight += attn_bias
        attn_weight = torch.softmax(attn_weight, dim=-1)
        attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
        attn_score = attn_weight @ value
        """
        assert attn_output.dtype == torch.bfloat16
        return attn_output