import torch
import torch.nn as nn


class DNN(nn.Module):
    def __init__(self, input_dim=5, hidden_dim1=10, hidden_dim2=10, output_dim=1):
        super(DNN, self).__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim1)
        self.layer2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.layer3 = nn.Linear(hidden_dim2, output_dim)

        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.layer1(x)
        x = self.activation(x)
        x = self.layer2(x)
        x = self.activation(x)
        x = self.layer3(x)
        return x


# 示例使用
if __name__ == "__main__":
    model = DNN(input_dim=5, hidden_dim1=10, hidden_dim2=10, output_dim=1)
    print(model)

    input_sample = torch.randn(3, 5)
    output = model(input_sample)

    print("\n输入形状:", input_sample.shape)
    print("输出形状:", output.shape)
    print("示例输出:", output)
