import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
    def __init__(self, num_classes):
        super(LeNet, self).__init__()
        # 1）前两层保持原样：conv1→pool→conv2
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5, padding=2)
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)    # 第一次池化
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, padding=0)
        # 2）第二次池化改为自适应固定输出 6×6
        #    不管输入是 224×224、128×128、还是 64×64，
        #    在 conv2 后, adaptive_pool2 会把空间维拉成 6×6
        self.adaptive_pool2 = nn.AdaptiveAvgPool2d((6, 6))
        # 3）此时全连接层输入特征数就是 16*6*6=576
        self.fc1 = nn.Linear(16 * 6 * 6, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        # x: [batch, 3, H, W]，H, W 任意只要 >= (kernel+stride) 都行
        x = F.relu(self.conv1(x))      # → [batch, 6, H,   W]
        x = self.pool1(x)              # → [batch, 6, H/2, W/2]
        x = F.relu(self.conv2(x))      # → [batch,16, H/2-4, W/2-4]（空间大小随输入变化）
        # 自适应池化到 [batch, 16, 6,  6]
        x = self.adaptive_pool2(x)
        x = x.view(x.size(0), -1)      # → [batch, 16*6*6=576]
        x = F.relu(self.fc1(x))        # → [batch, 120]
        x = F.relu(self.fc2(x))        # → [batch,  84]
        x = self.fc3(x)                # → [batch, num_class]
        return x

#