# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""GAplanes Field with shared features"""

from typing import Dict, Optional

import torch
from torch import Tensor, nn
from torch.nn.parameter import Parameter

from nerfstudio.cameras.rays import RaySamples
from nerfstudio.data.scene_box import SceneBox
from nerfstudio.field_components.encodings import Encoding, Identity, SHEncoding
from nerfstudio.field_components.field_heads import FieldHeadNames, RGBFieldHead
from nerfstudio.field_components.mlp import MLP
from nerfstudio.fields.base_field import Field
from nerfstudio.field_components.activations import trunc_exp
import ipdb

class GAplanesSharedField(Field):
    """GAplanes Field (shared features)"""

    def __init__(
        self,
        aabb: Tensor,
        # the aabb bounding box of the dataset

        direction_encoding: Encoding = Identity(in_dim=3),
        # the encoding method used for ray direction
        density_rgb_encoding: Encoding = Identity(in_dim=3),
        # the tensor encoding method used for scene density and color
        head_mlp_num_layers: int = 2,
        # number of layers for the MLP
        head_mlp_layer_width: int = 128,
        # layer width for the MLP

    ) -> None:
        super().__init__()
        self.aabb = Parameter(aabb, requires_grad=False)
        self.direction_encoding = direction_encoding ### fourier encoding (like k-planes)
        self.density_rgb_encoding = density_rgb_encoding #### get density features

        print(self.density_rgb_encoding.get_out_dim())
        self.density_mlp = nn.Sequential(                
                nn.Linear(self.density_rgb_encoding.get_out_dim(), 128),
                nn.ReLU(inplace=True),                
                nn.Linear(128, 1),
            )

        self.color_mlp = MLP(
            in_dim=self.density_rgb_encoding.get_out_dim() + self.direction_encoding.get_out_dim(),
            num_layers=head_mlp_num_layers,
            layer_width=head_mlp_layer_width,
            activation=nn.ReLU(),
            out_activation=nn.ReLU(), # this increases the mlp depth by a layer since there's also the rgbFieldHead
        )

        
        self.field_output_rgb = RGBFieldHead(in_dim=self.color_mlp.get_out_dim(), activation=nn.Sigmoid()) ## activation(linear)


    def get_density(self, ray_samples: RaySamples) -> Tensor:
        positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
        positions = positions * 2 - 1
        features = self.density_rgb_encoding(positions)

        density = self.density_mlp(features)
        density = trunc_exp(density-1) ### trunc exp as in kplanes

        return density

    def get_outputs(self, ray_samples: RaySamples) -> Tensor:
        d = ray_samples.frustums.directions
        positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
        positions = positions * 2 - 1
        rgb_features = self.density_rgb_encoding(positions)

        d_encoded = self.direction_encoding(d)
        out = self.color_mlp(torch.cat([rgb_features, d_encoded], dim=-1))  
        rgb = self.field_output_rgb(out)

        return rgb

    def forward(
        self,
        ray_samples: RaySamples,
        compute_normals: bool = False,
        mask: Optional[Tensor] = None,
        bg_color: Optional[Tensor] = None,
    ) -> Dict[FieldHeadNames, Tensor]:
        if compute_normals is True:
            raise ValueError("Surface normals are not currently supported with TensoRF")
        if mask is not None and bg_color is not None:
            base_density = torch.zeros(ray_samples.shape)[:, :, None].to(mask.device)
            base_rgb = bg_color.repeat(ray_samples[:, :, None].shape)
            if mask.any():
                # print("here")
                input_rays = ray_samples[mask, :]
                density = self.get_density(input_rays)
                rgb = self.get_outputs(input_rays)

                base_density[mask] = density
                base_rgb[mask] = rgb

                base_density.requires_grad_()
                base_rgb.requires_grad_()

            density = base_density
            rgb = base_rgb
        else:
            density = self.get_density(ray_samples)
            rgb = self.get_outputs(ray_samples)

        return {FieldHeadNames.DENSITY: density, FieldHeadNames.RGB: rgb}
