
import torch
import torch.nn as nn
class DenseBlock(nn.Module):
    def __init__(self, in_channels):
        super(DenseBlock, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(num_features=in_channels)
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=96, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=128, out_channels=32, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        bn = self.bn(x)
        conv1 = self.relu(self.conv1(bn))
        conv2 = self.relu(self.conv2(conv1))
        # Concatenate in channel dimension
        c2_dense = self.relu(torch.cat([conv1, conv2], 1))
        conv3 = self.relu(self.conv3(c2_dense))
        c3_dense = self.relu(torch.cat([conv1, conv2, conv3], 1))
        conv4 = self.relu(self.conv4(c3_dense))
        c4_dense = self.relu(torch.cat([conv1, conv2, conv3, conv4], 1))
        conv5 = self.relu(self.conv5(c4_dense))
        c5_dense = self.relu(torch.cat([conv1, conv2, conv3, conv4, conv5], 1))
        return c5_dense

# 定义 Transition 层
class Transition(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Transition, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.meanpool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)

    def forward(self, x):
        bn = self.bn(self.relu(self.conv(x)))
        out = self.meanpool(bn)
        return out

# 定义 DenseNet 模型类
class DenseNet(nn.Module):
    def __init__(self, num_classes=10):
        super(DenseNet, self).__init__()

        self.lowconv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=3, bias=False)
        self.relu = nn.ReLU()
        # Make Dense Blocks
        self.DenseBlock1 = self._make_dense_block(DenseBlock, 64)
        self.DenseBlock2 = self._make_dense_block(DenseBlock, 128)
        self.DenseBlock3 = self._make_dense_block(DenseBlock, 128)
        # Make transition Layers
        self.Transition1 = self._make_transition_layer(Transition, in_channels=160, out_channels=128)
        self.Transition2 = self._make_transition_layer(Transition, in_channels=160, out_channels=128)
        self.Transition3 = self._make_transition_layer(Transition, in_channels=160, out_channels=64)
        # Classifier
        self.bn = nn.BatchNorm2d(num_features=64)
        self.pre_classifier = nn.Linear(64 * 4 * 4, 512)
        self.classifier = nn.Linear(512, num_classes)

    def _make_dense_block(self, block, in_channels):
        layers = []
        layers.append(block(in_channels))
        return nn.Sequential(*layers)

    def _make_transition_layer(self, layer, in_channels, out_channels):
        modules = []
        modules.append(layer(in_channels, out_channels))
        return nn.Sequential(*modules)

    def forward(self, x):
        # 卷积神经网络用于特征提取
        out = self.relu(self.lowconv(x))  # 32x32x3 -> 32x32x64
        out = self.DenseBlock1(out)  # 32x32x64 -> 32x32x160
        out = self.Transition1(out)  # 32x32x160 -> 16x16x128
        out = self.DenseBlock2(out)  # 16x16x128 -> 16x16x160
        out = self.Transition2(out)  # 16x16x160 -> 8x8x128
        out = self.DenseBlock3(out)  # 8x8x128 -> 8x8x160
        out = self.Transition3(out)  # 8x8x160 -> 4x4x64
        out = self.bn(out)  # 4x4x64 -> 4x4x64
        out = out.view(-1, 64*4*4)  # 4x4x64 -> 1024
        # 全连接神经网络用于分类
        out = self.pre_classifier(out)  # 1024 -> 512
        out = self.classifier(out)  # 512 -> 10
        return out

