# https://github.com/DaLi-Jack/SSR-code/blob/main/ssr/model/ray_sampler.py

import abc
from tkinter.messagebox import NO
import torch.nn as nn
import torch
import nerfacc 
from src.utils.typing import *

from einops import rearrange, repeat

def ray_sample(cam2world_matrix, intrinsics, resolution):
    """
    Create batches of rays and return origins and directions.

    cam2world_matrix: (N, 4, 4)
    intrinsics: (N, 3, 3)
    resolution: int

    ray_origins: (N, M, 3)
    ray_dirs: (N, M, 3)
    """

    N, M = cam2world_matrix.shape[0], resolution**2
    cam_locs_world = cam2world_matrix[:, :3, 3]
    fx = intrinsics[:, 0, 0]
    fy = intrinsics[:, 1, 1]
    cx = intrinsics[:, 0, 2]
    cy = intrinsics[:, 1, 2]
    sk = intrinsics[:, 0, 1]
    uv = torch.stack(
        torch.meshgrid(
            torch.arange(
                resolution, dtype=torch.float32, device=cam2world_matrix.device
            ),
            torch.arange(
                resolution, dtype=torch.float32, device=cam2world_matrix.device
            ),
            indexing="ij",
        )
    ) * (1.0 / resolution) + (0.5 / resolution)
    uv = repeat(uv, "c h w -> b (h w) c", b=N)
    x_cam = uv[:, :, 0].view(N, -1)
    y_cam = uv[:, :, 1].view(N, -1)
    z_cam = torch.ones((N, M), device=cam2world_matrix.device)

    x_lift = (
        (
            x_cam
            - cx.unsqueeze(-1)
            + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1)
            - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1)
        )
        / fx.unsqueeze(-1)
        * z_cam
    )
    y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam

    cam_rel_points = torch.stack(
        (x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1
    )

    world_rel_points = torch.bmm(
        cam2world_matrix, cam_rel_points.permute(0, 2, 1)
    ).permute(0, 2, 1)[:, :, :3]

    ray_dirs = world_rel_points - cam_locs_world[:, None, :]
    ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)

    ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)

    return ray_origins, ray_dirs

class RaySampler(nn.Module):
    def __init__(self):
        super().__init__()
    def get_z_vals(self, ray_dirs, cam_loc, model):
        raise NotImplementedError


class UniformSampler(nn.Module):
    def __init__(self, aabb, N_samples, perturb = False):
        super().__init__()
        self.N_samples = N_samples
        self.aabb = aabb
        self.perturb = perturb

    # currently this is used for replica scannet and T&T
    def get_z_vals(self, 
        ray_origins: Float[Tensor, "N_rays 3"],
        ray_dirs: Float[Tensor, "N_rays 3"]
    ):
        device = ray_origins.device
        nears, fars, hits = nerfacc.get_ray_limits_box(ray_origins, ray_dirs, self.aabb)

        z_vals = torch.linspace(0.0, 1.0, self.N_samples, device=device).unsqueeze(1)
        z_vals = nears + (fars - nears) * z_vals # (N_samples, 1)
        sample_dist = (fars - nears) / self.N_samples # (N_samples, 1)

        if self.perturb and self.training:
            z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist

        return z_vals, nears, fars


