# GeoNeRF is a generalizable NeRF model that renders novel views
# without requiring per-scene optimization. This software is the 
# implementation of the paper "GeoNeRF: Generalizing NeRF with 
# Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
# and Francois Fleuret.

# Copyright (c) 2022 ams International AG

# This file is part of GeoNeRF.
# GeoNeRF is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.

# GeoNeRF is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with GeoNeRF. If not, see <http://www.gnu.org/licenses/>.

# This file incorporates work covered by the following copyright and  
# permission notice:

    # MIT License

    # Copyright (c) 2021 apchenstu

    # Permission is hereby granted, free of charge, to any person obtaining a copy
    # of this software and associated documentation files (the "Software"), to deal
    # in the Software without restriction, including without limitation the rights
    # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    # copies of the Software, and to permit persons to whom the Software is
    # furnished to do so, subject to the following conditions:

    # The above copyright notice and this permission notice shall be included in all
    # copies or substantial portions of the Software.

    # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    # SOFTWARE.

import torch
import torch.nn.functional as F

from utils.utils import normal_vect, interpolate_3D, interpolate_2D


class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()

    def create_embedding_fn(self):
        embed_fns = []

        if self.kwargs["include_input"]:
            embed_fns.append(lambda x: x)

        max_freq = self.kwargs["max_freq_log2"]
        N_freqs = self.kwargs["num_freqs"]

        if self.kwargs["log_sampling"]:
            freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs)
        self.freq_bands = freq_bands.reshape(1, -1, 1).cuda()

        for freq in freq_bands:
            for p_fn in self.kwargs["periodic_fns"]:
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))

        self.embed_fns = embed_fns

    def embed(self, inputs):
        repeat = inputs.dim() - 1
        inputs_scaled = (
            inputs.unsqueeze(-2) * self.freq_bands.view(*[1] * repeat, -1, 1)
        ).reshape(*inputs.shape[:-1], -1)
        inputs_scaled = torch.cat(
            (inputs, torch.sin(inputs_scaled), torch.cos(inputs_scaled)), dim=-1
        )
        return inputs_scaled


def get_embedder(multires=4):

    embed_kwargs = {
        "include_input": True,
        "max_freq_log2": multires - 1,
        "num_freqs": multires,
        "log_sampling": True,
        "periodic_fns": [torch.sin, torch.cos],
    }

    embedder_obj = Embedder(**embed_kwargs)
    embed = lambda x, eo=embedder_obj: eo.embed(x)
    return embed


def sigma2weights(sigma, return_T=False, return_alpha=False, return_all=False):
    alpha = 1.0 - torch.exp(-sigma)
    if return_alpha: return alpha

    T = torch.cumprod(
        torch.cat(
            [torch.ones(*alpha.shape[0:-1], 1).to(alpha.device), 1.0 - alpha + 1e-10], -1
        ),
        -1,
    )[..., :-1]
    weights = alpha * T

    if return_T: return T
    
    if return_all:
        return alpha, T, weights
    else:
        return weights


