import torch


class SklearnModelWrapper:
    """
        Base class for an ExactGP.
        Contains:
        - Prediction
        - Loss functions
        - Fitting procedure

        """

    def __init__(self, model):
        super(SklearnModelWrapper, self).__init__()
        self.model = model

    def forward(self, x):
        """
        We don't need a forward pass for a sklearn model.

        Function that feeds the data through the model.
        This function is dependent on the model.

        :param x: input data
        """
        raise NotImplementedError

    def predict(self, dataloader):
        """
        Function that predicts the label on x.

        :param dataloader: input dataloader (tupel)
        """
        batches = False if type(dataloader) == tuple else True
        if batches:
            raise NotImplementedError("This model does not predict with pytorch dataloaders.")
        else:
            x, _ = dataloader
            x = x.reshape(-1, 1) if len(x.shape) == 1 else x  # make sure the data has a batch dimension
            predictions_mean = self.model.predict(x.numpy())

        output = {'predictions': None,
                  'mean': torch.FloatTensor(predictions_mean),
                  'stddev': torch.tensor([-1])}
        return output

    def loss_func(self):
        """
         We don't need a loss function for a sklearn model.

        Function that return the loss function.
        """
        raise NotImplementedError

    def fit(self, train_data, args=None, debug=False):
        """
        Function that fits (train) the model on the data (x,y).

        :param train_data: tuple or dataloader with (features / input data, label / output data)
        :param args: arguments
        """
        x, y = train_data
        x = x.reshape(-1, 1) if len(x.shape) == 1 else x  # make sure the data has a batch dimension
        self.model.fit(x.numpy(), y.numpy())
        loss = self.model.score(x.numpy(), y.numpy())
        return loss, [loss], None
