# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# modified from DUSt3R

import torch
import torch.nn as nn

from itertools import repeat
import collections.abc
from torch.nn.functional import scaled_dot_product_attention
from functools import partial


def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return x
        return tuple(repeat(x, n))

    return parse


to_2tuple = _ntuple(2)


def drop_path(
    x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (
        x.ndim - 1
    )  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""

    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

    def extra_repr(self):
        return f"drop_prob={round(self.drop_prob,3):0.3f}"


class Mlp(nn.Module):
    """MLP as used in Vision Transformer, MLP-Mixer and related networks"""

    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        bias=True,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)

        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        return self.drop2(self.fc2(self.drop1(self.act(self.fc1(x)))))

class DecodeAttention(nn.Module):

    def __init__(
        self, dim, rope=None, rope3d=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.rope = rope.float() if rope is not None else None
        self.rope3d = rope3d.float() if rope3d is not None else None

    def forward(self, mask, x, xpos, point3r_tag=False):
        B, N, C = x.shape

        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .transpose(1, 3)
        )
        q, k, v = [qkv[:, :, i] for i in range(3)]

        if xpos is not None:
            q_type = q.dtype
            k_type = k.dtype
            if self.rope3d is not None:
                q = q.float()
                k = k.float()
                with torch.autocast(device_type="cuda", enabled=False):
                    q = self.rope3d(q, xpos)
                    k = self.rope3d(k, xpos)
                q = q.to(q_type)
                k = k.to(k_type)
        x = (
            scaled_dot_product_attention(
                query=q, key=k, value=v, dropout_p=self.attn_drop.p, scale=self.scale
            )
            .transpose(1, 2)
            .reshape(B, N, C)
        )

        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Attention(nn.Module):

    def __init__(
        self, dim, rope=None, rope3d=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.rope = rope.float() if rope is not None else None
        self.rope3d = rope3d.float() if rope3d is not None else None

    def forward(self, x, xpos, point3r_tag=False):
        B, N, C = x.shape

        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .transpose(1, 3)
        )
        q, k, v = [qkv[:, :, i] for i in range(3)]

        q_type = q.dtype
        k_type = k.dtype
        if self.rope is not None:
            q = q.float()
            k = k.float()
            with torch.autocast(device_type="cuda", enabled=False):
                q = self.rope(q, xpos)
                k = self.rope(k, xpos)
            q = q.to(q_type)
            k = k.to(k_type)

        x = (
            scaled_dot_product_attention(
                query=q, key=k, value=v, dropout_p=self.attn_drop.p, scale=self.scale
            )
            .transpose(1, 2)
            .reshape(B, N, C)
        )

        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    
