# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""GAplanes Field"""

from typing import Dict, Optional

import torch
from torch import Tensor, nn
from torch.nn.parameter import Parameter

from nerfstudio.cameras.rays import RaySamples
from nerfstudio.data.scene_box import SceneBox
from nerfstudio.field_components.encodings import Encoding, Identity, SHEncoding
from nerfstudio.field_components.field_heads import FieldHeadNames, RGBFieldHead
from nerfstudio.field_components.mlp import MLP
from nerfstudio.fields.base_field import Field
from nerfstudio.field_components.activations import trunc_exp
import ipdb

class GAplanesField(Field):
    """GAplanes Field"""

    def __init__(
        self,
        aabb: Tensor,
        # the aabb bounding box of the dataset
        feature_encoding: Encoding = Identity(in_dim=3),
        # the encoding method used for appearance encoding outputs
        direction_encoding: Encoding = Identity(in_dim=3),
        # the encoding method used for ray direction
        density_encoding: Encoding = Identity(in_dim=3),
        # the tensor encoding method used for scene density
        color_encoding: Encoding = Identity(in_dim=3),
        # the tensor encoding method used for scene color
        appearance_dim: int = 27,
        # the number of dimensions for the appearance embedding
        head_mlp_num_layers: int = 2,
        # number of layers for the MLP
        head_mlp_layer_width: int = 128,
        # layer width for the MLP
        use_sh: bool = False,
        # whether to use spherical harmonics as the feature decoding function
        sh_levels: int = 2,
        # number of levels to use for spherical harmonics
    ) -> None:
        super().__init__()
        self.aabb = Parameter(aabb, requires_grad=False)
        self.feature_encoding = feature_encoding #### identity
        self.direction_encoding = direction_encoding ### fourier encoding (like k-planes)
        self.density_encoding = density_encoding #### get density features
        self.color_encoding = color_encoding #### get color features

        # self.density_mlp = MLP(
        #     in_dim=self.density_encoding.get_out_dim(),
        #     out_dim = 1,
        #     num_layers=head_mlp_num_layers,
        #     layer_width=head_mlp_layer_width,
        #     activation=nn.ReLU(),
        #     # out_activation=nn.ReLU(), ##### using relu as in tensorf; k-planes uses truncated exponential activation!
        # )
        self.density_mlp = nn.Sequential(                
                nn.Linear(self.density_encoding.get_out_dim(), 128),
                nn.ReLU(inplace=True),                
                nn.Linear(128, 1),
            )

        self.color_mlp = MLP(
            in_dim=self.color_encoding.get_out_dim() + self.direction_encoding.get_out_dim(),
            num_layers=head_mlp_num_layers,
            layer_width=head_mlp_layer_width,
            activation=nn.ReLU(),
            out_activation=nn.ReLU(), # this increases the mlp depth by a layer since there's also the rgbFieldHead
        )

        self.use_sh = use_sh ## might try this as well

        if self.use_sh:
            self.sh = SHEncoding(sh_levels)
        #     self.B = nn.Linear(
        #         in_features=self.color_encoding.get_out_dim(), out_features=3 * self.sh.get_out_dim(), bias=False
        #     )
        # else:
        #     self.B = nn.Linear(in_features=self.color_encoding.get_out_dim(), out_features=appearance_dim, bias=False)
        ###### omit B for now

        self.field_output_rgb = RGBFieldHead(in_dim=self.color_mlp.get_out_dim(), activation=nn.Sigmoid()) ## activation(linear)

    def get_density(self, ray_samples: RaySamples) -> Tensor:
        positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
        positions = positions * 2 - 1
        density_features = self.density_encoding(positions)
        density = self.density_mlp(density_features)
        density = torch.exp(density)

        # print(torch.max(densityyy),"dnw")

        # ### why relu at the output?? (trunc exp for k-planes)
        # density_enc = torch.sum(density_features, dim=-1)[:, :, None]
        # relu = torch.nn.ReLU()
        # density = relu(density_enc)
        # print(torch.max(density),"w")
        # ipdb.set_trace()
        return density#, torch.exp(densityyy)

    def get_outputs(self, ray_samples: RaySamples) -> Tensor:
        d = ray_samples.frustums.directions
        positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
        positions = positions * 2 - 1
        rgb_features = self.color_encoding(positions)


        if self.use_sh:
            sh_mult = self.sh(d)[:, :, None]
            rgb_sh = rgb_features.view(sh_mult.shape[0], sh_mult.shape[1], 3, sh_mult.shape[-1])
            rgb = torch.relu(torch.sum(sh_mult * rgb_sh, dim=-1) + 0.5)
        else:
            d_encoded = self.direction_encoding(d)
            # rgb_features_encoded = self.feature_encoding(rgb_features)

            out = self.color_mlp(torch.cat([rgb_features, d_encoded], dim=-1))  
            rgb = self.field_output_rgb(out)

        return rgb

    def forward(
        self,
        ray_samples: RaySamples,
        compute_normals: bool = False,
        mask: Optional[Tensor] = None,
        bg_color: Optional[Tensor] = None,
    ) -> Dict[FieldHeadNames, Tensor]:
        if compute_normals is True:
            raise ValueError("Surface normals are not currently supported with TensoRF")
        if mask is not None and bg_color is not None:
            base_density = torch.zeros(ray_samples.shape)[:, :, None].to(mask.device)
            base_rgb = bg_color.repeat(ray_samples[:, :, None].shape)
            if mask.any():
                # print("here")
                input_rays = ray_samples[mask, :]
                density = self.get_density(input_rays)
                rgb = self.get_outputs(input_rays)

                base_density[mask] = density
                base_rgb[mask] = rgb

                base_density.requires_grad_()
                base_rgb.requires_grad_()

            density = base_density
            rgb = base_rgb
        else:
            density = self.get_density(ray_samples)
            rgb = self.get_outputs(ray_samples)

        return {FieldHeadNames.DENSITY: density, FieldHeadNames.RGB: rgb}
