"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""

from typing import Callable, List, Union

import numpy as np
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd

import torch.nn as nn
import math

try:
    import tinycudann as tcnn
except ImportError as e:
    print(
        f"Error: {e}! "
        "Please install tinycudann by: "
        "pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch"
    )
    exit()

def cart2sph(coord):
    azimuth = torch.atan2(coord[...,1:2],coord[...,0:1])
    elevation = torch.atan2(coord[...,2:3],torch.sqrt(coord[...,0:1]*coord[...,0:1] + coord[...,1:2]*coord[...,1:2]))
    r = torch.sqrt(coord[...,0:1]*coord[...,0:1] + coord[...,1:2]*coord[...,1:2] + coord[...,2:3]*coord[...,2:3])
#     return torch.cat([azimuth/math.pi, elevation/math.pi],dim=-1)
    return torch.cat([azimuth/math.pi, elevation*(-4)/math.pi-1],dim=-1)#, r

class _TruncExp(Function):  # pylint: disable=abstract-method
    # Implementation from torch-ngp:
    # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
    @staticmethod
    @custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, x):  # pylint: disable=arguments-differ
        ctx.save_for_backward(x)
        return torch.exp(x)

    @staticmethod
    @custom_bwd
    def backward(ctx, g):  # pylint: disable=arguments-differ
        x = ctx.saved_tensors[0]
        return g * torch.exp(torch.clamp(x, max=15))


trunc_exp = _TruncExp.apply


def contract_to_unisphere(
    x: torch.Tensor,
    aabb: torch.Tensor,
    ord: Union[str, int] = 2,
    #  ord: Union[float, int] = float("inf"),
    eps: float = 1e-6,
    derivative: bool = False,
):
    aabb_min, aabb_max = torch.split(aabb, 3, dim=-1)
    x = (x - aabb_min) / (aabb_max - aabb_min)
    x = x * 2 - 1  # aabb is at [-1, 1]
    mag = torch.linalg.norm(x, ord=ord, dim=-1, keepdim=True)
    mask = mag.squeeze(-1) > 1

    if derivative:
        dev = (2 * mag - 1) / mag**2 + 2 * x**2 * (
            1 / mag**3 - (2 * mag - 1) / mag**4
        )
        dev[~mask] = 1.0
        dev = torch.clamp(dev, min=eps)
        return dev
    else:
        x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
        x = x / 4 + 0.5  # [-inf, inf] is at [0, 1]
        return x


class NGPRadianceField(torch.nn.Module):
    """Instance-NGP Radiance Field"""

    def __init__(
        self,
        aabb: Union[torch.Tensor, List[float]],
        num_dim: int = 3,
        use_viewdirs: bool = True,
        density_activation: Callable = lambda x: trunc_exp(x - 1),
        unbounded: bool = False,
        base_resolution: int = 16,
        max_resolution: int = 4096,
        geo_feat_dim: int = 15,
        n_levels: int = 16,
        log2_hashmap_size: int = 19,
    ) -> None:
        super().__init__()
        if not isinstance(aabb, torch.Tensor):
            aabb = torch.tensor(aabb, dtype=torch.float32)
        self.register_buffer("aabb", aabb)
        self.num_dim = num_dim
        self.use_viewdirs = use_viewdirs
        self.density_activation = density_activation
        self.unbounded = unbounded
        self.base_resolution = base_resolution
        self.max_resolution = max_resolution
        self.geo_feat_dim = geo_feat_dim
        self.n_levels = n_levels
        self.log2_hashmap_size = log2_hashmap_size

        per_level_scale = np.exp(
            (np.log(max_resolution) - np.log(base_resolution)) / (n_levels - 1)
        ).tolist()

        if self.use_viewdirs:
            self.direction_encoding = tcnn.Encoding(
                n_input_dims=num_dim,
                encoding_config={
                    "otype": "Composite",
                    "nested": [
                        {
                            "n_dims_to_encode": 3,
                            "otype": "SphericalHarmonics",
                            "degree": 4,
                        },
                        # {"otype": "Identity", "n_bins": 4, "degree": 4},
                    ],
                },
            )

        self.mlp_base = tcnn.NetworkWithInputEncoding(
            n_input_dims=num_dim,
            n_output_dims=1 + self.geo_feat_dim,
            encoding_config={
                "otype": "HashGrid",
                "n_levels": n_levels,
                "n_features_per_level": 2,
                "log2_hashmap_size": 21,#log2_hashmap_size,
                "base_resolution": base_resolution,
                "per_level_scale": per_level_scale,
            },
            network_config={
                "otype": "FullyFusedMLP",
                "activation": "ReLU",
                "output_activation": "None",
                "n_neurons": 64,
                "n_hidden_layers": 1,
            },
        )
        self.coordnet = nn.Sequential(nn.Linear(3,128), nn.ReLU(), nn.Linear(128,3), nn.Tanh())
