import torch.nn as nn
from models.base_model import SequentialModel


class ReLU_Net(SequentialModel):
    def __init__(
        self,
        widths=[16, 64],
        dims=[40, 10],
        bias=False,
        leaky_output=False,
    ):
        # output from the example
        # layers=nn.Sequential(
        #     nn.Linear(40, 16, bias=False),
        #     nn.ReLU(),
        #     nn.Linear(16, 64, bias=False),
        #     nn.ReLU(),
        #     nn.Linear(64, 10, bias=False),
        #     nn.ReLU()
        # )
        super(ReLU_Net, self).__init__()

        self.dims = dims

        num_layers = 2 * (len(widths) + 1)
        all_dims = [dims[0]] + widths + [dims[1]]
        layers = []

        for i in range(num_layers):
            if i % 2 == 0:
                j = i // 2
                layers.append(nn.Linear(all_dims[j], all_dims[j + 1], bias=bias))
            else:
                layers.append(nn.ReLU())

        if leaky_output:
            layers[-1] = nn.LeakyReLU()

        self.layers = nn.Sequential(*layers)
        self.layer_input_shapes = self.get_layer_input_shapes()
