import torch
import torch.nn as nn


class MLP(nn.Module):
    def __init__(self, architecture=[1, 256, 128, 2], activation=nn.LeakyReLU(), final_activation=nn.Tanh(),
                 bias=True, residual=False):
        super(MLP, self).__init__()
        self.activation = activation
        self.architecture = architecture
        self.final_activation = final_activation
        self.residual = residual
        arch = nn.ModuleList()
        for i in range(1, len(architecture) - 1):
            if residual and i != 1:
                assert architecture[i - 1] == architecture[i]
            arch.append(nn.Linear(architecture[i - 1], architecture[i], bias=bias))
        self.basis = arch
        self.regressor = nn.Sequential(nn.Linear(architecture[-2], architecture[-1], bias=bias),
                                       final_activation)

    def forward(self, x):
        assert x.shape[1] == self.architecture[0]
        for i, layer in enumerate(self.basis):
            if self.residual and i > 0:
                x = self.activation(layer(x) + x)
            else:
                x = self.activation(layer(x))
        out = self.regressor(x)
        return out

    def get_layers(self):
        layers = []
        for l in self.basis:
            if isinstance(l, nn.Linear):
                layers.append(l)
        layers.append(self.regressor[0])
        return layers
