import torch
import torch.nn as nn
import torch.nn.functional as F

# class ConvSmall(nn.Module):
#     def __init__(self):
#         super(ConvSmall, self).__init__()
#         self.n_weights = 11330
#         self.weight = torch.nn.Parameter(torch.zeros(self.n_weights), requires_grad=True)

#         # initialization
#         self.weight.data.normal_(0.0, 0.02)

#     def forward(self, x):
#         conv1_weight, conv1_bias = self.weight[0:250].view(10,1,5,5), self.weight[250:260].view(10)
#         conv2_weight, conv2_bias = self.weight[260:2760].view(10,10,5,5), self.weight[2760:2770].view(10)
#         fc1_weight, fc1_bias = self.weight[2770:10770].view(50,160), self.weight[10770:10820].view(50)
#         fc2_weight, fc2_bias = self.weight[10820:11320].view(10,50), self.weight[11320:11330].view(10)
#         out = x
#         out = F.conv2d(out, conv1_weight, bias=conv1_bias)
#         out = F.relu(F.max_pool2d(out, 2))
#         out = F.conv2d(out, conv2_weight, bias=conv2_bias)
#         out = F.relu(F.max_pool2d(out, 2))
#         print (out.shape)
#         out = F.linear(out.view(-1, 160), fc1_weight, bias=fc1_bias)
#         out = F.relu(out)
#         out = F.linear(out, fc2_weight, bias=fc2_bias)
#         return out
class ConvSmall(nn.Module):
  def __init__(self):
    super(ConvSmall, self).__init__()

    # define layers
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5)
    self.conv2 = nn.Conv2d(in_channels=10, out_channels=16, kernel_size=5)

    self.fc1 = nn.Linear(in_features=16*4*4, out_features=120)
    self.fc2 = nn.Linear(in_features=120, out_features=60)
    self.out = nn.Linear(in_features=60, out_features=10)

  # define forward function
  def forward(self, t):
    # conv 1
    t = self.conv1(t)
    t = F.relu(t)
    t = F.max_pool2d(t, kernel_size=2, stride=2)

    # conv 2
    t = self.conv2(t)
    t = F.relu(t)
    t = F.max_pool2d(t, kernel_size=2, stride=2)

    # fc1
    t = t.reshape(-1, 16*4*4)
    t = self.fc1(t)
    t = F.relu(t)

    # fc2
    t = self.fc2(t)
    t = F.relu(t)

    # output
    t = self.out(t)
    # don't need softmax here since we'll use cross-entropy as activation.

    return t