# Provided by Ximeng Sun
import torch
import torch.nn as nn



class Classifier(nn.Module):
    def __init__(self, n_channels, num_class=1, ndf=32):
        super(Classifier, self).__init__()

        self.main = nn.Sequential(
            nn.Conv3d(n_channels, ndf, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv3d(ndf, ndf * 2, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
            nn.BatchNorm3d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv3d(ndf * 2, ndf * 4, 4, stride=(1, 2, 2), padding=(1, 1, 1), bias=False),
            nn.BatchNorm3d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv3d(ndf * 4, ndf * 8, 4, stride=(1, 2, 2), padding=(1, 1, 1), bias=False),
            nn.BatchNorm3d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv3d(ndf * 8, num_class, 4, 1, padding=(1, 0, 0), bias=False),
        )
        
    def forward(self, input):
        h = self.main(input).squeeze()
        return h
        # if not hasattr(self, 'feature_extractor'):
        #     self.feature_extractor = nn.Sequential(*list(self.main.children())[:-1])
        # feature = self.feature_extractor(input)
        # return h, feature