#ifndef CUDA_RASTERIZER_H_INCLUDED
#define CUDA_RASTERIZER_H_INCLUDED

#include <vector>
#include <functional>
#include <torch/extension.h>

// #include <pybind11/pybind11.h>
// #include <pybind11/stl.h>
// Either of the above two headers could make the compilation successful; 
// however, <pybind11/pybind11.h> will make the compilation very slow.

namespace CudaRasterizer
{
	class Rasterizer
	{
	public:

		static void markVisible(
			int P,
			float* means3D,
			float* viewmatrix,
			float* projmatrix,
			bool* present);


		/////////////////////////////// Preprocess ///////////////////////////////



		// Forward rendering procedure for differentiable rasterization
		// of Gaussians.
		static int preprocessForward(
			float2* means2D,
			float* depths,
			int* radii,
			float* cov3D,
			float4* conic_opacity,
			float* rgb,
			bool* clamped,//the above are all per-Gaussian intemediate results.
			const int P, int D, int M,
			const int width, int height,
			const float* means3D,
			const float* scales,
			const float* rotations,
			const float* shs,
			const float* opacities,//3dgs parameters
			const float scale_modifier,
			const float* viewmatrix,
			const float* projmatrix,
			const float* cam_pos,
			const float tan_fovx, float tan_fovy,
			const bool prefiltered,
			bool debug,//raster_settings
			const pybind11::dict &args);

		static void preprocessBackward(
			const int* radii,
			const float* cov3D,
			const bool* clamped,//the above are all per-Gaussian intemediate results.
			const int P, int D, int M, int R,
			const int width, int height,//rasterization setting.
			const float* means3D,
			const float* scales,
			const float* rotations,
			const float* shs,//input of this operator
			const float scale_modifier,
			const float* viewmatrix,
			const float* projmatrix,
			const float* campos,
			const float tan_fovx, float tan_fovy,//rasterization setting.
			const float* dL_dmean2D,
			const float* dL_dconic,
			float* dL_dcolor,//gradients of output of this operator
			float* dL_dmean3D,
			float* dL_dcov3D,
			float* dL_dscale,
			float* dL_drot,
			float* dL_dsh,//gradients of input of this operator
			bool debug,
			const pybind11::dict &args);
		

		////////////////////// GetDistributionStrategy ////////////////////////

		static void getDistributionStrategy(
			std::function<char* (size_t)> distBuffer,
			const int P,
			const int width, int height,
			float2* means2D,
			int* radii,
			bool* compute_locally,
			bool debug,
			const pybind11::dict &args);



		/////////////////////////////// Render ///////////////////////////////




		static int renderForward(
			std::function<char* (size_t)> geometryBuffer,
			std::function<char* (size_t)> binningBuffer,
			std::function<char* (size_t)> imageBuffer,
			const int P,
			const float* background,
			const int width, int height,
			float2* means2D,
			float* depths,
			int* radii,
			float4* conic_opacity,
			float* rgb,
			float* semantic_feature,
			float* entity_feature,
			bool* compute_locally,
			float* out_color,
			float* out_feature_map, 
			float* out_entity_map, 
			float* out_opacity_map, 
			int* n_render,
			int* n_consider,// If your uint32_t array contains values higher than 2,147,483,647, they will overflow when converted to int.
			int* n_contrib,//array of results for this function. 
			bool debug,
			const pybind11::dict &args);

		static void renderBackward(
			const int P, int R,
			const float* background,
			const int width, int height,//rasterization settings. 
			char* geom_buffer,
			char* binning_buffer,
			char* img_buffer,//buffer that contains intermedia results
			bool* compute_locally,
			const float* dL_dpix,//gradient of output
			const float* dL_dpix_semantic, 
			const float* dL_dpix_entity, 
			const float* dL_dpix_opacity, 
			float* dL_dmean2D,
			float* dL_dconic,
			float* dL_dopacity,
			float* dL_dcolor,//gradient of inputs
			float* dL_dsemantic_feature, 
			float* dL_dentity_feature, 
			const float2* means2D,
			const float4* conic_opacity,
			const float* rgb,//inputs
			const float* semantic_feature,//inputs
			const float* entity_feature,//inputs
			bool debug,
			const pybind11::dict &args);


	};
};

#endif