# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""GAplanes Field with shared features"""

from typing import Dict, Optional
from torchtyping import TensorType

import torch
from torch import Tensor, nn
from torch.nn.parameter import Parameter

from nerfstudio.cameras.rays import RaySamples
from nerfstudio.data.scene_box import SceneBox
from nerfstudio.field_components.encodings import Encoding, Identity, SHEncoding
from nerfstudio.field_components.field_heads import FieldHeadNames, RGBFieldHead
from nerfstudio.field_components.mlp import MLP
from nerfstudio.fields.base_field import Field, get_normalized_directions
from nerfstudio.field_components.activations import trunc_exp
import ipdb

class GAplanesV3Field(Field):
    """GAplanes Field (shared features)"""

    def __init__(
        self,
        aabb: Tensor,
        # the aabb bounding box of the dataset
        direction_encoding: Encoding = Identity(in_dim=3),
        # the encoding method used for ray direction
        density_rgb_encoding: Encoding = Identity(in_dim=3),
        # the tensor encoding method used for scene density and color


    ) -> None:
        super().__init__()
        self.aabb = Parameter(aabb, requires_grad=False)
        self.direction_encoding = direction_encoding ### fourier encoding (like k-planes)
        self.density_rgb_encoding = density_rgb_encoding #### get density features

        self.density_mlp = MLP(
            in_dim=self.density_rgb_encoding.get_out_dim(),
            out_dim=15 + 1,
            activation=nn.ReLU(),
            num_layers=2,
            layer_width=128,
        )
        
        # self.density_mlp = MLP(
        #     in_dim=self.density_rgb_encoding.get_out_dim(),
        #     out_dim=1,
        #     activation=nn.ReLU(),
        #     num_layers=2,
        #     layer_width=128,
        # )

        self.color_mlp = MLP(
            in_dim=15 + self.direction_encoding.get_out_dim(),
            num_layers=3,
            out_dim=3,
            layer_width=128,
            activation=nn.ReLU(),
            out_activation=nn.Sigmoid(), 
        )
        # self.field_output_rgb = RGBFieldHead(in_dim=self.color_mlp.get_out_dim(), activation=nn.Sigmoid()) ## activation(linear)



    def get_density(self, ray_samples: RaySamples) -> Tensor:
        positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
        positions = positions * 2 - 1
        features = self.density_rgb_encoding(positions)

        features = self.density_mlp(features)
        features, density_before_activation = torch.split(features, [15, 1], dim=-1) ### try with separate NNs as well

        density = trunc_exp(density_before_activation.to(positions) - 1) ### trunc exp as in kplanes

        return density , features

    def get_outputs(self, ray_samples: RaySamples, density_embedding: Optional[TensorType] = None) -> Dict[FieldHeadNames, TensorType]:
        d = ray_samples.frustums.directions
        # positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
        # positions = positions * 2 - 1
        # rgb_features = self.density_rgb_encoding(positions)
        d = get_normalized_directions(d)
        d_encoded = self.direction_encoding(d)
        color_features = [d_encoded, density_embedding] # .view(-1, self.geo_feat_dim)]
        color_features = torch.cat(color_features, dim=-1)

        rgb = self.color_mlp(color_features)
        # rgb = self.field_output_rgb(rgb)

        return {FieldHeadNames.RGB: rgb}

    # def forward(
    #     self,
    #     ray_samples: RaySamples,

    #     mask: Optional[Tensor] = None,
    #     bg_color: Optional[Tensor] = None,
    # ) -> Dict[FieldHeadNames, Tensor]:

    #     if mask is not None and bg_color is not None:
    #         base_density = torch.zeros(ray_samples.shape)[:, :, None].to(mask.device)
    #         base_rgb = bg_color.repeat(ray_samples[:, :, None].shape)
    #         if mask.any():
    #             # print("here")
    #             input_rays = ray_samples[mask, :]
    #             density = self.get_density(input_rays)
    #             rgb = self.get_outputs(input_rays)
    #             # ipdb.set_trace()
    #             base_density[mask] = density
    #             base_rgb[mask] = rgb.to(base_rgb)

    #             base_density.requires_grad_()
    #             base_rgb.requires_grad_()

    #         density = base_density
    #         rgb = base_rgb
    #     else:
    #         density = self.get_density(ray_samples)
    #         rgb = self.get_outputs(ray_samples)

    #     return {FieldHeadNames.DENSITY: density, FieldHeadNames.RGB: rgb}
