import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.utils.import_utils import is_xformers_available
from torchvision import transforms
if is_xformers_available():
    import xformers
    import xformers.ops
else:
    xformers = None
import matplotlib.pyplot as plt
import numpy as np


class MixIT_AttnProcessor(nn.Module):
    def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0, use_orig_kv=False):
        super().__init__()

        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
        self.num = 0
        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.use_orig_kv = use_orig_kv

        self.to_k_SC = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_SC = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

    def __call__(
            self,
            attn,
            hidden_states,
            encoder_hidden_states=None,
            attention_mask=None,
            temb=None,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
        
        # split hidden states
        sq_len = encoder_hidden_states.size(1)
        prompt_len = sq_len // (7*77)
        text_hidden_states, IP_hidden_states = encoder_hidden_states[:, :sq_len//7,
                                                  :], encoder_hidden_states[:, sq_len//7:, :]
        text_hidden_states = text_hidden_states.view(batch_size, prompt_len, 77, -1)
        IP_hidden_states = IP_hidden_states.view(batch_size, prompt_len, 77*6, -1)

        key_text = attn.to_k(text_hidden_states)
        value_text = attn.to_v(text_hidden_states)
        key_ip = self.to_k_SC(IP_hidden_states)
        value_ip = self.to_v_SC(IP_hidden_states)

        query = query.unsqueeze(1).repeat(1, prompt_len, 1, 1)
        key = torch.cat([key_text, key_ip], dim=2)
        value = torch.cat([value_text, value_ip*3], dim=2)   # enhance the strength of image prompt

        inner_dim = key_text.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, prompt_len, -1, attn.heads, head_dim).transpose(2, 3)#.transpose(1, 2)
        key = key.view(batch_size, prompt_len, -1, attn.heads, head_dim).transpose(2, 3)#.transpose(1, 2)
        value = value.view(batch_size, prompt_len, -1, attn.heads, head_dim).transpose(2, 3)#.transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(2, 3).reshape(batch_size, prompt_len, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)
        hidden_states = hidden_states.mean(dim=1)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

##
class KiVt_AttnProcessor(nn.Module):
    def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0, use_orig_kv=False):
        super().__init__()

        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
        self.num = 0
        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.use_orig_kv = use_orig_kv

        self.to_k_SC = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_SC = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

    def __call__(
            self,
            attn,
            hidden_states,
            encoder_hidden_states=None,
            attention_mask=None,
            temb=None,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
        
        # split hidden states
        sq_len = encoder_hidden_states.size(1)
        prompt_len = sq_len // (7*77)
        text_hidden_states, IP_hidden_states = encoder_hidden_states[:, :sq_len//7,
                                                  :], encoder_hidden_states[:, sq_len//7:, :]
        text_hidden_states = text_hidden_states.view(batch_size, prompt_len, 77, -1).repeat(1, 1, 6, 1)
        IP_hidden_states = IP_hidden_states.view(batch_size, prompt_len, 77*6, -1)

        key_text = attn.to_k(text_hidden_states)
        # value_text = attn.to_v(text_hidden_states)
        # key_ip = self.to_k_SC(IP_hidden_states)
        value_ip = self.to_v_SC(IP_hidden_states)

        query = query.unsqueeze(1).repeat(1, prompt_len, 1, 1)
        key = key_text
        value = value_ip

        inner_dim = key_text.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, prompt_len, -1, attn.heads, head_dim).transpose(2, 3)#.transpose(1, 2)
        key = key.view(batch_size, prompt_len, -1, attn.heads, head_dim).transpose(2, 3)#.transpose(1, 2)
        value = value.view(batch_size, prompt_len, -1, attn.heads, head_dim).transpose(2, 3)#.transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(2, 3).reshape(batch_size, prompt_len, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)
        hidden_states = hidden_states.mean(dim=1)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class TMaskIP_AttnProcessor(nn.Module):
    def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0, use_orig_kv=False):
        super().__init__()

        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
        self.num = 0
        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.use_orig_kv = use_orig_kv

        self.to_k_SC = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_SC = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

    def __call__(
            self,
            attn,
            hidden_states,
            encoder_hidden_states=None,
            attention_mask=None,
            temb=None,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
        
        # split hidden states
        sq_len = encoder_hidden_states.size(1)
        prompt_len = sq_len // (7*77)
        text_hidden_states, IP_hidden_states = encoder_hidden_states[:, :sq_len//7,
                                                  :], encoder_hidden_states[:, sq_len//7:, :]
        text_hidden_states = text_hidden_states.view(batch_size, prompt_len, 77, -1)
        IP_hidden_states = IP_hidden_states.view(batch_size, prompt_len, 77*6, -1)

        key_text = attn.to_k(text_hidden_states)
        # value_text = attn.to_v(text_hidden_states)
        key_ip = self.to_k_SC(IP_hidden_states)
        value_ip = self.to_v_SC(IP_hidden_states)

        inner_dim = key_text.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.unsqueeze(1).repeat(1, prompt_len, 1, 1).view(batch_size, prompt_len, -1, attn.heads, head_dim).transpose(2, 3)
        key_text = key_text.view(batch_size, prompt_len, -1, attn.heads, head_dim).transpose(2, 3)
        attention_ip_mask = torch.matmul(query, key_text.transpose(-2, -1)) / math.sqrt(query.size(-1))
        attention_ip_mask = attention_ip_mask.mean(dim=-1, keepdim=True)
        attention_ip_mask = F.softmax(attention_ip_mask, dim=-2)
        threshold = 0.6*attention_ip_mask.mean()
        attention_ip_mask[attention_ip_mask < threshold] = 0.0
        attention_ip_mask[attention_ip_mask >= threshold] = 1.0
        # attention_ip_mask = attention_ip_mask.repeat(1, 1, 1, 1, 77*6)

        # query = query.view(batch_size, prompt_len, -1, attn.heads, head_dim).transpose(2, 3)#.transpose(1, 2)
        key = key_ip.view(batch_size, prompt_len, -1, attn.heads, head_dim).transpose(2, 3)#.transpose(1, 2)
        value = value_ip.view(batch_size, prompt_len, -1, attn.heads, head_dim).transpose(2, 3)#.transpose(1, 2)
        attention_ip_mask = attention_ip_mask.view(batch_size, prompt_len, attn.heads, -1, attention_ip_mask.shape[-1])

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        # hidden_states = F.scaled_dot_product_attention(
        #     query, key, value, attn_mask=attention_ip_mask, dropout_p=0.0, is_causal=False
        # )
        hidden_states = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
        hidden_states = F.softmax(hidden_states, dim=-1)
        hidden_states = hidden_states*attention_ip_mask
        hidden_states = torch.matmul(hidden_states, value)

        hidden_states = hidden_states.transpose(2, 3).reshape(batch_size, prompt_len, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)
        hidden_states = hidden_states.mean(dim=1)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states
