import torch
import torch.nn as nn
from ldm.modules.attention import BasicTransformerBlock
from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
import torch.nn.functional as F



class PositionNet(nn.Module):
    def __init__(self,  in_dim, out_dim, mid_dim=512, fourier_freqs=8, position_net_point_or_box='box'):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim 

        self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
        self.position_net_point_or_box = position_net_point_or_box
        if self.position_net_point_or_box == 'box':
            self.position_dim = fourier_freqs*2*4 # 2 is sin&cos, 4 is xyxy 
        elif self.position_net_point_or_box == 'point':
            self.position_dim = fourier_freqs*2*2 # 2 is sin&cos, 2 is c_x c_y 
        elif self.position_net_point_or_box is None:
            self.position_dim = 0 
            

        self.linears = nn.Sequential(
            nn.Linear( self.in_dim + self.position_dim, mid_dim),
            nn.SiLU(),
            nn.Linear( mid_dim, mid_dim),
            nn.SiLU(),
            nn.Linear(mid_dim, out_dim),
        )
        
        self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.in_dim]))
        self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
  

    def forward(self, boxes, centers, masks, positive_embeddings):
        if self.position_net_point_or_box == 'box':
            B, N, _ = boxes.shape 
            # embedding position (it may includes padding as placeholder)
            xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
        elif self.position_net_point_or_box == 'point':
            B, N, _ = centers.shape 
            # embedding position (it may includes padding as placeholder)
            xyxy_embedding = self.fourier_embedder(centers) # B*N*2 --> B*N*C
        elif self.position_net_point_or_box is None:
            xyxy_embedding = None
            
        masks = masks.unsqueeze(-1)

        # learnable null embedding 
        positive_null = self.null_positive_feature.view(1,1,-1)
        xyxy_null =  self.null_position_feature.view(1,1,-1)

        # replace padding with learnable null embedding 
        positive_embeddings = positive_embeddings*masks + (1-masks)*positive_null
        if xyxy_embedding is not None:
            xyxy_embedding = xyxy_embedding*masks + (1-masks)*xyxy_null
            objs = self.linears(  torch.cat([positive_embeddings, xyxy_embedding], dim=-1)  )
        else:
            objs = self.linears( positive_embeddings )

        #assert objs.shape == torch.Size([B,N,self.out_dim])        
        return objs



