"""Python bindings for custom Cuda functions"""

from typing import Optional

import torch
from jaxtyping import Float, Int
from torch import Tensor
from torch.autograd import Function

import gsplat.cuda as _C
from .utils import bin_and_sort_gaussians, compute_cumulative_intersects, bin_and_sort_gaussians_batch


def rasterize_gaussians_sum(
    xys: Float[Tensor, "*batch 2"],
    depths: Float[Tensor, "*batch 1"],
    radii: Float[Tensor, "*batch 1"],
    conics: Float[Tensor, "*batch 3"],
    num_tiles_hit: Int[Tensor, "*batch 1"],
    colors: Float[Tensor, "*batch channels"],
    opacity: Float[Tensor, "*batch 1"],
    img_height: int,
    img_width: int,
    BLOCK_H: int=16,
    BLOCK_W: int=16, 
    background: Optional[Float[Tensor, "channels"]] = None,
    return_alpha: Optional[bool] = False,
) -> Tensor:
    """Rasterizes 2D gaussians by sorting and binning gaussian intersections for each tile and returns an N-dimensional output using alpha-compositing.

    Note:
        This function is differentiable w.r.t the xys, conics, colors, and opacity inputs.

    Args:
        xys (Tensor): xy coords of 2D gaussians.
        depths (Tensor): depths of 2D gaussians.
        radii (Tensor): radii of 2D gaussians
        conics (Tensor): conics (inverse of covariance) of 2D gaussians in upper triangular format
        num_tiles_hit (Tensor): number of tiles hit per gaussian
        colors (Tensor): N-dimensional features associated with the gaussians.
        opacity (Tensor): opacity associated with the gaussians.
        img_height (int): height of the rendered image.
        img_width (int): width of the rendered image.
        background (Tensor): background color
        return_alpha (bool): whether to return alpha channel

    Returns:
        A Tensor:

        - **out_img** (Tensor): N-dimensional rendered output image.
        - **out_alpha** (Optional[Tensor]): Alpha channel of the rendered output image.
    """
    if colors.dtype == torch.uint8:
        
        colors = colors.float() / 255

    if background is not None:
        assert (
            background.shape[0] == colors.shape[-1]
        ), f"incorrect shape of background color tensor, expected shape {colors.shape[-1]}"
    else:
        background = torch.ones(
            colors.shape[-1], dtype=torch.float32, device=colors.device
        )

    if xys.ndimension() != 2 or xys.size(1) != 2:
        raise ValueError("xys must have dimensions (N, 2)")

    if colors.ndimension() != 2:
        raise ValueError("colors must have dimensions (N, D)")

    return _RasterizeGaussiansSum.apply(
        xys.contiguous(),
        depths.contiguous(),
        radii.contiguous(),
        conics.contiguous(),
        num_tiles_hit.contiguous(),
        colors.contiguous(),
        opacity.contiguous(),
        img_height,
        img_width,
        BLOCK_H, 
        BLOCK_W,
        background.contiguous(),
        return_alpha,
    )


