
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat


class AttentionBase:

    def __init__(self):
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0

    def after_step(self):
        pass

    def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
        out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
        self.cur_att_layer += 1
        if self.cur_att_layer == self.num_att_layers:
            self.cur_att_layer = 0
            self.cur_step += 1
            # after step
            self.after_step()
        return out

    def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
        out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return out

    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0


class MutualSelfAttentionControl(AttentionBase):

    def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, guidance_scale=7.5):
        """
        Mutual self-attention control for Stable-Diffusion model
        Args:
            start_step: the step to start mutual self-attention control
            start_layer: the layer to start mutual self-attention control
            layer_idx: list of the layers to apply mutual self-attention control
            step_idx: list the steps to apply mutual self-attention control
            total_steps: the total number of steps
        """
        super().__init__()
        self.total_steps = total_steps
        self.start_step = start_step
        self.start_layer = start_layer
        self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, 16))
        self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
        # store the guidance scale to decide whether there are unconditional branch
        self.guidance_scale = guidance_scale
        print("step_idx: ", self.step_idx)
        print("layer_idx: ", self.layer_idx)

    def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
        """
        Attention forward function
        """
        if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
            return super().forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)

        if self.guidance_scale > 1.0:
            qu, qc = q[0:2], q[2:4]
            ku, kc = k[0:2], k[2:4]
            vu, vc = v[0:2], v[2:4]

            # merge queries of source and target branch into one so we can use torch API
            qu = torch.cat([qu[0:1], qu[1:2]], dim=2)
            qc = torch.cat([qc[0:1], qc[1:2]], dim=2)

            out_u = F.scaled_dot_product_attention(qu, ku[0:1], vu[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
            out_u = torch.cat(out_u.chunk(2, dim=2), dim=0) # split the queries into source and target batch
            out_u = rearrange(out_u, 'b h n d -> b n (h d)')

            out_c = F.scaled_dot_product_attention(qc, kc[0:1], vc[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
            out_c = torch.cat(out_c.chunk(2, dim=2), dim=0) # split the queries into source and target batch
            out_c = rearrange(out_c, 'b h n d -> b n (h d)')

            out = torch.cat([out_u, out_c], dim=0)
        else:
            q = torch.cat([q[0:1], q[1:2]], dim=2)
            out = F.scaled_dot_product_attention(q, k[0:1], v[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
            out = torch.cat(out.chunk(2, dim=2), dim=0) # split the queries into source and target batch
            out = rearrange(out, 'b h n d -> b n (h d)')
        return out

# forward function for default attention processor
# modified from __call__ function of AttnProcessor in diffusers
def override_attn_proc_forward(attn, editor, place_in_unet):
    def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
        """
        The attention is similar to the original implementation of LDM CrossAttention class
        except adding some modifications on the attention
        """
        if encoder_hidden_states is not None:
            context = encoder_hidden_states
        if attention_mask is not None:
            mask = attention_mask

        to_out = attn.to_out
        if isinstance(to_out, nn.modules.container.ModuleList):
            to_out = attn.to_out[0]
        else:
            to_out = attn.to_out

        h = attn.heads
        q = attn.to_q(x)
        is_cross = context is not None
        context = context if is_cross else x
        k = attn.to_k(context)
        v = attn.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))

        # the only difference
        out = editor(
            q, k, v, is_cross, place_in_unet,
            attn.heads, scale=attn.scale)

        return to_out(out)

    return forward

# forward function for lora attention processor
# modified from __call__ function of LoRAAttnProcessor2_0 in diffusers v0.17.1
def override_lora_attn_proc_forward(attn, editor, place_in_unet):
    def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, lora_scale=1.0):
        residual = hidden_states
        input_ndim = hidden_states.ndim
        is_cross = encoder_hidden_states is not None

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states)

        query, key, value = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=attn.heads), (query, key, value))

        # the only difference
        hidden_states = editor(
            query, key, value, is_cross, place_in_unet,
            attn.heads, scale=attn.scale)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.processor.to_out_lora(hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

    return forward

def register_attention_editor_diffusers(model, editor: AttentionBase, attn_processor='attn_proc'):
    """
    Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
    """
    def register_editor(net, count, place_in_unet):
        for name, subnet in net.named_children():
            if net.__class__.__name__ == 'Attention':  # spatial Transformer layer
                if attn_processor == 'attn_proc':
                    net.forward = override_attn_proc_forward(net, editor, place_in_unet)
                elif attn_processor == 'lora_attn_proc':
                    net.forward = override_lora_attn_proc_forward(net, editor, place_in_unet)
                else:
                    raise NotImplementedError("not implemented")
                return count + 1
            elif hasattr(net, 'children'):
                count = register_editor(subnet, count, place_in_unet)
        return count

    cross_att_count = 0
    for net_name, net in model.unet.named_children():
        if "down" in net_name:
            cross_att_count += register_editor(net, 0, "down")
        elif "mid" in net_name:
            cross_att_count += register_editor(net, 0, "mid")
        elif "up" in net_name:
            cross_att_count += register_editor(net, 0, "up")
    editor.num_att_layers = cross_att_count