#         self.coordnet = tcnn.NetworkWithInputEncoding(
#             n_input_dims=3,
#             n_output_dims=3,
#             encoding_config={
#                 "otype": "Frequency",
#                 "n_frequencies": 10
#             },
#             network_config={
#                 "otype": "FullyFusedMLP",
#                 "activation": "ReLU",
#                 "output_activation": "None",
#                 "n_neurons": 128,
#                 "n_hidden_layers": 4,
#             },
#         )
#         self.coordnet = tcnn.Encoding(3,
#             {
#                 "otype": "Grid",
#                 "type": "Dense",
#                 "n_levels": 1,
#                 "n_features_per_level": 4,
#                 "base_resolution": 128,
#             },)
        
#         self.encoding1 = tcnn.Encoding(2,
#             {
#                 "otype": "HashGrid",
#                 "n_levels": n_levels,
#                 "n_features_per_level": 2,
#                 "log2_hashmap_size": log2_hashmap_size,
#                 "base_resolution": base_resolution,
#                 "per_level_scale": per_level_scale,
#             },)
        self.encoding1 = tcnn.NetworkWithInputEncoding(
            n_input_dims=4,
            n_output_dims=1 + self.geo_feat_dim,
            encoding_config={
                "otype": "HashGrid",
                "n_levels": n_levels,
                "n_features_per_level": 2,
                "log2_hashmap_size": log2_hashmap_size,
                "base_resolution": base_resolution,
                "per_level_scale": per_level_scale,
            },
            network_config={
                "otype": "FullyFusedMLP",
                "activation": "ReLU",
                "output_activation": "None",
                "n_neurons": 64,
                "n_hidden_layers": 1,
            },
        )
        self.encoding2 = tcnn.NetworkWithInputEncoding(
            n_input_dims=2,
            n_output_dims=1 + self.geo_feat_dim,
            encoding_config={
                "otype": "HashGrid",
                "n_levels": n_levels,
                "n_features_per_level": 2,
                "log2_hashmap_size": log2_hashmap_size,
                "base_resolution": base_resolution,
                "per_level_scale": per_level_scale,
            },
            network_config={
                "otype": "FullyFusedMLP",
                "activation": "ReLU",
                "output_activation": "None",
                "n_neurons": 64,
                "n_hidden_layers": 1,
            },
        )
        self.encoding3 = tcnn.NetworkWithInputEncoding(
            n_input_dims=2,
            n_output_dims=1 + self.geo_feat_dim,
            encoding_config={
                "otype": "HashGrid",
                "n_levels": n_levels,
                "n_features_per_level": 2,
                "log2_hashmap_size": log2_hashmap_size,
                "base_resolution": base_resolution,
                "per_level_scale": per_level_scale,
            },
            network_config={
                "otype": "FullyFusedMLP",
                "activation": "ReLU",
                "output_activation": "None",
                "n_neurons": 64,
                "n_hidden_layers": 1,
            },
        )
