import math
from typing import Callable, List, Optional, Union

import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import Attention
from diffusers.models.unets import UNet2DConditionModel
from diffusers.utils import deprecate, logging
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
from einops import rearrange, repeat
from torch import nn


def default_set_attn_proc_func(
    name: str,
    hidden_size: int,
    cross_attention_dim: Optional[int],
    ori_attn_proc: object,
) -> object:
    return ori_attn_proc


def set_unet_2d_condition_attn_processor(
    unet: UNet2DConditionModel,
    set_self_attn_proc_func: Callable = default_set_attn_proc_func,
    set_cross_attn_proc_func: Callable = default_set_attn_proc_func,
    set_custom_attn_proc_func: Callable = default_set_attn_proc_func,
    set_self_attn_module_names: Optional[List[str]] = None,
    set_cross_attn_module_names: Optional[List[str]] = None,
    set_custom_attn_module_names: Optional[List[str]] = None,
) -> None:
    do_set_processor = lambda name, module_names: (
        any([name.startswith(module_name) for module_name in module_names])
        if module_names is not None
        else True
    )  # prefix match

    attn_procs = {}
    for name, attn_processor in unet.attn_processors.items():
        # set attn_processor by default, if module_names is None
        set_self_attn_processor = do_set_processor(name, set_self_attn_module_names)
        set_cross_attn_processor = do_set_processor(name, set_cross_attn_module_names)
        set_custom_attn_processor = do_set_processor(name, set_custom_attn_module_names)

        if name.startswith("mid_block"):
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]

        is_custom = "attn_mid_blocks" in name or "attn_post_blocks" in name
        # print('-----name------', name)
        if is_custom:
            # print('-----use is_custom------', is_custom)
            attn_procs[name] = (
                set_custom_attn_proc_func(name, hidden_size, None, attn_processor)
                if set_custom_attn_processor
                else attn_processor
            )
        else:
            # print('-----dont use is_custom------', is_custom)
            cross_attention_dim = (
                None
                if name.endswith("attn1.processor")
                else unet.config.cross_attention_dim
            )
            # print('-----cross_attention_dim------', cross_attention_dim)
            if cross_attention_dim is None or "motion_modules" in name:
                # self attention
                attn_procs[name] = (
                    set_self_attn_proc_func(
                        name, hidden_size, cross_attention_dim, attn_processor
                    )
                    if set_self_attn_processor
                    else attn_processor
                )
            else:
                # cross attention
                attn_procs[name] = (
                    set_cross_attn_proc_func(
                        name, hidden_size, cross_attention_dim, attn_processor
                    )
                    if set_cross_attn_processor
                    else attn_processor
                )

    unet.set_attn_processor(attn_procs)


