"""
NeRF differentiable renderer.
References:
https://github.com/bmild/nerf
https://github.com/kwea123/nerf_pl
"""
import torch
import torch.nn.functional as F
import util
import torch.autograd.profiler as profiler
from torch.nn import DataParallel
from dotmap import DotMap
import numpy as np
from .common import (
    get_mask, image_points_to_world, origin_to_world)


def save_pc(PC, PC_color, filename):
    from plyfile import PlyElement, PlyData
    PC = np.concatenate((PC, PC_color), axis=1)
    PC = [tuple(element) for element in PC]
    el = PlyElement.describe(np.array(PC, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]), 'vertex')
    PlyData([el]).write(filename)

class _RenderWrapper(torch.nn.Module):
    def __init__(self, net, renderer, simple_output):
        super().__init__()
        self.net = net
        self.renderer = renderer
        self.simple_output = simple_output

    def forward(self, rays, want_weights=False, **kwargs):
        device=rays.device
        if rays.shape[0] == 0:
            return (
                torch.zeros(0, 3, device=device),
                torch.zeros(0, device=device),
            )
        # print(id(self.net))
        outputs = self.renderer(
            self.net, rays, want_weights=want_weights and not self.simple_output, **kwargs
        )

        if self.simple_output:
            # raise NotImplementedError
            if self.renderer.using_fine:
                return outputs.fine
            else:
                return outputs.coarse
        else:
            # Make DotMap to dict to support DataParallel
            return outputs.toDict()

class _UniRenderWrapper(torch.nn.Module):
    def __init__(self, net, renderer, simple_output):
        super().__init__()
        self.net = net
        self.renderer = renderer
        self.simple_output = simple_output

    def forward(self, rays, src_images, src_poses,
                all_focals, c, want_weights=False, **kwargs):
        device=rays.device
        if rays.shape[0] == 0:
            return (
                torch.zeros(0, 3, device=device),
                torch.zeros(0, device=device),
            )
        
        self.net.encode(
            src_images,
            src_poses,
            all_focals,
            c=c,
        )

        outputs = self.renderer(
            self.net, rays, want_weights=want_weights and not self.simple_output, **kwargs
        )
        # import pdb; pdb.set_trace()
        # import pdb; pdb.set_trace()
        # print(outputs)
        if self.simple_output:
            # raise NotImplementedError
            if self.renderer.using_fine:
                return outputs.fine
            else:
                return outputs.coarse
        else:
            # Make DotMap to dict to support DataParallel
            return outputs.toDict()

