## A generic interface for the machine learning models used in this work.
import torch
import torch.nn as nn
import typing as tp


class BaseMLModel(nn.Module):

    def __init__(self, weighted_model: bool, train_set_size: int):
        """ Init a base model with data weights if weighted_model = True. In this case, 
            the model has an additional parameter data_weights_vector storing a weight for each data point
            in the train set.
            :param weighted
            :param train_set_size
        """
        super().__init__()

        self.weighted_model = weighted_model
        self.train_set_size = train_set_size

        if weighted_model:
            assert self.train_set_size is not None
            self.data_weights_vector = nn.Parameter(torch.ones(self.train_set_size), requires_grad=False)
        else: 
            self.data_weights_vector = None

    def reset_data_weights_and_size(self, new_data_size: int):
        """ Reset the data weights of the model to all ones with a data set size of "new_data_size".
            This function allows to change the lenght of the data weights vector associated with a weighted model.
            :param new_data_size: New size of the data weights vector in the model.
        """
        assert self.weighted_model
        self.train_set_size = new_data_size
        self.data_weights_vector = nn.Parameter(torch.ones(self.train_set_size), requires_grad=False)

    def fit(self, X: torch.tensor, y: torch.tensor):
        """ Fit the model to N records with D features and labels Y. 
            :param X: data tensor of shape [N, D]
            :param y: label tensor of shape [N], labels are binary with values 0, 1
        """
        raise NotImplementedError("Method is not implemented in BaseMLModel.")

    def forward(self, Xtest: torch.tensor) -> torch.tensor:
        """ Prediction method of the models.
            Returns numeric values for regression models 
            and probabilities for classification models.
            :param Xtest: input data of shape [M, D]
            :returns: [M]-tensor with the predictions.
        """
        raise NotImplementedError("Method is not implemented in BaseMLModel.")

    def predict_with_logits(self, Xtest: torch.tensor) -> torch.tensor:
        """ Return the pre-sigmoid activations for classification models. 
            For regression methods, this function is equivalent to forward.
            :param Xtest: input data of shape [B, D]
            :returns: [B]-tensor with the predictions.
        """
        raise NotImplementedError("Method is not implemented in BaseMLModel.")

    def get_all_params(self):
        """ Return a [1, N] vectorized version of all model parameters. """
        raise NotImplementedError("Method is not implemented in BaseMLModel.")

    def predict_from_parameters(self, x: torch.tensor, parameters: torch.tensor) -> torch.tensor:
        """
            Predict logits using different model parameters (instead the ones stored with this object) supplied as inputs.
            param x: [B, D] inputs. 
            :param parameters: C different weight settings to use, prepended by an additional batch dimension,
                 e.g. shape [C, R] for a weight of dimension [R]
                 or [C, R, Q] for a weight matrix of dimension [R, Q]
            :returns: [B, C]-matrix.
        """
        raise NotImplementedError("Method is not implemented in BaseMLModel.")
    
    def compute_parameters_from_data_weights(self, data_weights: torch.tensor, X: torch.tensor, y: torch.tensor):
        """ Compute model weight change under soft removal of specific points.
            Note: this function should be differentiable w.r.t. data_weights to be useful in 
            end-to-end gradient descent optimization.
        """
        raise NotImplementedError("Method is not implemented in BaseMLModel.")

    def parameter_change_under_removal(self, X: torch.tensor, y: torch.tensor, ind: torch.tensor = Ellipsis):
        """
            Compute model weight change under removal of specific points. The updated parameters are the current ones
            + the change returned by this function.
            (this can be seen as a discrete implementation of theta (s))
            :param X: data tensor of shape [N, D]
            :param y: label tensor of shape [N]
            :param ind: indices of the points where the weight change after removal is desired. Index tensor of lenght [C]
            return: A tensor of shape [C, num_params], where each row corresponds to the weight change with one parameter removed.
        """
        raise NotImplementedError("Method is not implemented in BaseMLModel.")

    def loss_objective(parameters: torch.tensor, X: torch.tensor, y: torch.tensor, data_weights: torch.tensor = None,) -> torch.tensor:
        """ A differentiable loss function w.r.t to data weights, the full input data 
            and current model parameters. The data_weights vector is passed as an explicit parameter, for 
            optimization with the infinitissimal jacknife. Usually pass model.data_weights_vector.
        """
        raise NotImplementedError("Method is not implemented in BaseMLModel.")

    def loss_objective_for_jackknife(self, X: torch.tensor, y: torch.tensor):
        """ A variant of the loss_objective, that fulfills the jacknife interface.
            returns: A function(weights, theta) that evalutates the loss
            (for a single data weights of shape(W) and thetas of shape (D+1))
        """
        return lambda al, th: self.loss_objective(th.reshape(1,-1), X, y, al)
