
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # input channel = 1, output channel = 6, kernel_size = 5
        # input size = (32, 32), output size = (28, 28)
        self.conv1 = nn.Conv2d(1, 6, 5)
        # input channel = 6, output channel = 16, kernel_size = 5
        # input size = (14, 14), output size = (10, 10)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # input dim = 16*5*5, output dim = 120
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        # input dim = 120, output dim = 84
        self.fc2 = nn.Linear(120, 84)
        # input dim = 84, output dim = 10
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # pool size = 2
        # input size = (28, 28), output size = (14, 14), output channel = 6
        # print(f'输入层: leaf={x.is_leaf}, grad_fn={x.grad_fn}')
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        # print(f'第一层卷积: leaf={x.is_leaf}')
        # pool size = 2
        # input size = (10, 10), output size = (5, 5), output channel = 16
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        # print(f'第二层卷积: leaf={x.is_leaf}', x.is_leaf)
        embedding = x.clone().detach()
        
        # flatten as one dimension
        x = x.view(x.size()[0], -1)
        # print(f'flatten操作: leaf={x.is_leaf}')
        # input dim = 16*5*5, output dim = 120
        x = F.relu(self.fc1(x))
        # print(f'第一层线性: leaf={x.is_leaf}')
        # input dim = 120, output dim = 84
        x = F.relu(self.fc2(x))
        # print(f'第二层线性: leaf={x.is_leaf}')
        # input dim = 84, output dim = 10
        x = self.fc3(x)
        # print(f'输出层: leaf={x.is_leaf}')
        
        return x, embedding

