import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(inplace=False),
            nn.Linear(512, 512),
            nn.ReLU(inplace=False),
            nn.Linear(512, 512),
            nn.ReLU(inplace=False),
            nn.Linear(512, 10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(x.size(0), -1)
        x = self.layers(x)
        return x


class MLPHomo(nn.Module):
    def __init__(self):
        super(MLPHomo, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(28 * 28, 512, bias=False),
            nn.ReLU(inplace=False),
            nn.Linear(512, 512, bias=False),
            nn.ReLU(inplace=False),
            nn.Linear(512, 512, bias=False),
            nn.ReLU(inplace=False),
            nn.Linear(512, 10, bias=False),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(x.size(0), -1)
        x = self.layers(x)
        return x
  
    
def test():
    # test the model homogeneity
    net = MLPHomo().double()
    print(net)

    # reinitailize the model using the random weights
    for name, param in net.named_parameters():
        param.data.copy_(torch.randn_like(param) * torch.rand_like(param))
    
    # get the random input
    X = torch.rand(1, 1, 28, 28).double()

    net.eval()
    Y = net(X)
    alpha = 10
    Y_ = net(X * alpha)
    print(torch.allclose(Y * alpha, Y_))


if __name__ == '__main__':
    test()