# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Encoding functions
"""

import itertools
from abc import abstractmethod
from typing import Literal, Optional, Sequence

import numpy as np
import torch
import torch.nn.functional as F
from jaxtyping import Float, Int, Shaped
from torch import Tensor, nn

from nerfstudio.field_components.base_field_component import FieldComponent
from nerfstudio.utils.external import TCNN_EXISTS, tcnn
from nerfstudio.utils.math import components_from_spherical_harmonics, expected_sin, generate_polyhedron_basis
from nerfstudio.utils.printing import print_tcnn_speed_warning

import ipdb
class Encoding(FieldComponent):
    """Encode an input tensor. Intended to be subclassed

    Args:
        in_dim: Input dimension of tensor
    """

    def __init__(self, in_dim: int) -> None:
        if in_dim <= 0:
            raise ValueError("Input dimension should be greater than zero")
        super().__init__(in_dim=in_dim)

    @classmethod
    def get_tcnn_encoding_config(cls) -> dict:
        """Get the encoding configuration for tcnn if implemented"""
        raise NotImplementedError("Encoding does not have a TCNN implementation")

    @abstractmethod
    def forward(self, in_tensor: Shaped[Tensor, "*bs input_dim"]) -> Shaped[Tensor, "*bs output_dim"]:
        """Call forward and returns and processed tensor

        Args:
            in_tensor: the input tensor to process
        """
        raise NotImplementedError


class Identity(Encoding):
    """Identity encoding (Does not modify input)"""

    def get_out_dim(self) -> int:
        if self.in_dim is None:
            raise ValueError("Input dimension has not been set")
        return self.in_dim

    def forward(self, in_tensor: Shaped[Tensor, "*bs input_dim"]) -> Shaped[Tensor, "*bs output_dim"]:
        return in_tensor


class ScalingAndOffset(Encoding):
    """Simple scaling and offset to input

    Args:
        in_dim: Input dimension of tensor
        scaling: Scaling applied to tensor.
        offset: Offset applied to tensor.
    """

    def __init__(self, in_dim: int, scaling: float = 1.0, offset: float = 0.0) -> None:
        super().__init__(in_dim)

        self.scaling = scaling
        self.offset = offset

    def get_out_dim(self) -> int:
        if self.in_dim is None:
            raise ValueError("Input dimension has not been set")
        return self.in_dim

    def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
        return self.scaling * in_tensor + self.offset


class NeRFEncoding(Encoding):
    """Multi-scale sinusoidal encodings. Support ``integrated positional encodings`` if covariances are provided.
    Each axis is encoded with frequencies ranging from 2^min_freq_exp to 2^max_freq_exp.

    Args:
        in_dim: Input dimension of tensor
        num_frequencies: Number of encoded frequencies per axis
        min_freq_exp: Minimum frequency exponent
        max_freq_exp: Maximum frequency exponent
        include_input: Append the input coordinate to the encoding
    """

    def __init__(
        self,
        in_dim: int,
        num_frequencies: int,
        min_freq_exp: float,
        max_freq_exp: float,
        include_input: bool = False,
        implementation: Literal["tcnn", "torch"] = "torch",
    ) -> None:
        super().__init__(in_dim)

        self.num_frequencies = num_frequencies
        self.min_freq = min_freq_exp
        self.max_freq = max_freq_exp
        self.include_input = include_input

        self.tcnn_encoding = None
        if implementation == "tcnn" and not TCNN_EXISTS:
            print_tcnn_speed_warning("NeRFEncoding")
        elif implementation == "tcnn":
            assert min_freq_exp == 0, "tcnn only supports min_freq_exp = 0"
            assert max_freq_exp == num_frequencies - 1, "tcnn only supports max_freq_exp = num_frequencies - 1"
            encoding_config = self.get_tcnn_encoding_config(num_frequencies=self.num_frequencies)
            self.tcnn_encoding = tcnn.Encoding(
                n_input_dims=in_dim,
                encoding_config=encoding_config,
            )

    @classmethod
    def get_tcnn_encoding_config(cls, num_frequencies) -> dict:
        """Get the encoding configuration for tcnn if implemented"""
        encoding_config = {"otype": "Frequency", "n_frequencies": num_frequencies}
        return encoding_config

    def get_out_dim(self) -> int:
        if self.in_dim is None:
            raise ValueError("Input dimension has not been set")
        out_dim = self.in_dim * self.num_frequencies * 2
        if self.include_input:
            out_dim += self.in_dim
        return out_dim

    def pytorch_fwd(
        self,
        in_tensor: Float[Tensor, "*bs input_dim"],
        covs: Optional[Float[Tensor, "*bs input_dim input_dim"]] = None,
    ) -> Float[Tensor, "*bs output_dim"]:
        """Calculates NeRF encoding. If covariances are provided the encodings will be integrated as proposed
            in mip-NeRF.

        Args:
            in_tensor: For best performance, the input tensor should be between 0 and 1.
            covs: Covariances of input points.
        Returns:
            Output values will be between -1 and 1
        """
        scaled_in_tensor = 2 * torch.pi * in_tensor  # scale to [0, 2pi]
        freqs = 2 ** torch.linspace(self.min_freq, self.max_freq, self.num_frequencies, device=in_tensor.device)
        scaled_inputs = scaled_in_tensor[..., None] * freqs  # [..., "input_dim", "num_scales"]
        scaled_inputs = scaled_inputs.view(*scaled_inputs.shape[:-2], -1)  # [..., "input_dim" * "num_scales"]

        if covs is None:
            encoded_inputs = torch.sin(torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1))
        else:
            input_var = torch.diagonal(covs, dim1=-2, dim2=-1)[..., :, None] * freqs[None, :] ** 2
            input_var = input_var.reshape((*input_var.shape[:-2], -1))
            encoded_inputs = expected_sin(
                torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1), torch.cat(2 * [input_var], dim=-1)
            )
        return encoded_inputs

    def forward(
        self, in_tensor: Float[Tensor, "*bs input_dim"], covs: Optional[Float[Tensor, "*bs input_dim input_dim"]] = None
    ) -> Float[Tensor, "*bs output_dim"]:
        if self.tcnn_encoding is not None:
            encoded_inputs = self.tcnn_encoding(in_tensor)
        else:
            encoded_inputs = self.pytorch_fwd(in_tensor, covs)
        if self.include_input:
            encoded_inputs = torch.cat([encoded_inputs, in_tensor], dim=-1)
        return encoded_inputs


class FFEncoding(Encoding):
    """Fourier Feature encoding. Supports integrated encodings.

    Args:
        in_dim: Input dimension of tensor
        basis: Basis matrix from which to construct the Fourier features.
        num_frequencies: Number of encoded frequencies per axis
        min_freq_exp: Minimum frequency exponent
        max_freq_exp: Maximum frequency exponent
        include_input: Append the input coordinate to the encoding
    """

    def __init__(
        self,
        in_dim: int,
        basis: Float[Tensor, "M N"],
        num_frequencies: int,
        min_freq_exp: float,
        max_freq_exp: float,
        include_input: bool = False,
    ) -> None:
        super().__init__(in_dim)
        self.num_frequencies = num_frequencies
        self.min_freq = min_freq_exp
        self.max_freq = max_freq_exp
        self.register_buffer(name="b_matrix", tensor=basis)
        self.include_input = include_input

    def get_out_dim(self) -> int:
        if self.in_dim is None:
            raise ValueError("Input dimension has not been set")
        assert isinstance(self.b_matrix, Tensor)
        out_dim = self.b_matrix.shape[1] * self.num_frequencies * 2
        if self.include_input:
            out_dim += self.in_dim
        return out_dim

    def forward(
        self,
        in_tensor: Float[Tensor, "*bs input_dim"],
        covs: Optional[Float[Tensor, "*bs input_dim input_dim"]] = None,
    ) -> Float[Tensor, "*bs output_dim"]:
        """Calculates FF encoding. If covariances are provided the encodings will be integrated as proposed
            in mip-NeRF.

        Args:
            in_tensor: For best performance, the input tensor should be between 0 and 1.
            covs: Covariances of input points.

        Returns:
            Output values will be between -1 and 1
        """
        scaled_in_tensor = 2 * torch.pi * in_tensor  # scale to [0, 2pi]
        scaled_inputs = scaled_in_tensor @ self.b_matrix  # [..., "num_frequencies"]
        freqs = 2 ** torch.linspace(self.min_freq, self.max_freq, self.num_frequencies, device=in_tensor.device)
        scaled_inputs = scaled_inputs[..., None] * freqs  # [..., "input_dim", "num_scales"]
        scaled_inputs = scaled_inputs.view(*scaled_inputs.shape[:-2], -1)  # [..., "input_dim" * "num_scales"]

        if covs is None:
            encoded_inputs = torch.sin(torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1))
        else:
            input_var = torch.sum((covs @ self.b_matrix) * self.b_matrix, -2)
            input_var = input_var[..., :, None] * freqs[None, :] ** 2
            input_var = input_var.reshape((*input_var.shape[:-2], -1))
            encoded_inputs = expected_sin(
                torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1), torch.cat(2 * [input_var], dim=-1)
            )

        if self.include_input:
            encoded_inputs = torch.cat([encoded_inputs, in_tensor], dim=-1)

        return encoded_inputs


class RFFEncoding(FFEncoding):
    """Random Fourier Feature encoding. Supports integrated encodings.

    Args:
        in_dim: Input dimension of tensor
        num_frequencies: Number of encoding frequencies
        scale: Std of Gaussian to sample frequencies. Must be greater than zero
        include_input: Append the input coordinate to the encoding
    """

    def __init__(self, in_dim: int, num_frequencies: int, scale: float, include_input: bool = False) -> None:
        if not scale > 0:
            raise ValueError("RFF encoding scale should be greater than zero")

        b_matrix = torch.normal(mean=0, std=scale, size=(in_dim, num_frequencies))
        super().__init__(in_dim, b_matrix, 1, 0.0, 0.0, include_input)


class PolyhedronFFEncoding(FFEncoding):
    """Fourier Feature encoding using polyhedron basis as proposed by mip-NeRF360. Supports integrated encodings.

    Args:
        num_frequencies: Number of encoded frequencies per axis
        min_freq_exp: Minimum frequency exponent
        max_freq_exp: Maximum frequency exponent
        basis_shape: Shape of polyhedron basis. Either "octahedron" or "icosahedron"
        basis_subdivisions: Number of times to tesselate the polyhedron.
        include_input: Append the input coordinate to the encoding
    """

    def __init__(
        self,
        num_frequencies: int,
        min_freq_exp: float,
        max_freq_exp: float,
        basis_shape: Literal["octahedron", "icosahedron"] = "octahedron",
        basis_subdivisions: int = 1,
        include_input: bool = False,
    ) -> None:
        basis_t = generate_polyhedron_basis(basis_shape, basis_subdivisions).T
        super().__init__(3, basis_t, num_frequencies, min_freq_exp, max_freq_exp, include_input)


class HashEncoding(Encoding):
    """Hash encoding

    Args:
        num_levels: Number of feature grids.
        min_res: Resolution of smallest feature grid.
        max_res: Resolution of largest feature grid.
        log2_hashmap_size: Size of hash map is 2^log2_hashmap_size.
        features_per_level: Number of features per level.
        hash_init_scale: Value to initialize hash grid.
        implementation: Implementation of hash encoding. Fallback to torch if tcnn not available.
        interpolation: Interpolation override for tcnn hashgrid. Not supported for torch unless linear.
    """

    def __init__(
        self,
        num_levels: int = 16,
        min_res: int = 16,
        max_res: int = 1024,
        log2_hashmap_size: int = 19,
        features_per_level: int = 2,
        hash_init_scale: float = 0.001,
        implementation: Literal["tcnn", "torch"] = "tcnn",
        interpolation: Optional[Literal["Nearest", "Linear", "Smoothstep"]] = None,
    ) -> None:
        super().__init__(in_dim=3)
        self.num_levels = num_levels
        self.min_res = min_res
        self.features_per_level = features_per_level
        self.hash_init_scale = hash_init_scale
        self.log2_hashmap_size = log2_hashmap_size
        self.hash_table_size = 2**log2_hashmap_size

        levels = torch.arange(num_levels)
        self.growth_factor = np.exp((np.log(max_res) - np.log(min_res)) / (num_levels - 1)) if num_levels > 1 else 1
        self.scalings = torch.floor(min_res * self.growth_factor**levels)

        self.hash_offset = levels * self.hash_table_size

        self.tcnn_encoding = None
        self.hash_table = torch.empty(0)
        if implementation == "torch":
            self.build_nn_modules()
        elif implementation == "tcnn" and not TCNN_EXISTS:
            print_tcnn_speed_warning("HashEncoding")
            self.build_nn_modules()
        elif implementation == "tcnn":
            encoding_config = self.get_tcnn_encoding_config(
                num_levels=self.num_levels,
                features_per_level=self.features_per_level,
                log2_hashmap_size=self.log2_hashmap_size,
                min_res=self.min_res,
                growth_factor=self.growth_factor,
                interpolation=interpolation,
            )
            self.tcnn_encoding = tcnn.Encoding(
                n_input_dims=3,
                encoding_config=encoding_config,
            )

        if self.tcnn_encoding is None:
            assert (
                interpolation is None or interpolation == "Linear"
            ), f"interpolation '{interpolation}' is not supported for torch encoding backend"

    def build_nn_modules(self) -> None:
        """Initialize the torch version of the hash encoding."""
        self.hash_table = torch.rand(size=(self.hash_table_size * self.num_levels, self.features_per_level)) * 2 - 1
        self.hash_table *= self.hash_init_scale
        self.hash_table = nn.Parameter(self.hash_table)

    @classmethod
    def get_tcnn_encoding_config(
        cls, num_levels, features_per_level, log2_hashmap_size, min_res, growth_factor, interpolation=None
    ) -> dict:
        """Get the encoding configuration for tcnn if implemented"""
        encoding_config = {
            "otype": "HashGrid",
            "n_levels": num_levels,
            "n_features_per_level": features_per_level,
            "log2_hashmap_size": log2_hashmap_size,
            "base_resolution": min_res,
            "per_level_scale": growth_factor,
        }
        if interpolation is not None:
            encoding_config["interpolation"] = interpolation
        return encoding_config

    def get_out_dim(self) -> int:
        return self.num_levels * self.features_per_level

    def hash_fn(self, in_tensor: Int[Tensor, "*bs num_levels 3"]) -> Shaped[Tensor, "*bs num_levels"]:
        """Returns hash tensor using method described in Instant-NGP

        Args:
            in_tensor: Tensor to be hashed
        """

        # min_val = torch.min(in_tensor)
        # max_val = torch.max(in_tensor)
        # assert min_val >= 0.0
        # assert max_val <= 1.0

        in_tensor = in_tensor * torch.tensor([1, 2654435761, 805459861]).to(in_tensor.device)
        x = torch.bitwise_xor(in_tensor[..., 0], in_tensor[..., 1])
        x = torch.bitwise_xor(x, in_tensor[..., 2])
        x %= self.hash_table_size
        x += self.hash_offset.to(x.device)
        return x

    def pytorch_fwd(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
        """Forward pass using pytorch. Significantly slower than TCNN implementation."""

        assert in_tensor.shape[-1] == 3
        in_tensor = in_tensor[..., None, :]  # [..., 1, 3]
        scaled = in_tensor * self.scalings.view(-1, 1).to(in_tensor.device)  # [..., L, 3]
        scaled_c = torch.ceil(scaled).type(torch.int32)
        scaled_f = torch.floor(scaled).type(torch.int32)

        offset = scaled - scaled_f

        hashed_0 = self.hash_fn(scaled_c)  # [..., num_levels]
        hashed_1 = self.hash_fn(torch.cat([scaled_c[..., 0:1], scaled_f[..., 1:2], scaled_c[..., 2:3]], dim=-1))
        hashed_2 = self.hash_fn(torch.cat([scaled_f[..., 0:1], scaled_f[..., 1:2], scaled_c[..., 2:3]], dim=-1))
        hashed_3 = self.hash_fn(torch.cat([scaled_f[..., 0:1], scaled_c[..., 1:2], scaled_c[..., 2:3]], dim=-1))
        hashed_4 = self.hash_fn(torch.cat([scaled_c[..., 0:1], scaled_c[..., 1:2], scaled_f[..., 2:3]], dim=-1))
        hashed_5 = self.hash_fn(torch.cat([scaled_c[..., 0:1], scaled_f[..., 1:2], scaled_f[..., 2:3]], dim=-1))
        hashed_6 = self.hash_fn(scaled_f)
        hashed_7 = self.hash_fn(torch.cat([scaled_f[..., 0:1], scaled_c[..., 1:2], scaled_f[..., 2:3]], dim=-1))

        f_0 = self.hash_table[hashed_0]  # [..., num_levels, features_per_level]
        f_1 = self.hash_table[hashed_1]
        f_2 = self.hash_table[hashed_2]
        f_3 = self.hash_table[hashed_3]
        f_4 = self.hash_table[hashed_4]
        f_5 = self.hash_table[hashed_5]
        f_6 = self.hash_table[hashed_6]
        f_7 = self.hash_table[hashed_7]

        f_03 = f_0 * offset[..., 0:1] + f_3 * (1 - offset[..., 0:1])
        f_12 = f_1 * offset[..., 0:1] + f_2 * (1 - offset[..., 0:1])
        f_56 = f_5 * offset[..., 0:1] + f_6 * (1 - offset[..., 0:1])
        f_47 = f_4 * offset[..., 0:1] + f_7 * (1 - offset[..., 0:1])

        f0312 = f_03 * offset[..., 1:2] + f_12 * (1 - offset[..., 1:2])
        f4756 = f_47 * offset[..., 1:2] + f_56 * (1 - offset[..., 1:2])

        encoded_value = f0312 * offset[..., 2:3] + f4756 * (
            1 - offset[..., 2:3]
        )  # [..., num_levels, features_per_level]

        return torch.flatten(encoded_value, start_dim=-2, end_dim=-1)  # [..., num_levels * features_per_level]

    def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
        if self.tcnn_encoding is not None:
            return self.tcnn_encoding(in_tensor)
        return self.pytorch_fwd(in_tensor)


class TensorCPEncoding(Encoding):
    """Learned CANDECOMP/PARFAC (CP) decomposition encoding used in TensoRF

    Args:
        resolution: Resolution of grid.
        num_components: Number of components per dimension.
        init_scale: Initialization scale.
    """

    def __init__(self, resolution: int = 256, num_components: int = 24, init_scale: float = 0.1) -> None:
        super().__init__(in_dim=3)

        self.resolution = resolution
        self.num_components = num_components

        # TODO Learning rates should be different for these
        self.line_coef = nn.Parameter(init_scale * torch.randn((3, num_components, resolution, 1)))

    def get_out_dim(self) -> int:
        return self.num_components

    def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
        line_coord = torch.stack([in_tensor[..., 2], in_tensor[..., 1], in_tensor[..., 0]])  # [3, ...]
        line_coord = torch.stack([torch.zeros_like(line_coord), line_coord], dim=-1)  # [3, ...., 2]

        # Stop gradients from going to sampler
        line_coord = line_coord.view(3, -1, 1, 2).detach()

        line_features = F.grid_sample(self.line_coef, line_coord, align_corners=True)  # [3, Components, -1, 1]

        features = torch.prod(line_features, dim=0)
        features = torch.moveaxis(features.view(self.num_components, *in_tensor.shape[:-1]), 0, -1)

        return features  # [..., Components]

    @torch.no_grad()
    def upsample_grid(self, resolution: int) -> None:
        """Upsamples underyling feature grid

        Args:
            resolution: Target resolution.
        """

        line_coef = F.interpolate(self.line_coef.data, size=(resolution, 1), mode="bilinear", align_corners=True)
        self.line_coef = torch.nn.Parameter(line_coef)

        self.resolution = resolution


class TensorVMEncoding(Encoding):
    """Learned vector-matrix encoding proposed by TensoRF

    Args:
        resolution: Resolution of grid.
        num_components: Number of components per dimension.
        init_scale: Initialization scale.
    """

    plane_coef: Float[Tensor, "3 num_components resolution resolution"]
    line_coef: Float[Tensor, "3 num_components resolution 1"]

    def __init__(
        self,
        resolution: int = 128,
        num_components: int = 24,
        init_scale: float = 0.1,
    ) -> None:
        super().__init__(in_dim=3)

        self.resolution = resolution
        self.num_components = num_components

        self.plane_coef = nn.Parameter(init_scale * torch.randn((3, num_components, resolution, resolution)))
        self.line_coef = nn.Parameter(init_scale * torch.randn((3, num_components, resolution, 1)))

    def get_out_dim(self) -> int:
        return self.num_components * 3

    def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
        """Compute encoding for each position in in_positions

        Args:
            in_tensor: position inside bounds in range [-1,1],

        Returns: Encoded position
        """
        plane_coord = torch.stack([in_tensor[..., [0, 1]], in_tensor[..., [0, 2]], in_tensor[..., [1, 2]]])  # [3,...,2]
        line_coord = torch.stack([in_tensor[..., 2], in_tensor[..., 1], in_tensor[..., 0]])  # [3, ...]
        line_coord = torch.stack([torch.zeros_like(line_coord), line_coord], dim=-1)  # [3, ...., 2]

        # Stop gradients from going to sampler
        plane_coord = plane_coord.view(3, -1, 1, 2).detach()
        line_coord = line_coord.view(3, -1, 1, 2).detach()

        plane_features = F.grid_sample(self.plane_coef, plane_coord, align_corners=True)  # [3, Components, -1, 1]
        line_features = F.grid_sample(self.line_coef, line_coord, align_corners=True)  # [3, Components, -1, 1]

        features = plane_features * line_features  # [3, Components, -1, 1]
        features = torch.moveaxis(features.view(3 * self.num_components, *in_tensor.shape[:-1]), 0, -1)

        return features  # [..., 3 * Components]

    @torch.no_grad()
    def upsample_grid(self, resolution: int) -> None:
        """Upsamples underlying feature grid

        Args:
            resolution: Target resolution.
        """
        plane_coef = F.interpolate(
            self.plane_coef.data, size=(resolution, resolution), mode="bilinear", align_corners=True
        )
        line_coef = F.interpolate(self.line_coef.data, size=(resolution, 1), mode="bilinear", align_corners=True)

        self.plane_coef, self.line_coef = torch.nn.Parameter(plane_coef), torch.nn.Parameter(line_coef)
        self.resolution = resolution


class TriplaneEncoding(Encoding):
    """Learned triplane encoding

    The encoding at [i,j,k] is an n dimensional vector corresponding to the element-wise product of the
    three n dimensional vectors at plane_coeff[i,j], plane_coeff[i,k], and plane_coeff[j,k].

    This allows for marginally more expressivity than the TensorVMEncoding, and each component is self standing
    and symmetrical, unlike with VM decomposition where we needed one component with a vector along all the x, y, z
    directions for symmetry.

    This can be thought of as 3 planes of features perpendicular to the x, y, and z axes, respectively and intersecting
    at the origin, and the encoding being the element-wise product of the element at the projection of [i, j, k] on
    these planes.

    The use for this is in representing a tensor decomp of a 4D embedding tensor: (x, y, z, feature_size)

    This will return a tensor of shape (bs:..., num_components)

    Args:
        resolution: Resolution of grid.
        num_components: The number of scalar triplanes to use (ie: output feature size)
        init_scale: The scale of the initial values of the planes
        product: Whether to use the element-wise product of the planes or the sum
    """

    plane_coef: Float[Tensor, "3 num_components resolution resolution"]

    def __init__(
        self,
        resolution: int = 32,
        num_components: int = 64,
        init_scale: float = 0.1,
        reduce: Literal["sum", "product"] = "sum",
    ) -> None:
        super().__init__(in_dim=3)

        self.resolution = resolution
        self.num_components = num_components
        self.init_scale = init_scale
        self.reduce = reduce

        self.plane_coef = nn.Parameter(
            self.init_scale * torch.randn((3, self.num_components, self.resolution, self.resolution))
        )

    def get_out_dim(self) -> int:
        return self.num_components

    def forward(self, in_tensor: Float[Tensor, "*bs 3"]) -> Float[Tensor, "*bs num_components featuresize"]:
        """Sample features from this encoder. Expects in_tensor to be in range [0, resolution]"""

        original_shape = in_tensor.shape
        in_tensor = in_tensor.reshape(-1, 3)

        plane_coord = torch.stack([in_tensor[..., [0, 1]], in_tensor[..., [0, 2]], in_tensor[..., [1, 2]]], dim=0)

        # Stop gradients from going to sampler
        plane_coord = plane_coord.detach().view(3, -1, 1, 2)
        plane_features = F.grid_sample(
            self.plane_coef, plane_coord, align_corners=True
        )  # [3, num_components, flattened_bs, 1]

        if self.reduce == "product":
            plane_features = plane_features.prod(0).squeeze(-1).T  # [flattened_bs, num_components]
        else:
            plane_features = plane_features.sum(0).squeeze(-1).T

        return plane_features.reshape(*original_shape[:-1], self.num_components)

    @torch.no_grad()
    def upsample_grid(self, resolution: int) -> None:
        """Upsamples underlying feature grid

        Args:
            resolution: Target resolution.
        """
        plane_coef = F.interpolate(
            self.plane_coef.data, size=(resolution, resolution), mode="bilinear", align_corners=True
        )

        self.plane_coef = torch.nn.Parameter(plane_coef)
        self.resolution = resolution



class GaplaneEncoding(Encoding):
    """Learned lowrank + lowres encoding

    Args:
        resolution: Resolution of grid.
        num_components: Number of components per dimension.
        init_scale: Initialization scale.
    """

    volume_coef: Float[Tensor, "1 num_components resolution resolution resolution"]
    plane_coef: Float[Tensor, "3 num_components resolution resolution"]
    line_coef: Float[Tensor, "3 num_components resolution 1"]

    def __init__(
        self,
        resolution: Sequence[int] = (128, 128, 64), ## Nl-Np-Nv (can be different for all grids)
        num_components: Sequence[int] = (32, 32, 16), ## Cl-Cp-Cv
        rng: float = 0.4,
        shft: float = 0.1,
        cat: bool = True,
        reduce: Literal["concat", "product"] = "product", ## concat: [lines, planes, volume] ## prod: [*lines ,lines * planes, volume]
    ) -> None:
        super().__init__(in_dim=3)

        self.resolution = resolution
        self.num_components = num_components
        self.reduce = reduce
        self.cat = cat

        if reduce == "product":
            assert num_components[0] == num_components[1], "Can't have Cl != Cp when multiplying features"


        Cl, Cp, Cv = num_components
        Nl, Np, Nv = resolution

        # self.lines = nn.ParameterList([nn.Parameter(torch.rand(1, Cl, Nl, 1)*rng + shft) for _ in range(3)])
        # self.planes = nn.ParameterList([nn.Parameter(torch.rand(1, Cp, Np, Np)*rng + shft) for _ in range(3)])
        # self.volume = nn.Parameter(torch.randn(1, Cv, Nv, Nv, Nv) * 0.01)

        self.line_coef = nn.Parameter(0.4 * torch.rand((3, num_components[0], resolution[0], 1)) + shft)
        self.plane_coef = nn.Parameter(0.15 * torch.rand((3, num_components[1], resolution[1], resolution[1])) + shft)
        self.volume_coef = nn.Parameter(0.025 * torch.rand((1, num_components[2], resolution[2], resolution[2], resolution[2])) + shft)

    def get_out_dim(self) -> int: ### change
        if self.cat:
            if self.reduce == "product":
                return 4*(self.num_components[0]) + self.num_components[2] # sum(self.num_components)
            else:
                return 3*(self.num_components[0]+self.num_components[1]) + self.num_components[2]
        else:
            return sum(self.num_components)

    def sample_volume(self, coords3d, plane):
        ### batch_size, n_coords, n_dims = coordinates.shape
        sampled_features = torch.nn.functional.grid_sample(plane,
                                                           coords3d.reshape(coords3d.shape[0], 1, 1, -1, coords3d.shape[-1]),
                                                           mode='bilinear', padding_mode='zeros', align_corners=True)
        sampled_features = sampled_features.squeeze(2)
        N, C, H, W = sampled_features.shape
        sampled_features = sampled_features.reshape(N, C, H*W).permute(0, 2, 1)
        return sampled_features

    def sample_plane(self, coords2d, plane):
        assert len(coords2d.shape) == 3, coords2d.shape
        sampled_features = torch.nn.functional.grid_sample(plane,
                                                           coords2d.reshape(coords2d.shape[0], 1, -1, coords2d.shape[-1]),
                                                           mode='bilinear', padding_mode='zeros', align_corners=True)
        N, C, H, W = sampled_features.shape
        sampled_features = sampled_features.reshape(N, C, H*W).permute(0, 2, 1)
        return sampled_features


    def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
        """Compute encoding for each position in in_positions

        Args:
            in_tensor: position inside bounds in range [-1,1],

        Returns: Encoded position
        """
        plane_coord = torch.stack([in_tensor[..., [0, 1]], in_tensor[..., [0, 2]], in_tensor[..., [1, 2]]])  # [3,...,2]
        line_coord = torch.stack([in_tensor[..., 2], in_tensor[..., 1], in_tensor[..., 0]])  # [3, ...]
        line_coord = torch.stack([line_coord, line_coord], dim=-1)  # [3, ...., 2]

        

        # Stop gradients from going to sampler
        plane_coord = plane_coord.view(3, -1, 1, 2).detach() ## torch.Size([3, 819200, 1, 2]) (num_planes, batch*num_samples, 1, 2)
        line_coord = line_coord.view(3, -1, 1, 2).detach()
        volume_coord = in_tensor.view(1, -1, 1, 1, 3).detach()

        plane_features = F.grid_sample(self.plane_coef, plane_coord, align_corners=True)  # [3, Cp, -1, 1]
        line_features = F.grid_sample(self.line_coef, line_coord, align_corners=True)  # [3, Cl, -1, 1]
        volume_features = F.grid_sample(self.volume_coef, volume_coord, align_corners=True) # [1, Cv, -1, 1, 1]
        volume_features = volume_features.squeeze(0).squeeze(-1) # [Cv, -1, 1]
        # volume_features = self.sample_volume(volume_coord, self.volume_coef)

        #print(volume_features.shape)
        
        ##### check dimensions!!!
        # features = plane_features * line_features  # [3, Components, -1, 1]
        # features = features.sum(0)
        # features = torch.moveaxis(features.view(self.num_components[0], *in_tensor.shape[:-1]), 0, -1)

        Cl, Cp, Cv = self.num_components

        if self.reduce == "product":
            # print("hereeeeee")
            # PP_features = plane_features[0] * plane_features[1] * plane_features[2] # [Components, -1, 1]
            if self.cat:
                PL_features = (plane_features * line_features).reshape(3*Cl, -1, 1) # [3*Components, -1, 1]
            else:
                PL_features = torch.sum(plane_features * line_features, dim=0)  # [Components, -1, 1]
            
            LL_features = line_features[0] * line_features[1] * line_features[2] ##### [Components, -1, 1]
            features = torch.cat([LL_features, PL_features, volume_features], dim=0)
            # features = torch.cat([features, PP_features], dim=0)
            # print(features.shape)


        else:
            if self.cat:
                line = line_features.reshape(3*Cl, -1, 1)
                plane = plane_features.reshape(3*Cp, -1, 1)
                features = torch.cat([line, plane, volume_features], dim=0)
            else:
                features = torch.cat([torch.sum(line_features, dim=0), torch.sum(plane_features, dim=0), volume_features], dim=0)



        features = torch.moveaxis(features.view(features.shape[0], *in_tensor.shape[:-1]), 0, -1) # features.shape[0] = out_dim

        return features  # [..., out_dim] [4096, 200, out_dim]

    @torch.no_grad()
    def upsample_grid(self, resolution: Sequence[int]) -> None: ###### do we need this?? -- keep
        """Upsamples underlying feature grid

        Args:
            resolution: Target resolution.
        """
        plane_coef = F.interpolate(
            self.plane_coef.data, size=(resolution[1], resolution[1]), mode="bilinear", align_corners=True
        )
        volume_coef = F.interpolate(
            self.volume_coef.data, size=(resolution[2], resolution[2], resolution[2]), mode="trilinear", align_corners=True
        )
        line_coef = F.interpolate(self.line_coef.data, size=(resolution[0], 1), mode="bilinear", align_corners=True)

        self.plane_coef, self.line_coef = torch.nn.Parameter(plane_coef), torch.nn.Parameter(line_coef)
        self.volume_coef = torch.nn.Parameter(volume_coef) ### upsample volume too 
        self.resolution = resolution
        

class GaplaneMultiresEncoding(Encoding):
    """Learned lowrank + lowres encoding

    Args:
        resolution: Resolution of grid.
        multiscale_res: Scales for the multires grids.
        num_components: Number of components per dimension.
        init_scale: Initialization scale.
    """

    volume_coef: Float[Tensor, "1 num_components resolution resolution resolution"]
    plane_coef: Float[Tensor, "3 num_components resolution resolution"]
    line_coef: Float[Tensor, "3 num_components resolution 1"]

    def __init__(
        self,
        resolution: Sequence[int] = (128, 128, 64), ## Nl-Np-Nv (can be different for all grids)
        num_components: Sequence[int] = (32, 32, 16), ## Cl-Cp-Cv
        multiscale_res: Sequence[int] = (1, 2, 4),
        rng: float = 0.4,
        shft: float = 0.1,
        cat: bool = False,
        adapt_Cs: bool = False,
        reduce: Literal["concat", "product"] = "product", ## concat: [lines, planes, volume] ## prod: [*lines ,lines * planes, volume]
    ) -> None:
        super().__init__(in_dim=3)

        self.resolution = resolution
        self.num_components = num_components
        self.reduce = reduce
        self.multiscale_res = multiscale_res
        self.cat = cat ## concat to increase feature dim
        self.adapt_Cs = adapt_Cs

        if reduce == "product":
            assert num_components[0] == num_components[1], "Can't have Cl != Cp when multiplying features"


        Cl, Cp, Cv = num_components
        Nl, Np, Nv = resolution

        if self.adapt_Cs:
            if reduce == "concat":
                self.Cls = [num_components[0], num_components[0]*2, num_components[0]*5]
                self.Cps = [num_components[1], num_components[1]*2, int(num_components[1]*4.5)]
            else:
                self.Cls = [num_components[0], num_components[0], int(num_components[1]*2)]
                self.Cps = [num_components[1], num_components[1], int(num_components[1]*2)]

        self.line_grids = nn.ParameterList()
        self.plane_grids = nn.ParameterList()

        for i,res in enumerate(multiscale_res):
            if adapt_Cs:
                # line_coef = nn.Parameter(0.4 * torch.rand((3, res*num_components[0], res*resolution[0], 1)) + shft)
                # plane_coef = nn.Parameter(0.15 * torch.rand((3, res*num_components[1], res*resolution[1], res*resolution[1])) + shft)
                line_coef = nn.Parameter(0.4 * torch.rand((3, self.Cls[i], res*resolution[0], 1)) + shft)
                plane_coef = nn.Parameter(0.15 * torch.rand((3, self.Cps[i], res*resolution[1], res*resolution[1])) + shft)
            else:
                line_coef = nn.Parameter(0.4 * torch.rand((3, num_components[0], res*resolution[0], 1)) + shft)
                plane_coef = nn.Parameter(0.15 * torch.rand((3, num_components[1], res*resolution[1], res*resolution[1])) + shft)
            self.line_grids.append(line_coef)
            self.plane_grids.append(plane_coef)

        
        self.volume_coef = nn.Parameter(0.025 * torch.rand((1, num_components[2], resolution[2], resolution[2], resolution[2])) + shft) #*0.01)

    def get_out_dim(self) -> int: ### change
        if self.reduce == "product": ## make it more general
            if self.cat:
                return (4 * self.num_components[0] + self.num_components[2]) * len(self.multiscale_res) ##update this later
            else:
                if self.adapt_Cs:
                    # return (self.num_components[0] + self.num_components[1]) * sum(self.multiscale_res) + self.num_components[2]
                    return sum(self.Cls) + sum(self.Cps) + self.num_components[2]

                return (self.num_components[0] + self.num_components[1]) * len(self.multiscale_res) + self.num_components[2]
    
        else:
            if self.cat:
                return (self.num_components[0] + self.num_components[1]) * 3 * len(self.multiscale_res) + self.num_components[2]
            else:
                if self.adapt_Cs:
                    return sum(self.Cls) + sum(self.Cps) + self.num_components[2]

                return (self.num_components[0] + self.num_components[1]) * len(self.multiscale_res) + self.num_components[2]
    

    def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
        """Compute encoding for each position in in_positions

        Args:
            in_tensor: position inside bounds in range [-1,1],

        Returns: Encoded position
        """
        plane_coord = torch.stack([in_tensor[..., [0, 1]], in_tensor[..., [0, 2]], in_tensor[..., [1, 2]]])  # [3,...,2]
        line_coord = torch.stack([in_tensor[..., 2], in_tensor[..., 1], in_tensor[..., 0]])  # [3, ...]
        line_coord = torch.stack([line_coord, line_coord], dim=-1)  # [3, ...., 2]

        

        # Stop gradients from going to sampler
        plane_coord = plane_coord.view(3, -1, 1, 2).detach() ## torch.Size([3, 819200, 1, 2]) (num_planes, batch*num_samples, 1, 2)
        line_coord = line_coord.view(3, -1, 1, 2).detach()
        volume_coord = in_tensor.view(1, -1, 1, 1, 3).detach()

        multires_features = []

        volume_features = F.grid_sample(self.volume_coef, volume_coord, align_corners=True) # [1, Cv, -1, 1, 1]
        volume_features = volume_features.squeeze(0).squeeze(-1) # [Cv, -1, 1]
        volume_features = torch.moveaxis(volume_features.view(volume_features.shape[0], *in_tensor.shape[:-1]), 0, -1)
            

        for s in range(len(self.multiscale_res)):
            line_coef = self.line_grids[s]
            plane_coef = self.plane_grids[s]
            # volume_coef = self.volume_grids[s]

            plane_features = F.grid_sample(plane_coef, plane_coord, align_corners=True)  # [3, Cp, -1, 1]
            line_features = F.grid_sample(line_coef, line_coord, align_corners=True)  # [3, Cl, -1, 1]
            # volume_features = F.grid_sample(volume_coef, volume_coord, align_corners=True) # [1, Cv, -1, 1, 1]
            # volume_features = volume_features.squeeze(0).squeeze(-1) # [Cv, -1, 1]
            

            if self.reduce == "product":
                if self.cat:
                    PL_features = (plane_features * line_features).reshape(3*self.num_components[0], -1, 1)
                else:
                    PL_features = torch.sum(plane_features * line_features, dim=0)  # [Components, -1, 1]
                LL_features = line_features[0] * line_features[1] * line_features[2] ##### [Components, -1, 1]
                # features = torch.cat([LL_features, PL_features, volume_features], dim=0)
                features = torch.cat([LL_features, PL_features], dim=0)


            else: ## change this later
                # ipdb.set_trace()
                features = torch.cat([torch.sum(line_features, dim=0), torch.sum(plane_features, dim=0)], dim=0)


            features = torch.moveaxis(features.view(features.shape[0], *in_tensor.shape[:-1]), 0, -1) # features.shape[0] = out_dim
            multires_features.append(features)
        
        multires_features.append(volume_features)
        features = torch.cat(multires_features, dim=-1)
        # print(features.shape)

        return features  # [..., out_dim] [4096, 200, out_dim]



class KplaneV4Encoding(Encoding):
    """Learned lowrank + lowres encoding

    Args:
        resolution: Resolution of grid.
        multiscale_res: Scales for the multires grids.
        num_components: Number of components per dimension.
        init_scale: Initialization scale.
    """


    plane_coef: Float[Tensor, "3 num_components resolution resolution"]


    def __init__(
        self,
        resolution: Sequence[int] = (128, 128, 64), ## Nl-Np-Nv (can be different for all grids)
        num_components: Sequence[int] = (32, 32, 16), ## Cl-Cp-Cv
        multiscale_res: Sequence[int] = (1, 2, 4),
        rng: float = 0.4,
        shft: float = 0.1,

    ) -> None:
        super().__init__(in_dim=3)

        self.resolution = resolution[1]
        self.num_components = num_components[1]

        self.multiscale_res = multiscale_res


        self.plane_grids = nn.ParameterList()

        for res in multiscale_res:

            plane_coef = nn.Parameter(0.4 * torch.rand((3, num_components[1], res*resolution[1], res*resolution[1])) + shft)
            self.plane_grids.append(plane_coef)

        

    def get_out_dim(self) -> int: ### change
        return self.num_components * len(self.multiscale_res)
        

    def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
        """Compute encoding for each position in in_positions

        Args:
            in_tensor: position inside bounds in range [-1,1],

        Returns: Encoded position
        """
        plane_coord = torch.stack([in_tensor[..., [0, 1]], in_tensor[..., [0, 2]], in_tensor[..., [1, 2]]])  # [3,...,2]


        # Stop gradients from going to sampler
        plane_coord = plane_coord.view(3, -1, 1, 2).detach() ## torch.Size([3, 819200, 1, 2]) (num_planes, batch*num_samples, 1, 2)


        multires_features = []

        
        for s in range(len(self.multiscale_res)):
            
            plane_coef = self.plane_grids[s]

            plane_features = F.grid_sample(plane_coef, plane_coord, align_corners=True)  # [3, Cp, -1, 1]
            

            features = plane_features[0] * plane_features[1] * plane_features[2] ##### [Components, -1, 1]

            features = torch.moveaxis(features.view(features.shape[0], *in_tensor.shape[:-1]), 0, -1) # features.shape[0] = out_dim
            multires_features.append(features)
        
        features = torch.cat(multires_features, dim=-1)

        return features  # [..., out_dim] [4096, 200, out_dim]



class Tensorfv4Encoding(Encoding):
    """Learned lowrank + lowres encoding

    Args:
        resolution: Resolution of grid.
        multiscale_res: Scales for the multires grids.
        num_components: Number of components per dimension.
        init_scale: Initialization scale.
    """

    plane_coef: Float[Tensor, "3 num_components resolution resolution"]
    line_coef: Float[Tensor, "3 num_components resolution 1"]

    def __init__(
        self,
        resolution: Sequence[int] = (128, 128, 64), ## Nl-Np-Nv (can be different for all grids)
        num_components: Sequence[int] = (32, 32, 16), ## Cl-Cp-Cv
        multiscale_res: Sequence[int] = (1, 2, 4),
        rng: float = 0.4,
        shft: float = 0.1,
        use_vm: bool = True, 
    ) -> None:
        super().__init__(in_dim=3)

        self.resolution = resolution
        self.num_components = num_components

        self.multiscale_res = multiscale_res
        self.use_vm = use_vm

        if self.use_vm:
            assert num_components[0] == num_components[1], "Can't have Cl != Cp when multiplying features"


        self.line_grids = nn.ParameterList()
        self.plane_grids = nn.ParameterList()

        for res in multiscale_res:
            line_coef = nn.Parameter(0.4 * torch.rand((3, num_components[0], res*resolution[0], 1)) + shft)
            self.line_grids.append(line_coef)
            if self.use_vm:
                plane_coef = nn.Parameter(0.15 * torch.rand((3, num_components[1], res*resolution[1], res*resolution[1])) + shft)
                self.plane_grids.append(plane_coef)

        

    def get_out_dim(self) -> int: ### 
        return (self.num_components[0]) * len(self.multiscale_res)
    
    

    def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
        """Compute encoding for each position in in_positions

        Args:
            in_tensor: position inside bounds in range [-1,1],

        Returns: Encoded position
        """
        
        line_coord = torch.stack([in_tensor[..., 2], in_tensor[..., 1], in_tensor[..., 0]])  # [3, ...]
        line_coord = torch.stack([line_coord, line_coord], dim=-1)  # [3, ...., 2]
        line_coord = line_coord.view(3, -1, 1, 2).detach()

        
        if self.use_vm:
            plane_coord = torch.stack([in_tensor[..., [0, 1]], in_tensor[..., [0, 2]], in_tensor[..., [1, 2]]])  # [3,...,2]
            plane_coord = plane_coord.view(3, -1, 1, 2).detach() ## torch.Size([3, 819200, 1, 2]) (num_planes, batch*num_samples, 1, 2)
        

        multires_features = []
    
        for s in range(len(self.multiscale_res)):

            line_coef = self.line_grids[s]
            line_features = F.grid_sample(line_coef, line_coord, align_corners=True)  # [3, Cl, -1, 1]
            
            
            if self.use_vm:
                plane_coef = self.plane_grids[s]
                plane_features = F.grid_sample(plane_coef, plane_coord, align_corners=True)  # [3, Cp, -1, 1]
                features = torch.sum(plane_features * line_features, dim=0)  # [Components, -1, 1]
            else:
                features = line_features[0] * line_features[1] * line_features[2] ##### [Components, -1, 1]
            

            features = torch.moveaxis(features.view(features.shape[0], *in_tensor.shape[:-1]), 0, -1) # features.shape[0] = out_dim
            multires_features.append(features)
        

        features = torch.cat(multires_features, dim=-1)
        # print(features.shape)

        return features  # [..., out_dim] [4096, 200, out_dim]

class KPlanesEncoding(Encoding):
    """Learned K-Planes encoding

    A plane encoding supporting both 3D and 4D coordinates. With 3D coordinates this is similar to
    :class:`TriplaneEncoding`. With 4D coordinates, the encoding at point ``[i,j,k,q]`` is
    a n-dimensional vector computed as the elementwise product of 6 n-dimensional vectors at
    ``planes[i,j]``, ``planes[i,k]``, ``planes[i,q]``, ``planes[j,k]``, ``planes[j,q]``,
    ``planes[k,q]``.

    Unlike :class:`TriplaneEncoding` this class supports different resolution along each axis.

    This will return a tensor of shape (bs:..., num_components)

    Args:
        resolution: Resolution of the grid. Can be a sequence of 3 or 4 integers.
        num_components: The number of scalar planes to use (ie: output feature size)
        init_a: The lower-bound of the uniform distribution used to initialize the spatial planes
        init_b: The upper-bound of the uniform distribution used to initialize the spatial planes
        reduce: Whether to use the element-wise product of the planes or the sum
    """

    def __init__(
        self,
        resolution: Sequence[int] = (128, 128, 128),
        num_components: int = 64,
        init_a: float = 0.1,
        init_b: float = 0.5,
        reduce: Literal["sum", "product"] = "product",
    ) -> None:
        super().__init__(in_dim=len(resolution))

        self.resolution = resolution
        self.num_components = num_components
        self.reduce = reduce
        if self.in_dim not in {3, 4}:
            raise ValueError(
                f"The dimension of coordinates must be either 3 (static scenes) "
                f"or 4 (dynamic scenes). Found resolution with {self.in_dim} dimensions."
            )
        has_time_planes = self.in_dim == 4

        self.coo_combs = list(itertools.combinations(range(self.in_dim), 2))
        # Unlike the Triplane encoding, we use a parameter list instead of batching all planes
        # together to support uneven resolutions (especially useful for time).
        # Dynamic models (in_dim == 4) will have 6 planes:
        # (y, x), (z, x), (t, x), (z, y), (t, y), (t, z)
        # static models (in_dim == 3) will only have the 1st, 2nd and 4th planes.
        self.plane_coefs = nn.ParameterList()
        for coo_comb in self.coo_combs:
            new_plane_coef = nn.Parameter(
                torch.empty([self.num_components] + [self.resolution[cc] for cc in coo_comb[::-1]])
            )
            if has_time_planes and 3 in coo_comb:  # Time planes initialized to 1
                nn.init.ones_(new_plane_coef)
            else:
                nn.init.uniform_(new_plane_coef, a=init_a, b=init_b)
            self.plane_coefs.append(new_plane_coef)

    def get_out_dim(self) -> int:
        return self.num_components

    def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
        """Sample features from this encoder. Expects ``in_tensor`` to be in range [-1, 1]"""
        original_shape = in_tensor.shape

        assert any(self.coo_combs)
        output = 1.0 if self.reduce == "product" else 0.0  # identity for corresponding op
        for ci, coo_comb in enumerate(self.coo_combs):
            grid = self.plane_coefs[ci].unsqueeze(0)  # [1, feature_dim, reso1, reso2]
            coords = in_tensor[..., coo_comb].view(1, 1, -1, 2)  # [1, 1, flattened_bs, 2]
            interp = F.grid_sample(
                grid, coords, align_corners=True, padding_mode="border"
            )  # [1, output_dim, 1, flattened_bs]
            interp = interp.view(self.num_components, -1).T  # [flattened_bs, output_dim]
            if self.reduce == "product":
                output = output * interp
            else:
                output = output + interp

        # Typing: output gets converted to a tensor after the first iteration of the loop
        assert isinstance(output, Tensor)
        return output.reshape(*original_shape[:-1], self.num_components)


class SHEncoding(Encoding):
    """Spherical harmonic encoding

    Args:
        levels: Number of spherical harmonic levels to encode.
    """

    def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "torch") -> None:
        super().__init__(in_dim=3)

        if levels <= 0 or levels > 4:
            raise ValueError(f"Spherical harmonic encoding only supports 1 to 4 levels, requested {levels}")

        self.levels = levels

        self.tcnn_encoding = None
        if implementation == "tcnn" and not TCNN_EXISTS:
            print_tcnn_speed_warning("SHEncoding")
        elif implementation == "tcnn":
            encoding_config = self.get_tcnn_encoding_config(levels=self.levels)
            self.tcnn_encoding = tcnn.Encoding(
                n_input_dims=3,
                encoding_config=encoding_config,
            )

    @classmethod
    def get_tcnn_encoding_config(cls, levels) -> dict:
        """Get the encoding configuration for tcnn if implemented"""
        encoding_config = {
            "otype": "SphericalHarmonics",
            "degree": levels,
        }
        return encoding_config

    def get_out_dim(self) -> int:
        return self.levels**2

    @torch.no_grad()
    def pytorch_fwd(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
        """Forward pass using pytorch. Significantly slower than TCNN implementation."""
        return components_from_spherical_harmonics(levels=self.levels, directions=in_tensor)

    def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
        if self.tcnn_encoding is not None:
            return self.tcnn_encoding(in_tensor)
        return self.pytorch_fwd(in_tensor)
