
#include <math.h>
#include <torch/extension.h>
#include <cstdio>
#include <sstream>
#include <iostream>
#include <tuple>
#include <stdio.h>
#include <cuda_runtime_api.h>
#include <memory>
#include "cuda_rasterizer/config.h"
#include "cuda_rasterizer/rasterizer.h"
#include <fstream>
#include <string>
#include <functional>
#include "cuda_rasterizer/auxiliary.h"


std::function<char*(size_t N)> resizeFunctional(torch::Tensor& t) {
    auto lambda = [&t](size_t N) {
        t.resize_({(long long)N});
		return reinterpret_cast<char*>(t.contiguous().data_ptr());
    };
    return lambda;
}

std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
	const torch::Tensor& background,
	const torch::Tensor& means3D,
    const torch::Tensor& colors,
    const torch::Tensor& opacity,
	const torch::Tensor& scales,
	const torch::Tensor& rotations,
	const float scale_modifier,
	const torch::Tensor& cov3D_precomp,
	const torch::Tensor& all_map,
	const torch::Tensor& viewmatrix,
	const torch::Tensor& projmatrix,
	const torch::Tensor& ref_to_src_list,
	const torch::Tensor& src_cam_pos,
	const torch::Tensor& src_images,
	const torch::Tensor& src_rendered_depths,
	const int nb_src_images,