#         self.network = tcnn.Network(self.encoding.n_output_dims, 1 + self.geo_feat_dim,
#             {
#                 "otype": "FullyFusedMLP",
#                 "activation": "ReLU",
#                 "output_activation": "None",
#                 "n_neurons": 64,
#                 "n_hidden_layers": 1,
#             },)
        self.network = nn.Sequential(nn.Linear(self.encoding1.n_output_dims, 64), nn.ReLU(), nn.Linear(64,64), nn.ReLU(), nn.Linear(64,1 + self.geo_feat_dim))
#         self.mlp_base = nn.Sequential(self.encoding, self.network)
#         self.mlp_base = nn.Sequential(encoding, nn.Linear(encoding.n_output_dims, 64).half(), nn.ReLU().half(), nn.Linear(64,64).half(), nn.ReLU().half(), nn.Linear(64,1 + self.geo_feat_dim).half())
        self.peri = torch.nn.Parameter(torch.cat([torch.ones((1, 15, 10, 3)), torch.zeros((1, 15, 10, 3))],dim=1))
        
        if self.geo_feat_dim > 0:
            self.mlp_head = tcnn.Network(
                n_input_dims=(
                    (
                        self.direction_encoding.n_output_dims
                        if self.use_viewdirs
                        else 0
                    )
                    + self.geo_feat_dim
                ),
                n_output_dims=3,
                network_config={
                    "otype": "FullyFusedMLP",
                    "activation": "ReLU",
                    "output_activation": "None",
                    "n_neurons": 64,
                    "n_hidden_layers": 2,
                },
            )
#             self.mlp_head = nn.Sequential(nn.Linear(self.direction_encoding.n_output_dims+self.geo_feat_dim, 64), nn.ReLU(), nn.Linear(64,64), nn.ReLU(), nn.Linear(64,64), nn.ReLU(), nn.Linear(64,3))

#         self.gammas = torch.nn.ParameterList()
#         self.hidden_layers = torch.nn.ModuleList()
#         self.net_depth = 3
#         featureC = 64
        
#         in_features = self.direction_encoding.n_output_dims + self.geo_feat_dim
#         for i in range(self.net_depth):
#             self.hidden_layers.append(
#                 torch.nn.Linear(in_features, featureC)
#             )
#             in_features = featureC
#             self.gammas.append(torch.nn.Parameter(torch.cat([torch.ones((1, 1, 10, 3)), torch.zeros((1, 1, 10, 3))],dim=1)))
#         self.hidden_layers.append(torch.nn.Linear(in_features, 3))
#         self.act = torch.nn.GELU()
        
    def query_density(self, x, return_feat: bool = False):
#         print("a",torch.amin(torch.amin(x, dim=0),dim=0),torch.amax(torch.amax(x, dim=0),dim=0))
        if self.unbounded:
            x = contract_to_unisphere(x, self.aabb)
        else:
            aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
            x = (x - aabb_min) / (aabb_max - aabb_min)
        selector = ((x > 0.0) & (x < 1.0)).all(dim=-1)
#         print("b",torch.amin(torch.amin(x, dim=0),dim=0),torch.amax(torch.amax(x, dim=0),dim=0))
#         ch = x.size(1)
#         x = x + self.coordnet(x.view(-1,3)).view(-1,ch,4)[:,:,:3]
#         x = x + self.coordnet(x)[:,:3]
#         sh = x.shape
#         x = self.network(self.encoding(x.view(-1, self.num_dim)).float()) + self.coordnet(cart2sph(x).view(-1, 2))
#         x = (
#             x
#             .view(list(sh[:-1]) + [1 + self.geo_feat_dim])
#             .to(x)
#         )
 
