import numpy as np

class BaseModel:

    def __init__(self):
        pass

    def fit(self):
        raise NotImplementedError()
    
    def pred(self, x):
        '''
        Args:
        - x:    [ nbatch, nX ] np
        Return:
        - y:    [ nbatch, nY ] np
        '''
        raise NotImplementedError()

    def contain(self, x, y, lam):
        '''
        Args:
        - x:    [ nbatch, nX ] np
        - y:    [ nbatch, nY ] np
        - lam:  [ nbatch ] np
        Return:
        - loss:   [ nbatch ] np
        '''
        ypred   = self.pred(x)  # [ nbatch, nY ] np
        mask    = np.linalg.norm(y - ypred, ord=np.inf, axis=1) < lam # [ nbatch ] np
        return mask.astype('bool') 
    
class OffsetScalarPredictor(BaseModel):

    def __init__(self, offset):
        '''
        Args:
        - offset:   [ nY ] list or array
        '''
        super().__init__()
        self.offset = np.asarray(offset)

    def pred(self, x, nY = None):
        '''
        Args:
        - x:    [ nbatch, nX ] np
        Return:
        - y:    [ nbatch, nY ] np
        '''
        if nY is None:
            nY = x.shape[1]
        nbatch = x.shape[0]
        return np.zeros([nbatch, nY]) + self.offset[None, :]