
import xformers.ops
import torch
import math
import sys
import torch.nn.functional as F

def register_attention_control(model, controller):
    def ca_forward_sa(self):
        def forward(x: torch.Tensor) -> torch.Tensor:
            B, N, C = x.shape
            enable_flash_attn = False
            qkv = self.qkv(x)
            qkv_shape = (B, N, 3, self.num_heads, self.head_dim)

            qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
            q, k, v = qkv.unbind(0)
            if self.qk_norm_legacy:
                if self.rope:
                    q = self.rotary_emb(q)
                    k = self.rotary_emb(k)
                q, k = self.q_norm(q), self.k_norm(k)
            else:
                q, k = self.q_norm(q), self.k_norm(k)
                if self.rope:
                    q = self.rotary_emb(q)
                    k = self.rotary_emb(k)

            dtype = q.dtype
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)  # translate attn to float32
            attn = controller.SAforward(attn)
            attn = attn.to(torch.float32)
            attn = attn.softmax(dim=-1)
            attn = attn.to(dtype)  # cast back attn to original dtype
            attn = self.attn_drop(attn)
            x = attn @ v
            x_output_shape = (B, N, C)
            x = x.transpose(1, 2)
            x = x.reshape(x_output_shape)
            x = self.proj(x)
            x = self.proj_drop(x)
            return x
        return forward
    def ca_forward(self):
        def forward(x, cond, mask=None):
            B, N, C = x.shape
            #print(F"x.shape:{x.shape}")
            if mask is None:
                Bc, Nc, _ = cond.shape
                assert Bc == B, ""
                mask = [Nc] * B

            q = self.q_linear(x).view(B, N, self.num_heads, self.head_dim)

            if cond.shape[0] == 1 and B > 1:
                cond = cond.repeat(B, 1, 1)
            kv = self.kv_linear(cond).view(B, -1, 2, self.num_heads, self.head_dim)
            k, v = kv.unbind(2)  
            scale = 1.0 / math.sqrt(self.head_dim)
            q = q * scale

            q = q.transpose(1, 2)           
            k = k.transpose(1, 2)          
            v = v.transpose(1, 2)          
            outputs = []
            attn_list = [] 
            v_b_list = []
            key_start = 0
            for b in range(B):
                key_len = mask[b] 
                k_b = k[b, :, key_start:key_start+key_len, :] 
                v_b = v[b, :, key_start:key_start+key_len, :]  
                v_b_list.append(v_b)
                q_b = q[b] 
                attn_scores = torch.matmul(q_b, k_b.transpose(-2, -1))
                attn_weights = torch.softmax(attn_scores, dim=-1)
                attn_weights = self.attn_drop(attn_weights)
                attn_list.append(attn_weights)
                key_start += key_len  
            attn_list = controller(attn_list)
            for b in range(B):
                attn_weights = attn_list[b]
                v_b = v_b_list[b]
                out = torch.matmul(attn_weights, v_b)
                outputs.append(out)
            out = torch.stack(outputs, dim=0)
            out = out.transpose(1, 2).reshape(B, N, C)
            out = self.proj(out)
            out = self.proj_drop(out)
            return out
        return forward
    class DummyController:
        def __call__(self, *args):
            return args[0]
        def __init__(self):
            self.num_att_layers = 0
        def SAforward(self, attn):
            return attn
    if controller == None:
        controller = DummyController()
    cross_att_count = 0
    def register_recr(net_, count):
        if net_.__class__.__name__ == 'MultiHeadCrossAttention':
            net_.forward = ca_forward(net_)
            return count + 1
        elif net_.__class__.__name__ == 'Attention':
            net_.forward = ca_forward_sa(net_)
            return count
        elif hasattr(net_, 'children'):
            for net__ in net_.children():
                count = register_recr(net__, count)
        return count
    for net in model.named_children():
        if 'DiTMoudle' in net[0]:
            for layer in net[1]:
                cross_att_count += register_recr(layer, 0)

    controller.num_att_layers = cross_att_count
    print(F"cross_attention:{cross_att_count}")