"""
Field for compound nerf model, adds scene contraction and image embeddings to instant ngp
"""

from typing import Dict, Literal, Optional, Tuple

import torch
from torch import Tensor, nn

from nerfstudio.cameras.rays import RaySamples
from nerfstudio.data.scene_box import SceneBox
from nerfstudio.field_components.activations import trunc_exp
from nerfstudio.field_components.embedding import Embedding
from nerfstudio.field_components.field_heads import (
    FieldHeadNames,
    PredNormalsFieldHead,
    SemanticFieldHead,
    TransientDensityFieldHead,
    TransientRGBFieldHead,
    UncertaintyFieldHead,
)
from nerfstudio.field_components.spatial_distortions import SpatialDistortion
from nerfstudio.fields.base_field import Field, get_normalized_directions
from nerfstudio.field_components.encodings import SHEncoding

from eks.field.mlp import MLP, MLPWithHashEncoding
from eks.knnx.knn_algorithms import BaseKNN


class EksField(Field):
    """Compound Field

    Args:
        aabb: parameters of scene aabb bounds
        num_images: number of images in the dataset
        knn_algorithm: KNN algorithm to use for nearest neighbor search
        num_layers: number of hidden layers
        hidden_dim: dimension of hidden layers
        geo_feat_dim: output geo feat dimensions
        num_layers_color: number of hidden layers for color network
        n_features_per_gauss: number of features per Gaussian in the encoding
        hidden_dim_color: dimension of hidden layers for color network
        spatial_distortion: spatial distortion to apply to the scene
        seed_points: seed points for the encoding
    """

    aabb: Tensor

    def __init__(
        self,
        aabb: Tensor,
        num_images: int,
        knn_algorithm: BaseKNN,
        num_layers: int = 2,
        hidden_dim: int = 64,
        geo_feat_dim: int = 15,
        num_layers_color: int = 3,
        n_features_per_gauss: int = 32,
        hidden_dim_color: int = 64,
        spatial_distortion: Optional[SpatialDistortion] = None,
        implementation: Literal["tcnn", "torch"] = "tcnn",
        seed_points: Optional[Tensor] = None,
        densify: bool = True,
        prune: bool = True,
        unfreeze_means: bool = False,
    ) -> None:
        super().__init__()

        self.register_buffer("aabb", aabb)
        self.register_buffer("n_features_per_gauss", torch.tensor(n_features_per_gauss))

        self.geo_feat_dim = geo_feat_dim
        self.spatial_distortion = spatial_distortion
        self.num_images = num_images

        self.step = 0

        self.direction_encoding = SHEncoding(
            levels=4,
            implementation=implementation,
        )

        self.mlp_base = MLPWithHashEncoding(
            knn_algorithm=knn_algorithm,
            n_features_per_gauss=n_features_per_gauss,
            num_layers=num_layers,
            layer_width=hidden_dim,
            out_dim=1 + self.geo_feat_dim,
            activation=nn.ReLU(),
            out_activation=None,
            seed_points=seed_points,
            densify=densify,
            prune=prune,
            unfreeze_means=unfreeze_means,
            spatial_distortion=self.spatial_distortion,
        )

        self.mlp_head = MLP(
            in_dim=self.direction_encoding.get_out_dim() + self.geo_feat_dim,
            num_layers=num_layers_color,
            layer_width=hidden_dim_color,
            out_dim=3,
            activation=nn.ReLU(),
            out_activation=nn.Sigmoid(),
            implementation=implementation,
        )

    def get_sampling_positions(self, ray_samples: RaySamples) -> Tensor:
        """Computes and returns the sampling positions."""
        if self.spatial_distortion is not None:
            uncontracted_positions = ray_samples.frustums.get_positions()
            positions = self.spatial_distortion(uncontracted_positions)
            positions = (positions + 2.0) / 4.0
        else:
            positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
            uncontracted_positions = positions

        # Make sure the tcnn gets inputs between 0 and 1.
        selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1)
        positions = positions * selector[..., None]

        return positions, uncontracted_positions

    def get_density(self, ray_samples: RaySamples) -> Tuple[Tensor, Tensor]:
        """Computes and returns the densities."""
        positions, uncontracted_positions = self.get_sampling_positions(ray_samples)
        assert positions.numel() > 0, "positions is empty."

        self._sample_locations = positions
        if not self._sample_locations.requires_grad:
            self._sample_locations.requires_grad = True
        uncontracted_positions_flat = uncontracted_positions.view(-1, 3)

        assert uncontracted_positions_flat.numel() > 0, "uncontracted_positions_flat is empty."
        h = self.mlp_base(uncontracted_positions_flat).view(*ray_samples.frustums.shape, -1)
        density_before_activation, base_mlp_out = torch.split(h, [1, self.geo_feat_dim], dim=-1)
        self._density_before_activation = density_before_activation

        # Rectifying the density with an exponential is much more stable than a ReLU or
        # softplus, because it enables high post-activation (float32) density outputs
        # from smaller internal (float16) parameters.
        density = trunc_exp(density_before_activation.to(positions) - 1)
        selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1)
        density = density * selector[..., None]
        return density, base_mlp_out

    def get_outputs(
        self, ray_samples: RaySamples, density_embedding: Optional[Tensor] = None, direction_transform: Optional[torch.Tensor] = None
    ) -> Dict[FieldHeadNames, Tensor]:
        assert density_embedding is not None
        outputs = {}
        if ray_samples.camera_indices is None:
            raise AttributeError("Camera indices are not provided.")

        if direction_transform is not None:
            directions = ray_samples.frustums.directions
            rotated_dirs = torch.einsum("...ij,...j->...i", direction_transform, directions)
            directions = rotated_dirs
        else:
            directions = ray_samples.frustums.directions


        directions = get_normalized_directions(directions)
        directions_flat = directions.view(-1, 3)
        d = self.direction_encoding(directions_flat)

        outputs_shape = ray_samples.frustums.directions.shape[:-1]

        h = torch.cat(
            [
                d,
                density_embedding.view(-1, self.geo_feat_dim),
            ],
            dim=-1,
        )
        rgb = self.mlp_head(h).view(*outputs_shape, -1).to(directions)
        outputs.update({FieldHeadNames.RGB: rgb})

        return outputs
    
    def forward(self, ray_samples: RaySamples, direction_transform: Optional[torch.Tensor] = None, compute_normals: bool = False) -> Dict[FieldHeadNames, Tensor]:
        """Evaluates the field at points along the ray.

        Args:
            ray_samples: Samples to evaluate field on.
        """
        if compute_normals:
            with torch.enable_grad():
                density, density_embedding = self.get_density(ray_samples)
        else:
            density, density_embedding = self.get_density(ray_samples)

        field_outputs = self.get_outputs(ray_samples, density_embedding=density_embedding, direction_transform=direction_transform)
        field_outputs[FieldHeadNames.DENSITY] = density  # type: ignore

        if compute_normals:
            with torch.enable_grad():
                normals = self.get_normals()
            field_outputs[FieldHeadNames.NORMALS] = normals  # type: ignore
        return field_outputs