const int buffer_length,
const float depth_error_threshold,
const float tan_fovx, 
const float tan_fovy,
    const int image_height,
    const int image_width,
	const torch::Tensor& sh,
	const int degree,
	const torch::Tensor& campos,
	const bool prefiltered,
	const bool render_geo,
	const bool render_depth_only,
	const bool debug)
{
  if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
    AT_ERROR("means3D must have dimensions (num_points, 3)");
  }
  
  const int P = means3D.size(0);
  const int H = image_height;
  const int W = image_width;

  auto int_opts = means3D.options().dtype(torch::kInt32);
  auto float_opts = means3D.options().dtype(torch::kFloat32);

  torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
  torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
  torch::Tensor out_normal_map = torch::full({NUM_NORMAL_CHANNELS, H, W}, 0, float_opts);
  torch::Tensor out_median_intersected_depth = torch::full({1, H, W}, 0, float_opts);
  // --- Color aggregation feature ---
  torch::Tensor out_cam_feat = torch::full({4*M, H, W}, 0, float_opts);
  torch::Tensor out_warped_image = torch::full({3*M, H, W}, 0, float_opts);
  torch::Tensor out_min_depth_diff = torch::full({1, H, W}, 0, float_opts); 
  torch::Tensor out_camera_ray = torch::full({3, H, W}, 0, float_opts); // (This is optional)
  // --- Extra params for exposure ---
  torch::Tensor out_use_first_src_frame_mask = torch::full({1, H, W}, 0, int_opts);
  
  torch::Device device(torch::kCUDA);
  torch::TensorOptions options(torch::kByte);
  torch::Tensor geomBuffer = torch::empty({0}, options.device(device));
  torch::Tensor binningBuffer = torch::empty({0}, options.device(device));
  torch::Tensor imgBuffer = torch::empty({0}, options.device(device));
  std::function<char*(size_t)> geomFunc = resizeFunctional(geomBuffer);
  std::function<char*(size_t)> binningFunc = resizeFunctional(binningBuffer);
  std::function<char*(size_t)> imgFunc = resizeFunctional(imgBuffer);
  
  int rendered = 0;
  if(P != 0)
  {
	  int M = 0;
	  if(sh.size(0) != 0)
	  {
		M = sh.size(1);
      }

	  rendered = CudaRasterizer::Rasterizer::forward(
	    geomFunc,
		binningFunc,
		imgFunc,
	    P, degree, M,
		background.contiguous().data<float>(),
		W, H,
		means3D.contiguous().data<float>(),
		sh.contiguous().data_ptr<float>(),
		colors.contiguous().data<float>(), 
		opacity.contiguous().data<float>(), 
		scales.contiguous().data_ptr<float>(),
		scale_modifier,
		rotations.contiguous().data_ptr<float>(),
		cov3D_precomp.contiguous().data<float>(), 
		all_map.contiguous().data<float>(), 
		viewmatrix.contiguous().data<float>(), 
		projmatrix.contiguous().data<float>(),
		ref_to_src_list.contiguous().data<float>(),
		src_cam_pos.contiguous().data<float>(),
		src_images.contiguous().data<float>(),
		src_rendered_depths.contiguous().data<float>(),
		nb_src_images,
		buffer_length,
		depth_error_threshold,
		campos.contiguous().data<float>(),
		tan_fovx,
		tan_fovy,
		prefiltered,
		out_color.contiguous().data<float>(),
		radii.contiguous().data<int>(),
		out_normal_map.contiguous().data<float>(),
		out_median_intersected_depth.contiguous().data<float>(),
		out_cam_feat.contiguous().data<float>(),
		out_warped_image.contiguous().data<float>(),
		out_min_depth_diff.contiguous().data<float>(),
		out_camera_ray.contiguous().data<float>(), // (This is optional)
		// --- Extra params for exposure ---
		out_use_first_src_frame_mask.contiguous().data<int>(),
		render_geo,
		render_depth_only,
		debug);
  }
  return std::make_tuple(rendered, out_color, radii, 
						 out_normal_map,
	  					 out_median_intersected_depth, 
						 out_cam_feat, out_warped_image,
						 out_min_depth_diff,
						 out_camera_ray, out_use_first_src_frame_mask, 
  						 geomBuffer, binningBuffer, imgBuffer);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansBackwardCUDA(
 	const torch::Tensor& background,
const torch::Tensor& normal_map_pixels,
	const torch::Tensor& intersected_depth_pixels,
	const torch::Tensor& warped_image_pixels,
	const torch::Tensor& means3D,
	const torch::Tensor& radii,
    const torch::Tensor& colors,
	const torch::Tensor& all_maps,
	const torch::Tensor& scales,
	const torch::Tensor& rotations,
	const float scale_modifier,
	const torch::Tensor& cov3D_precomp,
	const torch::Tensor& viewmatrix,
    const torch::Tensor& projmatrix,
	const torch::Tensor& ref_to_src_list,
	const torch::Tensor& src_cam_pos,
	const torch::Tensor& src_images,
	const torch::Tensor& src_rendered_depths,
	const int nb_src_images, 
	const float tan_fovx,
	const float tan_fovy,
    const torch::Tensor& dL_dout_color,
	const torch::Tensor& dL_dout_normal_map,
	const torch::Tensor& dL_dout_median_intersected_depth,
	const torch::Tensor& dL_dout_warped_image,
	const torch::Tensor& sh,
	const int degree,
	const torch::Tensor& campos,
	const torch::Tensor& geomBuffer,
	const int R,
	const torch::Tensor& binningBuffer,
	const torch::Tensor& imageBuffer,
	const bool render_geo,
	const bool debug) 
{
  const int P = means3D.size(0);
  const int H = dL_dout_color.size(1);
  const int W = dL_dout_color.size(2);
  
  int M = 0;
  if(sh.size(0) != 0)
  {	
	M = sh.size(1);
  }

  torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options());
  torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options());
  torch::Tensor dL_dmeans2D_abs = torch::zeros({P, 3}, means3D.options());
  torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options());
  torch::Tensor dL_dall_map = torch::zeros({P, NUM_PLANE_PARAMS}, means3D.options());
  torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options());
  torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options());
  torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options());
  torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options());
  torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options());
  torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options());
  
  if(P != 0)
  {  
  CudaRasterizer::Rasterizer::backward(P, degree, M, R,
  background.contiguous().data<float>(),
  normal_map_pixels.contiguous().data<float>(),
	  intersected_depth_pixels.contiguous().data<float>(),
	  warped_image_pixels.contiguous().data<float>(),
	  W, H, 
	  means3D.contiguous().data<float>(),
	  sh.contiguous().data<float>(),
	  colors.contiguous().data<float>(),
	  all_maps.contiguous().data<float>(),
	  scales.data_ptr<float>(),
	  scale_modifier,
	  rotations.data_ptr<float>(),
	  cov3D_precomp.contiguous().data<float>(),
	  viewmatrix.contiguous().data<float>(),
	  projmatrix.contiguous().data<float>(),
	  ref_to_src_list.contiguous().data<float>(),
	  src_cam_pos.contiguous().data<float>(),
	  src_images.contiguous().data<float>(),
	  src_rendered_depths.contiguous().data<float>(),
	  nb_src_images, 
	  campos.contiguous().data<float>(),
	  tan_fovx,
	  tan_fovy,
	  radii.contiguous().data<int>(),
	  reinterpret_cast<char*>(geomBuffer.contiguous().data_ptr()),
	  reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()),
	  reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()),
	  dL_dout_color.contiguous().data<float>(),
	  dL_dout_normal_map.contiguous().data<float>(),
	  dL_dout_median_intersected_depth.contiguous().data<float>(),
	  dL_dout_warped_image.contiguous().data<float>(),
	  dL_dmeans2D.contiguous().data<float>(),
	  dL_dmeans2D_abs.contiguous().data<float>(),
	  dL_dconic.contiguous().data<float>(),  
	  dL_dopacity.contiguous().data<float>(),
	  dL_dcolors.contiguous().data<float>(),
	  dL_dmeans3D.contiguous().data<float>(),
	  dL_dcov3D.contiguous().data<float>(),
	  dL_dsh.contiguous().data<float>(),
	  dL_dscales.contiguous().data<float>(),
	  dL_drotations.contiguous().data<float>(),
	  dL_dall_map.contiguous().data<float>(),
	  render_geo,
	  debug);
  }

  return std::make_tuple(dL_dmeans2D, dL_dmeans2D_abs, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, dL_dsh, dL_dscales, dL_drotations, dL_dall_map);
}

torch::Tensor markVisible(
		torch::Tensor& means3D,
		torch::Tensor& viewmatrix,
		torch::Tensor& projmatrix)
{ 
  const int P = means3D.size(0);
  
  torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool));
 
  if(P != 0)
  {
	CudaRasterizer::Rasterizer::markVisible(P,
		means3D.contiguous().data<float>(),
		viewmatrix.contiguous().data<float>(),
		projmatrix.contiguous().data<float>(),
		present.contiguous().data<bool>());
  }
  
  return present;
}
