from encoder.base import base_encoder
from spikingjelly.clock_driven import encoding, neuron
import torch
import numpy as np


class PoissonEncoder(base_encoder):
    def __init__(self, T, **kwargs):
        super().__init__(T=T, **kwargs)
        self.kernel = encoding.PoissonEncoder()


class IntegratingEncoder(base_encoder):

    def __init__(self, T, tau=np.inf):
        super().__init__(T=T)
        if tau == np.inf:
            self.kernel = neuron.IFNode()
        else:
            self.kernel = neuron.LIFNode(tau=tau)


class NormalizedPoissonEncoder(base_encoder):
    def __init__(self, T, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
        super().__init__(T=T)
        self.kernel = encoding.PoissonEncoder()
        self.mean = torch.tensor(mean)
        self.std = torch.tensor(std)
        # print(self.std)

    def encode(self, img):
        img = (img - self.mean.reshape(1, -1, 1, 1).to(img)) / self.std.reshape(1, -1, 1, 1).to(img)
        if self.T is None:
            return self.kernel(img).float() * torch.sign(img)
        else:
            self.reset()
            return [self.kernel(img).float() * torch.sign(img) for _ in range(self.T)]


if __name__ == "__main__":
    # e = PoissonEncoder(T=5)
    # e = IntergratingEncoder()

    # x = torch.rand(10, 9, 3)
    # for i in range(10):
    #     print( e.encode(x) )
    #
    # e.reset()
    # x = torch.rand(10, 9, 4)
    # for i in range(10):
    #     print( e.encode(x) )

    import json

    x = torch.rand(11, 3, 32, 32)
    e = NormalizedPoissonEncoder(T=10,
                                 **json.loads("{\"mean\":[0.4914, 0.4822, 0.4465], \"std\":[0.557, 0.549, 0.5534]}"))
    print(torch.stack(e.encode(x)).shape)
