import torch
import torch.nn as nn
import torch.nn.functional as F
import math 
from qdiff.base.base_quantizer import StaticQuantizer, DynamicQuantizer
from qdiff.base.mixed_precision_quantizer import MixedPrecisionStaticQuantizer, MixedPrecisionDynamicQuantizer
from omegaconf import OmegaConf, ListConfig

from models.attn_eval_utils import evaluate_attention_maps

class SparseAttentionMap(torch.nn.Module):
    """
    the quantization for attention map
    """
    def __init__(
        self,
        quant_config: dict,
    ) -> None:
        super().__init__()
        
        self.quant_config = quant_config
        if self.quant_config.attn.sparse.type == "N_M":
            self.N = self.quant_config.attn.sparse.N
            self.M = self.quant_config.attn.sparse.M
            
    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        '''
        Apply the N:M (N=8) sparsity, simply seeking the mininum elements and set it as zero.
        '''
        BS, head_per_split_num, N_token, N_token = x.shape
        N_text_token = self.quant_config.model.n_text_tokens
        N_image_token = N_token - self.quant_config.model.n_text_tokens
        
        attn_map_image = x[:,:,N_text_token:,N_text_token:]
        nearest_divisible_n = (N_text_token // self.N) * self.N
        # INFO: get smaller divisible n_image_token, drop last, drop less to preserve performance.
        attn_map_image = attn_map_image[:,:,:nearest_divisible_n,:nearest_divisible_n]
        attn_map_image = attn_map_image.reshape([BS, head_per_split_num, nearest_divisible_n, nearest_divisible_n//self.N, self.N])
        _, indices = torch.topk(attn_map_image, k=self.N-self.M, largest=False, dim=-1)
        attn_map_image.scatter_(-1,indices,0.)
        attn_map_image = attn_map_image.reshape([BS, head_per_split_num, nearest_divisible_n, nearest_divisible_n])
        # print((attn_map_image==0).sum() / attn_map_image.numel())
        
        x[:,:, N_text_token:N_text_token+nearest_divisible_n , N_text_token:N_text_token+nearest_divisible_n] = attn_map_image
        
        return x
    
class EmptyHeadAttentionMap(torch.nn.Module):
    """
    the quantization for attention map
    """
    def __init__(
        self,
        quant_config: dict,
        sparse_plan: dict,
    ) -> None:
        super().__init__()
        
        self.quant_config = quant_config
        self.sparse_plan = sparse_plan
        self.empty_head_processor = self.quant_config.attn.sparse.empty_head_procssor
                    
    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        '''
        get the empty heads and assign uniform attn values to them. 
        '''
        BS, head_per_split_num, N_token, N_token = x.shape
        N_text_token = self.quant_config.model.n_text_tokens
        N_image_token = N_token - self.quant_config.model.n_text_tokens
        attn_map_image = x[:,:,N_text_token:,N_text_token:]
        
        # INFO: get which heads are empty
        indices = torch.arange(self.split_range[0], self.split_range[1], device='cuda')  # split_range are assigned in attn inference, so do self.i_block.

        if_head_empty = self.permute_plan['empty'][self.i_block][indices]
        empty_indices = torch.nonzero(if_head_empty).squeeze(-1)
        
        # INFO: for empty heads, make the values in it as uniform attn values. 
        # the original x should sum to one from last dim.  x.sum(dim=1) -> [1,1,...,1]
        # the image_tokens should have 
        if len(empty_indices) > 0:
            print('Block {}, has empty heads {}'.format(self.i_block, empty_indices.tolist()))
            empty_heads = attn_map_image.index_select(dim=1, index=empty_indices)
            if self.empty_head_processor == 'uniform':
                x[:,empty_indices,N_text_token:,N_text_token:] = empty_heads.mean(-1).unsqueeze(-1).repeat(1,1,1,N_image_token)
            elif self.empty_head_processor == 'zero':
                x[:,empty_indices,N_text_token:,N_text_token:].fill_(0.)
            else:
                raise NotImplementedError
            
            # check the relative diffs. 
            # l1_diff = (empty_heads - empty_heads.mean(-1).unsqueeze(-1).repeat(1,1,1,N_image_token)).sum() / empty_heads.sum()

        return x

# INFO: utility for permute.
permutations = torch.tensor([
        [0, 1, 2],  # 0: FHW
        [0, 2, 1],  # 1: FWH
        [1, 2, 0],  # 2: HWF
        [1, 0, 2],  # 3: HFW
        [2, 1, 0],  # 4: WHF
        [2, 0, 1],  # 5: WFH
])

class PAROAttentionMap(torch.nn.Module):
    """
    the PARO sparse processor for attention map: block_sparse + empty
    """
    def __init__(
        self,
        quant_config: dict,
        sparse_plan: dict,
        permute_plan: dict,
    ) -> None:
        super().__init__()
        
        self.quant_config = quant_config
        self.sparse_plan = sparse_plan
        self.permute_plan = permute_plan
        self.empty_head_processor = self.quant_config.attn.sparse.empty_head_procssor
        self.online = quant_config.attn.sparse.get('online', False)
        self.dense_rate_accumulator = []
        
    def get_sparse_mask(self, attn_map_image):
        '''
        input: attn_map [BS, N_head, N, N]
        outpute: attn_mask [BS, N_head, N, N]
        the local N_image_token is actually N_mask_token
        '''
        BS, head_per_split_num, N_image_token, N_image_token = attn_map_image.shape
        
        attn_map_image_blocks = attn_map_image.reshape([
            BS, head_per_split_num, self.N_block_sparse, self.block_sparse_size, self.N_block_sparse, self.block_sparse_size
        ]).permute([0,1,2,4,3,5])
        
        # sparse_block_max = attn_map_image_blocks.max(dim=-1)[0].max(dim=-1)[0]
        # sparse_mask_large_value = sparse_block_max > self.quant_config.sparse_plan.sparse.max_threshold
         
        sparse_block_sum = attn_map_image_blocks.sum(dim=-1).sum(dim=-1)
        sparse_block_sum_mean = sparse_block_sum.mean(dim=(2,3))  # [BS, head_per_split_num]
        if self.quant_config.sparse_plan.sparse.sparse_type == "mean":
            sparse_mask_small_sum = sparse_block_sum > (sparse_block_sum_mean/self.quant_config.sparse_plan.sparse.block_sum_k).reshape([BS, head_per_split_num, 1, 1]).expand_as(sparse_block_sum)
        elif self.quant_config.sparse_plan.sparse.sparse_type == "threshold":
            sparse_mask_small_sum = sparse_block_sum > self.quant_config.sparse_plan.sparse.sum_threshold
        else:
            raise NotImplementedError
        
        sparse_mask = sparse_mask_small_sum
        # sparse_mask = torch.logical_or(sparse_mask_large_value, sparse_mask_small_sum)
                
        return sparse_mask

    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        '''
        get the empty heads and assign uniform attn values to them. 
        '''
        BS, head_per_split_num, N_token, N_token = x.shape
        N_text_token = self.quant_config.model.n_text_tokens
        N_image_token = N_token - self.quant_config.model.n_text_tokens
        attn_map_image = x[:,:,N_text_token:,N_text_token:]
        
        # DEBUG_ONLY
        # x_old = x.clone()
        # attn_map_image_old = attn_map_image.clone()
        
        assert isinstance(self.quant_config.calib_data.attn_ds_rate, int)
        
        if self.quant_config.calib_data.get("attn_ds_type", None) == "part":
            self.N_block_sparse = self.sparse_plan['sparse'].shape[-1]
            self.block_sparse_size = N_image_token // self.N_block_sparse
        else:
            self.block_sparse_size = self.quant_config.calib_data.attn_ds_rate
            self.N_block_sparse = N_image_token // self.block_sparse_size
        
        # INFO: get the sparse mask.
        N_block_sparse = self.N_block_sparse
        N_mask_token = N_block_sparse*self.block_sparse_size  # when not divisible, could be smaller than N_image_token
        indices = torch.arange(self.split_range[0], self.split_range[1], device='cuda')
        
        attn_map_image_ = attn_map_image[:,:,:N_mask_token,:N_mask_token]
        if self.online:
            block_sparse_masks = self.get_sparse_mask(attn_map_image_)
        else:
            N_timestep_in_calib_data = self.sparse_plan['sparse'].shape[0]
            i_timestep_in_calib_data = int(self.i_timestep // (1/N_timestep_in_calib_data))
            block_sparse_masks = self.sparse_plan['sparse'][i_timestep_in_calib_data, self.i_block,indices].unsqueeze(0).repeat([BS,1,1,1])  # [BS, N_block, N_block]
            
        for i_ in indices:
            block_sparse_mask = block_sparse_masks[:,i_%head_per_split_num,:,:]
            block_sparse_mask = block_sparse_mask.reshape([BS,self.N_block_sparse,1,self.N_block_sparse,1])
            
            attn_map_image_head = attn_map_image_[:,i_ % head_per_split_num,:,:].reshape([
                BS,
                self.N_block_sparse,
                self.block_sparse_size,
                self.N_block_sparse,
                self.block_sparse_size,    
            ])
            
            # INFO: if pre-softmax, the masked value should be -1.e10 (very large negative value), to ensure the post-softmax corresponding values are close to 0. 
            if self.quant_config.attn.sparse.get("pre_softmax", False):
                large_negative_value = -1.e3
                attn_map_image_head[block_sparse_mask.expand_as(attn_map_image_head) == 0] = large_negative_value
                attn_map_image[:,i_%head_per_split_num,:N_mask_token,:N_mask_token] = (attn_map_image_head).reshape([
                    BS,N_mask_token,N_mask_token
                ])
            else:
                attn_map_image[:,i_%head_per_split_num,:N_mask_token,:N_mask_token] = (attn_map_image_head*block_sparse_mask).reshape([
                    BS,N_mask_token,N_mask_token
                ])
            
        # --> Check correctness
        # dense_rate_1 = 1 - (block_sparse_masks[:,:,:].sum() / block_sparse_masks[:,:,:].numel())
        # dense_rate_2 = (attn_map_image[:,:,:,:] == -1.e3).sum() / attn_map_image[:,:,:,:].numel()
        # # dense_rate_3 = (x[:,:,:,:] == -1.e3).sum() / x[:,:,:,:].numel()
        # print(dense_rate_1, dense_rate_2)  # the error is due to the resolution.
        # import ipdb; ipdb.set_trace()
            
        # # --> DEBUG: plot the pre and post attn_map
        # self.type = 'attn'
        # save_d = {}
        # save_d['origin'] = self.attn_map_downsample(attn_map_image_old.cpu())
        # if self.quant_config.attn.sparse.get("pre_softmax", False):
        #     attn_map_image[attn_map_image == large_negative_value] = 0
        # save_d['masked'] = self.attn_map_downsample(attn_map_image.cpu())
        # save_d['mask'] = block_sparse_masks.cpu()
        # torch.save(save_d, f'visualization/debug_sparse_infer/debug_block_sparse_block_{self.i_block}_{self.split_range[0]}_{self.split_range[1]}.pth')
        # results = evaluate_attention_maps(x, x_old)
        # print(results)
        # print( (F.softmax(attn_map_image, dim=-1) <= 1.e-4).sum() / attn_map_image.numel() )
        # import ipdb; ipdb.set_trace()
        
        # INFO: process the empty heads further
        # will overwrite the zero_mask. 
        indices = torch.arange(self.split_range[0], self.split_range[1], device='cuda')
        if_head_empty = self.permute_plan['empty'][self.i_block][indices]
        empty_indices = torch.nonzero(if_head_empty).squeeze(-1)
        
        if len(empty_indices) > 0:
            # print('Block {}, has empty heads {}'.format(self.i_block, empty_indices.tolist()))
            empty_heads = attn_map_image.index_select(dim=1, index=empty_indices)
            if self.empty_head_processor == 'uniform':
                x[:,empty_indices,N_text_token:,N_text_token:] = empty_heads.mean(-1).unsqueeze(-1).repeat(1,1,1,N_image_token)
            elif self.empty_head_processor == 'zero':
                x[:,empty_indices,N_text_token:,N_text_token:].fill_(0.)
            else:
                raise NotImplementedError
        
        block_sparse_mask_with_empty = block_sparse_masks*(1-if_head_empty).reshape([1,head_per_split_num,1,1])
        dense_rate = (block_sparse_masks.sum() / block_sparse_masks.numel()).item()
        dense_rate_with_empty = (block_sparse_mask_with_empty.sum() / block_sparse_mask_with_empty.numel()).item()
        # print(f'dense rate:{dense_rate:.4f}, dense_rate_with_empty:{dense_rate_with_empty:.4f}')
        self.dense_rate_accumulator.append(dense_rate_with_empty)

        return x
        
    # copied from customize_cogvideox_attn_processor.py SaveActivationHook.attn_map_downsample
    def attn_map_downsample(self, data):
        '''
        down_sample in the N_token dimension, handle the indivisible situation. 
        '''
        assert self.type == 'attn'
        BS, head_per_split_num, N_token, N_token = data.shape
        self.attn_ds_rate = self.quant_config.calib_data.attn_ds_rate
    
        # INFO: single value of self.attn_ds_rate, just downsample the last dim (used for the sparse plan), maybe it causes the last_dim to vanish, when ds_size is larger than F,H,W. But we assume after permute the data is already locally aggregated, so use plain downsample, since it is more suitable for efficienct kernel processing.
        N_text_token = self.quant_config.model.n_text_tokens
        N_image_token = N_token - N_text_token
        data = data[:,:,N_text_token:,N_text_token:]
        
        N_remainder = N_image_token % self.attn_ds_rate
        if N_remainder != 0:
            data = data[:,:,:-N_remainder,:-N_remainder]
        data = data.reshape([
            BS,head_per_split_num,N_image_token//self.attn_ds_rate,self.attn_ds_rate,N_image_token//self.attn_ds_rate,self.attn_ds_rate
            ])
        return data.mean(dim=3).mean(dim=4)
        