import torch.nn as nn


class GaussianReadoutModel(nn.Module):
    '''
    Model consisting of CNN backone model and Gaussian Readout
    '''
    def __init__(self, core_model, readout):
        '''
        Args:
            core_model: instance of class CoreModel
            readout: instance of class FullGaussian2d
        '''
        super().__init__()

        self.core_model = core_model
        self.core_model.eval()
        self.readout = readout

    def forward(self, x, return_features=False):

        z = self.core_model(x)
        if return_features:
            return self.readout(z, return_features=True)
        else:
            return self.readout(z)

    def eval(self):
        self.readout.eval()

