# ORIGINAL LICENSE
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Modified by Jiale Xu
# The modifications are subject to the same license as the original.

import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
import spaces

from .utils.renderer import generate_planes, project_onto_planes, sample_from_planes


class OSGDecoder(nn.Module):
    """
    Triplane decoder that gives RGB and sigma values from sampled features.
    Using ReLU here instead of Softplus in the original implementation.
    
    Reference:
    EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
    """
    def __init__(self, n_features: int,
                 hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU, use_deformation_weight: bool = True):
        super().__init__()

        self.net_sdf = nn.Sequential(
            nn.Linear(3 * n_features, hidden_dim),
            activation(),
            *itertools.chain(*[[
                nn.Linear(hidden_dim, hidden_dim),
                activation(),
            ] for _ in range(num_layers - 2)]),
            nn.Linear(hidden_dim, 1),
        )

        self.net_rgb = nn.Sequential(
            nn.Linear(3 * n_features, hidden_dim),
            activation(),
            *itertools.chain(*[[
                nn.Linear(hidden_dim, hidden_dim),
                activation(),
            ] for _ in range(num_layers - 2)]),
            nn.Linear(hidden_dim, 3),
        )

        self.net_material = nn.Sequential(
            nn.Linear(3 * n_features, hidden_dim),
            activation(),
            *itertools.chain(*[[
                nn.Linear(hidden_dim, hidden_dim),
                activation(),
            ] for _ in range(num_layers - 2)]),
            nn.Linear(hidden_dim, 2),
        )

        if use_deformation_weight:
            self.net_deformation = nn.Sequential(
                nn.Linear(3 * n_features, hidden_dim),
                activation(),
                *itertools.chain(*[[
                    nn.Linear(hidden_dim, hidden_dim),
                    activation(),
                ] for _ in range(num_layers - 2)]),
                nn.Linear(hidden_dim, 3),
            )
            self.net_weight = nn.Sequential(
                nn.Linear(8 * 3 * n_features, hidden_dim),
                activation(),
                *itertools.chain(*[[
                    nn.Linear(hidden_dim, hidden_dim),
                    activation(),
                ] for _ in range(num_layers - 2)]),
                nn.Linear(hidden_dim, 21),
            )

        # init all bias to zero
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.zeros_(m.bias)

    def get_geometry_prediction(self, sampled_features, flexicubes_indices):
        _N, n_planes, _M, _C = sampled_features.shape
        sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)

        sdf = self.net_sdf(sampled_features)
        deformation = self.net_deformation(sampled_features)

        grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1)
        grid_features = grid_features.reshape(
            sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1])
        weight = self.net_weight(grid_features) * 0.1

        return sdf, deformation, weight
    
    @spaces.GPU
    def get_texture_prediction(self, sampled_features):
        _N, n_planes, _M, _C = sampled_features.shape
        sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)

        rgb = self.net_rgb(sampled_features)
        rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001  # Uses sigmoid clamping from MipNeRF
        
        materials = self.net_material(sampled_features)
        materials = torch.sigmoid(materials)
        metallic, roughness = materials[...,0], materials[...,1]
        rmax, rmin = 1.0, 0.04 ** 2
        roughness = roughness * (rmax - rmin) + rmin

        return rgb, metallic, roughness

    def get_sdf_prediction(self, sampled_features):
        return self.net_sdf(sampled_features)

class TriplaneSynthesizer(nn.Module):
    """
    Synthesizer that renders a triplane volume with planes and a camera.
    
    Reference:
    EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
    """

    DEFAULT_RENDERING_KWARGS = {
        'ray_start': 'auto',
        'ray_end': 'auto',
        'box_warp': 2.,
        'white_back': True,
        'disparity_space_sampling': False,
        'clamp_mode': 'softplus',
        'sampler_bbox_min': -1.,
        'sampler_bbox_max': 1.,
    }

    def __init__(self, triplane_dim: int, samples_per_ray: int, use_deformation_weight: bool = True):
        super().__init__()

        # attributes
        self.triplane_dim = triplane_dim
        self.rendering_kwargs = {
            **self.DEFAULT_RENDERING_KWARGS,
            'depth_resolution': samples_per_ray // 2,
            'depth_resolution_importance': samples_per_ray // 2,
        }


        # modules
        # self.register_buffer("plane_axes", generate_planes())
        self.plane_axes = generate_planes()
        self.decoder = OSGDecoder(n_features=triplane_dim, use_deformation_weight=use_deformation_weight)

    def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices):
        plane_axes = self.plane_axes.to(planes.device)
        sampled_features = sample_from_planes(
            plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp'])

        sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices)
        return sdf, deformation, weight
    
    @spaces.GPU
    def get_texture_prediction(self, planes, sample_coordinates):
        plane_axes = self.plane_axes.to(planes.device)
        sampled_features = sample_from_planes(
            plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp'])

        rgb, matellic, roughness = self.decoder.get_texture_prediction(sampled_features)
        return rgb, matellic, roughness

    def get_sdf_prediction(self, planes, sample_coordinates):
        """
        for eikonal loss
        Args:
            planes:
            sample_coordinates:

        Returns:
            sdf value
        """
        plane_axes = self.plane_axes.to(planes.device)
        sampled_features = sample_from_planes(
            plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp'])

        _N, n_planes, _M, _C = sampled_features.shape
        sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)

        sdf = self.decoder.get_sdf_prediction(sampled_features)
        return sdf
