import torch
import torch.nn as nn


class Estimator(nn.Module):
    def __init__(self, num_blocks=9, num_ops=7, num_pex=3, num_pey=5, num_rf=5, num_df=3):
        super(Estimator, self).__init__()
        len_hw_params = num_pex + num_pey + num_rf + num_df
        self.layers = nn.ModuleList()
        self.layers.append(
            nn.Sequential(
                nn.Linear(num_blocks * num_ops + len_hw_params, 256),
                nn.ReLU()
            )
        )
        for i in range(3):
            self.layers.append(Block(256))
        self.layers.append(
            nn.Linear(256, 3),
        )

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class Block(nn.Module):
    def __init__(self, num_features):
        super(Block, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(num_features, num_features),
            nn.ReLU(),
        )

    def forward(self, x):
        residual = x
        out = self.layer(x)
        out += residual
        return x

