import torch

from .base import Model


class FeedForwardNet(Model):
    name = "ff_net"

    def __init__(self,
                 dim=None,
                 weights=None,
                 **kwargs):
        super().__init__(**kwargs)

        if dim is None:
            dim = []
        self.dim = dim
        self.weights = weights

        self.net = None

    @staticmethod
    def build_net(input_dim, hidden_dims, output_dim=1, weights=None):
        net = []
        for hidden_dim in hidden_dims:
            net.append(torch.nn.Linear(input_dim, hidden_dim))
            net.append(torch.nn.ReLU())
            input_dim = hidden_dim
        net.append(torch.nn.Linear(input_dim, output_dim))
        net = torch.nn.ModuleList(net)

        if weights is not None:
            if len(hidden_dims) != 0:
                raise NotImplementedError

            # Only works for linear models (and is generally unadvisable)
            net[0].bias.data = weights[0]
            net[0].weight.data[0, :] = weights[1:]

        return net

    def setup(self, stage):
        if stage != "fit":
            return

        self.net = self.build_net(self.feat_dim, self.dim, output_dim=1, weights=self.weights)

    def forward(self, feat, _sens):
        for layer in self.net:
            feat = layer(feat)
        return feat[..., 0]
