import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

import pdb


class Classifier(nn.Module):
    def __init__(self, hidden_layers_classifier, input_dim=2, feat_dim=64, num_class=10):
        super(Classifier, self).__init__()

        self.classifier = nn.Sequential()
        self.classifier.add_module('hidden_layer{}'.format(0), nn.Sequential(
            nn.Linear(input_dim, feat_dim),
            nn.ReLU(inplace=True)
        ))
        for i in range(1, hidden_layers_classifier - 1):
            self.classifier.add_module('hidden_layer{}'.format(i), nn.Sequential(
                nn.Linear(feat_dim, feat_dim),
                nn.ReLU(inplace=True)
            ))

        self.classifier.add_module('output_layer', nn.Linear(feat_dim, num_class))

    def forward(self, z):
        logit = self.classifier(z)
        return logit