def volume_rendering(rgb_sigma, pts_depth, white_bkgd=False, have_dist=False, gene_mask=False, have_gene=False, have_B=False, have_style=False):
    if gene_mask:
        S = rgb_sigma[..., 3].shape[1]
        rgb = rgb_sigma[..., :3]

        ray_cnt = torch.sum(rgb_sigma[..., 3],dim=1).reshape(-1,1).repeat(1,S)
        weights = torch.zeros_like(ray_cnt)
        weights[ray_cnt!=0] = 1/ray_cnt[ray_cnt!=0]

        rendered_rgb = torch.sum(weights[..., None] * rgb, -2)
        rendered_depth = torch.sum(weights * pts_depth, -1)
        
        return rendered_rgb, rendered_depth

    if have_dist:
        rgb_sigma, dist_args = rgb_sigma
        mean = dist_args['dist_mean'] # for rgb
        var = dist_args['dist_var'] # for rgb
    
    if have_gene:
        rgb_sigma, gene_rgb_sigma = rgb_sigma
    
    if have_B:
        rgb_sigma_all, B_all = rgb_sigma
    
    if have_style:
        rgb_sigma, style_rgb = rgb_sigma

    if have_B:
        rendered_rgb, rendered_depth, rendered_B = {}, {}, {}
        nan_len = {'rgb':0,'sigma':0,'B':0}
        for l in range(3):
            rgb_sigma = rgb_sigma_all[f"level_{l}"]
            rgb = rgb_sigma[..., :3]
            weights = sigma2weights(rgb_sigma[..., 3])

            rendered_rgb[f"level_{l}"] = torch.sum(weights[..., None] * rgb, -2)
            rendered_depth[f"level_{l}"] = torch.sum(weights * pts_depth, -1)
            rendered_B[f"level_{l}"] = torch.sum(weights * B_all[f"level_{l}"].squeeze(-1), -1)

            nan_len['rgb'] += len(torch.isnan(rgb).nonzero())
            nan_len['sigma'] += len(torch.isnan(rgb_sigma[..., 3]).nonzero())
            nan_len['B'] += len(torch.isnan(B_all[f"level_{l}"]).nonzero())

    else:
        rgb = rgb_sigma[..., :3]
        weights = sigma2weights(rgb_sigma[..., 3])

        rendered_rgb = torch.sum(weights[..., None] * rgb, -2)
        rendered_depth = torch.sum(weights * pts_depth, -1)
        if have_dist:
            rendered_rgb_mean = torch.sum((weights[..., None]) * mean, -2)
            rendered_rgb_var = torch.sum((weights[..., None]**2) * var.squeeze(-1), -2)
        
        if have_gene:
            rgb = gene_rgb_sigma[..., :3]
            weights = sigma2weights(gene_rgb_sigma[..., 3])

            gene_rendered_rgb = torch.sum(weights[..., None] * rgb, -2)
            gene_rendered_depth = torch.sum(weights * pts_depth, -1)
        
        if have_style:
            if isinstance(style_rgb, dict):
                rendered_style_rgb = {}
                rendered_style_rgb['c'] = torch.sum(weights[..., None] * style_rgb['c'], -2)
                rendered_style_rgb['c+s'] = torch.sum(weights[..., None] * style_rgb['c+s'], -2)
            else:
                rendered_style_rgb = torch.sum(weights[..., None] * style_rgb, -2)
    
    if white_bkgd:
        acc_map = torch.sum(weights, -1)
        rendered_rgb = rendered_rgb + (1.-acc_map[...,None])

    if have_dist:
        return rendered_rgb, rendered_depth, {'rendered_rgb_mean':rendered_rgb_mean, 'rendered_rgb_var':rendered_rgb_var}
    elif have_gene:
        return rendered_rgb, rendered_depth, {'gene_rendered_rgb':gene_rendered_rgb, 'gene_rendered_depth':gene_rendered_depth}
    elif have_B:
        return rendered_rgb, rendered_depth, rendered_B, nan_len
    elif have_style:
        return rendered_rgb, rendered_depth, rendered_style_rgb
    else:
        return rendered_rgb, rendered_depth


def get_angle_wrt_src_cams(c2ws, rays_pts, rays_dir_unit, return_angle_cos=False):
    nb_rays = rays_pts.shape[0]
    ## Unit vectors from source cameras to the points on the ray
    dirs = normal_vect(rays_pts.unsqueeze(2) - c2ws[:, :3, 3][None, None])
    ## Cosine of the angle between two directions
    angle_cos = torch.sum(
        dirs * rays_dir_unit.reshape(nb_rays, 1, 1, 3), dim=-1, keepdim=True
    )
    if return_angle_cos: return angle_cos
    # Cosine to Sine and approximating it as the angle (angle << 1 => sin(angle) = angle)
    angle = (1 - (angle_cos**2)).abs().sqrt()

    return angle


