import torch
import torch.nn as nn
import numpy as np

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    def forward(self, x):
        N, C, H, W = x.size() # read in N, C, H, W
        return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image  
    

class Classifier(nn.Module):
    '''
    MNIST digit classifier.
    '''
    def __init__(self, input_channel=1, use_gpu=True, feature=False):
        super(Classifier, self).__init__()
        self.use_gpu = use_gpu
        self.feature_extract = feature
        self.feature = nn.Sequential(
            nn.Conv2d(input_channel, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 128, kernel_size=4, stride=1, padding=0),)

            
        self.classify = nn.Sequential(

            nn.LeakyReLU(0.2, inplace=False),
            Flatten(),
            nn.Linear(128, 10),
            nn.Softmax())
            
        #                 nn.Conv2d(128, 10, kernel_size=1, stride=1, padding=0)
        if self.use_gpu:
            self.type(torch.cuda.FloatTensor)


    def forward(self, input):
        if(self.feature_extract):
            return self.feature(input)
        return self.classify(self.feature(input))
