import einops
import torch
from torch import nn

from src.models.kappa_overrides.mlp import Mlp
from src.modules.positional_embeddings import ContinuousSincosEmbed


class MlpProbeEncoding(nn.Module):
    
    def __init__(
            self,
            hidden_dim: int,
            ndim: int,
            field_dim: int,
    ):
        super().__init__()
        
        self.pos_embed = ContinuousSincosEmbed(dim=hidden_dim, ndim=ndim)
        
        self.mlp = Mlp(in_dim=hidden_dim*2,
                        hidden_dim=2*hidden_dim, 
                        out_dim=hidden_dim)
        
    def forward(
            self,
            x: torch.Tensor,
            pos: torch.Tensor,
    ) -> torch.Tensor:
        # assume that x has shape (batch_size n_probes dim)
        
        b, T, n, _ = x.shape
        
        x_pos_enc = einops.repeat(self.pos_embed(pos), 'b n d -> b T n d', T=T)
        velocity_enc = einops.rearrange(
                self.pos_embed(einops.rearrange(x, 'b T n d -> (b T) n d')),
                '(b T) n d -> b T n d', T=T)
        x_cat = torch.cat([velocity_enc, x_pos_enc], dim=-1)
        
        x = self.mlp(x_cat)
        return x, pos