def interpolate_pts_feats(imgs, feats_fpn, feats_vol, rays_pts_ndc, other_kwarg=None, occ_masks=None, two_feats=False, feat3d_range_mask=False, content_style_feat=None):
    nb_views = feats_fpn.shape[1]
    interpolated_feats = []

    for i in range(nb_views):
        if two_feats:
            ray_feats_0_A = interpolate_3D(
                feats_vol["A"][f"level_0"][:, i], rays_pts_ndc[f"level_0"][:, :, i]
            )
            ray_feats_1_A = interpolate_3D(
                feats_vol["A"][f"level_1"][:, i], rays_pts_ndc[f"level_1"][:, :, i]
            )
            ray_feats_2_A = interpolate_3D(
                feats_vol["A"][f"level_2"][:, i], rays_pts_ndc[f"level_2"][:, :, i]
            )

            ray_feats_0_B = interpolate_3D(
                feats_vol["B"][f"level_0"][:, i], rays_pts_ndc[f"level_0"][:, :, i]
            )
            ray_feats_1_B = interpolate_3D(
                feats_vol["B"][f"level_1"][:, i], rays_pts_ndc[f"level_1"][:, :, i]
            )
            ray_feats_2_B = interpolate_3D(
                feats_vol["B"][f"level_2"][:, i], rays_pts_ndc[f"level_2"][:, :, i]
            )

            occ_masks_v = occ_masks[:,:,i] # 0: occ; 1: disocc
            dis_occ_masks_v = ~occ_masks_v
            ray_feats_0 = occ_masks_v*ray_feats_0_A + dis_occ_masks_v*ray_feats_0_B
            ray_feats_1 = occ_masks_v*ray_feats_1_A + dis_occ_masks_v*ray_feats_1_B
            ray_feats_2 = occ_masks_v*ray_feats_2_A + dis_occ_masks_v*ray_feats_2_B
        else:
            ray_feats_0 = interpolate_3D(
                feats_vol[f"level_0"][:, i], rays_pts_ndc[f"level_0"][:, :, i]
            )
            ray_feats_1 = interpolate_3D(
                feats_vol[f"level_1"][:, i], rays_pts_ndc[f"level_1"][:, :, i]
            )
            ray_feats_2 = interpolate_3D(
                feats_vol[f"level_2"][:, i], rays_pts_ndc[f"level_2"][:, :, i]
            )

        ray_feats_fpn, ray_colors, ray_masks = interpolate_2D(
            feats_fpn[:, i], imgs[:, i], rays_pts_ndc[f"level_0"][:, :, i]
        )
        feat_to_cat = [
                        ray_feats_0,
                        ray_feats_1,
                        ray_feats_2,
                        ray_feats_fpn,
                        ray_colors,
                        ray_masks,
                      ]
        
        if 'disocc_confi' in other_kwarg.keys():
            for l in range(3):
                ray_disocc_confi_l = interpolate_3D(
                    other_kwarg['disocc_confi'][f"level_{l}"][:, i], rays_pts_ndc[f"level_{l}"][:, :, i]
                ).unsqueeze(-1)
                feat_to_cat.append(ray_disocc_confi_l)
        
        if 'texFeat' in other_kwarg.keys():
            for l in range(3):
                texFeat_l = interpolate_3D(
                    other_kwarg['texFeat'][f"level_{l}"][:, i], rays_pts_ndc[f"level_{l}"][:, :, i]
                )
                feat_to_cat.append(texFeat_l)

        if 'feat_SA' in other_kwarg.keys():
            feat_SA, _, _ = interpolate_2D(
                other_kwarg['feat_SA'][:, i], imgs[:, i], rays_pts_ndc[f"level_0"][:, :, i]
            )
            feat_to_cat.append(feat_SA)
        
        if 'global_geoFeat' in other_kwarg.keys():
            global_geoFeat = interpolate_3D(
                other_kwarg['global_geoFeat'][:, i], rays_pts_ndc[f"level_2"][:, :, i]
            )
            feat_to_cat.append(global_geoFeat)
        
        if feat3d_range_mask:
            for l in range(3):
                grid = rays_pts_ndc[f"level_{l}"][:, :, i].unsqueeze(0) * 2 - 1.0  # [1 H W 3] (x,y,z)
                with torch.no_grad():
                    in_mask_3d = (grid > -1.0) * (grid < 1.0)
                    in_mask_3d = (in_mask_3d[..., 0] * in_mask_3d[..., 1] * in_mask_3d[..., 2]).float().permute(1, 2, 0)
                feat_to_cat.append(in_mask_3d)
        
        if 'common_feat' in other_kwarg.keys() and 'special_feat' in other_kwarg.keys():  
            # common
            for l in range(3):
                content_l = other_kwarg['common_feat'][f"level_{l}"].unsqueeze(0)
                ray_content_l, _, _ = interpolate_2D(
                    content_l[:, i], imgs[:, i], rays_pts_ndc[f"level_0"][:, :, i]
                )
                feat_to_cat.append(ray_content_l)
            for l in range(3):
                content_l = other_kwarg['special_feat'][f"level_{l}"].unsqueeze(0)
                ray_content_l, _, _ = interpolate_2D(
                    content_l[:, i], imgs[:, i], rays_pts_ndc[f"level_0"][:, :, i]
                )
                feat_to_cat.append(ray_content_l)

            N, S = rays_pts_ndc["level_0"].shape[:2]

            if content_style_feat["style"] is None :
               
                style = torch.randn(feat_to_cat[0].shape[0], feat_to_cat[0].shape[1], 8).cuda() #style.expand(N, S, -1) temporal fix this way to testing
                domain = torch.randn(feat_to_cat[0].shape[0], feat_to_cat[0].shape[1], 5 ).cuda() #domain.expand(N, S, -1)
            else:
                style, domain = content_style_feat["style"]["style_feat"], content_style_feat["style"]["domain"]
            style = style.expand(N, S, -1) 
            domain = domain.expand(N, S, -1)

            feat_to_cat.append(style)
            feat_to_cat.append(domain)
        else:
            if content_style_feat != None:
                content_all, style, domain = content_style_feat["content"], content_style_feat["style"]["style_feat"], content_style_feat["style"]["domain"], 
                for l in range(3):
                    # content_l = content_all[f"level_{l}"].unsqueeze(0)
                    content_l = other_kwarg['feats'][f"level_{l}"].unsqueeze(0)
                    ray_content_l, _, _ = interpolate_2D(
                        content_l[:, i], imgs[:, i], rays_pts_ndc[f"level_0"][:, :, i]
                    )
                    feat_to_cat.append(ray_content_l)

                N, S = rays_pts_ndc["level_0"].shape[:2]
                style = style.expand(N, S, -1)
                domain = domain.expand(N, S, -1)
                feat_to_cat.append(style)
                feat_to_cat.append(domain)
        
        if 'style3D' in other_kwarg.keys():
             for l in range(3):
                style3DFeat_l = interpolate_3D(
                    other_kwarg['style3D'][f"level_{l}"][:, i], rays_pts_ndc[f"level_{l}"][:, :, i]
                )
                feat_to_cat.append(style3DFeat_l)
            
        interpolated_feats.append(
            torch.cat(
                feat_to_cat,
                dim=-1,
            )
        )

    interpolated_feats = torch.stack(interpolated_feats, dim=2)

    return interpolated_feats


