import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import einsum, rearrange, repeat


"""
Wrapper module for converting real valued input/output to
spike trains and rates respectively
"""


class FourierFeatureLayer(nn.Module):
    def __init__(self, num_channels, num_features_per_channel):
        super().__init__()
        self.num_channels = num_channels
        self.num_features_per_channel = num_features_per_channel
        assert (
            num_features_per_channel % 2 == 0
        ), "num_features_per_channel must be even"

        # generate random fourier feature frequencies
        self.freq = nn.Parameter(
            torch.randn(1, num_features_per_channel // 2) / 20
        )  # (c, c1//2)
        self.freq_bias = nn.Parameter(
            torch.zeros(1, num_features_per_channel // 2)
        )  # (c, c1//2)

    def forward(self, x):
        # compute fourier features
        batch_size = x.size(0)
        x = rearrange(x, "b c l -> (b c l) 1")
        fouriers = torch.cat(
            [
                torch.cos(2 * torch.pi * x @ self.freq + self.freq_bias),
                torch.sin(2 * torch.pi * x @ self.freq + self.freq_bias),
            ],
            dim=-1,
        )
        fouriers = rearrange(
            fouriers,
            "(b c l) c1 -> b (c c1) l",
            b=batch_size,
            c=self.num_channels,
            c1=self.num_features_per_channel,
        )
        return fouriers


class CountWrapper(nn.Module):
    def __init__(self, ae_net, use_sin_enc=False):
        super().__init__()
        self.ae_net = ae_net
        self.use_sin_enc = use_sin_enc
        if self.use_sin_enc:

            # replace ae_net.encoder_in with FourierFeatureLayer
            # old was nn.Conv1d(C_in, C, 1, groups=in_groups)

            fourier_feature_layer = FourierFeatureLayer(
                ae_net.C_in, (ae_net.C // ae_net.C_in)
            )
            self.encoder_in = fourier_feature_layer

    def forward(self, x):
        # x: [B, C_in, L]
        # rates: [B, C_in, L]
        logrates, z = self.ae_net(x)
        return F.softplus(logrates), z

    def encode(self, x):
        return self.ae_net.encode(x)

    def decode(self, z):
        return F.softplus(self.ae_net.decode(z))


if __name__ == "__main__":

    # test CountWrapper
    from ntldm.networks import AutoEncoder

    autoencoder = AutoEncoder(C_in=8, C=256, C_latent=8, L=500, kernel="s4")
    count_wrapper = CountWrapper(autoencoder, use_sin_enc=True)
    x = torch.randn(10, 8, 500)
    rates, z = count_wrapper(x)
    print(rates.shape, z.shape)
    print(count_wrapper)
    print("CountWrapper test passed")