#         x = self.coordnet(x.view(-1, self.num_dim))[:,:3].view()
#         x = (
#             self.mlp_base(self.coordnet(x.view(-1, self.num_dim))[:,:3])
#             .view(list(x.shape[:-1]) + [1 + self.geo_feat_dim])
#             .to(x)
#         )
        sh = x.shape
        x = x.view(-1, self.num_dim)
        x = self.encoding1(x[:,[0,1]])*self.encoding2(x[:,[0,2]])*self.encoding3(x[:,[1,2]])
        x = self.network(x.float())
        x = (x.view(list(sh[:-1]) + [1 + self.geo_feat_dim]).to(x))
#         print(x.shape)
#         x = (
#             self.mlp_base(x.view(-1, self.num_dim))
#             .view(list(x.shape[:-1]) + [1 + self.geo_feat_dim])
#             .to(x)
#         )
#         print(x.shape)
        density_before_activation, base_mlp_out = torch.split(
            x, [1, self.geo_feat_dim], dim=-1
        )
        density = (
            self.density_activation(density_before_activation)
            * selector[..., None]
        )
        if return_feat:
            return density, base_mlp_out
        else:
            return density

    def _query_rgb(self, dir, embedding, apply_act: bool = True):
        # tcnn requires directions in the range [0, 1]
#         print(dir.shape, embedding.shape)
#         embedding = torch.nn.functional.group_norm(embedding,embedding.size(1))
#         x = F.normalize(x)
#         affine = torch.nn.functional.grid_sample(self.peri, cart2sph(dir[:,0]).view(1,1,-1,2), align_corners=True).view(2,15,-1,1).permute(0,2,3,1)
#         embedding = affine[0]*embedding + affine[1]
        if self.use_viewdirs:
#             print(torch.amin(torch.amin(cart2sph(dir), dim=0),dim=0))
#             print(torch.amax(torch.amax(cart2sph(dir), dim=0),dim=0))
            dir = (dir + 1.0) / 2.0
#             print(torch.amin(torch.amin(dir, dim=0),dim=0))
            d = self.direction_encoding(dir.reshape(-1, dir.shape[-1]))
            h = torch.cat([d, embedding.reshape(-1, self.geo_feat_dim)], dim=-1)
        else:
            h = embedding.reshape(-1, self.geo_feat_dim)
        
#         x = h
#         for i in range(self.net_depth):
#             x = self.hidden_layers[i](x)
#             if i==0:
#                 affine = torch.nn.functional.grid_sample(self.gammas[i], cart2sph(dir).view(1,1,-1,2), align_corners=True).view(2,-1,1)
#     #             print(x.shape, affine.shape)
#                 x = affine[0]*x + affine[1]
#             x = self.act(x)
#         rgb = (self.hidden_layers[-1](x).reshape(list(embedding.shape[:-1]) + [3]).to(embedding))
#         print(h.shape)
        rgb = (
            self.mlp_head(h)
            .reshape(list(embedding.shape[:-1]) + [3])
            .to(embedding)
        )
#         print(rgb.shape)
        if apply_act:
            rgb = torch.sigmoid(rgb)
        return rgb

    def forward(
        self,
        positions: torch.Tensor,
        directions: torch.Tensor = None,
    ):
#         ch = positions.size(1)
#         positions = positions.clone() + self.coordnet(positions.clone().view(-1,3)).view(-1,ch,3)
        if self.use_viewdirs and (directions is not None):
            assert (
                positions.shape == directions.shape
            ), f"{positions.shape} v.s. {directions.shape}"
            density, embedding = self.query_density(positions, return_feat=True)
            rgb = self._query_rgb(directions, embedding=embedding)
        return rgb, density  # type: ignore


