import torch
import torch.nn as nn
import numpy as np


class PositionEncoder(nn.Module):
    """
    Encode 7D position + direction vector into high-dimensional space, then map to 64D through linear layer
    Based on NeRF's positional encoding method
    """

    def __init__(
        self, input_dims=7, output_dims=64, num_freq_bands=10, include_input=True
    ):
        super().__init__()
        self.input_dims = input_dims
        self.include_input = include_input
        self.num_freq_bands = num_freq_bands

        # Calculate encoded dimensions
        self.encoded_dims = (
            input_dims * (1 + 2 * num_freq_bands)
            if include_input
            else input_dims * 2 * num_freq_bands
        )

        # Linear layer to map encoded vector to target dimension
        self.linear = nn.Linear(self.encoded_dims, output_dims)
        self.activation = nn.LeakyReLU()

        # Create frequency list
        self.frequencies = 2.0 ** torch.linspace(
            0.0, num_freq_bands - 1, num_freq_bands
        )

    def forward(self, x):
        """
        Input: [batch_size*max_camera_num, 7] tensor
        Output: [batch_size*max_camera_num, 64] tensor
        """

        # Initialize output list
        out_list = []

        # If include original input, add it first
        if self.include_input:
            out_list.append(x)

        # For each frequency, calculate sin and cos encoding
        for freq in self.frequencies:
            for in_dim in range(self.input_dims):
                out_list.append(torch.sin(freq * np.pi * x[:, in_dim : in_dim + 1]))
                out_list.append(torch.cos(freq * np.pi * x[:, in_dim : in_dim + 1]))

        # Concatenate all outputs
        encoded = torch.cat(out_list, dim=-1)

        # Through linear layer and ReLU
        output = self.activation(self.linear(encoded))

        return output