class _UniPtsWrapper(torch.nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        # print("net ", id(self.net))

    def forward(self, rays, src_images, src_poses,
                all_focals, c, coarse, viewdirs, **kwargs):
        # print(rays.shape, rays.device)

        self.net.encode(
            src_images,
            src_poses,
            all_focals,
            c=c,
        )
        outputs = self.net(rays, coarse=coarse, viewdirs=viewdirs, **kwargs)
        return outputs

class _PtsWrapper(torch.nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        # print("net ", id(self.net))

    def forward(self, rays, coarse, viewdirs, **kwargs):
        # print(rays.shape, rays.device)
        outputs = self.net(rays, coarse=coarse, viewdirs=viewdirs, **kwargs)
        return outputs

class NeRFRenderer(torch.nn.Module):
    """
    NeRF differentiable renderer
    :param n_coarse number of coarse (binned uniform) samples
    :param n_fine number of fine (importance) samples
    :param n_fine_depth number of expected depth samples
    :param noise_std noise to add to sigma. We do not use it
    :param depth_std noise for depth samples
    :param eval_batch_size ray batch size for evaluation
    :param white_bkgd if true, background color is white; else black
    :param lindisp if to use samples linear in disparity instead of distance
    :param sched ray sampling schedule. list containing 3 lists of equal length.
    sched[0] is list of iteration numbers,
    sched[1] is list of coarse sample numbers,
    sched[2] is list of fine sample numbers
    """

    def __init__(
        self,
        n_coarse=128,
        n_fine=0,
        n_fine_depth=0,
        noise_std=0.0,
        depth_std=0.01,
        eval_batch_size=100000,
        white_bkgd=False,
        lindisp=False,
        sched=None,  # ray sampling schedule for coarse and fine rays
        n_classes=2,
        use_rgb_head=True,
        use_seg_head=True,
    ):
        super().__init__()
        self.n_coarse = n_coarse
        self.n_fine = n_fine
        self.n_fine_depth = n_fine_depth

        self.noise_std = noise_std
        self.depth_std = depth_std

        self.eval_batch_size = eval_batch_size
        self.white_bkgd = white_bkgd
        self.lindisp = lindisp
        self.n_classes = n_classes
        self.use_rgb_head = use_rgb_head
        self.use_seg_head = use_seg_head

        if lindisp:
            print("Using linear displacement rays")
        self.using_fine = n_fine > 0
        self.sched = sched
        if sched is not None and len(sched) == 0:
            self.sched = None
        self.register_buffer(
            "iter_idx", torch.tensor(0, dtype=torch.long), persistent=True
        )
        self.register_buffer(
            "last_sched", torch.tensor(0, dtype=torch.long), persistent=True
        )

    def sample_coarse(self, rays):
        """
        Stratified sampling. Note this is different from original NeRF slightly.
        :param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8)
        :return (B, Kc)
        """
        device = rays.device
        near, far = rays[:, -2:-1], rays[:, -1:]  # (B, 1)

        step = 1.0 / self.n_coarse
        B = rays.shape[0]
        z_steps = torch.linspace(0, 1 - step, self.n_coarse, device=device)  # (Kc)
        z_steps = z_steps.unsqueeze(0).repeat(B, 1)  # (B, Kc)
        z_steps += torch.rand_like(z_steps) * step
        if not self.lindisp:  # Use linear sampling in depth space
            return near * (1 - z_steps) + far * z_steps  # (B, Kf)
        else:  # Use linear sampling in disparity space
            return 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps)  # (B, Kf)

        # Use linear sampling in depth space
        return near * (1 - z_steps) + far * z_steps  # (B, Kc)

    def sample_fine(self, rays, weights):
        """
        Weighted stratified (importance) sample
        :param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8)
        :param weights (B, Kc)
        :return (B, Kf-Kfd)
        """
        device = rays.device
        B = rays.shape[0]

        weights = weights.detach() + 1e-5  # Prevent division by zero
        pdf = weights / torch.sum(weights, -1, keepdim=True)  # (B, Kc)
        cdf = torch.cumsum(pdf, -1)  # (B, Kc)
        cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1)  # (B, Kc+1)

        u = torch.rand(
            B, self.n_fine - self.n_fine_depth, dtype=torch.float32, device=device
        )  # (B, Kf)
        inds = torch.searchsorted(cdf, u, right=True).float() - 1.0  # (B, Kf)
        inds = torch.clamp_min(inds, 0.0)

        z_steps = (inds + torch.rand_like(inds)) / self.n_coarse  # (B, Kf)

        near, far = rays[:, -2:-1], rays[:, -1:]  # (B, 1)
        if not self.lindisp:  # Use linear sampling in depth space
            z_samp = near * (1 - z_steps) + far * z_steps  # (B, Kf)
        else:  # Use linear sampling in disparity space
            z_samp = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps)  # (B, Kf)
        return z_samp

    def sample_fine_depth(self, rays, depth):
        """
        Sample around specified depth
        :param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8)
        :param depth (B)
        :return (B, Kfd)
        """
        z_samp = depth.unsqueeze(1).repeat((1, self.n_fine_depth))
        z_samp += torch.randn_like(z_samp) * self.depth_std
        # Clamp does not support tensor bounds
        z_samp = torch.max(torch.min(z_samp, rays[:, -1:]), rays[:, -2:-1])
        return z_samp

    def composite(self, model, rays, z_samp, coarse=True, sb=0):
        """
        Render RGB and depth for each ray using NeRF alpha-compositing formula,
        given sampled positions along each ray (see sample_*)
        :param model should return (B, (r, g, b, sigma)) when called with (B, (x, y, z))
        should also support 'coarse' boolean argument
        :param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8)
        :param z_samp z positions sampled for each ray (B, K)
        :param coarse whether to evaluate using coarse NeRF
        :param sb super-batch dimension; 0 = disable
        :return weights (B, K), rgb (B, 3), depth (B)
        """
        with profiler.record_function("renderer_composite"):
            B, K = z_samp.shape

            deltas = z_samp[:, 1:] - z_samp[:, :-1]  # (B, K-1)
            #  if far:
            #      delta_inf = 1e10 * torch.ones_like(deltas[:, :1])  # infty (B, 1)
            delta_inf = rays[:, -1:] - z_samp[:, -1:]
            deltas = torch.cat([deltas, delta_inf], -1)  # (B, K)

            # (B, K, 3)
            points = rays[:, None, :3] + z_samp.unsqueeze(2) * rays[:, None, 3:6]
            points = points.reshape(-1, 3)  # (B*K, 3)

            use_viewdirs = hasattr(model, "use_viewdirs") and model.use_viewdirs
            # print(use_viewdirs)

            val_all = []
            if sb > 0:
                points = points.reshape(
                    sb, -1, 3
                )  # (SB, B'*K, 3) B' is real ray batch size
                eval_batch_size = (self.eval_batch_size - 1) // sb + 1
                eval_batch_dim = 1
            else:
                eval_batch_size = self.eval_batch_size
                eval_batch_dim = 0

            split_points = torch.split(points, eval_batch_size, dim=eval_batch_dim)
            if use_viewdirs:
                dim1 = K
                viewdirs = rays[:, None, 3:6].expand(-1, dim1, -1)  # (B, K, 3)
                if sb > 0:
                    viewdirs = viewdirs.reshape(sb, -1, 3)  # (SB, B'*K, 3)
                else:
                    viewdirs = viewdirs.reshape(-1, 3)  # (B*K, 3)
                split_viewdirs = torch.split(
                    viewdirs, eval_batch_size, dim=eval_batch_dim
                )
                for pnts, dirs in zip(split_points, split_viewdirs):
                    val_all.append(model(pnts, coarse=coarse, viewdirs=dirs))
            else:
                for pnts in split_points:
                    val_all.append(model(pnts, coarse=coarse))
            points = None
            viewdirs = None
            # (B*K, 4) OR (SB, B'*K, 4)
            out = torch.cat(val_all, dim=eval_batch_dim)
            out = out.reshape(B, K, -1)  # (B, K, 4 or 5)

            sigmas = out[..., 0]  # (B, K)
            if self.training and self.noise_std > 0.0:
                sigmas = sigmas + torch.randn_like(sigmas) * self.noise_std

            alphas = 1 - torch.exp(-deltas * torch.relu(sigmas))  # (B, K)
            deltas = None
            sigmas = None

            alphas_shifted = torch.cat(
                [torch.ones_like(alphas[:, :1]), 1 - alphas + 1e-10], -1
            )  # (B, K+1) = [1, a1, a2, ...]
            T = torch.cumprod(alphas_shifted, -1)  # (B)
            weights = alphas * T[:, :-1]  # (B, K)
            alphas = None
            alphas_shifted = None
            depth_final = torch.sum(weights * z_samp, -1)  # (B)

            out_dict = {"weights": weights,
                        "depth": depth_final}

            if self.use_rgb_head:
                rgbs = out[..., 1:4]  # (B, K, 3)
                rgb_final = torch.sum(weights.unsqueeze(-1) * rgbs, -2)  # (B, 3)
                if self.white_bkgd:
                    # White background
                    pix_alpha = weights.sum(dim=1)  # (B), pixel alpha
                    rgb_final = rgb_final + 1 - pix_alpha.unsqueeze(-1)  # (B, 3)
                out_dict["rgb"] = rgb_final

            if self.use_seg_head:  
                seg = out[..., -self.n_classes:] # (B, K, C)
                seg_final = torch.sum(weights.unsqueeze(-1) * seg, -2)  # (B, 3)
                out_dict["seg"] = seg_final


            return out_dict

    def forward(
        self, model, rays, want_weights=False, **kwargs
    ):
        """
        :model nerf model, should return (SB, B, (r, g, b, sigma))
        when called with (SB, B, (x, y, z)), for multi-object:
        SB = 'super-batch' = size of object batch,
        B  = size of per-object ray batch.
        Should also support 'coarse' boolean argument for coarse NeRF.
        :param rays ray spec [origins (3), directions (3), near (1), far (1)] (SB, B, 8)
        :param want_weights if true, returns compositing weights (SB, B, K)
        :return render dict
        """
        with profiler.record_function("renderer_forward"):
            if self.sched is not None and self.last_sched.item() > 0:
                self.n_coarse = self.sched[1][self.last_sched.item() - 1]
                self.n_fine = self.sched[2][self.last_sched.item() - 1]

            assert len(rays.shape) == 3
            superbatch_size = rays.shape[0]
            rays = rays.reshape(-1, 8)  # (SB * B, 8)

            z_coarse = self.sample_coarse(rays)  # (B, Kc)
            coarse_composite = self.composite(
                model, rays, z_coarse, coarse=True, sb=superbatch_size,
            )
            
            # points = rays[:, None, :3] + z_coarse.unsqueeze(2) * rays[:, None, 3:6]
            # rgb_gt = kwargs['rgb_gt'].reshape(-1,1,3).repeat(1,64,1)
            # seg_gt = kwargs['seg_gt'].reshape(-1,1,1).repeat(1,64,3)
            # print(rays.shape, rgb_gt.shape, seg_gt.shape, rgb_gt.max(), rgb_gt.min(), seg_gt.max(), seg_gt.min(), rays.max(), rays.min())
            # save_pc(points.reshape(superbatch_size,-1,3)[0].detach().cpu().numpy(), (rgb_gt.reshape(superbatch_size,-1,3)[0].detach().cpu().numpy()+1)*127.5, "rgb_pc.ply")
            # save_pc(points.reshape(superbatch_size,-1,3)[0].detach().cpu().numpy(), seg_gt.reshape(superbatch_size,-1,3)[0].detach().cpu().numpy()*255/5, "seg_pc.ply")
            # assert(True == False)

            outputs = DotMap(
                coarse=self._format_outputs(
                    coarse_composite, superbatch_size, want_weights=want_weights,
                ),
            )



            if self.using_fine:
                all_samps = [z_coarse]
                if self.n_fine - self.n_fine_depth > 0:
                    all_samps.append(
                        self.sample_fine(rays, coarse_composite["weights"].detach())
                    )  # (B, Kf - Kfd)
                if self.n_fine_depth > 0:
                    all_samps.append(
                        self.sample_fine_depth(rays, coarse_composite["depth"])
                    )  # (B, Kfd)
                z_combine = torch.cat(all_samps, dim=-1)  # (B, Kc + Kf)
                z_combine_sorted, argsort = torch.sort(z_combine, dim=-1)
                fine_composite = self.composite(
                    model, rays, z_combine_sorted, coarse=False, sb=superbatch_size,
                )
                outputs.fine = self._format_outputs(
                    fine_composite, superbatch_size, want_weights=want_weights,
                )

            return outputs

    def _format_outputs(
        self, rendered_outputs, superbatch_size, want_weights=False,
    ):
        # import pdb; pdb.set_trace()
        depth = rendered_outputs["depth"]
        ret_dict = DotMap(depth=depth)
        if want_weights:
            ret_dict.weights = rendered_outputs["weights"]
        if self.use_rgb_head:
            ret_dict.rgb = rendered_outputs["rgb"]
        if self.use_seg_head:
            ret_dict.seg = rendered_outputs["seg"]

        if superbatch_size > 0:
            ret_dict.depth = ret_dict.depth.reshape(superbatch_size, -1)
            if want_weights:
                ret_dict.weights = ret_dict.weights.reshape(superbatch_size, -1, ret_dict.weights.shape[-1])
            if self.use_rgb_head:
                ret_dict.rgb = ret_dict.rgb.reshape(superbatch_size, -1, 3)
            if self.use_seg_head:
                ret_dict.seg = ret_dict.seg.reshape(superbatch_size, -1, self.n_classes)
        return ret_dict

    def sched_step(self, steps=1):
        """
        Called each training iteration to update sample numbers
        according to schedule
        """
        if self.sched is None:
            return
        self.iter_idx += steps
        while (
            self.last_sched.item() < len(self.sched[0])
            and self.iter_idx.item() >= self.sched[0][self.last_sched.item()]
        ):
            self.n_coarse = self.sched[1][self.last_sched.item()]
            self.n_fine = self.sched[2][self.last_sched.item()]
            print(
                "INFO: NeRF sampling resolution changed on schedule ==> c",
                self.n_coarse,
                "f",
                self.n_fine,
            )
            self.last_sched += 1

    @classmethod
    def from_conf(cls, conf, white_bkgd=False, lindisp=False, eval_batch_size=100000):
        return cls(
            conf.get_int("n_coarse", 128),
            conf.get_int("n_fine", 0),
            n_fine_depth=conf.get_int("n_fine_depth", 0),
            noise_std=conf.get_float("noise_std", 0.0),
            depth_std=conf.get_float("depth_std", 0.01),
            white_bkgd=conf.get_float("white_bkgd", white_bkgd),
            lindisp=lindisp,
            eval_batch_size=conf.get_int("eval_batch_size", eval_batch_size),
            sched=conf.get_list("sched", None),
            n_classes=conf.get_int("n_classes", 1),
            use_rgb_head=conf.get_bool("use_rgb_head", True),
            use_seg_head=conf.get_bool("use_seg_head", True),
        )

    def bind_parallel(self, net, gpus=None, simple_output=False):
        """
        Returns a wrapper module compatible with DataParallel.
        Specifically, it renders rays with this renderer
        but always using the given network instance.
        Specify a list of GPU ids in 'gpus' to apply DataParallel automatically.
        :param net A PixelNeRF network
        :param gpus list of GPU ids to parallize to. If length is 1,
        does not parallelize
        :param simple_output only returns rendered (rgb, depth) instead of the 
        full render output map. Saves data tranfer cost.
        :return torch module
        """
        wrapped = _RenderWrapper(net, self, simple_output=simple_output)
        if gpus is not None and len(gpus) > 1:
            print("Using multi-GPU", gpus)
            wrapped = torch.nn.DataParallel(wrapped, gpus, dim=1)
        return wrapped

    def bind_pts_parallel(self, net, gpus=None):
        """
        Returns a wrapper module compatible with DataParallel.
        Specifically, it renders rays with this renderer
        but always using the given network instance.
        Specify a list of GPU ids in 'gpus' to apply DataParallel automatically.
        :param net A PixelNeRF network
        :param gpus list of GPU ids to parallize to. If length is 1,
        does not parallelize
        :param simple_output only returns rendered (rgb, depth) instead of the 
        full render output map. Saves data tranfer cost.
        :return torch module
        """
        wrapped = _PtsWrapper(net)
        if gpus is not None and len(gpus) > 1:
            print("Using multi-GPU", gpus)
            wrapped = torch.nn.DataParallel(wrapped, gpus, dim=1)
        return wrapped