class Jigsaw3DAttnProcessor(torch.nn.Module):
# class MVAdapterAttnProcessor(torch.nn.Module):
    r"""
    Attention processor for Decoupled Row-wise Self-Attention and Image Cross-Attention for PyTorch 2.0.
    """

    def __init__(
        self,
        query_dim: int,
        inner_dim: int,
        num_views: int = 1,
        name: Optional[str] = None,
        use_mv: bool = True,
        use_ref: bool = False,
        is_ref_branch: bool = False,
    ):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError(
                "Jigsaw3DAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
                # "MVAdapterAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

        super().__init__()

        self.num_views = num_views
        self.name = name  # NOTE: need for image cross-attention
        self.use_mv = use_mv
        self.use_ref = use_ref
        self.is_ref_branch = is_ref_branch

        if self.use_mv:
            self.to_q_mv = nn.Linear(
                in_features=query_dim, out_features=inner_dim, bias=False
            )
            self.to_k_mv = nn.Linear(
                in_features=query_dim, out_features=inner_dim, bias=False
            )
            self.to_v_mv = nn.Linear(
                in_features=query_dim, out_features=inner_dim, bias=False
            )
            self.to_out_mv = nn.ModuleList(
                [
                    nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
                    nn.Dropout(0.0),
                ]
            )

        if self.use_ref:
            self.to_q_ref = nn.Linear(
                in_features=query_dim, out_features=inner_dim, bias=False
            )
            self.to_k_ref = nn.Linear(
                in_features=query_dim, out_features=inner_dim, bias=False
            )
            self.to_v_ref = nn.Linear(
                in_features=query_dim, out_features=inner_dim, bias=False
            )
            self.to_out_ref = nn.ModuleList(
                [
                    nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
                    nn.Dropout(0.0),
                ]
            )

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
        mv_scale: float = 1.0,
        ref_hidden_states: Optional[torch.FloatTensor] = None,
        ref_scale: float = 1.0,
        cache_hidden_states: Optional[List[torch.FloatTensor]] = None,
        use_mv: bool = True,
        use_ref: bool = True,
        is_ref_branch: bool = False,
        num_views: Optional[int] = None,
        *args,
        **kwargs,
    ) -> torch.FloatTensor:
        """
        New args:
            mv_scale (float): scale for multi-view self-attention.
            ref_hidden_states (torch.FloatTensor): reference encoder hidden states for image cross-attention.
            ref_scale (float): scale for image cross-attention.
            cache_hidden_states (List[torch.FloatTensor]): cache hidden states from reference unet.

        """
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            deprecate("scale", "1.0.0", deprecation_message)

        if num_views is not None:
            self.num_views = num_views

        # NEW: cache hidden states for reference unet
        if cache_hidden_states is not None:
            cache_hidden_states[self.name] = hidden_states.clone()

        # NEW: whether to use multi-view attention and image cross-attention
        use_mv = self.use_mv and use_mv
        use_ref = self.use_ref and use_ref

        residual = hidden_states
        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(
                batch_size, channel, height * width
            ).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape
            if encoder_hidden_states is None
            else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(
                attention_mask, sequence_length, batch_size
            )
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(
                batch_size, attn.heads, -1, attention_mask.shape[-1]
            )

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
                1, 2
            )

        query = attn.to_q(hidden_states)

        # NEW: for decoupled multi-view attention
        if use_mv:
            query_mv = self.to_q_mv(hidden_states)

        # NEW: for decoupled reference cross attention
        if use_ref:
            query_ref = self.to_q_ref(hidden_states)

        # if encoder_hidden_states == None:
        #     print('encoder_hidden_states: ', encoder_hidden_states)
        # else:
        #     print('encoder_hidden_states: not None')

        # print('use_mv: ', use_mv)
        # print('use_ref: ', use_ref)
        # if is_ref_branch == True:
        #     print('is_ref_branch: ', is_ref_branch)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(
                encoder_hidden_states
            )

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(
            batch_size, -1, attn.heads * head_dim
        )
        hidden_states = hidden_states.to(query.dtype)

        ####### Decoupled multi-view self-attention ########
        if use_mv:
            key_mv = self.to_k_mv(encoder_hidden_states)
            value_mv = self.to_v_mv(encoder_hidden_states)

            query_mv = query_mv.view(batch_size, -1, attn.heads, head_dim)
            key_mv = key_mv.view(batch_size, -1, attn.heads, head_dim)
            value_mv = value_mv.view(batch_size, -1, attn.heads, head_dim)

            height = width = math.isqrt(sequence_length)

            # row self-attention
            query_mv = rearrange(
                query_mv,
                "(b nv) (ih iw) h c -> (b nv ih) iw h c",
                nv=self.num_views,
                ih=height,
                iw=width,
            ).transpose(1, 2)
            key_mv = rearrange(
                key_mv,
                "(b nv) (ih iw) h c -> b ih (nv iw) h c",
                nv=self.num_views,
                ih=height,
                iw=width,
            )
            key_mv = (
                key_mv.repeat_interleave(self.num_views, dim=0)
                .view(batch_size * height, -1, attn.heads, head_dim)
                .transpose(1, 2)
            )
            value_mv = rearrange(
                value_mv,
                "(b nv) (ih iw) h c -> b ih (nv iw) h c",
                nv=self.num_views,
                ih=height,
                iw=width,
            )
            value_mv = (
                value_mv.repeat_interleave(self.num_views, dim=0)
                .view(batch_size * height, -1, attn.heads, head_dim)
                .transpose(1, 2)
            )

            hidden_states_mv = F.scaled_dot_product_attention(
                query_mv,
                key_mv,
                value_mv,
                dropout_p=0.0,
                is_causal=False,
            )
            hidden_states_mv = rearrange(
                hidden_states_mv,
                "(b nv ih) h iw c -> (b nv) (ih iw) (h c)",
                nv=self.num_views,
                ih=height,
            )
            hidden_states_mv = hidden_states_mv.to(query.dtype)

            # linear proj
            hidden_states_mv = self.to_out_mv[0](hidden_states_mv)
            # dropout
            hidden_states_mv = self.to_out_mv[1](hidden_states_mv)

        if use_ref:
            reference_hidden_states = ref_hidden_states[self.name]
            # reference_hidden_states = torch.cat([reference_hidden_states, reference_hidden_states], dim=0)

            key_ref = self.to_k_ref(reference_hidden_states)
            value_ref = self.to_v_ref(reference_hidden_states)

            query_ref = query_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
                1, 2
            )
            key_ref = key_ref.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            value_ref = value_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
                1, 2
            )

            hidden_states_ref = F.scaled_dot_product_attention(
                query_ref, key_ref, value_ref, dropout_p=0.0, is_causal=False
            )

            hidden_states_ref = hidden_states_ref.transpose(1, 2).reshape(
                batch_size, -1, attn.heads * head_dim
            )
            hidden_states_ref = hidden_states_ref.to(query.dtype)

            # linear proj
            hidden_states_ref = self.to_out_ref[0](hidden_states_ref)
            # dropout
            hidden_states_ref = self.to_out_ref[1](hidden_states_ref)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if use_mv:
            hidden_states = hidden_states + hidden_states_mv * mv_scale

        if use_ref:
            hidden_states = hidden_states + hidden_states_ref * ref_scale

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(
                batch_size, channel, height, width
            )

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

    def set_num_views(self, num_views: int) -> None:
        self.num_views = num_views

