import torch
import torch.nn as nn

class DNN(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super(DNN, self).__init__()
        layers = []
        # 根据 input_dim 是否为 None 来选择使用 LazyLinear 或 Linear
        if input_dim is None:
            layers.append(nn.LazyLinear(hidden_dims[0]))
        else:
            layers.append(nn.Linear(input_dim, hidden_dims[0]))
        layers.append(nn.ReLU(inplace=False))
        prev_dim = hidden_dims[0]
        for h_dim in hidden_dims[1:]:
            layers.append(nn.Linear(prev_dim, h_dim))
            layers.append(nn.ReLU(inplace=False))
            prev_dim = h_dim
        layers.append(nn.Linear(prev_dim, output_dim))
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.model(x)

def DNN_small(input_dim, num_classes):
    return DNN(input_dim=input_dim, hidden_dims=[512, 256], output_dim=num_classes)