import torch
import torch.nn as nn


class ConvNet(nn.Module):

    def __init__(self, depth, widen_factor, num_classes=10, num_input_channels=3):
        super(ConvNet, self).__init__()
        w = 16 * widen_factor
        layers = []
        layers += [nn.Conv2d(num_input_channels, w, kernel_size=3, stride=1), nn.ReLU()]
        for i in range(1, depth-1):
            layers += [nn.Conv2d(w, w, kernel_size=3, stride=1), nn.ReLU()]
        layers += [nn.Conv2d(w, num_classes, kernel_size=3, stride=1), nn.ReLU()]
        self.conv_layers = nn.Sequential(*layers)

    def forward(self, input_dict):
        x = input_dict['inputs']
        out = self.conv_layers(x)
        out = out.mean(axis=(2, 3))
        return {
            'logits': out
        }