import torch
import torch.nn as nn


class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        # 第一个卷积组：Conv -> ReLU -> MaxPool
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 第二个卷积组：Conv -> ReLU -> MaxPool
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 全连接层：需要根据卷积结果来计算特征尺寸
        # CIFAR图像尺寸 32×32，经两次2×2池化后尺寸变为 8×8
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        # 输出层：类别数（比如 CIFAR-10 为10）
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        # 第一卷积层 + 池化
        x = self.pool1(torch.relu(self.conv1(x)))
        # 第二卷积层 + 池化
        x = self.pool2(torch.relu(self.conv2(x)))

        # 展平
        x = x.view(x.size(0), -1)

        # 全连接层
        x = torch.relu(self.fc1(x))

        # 最后一层输出 raw logits
        x = self.fc2(x)
        # 如果用 nn.CrossEntropyLoss，则不需要额外调用 SoftMax，
        # 因为 CrossEntropyLoss 内部会处理 logits 的 SoftMax
        return x