class NGPDensityField(torch.nn.Module):
    """Instance-NGP Density Field used for resampling"""

    def __init__(
        self,
        aabb: Union[torch.Tensor, List[float]],
        num_dim: int = 3,
        density_activation: Callable = lambda x: trunc_exp(x - 1),
        unbounded: bool = False,
        base_resolution: int = 16,
        max_resolution: int = 128,
        n_levels: int = 5,
        log2_hashmap_size: int = 17,
    ) -> None:
        super().__init__()
        if not isinstance(aabb, torch.Tensor):
            aabb = torch.tensor(aabb, dtype=torch.float32)
        self.register_buffer("aabb", aabb)
        self.num_dim = num_dim
        self.density_activation = density_activation
        self.unbounded = unbounded
        self.base_resolution = base_resolution
        self.max_resolution = max_resolution
        self.n_levels = n_levels
        self.log2_hashmap_size = log2_hashmap_size

        per_level_scale = np.exp(
            (np.log(max_resolution) - np.log(base_resolution)) / (n_levels - 1)
        ).tolist()

        self.mlp_base = tcnn.NetworkWithInputEncoding(
            n_input_dims=num_dim,
            n_output_dims=1,
            encoding_config={
                "otype": "HashGrid",
                "n_levels": n_levels,
                "n_features_per_level": 2,
                "log2_hashmap_size": log2_hashmap_size,
                "base_resolution": base_resolution,
                "per_level_scale": per_level_scale,
            },
            network_config={
                "otype": "FullyFusedMLP",
                "activation": "ReLU",
                "output_activation": "None",
                "n_neurons": 64,
                "n_hidden_layers": 1,
            },
        )
        
#         encoding = tcnn.Encoding(num_dim,
#             {
#                "otype": "HashGrid",
#                 "n_levels": n_levels,
#                 "n_features_per_level": 2,
#                 "log2_hashmap_size": log2_hashmap_size,
#                 "base_resolution": base_resolution,
#                 "per_level_scale": per_level_scale,
#             },
#         )
#         network = tcnn.Network(encoding.n_output_dims, 1,
#             {
#                 "otype": "FullyFusedMLP",
#                 "activation": "ReLU",
#                 "output_activation": "None",
#                 "n_neurons": 64,
#                 "n_hidden_layers": 1,
#             },)
#         self.mlp_base = nn.Sequential(encoding, nn.Linear(encoding.n_output_dims, 64).half(), nn.ReLU().half(), nn.Linear(64,64).half(), nn.ReLU().half(), nn.Linear(64,1).half())

#         self.coordnet = nn.Sequential(nn.Linear(3,128), nn.Tanh(), nn.Linear(128,128), nn.Tanh(), nn.Linear(128,128), nn.Tanh(), nn.Linear(128,3), nn.Tanh())
#         self.coordnet = tcnn.NetworkWithInputEncoding(
#             n_input_dims=3,
#             n_output_dims=3,
#             encoding_config={
#                 "otype": "Frequency",
#                 "n_frequencies": 10
#             },
#             network_config={
#                 "otype": "FullyFusedMLP",
#                 "activation": "ReLU",
#                 "output_activation": "None",
#                 "n_neurons": 128,
#                 "n_hidden_layers": 4,
#             },
#         )
        self.coordnet = tcnn.Encoding(3,
            {
                "otype": "Grid",
                "type": "Dense",
                "n_levels": 1,
                "n_features_per_level": 4,
                "base_resolution": 128,
            },)
    
        
    def forward(self, positions: torch.Tensor):
        if self.unbounded:
            positions = contract_to_unisphere(positions, self.aabb)
        else:
            aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
            positions = (positions - aabb_min) / (aabb_max - aabb_min)
        selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1)
        
#         ch = positions.size(1)
#         positions = positions + self.coordnet(positions.view(-1,3)).view(-1,ch,4)[:,:,:3]
        
        density_before_activation = (
            self.mlp_base(positions.view(-1, self.num_dim))
            .view(list(positions.shape[:-1]) + [1])
            .to(positions)
        )
        density = (
            self.density_activation(density_before_activation)
            * selector[..., None]
        )
        return density