def get_occ_masks(depth_map_norm, rays_pts_ndc, visibility_thr=0.2):
    nb_views = depth_map_norm["level_0"].shape[1]
    z_diff = []
    for i in range(nb_views):
        ## Interpolate depth maps corresponding to each sample point
        # [1 H W 3] (x,y,z)
        grid = rays_pts_ndc[f"level_0"][None, :, :, i, :2] * 2 - 1.0
        rays_depths = F.grid_sample(
            depth_map_norm["level_0"][:, i : i + 1],
            grid,
            align_corners=True,
            mode="bilinear",
            padding_mode="border",
        )[0, 0]
        z_diff.append(rays_pts_ndc["level_0"][:, :, i, 2] - rays_depths)
    z_diff = torch.stack(z_diff, dim=2)

    occ_masks = z_diff.unsqueeze(-1) < visibility_thr

    return occ_masks


def render_rays(
    c2ws,
    rays_pts,
    rays_pts_ndc,
    pts_depth,
    rays_dir,
    feats_vol,
    feats_fpn,
    imgs,
    depth_map_norm,
    renderer_net,
    other_kwarg,
    have_dist=False,
    have_gene=False,
    use_att_3d=False,
    other_output=[],
    for_mask=None,
    two_feats=False,
    feat3d_range_mask=False,
    use_angle_cos=False,
    use_angle_both=False,
    have_B=False,
    content_style_feat=None,
    input_phi=None,
    return_what_style=None,
    input_z=None,
    alpha=None,
    delta_t_loss=False,
):
    renderer_other_output = {}

    ## The angles between the ray and source camera vectors
    rays_dir_unit = rays_dir / torch.norm(rays_dir, dim=-1, keepdim=True)
    angles = get_angle_wrt_src_cams(c2ws, rays_pts, rays_dir_unit, return_angle_cos=use_angle_cos)

    ## Positional encoding
    if use_angle_both:
        PE_angles = get_embedder()(angles) # angle diff (bs, n_sample, nb_view, 1*9)
        cos_angles = get_angle_wrt_src_cams(c2ws, rays_pts, rays_dir_unit, return_angle_cos=True)
        embedded_angles = {'PE':PE_angles, 'cos':cos_angles}
        embedded_novel_angles = get_embedder()(rays_dir_unit.reshape(-1,1,3).repeat(1,angles.shape[1],1)) # (bs, n_sample, 3*9)
    elif use_angle_cos:
        # not use positional encoding
        embedded_angles = angles
        embedded_novel_angles = get_embedder()(rays_dir_unit.reshape(-1,1,3).repeat(1,angles.shape[1],1)) # (bs, n_sample, 3*9)
    else:
        embedded_angles = get_embedder()(angles) # angle diff (bs, n_sample, nb_view, 1*9)
        embedded_novel_angles = get_embedder()(rays_dir_unit.reshape(-1,1,3).repeat(1,angles.shape[1],1)) # (bs, n_sample, 3*9)

    ## Getting Occlusion Masks based on predicted depths
    occ_masks = get_occ_masks(depth_map_norm, rays_pts_ndc)

    ## Interpolate all features for sample points
    pts_feat = interpolate_pts_feats(imgs, feats_fpn, feats_vol, rays_pts_ndc, other_kwarg, occ_masks, two_feats=two_feats, feat3d_range_mask=feat3d_range_mask, content_style_feat=content_style_feat)

    ## rendering sigma and RGB values
    if use_att_3d:
        rendered_net_output = renderer_net(embedded_angles, pts_feat, occ_masks, embedded_novel_angles, for_att_3d={'tex_feat_2':other_kwarg['texFeat']['level_2'], 'rays_pts_ndc_2':rays_pts_ndc["level_2"]})
    elif return_what_style != None:
        rendered_net_output = renderer_net(embedded_angles, pts_feat, occ_masks, embedded_novel_angles, for_mask=for_mask, input_phi=input_phi, return_what_style=return_what_style, z=input_z, alpha=alpha)
    else:
        rendered_net_output = renderer_net(embedded_angles, pts_feat, occ_masks, embedded_novel_angles, for_mask=for_mask, input_phi=input_phi, z=input_z, alpha=alpha)

    if delta_t_loss:
        import math
        t_rand = torch.rand(1).item()*(2*math.pi)
        _rendered_net_output = renderer_net(embedded_angles, pts_feat, occ_masks, embedded_novel_angles, for_mask=for_mask, input_phi=t_rand, return_what_style=return_what_style, z=input_z, alpha=alpha)
        _, _, rendered_style_rgb_tRand = volume_rendering(_rendered_net_output, pts_depth, have_dist=have_dist, gene_mask=(for_mask!=None), have_gene=have_gene, have_B=have_B, have_style=(content_style_feat!=None))
        renderer_other_output['rendered_style_rgb_tRand'] = rendered_style_rgb_tRand
        delta_cos = math.cos(renderer_net.t) - math.cos(t_rand)
        delta_sin = math.sin(renderer_net.t) - math.sin(t_rand)
        renderer_other_output['delta_cos'] = delta_cos
        renderer_other_output['delta_sin'] = delta_sin

    volume_rendering_output = volume_rendering(rendered_net_output, pts_depth, have_dist=have_dist, gene_mask=(for_mask!=None), have_gene=have_gene, have_B=have_B, have_style=(content_style_feat!=None))
    if have_dist:
        rgb_sigma, dist_args = rendered_net_output
        rendered_rgb, rendered_depth, rendered_dist_args = volume_rendering_output
    elif have_gene:
        rgb_sigma, gene_rgb_sigma = rendered_net_output
        rendered_rgb, rendered_depth, gene_rendered = volume_rendering_output
    elif have_B:
        rgb_sigma, B = rendered_net_output # dict
        rgb_sigma, B = rgb_sigma["level_2"], B["level_2"] # for other_output convience
        rendered_rgb, rendered_depth, rendered_B, nan_len = volume_rendering_output # dict
    elif content_style_feat != None:
        rgb_sigma, style_rgb = rendered_net_output
        rendered_rgb, rendered_depth, rendered_style_rgb = volume_rendering_output
    else:
        rgb_sigma = rendered_net_output
        rendered_rgb, rendered_depth = volume_rendering_output

    ## other_output
    if have_dist:
        renderer_other_output['rendered_dist_args'] = rendered_dist_args
    if have_gene:
        renderer_other_output['gene_rendered'] = gene_rendered
    if have_B:
        renderer_other_output['rendered_B'] = rendered_B
    if content_style_feat != None:
        renderer_other_output['rendered_style_rgb'] = rendered_style_rgb
    if 'T' in other_output:
        renderer_other_output['T'] = sigma2weights(rgb_sigma[..., 3], return_T=True)
    if 'rgb' in other_output:
        renderer_other_output['rgb'] = rgb_sigma[..., :3]
    if 'alpha' in other_output:
        renderer_other_output['alpha'] = sigma2weights(rgb_sigma[..., 3], return_alpha=True)
    if 'angle_cos' in other_output:
        renderer_other_output['angle_cos'] = get_angle_wrt_src_cams(c2ws, rays_pts, rays_dir_unit, return_angle_cos=True) # (bs, nb_sample, nb_view, 1)
    if 'input_color' in other_output:
        renderer_other_output['input_color'] = pts_feat[..., 24 + 8 : 24 + 8 + 3]
    if 'D' in other_output:
        renderer_other_output['D'] = pts_feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 3] # disocc_confi (level 0, 1, 2)
    if 'density' in other_output:
        renderer_other_output['density'] = rgb_sigma[..., 3]
    if 'mask' in other_output:
        renderer_other_output['mask'] = pts_feat[..., 24 + 8 + 3 : 24 + 8 + 3 + 1] * occ_masks # (bs, nb_sample, nb_views, 1)
    if 'nan_len' in other_output:
        renderer_other_output['nan_len'] = nan_len

    return rendered_rgb, rendered_depth, renderer_other_output
