import torch
import torch.nn as nn
import torch.nn.functional as F


class Inception(nn.Module):
    def __init__(self, in_channels, n_kernels):
        super(Inception, self).__init__()
        # 1x1 convolution
        self.conv1x1 = nn.Conv2d(in_channels, n_kernels, kernel_size=1)
        # 1x1 convolution followed by 3x3 convolution
        self.conv1x1_3x3 = nn.Sequential(
            nn.Conv2d(in_channels, n_kernels, kernel_size=1),
            nn.Conv2d(n_kernels, n_kernels, kernel_size=3, padding=1)
        )
        # 1x1 convolution followed by 5x5 convolution
        self.conv1x1_5x5 = nn.Sequential(
            nn.Conv2d(in_channels, n_kernels, kernel_size=1),
            nn.Conv2d(n_kernels, n_kernels, kernel_size=5, padding=2)
        )
        # 3x3 max pooling followed by 1x1 convolution
        self.pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, n_kernels, kernel_size=1)
        )

    def forward(self, x):
        # Apply each branch of the Inception module
        conv1x1_out = F.relu(self.conv1x1(x))
        conv1x1_3x3_out = F.relu(self.conv1x1_3x3(x))
        conv1x1_5x5_out = F.relu(self.conv1x1_5x5(x))
        pool_out = F.relu(self.pool(x))

        # Concatenate all branches
        return torch.cat([conv1x1_out, conv1x1_3x3_out, conv1x1_5x5_out, pool_out], dim=1)



class GoogleNet(nn.Module):
    def __init__(self, in_channels=3, n_kernels=16, out_dim=10):
        super(GoogleNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, n_kernels, kernel_size=7, stride=2, padding=3)
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(n_kernels, 2 * n_kernels, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(2 * n_kernels, 4 * n_kernels, kernel_size=3, stride=1, padding=1)

        self.pool2 = nn.AdaptiveAvgPool2d((7, 7))

        self.fc1 = nn.Linear(4 * n_kernels * 7 * 7, 1000)
        self.fc2 = nn.Linear(1000, 500)
        self.fc3 = nn.Linear(500, out_dim)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = F.relu(self.conv2(x))
        x = self.pool2(F.relu(self.conv3(x)))

        # print(f"Before flattening: {x.shape}")  # 确保它是 (batch, channels, 7, 7)
        x = x.view(x.shape[0], -1)
        # print(f"After flattening: {x.shape}")  # 确保它是 (batch, 4*n_kernels*7*7)

        x = F.relu(self.fc1(x))
        o = F.relu(self.fc2(x))
        x = self.fc3(o)
        return x

