# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Fields for K-Planes (https://sarafridov.github.io/K-Planes/).
"""

from typing import Dict, Iterable, List, Optional, Tuple, Sequence, Literal

import torch
from rich.console import Console
from torch import nn
from torchtyping import TensorType

from nerfstudio.cameras.rays import RaySamples, Frustums
from nerfstudio.data.scene_box import SceneBox
from nerfstudio.field_components.activations import trunc_exp
from nerfstudio.field_components.embedding import Embedding
from nerfstudio.field_components.encodings import KPlanesEncoding, SHEncoding, NeRFEncoding, GaplaneMultiresEncoding, GaplaneEncoding
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.field_components.spatial_distortions import SpatialDistortion
from nerfstudio.fields.base_field import Field, get_normalized_directions

from nerfstudio.field_components.mlp import MLP

try:
    import tinycudann as tcnn
except ImportError:
    # tinycudann module doesn't exist
    pass

CONSOLE = Console(width=120)


class GAPlanesV2Field(Field):
    """GA-Planes field w/ prop sampling.

    Args:
        aabb: Parameters of scene aabb bounds
        geo_feat_dim: Dimension of 'geometry' features. Controls output dimension of sigma network
        grid_base_resolution: Base grid resolution
        grid_feature_dim: Dimension of feature vectors stored in grid
        concat_across_scales: Whether to concatenate features at different scales
        multiscale_res: Multiscale grid resolutions
        spatial_distortion: Spatial distortion to apply to the scene
        use_average_appearance_embedding: Whether to use average appearance embedding or zeros for inference

    """

    def __init__(
        self,
        aabb: TensorType,
        geo_feat_dim: int = 15,  # TODO: This should be removed
        concat_across_scales: bool = True,  # TODO: Maybe this should be removed
        grid_base_resolution: Sequence[int] = (128, 128, 64),
        grid_feature_dim: Sequence[int] = (32, 32, 16),
        multiscale_res: Sequence[int] = (1, 2, 4),
        spatial_distortion: Optional[SpatialDistortion] = None,
        reduce: Literal["concat", "product"] = "product",
    ) -> None:

        super().__init__()

        self.register_buffer("aabb", aabb)
        self.geo_feat_dim = geo_feat_dim
        self.grid_base_resolution = list(grid_base_resolution)
        self.concat_across_scales = concat_across_scales
        self.spatial_distortion = spatial_distortion
        self.reduce=reduce

        
        self.feature_encoding = GaplaneMultiresEncoding(
            resolution=grid_base_resolution,
            num_components=grid_feature_dim,
            multiscale_res=multiscale_res,
            reduce=reduce,
        )
        self.feature_dim = self.feature_encoding.get_out_dim()
        print(self.feature_dim)


        


        self.sigma_net = MLP(
            in_dim=self.feature_dim,
            out_dim=self.geo_feat_dim + 1,
            activation=nn.ReLU(),
            num_layers=2,
            layer_width=64,
        )


        # self.direction_encoding = SHEncoding(levels=4) ## also try nerf-encoding
        self.direction_encoding = NeRFEncoding(
            in_dim=3, num_frequencies=4, min_freq_exp=0.0, max_freq_exp=4.0, include_input=True
        ) 

        in_dim_color = (
            self.direction_encoding.get_out_dim() + self.geo_feat_dim
        )

        self.color_net = MLP(
            in_dim=in_dim_color,
            out_dim=3,
            activation=nn.ReLU(),
            out_activation=nn.Sigmoid(),
            num_layers=3,
            layer_width=64,
        )

    def get_density(self, ray_samples: RaySamples) -> Tuple[TensorType, TensorType]:
        """Computes and returns the densities."""
        positions = ray_samples.frustums.get_positions()
        if self.spatial_distortion is not None:
            positions = self.spatial_distortion(positions)
            positions = positions / 2  # from [-2, 2] to [-1, 1]
        else:
            # From [0, 1] to [-1, 1]
            positions = SceneBox.get_normalized_positions(positions, self.aabb) * 2.0 - 1.0

        features = self.feature_encoding(positions)
        # positions_flat = positions.view(-1, positions.shape[-1])
        # features = interpolate_ms_features(
        #     positions_flat, grid_encodings=self.grids, concat_features=self.concat_across_scales
        # )
        if len(features) < 1:
            features = torch.zeros((0, 1), device=features.device, requires_grad=True)

        features = self.sigma_net(features)#.view(*ray_samples.frustums.shape, -1)
        features, density_before_activation = torch.split(features, [self.geo_feat_dim, 1], dim=-1) ### try with separate NNs as well

        density = trunc_exp(density_before_activation.to(positions) - 1)
        return density, features

    def get_outputs(
        self, ray_samples: RaySamples, density_embedding: Optional[TensorType] = None
    ) -> Dict[FieldHeadNames, TensorType]:
        assert density_embedding is not None

        output_shape = ray_samples.frustums.shape
        directions = ray_samples.frustums.directions.reshape(-1, 3)

        directions = get_normalized_directions(directions)
        d = self.direction_encoding(directions)
        color_features = [d, density_embedding.view(-1, self.geo_feat_dim)]


        color_features = torch.cat(color_features, dim=-1)
        rgb = self.color_net(color_features).view(*output_shape, -1)

        return {FieldHeadNames.RGB: rgb}


class GAPlanesDensityField(Field):
    """A lightweight density field module.

    Args:
        aabb: Parameters of scene aabb bounds
        resolution: Grid resolution
        num_output_coords: dimension of grid feature vectors
        spatial_distortion: Spatial distortion to apply to the scene

    """

    def __init__(
        self,
        aabb: TensorType,
        resolution: List[int], # Nl-Np-Nv
        grid_feature_dim: List[int], # Cl-Cp-Cv
        spatial_distortion: Optional[SpatialDistortion] = None,
        reduce: Literal["concat", "product"] = "product",
    ):
        super().__init__()

        self.register_buffer("aabb", aabb)
        self.spatial_distortion = spatial_distortion
        


        # self.grids = KPlanesEncoding(resolution, num_output_coords, init_a=0.1, init_b=0.15)
        self.feature_encoding = GaplaneEncoding(
            resolution=resolution,
            num_components=grid_feature_dim,
            reduce=reduce
        )

        self.feature_dim = self.feature_encoding.get_out_dim()
        
        self.sigma_net = MLP(
            in_dim=self.feature_dim,
            out_dim=1,
            num_layers=2,
            layer_width=64,
            activation=nn.ReLU(),
        )

        CONSOLE.log(f"Initialized GAPlaneDensityField. with resolution={resolution}")

    # pylint: disable=arguments-differ
    def density_fn(self, positions: TensorType["bs":..., 3]) -> TensorType["bs":..., 1]:
        """Returns only the density. Overrides base function to add times in samples

        Args:
            positions: the origin of the samples/frustums
            times: the time of rays
        """

        ray_samples = RaySamples(
            frustums=Frustums(
                origins=positions,
                directions=torch.ones_like(positions),
                starts=torch.zeros_like(positions[..., :1]),
                ends=torch.zeros_like(positions[..., :1]),
                pixel_area=torch.ones_like(positions[..., :1]),
            )
        )
        density, _ = self.get_density(ray_samples)
        return density

    def get_density(self, ray_samples: RaySamples) -> Tuple[TensorType, None]:
        """Computes and returns the densities."""
        positions = ray_samples.frustums.get_positions()
        if self.spatial_distortion is not None:
            positions = self.spatial_distortion(positions)
            positions = positions / 2  # from [-2, 2] to [-1, 1]
        else:
            # From [0, 1] to [-1, 1]
            positions = SceneBox.get_normalized_positions(positions, self.aabb) * 2.0 - 1.0

        # positions_flat = positions.view(-1, positions.shape[-1])
        # features = interpolate_ms_features(
        #     positions_flat, grid_encodings=[self.grids], concat_features=False
        # )
        # if len(features) < 1:
        #     features = torch.zeros((0, 1), device=features.device, requires_grad=True)
        features = self.feature_encoding(positions)
        density_before_activation = self.sigma_net(features)#.view(*ray_samples.frustums.shape, -1)
        density = trunc_exp(density_before_activation.to(positions) - 1)
        return density, None

    def get_outputs(self, ray_samples: RaySamples, density_embedding: Optional[TensorType] = None) -> dict:
        return {}