class base_encoder:

    def __init__(self, T=None, **kwargs):
        self.kernel = None
        self.T = T
        pass

    def reset(self):
        if hasattr(self.kernel, 'reset'):
            self.kernel.reset()

    def encode(self, img):
        if self.T is None:
            return self.kernel(img).float()
        else:
            self.reset()
            return [self.kernel(img).float() for _ in range(self.T)]


class ImgEncoder(base_encoder):
    def __init__(self, T, **kwargs):
        super().__init__(**kwargs)
        self.T = T
        self.kernel = lambda x:x

class EventEncoder(base_encoder):
    def __init__(self, T, **kwargs):
        super().__init__(T = T, **kwargs)

    def encode(self, img_series):
        assert(img_series.shape[1]==self.T)
        assert(len(img_series.shape)==5)
        return img_series.permute(1, 0, 2, 3, 4).float()


if __name__ == "__main__":

    # e = PoissonEncoder()
    e = ImgEncoder()

    import torch
    x = torch.rand(10, 9, 3)
    for i in range(10):
        print( e.encode(x) )

    e.reset()
    x = torch.rand(10, 9, 4)
    for i in range(10):
        print( e.encode(x) )