import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


class GaussianFourierFeatureTransform(nn.Module):
    """
    Given an input of size [batches, num_input_channels, ...],
     returns a tensor of size [batches, mapping_size*2, ...].
    """

    def __init__(self, in_channels, mapping_size=256, scale=10):
        super().__init__()

        self._num_input_channels = in_channels
        self._mapping_size = mapping_size
        self.out_channels = mapping_size * 2
        self.register_buffer("_B", torch.randn((in_channels, mapping_size)) * scale)

    def forward(self, x):
        assert len(x.shape) >= 3

        x = (x @ self._B)
        x = 2 * math.pi * x
        return torch.cat([torch.sin(x), torch.cos(x)], dim=-1)

    # def forward(self, weights, biases):
    #     out_weights, out_biases = [], []
    #     for w, b in zip(weights, biases):
    #         out_weights.append(self.encode_tensor(w))
    #         out_biases.append(self.encode_tensor(b))
    #     return out_weights, out_biases