class ErrorBoundSampler(RaySampler):
    def __init__(self, scene_bounding_sphere, near, far, N_samples, N_samples_eval, N_samples_extra,
                 eps, beta_iters, max_total_iters, take_sphere_intersection, encoder,
                 inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=1.0e-6):
        #super().__init__(near, 2.0 * scene_bounding_sphere)
        super().__init__(near, 2.0 * scene_bounding_sphere * 1.75 if far == -1 else far)  # default far is 2*R
        
        self.N_samples = N_samples
        self.N_samples_eval = N_samples_eval
        self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection, far) # replica scannet and T&T courtroom

        self.N_samples_extra = N_samples_extra
        self.take_sphere_intersection = take_sphere_intersection

        self.encoder = encoder

        self.eps = eps
        self.beta_iters = beta_iters
        self.max_total_iters = max_total_iters
        self.scene_bounding_sphere = scene_bounding_sphere
        self.add_tiny = add_tiny

        self.inverse_sphere_bg = inverse_sphere_bg
        if inverse_sphere_bg:
            self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0)

    def get_z_vals(self, ray_dirs_obj, cam_loc_obj, model, latent_feature, cat_feature, global_latent, uv, image_shape, model_input):
        """
        use more feature and deformable attention, and suitable for 'add' and 'cat' architecture mode
        :param latent_feature, in 'add':latent_feature = img_feature + global_feature + rot_feature; in 'cat': latent_feature = img_feature     (B*N_uv, 256)
        :param cat_feature, in 'add': None; in 'cat': cat(global_feature, rot_feature)      (B, 256+39)
        :param img_feature_temp, save origin img feature for 'add',                         (B*N_uv, 256)
        :param global_latent,                                                               (B, 256)
        :param uv, sample pixel,                                                            (B, N_uv, 2)
        :param architecture_mode, 'add' or 'cat'
        """
        beta0 = model.density.get_beta().detach()
        batch_size, N_uv, _ = uv.shape


        # Start with uniform sampling
        z_vals, near, far = self.uniform_sampler.get_z_vals(ray_dirs_obj, cam_loc_obj, model)
        samples, samples_idx = z_vals, None

        # Get maximum beta from the upper bound (Lemma 2)
        dists = z_vals[:, 1:] - z_vals[:, :-1]
        bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1)
        beta = torch.sqrt(bound)

        total_iters, not_converge = 0, True

        # Algorithm 1
        while not_converge and total_iters < self.max_total_iters:
            # get obj coords
            points_obj = cam_loc_obj.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs_obj.unsqueeze(1)            # samples: (B*N_uv, N_pts_per_ray), cam_loc_obj: (B*N_uv, 3), ray_dirs_obj: (B*N_uv, 3)
            points_obj = points_obj.reshape(-1, 3)

            # Calculating the SDF only for the new sampled points
            with torch.no_grad():
                # transfer to cube coords
                cube_coords = sdf_util.scene_obj2cube_coords(points_obj, model_input['scene_scale'], model_input['centroid'], model_input['none_equal_scale'])    # (B*N_uv*N_pts_per_ray, 3)
                samples_sdf = model.implicit_network.get_sdf_vals(cube_coords, latent_feature, cat_feature)
            if samples_idx is not None:
                sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]),
                                       samples_sdf.reshape(-1, samples.shape[1])], -1)
                sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1)
            else:
                sdf = samples_sdf

            # Calculating the bound d* (Theorem 1)
            d = sdf.reshape(z_vals.shape)
            dists = z_vals[:, 1:] - z_vals[:, :-1]
            a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs()
            first_cond = a.pow(2) + b.pow(2) <= c.pow(2)
            second_cond = a.pow(2) + c.pow(2) <= b.pow(2)
            d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1).cuda()
            d_star[first_cond] = b[first_cond]
            d_star[second_cond] = c[second_cond]
            s = (a + b + c) / 2.0
            area_before_sqrt = s * (s - a) * (s - b) * (s - c)
            mask = ~first_cond & ~second_cond & (b + c - a > 0)
            d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask])
            d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star  # Fixing the sign

            # Updating beta using line search
            curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star)
            beta[curr_error <= self.eps] = beta0
            beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta
            for j in range(self.beta_iters):
                beta_mid = (beta_min + beta_max) / 2.
                curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star)
                beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps]
                beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps]
            beta = beta_max

            # Upsample more points
            density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1))

            dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1)
            free_energy = dists * density
            shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1)
            alpha = 1 - torch.exp(-free_energy)
            transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1))
            weights = alpha * transmittance  # probability of the ray hits something here

            #  Check if we are done and this is the last sampling
            total_iters += 1
            not_converge = beta.max() > beta0

            if not_converge and total_iters < self.max_total_iters:
                ''' Sample more points proportional to the current error bound'''

                N = self.N_samples_eval

                bins = z_vals
                error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2)
                error_integral = torch.cumsum(error_per_section, dim=-1)
                bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1]

                pdf = bound_opacity + self.add_tiny
                pdf = pdf / torch.sum(pdf, -1, keepdim=True)
                cdf = torch.cumsum(pdf, -1)
                cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)

            else:
                ''' Sample the final sample set to be used in the volume rendering integral '''

                N = self.N_samples

                bins = z_vals
                pdf = weights[..., :-1]
                pdf = pdf + 1e-5  # prevent nans
                pdf = pdf / torch.sum(pdf, -1, keepdim=True)
                cdf = torch.cumsum(pdf, -1)
                cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)  # (batch, len(bins))

            # Invert CDF
            if (not_converge and total_iters < self.max_total_iters) or (not model.training):
                u = torch.linspace(0., 1., steps=N).cuda().unsqueeze(0).repeat(cdf.shape[0], 1)
            else:
                u = torch.rand(list(cdf.shape[:-1]) + [N]).cuda()
            u = u.contiguous()

            inds = torch.searchsorted(cdf, u, right=True)
            below = torch.max(torch.zeros_like(inds - 1), inds - 1)
            above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
            inds_g = torch.stack([below, above], -1)  # (batch, N_samples, 2)

            matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
            cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
            bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

            denom = (cdf_g[..., 1] - cdf_g[..., 0])
            denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
            t = (u - cdf_g[..., 0]) / denom
            samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

            # Adding samples if we not converged
            if not_converge and total_iters < self.max_total_iters:
                z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1)

        z_samples = samples
        
        #TODO Use near and far from intersection
        near, far = self.near * torch.ones(ray_dirs_obj.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs_obj.shape[0],1).cuda()
        if self.inverse_sphere_bg: # if inverse sphere then need to add the far sphere intersection
            far = rend_util.get_sphere_intersections(cam_loc_obj, ray_dirs_obj, r=self.scene_bounding_sphere)[:,1:]

        if self.N_samples_extra > 0:
            if model.training:
                sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra]           # z_vals include uniform sample z_vals
            else:
                sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long()
            z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1)
        else:
            z_vals_extra = torch.cat([near, far], -1)

        z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1)


        # add some of the near surface points
        idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],)).cuda()
        z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1))              # one random value on a ray

        if self.inverse_sphere_bg:
            z_vals_inverse_sphere, _, _ = self.inverse_sphere_sampler.get_z_vals(ray_dirs_obj, cam_loc_obj, model)
            z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere)
            z_vals = (z_vals, z_vals_inverse_sphere)

        return z_vals, z_samples_eik

    def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star):
        density = model.density(sdf.reshape(z_vals.shape), beta=beta)
        shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), dists * density[:, :-1]], dim=-1)
        integral_estimation = torch.cumsum(shifted_free_energy, dim=-1)
        error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2)
        error_integral = torch.cumsum(error_per_section, dim=-1)
        bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1])

        return bound_opacity.max(-1)[0]