# Copyright (C) king.com Ltd 2025
# License: Apache 2.0
import torch


class MLP(torch.nn.Module):
    def __init__(self,
                 input_dim,
                 output_dim,
                 hidden_dim=64,
                 num_hidden_layers=2,
                 activation=torch.nn.ReLU,
                 sqaush_output=False
                 ):
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_hidden_layers = num_hidden_layers
        self.activation = activation
        self.sqaush_output = sqaush_output

        self.network = torch.nn.Sequential()
        self.network.add_module('input', torch.nn.Linear(input_dim, hidden_dim))
        self.network.add_module('input_activation', activation())
        for i in range(num_hidden_layers):
            self.network.add_module(f'hidden_{i}', torch.nn.Linear(hidden_dim, hidden_dim))
            self.network.add_module(f'hidden_{i}_activation', activation())

        self.network.add_module('output', torch.nn.Linear(hidden_dim, output_dim))
        if sqaush_output:
            self.network.add_module('output_squash', torch.nn.Tanh())

    def forward(self, x):
        return self.network(x)