class _RasterizeGaussiansSum(Function):
    """Rasterizes 2D gaussians"""

    @staticmethod
    def forward(
        ctx,
        xys: Float[Tensor, "*batch 2"],
        depths: Float[Tensor, "*batch 1"],
        radii: Float[Tensor, "*batch 1"],
        conics: Float[Tensor, "*batch 3"],
        num_tiles_hit: Int[Tensor, "*batch 1"],
        colors: Float[Tensor, "*batch channels"],
        opacity: Float[Tensor, "*batch 1"],
        img_height: int,
        img_width: int,
        BLOCK_H: int=16,
        BLOCK_W: int=16, 
        background: Optional[Float[Tensor, "channels"]] = None,
        return_alpha: Optional[bool] = False,
    ) -> Tensor:
        num_points = xys.size(0)
        BLOCK_X, BLOCK_Y = BLOCK_W, BLOCK_H
        tile_bounds = (
            (img_width + BLOCK_X - 1) // BLOCK_X,
            (img_height + BLOCK_Y - 1) // BLOCK_Y,
            1,
        )
        block = (BLOCK_X, BLOCK_Y, 1)
        img_size = (img_width, img_height, 1)

        num_intersects, cum_tiles_hit = compute_cumulative_intersects(num_tiles_hit)

        if num_intersects < 1:
            out_img = (
                torch.ones(img_height, img_width, colors.shape[-1], device=xys.device)
                * background
            )
            gaussian_ids_sorted = torch.zeros(0, 1, device=xys.device)
            tile_bins = torch.zeros(0, 2, device=xys.device)
            final_Ts = torch.zeros(img_height, img_width, device=xys.device)
            final_idx = torch.zeros(img_height, img_width, device=xys.device)
        else:
            (
                isect_ids_unsorted,
                gaussian_ids_unsorted,
                isect_ids_sorted,
                gaussian_ids_sorted,
                tile_bins,
            ) = bin_and_sort_gaussians(
                num_points,
                num_intersects,
                xys,
                depths,
                radii,
                cum_tiles_hit,
                tile_bounds,
            )
            if colors.shape[-1] == 3:
                rasterize_fn = _C.rasterize_sum_forward
            else:
                rasterize_fn = _C.nd_rasterize_sum_forward

            out_img, final_Ts, final_idx = rasterize_fn(
                tile_bounds,
                block,
                img_size,
                gaussian_ids_sorted,
                tile_bins,
                xys,
                conics,
                colors,
                opacity,
                background,
            )

        ctx.img_width = img_width
        ctx.img_height = img_height
        ctx.BLOCK_H = BLOCK_H
        ctx.BLOCK_W = BLOCK_W
        ctx.num_intersects = num_intersects
        ctx.save_for_backward(
            gaussian_ids_sorted,
            tile_bins,
            xys,
            conics,
            colors,
            opacity,
            background,
            final_Ts,
            final_idx,
        )

        if return_alpha:
            out_alpha = 1 - final_Ts
            return out_img, out_alpha
        else:
            return out_img

    @staticmethod
    def backward(ctx, v_out_img, v_out_alpha=None):
        img_height = ctx.img_height
        img_width = ctx.img_width
        BLOCK_H = ctx.BLOCK_H
        BLOCK_W = ctx.BLOCK_W
        num_intersects = ctx.num_intersects

        if v_out_alpha is None:
            v_out_alpha = torch.zeros_like(v_out_img[..., 0])

        (
            gaussian_ids_sorted,
            tile_bins,
            xys,
            conics,
            colors,
            opacity,
            background,
            final_Ts,
            final_idx,
        ) = ctx.saved_tensors

        if num_intersects < 1:
            v_xy = torch.zeros_like(xys)
            v_conic = torch.zeros_like(conics)
            v_colors = torch.zeros_like(colors)
            v_opacity = torch.zeros_like(opacity)

        else:
            if colors.shape[-1] == 3:
                rasterize_fn = _C.rasterize_sum_backward
            else:
                rasterize_fn = _C.nd_rasterize_sum_backward
            v_xy, v_conic, v_colors, v_opacity = rasterize_fn(
                img_height,
                img_width,
                BLOCK_H,
                BLOCK_W,
                gaussian_ids_sorted,
                tile_bins,
                xys,
                conics,
                colors,
                opacity,
                background,
                final_Ts,
                final_idx,
                v_out_img,
                v_out_alpha,
            )

        return (
            v_xy,  
            None,  
            None,  
            v_conic,  
            None,  
            v_colors,  
            v_opacity,  
            None,  
            None,  
            None,  
            None,  
            None,  
            None,  
        )


