import pdb

import torch
import torch.nn as nn
import torch.nn.functional as F


class SimpleConvNet(nn.Module):
  def __init__(self):
    super(SimpleConvNet, self).__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(1, 16, kernel_size=5, padding=2),
        # nn.BatchNorm2d(16),
        nn.ReLU(),
        nn.MaxPool2d(2))
    self.layer2 = nn.Sequential(
        nn.Conv2d(16, 32, kernel_size=5, padding=2),
        # nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.MaxPool2d(2))
    self.fc = nn.Linear(7*7*32, 10)

  def forward(self, x):
    out = self.layer1(x)
    out = self.layer2(out)
    out = out.view(out.size(0), -1)
    out = self.fc(out)
    return out


class MLP(nn.Module):
  def __init__(self, ninp=784, nhid=100, nout=10, nlayers=0, dropout=0,
               use_bias=True):
    super(MLP, self).__init__()

    self.ninp = ninp
    self.nhid = nhid
    self.nout = nout
    self.nlayers = nlayers
    self.dropout = dropout
    self.use_bias = use_bias

    modules = []
    if nlayers == 0:
      modules += [nn.Linear(ninp, nout, bias=use_bias)]
    else:
      modules += [nn.Linear(ninp, nhid, bias=use_bias), nn.ReLU(), nn.Dropout(dropout)]
      for l in range(nlayers-1):
        modules += [nn.Linear(nhid, nhid, bias=use_bias), nn.ReLU(), nn.Dropout(dropout)]
      modules += [nn.Linear(nhid, nout, bias=use_bias)]

    self.model = nn.Sequential(*modules)

  def forward(self, input):
    input = input.view(input.size(0), -1)
    return self.model(input)
