# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Fields for K-Planes (https://sarafridov.github.io/K-Planes/).
"""

from typing import Dict, Iterable, List, Optional, Tuple, Sequence, Literal

import torch
from rich.console import Console
from torch import nn, Tensor
from torchtyping import TensorType

from nerfstudio.cameras.rays import RaySamples, Frustums
from nerfstudio.data.scene_box import SceneBox
from nerfstudio.field_components.activations import trunc_exp
from nerfstudio.field_components.embedding import Embedding
from nerfstudio.field_components.encodings import KPlanesEncoding, SHEncoding, NeRFEncoding, GaplaneMultiresEncoding, GaplaneEncoding, Tensorfv4Encoding, KplaneV4Encoding
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.field_components.spatial_distortions import SpatialDistortion
from nerfstudio.fields.base_field import Field, get_normalized_directions

from nerfstudio.field_components.mlp import MLP

try:
    import tinycudann as tcnn
except ImportError:
    # tinycudann module doesn't exist
    pass

CONSOLE = Console(width=120)


class GAPlanesV4Field(Field):
    """GA-Planes multires field.

    Args:
        aabb: Parameters of scene aabb bounds
        geo_feat_dim: Dimension of 'geometry' features. Controls output dimension of sigma network
        grid_base_resolution: Base grid resolution
        grid_feature_dim: Dimension of feature vectors stored in grid

        multiscale_res: Multiscale grid resolutions
        spatial_distortion: Spatial distortion to apply to the scene
        use_average_appearance_embedding: Whether to use average appearance embedding or zeros for inference

    """

    def __init__(
        self,
        aabb: TensorType,
        grid_base_resolution: Sequence[int] = (128, 128, 64),
        grid_feature_dim: Sequence[int] = (32, 32, 16),
        multiscale_res: Sequence[int] = (1, 2, 4),
        spatial_distortion: Optional[SpatialDistortion] = None,
        reduce: Literal["concat", "product"] = "product",
        method: Literal["gaplane", "kplane", "tensorf-vm", "tensorf-cp", "volume"] = "gaplane",
    ) -> None:

        super().__init__()

        self.register_buffer("aabb", aabb)

        self.grid_base_resolution = list(grid_base_resolution)

        self.spatial_distortion = spatial_distortion
        self.reduce=reduce

        if method == "gaplane":
            self.feature_encoding = GaplaneMultiresEncoding(
                resolution=grid_base_resolution,
                num_components=grid_feature_dim,
                multiscale_res=multiscale_res,
                reduce=reduce,
            )
        elif method == "kplane":
            self.feature_encoding = KplaneV4Encoding(
                resolution=grid_base_resolution,
                num_components=grid_feature_dim,
                multiscale_res=multiscale_res,
            )
        elif "tensorf" in method:
            use_vm = "vm" in method
            self.feature_encoding = Tensorfv4Encoding(
                resolution=grid_base_resolution,
                num_components=grid_feature_dim,
                multiscale_res=multiscale_res,
                use_vm=use_vm,
            )
        

        self.feature_dim = self.feature_encoding.get_out_dim()


        


        self.sigma_net = MLP(
            in_dim=self.feature_dim,
            out_dim=1,
            activation=nn.ReLU(),
            num_layers=2,
            layer_width=128,
        )


        # self.direction_encoding = SHEncoding(levels=4) ## also try nerf-encoding
        self.direction_encoding = NeRFEncoding(
            in_dim=3, num_frequencies=4, min_freq_exp=0.0, max_freq_exp=4.0, include_input=True
        ) 

        in_dim_color = self.direction_encoding.get_out_dim() + self.feature_dim

        self.color_net = MLP(
            in_dim=in_dim_color,
            out_dim=3,
            activation=nn.ReLU(),
            out_activation=nn.Sigmoid(),
            num_layers=4,
            layer_width=128,
        )

    def get_density(self, ray_samples: RaySamples) -> Tuple[TensorType, TensorType]:
        """Computes and returns the densities."""
        positions = ray_samples.frustums.get_positions()
        if self.spatial_distortion is not None:
            positions = self.spatial_distortion(positions)
            positions = positions / 2  # from [-2, 2] to [-1, 1]
        else:
            # From [0, 1] to [-1, 1]
            positions = SceneBox.get_normalized_positions(positions, self.aabb) * 2.0 - 1.0

        features = self.feature_encoding(positions)
        # print(features.device)
        # print(positions.device)

        if len(features) < 1:
            features = torch.zeros((0, 1), device=features.device, requires_grad=True)

        density_before_activation = self.sigma_net(features)#.view(*ray_samples.frustums.shape, -1)
        
        density = trunc_exp(density_before_activation.to(positions) - 1)
        return density

    def get_outputs(
        self, ray_samples: RaySamples
    ) -> Tensor:
        d = ray_samples.frustums.directions
        positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
        positions = positions * 2 - 1
        rgb_features = self.feature_encoding(positions)

        d_encoded = self.direction_encoding(d)
        rgb = self.color_net(torch.cat([rgb_features, d_encoded], dim=-1))  

        return rgb

    def forward(
        self,
        ray_samples: RaySamples,
        compute_normals: bool = False,
        mask: Optional[Tensor] = None,
        bg_color: Optional[Tensor] = None,
    ) -> Dict[FieldHeadNames, Tensor]:
        if compute_normals is True:
            raise ValueError("Surface normals are not currently supported with TensoRF")
        if mask is not None and bg_color is not None:
            base_density = torch.zeros(ray_samples.shape)[:, :, None].to(mask.device)
            base_rgb = bg_color.repeat(ray_samples[:, :, None].shape)
            if mask.any():
                # print("here")
                input_rays = ray_samples[mask, :]
                density = self.get_density(input_rays)
                rgb = self.get_outputs(input_rays)

                base_density[mask] = density
                base_rgb[mask] = rgb

                base_density.requires_grad_()
                base_rgb.requires_grad_()

            density = base_density
            rgb = base_rgb
        else:
            density = self.get_density(ray_samples)
            rgb = self.get_outputs(ray_samples)

        return {FieldHeadNames.DENSITY: density, FieldHeadNames.RGB: rgb}



   