import torch
from torch import nn

class ConvFusionBlock(nn.Module):
    def __init__(self, pc_hidden_size, vis_hidden_size, output_size, window_radius):
        super().__init__()

        self.pc_proj = nn.Linear(pc_hidden_size, output_size // 2)
        self.vis_proj = nn.Conv2d(vis_hidden_size, output_size // 2, kernel_size = window_radius * 2 + 1)
        
        self.fusion = nn.Sequential(
                nn.Linear(output_size, output_size * 2),
                nn.ReLU(),
                nn.Linear(output_size * 2, output_size)
            )
        
        self.norm = nn.BatchNorm1d(output_size)

    def forward(self, feats_pc, feats_vis):
        feats_pc = self.pc_proj(feats_pc)
        feats_vis = self.vis_proj(feats_vis).squeeze(-1).squeeze(-1)

        input_states = torch.cat([feats_pc, feats_vis], dim = -1)
        hidden_states = self.fusion(input_states)
        hidden_states = self.norm(input_states + hidden_states)

        return hidden_states
    
class LearnableFourierPositionalEncoding(nn.Module):
    def __init__(self, M: int, dim: int, F_dim: int = None,
                 gamma: float = 1.0) -> None:  # [2, head_dim, head_dim]
        super().__init__()
        F_dim = F_dim if F_dim is not None else dim
        self.gamma = gamma
        self.Wr = nn.Linear(M, F_dim // 2, bias=False)
        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # [B, N, 2]
        """ encode position vector """
        projected = self.Wr(x)
        cosines, sines = torch.cos(projected), torch.sin(projected)
        emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
        return emb.repeat_interleave(2, dim=-1)  # [2, B, 1, N, head_dim]