class UnisurfRenderer(torch.nn.Module):
    def __init__(
        self,
        n_coarse=128,
        n_fine=0,
        n_fine_depth=0,
        noise_std=0.0,
        depth_std=0.01,
        eval_batch_size=100000,
        white_bkgd=False,
        lindisp=False,
        sched=None,  # ray sampling schedule for coarse and fine rays
        n_classes=2,
        use_rgb_head=True,
        use_seg_head=True,
    ):
        super().__init__()

        # self._device = device
        self.depth_range = [0.01, 4]
        self.n_max_network_queries = 64000
        self.white_background = True
        # self.cfg=cfg

        # self.model = model.to(device)
        # if model_bg is not None:
        #     self.model_bg = model.to(device)
        # else:
        #     self.model_bg = None


        self.n_coarse = n_coarse
        self.n_fine = n_fine
        self.n_fine_depth = n_fine_depth

        self.noise_std = noise_std
        self.depth_std = depth_std

        self.eval_batch_size = eval_batch_size
        self.white_bkgd = white_bkgd
        self.lindisp = lindisp
        self.n_classes = n_classes
        self.use_rgb_head = use_rgb_head
        self.use_seg_head = use_seg_head

        if lindisp:
            print("Using linear displacement rays")
        # self.using_fine = n_fine > 0
        self.using_fine = False #DONT USE FINE
        self.sched = sched
        if sched is not None and len(sched) == 0:
            self.sched = None
        self.register_buffer(
            "iter_idx", torch.tensor(0, dtype=torch.long), persistent=True
        )
        self.register_buffer(
            "last_sched", torch.tensor(0, dtype=torch.long), persistent=True
        )

    # def forward(self, model, pixels, add_noise=False, it=100000, eval_=False):
    def forward(self, model, pixels, want_weights=False, add_noise=False, it=100000, eval_=False, gt_img=None, **kwargs):
        # print(it)
        # it=1000000
        # Get configs
        # print(pixels.shape)
        epsilon = 1e-6
        device = model.pixelnerf.poses.device
        # camera_mat = torch.stack([[model.focal[...,0], 0, model.c[...,0], 0],
        #                               [0, model.focal[...,1], model.c[...,1], 0],
        #                               [0, 0, 1, 0],
        #                               [0, 0, 0, 1],
        #                             ])
        camera_mat = torch.FloatTensor([[1/model.pixelnerf.focal[0,0].item(), 0, -model.pixelnerf.c[0,0].item()/model.pixelnerf.focal[0,0].item(), 0],
                                      [0, -1/model.pixelnerf.focal[0,1].item(), model.pixelnerf.c[0,1].item()/model.pixelnerf.focal[0,1].item(), 0],
                                      [0, 0, -1, 0],
                                      [0, 0, 0, 1],
                                    ]).to(device)
        scale_mat = torch.eye(4).to(device)
        batch_size, n_points, _ = pixels.shape
        
        rad = 0.99
        ada_start = 2.0
        ada_end = 0.1
        ada_grad = 0.000015
        steps = 64
        steps_outside = 32
        ray_steps = 256
        # import pdb; pdb.set_trace()
        depth_range = torch.tensor(self.depth_range)
        n_max_network_queries = self.n_max_network_queries
        
        # Prepare camera projection
        view_inds = pixels[...,0].long()
        # img_idx=50
        # view_inds[:] = img_idx
        pixels = pixels[...,1:]
        world_mat = model.target_poses[view_inds].to(device)
        # print(world_mat.shape)
        camera_mat = camera_mat.reshape(1,1,4,4).repeat(world_mat.shape[0], world_mat.shape[1], 1, 1)
        scale_mat = scale_mat.reshape(1,1,4,4).repeat(world_mat.shape[0], world_mat.shape[1], 1, 1)
        # import pdb; pdb.set_trace()
        # print(camera_mat[0,0])
        pixels_world = image_points_to_world(
            pixels, camera_mat, world_mat,scale_mat, invert=False
        )
        # print(camera_mat[0,0])
        camera_world = origin_to_world(
            n_points, camera_mat, world_mat, scale_mat, invert=False
        )
        
        # pixels_world_np = pixels_world[0].cpu().numpy()
        # camera_world_np = camera_world[0].cpu().numpy()
        # pixels_np = pixels[0].cpu().numpy()
        # gts = gt_img.reshape(-1, 128,128,3).cpu().numpy()
        # pixel_colors = np.array([gts[img_idx,127-i, j]  for j, i in zip(pixels_np[:,0], pixels_np[:,1])])*255

        # save_pc(pixels_world_np, pixel_colors, 'test_pc.ply')
        # save_pc(pixels_world_np,
        #         np.concatenate([pixels_np, np.zeros((pixels_np.shape[0], 1))], axis=-1),
        #         'pixels.ply')
        # save_pc(camera_world_np, np.ones_like(camera_world_np) * np.array([0,0,255.]), 'cam.ply')
        # import pdb; pdb.set_trace()


        # print(camera_mat[0,0])
        ray_vector = (pixels_world - camera_world)
        ray_vector = ray_vector/ray_vector.norm(2,2).unsqueeze(-1)
        
        # Get sphere intersection
        depth_intersect,_ = get_sphere_intersection(
            camera_world, ray_vector, r=rad
        )
        # import pdb; pdb.set_trace()
        

        # Find surface
        with torch.no_grad():
            d_i = self.ray_marching(
                camera_world, ray_vector, model,
                n_secant_steps=8, 
                n_steps=[int(ray_steps),int(ray_steps)+1], 
                rad=rad
            )
            # import pdb; pdb.set_trace()
            # print(d_i.shape)

            # Get mask for where first evaluation point is occupied
            mask_zero_occupied = d_i == 0
            d_i = d_i.detach()

            # Get mask for predicted depth
            mask_pred = get_mask(d_i).detach()
        
        # with torch.no_grad():
            dists =  torch.ones_like(d_i).to(device)
            dists[mask_pred] = d_i[mask_pred]
            dists[mask_zero_occupied] = 0.
            network_object_mask = mask_pred & ~mask_zero_occupied
            network_object_mask = network_object_mask[0]
            dists = dists[0]

        # Project depth to 3d poinsts
        camera_world = camera_world.reshape(-1, 3)
        ray_vector = ray_vector.reshape(-1, 3)
       
        points = camera_world + ray_vector * dists.unsqueeze(-1)
        points = points.view(-1,3)

        # Define interval
        depth_intersect[:,:,0] = torch.Tensor([0.0]).to(device)
        dists_intersect = depth_intersect.reshape(-1, 2)

        d_inter = dists[network_object_mask]
        d_sphere_surf = dists_intersect[network_object_mask][:,1]
        delta = torch.max(ada_start * torch.exp(-1 * ada_grad * it * torch.ones(1)),\
             ada_end * torch.ones(1)).to(device)

        dnp = d_inter - delta
        dfp = d_inter + delta
        dnp = torch.where(dnp < depth_range[0].float().to(device),\
            depth_range[0].float().to(device), dnp)
        dfp = torch.where(dfp >  d_sphere_surf,  d_sphere_surf, dfp)
        if (dnp!=0.0).all() and it > 5000:
            full_steps = steps+steps_outside
        else:
            full_steps = steps

        d_nointer = dists_intersect[~network_object_mask]

        d2 = torch.linspace(0., 1., steps=full_steps, device=device)
        d2 = d2.view(1, 1, -1).repeat(batch_size, d_nointer.shape[0], 1)
        d2 = depth_range[0] * (1. - d2) + d_nointer[:,1].view(1, -1, 1)* d2

        if add_noise:
            di_mid = .5 * (d2[:, :, 1:] + d2[:, :, :-1])
            di_high = torch.cat([di_mid, d2[:, :, -1:]], dim=-1)
            di_low = torch.cat([d2[:, :, :1], di_mid], dim=-1)
            noise = torch.rand(batch_size, d2.shape[1], full_steps, device=device)
            d2 = di_low + (di_high - di_low) * noise 
        
        p_noiter = camera_world[~network_object_mask].unsqueeze(-2) \
            + ray_vector[~network_object_mask].unsqueeze(-2) * d2.unsqueeze(-1)
        p_noiter = p_noiter.reshape(-1, 3)
        # save_pc(p_noiter.cpu().numpy(), np.ones_like(p_noiter.cpu().numpy()) * np.array([0,0,1]), 'no_inter.ply')

        # Sampling region with surface intersection        
        d_interval = torch.linspace(0., 1., steps=steps, device=device)
        d_interval = d_interval.view(1, 1, -1).repeat(batch_size, d_inter.shape[0], 1)        
        d_interval = (dnp).view(1, -1, 1) * (1. - d_interval) + (dfp).view(1, -1, 1) * d_interval

        if full_steps != steps:
            d_binterval = torch.linspace(0., 1., steps=steps_outside, device=device)
            d_binterval = d_binterval.view(1, 1, -1).repeat(batch_size, d_inter.shape[0], 1)
            d_binterval =  depth_range[0] * (1. - d_binterval) + (dnp).view(1, -1, 1)* d_binterval
            d1,_ = torch.sort(torch.cat([d_binterval, d_interval],dim=-1), dim=-1)
        else:
            d1 = d_interval

        if add_noise:
            di_mid = .5 * (d1[:, :, 1:] + d1[:, :, :-1])
            di_high = torch.cat([di_mid, d1[:, :, -1:]], dim=-1)
            di_low = torch.cat([d1[:, :, :1], di_mid], dim=-1)
            noise = torch.rand(batch_size, d1.shape[1], full_steps, device=device)
            d1 = di_low + (di_high - di_low) * noise 

        p_iter = camera_world[network_object_mask].unsqueeze(-2)\
             + ray_vector[network_object_mask].unsqueeze(-2) * d1.unsqueeze(-1)
        p_iter = p_iter.reshape(-1, 3)
        # save_pc(p_iter.cpu().numpy(), np.ones_like(p_iter.cpu().numpy()) * np.array([1,0,0]), 'inter.ply')
        

        # Merge rendering points
        p_fg = torch.zeros(batch_size * n_points, full_steps, 3, device=device)
        p_fg[~network_object_mask] =  p_noiter.view(-1, full_steps,3)
        p_fg[network_object_mask] =  p_iter.view(-1, full_steps,3)
        p_fg = p_fg.reshape(1, -1, 3)
        ray_vector_fg = ray_vector.unsqueeze(-2).repeat(1, 1, full_steps, 1)
        ray_vector_fg = -1*ray_vector_fg.reshape(-1, 3)

        # Run Network
        noise = not eval_
        if self.use_rgb_head:
            rgb_fg = []
        if self.use_seg_head:
            seg_fg = []
        logits_alpha_fg = []
        # print(p_fg.shape)
        for i in range(0, p_fg.shape[1], n_max_network_queries):
            out_i, logits_alpha_i = model(
                p_fg[:,i:i+n_max_network_queries], 
                viewdirs=ray_vector_fg[i:i+n_max_network_queries].unsqueeze(0), 
                return_addocc=True
            )
            if self.use_rgb_head:
                rgb_fg.append(out_i['rgb'])
            if self.use_seg_head:
                seg_fg.append(out_i['seg'])
            logits_alpha_fg.append(logits_alpha_i)
        logits_alpha_fg = torch.cat(logits_alpha_fg, dim=1)
        alpha = logits_alpha_fg.view(batch_size * n_points, full_steps)
        weights = alpha * torch.cumprod(
                                        torch.cat(
                                                  [torch.ones((alpha.shape[0], 1), device=device),
                                                   1.-alpha + epsilon],-1), -1)[:, :-1]
        # if weights.isnan().any():
        #     print(weights)
        #     print(alpha)
        #     print(logits_alpha_fg)
            # import pdb; pdb.set_trace()

        weights[weights.isnan()] = 0.
        

        out_dict = DotMap()
        out_dict.coarse.weights = weights.unsqueeze(0)
        out_dict.coarse.depth = d_i
        # print(weights.shape, d_i.shape)
        if self.use_rgb_head:
            rgb_fg = torch.cat(rgb_fg, dim=1)
            rgb = rgb_fg.reshape(batch_size * n_points, full_steps, -1)
            rgb_values = torch.sum(weights.unsqueeze(-1) * rgb, dim=-2)
            if self.white_background:
                acc_map = torch.sum(weights, -1)
                rgb_values = rgb_values + (1. - acc_map.unsqueeze(-1))
            out_dict.coarse.rgb = rgb_values.reshape(batch_size, n_points, -1)
        if self.use_seg_head:
            seg_fg = torch.cat(seg_fg, dim=1)
            seg = seg_fg.reshape(batch_size * n_points, full_steps, -1)
            seg_values = torch.sum(weights.unsqueeze(-1) * seg, dim=-2)
            out_dict.coarse.seg = seg_values.reshape(batch_size, n_points, -1)

        # print(p_fg.shape, weights.shape, rgb.shape)
        # import pdb; pdb.set_trace()
        
        # save_pc(p_fg[0].reshape(-1,3).detach().cpu().numpy(), weights.reshape(-1,1).repeat(1,3).detach().cpu().numpy() * rgb.reshape(-1,3).detach().cpu().numpy() * 255., 'surface_rgb_vis.ply')
        # save_pc(p_fg[0].reshape(-1,3).detach().cpu().numpy(), weights.reshape(-1,1).repeat(1,3).detach().cpu().numpy() * 255., 'surface_vis.ply')
        # save_pc(p_fg[0].reshape(1024,96,3)[~d_i[0].isinf()].reshape(-1,3).detach().cpu().numpy(), weights[~d_i[0].isinf()].reshape(-1,1).repeat(1,3).detach().cpu().numpy() * rgb[~d_i[0].isinf()].reshape(-1,3).detach().cpu().numpy() * 255., 'surface_rgb_vis2.ply')
        # import pdb; pdb.set_trace()

        if not eval_ and network_object_mask.sum() > 0:
            surface_mask = network_object_mask.view(-1)
            surface_points = points[surface_mask]
            N = surface_points.shape[0]
            surface_points_neig = surface_points + (torch.rand_like(surface_points) - 0.5) * 0.01      
            pp = torch.cat([surface_points, surface_points_neig], dim=0)
            # print(pp.shape)
            g = model.gradient(pp.unsqueeze(0)).squeeze()
            # print(g.shape)
            normals_ = g / (g.norm(2, dim=1).unsqueeze(-1) + 10**(-5))
            # print(normals_.shape)
            diff_norm =  torch.norm(normals_[:N] - normals_[N:], dim=-1)
            out_dict.coarse.diff_norm = diff_norm
        else:
            out_dict.coarse.diff_norm = torch.tensor([0.0]).to(device)

        out_dict.coarse.mask_pred = network_object_mask
        # import pdb; pdb.set_trace()
        return out_dict

    def ray_marching(self, ray0, ray_direction, model, c=None,
                             tau=0.5, n_steps=[128, 129], n_secant_steps=8,
                             depth_range=[0.01, 2.4], max_points=350000, rad=1.0):
        ''' Performs ray marching to detect surface points.
        The function returns the surface points as well as d_i of the formula
            ray(d_i) = ray0 + d_i * ray_direction
        which hit the surface points. In addition, masks are returned for
        illegal values.
        Args:
            ray0 (tensor): ray start points of dimension B x N x 3
            ray_direction (tensor):ray direction vectors of dim B x N x 3
            model (nn.Module): model model to evaluate point occupancies
            c (tensor): latent conditioned code
            tay (float): threshold value
            n_steps (tuple): interval from which the number of evaluation
                steps if sampled
            n_secant_steps (int): number of secant refinement steps
            depth_range (tuple): range of possible depth values (not relevant when
                using cube intersection)
            method (string): refinement method (default: secant)
            check_cube_intersection (bool): whether to intersect rays with
                unit cube for evaluation
            max_points (int): max number of points loaded to GPU memory
        '''
        # Shotscuts
        batch_size, n_pts, D = ray0.shape
        device = ray0.device
        tau = 0.5
        n_steps = torch.randint(n_steps[0], n_steps[1], (1,)).item()
        # import pdb; pdb.set_trace()

            
        depth_intersect, _ = get_sphere_intersection(ray0, ray_direction, r=rad)
        d_intersect = depth_intersect[...,1]            
        
        d_proposal = torch.linspace(
            0, 1, steps=n_steps).view(
                1, 1, n_steps, 1).to(device)
        d_proposal = depth_range[0] * (1. - d_proposal) + d_intersect.view(1, -1, 1,1)* d_proposal

        p_proposal = ray0.unsqueeze(2).repeat(1, 1, n_steps, 1) + \
            ray_direction.unsqueeze(2).repeat(1, 1, n_steps, 1) * d_proposal

        # Evaluate all proposal points in parallel
        with torch.no_grad():
            # val = torch.cat([(
            #     model(p_split, only_occupancy=True) - tau)
            #     for p_split in torch.split(
            #         p_proposal.reshape(batch_size, -1, 3),
            #         int(max_points / batch_size), dim=1)], dim=1).view(
            #             batch_size, -1, n_steps)
            p_splits = torch.split(
                    p_proposal.reshape(batch_size, -1, 3),
                    int(max_points / batch_size), dim=1)
            model_outs = []
            for p_split in p_splits:
                model_outs.append(model(p_split, only_occupancy=True) - tau)
            val = torch.cat(model_outs, dim=1).view(batch_size, -1, n_steps)
        # print(id(model))
        # import pdb; pdb.set_trace()
        # Create mask for valid points where the first point is not occupied
        mask_0_not_occupied = val[:, :, 0] < 0

        # Calculate if sign change occurred and concat 1 (no sign change) in
        # last dimension
        sign_matrix = torch.cat([torch.sign(val[:, :, :-1] * val[:, :, 1:]),
                                 torch.ones(batch_size, n_pts, 1).to(device)],
                                dim=-1)
        cost_matrix = sign_matrix * torch.arange(
            n_steps, 0, -1).float().to(device)

        # Get first sign change and mask for values where a.) a sign changed
        # occurred and b.) no a neg to pos sign change occurred (meaning from
        # inside surface to outside)
        values, indices = torch.min(cost_matrix, -1)
        mask_sign_change = values < 0
        mask_neg_to_pos = val[torch.arange(batch_size).unsqueeze(-1),
                              torch.arange(n_pts).unsqueeze(-0), indices] < 0

        # Define mask where a valid depth value is found
        mask = mask_sign_change & mask_neg_to_pos & mask_0_not_occupied 

        # Get depth values and function values for the interval
        # to which we want to apply the Secant method
        n = batch_size * n_pts
        d_low = d_proposal.view(
            n, n_steps, 1)[torch.arange(n), indices.view(n)].view(
                batch_size, n_pts)[mask]
        f_low = val.view(n, n_steps, 1)[torch.arange(n), indices.view(n)].view(
            batch_size, n_pts)[mask]
        indices = torch.clamp(indices + 1, max=n_steps-1)
        d_high = d_proposal.view(
            n, n_steps, 1)[torch.arange(n), indices.view(n)].view(
                batch_size, n_pts)[mask]
        f_high = val.view(
            n, n_steps, 1)[torch.arange(n), indices.view(n)].view(
                batch_size, n_pts)[mask]

        ray0_masked = ray0[mask]
        ray_direction_masked = ray_direction[mask]

        # write c in pointwise format
        if c is not None and c.shape[-1] != 0:
            c = c.unsqueeze(1).repeat(1, n_pts, 1)[mask]
        
        # Apply surface depth refinement step (e.g. Secant method)
        d_pred = self.secant(
            model, f_low, f_high, d_low, d_high, n_secant_steps, ray0_masked,
            ray_direction_masked, tau)
        # for sanity
        d_pred_out = torch.ones(batch_size, n_pts).to(device)
        d_pred_out[mask] = d_pred
        d_pred_out[mask == 0] = np.inf
        d_pred_out[mask_0_not_occupied == 0] = 0
        # import pdb; pdb.set_trace()
        return d_pred_out

    def secant(self, model, f_low, f_high, d_low, d_high, n_secant_steps,
                          ray0_masked, ray_direction_masked, tau, it=0):
        ''' Runs the secant method for interval [d_low, d_high].
        Args:
            d_low (tensor): start values for the interval
            d_high (tensor): end values for the interval
            n_secant_steps (int): number of steps
            ray0_masked (tensor): masked ray start points
            ray_direction_masked (tensor): masked ray direction vectors
            model (nn.Module): model model to evaluate point occupancies
            c (tensor): latent conditioned code c
            tau (float): threshold value in logits
        '''
        d_pred = - f_low * (d_high - d_low) / (f_high - f_low) + d_low
        if d_pred.shape[0] > 0:
            # import pd; pdb.set_trace()
            for i in range(n_secant_steps):
                p_mid = ray0_masked + d_pred.unsqueeze(-1) * ray_direction_masked
                with torch.no_grad():
                    f_mid = model(p_mid.unsqueeze(0), only_occupancy=True)[0,...,0] - tau
                ind_low = f_mid < 0
                ind_low = ind_low
                if ind_low.sum() > 0:
                    d_low[ind_low] = d_pred[ind_low]
                    f_low[ind_low] = f_mid[ind_low]
                if (ind_low == 0).sum() > 0:
                    d_high[ind_low == 0] = d_pred[ind_low == 0]
                    f_high[ind_low == 0] = f_mid[ind_low == 0]

                d_pred = - f_low * (d_high - d_low) / (f_high - f_low) + d_low

        return d_pred
    
    def transform_to_homogenous(self, p):
        device = self._device
        batch_size, num_points, _ = p.size()
        r = torch.sqrt(torch.sum(p**2, dim=2, keepdim=True))
        p_homo = torch.cat((p, torch.ones(batch_size, num_points, 1).to(device)), dim=2) / r
        return p_homo

    def to(self, device):
        ''' Puts the model to the device.
        Args:
            device (device): pytorch device
        '''
        model = super().to(device)
        model._device = device
        return model
    
    def sched_step(self, steps=1):
        """
        Called each training iteration to update sample numbers
        according to schedule
        """
        if self.sched is None:
            return
        self.iter_idx += steps
        while (
            self.last_sched.item() < len(self.sched[0])
            and self.iter_idx.item() >= self.sched[0][self.last_sched.item()]
        ):
            self.n_coarse = self.sched[1][self.last_sched.item()]
            self.n_fine = self.sched[2][self.last_sched.item()]
            print(
                "INFO: NeRF sampling resolution changed on schedule ==> c",
                self.n_coarse,
                "f",
                self.n_fine,
            )
            self.last_sched += 1

    @classmethod
    def from_conf(cls, conf, white_bkgd=False, lindisp=False, eval_batch_size=100000):
        return cls(
            conf.get_int("n_coarse", 128),
            conf.get_int("n_fine", 0),
            n_fine_depth=conf.get_int("n_fine_depth", 0),
            noise_std=conf.get_float("noise_std", 0.0),
            depth_std=conf.get_float("depth_std", 0.01),
            white_bkgd=conf.get_float("white_bkgd", white_bkgd),
            lindisp=lindisp,
            eval_batch_size=conf.get_int("eval_batch_size", eval_batch_size),
            sched=conf.get_list("sched", None),
            n_classes=conf.get_int("n_classes", 1),
            use_rgb_head=conf.get_bool("use_rgb_head", True),
            use_seg_head=conf.get_bool("use_seg_head", True),
        )

    def bind_parallel(self, net, gpus=None, simple_output=False):
        """
        Returns a wrapper module compatible with DataParallel.
        Specifically, it renders rays with this renderer
        but always using the given network instance.
        Specify a list of GPU ids in 'gpus' to apply DataParallel automatically.
        :param net A PixelNeRF network
        :param gpus list of GPU ids to parallize to. If length is 1,
        does not parallelize
        :param simple_output only returns rendered (rgb, depth) instead of the 
        full render output map. Saves data tranfer cost.
        :return torch module
        """
        wrapped = _UniRenderWrapper(net, self, simple_output=simple_output)
        if gpus is not None and len(gpus) > 1:
            print("Using multi-GPU", gpus)
            wrapped = torch.nn.DataParallel(wrapped, gpus, dim=0)
        return wrapped

    def bind_pts_parallel(self, net, gpus=None):
        """
        Returns a wrapper module compatible with DataParallel.
        Specifically, it renders rays with this renderer
        but always using the given network instance.
        Specify a list of GPU ids in 'gpus' to apply DataParallel automatically.
        :param net A PixelNeRF network
        :param gpus list of GPU ids to parallize to. If length is 1,
        does not parallelize
        :param simple_output only returns rendered (rgb, depth) instead of the 
        full render output map. Saves data tranfer cost.
        :return torch module
        """
        wrapped = _UniPtsWrapper(net)
        if gpus is not None and len(gpus) > 1:
            print("Using multi-GPU", gpus)
            wrapped = torch.nn.DataParallel(wrapped, gpus, dim=0)
        return wrapped

