import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square convolution
        # kernel
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, 
            kernel_size=5, stride=1, padding=2, bias=True)
        self.relu1 = nn.ReLU(inplace=False)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, 
            kernel_size=5, stride=1, padding=0, bias=True)
        self.relu2 = nn.ReLU(inplace=False)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 6*6 from image dimension 
        self.relu3 = nn.ReLU(inplace=False)
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU(inplace=False)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(self.relu1(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(self.relu2(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = self.relu3(self.fc1(x))
        x = self.relu4(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
    

class LeNet_BN(nn.Module):
    def __init__(self):
        super(LeNet_BN, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square convolution
        # kernel
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, 
            kernel_size=5, stride=1, padding=2, bias=True)
        self.bn1 = nn.BatchNorm2d(6)
        self.relu1 = nn.ReLU(inplace=False)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, 
            kernel_size=5, stride=1, padding=0, bias=True)
        self.bn2 = nn.BatchNorm2d(16)
        self.relu2 = nn.ReLU(inplace=False)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 6*6 from image dimension
        self.bn3 = nn.BatchNorm1d(120)
        self.relu3 = nn.ReLU(inplace=False)
        self.fc2 = nn.Linear(120, 84)
        self.bn4 = nn.BatchNorm1d(84)
        self.relu4 = nn.ReLU(inplace=False)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(self.relu1(self.bn1(self.conv1(x))), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(self.relu2(self.bn2(self.conv2(x))), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = self.relu3(self.bn3(self.fc1(x)))
        x = self.relu4(self.bn4(self.fc2(x)))
        x = self.fc3(x)
        return x
    
    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
    

class LeNetHomo(LeNet):
    def __init__(self):
        super(LeNetHomo, self).__init__()
        # reinitialize the layers
        # 1 input image channel, 6 output channels, 3x3 square convolution
        # kernel
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6,
            kernel_size=5, stride=1, padding=2, bias=False)
        self.relu1 = nn.ReLU(inplace=False)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16,
            kernel_size=5, stride=1, padding=0, bias=False)
        self.relu2 = nn.ReLU(inplace=False)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120, bias=False)  # 6*6 from image dimension
        self.relu3 = nn.ReLU(inplace=False)
        self.fc2 = nn.Linear(120, 84, bias=False)
        self.relu4 = nn.ReLU(inplace=False)
        self.fc3 = nn.Linear(84, 10, bias=False)


def test():
    # test the model homogeneity
    net = LeNetHomo().double()
    print(net)

    # reinitailize the model using the random weights
    for name, param in net.named_parameters():
        param.data.copy_(torch.randn_like(param) * torch.rand_like(param))
    
    # get the random input
    X = torch.rand(1, 1, 28, 28).double()

    net.eval()
    Y = net(X)
    alpha = 10
    Y_ = net(X * alpha)
    print(torch.allclose(Y * alpha, Y_))


if __name__ == '__main__':
    test()