import torch
import torch.nn as nn

from kappamodules.layers import ContinuousSincosEmbed, LinearProjection

class Patchify(nn.Module):

    def __init__(self, 
                 hidden_dim: int,
                 field_dim: int,
                 ndim: int,
                 ):
        super().__init__()
        
        self.field_proj = LinearProjection(field_dim, hidden_dim)
        self.pos_embed = ContinuousSincosEmbed(dim=hidden_dim, ndim=ndim)
        self.pos_proj = LinearProjection(hidden_dim, hidden_dim)
        self.out_proj = nn.Sequential(
            # nn.GELU(),
            LinearProjection(hidden_dim, hidden_dim),
        )
        
    def forward(self, 
                field: torch.Tensor,
                input_pos: torch.Tensor
                ):
        
        posenc = self.pos_embed(input_pos)
        posenc = self.pos_proj(posenc)
        field_emb = self.field_proj(field)
        x = posenc + field_emb
        
        x = self.out_proj(x)
        return x
        