def rasterize_gaussians_sum_batch(
    batch_size: int,
    num_points_per_image: int,
    xys: Float[Tensor, "total_points 2"],
    depths: Float[Tensor, "total_points 1"],
    radii: Float[Tensor, "total_points 1"],
    conics: Float[Tensor, "total_points 3"],
    num_tiles_hit: Int[Tensor, "total_points 1"],
    colors: Float[Tensor, "total_points channels"],
    opacity: Float[Tensor, "total_points 1"],
    img_height: int,
    img_width: int,
    BLOCK_H: int = 16,
    BLOCK_W: int = 16,
    background: Optional[Float[Tensor, "batch channels"]] = None,
    return_alpha: Optional[bool] = False,
) -> Tensor:
    """(BATCH VERSION) Rasterizes 2D gaussians for a batch of images.

    Note:
        This function is differentiable w.r.t the xys, conics, colors, and opacity inputs.

    Args:
        batch_size (int): The number of images in the batch.
        num_points_per_image (int): The number of gaussians per image.
        xys (Tensor): xy coords of 2D gaussians for the entire batch.
        depths (Tensor): depths of 2D gaussians for the entire batch.
        radii (Tensor): radii of 2D gaussians for the entire batch.
        conics (Tensor): conics (inverse of covariance) of 2D gaussians.
        num_tiles_hit (Tensor): number of tiles hit per gaussian for the entire batch.
        colors (Tensor): N-dimensional features associated with the gaussians.
        opacity (Tensor): opacity associated with the gaussians.
        img_height (int): height of the rendered image.
        img_width (int): width of the rendered image.
        background (Tensor): background color for each image, shape (batch_size, channels).
        return_alpha (bool): whether to return alpha channel.

    Returns:
        A Tensor or Tuple of Tensors:
        - **out_img** (Tensor): N-dimensional rendered output image, shape (batch_size, H, W, C).
        - **out_alpha** (Optional[Tensor]): Alpha channel of the rendered output image, shape (batch_size, H, W).
    """
    if colors.dtype == torch.uint8:
        colors = colors.float() / 255.0

    if background is None:
        background = torch.ones(
            batch_size, colors.shape[-1], dtype=torch.float32, device=colors.device
        )
    
    assert background.shape == (batch_size, colors.shape[-1]), \
        f"Background tensor shape mismatch. Expected ({batch_size}, {colors.shape[-1]}), got {background.shape}"

    total_points = batch_size * num_points_per_image
    assert xys.shape[0] == total_points, "Input tensor dimensions do not match batch_size and num_points_per_image."

    return _RasterizeGaussiansSumBatch.apply(
        batch_size,
        num_points_per_image,
        xys.contiguous(),
        depths.contiguous(),
        radii.contiguous(),
        conics.contiguous(),
        num_tiles_hit.contiguous(),
        colors.contiguous(),
        opacity.contiguous(),
        img_height,
        img_width,
        BLOCK_H,
        BLOCK_W,
        background.contiguous(),
        return_alpha,
    )




