import torch
from torch import nn


class NeuralNetwork(nn.Module):
    def __init__(self, in_size=50, hidden_size=1000, out_size=1, dtype=None):
        super(NeuralNetwork, self).__init__()

        if dtype is not None:
            self.net = nn.Sequential(
                nn.Linear(in_size, hidden_size, dtype=dtype),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size, dtype=dtype),
                nn.ReLU(),
                nn.Linear(hidden_size, out_size, dtype=dtype),
            )
        else:
            self.net = nn.Sequential(
                nn.Linear(in_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, out_size),
            )

    def forward(self, x):
        out = self.net(x)
        return out


class ParallelNetworks(nn.Module):
    def __init__(self, num_models, model_class, **model_class_init_args):
        super(ParallelNetworks, self).__init__()
        self.nets = nn.ModuleList(
            [model_class(**model_class_init_args) for i in range(num_models)]
        )

    def forward(self, xs):
        assert xs.shape[0] == len(self.nets)

        for i in range(len(self.nets)):
            out = self.nets[i](xs[i])
            if i == 0:
                outs = torch.zeros(
                    [len(self.nets)] + list(out.shape), device=out.device
                )
            outs[i] = out
        return outs
