import torch.nn as nn
from rff.layers import GaussianEncoding


class RFFNet(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        hidden_features,
        hidden_layers,
        outermost_linear=True,
        sigma=1.0,
    ):
        super().__init__()

        self.net = []
        self.net.append(
            GaussianEncoding(
                input_size=in_features,
                encoded_size=hidden_features // 2,
                sigma=sigma,
            )
        )
        self.net.append(nn.Linear(hidden_features, hidden_features))
        self.net.append(nn.ReLU())

        for i in range(hidden_layers):
            self.net.append(nn.Linear(hidden_features, hidden_features))
            self.net.append(nn.ReLU())

        self.net.append(nn.Linear(hidden_features, out_features))
        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        output = self.net(coords)
        return output