def get_sphere_intersection(cam_loc, ray_directions, r = 1.0):
    # Input: n_images x 4 x 4 ; n_images x n_rays x 3
    # Output: n_images * n_rays x 2 (close and far) ; n_images * n_rays
    device = cam_loc.device
    n_imgs, n_pix, _ = ray_directions.shape
    cam_loc = cam_loc.permute(1, 2, 0) #(1024, 3, 1)
    ray_directions = ray_directions.transpose(0,1) #(1024, 1, 3)
    ray_cam_dot = torch.bmm(ray_directions, cam_loc).squeeze()
    # import pdb; pdb.set_trace()
    under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2,1).squeeze() ** 2 - r ** 2)

    under_sqrt = under_sqrt.reshape(-1)
    mask_intersect = under_sqrt > 0
    
    sphere_intersections = torch.zeros(n_imgs * n_pix, 2, device=device).float()
    sphere_intersections[mask_intersect] = torch.sqrt(under_sqrt[mask_intersect]).unsqueeze(-1) * torch.Tensor([-1, 1]).to(device).float()
    sphere_intersections[mask_intersect] -= ray_cam_dot.reshape(-1)[mask_intersect].unsqueeze(-1)

    sphere_intersections = sphere_intersections.reshape(n_imgs, n_pix, 2)
    sphere_intersections = sphere_intersections.clamp_min(0.0)
    mask_intersect = mask_intersect.reshape(n_imgs, n_pix)

    return sphere_intersections, mask_intersect