# model output function
# base class that needs to be inherited in all cases

class BaseModelOutputClass():
    def __init__(self):
        pass

    @staticmethod
    def model_output( data, model, *args, **kwargs):
        '''
        The model output function in TRAK paper
        :param data: data for model input, not restricted for the data type
        :param model: model to be traced
        '''
        pass

    @staticmethod
    def loss_grad_to_out(data, model, *args, **kwargs):
        '''
        The variable Q in TRAK paper
        :param data: data for model input, not restricted for the data type
        :param model: model to be traced
        '''
        pass