class MemoryAttention(nn.Module):

    def __init__(
        self, dim, rope=None, rope3d=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.rope = rope.float() if rope is not None else None
        self.rope3d = rope3d.float() if rope3d is not None else None

    def forward(self, mask, x, xpos, point3r_tag=False):
        B, N, C = x.shape

        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .transpose(1, 3)
        )
        q, k, v = [qkv[:, :, i] for i in range(3)]
        
        if xpos is not None:
            if xpos.shape[-1] == 2:
                q_type = q.dtype
                k_type = k.dtype
                if self.rope is not None:
                    q = q.float()
                    k = k.float()
                    with torch.autocast(device_type="cuda", enabled=False):
                        q = self.rope(q, xpos)
                        k = self.rope(k, xpos) # bs, num_heads, N, C // num_heads
                    q = q.to(q_type)
                    k = k.to(k_type)
            else:
                q_type = q.dtype
                k_type = k.dtype
                if self.rope3d is not None:
                    q = q.float()
                    k = k.float()
                    with torch.autocast(device_type="cuda", enabled=False):
                        q = self.rope3d(q, xpos)
                        k = self.rope3d(k, xpos)
                    q = q.to(q_type)
                    k = k.to(k_type)
        if mask is not None:
            mask = mask.bool()
            attn_mask = mask.unsqueeze(-1) & mask.unsqueeze(1)
            attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
            attn_mask = attn_mask.bool()
        else:
            attn_mask = None
        x = (
            scaled_dot_product_attention(
                query=q, key=k, value=v, attn_mask=attn_mask, dropout_p=self.attn_drop.p, scale=self.scale
            )
            .transpose(1, 2)
            .reshape(B, N, C)
        )

        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):

    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        rope=None,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            rope=rope,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
        )

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

    def forward(self, x, xpos):
        x = x + self.drop_path(self.attn(self.norm1(x), xpos))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class CrossAttention(nn.Module):

    def __init__(
        self, dim, rope=None, rope3d=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        self.projq = nn.Linear(dim, dim, bias=qkv_bias)
        self.projk = nn.Linear(dim, dim, bias=qkv_bias)
        self.projv = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.rope = rope.float() if rope is not None else None
        self.rope3d = rope3d.float() if rope3d is not None else None

    def forward(self, query, key, value, qpos, kpos, point3r_tag=False):
        B, Nq, C = query.shape
        Nk = key.shape[1]
        Nv = value.shape[1]

        q = (
            self.projq(query)
            .reshape(B, Nq, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )
        k = (
            self.projk(key)
            .reshape(B, Nk, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )
        v = (
            self.projv(value)
            .reshape(B, Nv, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )

        if kpos != None: 
            if kpos.shape[-1] == 2:
                q_type = q.dtype
                k_type = k.dtype
                if self.rope is not None:
                    if qpos is not None:
                        q = q.float()
                        with torch.autocast(device_type="cuda", enabled=False):
                            q = self.rope(q, qpos)
                        q = q.to(q_type)

                    if kpos is not None:
                        k = k.float()
                        with torch.autocast(device_type="cuda", enabled=False):
                            k = self.rope(k, kpos)
                        k = k.to(k_type)
            else:
                q_type = q.dtype
                k_type = k.dtype
                if self.rope is not None:
                    if qpos is not None:
                        q = q.float()
                        with torch.autocast(device_type="cuda", enabled=False):
                            q = self.rope(q, qpos)
                        q = q.to(q_type)
                if self.rope3d is not None:
                    if kpos is not None:
                        k = k.float()
                        with torch.autocast(device_type="cuda", enabled=False):
                            k = self.rope3d(k, kpos)
                        k = k.to(k_type)

        x = (
            scaled_dot_product_attention(
                query=q, key=k, value=v, dropout_p=self.attn_drop.p, scale=self.scale
            )
            .transpose(1, 2)
            .reshape(B, Nq, C)
        )

        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    
class DecodeCrossAttention(nn.Module):

    def __init__(
        self, dim, rope=None, rope3d=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        self.projq = nn.Linear(dim, dim, bias=qkv_bias)
        self.projk = nn.Linear(dim, dim, bias=qkv_bias)
        self.projv = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.rope = rope.float() if rope is not None else None
        self.rope3d = rope3d.float() if rope3d is not None else None

    def forward(self, mask, query, key, value, qpos, kpos, point3r_tag=False):
        B, Nq, C = query.shape
        Nk = key.shape[1]
        Nv = value.shape[1]

        q = (
            self.projq(query)
            .reshape(B, Nq, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )
        k = (
            self.projk(key)
            .reshape(B, Nk, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )
        v = (
            self.projv(value)
            .reshape(B, Nv, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )

        if kpos is not None: 
            if kpos.shape[-1] == 2:
                q_type = q.dtype
                k_type = k.dtype
                if self.rope is not None:
                    if qpos is not None:
                        q = q.float()
                        with torch.autocast(device_type="cuda", enabled=False):
                            q = self.rope(q, qpos)
                        q = q.to(q_type)

                    if kpos is not None:
                        k = k.float()
                        with torch.autocast(device_type="cuda", enabled=False):
                            k = self.rope(k, kpos)
                        k = k.to(k_type)
            else:
                
                q_type = q.dtype
                k_type = k.dtype
                if self.rope3d is not None:
                    if qpos is not None:
                        q = q.float()
                        with torch.autocast(device_type="cuda", enabled=False):
                            q = self.rope3d(q, qpos)
                        q = q.to(q_type)
                    if kpos is not None:
                        k = k.float()
                        with torch.autocast(device_type="cuda", enabled=False):
                            k = self.rope3d(k, kpos)
                        k = k.to(k_type)
        
        if mask is not None:
            mask = mask.bool()
            mask_img = torch.ones(B, Nq).to(mask.device).bool()
            attn_mask = mask_img.unsqueeze(-1) & mask.unsqueeze(1)
            attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
            attn_mask = attn_mask.bool()
        else:
            attn_mask = None

        x = (
            scaled_dot_product_attention(
                query=q, key=k, value=v, attn_mask=attn_mask, dropout_p=self.attn_drop.p, scale=self.scale
            )
            .transpose(1, 2)
            .reshape(B, Nq, C)
        )

        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    
class MemoryCrossAttention(nn.Module):

    def __init__(
        self, dim, rope=None, rope3d=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        self.projq = nn.Linear(dim, dim, bias=qkv_bias)
        self.projk = nn.Linear(dim, dim, bias=qkv_bias)
        self.projv = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.rope = rope.float() if rope is not None else None
        self.rope3d = rope3d.float() if rope3d is not None else None

    def forward(self, mask, query, key, value, qpos, kpos, point3r_tag=False):
        B, Nq, C = query.shape
        Nk = key.shape[1]
        Nv = value.shape[1]

        q = (
            self.projq(query)
            .reshape(B, Nq, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )
        k = (
            self.projk(key)
            .reshape(B, Nk, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )
        v = (
            self.projv(value)
            .reshape(B, Nv, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )

        if qpos is not None:
            if qpos.shape[-1] == 2:
                q_type = q.dtype
                k_type = k.dtype
                if self.rope is not None:
                    if qpos is not None:
                        q = q.float()
                        with torch.autocast(device_type="cuda", enabled=False):
                            q = self.rope(q, qpos)
                        q = q.to(q_type)

                    if kpos is not None:
                        k = k.float()
                        with torch.autocast(device_type="cuda", enabled=False):
                            k = self.rope(k, kpos)
                        k = k.to(k_type)
            else:
                q_type = q.dtype
                k_type = k.dtype
                if self.rope3d is not None:
                    if qpos is not None:
                        q = q.float()
                        with torch.autocast(device_type="cuda", enabled=False):
                            q = self.rope3d(q, qpos)
                        q = q.to(q_type)
                    if kpos is not None:
                        k = k.float()
                        with torch.autocast(device_type="cuda", enabled=False):
                            k = self.rope3d(k, kpos)
                        k = k.to(k_type)
        if mask is not None:
            mask = mask.bool()
            mask_img = torch.ones(B, Nk).to(mask.device).bool()
            attn_mask = mask.unsqueeze(-1) & mask_img.unsqueeze(1)
            attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
            attn_mask = attn_mask.bool()
        else:
            attn_mask = None
        x = (
            scaled_dot_product_attention(
                query=q, key=k, value=v, attn_mask=attn_mask, dropout_p=self.attn_drop.p, scale=self.scale
            )
            .transpose(1, 2)
            .reshape(B, Nq, C)
        )

        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class DecoderBlock(nn.Module):

    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        norm_mem=True,
        rope=None,
        rope3d=None,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = DecodeAttention(
            dim,
            rope=rope,
            rope3d=rope3d,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.cross_attn = DecodeCrossAttention(
            dim,
            rope=rope,
            rope3d=rope3d,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.norm3 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )
        self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()

    def forward(self, i, x, y, mask, xpos, ypos, point3r_tag=False):
        if i == 0:
            x = x + self.drop_path(self.attn(mask, self.norm1(x), None, point3r_tag))
            y_ = self.norm_y(y)
            x = x + self.drop_path(self.cross_attn(mask, self.norm2(x), y_, y_, None, None, point3r_tag))
            x = x + self.drop_path(self.mlp(self.norm3(x)))
        else:
            x = x + self.drop_path(self.attn(mask, self.norm1(x), xpos, point3r_tag))
            y_ = self.norm_y(y)
            x = x + self.drop_path(self.cross_attn(mask, self.norm2(x), y_, y_, xpos, ypos, point3r_tag))
            x = x + self.drop_path(self.mlp(self.norm3(x)))
        return x, y
    
class MemoryDecoderBlock(nn.Module):

    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        norm_mem=True,
        rope=None,
        rope3d=None,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = MemoryAttention(
            dim,
            rope=rope,
            rope3d=rope3d,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.cross_attn = MemoryCrossAttention(
            dim,
            rope=rope,
            rope3d=rope3d,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.norm3 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )
        self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()

    def forward(self, i, x, y, mask, xpos, ypos, point3r_tag=False):
        if i == 0:
            x = x + self.drop_path(self.attn(mask, self.norm1(x), None, point3r_tag))
            y_ = self.norm_y(y)
            x = x + self.drop_path(self.cross_attn(mask, self.norm2(x), y_, y_, None, None, point3r_tag))
            x = x + self.drop_path(self.mlp(self.norm3(x)))
        else:
            x = x + self.drop_path(self.attn(mask, self.norm1(x), xpos, point3r_tag))
            y_ = self.norm_y(y)
            x = x + self.drop_path(self.cross_attn(mask, self.norm2(x), y_, y_, xpos, ypos, point3r_tag))
            x = x + self.drop_path(self.mlp(self.norm3(x)))
        return x, y


class CustomDecoderBlock(nn.Module):

    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        norm_mem=True,
        rope=None,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            rope=rope,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.cross_attn = CrossAttention(
            dim,
            rope=rope,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.norm3 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )
        self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
        self.norm_z = norm_layer(dim) if norm_mem else nn.Identity()

    def forward(self, x, y, z, xpos, ypos):
        x = x + self.drop_path(self.attn(self.norm1(x), xpos))
        y_ = self.norm_y(y)
        z_ = self.norm_z(z)
        x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, z_, xpos, ypos))
        x = x + self.drop_path(self.mlp(self.norm3(x)))
        return x, y


class ModLN(nn.Module):
    """
    Modulation with adaLN.

    References:
    DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101
    """

    def __init__(self, inner_dim: int, mod_dim: int, eps: float):
        super().__init__()
        self.norm = nn.LayerNorm(inner_dim, eps=eps)
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(mod_dim, inner_dim * 2),
        )

    @staticmethod
    def modulate(x, shift, scale):

        return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

    def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
        shift, scale = self.mlp(mod).chunk(2, dim=-1)  # [N, D]
        return self.modulate(self.norm(x), shift, scale)  # [N, L, D]


class ConditionModulationBlock(nn.Module):

    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=partial(ModLN, eps=1e-6),
        rope=None,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim, dim)
        self.attn = Attention(
            dim,
            rope=rope,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim, dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

    def forward(self, x, mod, xpos):
        x = x + self.drop_path(self.attn(self.norm1(x, mod), xpos))
        x = x + self.drop_path(self.mlp(self.norm2(x, mod)))
        return x


class PositionGetter(object):

    def __init__(self):
        self.cache_positions = {}

    def __call__(self, b, h, w, device):
        if not (h, w) in self.cache_positions:
            x = torch.arange(w, device=device)
            y = torch.arange(h, device=device)
            self.cache_positions[h, w] = torch.cartesian_prod(y, x)  # (h, w, 2)
        pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone()
        return pos


class PatchEmbed(nn.Module):
    
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        norm_layer=None,
        flatten=True,
    ):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

        self.position_getter = PositionGetter()

    def forward(self, x):
        B, C, H, W = x.shape
        torch._assert(
            H == self.img_size[0],
            f"Input image height ({H}) doesn't match model ({self.img_size[0]}).",
        )
        torch._assert(
            W == self.img_size[1],
            f"Input image width ({W}) doesn't match model ({self.img_size[1]}).",
        )
        x = self.proj(x)
        pos = self.position_getter(B, x.size(2), x.size(3), x.device)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
        x = self.norm(x)
        return x, pos

    def _init_weights(self):
        w = self.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))