# This code is taken from diffusers. All our modifications are noted within ######
import torch
import functools
import math
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.attention import AttentionMixin, FeedForward
from diffusers.models.attention_dispatch import dispatch_attention_fn
from diffusers.models.attention_processor import Attention
from diffusers.models.cache_utils import CacheMixin
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen


def mask_attn(attn_weight, src_indices, tgt_indices):
    src_mesh, tgt_mesh = torch.meshgrid(src_indices, tgt_indices, indexing='ij')
    attn_weight[:, :, src_mesh, tgt_mesh] = float('-inf')


class ConstrainedQwenDoubleStreamAttnProcessor2_0:
    """
    Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
    implements joint attention computation where text and image streams are processed together.
    """

    _attention_backend = None

    def __init__(self, block_id):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError(
                "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )
        self.block_id = block_id
        ###### bounding box information
        self.bb = None
        ### segmentation map information. Expects self.bb to have dictionary 'segmentation_dict' in form {'head': int, 'block': int, 'thresh': float (0,1)}
        # should be set right before using
        self.save_maps = False
        # should stay fixed
        self.token_start = 0
        self.len_text_embedding = None
        self.height = 32
        self.width = 32
        ######

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,  # Image stream
        encoder_hidden_states: torch.FloatTensor = None,  # Text stream
        encoder_hidden_states_mask: torch.FloatTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        if encoder_hidden_states is None:
            raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")

        seq_txt = encoder_hidden_states.shape[1]
        self.len_text_embedding = seq_txt

        # Compute QKV for image stream (sample projections)
        img_query = attn.to_q(hidden_states)
        img_key = attn.to_k(hidden_states)
        img_value = attn.to_v(hidden_states)

        # Compute QKV for text stream (context projections)
        txt_query = attn.add_q_proj(encoder_hidden_states)
        txt_key = attn.add_k_proj(encoder_hidden_states)
        txt_value = attn.add_v_proj(encoder_hidden_states)

        # Reshape for multi-head attention
        img_query = img_query.unflatten(-1, (attn.heads, -1))
        img_key = img_key.unflatten(-1, (attn.heads, -1))
        img_value = img_value.unflatten(-1, (attn.heads, -1))

        txt_query = txt_query.unflatten(-1, (attn.heads, -1))
        txt_key = txt_key.unflatten(-1, (attn.heads, -1))
        txt_value = txt_value.unflatten(-1, (attn.heads, -1))

        # Apply QK normalization
        if attn.norm_q is not None:
            img_query = attn.norm_q(img_query)
        if attn.norm_k is not None:
            img_key = attn.norm_k(img_key)
        if attn.norm_added_q is not None:
            txt_query = attn.norm_added_q(txt_query)
        if attn.norm_added_k is not None:
            txt_key = attn.norm_added_k(txt_key)

        # Apply RoPE
        if image_rotary_emb is not None:
            img_freqs, txt_freqs = image_rotary_emb
            img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
            img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
            txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
            txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)

        # Concatenate for joint attention
        # Order: [text, image]
        joint_query = torch.cat([txt_query, img_query], dim=1)
        joint_key = torch.cat([txt_key, img_key], dim=1)
        joint_value = torch.cat([txt_value, img_value], dim=1)

        ######
        joint_query = joint_query.permute(0, 2, 1, 3)
        joint_key = joint_key.permute(0, 2, 1, 3)
        joint_value = joint_value.permute(0, 2, 1, 3)
        joint_hidden_states = self.constrained_scaled_dot_product_attention(
            joint_query, joint_key, joint_value,
            attn_mask=attention_mask,
            dropout_p=0.0,
            is_causal=False,
        )
        joint_hidden_states = joint_hidden_states.permute(0, 2, 1, 3)
        ######

        # Reshape back
        joint_hidden_states = joint_hidden_states.flatten(2, 3)
        joint_hidden_states = joint_hidden_states.to(joint_query.dtype)

        # Split attention outputs back
        txt_attn_output = joint_hidden_states[:, :seq_txt, :]  # Text part
        img_attn_output = joint_hidden_states[:, seq_txt:, :]  # Image part

        # Apply output projections
        img_attn_output = attn.to_out[0](img_attn_output)
        if len(attn.to_out) > 1:
            img_attn_output = attn.to_out[1](img_attn_output)  # dropout

        txt_attn_output = attn.to_add_out(txt_attn_output)

        return img_attn_output, txt_attn_output

    def constrained_scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropout_p=0.0,
            is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
        L, S = query.size(-2), key.size(-2)
        scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
        attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
        if is_causal:
            assert attn_mask is None
            temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
            attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
            attn_bias.to(query.dtype)

        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
            else:
                attn_bias = attn_mask + attn_bias

        if enable_gqa:
            key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
            value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

        attn_weight = query @ key.transpose(-2, -1) * scale_factor
        attn_weight += attn_bias
        ######
        if self.bb is not None:
            batch_size, num_heads, total_len, _ = attn_weight.shape
            assert total_len == self.len_text_embedding + self.width * self.height

            # Get bbox image indices (flattened + shifted)
            bbox_indices = [y * self.width + x for y in range(self.bb['y_min'], self.bb['y_max'] + 1) for x in range(self.bb['x_min'], self.bb['x_max'] + 1)]
            bbox_img_indices = [self.len_text_embedding + idx for idx in bbox_indices]
            # All image indices
            all_img_indices = set(range(self.len_text_embedding, self.len_text_embedding + self.width*self.height))
            outside_bbox_indices = list(all_img_indices - set(bbox_img_indices))
            # 3. Convert to torch tensors
            bbox_img_indices = torch.tensor(bbox_img_indices, dtype=torch.long, device=attn_weight.device)
            outside_bbox_indices = torch.tensor(outside_bbox_indices, dtype=torch.long, device=attn_weight.device)
            text_indices = torch.arange(self.len_text_embedding, dtype=torch.long, device=attn_weight.device)

            # 1) From image outside bbox to text: set to -inf
            mask_attn(attn_weight, outside_bbox_indices, text_indices)

            # 2) From text to image outside bbox: set to -inf
            mask_attn(attn_weight, text_indices, outside_bbox_indices)

            # 3) From image inside bbox to image outside bbox: set to -inf
            mask_attn(attn_weight, bbox_img_indices, outside_bbox_indices)
        ######
        attn_weight = torch.softmax(attn_weight, dim=-1)
        attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
        ######
        # (batch_size, num_heads, seq_len_q, seq_len_kv) --> text embeds at beginning of seq
        # save cutout maps
        if self.save_maps:
            target_head_id = self.bb['segmentation_dict']['head']
            target_block_id = self.bb['segmentation_dict']['block']
            target_thresh = self.bb['segmentation_dict']['thresh']
            if target_block_id == self.block_id:
                attn_map = []
                # take maps from non-padding tokens
                for i in range(self.token_start, self.bb['end_token']):
                    single_head_token_attention = attn_weight[0, target_head_id, i]  # todo maybe this can be fixed to support batching
                    visual_attention = single_head_token_attention[self.len_text_embedding:]
                    attn_map.append(visual_attention)
                # average the maps
                stacked_attn_map = torch.stack(attn_map, dim=0)
                mean_attn_map = stacked_attn_map.mean(dim=0).float()
                mean_attn_map_reshaped = mean_attn_map.reshape(1, 1, 32, 32)
                # # Define a kernel for smoothing
                blurred = F.max_pool2d(mean_attn_map_reshaped, kernel_size=self.bb['blur_kernel'], stride=1, padding=self.bb['blur_kernel'] // 2)
                mean_attn_map = blurred
                # take the top attention
                top_attention = self.top_percent_attention_mask(mean_attn_map, target_thresh)
                top_attention = top_attention.flatten()
                for x in outside_bbox_indices:
                    top_attention[x.item() - self.len_text_embedding] = False
                self.bb['mask'] = top_attention
            self.save_maps = False
        ######
        return attn_weight @ value

    ######
    def top_percent_attention_mask(self, attn_map: torch.Tensor, thr: float) -> torch.Tensor:
        # Flatten and sort the attention values descending
        flat = attn_map.view(-1)
        sorted_vals, indices = torch.sort(flat, descending=True)
        
        # Compute cumulative sum (i.e., attention mass)
        cumsum = torch.cumsum(sorted_vals, dim=0)
        total = cumsum[-1]
        
        # Find how many elements to include to reach the threshold % of total mass
        cutoff_idx = torch.searchsorted(cumsum, thr * total)

        # Create a binary mask of the same shape as the original attention map
        mask = torch.zeros_like(flat, dtype=torch.bool)
        mask[indices[:cutoff_idx + 1]] = True  # include up to the cutoff

        return mask.view_as(attn_map)
    ######