from __future__ import print_function

from fedem.models import get_mobilenet
from module_torch.model.classifier import MultiClassifierBase


class MobileNetCIFAR10(MultiClassifierBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.moblienet = get_mobilenet(n_classes=10)

    def forward(self, *args, **kwargs):
        return self.moblienet(*args, **kwargs)
