from torch import nn
from utils.fmodule import FModule

class Model(FModule):
    def __init__(self):
        super(Model, self).__init__()
        self.input_require_shape = [3, -1, -1]
        self.task = 'cifar10'
        self.shared_model = nn.Sequential(
            nn.Conv2d(3, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.local_model = nn.Sequential(
            nn.Linear(1600, 384),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(384, 192),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(192, 10)
        )

    def forward(self, x):
        output = self.shared_model(x)
        output = output.view(-1, 1600)
        output = self.local_model(output)
        return output
