from torch import nn
import torch


class MLP(nn.Module):
    def __init__(self, input_dim=2,
                 structure_str='100-100-100',
                 relu=False,
                 ):
        super().__init__()
        self.relu = relu
        current_dim = input_dim + 1
        hidden_dims = [int(i) for i in structure_str.split('-')]
        self.layers = []
        for i, hidden_dim in enumerate(hidden_dims):
            if self.relu:
                self.layers.append(nn.Sequential(nn.Linear(current_dim, hidden_dim, bias=True),
                                                 nn.ReLU(),
                                                 ))
            else:
                self.layers.append(nn.Sequential(nn.Linear(current_dim, hidden_dim, bias=True),
                                                 nn.Tanh(),
                                                 ))
            current_dim = hidden_dim
        self.layers.append(nn.Linear(current_dim, input_dim))
        self.model = nn.Sequential(*self.layers)


    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x_input, t):
        inputs = torch.cat([x_input, t], dim=1)
        x = inputs
        for i, layer in enumerate(self.layers):
            x = layer(x)
        return x


class ModelWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, *args):
        output = self.model(*args)
        return output