class _RasterizeGaussiansSumBatch(Function):
    """(BATCH VERSION) Autograd function for rasterizing 2D gaussians."""

    @staticmethod
    def forward(
        ctx,
        batch_size: int,
        num_points_per_image: int,
        xys: Float[Tensor, "total_points 2"],
        depths: Float[Tensor, "total_points 1"],
        radii: Float[Tensor, "total_points 1"],
        conics: Float[Tensor, "total_points 3"],
        num_tiles_hit: Int[Tensor, "total_points 1"],
        colors: Float[Tensor, "total_points channels"],
        opacity: Float[Tensor, "total_points 1"],
        img_height: int,
        img_width: int,
        BLOCK_H: int,
        BLOCK_W: int,
        background: Float[Tensor, "batch channels"],
        return_alpha: bool,
    ) -> Tensor:
        
        
        total_points = xys.size(0)
        BLOCK_X, BLOCK_Y = BLOCK_W, BLOCK_H
        tile_bounds = (
            (img_width + BLOCK_X - 1) // BLOCK_X,
            (img_height + BLOCK_Y - 1) // BLOCK_Y,
            1,
        )
        block = (BLOCK_X, BLOCK_Y, 1)
        img_size = (img_width, img_height, 1)

        
        num_intersects, cum_tiles_hit = compute_cumulative_intersects(num_tiles_hit)

        
        if num_intersects < 1:
            out_img = background.view(batch_size, 1, 1, -1).expand(batch_size, img_height, img_width, -1)
            final_Ts = torch.ones(batch_size, img_height, img_width, device=xys.device)
            gaussian_ids_sorted = torch.zeros(0, 1, device=xys.device, dtype=torch.int32)
            tile_bins = torch.zeros(0, 2, device=xys.device, dtype=torch.int32)
            final_idx = torch.zeros(batch_size, img_height, img_width, device=xys.device, dtype=torch.int32)
        else:
            
            (
                gaussian_ids_sorted,
                tile_bins,
            ) = bin_and_sort_gaussians_batch(
                batch_size,
                num_points_per_image,
                num_intersects,
                xys,
                depths,
                radii,
                cum_tiles_hit,
                tile_bounds,
            )

            
            if colors.shape[-1] == 3:
                rasterize_fn = _C.rasterize_batch_forward_sum
            else:
                rasterize_fn = _C.nd_rasterize_batch_forward_sum

            out_img, final_Ts, final_idx = rasterize_fn(
                batch_size,
                tile_bounds,
                block,
                img_size,
                gaussian_ids_sorted,
                tile_bins,
                xys,
                conics,
                colors,
                opacity,
                background,
            )

        
        ctx.img_width = img_width
        ctx.img_height = img_height
        ctx.BLOCK_H = BLOCK_H
        ctx.BLOCK_W = BLOCK_W
        ctx.num_intersects = num_intersects
        
        ctx.batch_size = batch_size

        ctx.save_for_backward(
            gaussian_ids_sorted,
            tile_bins,
            xys,
            conics,
            colors,
            opacity,
            background,
            final_Ts,
            final_idx,
        )

        
        if return_alpha:
            out_alpha = 1.0 - final_Ts
            return out_img, out_alpha
        else:
            return out_img

    @staticmethod
    def backward(ctx, v_out_img, v_out_alpha=None):
        
        
        batch_size = ctx.batch_size
        img_height, img_width = ctx.img_height, ctx.img_width
        BLOCK_H, BLOCK_W = ctx.BLOCK_H, ctx.BLOCK_W
        num_intersects = ctx.num_intersects
        (
            gaussian_ids_sorted, tile_bins, xys, conics, colors,
            opacity, background, final_Ts, final_idx
        ) = ctx.saved_tensors

        if num_intersects < 1:
            return (None, None, torch.zeros_like(xys), None, None, torch.zeros_like(conics), None,
                    torch.zeros_like(colors), torch.zeros_like(opacity), None, None, None, None, None, None)

        if v_out_alpha is None:
            v_out_alpha = torch.zeros_like(v_out_img[..., 0])

        
        if colors.shape[-1] == 3:
            rasterize_fn_bw = _C.rasterize_batch_backward_sum
        else:
            rasterize_fn_bw = _C.nd_rasterize_batch_backward_sum
        
        
        v_xy, v_conic, v_colors, v_opacity = rasterize_fn_bw(
            batch_size,
            img_height,
            img_width,
            BLOCK_H,
            BLOCK_W,
            gaussian_ids_sorted,
            tile_bins,
            xys,
            conics,
            colors,
            opacity,
            background,
            final_Ts,
            final_idx,
            v_out_img,
            v_out_alpha,
        )

        return (
            None, None,  
            v_xy,        
            None,        
            None,        
            v_conic,     
            None,        
            v_colors,    
            v_opacity,   
            None, None, None, None, None, None 
        )