import torch.nn as nn
from torchvision.models import googlenet


def reshape_googlenet(net, input_shape=(3, 224, 224), num_classes=1000):
    if input_shape[0] != 3:
        net.conv1.conv = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3)
        nn.init.trunc_normal_(net.conv1.conv.weight, mean=0.0, std=0.01, a=-2, b=2)

    if num_classes != 1000:
        net.fc = nn.Linear(net.fc.in_features, num_classes)
        nn.init.trunc_normal_(net.fc.weight, mean=0.0, std=0.01, a=-2, b=2)

    return net


def GoogLeNet(input_shape=(3, 224, 224), num_classes=1000):
    return reshape_googlenet(googlenet(aux_logits=False), input_shape, num_classes)
