# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class AttnProcessor(nn.Module):
    r"""
    Default processor for performing attention-related computations.
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class IPAttnProcessor(nn.Module):
    r"""
    Attention processor for IP-Adapater.
    Args:
        hidden_size (`int`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`):
            The number of channels in the `encoder_hidden_states`.
        scale (`float`, defaults to 1.0):
            the weight scale of image prompt.
        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
            The context length of the image features.
    """

    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
        super().__init__()

        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.num_tokens = num_tokens

        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            # get encoder_hidden_states, ip_hidden_states
            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
            encoder_hidden_states, ip_hidden_states = (
                encoder_hidden_states[:, :end_pos, :],
                encoder_hidden_states[:, end_pos:, :],
            )
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # for ip-adapter
        ip_key = self.to_k_ip(ip_hidden_states)
        ip_value = self.to_v_ip(ip_hidden_states)

        ip_key = attn.head_to_batch_dim(ip_key)
        ip_value = attn.head_to_batch_dim(ip_value)

        ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
        self.attn_map = ip_attention_probs
        ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
        ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)





        hidden_states = hidden_states + self.scale * ip_hidden_states

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class AttnProcessor2_0(torch.nn.Module):
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class IPAttnProcessor2_0(torch.nn.Module):
    r"""
    Attention processor for IP-Adapater for PyTorch 2.0.
    Args:
        hidden_size (`int`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`):
            The number of channels in the `encoder_hidden_states`.
        scale (`float`, defaults to 1.0):
            the weight scale of image prompt.
        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
            The context length of the image features.
    """

    def __init__(self, name, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, denoise_step=0):
        super().__init__()

        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
        self.name = name
        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.num_tokens = num_tokens
        self.denoise_step = denoise_step
        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.style_matrix = None
        self.style_matrix1 = None
        self.scale_entropy = None
        self.wk = []

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states
        self.denoise_step += 1
        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)
        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            print("attn.group_norm")
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query1 = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            # get encoder_hidden_states, ip_hidden_states
            end_pos1 = encoder_hidden_states.shape[1] - self.num_tokens * 4
            end_pos2 = encoder_hidden_states.shape[1] - self.num_tokens * 3
            end_pos3 = encoder_hidden_states.shape[1] - self.num_tokens * 2
            end_pos4 = encoder_hidden_states.shape[1] - self.num_tokens
            encoder_hidden_states, encoder_content_hidden_states, encoder_content_hidden_states1, ip_hidden_states, ip_hidden_states_style, nums , numc = (
                encoder_hidden_states[:, :77, :],
                encoder_hidden_states[:, 77:154, :],
                encoder_hidden_states[:, 154:end_pos1, :],
                encoder_hidden_states[:, end_pos1:end_pos2, :],
                encoder_hidden_states[:, end_pos2:end_pos3, :],
                encoder_hidden_states[:, end_pos3:end_pos4, :],
                encoder_hidden_states[:, end_pos4:, :],
            )

            # if self.denoise_step > 30 * 0.6:
            #     encoder_hidden_states = encoder_hidden_states1

            # if self.denoise_step >= 10:
            #     encoder_hidden_states = encoder_hidden_states0
            if attn.norm_cross:
                print("attn.norm_cross")
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        # print(encoder_hidden_states.shape)
        key = attn.to_k(encoder_hidden_states)
        # print(key.shape)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)


        # #################################################################################################
        # key = attn.to_k(encoder_content_hidden_states1)
        # value = attn.to_v(encoder_content_hidden_states1)
        #
        # inner_dim = key.shape[-1]
        # head_dim = inner_dim // attn.heads
        #
        # query = query1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #
        # key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        # value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #
        # # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # # TODO: add support for attn.scale when we move to Torch 2.1
        # hidden_states1 = F.scaled_dot_product_attention(
        #     query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        # )
        # hidden_states1 = hidden_states1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        # hidden_states1 = hidden_states1.to(query.dtype)
        #
        # key = attn.to_k(encoder_content_hidden_states)
        # value = attn.to_v(encoder_content_hidden_states)
        #
        # inner_dim = key.shape[-1]
        # head_dim = inner_dim // attn.heads
        #
        # query = query1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #
        # key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        # value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #
        # # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # # TODO: add support for attn.scale when we move to Torch 2.1
        # hidden_states2 = F.scaled_dot_product_attention(
        #     query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        # )
        # hidden_states2 = hidden_states2.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        # hidden_states2 = hidden_states2.to(query.dtype)
        #
        #
        # ip_key = self.to_k_ip(ip_hidden_states)
        # ip_value = self.to_v_ip(ip_hidden_states)
        # ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        # ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        # ip_hidden_states1 = F.scaled_dot_product_attention(
        #     query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
        # )
        # ip_hidden_states1 = ip_hidden_states1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        # ip_hidden_states1 = ip_hidden_states1.to(query.dtype)
        #
        #
        #
        # hidden_states = hidden_states + 0.6 * (ip_hidden_states1 - hidden_states2 - hidden_states1)
        # ################################################################################################

        pass
        # ################################################################################################


        # if self.style_matrix == None:
        #     self.style_matrix = torch.linalg.pinv(attn.to_k.weight.float())
        #     self.style_matrix1 = torch.linalg.pinv(self.to_k_ip.weight.float())


        # if self.denoise_step == 1:
        #     # self.wk.append(self.to_k_ip(ip_hidden_states))
        #     print(self.name, self.to_k_ip.weight.shape)
        #     from sklearn.manifold import TSNE
        #     import matplotlib.pyplot as plt
        #     from sklearn.neighbors import NearestNeighbors
        #     from sklearn.decomposition import PCA
        #     import numpy as np
        #     import umap
            # 假设你的向量如下
            # features = torch.randn(640, 2048)  # 示例数据，实际用你的 K 向量
            # features_np = self.to_k_ip.weight.cpu().numpy() - self.style_matrix.T.cpu().numpy()
            # features_np = self.style_matrix.cpu().numpy()
            # features_np = attn.to_k.weight.cpu().numpy()
            # features_np = np.linalg.inv(self.to_k_ip.weight.float().cpu().numpy() @ self.to_k_ip.weight.float().T.cpu().numpy())
            # features_np = self.to_k_ip.weight.float().cpu().numpy() @ self.to_k_ip.weight.float().T.cpu().numpy()
            # t-SNE 降维
            # import matplotlib.pyplot as plt
            # import numpy as np
            # from sklearn.decomposition import PCA
            #
            # # PCA on both datasets
            # pca_text = PCA()
            # pca_text.fit(self.style_matrix.cpu().numpy())
            # explained_text = pca_text.explained_variance_ratio_
            #
            # pca_image = PCA()
            # pca_image.fit(self.style_matrix1.cpu().numpy())
            # explained_image = pca_image.explained_variance_ratio_
            #
            # # Compute difference and indices
            # diff = explained_text - explained_image
            # indices = np.arange(1, len(diff) + 1)  # Define component indices starting from 1
            #
            # # Plot the bar chart
            # plt.figure(figsize=(14, 6))
            # plt.bar(indices, diff, color=np.where(diff >= 0, '#1f77b4', '#ff7f0e'), alpha=0.8)
            #
            # plt.axhline(0, color='gray', linewidth=1, linestyle='--')
            # plt.xlabel('Principal Component Index', fontsize=14)
            # plt.ylabel('Variance Difference (Text - Image)', fontsize=14)
            # plt.grid(True, axis='y', alpha=0.3)
            # plt.tight_layout()
            # plt.savefig(f"/root/project1/IP-Adapter/log/log1/1/{self.name}_pca_variance_difference_bar.png",
            #             dpi=1000, bbox_inches='tight')
            # plt.close()

            # # 图表装饰
            # # plt.title('Explained Variance Ratio per Principal Component', fontsize=14, pad=20)
            # plt.xlabel('Principal Component Index', fontsize=12)
            # plt.ylabel('Variance Ratio', fontsize=12)
            # plt.legend(fontsize=12)
            # plt.grid(True, alpha=0.3)
            # plt.xticks(fontsize=10)
            # plt.yticks(fontsize=10)
            #
            # # 保存结果
            # plt.tight_layout()
            # plt.savefig(
            #     f"/root/project1/IP-Adapter/log/log1/1/{self.name}_pca_variance_ratio.png",
            #     dpi=500,
            #     bbox_inches='tight'
            # )
            # plt.close()

            # 初始化画布
            # plt.figure(figsize=(12, 6))
            #
            # # ---------------------------- 主成分方差比分析 ----------------------------
            # # 第一个数据集（Text）
            # pca_text = PCA(n_components=50)
            # pca_text.fit(self.style_matrix.cpu().numpy())
            # explained_text = pca_text.explained_variance_ratio_
            # plt.plot(range(1, 51), explained_text, label="Text", marker='o', linestyle='-', color='#1f77b4')
            #
            # # 第二个数据集（Image）
            # pca_image = PCA(n_components=50)
            # pca_image.fit(self.style_matrix1.cpu().numpy())
            # explained_image = pca_image.explained_variance_ratio_
            # plt.plot(range(1, 51), explained_image, label="Image", marker='s', linestyle='--', color='#ff7f0e')
            #
            # # 图表装饰
            # plt.title('Explained Variance Ratio per Principal Component', fontsize=14, pad=20)
            # plt.xlabel('Principal Component Index', fontsize=12)
            # plt.ylabel('Variance Ratio', fontsize=12)
            # plt.legend(fontsize=12)
            # plt.grid(True, alpha=0.3)
            # plt.xticks(fontsize=10)
            # plt.yticks(fontsize=10)
            #
            # # 保存结果
            # plt.tight_layout()
            # plt.savefig(
            #     f"/root/project1/IP-Adapter/log/log1/1/{self.name}_pca_variance_ratio.png",
            #     dpi=500,
            #     bbox_inches='tight'
            # )
            # plt.close()

            # import matplotlib.pyplot as plt
            # import numpy as np
            # from sklearn.decomposition import PCA
            #
            # plt.figure(figsize=(10, 6))
            #
            # # 第一个 style_matrix
            # pca = PCA(n_components=50)
            # pca.fit(self.style_matrix.cpu().numpy())
            # explained1 = pca.explained_variance_ratio_
            # plt.plot(range(1, 51), explained1, label="Text", marker='o')
            #
            # # 第二个 style_matrix1
            # pca = PCA(n_components=50)
            # pca.fit(self.style_matrix1.cpu().numpy())
            # explained2 = pca.explained_variance_ratio_
            # plt.plot(range(1, 51), explained2, label="Image", marker='s')
            #
            # # 图像设置
            # plt.title('Explained Variance Ratio of Each PCA Component')
            # plt.xlabel('Principal Component Index')
            # plt.ylabel('Explained Variance Ratio')
            # plt.legend()
            # plt.grid(True)
            # plt.tight_layout()
            #
            # # 保存
            # plt.savefig(f"/root/project1/IP-Adapter/log/log1/1/{self.name}_per_component.png", dpi=500,
            #             bbox_inches='tight')

            # tsne = TSNE(n_components=2, perplexity=30, random_state=42)
            #
            # # tsne = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, metric='cosine', random_state=42)
            # features_2d = tsne.fit_transform(features_np)
            # plt.figure(figsize=(12, 10))
            #
            # nbrs = NearestNeighbors(n_neighbors=10).fit(features_2d)
            # distances, _ = nbrs.kneighbors(features_2d)
            #
            # # Step 2: 局部密度 = 1 / 平均距离（越密集越小）
            # local_density = 1. / (distances[:, 1:].mean(axis=1) + 1e-5)  # 跳过自己，避免为0
            # local_density = (local_density - local_density.min()) / (
            #             local_density.max() - local_density.min())  # 归一化到 [0,1]
            #
            # # Step 3: 可视化，颜色由局部密度决定（最亮的点在最上层）
            # # 根据 local_density 排序（从低到高）
            # sorted_idx = np.argsort(local_density)
            # features_sorted = features_2d[sorted_idx]
            # density_sorted = local_density[sorted_idx]
            # plt.scatter(features_sorted[:, 0], features_sorted[:, 1],
            #             c=density_sorted, cmap='inferno', s=30, alpha=1)
            # plt.colorbar(label="Local Density")
            # plt.savefig(f"/root/project1/IP-Adapter/log/log1/1/{self.name}_density.png", dpi=500, bbox_inches='tight')
            # plt.close()
        #
        #
        #
        # if self.style_matrix == None:
        #     self.style_matrix = torch.linalg.pinv(self.to_k_ip.weight.float())
        #     # self.style_matrix = torch.linalg.pinv(attn.to_k.weight.float())
        #     #
        #     embeddings = F.softmax((ip_hidden_states), dim=-1)  # 转为概率分布
        #     eps = 1e-3  # 避免 log(0)
        #     entropy1 = -torch.sum(embeddings * torch.log(embeddings + eps), dim=-1)
        #
        #     embeddings = F.softmax((ip_hidden_states_style), dim=-1)  # 转为概率分布
        #     entropy2 = -torch.sum(embeddings * torch.log(embeddings + eps), dim=-1)
        #
        #     self.scale_entropy = (entropy1.sum() / (entropy1.sum() + entropy2.sum()))
        #
        # # if self.scale_entropy == None:
        # #     embeddings = F.softmax((ip_hidden_states), dim=-1)  # 转为概率分布
        # #     eps = 1e-3  # 避免 log(0)
        # #     entropy1 = -torch.sum(embeddings * torch.log(embeddings + eps), dim=-1)
        # #
        # #     embeddings = F.softmax((ip_hidden_states_style), dim=-1)  # 转为概率分布
        # #     entropy2 = -torch.sum(embeddings * torch.log(embeddings + eps), dim=-1)
        # #
        # #     self.scale_entropy = (entropy1.sum() / (entropy1.sum() + entropy2.sum()))
        #
        # t1 = 1
        # t2 = 1
        # #ab1
        # # scale1 = 0.55
        # # scale2 = 0.65
        #
        # #ab3
        # scale1 = 0.55
        # scale2 = 0.7
        # #ab2
        # # scale1 = 0.6
        # # scale2 = 0.6
        # #
        # # if self.denoise_step <= 50 * self.scale: #:
        # if self.denoise_step <= 50 * self.scale_entropy + 1: #:
        # # if self.denoise_step > 30 * self.scale_entropy + 1: #:
        # # if self.denoise_step <= 30 * 0.5 + 1: #:
        # # if self.denoise_step < 30 * self.scale_entropy: #:
        # # if self.denoise_step < 60: #:/
        #     if t1 == 0:
        #         ip_key = self.to_k_ip(ip_hidden_states)
        #         ip_value = self.to_v_ip(ip_hidden_states)
        #         ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #         ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #         ip_hidden_states1 = F.scaled_dot_product_attention(
        #             query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
        #         )
        #         ip_hidden_states1 = ip_hidden_states1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        #         ip_hidden_states1 = ip_hidden_states1.to(query.dtype)
        #         hidden_states += scale1 * ip_hidden_states1
        #     elif t1 == 1:
        #         # === 图像内容分支 ===
        #         key_img_content = self.to_k_ip(ip_hidden_states) - ip_hidden_states @ self.style_matrix.to(hidden_states.dtype)
        #         key_fused_content = key_img_content
        #         value_content = self.to_v_ip(ip_hidden_states)
        #
        #         # === 图像风格分支 ===
        #         key_img_style = self.to_k_ip(ip_hidden_states_style)
        #         key_img_style1 = ip_hidden_states_style @ self.style_matrix.to(hidden_states.dtype)
        #         key_txt_style = attn.to_k(encoder_content_hidden_states1)[:, 1: int(nums[0,0,0]) + 1, :]
        #         key_txt_style = key_txt_style.mean(dim = 1).unsqueeze(1).repeat(1, self.num_tokens, 1)
        #
        #         key_fused_style = key_img_style + key_txt_style - key_img_style1
        #         key_fused_style = key_fused_style
        #
        #         value_style = self.to_v_ip(ip_hidden_states_style)
        #
        #         ip_key = torch.cat([key_fused_content, key_fused_style],1)
        #         ip_value = torch.cat([value_content, value_style],1)
        #
        #         ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #         ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #         ip_hidden_states1 = F.scaled_dot_product_attention(
        #             query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
        #         )
        #         ip_hidden_states1 = ip_hidden_states1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        #         ip_hidden_states1 = ip_hidden_states1.to(query.dtype)
        #         hidden_states += scale1 * ip_hidden_states1
        #     elif t1 == 2:
        #         ip_key = self.to_k_ip(ip_hidden_states) - ip_hidden_states @ self.style_matrix.to(hidden_states.dtype)
        #         ip_value = self.to_v_ip(ip_hidden_states)
        #         ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #         ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #         ip_hidden_states1 = F.scaled_dot_product_attention(
        #             query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
        #         )
        #         ip_hidden_states1 = ip_hidden_states1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        #         ip_hidden_states1 = ip_hidden_states1.to(query.dtype)
        #         hidden_states += scale1 * ip_hidden_states1
        #     elif t1 == 3:
        #         # === 图像内容分支 ===
        #         key_img_content = self.to_k_ip(ip_hidden_states)
        #         key_fused_content = key_img_content
        #         value_content = self.to_v_ip(ip_hidden_states)
        #
        #         # === 图像风格分支 ===
        #         key_img_style = self.to_k_ip(ip_hidden_states_style)
        #         key_img_style1 = ip_hidden_states_style @ self.style_matrix.to(hidden_states.dtype)
        #         key_txt_style = attn.to_k(encoder_content_hidden_states1)[:, 1: int(nums[0,0,0]) + 1, :]
        #         key_txt_style = key_txt_style.mean(dim = 1).unsqueeze(1).repeat(1, self.num_tokens, 1)
        #
        #         key_fused_style = key_img_style + key_txt_style
        #         key_fused_style = key_fused_style
        #
        #         value_style = self.to_v_ip(ip_hidden_states_style)
        #
        #         ip_key = torch.cat([key_fused_content, key_fused_style],1)
        #         ip_value = torch.cat([value_content, value_style],1)
        #
        #         ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #         ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #         ip_hidden_states1 = F.scaled_dot_product_attention(
        #             query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
        #         )
        #         ip_hidden_states1 = ip_hidden_states1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        #         ip_hidden_states1 = ip_hidden_states1.to(query.dtype)
        #         hidden_states += scale1 * ip_hidden_states1
        #     else:
        #         pass
        # else:
        #     if t2 == 0:
        #         ip_key = self.to_k_ip(ip_hidden_states_style)
        #         ip_value = self.to_v_ip(ip_hidden_states_style)
        #         ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #         ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #         ip_hidden_states1 = F.scaled_dot_product_attention(
        #             query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
        #         )
        #         ip_hidden_states1 = ip_hidden_states1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        #         ip_hidden_states1 = ip_hidden_states1.to(query.dtype)
        #         hidden_states += scale2 * ip_hidden_states1
        #     elif t2 == 1:
        #         # === 图像内容分支 ===
        #         ip_key = ip_hidden_states_style @ self.style_matrix.to(hidden_states.dtype)
        #         ip_value = self.to_v_ip(ip_hidden_states_style)
        #         ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #         ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #         ip_hidden_states1 = F.scaled_dot_product_attention(
        #             query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
        #         )
        #         ip_hidden_states1 = ip_hidden_states1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        #         ip_hidden_states1 = ip_hidden_states1.to(query.dtype)
        #         hidden_states += scale2 * ip_hidden_states1



        # Image-Driven Style Transfer
        # if self.style_matrix == None:
        #     self.style_matrix = torch.linalg.pinv(self.to_k_ip.weight.float())
        #     # self.style_matrix = torch.linalg.pinv(attn.to_k.weight.float())
        #     #
        #     embeddings = F.softmax((ip_hidden_states), dim=-1)  # 转为概率分布
        #     eps = 1e-3  # 避免 log(0)
        #     entropy1 = -torch.sum(embeddings * torch.log(embeddings + eps), dim=-1)
        #
        #     embeddings = F.softmax((ip_hidden_states_style), dim=-1)  # 转为概率分布
        #     entropy2 = -torch.sum(embeddings * torch.log(embeddings + eps), dim=-1)
        #
        #     self.scale_entropy = (entropy1.sum() / (entropy1.sum() + entropy2.sum()))
        # # if self.denoise_step <= 30 * self.scale_entropy + 1: #:
        #     # === 图像内容分支 ===
        #
        # if self.denoise_step <= 50 * self.scale_entropy + 1: #:
        #     ip_key = self.to_k_ip(ip_hidden_states) - ip_hidden_states @ self.style_matrix.to(hidden_states.dtype)
        #     ip_value = self.to_v_ip(ip_hidden_states)
        #     ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #     ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #     ip_hidden_states1 = F.scaled_dot_product_attention(
        #         query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
        #     )
        #     ip_hidden_states1 = ip_hidden_states1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        #     ip_hidden_states1 = ip_hidden_states1.to(query.dtype)
        #     hidden_states = ip_hidden_states1
        # else:
        #     # # === 图像内容分支 ===
        #     ip_key = ip_hidden_states_style @ self.style_matrix.to(hidden_states.dtype)
        #     ip_value = self.to_v_ip(ip_hidden_states_style)
        #     ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #     ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #     ip_hidden_states1 = F.scaled_dot_product_attention(
        #         query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
        #     )
        #     ip_hidden_states1 = ip_hidden_states1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        #     ip_hidden_states1 = ip_hidden_states1.to(query.dtype)
        #     hidden_states = ip_hidden_states1
        #
        #     ip_key = self.to_k_ip(ip_hidden_states) - ip_hidden_states @ self.style_matrix.to(hidden_states.dtype)
        #     ip_value = self.to_v_ip(ip_hidden_states)
        #     ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #     ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #     ip_hidden_states1 = F.scaled_dot_product_attention(
        #         query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
        #     )
        #     ip_hidden_states1 = ip_hidden_states1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        #     ip_hidden_states1 = ip_hidden_states1.to(query.dtype)
        #     hidden_states += 0.6 * ip_hidden_states1



        #########attention score map
        if self.style_matrix == None:
            # self.style_matrix = torch.linalg.pinv(self.to_k_ip.weight.float())
            U, S, V = torch.linalg.svd(self.to_k_ip.weight.float())
            S = 1/S
            A = self.to_k_ip.weight.float()
            Sigma = torch.zeros_like(A)  # 创建和 A 同形状的零矩阵
            min_dim = min(A.shape)  # 取矩阵的最小维度（秩的上限）
            Sigma[:min_dim, :min_dim] = torch.diag(S[:min_dim])  # 填充奇异值到 Sigma

            self.style_matrix = (U @ Sigma @ V).T  # 矩阵乘法还原
            # print(S)
            # U, S1, V = torch.linalg.svd(self.style_matrix.T.float())
            # print(S-1/S)
            # print("========================================")
            # self.style_matrix = torch.linalg.pinv(attn.to_k.weight.float())
            #
            embeddings = F.softmax((ip_hidden_states), dim=-1)  # 转为概率分布
            eps = 1e-3  # 避免 log(0)
            entropy1 = -torch.sum(embeddings * torch.log(embeddings + eps), dim=-1)

            embeddings = F.softmax((ip_hidden_states_style), dim=-1)  # 转为概率分布
            entropy2 = -torch.sum(embeddings * torch.log(embeddings + eps), dim=-1)

            self.scale_entropy = (entropy1.sum() / (entropy1.sum() + entropy2.sum()))
        #
        ###1
        # key_img_content = self.to_k_ip(ip_hidden_states) - ip_hidden_states @ self.style_matrix.to(hidden_states.dtype)
        # key_fused_content = key_img_content
        # value_content = self.to_v_ip(ip_hidden_states)
        #
        # key_img_style = self.to_k_ip(ip_hidden_states_style)
        # key_img_style1 = ip_hidden_states_style @ self.style_matrix.to(hidden_states.dtype)
        # key_txt_style = attn.to_k(encoder_content_hidden_states1)[:, 1: int(nums[0, 0, 0]) + 1, :]
        # key_txt_style = key_txt_style.mean(dim=1).unsqueeze(1).repeat(1, self.num_tokens, 1)
        # key_fused_style = key_img_style + key_txt_style - key_img_style1
        # key_fused_style = key_fused_style
        # value_style = self.to_v_ip(ip_hidden_states_style)
        #
        # ip_key = torch.cat([key_fused_content, key_fused_style], 1)
        # ip_value = torch.cat([value_content, value_style], 1)

        ###2
        # ip_key = self.to_k_ip(ip_hidden_states)
        # ip_value = self.to_v_ip(ip_hidden_states)

        ###3
        # if self.denoise_step > 50 * 0.3:
        #     ip_key = ip_hidden_states_style @ self.style_matrix.to(hidden_states.dtype)
        #     # ip_key = self.to_k_ip(ip_hidden_states_style)
        #     # ip_key = attn.to_k(ip_hidden_states_style)
        #     ip_value = self.to_v_ip(ip_hidden_states_style)
        #
        #     ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #     ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #     ip_hidden_states1 = F.scaled_dot_product_attention(
        #         query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
        #     )
        #     with torch.no_grad():
        #         self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
        #
        #     ip_hidden_states1 = ip_hidden_states1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        #     ip_hidden_states1 = ip_hidden_states1.to(query.dtype)
        #     hidden_states += 0.6 * ip_hidden_states1


        # with torch.no_grad():
        #     self.attn_map = query @ key.transpose(-2, -1).softmax(dim=-1)[:,:,:,1:9]













        if self.denoise_step <= 30 * self.scale_entropy + 1: #:
        # if self.denoise_step < 30 * self.scale_entropy:
            # === 图像内容分支 ===
            key_img_content = self.to_k_ip(ip_hidden_states) - ip_hidden_states @ self.style_matrix.to(hidden_states.dtype)
            key_fused_content = key_img_content
            value_content = self.to_v_ip(ip_hidden_states)

            # === 图像风格分支 ===
            key_img_style = self.to_k_ip(ip_hidden_states_style)
            key_img_style1 = ip_hidden_states_style @ self.style_matrix.to(hidden_states.dtype)
            key_txt_style = attn.to_k(encoder_content_hidden_states1)[:, 1: int(nums[0,0,0]) + 1, :]
            key_txt_style = key_txt_style.mean(dim = 1).unsqueeze(1).repeat(1, self.num_tokens, 1)

            key_fused_style = key_img_style + key_txt_style - key_img_style1
            key_fused_style = key_fused_style

            value_style = self.to_v_ip(ip_hidden_states_style)

            ip_key = torch.cat([key_fused_content, key_fused_style],1)
            ip_value = torch.cat([value_content, value_style],1)

            ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            ip_hidden_states1 = F.scaled_dot_product_attention(
                query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
            )
            # with torch.no_grad():
            #     self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
            ip_hidden_states1 = ip_hidden_states1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
            ip_hidden_states1 = ip_hidden_states1.to(query.dtype)
            hidden_states += 0.6 * ip_hidden_states1
        else:
            # === 图像内容分支 ===
            ip_key = ip_hidden_states_style @ self.style_matrix.to(hidden_states.dtype)
            ip_value = self.to_v_ip(ip_hidden_states_style)
            ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            ip_hidden_states1 = F.scaled_dot_product_attention(
                query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
            )
            ip_hidden_states1 = ip_hidden_states1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
            ip_hidden_states1 = ip_hidden_states1.to(query.dtype)
            hidden_states += 0.6 * ip_hidden_states1


        # if self.style_matrix == None:
        #     self.style_matrix = torch.linalg.pinv(self.to_k_ip.weight.float())
        #
        #     if int(nums[0,0,0]) <= 5:
        #         embeddings = F.softmax((encoder_hidden_states[:, 1: int(self.num_tokens) + 1, :]), dim=-1)
        #     else:
        #         embeddings = F.softmax((encoder_hidden_states[:, 1: int(nums[0,0,0]) + 1, :]), dim=-1)
        #     eps = 1e-3
        #     entropy1 = -torch.sum(embeddings * torch.log(embeddings + eps), dim=-1)
        #     embeddings = F.softmax((ip_hidden_states_style), dim=-1)
        #     entropy2 = -torch.sum(embeddings * torch.log(embeddings + eps), dim=-1)
        #     # print(entropy1.sum() ,entropy2.sum() )
        #     self.scale_entropy = (entropy1.sum() / (entropy1.sum() + entropy2.sum()))
        #     # print(self.scale_entropy)
        # # if self.denoise_step <= 50 * self.scale_entropy:
        # # ip_hidden_states_style = ip_hidden_states_style * self.scale + ip_hidden_states * (1 - self.scale)
        # if self.denoise_step <= 50 * self.scale_entropy:
        #
        #     ip_key = self.to_k_ip(ip_hidden_states_style) - ip_hidden_states_style @ self.style_matrix.to(hidden_states.dtype)
        #     ip_value = self.to_v_ip(ip_hidden_states_style)
        #
        #
        #     ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #     ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #     ip_hidden_states1 = F.scaled_dot_product_attention(
        #         query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
        #     )
        #     ip_hidden_states1 = ip_hidden_states1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        #     ip_hidden_states1 = ip_hidden_states1.to(query.dtype)
        #     # hidden_states += 0.35 * ip_hidden_states1
        #     hidden_states += 0.4 * ip_hidden_states1
        # else:
        #     # === 图像内容分支 ===
        #     ip_key = ip_hidden_states_style @ self.style_matrix.to(hidden_states.dtype)
        #     ip_value = self.to_v_ip(ip_hidden_states_style)
        #     ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #     ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #     ip_hidden_states1 = F.scaled_dot_product_attention(
        #         query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
        #     )
        #     ip_hidden_states1 = ip_hidden_states1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        #     ip_hidden_states1 = ip_hidden_states1.to(query.dtype)
        #     hidden_states += 0.7 * ip_hidden_states1



        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


## for controlnet
class CNAttnProcessor:
    r"""
    Default processor for performing attention-related computations.
    """

    def __init__(self, num_tokens=4):
        self.num_tokens = num_tokens

    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs,):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
            encoder_hidden_states = encoder_hidden_states[:, :end_pos]  # only use text
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)
        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class CNAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(self, num_tokens=4):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
        self.num_tokens = num_tokens

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
            encoder_hidden_states = encoder_hidden_states[:, :end_pos]  # only use text
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states
