"""
MLP models for SDF neural networks - simplified for LAMP generation
"""
import torch
import torch.nn.functional as F
from torch import nn
from .embedder import Embedder


class MLP3D(nn.Module):
    """3D MLP for SDF representation with positional encoding"""
    
    def __init__(
        self,
        out_size=1,
        hidden_neurons=[1024, 1024, 1024],
        use_leaky_relu=False,
        use_bias=True,
        multires=10,
        **kwargs,
    ):
        super().__init__()
        self.embedder = Embedder(
            include_input=True,
            input_dims=3,
            max_freq_log2=multires - 1,
            num_freqs=multires,
            log_sampling=True,
            periodic_fns=[torch.sin, torch.cos],
        )
        self.layers = nn.ModuleList([])
        self.use_leaky_relu = use_leaky_relu
        
        in_size = self.embedder.out_dim
        self.layers.append(nn.Linear(in_size, hidden_neurons[0], bias=use_bias))
        for i, _ in enumerate(hidden_neurons[:-1]):
            self.layers.append(
                nn.Linear(hidden_neurons[i], hidden_neurons[i + 1], bias=use_bias)
            )
        self.layers.append(nn.Linear(hidden_neurons[-1], out_size, bias=use_bias))

    def forward(self, model_input):
        if isinstance(model_input, dict):
            coords_org = model_input["coords"].clone().detach().requires_grad_(True)
        else:
            coords_org = model_input.clone().detach().requires_grad_(True)
            
        x = coords_org
        x = self.embedder.embed(x)
        for i, layer in enumerate(self.layers[:-1]):
            x = layer(x)
            x = F.leaky_relu(x) if self.use_leaky_relu else F.relu(x)
        x = self.layers[-1](x)
        
        if isinstance(model_input, dict):
            return {"model_out": x}
        else:
            return x
