import torch.nn as nn
import torch
import torch.nn.init as init


class StaticEncoder(nn.Module):
    def __init__(self, in_dim=10, out_dim=32):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, out_dim)
        )
        
        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    init.constant_(module.bias, 0.0)

    def forward(self, x):
        if x.shape[-1] != self.in_dim:
            if x.shape[-1] < self.in_dim:
                padding = torch.zeros(*x.shape[:-1], self.in_dim - x.shape[-1], 
                                    device=x.device, dtype=x.dtype)
                x = torch.cat([x, padding], dim=-1)
            else:
                x = x[..., :self.in_dim]
        
        output = self.mlp(x)
        
        return output