# coding=utf-8
from . import utils, resnet_cond
import torch.nn as nn
import torch 

@utils.register_model(name='classifier')
class Classifier(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.classifier = resnet_cond.resnet_cond(config, model_type=config.classifier.model)
    
  def forward(self, x, cond, embed=False):
    if type(x)==list:
      x = torch.cat(x,dim=1)
    x = self.classifier(x, cond, embed)
    return x

