# This code is taken from diffusers. All our modifications are noted within ######
import inspect
import math
from typing import Callable, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import nn

from scipy.ndimage import binary_fill_holes

from diffusers.image_processor import IPAdapterMaskProcessor
from diffusers.utils import deprecate, is_torch_xla_available, logging
from diffusers.utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available
from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
from diffusers.models.attention_processor import Attention


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

if is_torch_npu_available():
    import torch_npu

if is_xformers_available():
    import xformers
    import xformers.ops
else:
    xformers = None

if is_torch_xla_available():
    # flash attention pallas kernel is introduced in the torch_xla 2.3 release.
    if is_torch_xla_version(">", "2.2"):
        from torch_xla.experimental.custom_kernel import flash_attention
        from torch_xla.runtime import is_spmd
    XLA_AVAILABLE = True
else:
    XLA_AVAILABLE = False


def mask_attn(attn_weight, src_indices, tgt_indices):
    src_mesh, tgt_mesh = torch.meshgrid(src_indices, tgt_indices, indexing='ij')
    attn_weight[:, :, src_mesh, tgt_mesh] = float('-inf')


class ConstrainedJointAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    def __init__(self, block_id):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
        self.block_id = block_id
        ###### bounding box information
        self.bb = None
        ### segmentation map information. Expects self.bb to have dictionary 'segmentation_dict' in form {'head': int, 'block': int, 'thresh': float (0,1)}
        # should be set right before using
        self.save_maps = False
        # should stay fixed
        self.token_start = 0
        self.token_end = 333
        self.len_text_embedding = 333
        self.height = 32
        self.width = 32
        ######

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
    ) -> torch.FloatTensor:

        residual = hidden_states

        batch_size = hidden_states.shape[0]

        # `sample` projections.
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        if attn.norm_q is not None:
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        # `context` projections.
        if encoder_hidden_states is not None:
            encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
            encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
            encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

            encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
                batch_size, -1, attn.heads, head_dim
            ).transpose(1, 2)
            encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
                batch_size, -1, attn.heads, head_dim
            ).transpose(1, 2)
            encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
                batch_size, -1, attn.heads, head_dim
            ).transpose(1, 2)

            if attn.norm_added_q is not None:
                encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
            if attn.norm_added_k is not None:
                encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)

            query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
            key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
            value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)

        ######
        hidden_states = self.constrained_scaled_dot_product_attention(
            query, key, value, dropout_p=0.0, is_causal=False
            )
        ######
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        if encoder_hidden_states is not None:
            # Split the attention outputs.
            hidden_states, encoder_hidden_states = (
                hidden_states[:, : residual.shape[1]],
                hidden_states[:, residual.shape[1] :],
            )
            if not attn.context_pre_only:
                encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if encoder_hidden_states is not None:
            return hidden_states, encoder_hidden_states
        else:
            return hidden_states

    
    def constrained_scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropout_p=0.0,
            is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
        L, S = query.size(-2), key.size(-2)
        scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
        attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
        if is_causal:
            assert attn_mask is None
            temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
            attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
            attn_bias.to(query.dtype)

        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
            else:
                attn_bias = attn_mask + attn_bias

        if enable_gqa:
            key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
            value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

        attn_weight = query @ key.transpose(-2, -1) * scale_factor
        attn_weight += attn_bias
        ######
        if self.bb is not None:
            batch_size, num_heads, total_len, _ = attn_weight.shape
            assert total_len == self.len_text_embedding + self.width * self.height

            # Get bbox image indices (flattened)
            bbox_img_indices = [y * self.width + x for y in range(self.bb['y_min'], self.bb['y_max'] + 1) for x in range(self.bb['x_min'], self.bb['x_max'] + 1)]
            # All image indices
            all_img_indices = set(range(0, self.width*self.height))
            outside_bbox_indices = list(all_img_indices - set(bbox_img_indices))
            text_indices = [x + (self.height * self.width) for x in range(self.len_text_embedding)]
            # 3. Convert to torch tensors
            bbox_img_indices = torch.tensor(bbox_img_indices, dtype=torch.long, device=attn_weight.device)
            outside_bbox_indices = torch.tensor(outside_bbox_indices, dtype=torch.long, device=attn_weight.device)
            text_indices = torch.tensor(text_indices, dtype=torch.long, device=attn_weight.device)

            # 1) From image outside bbox to text: set to -inf
            mask_attn(attn_weight, outside_bbox_indices, text_indices)

            # 2) From text to image outside bbox: set to -inf
            mask_attn(attn_weight, text_indices, outside_bbox_indices)

            # 3) From image inside bbox to image outside bbox: set to -inf
            mask_attn(attn_weight, bbox_img_indices, outside_bbox_indices)
        ######
        attn_weight = torch.softmax(attn_weight, dim=-1)
        attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
        ######
        # (batch_size, num_heads, seq_len_q, seq_len_kv) --> text embeds at beginning of seq
        # save cutout maps
        if self.save_maps:
            target_head_id = self.bb['segmentation_dict']['head']
            target_block_id = self.bb['segmentation_dict']['block']
            target_thresh = self.bb['segmentation_dict']['thresh']
            # take maps from non-padding tokens
            if target_block_id == self.block_id:
                attn_map = []
                for i in range((self.height * self.width) + self.token_start, (self.height * self.width) + self.token_end):
                    single_head_token_attention = attn_weight[0, target_head_id, i]  # todo maybe this can be fixed to support batching
                    visual_attention = single_head_token_attention[:(self.height * self.width)]
                    attn_map.append(visual_attention)
                # average the maps
                stacked_attn_map = torch.stack(attn_map, dim=0)
                mean_attn_map = stacked_attn_map.mean(dim=0).float()
                mean_attn_map_reshaped = mean_attn_map.reshape(1, 1, self.width, self.height)
                # # Define a kernel for smoothing
                blurred = F.max_pool2d(mean_attn_map_reshaped, kernel_size=self.bb['blur_kernel'], stride=1, padding=self.bb['blur_kernel'] // 2)
                mean_attn_map = blurred
                # take the top attention
                top_attention = self.top_percent_attention_mask(mean_attn_map, target_thresh)
                top_attention = top_attention.flatten()
                for x in outside_bbox_indices:
                    top_attention[x.item()] = False
                self.bb['mask'] = top_attention
            self.save_maps = False
        ######
        return attn_weight @ value


    ######
    def top_percent_attention_mask(self, attn_map: torch.Tensor, thr: float) -> torch.Tensor:
        # Flatten and sort the attention values descending
        flat = attn_map.view(-1)
        sorted_vals, indices = torch.sort(flat, descending=True)
        
        # Compute cumulative sum (i.e., attention mass)
        cumsum = torch.cumsum(sorted_vals, dim=0)
        total = cumsum[-1]
        
        # Find how many elements to include to reach the threshold % of total mass
        cutoff_idx = torch.searchsorted(cumsum, thr * total)

        # Create a binary mask of the same shape as the original attention map
        mask = torch.zeros_like(flat, dtype=torch.bool)
        mask[indices[:cutoff_idx + 1]] = True  # include up to the cutoff

        return mask.view_as(attn_map)
    ######
