#include "forward2d.cuh"
#include "helpers.cuh"
#include <algorithm>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <iostream>
namespace cg = cooperative_groups;





__global__ void project_gaussians_2d_forward_kernel(
    const int num_points,
    const float2* __restrict__ means2d,
    const float3* __restrict__ L_elements,
    const dim3 img_size,
    const dim3 tile_bounds,
    const float clip_thresh,
    float2* __restrict__ xys,
    float* __restrict__ depths,
    int* __restrict__ radii,
    float3* __restrict__ conics,
    int32_t* __restrict__ num_tiles_hit
) {
    
    unsigned idx = cg::this_grid().thread_rank(); 
    if (idx >= num_points) {
        return;
    }
    radii[idx] = 0;
    num_tiles_hit[idx] = 0;

    
    
    
    

    float2 center = {0.5f * img_size.x * means2d[idx].x + 0.5f * img_size.x,
                     0.5f * img_size.y * means2d[idx].y + 0.5f * img_size.y};
    
    float l11 = L_elements[idx].x; 
    float l21 = L_elements[idx].y; 
    float l22 = L_elements[idx].z; 

    
    
                                
    float3 cov2d = make_float3(l11*l11, l11*l21, l21*l21 + l22*l22);
    
    const float px_var = 1.0f / 12.0f;   
    cov2d.x += px_var;                   
    cov2d.z += px_var;                   
    const float min_var = 0.25f * 0.25f; 
    cov2d.x = fmaxf(cov2d.x, min_var);
    cov2d.z = fmaxf(cov2d.z, min_var);
    
    float3 conic;
    float radius;
    bool ok = compute_cov2d_bounds(cov2d, conic, radius);
    if (!ok)
        return; 
    
    conics[idx] = conic;
    xys[idx] = center;
    radii[idx] = (int)radius;
    uint2 tile_min, tile_max;
    get_tile_bbox(center, radius, tile_bounds, tile_min, tile_max);
    int32_t tile_area = (tile_max.x - tile_min.x) * (tile_max.y - tile_min.y);
    if (tile_area <= 0) {
        
        return;
    }
    num_tiles_hit[idx] = tile_area;
    
    depths[idx] = 0.0f;

}

__global__ void project_gaussians_2d_scale_rot_forward_kernel(
    const int num_points,
    const float2* __restrict__ means2d,
    const float2* __restrict__ scales2d,
    const float* __restrict__ rotation,
    const dim3 img_size,
    const dim3 tile_bounds,
    const float clip_thresh,
    float2* __restrict__ xys,
    float* __restrict__ depths,
    int* __restrict__ radii,
    float3* __restrict__ conics,
    int32_t* __restrict__ num_tiles_hit
) {
    unsigned idx = cg::this_grid().thread_rank(); 
    if (idx >= num_points) {
        return;
    }
    radii[idx] = 0;
    num_tiles_hit[idx] = 0;

    
    float2 center = {0.5f * img_size.x * means2d[idx].x + 0.5f * img_size.x,
                     0.5f * img_size.y * means2d[idx].y + 0.5f * img_size.y};

    glm::mat2 R = rotmat2d(rotation[idx]);
    glm::mat2 S = scale_to_mat2d(scales2d[idx]);
    glm::mat2 M = R * S;
    glm::mat2 tmp = M * glm::transpose(M);
    

    float3 cov2d = make_float3(tmp[0][0], tmp[0][1], tmp[1][1]);
    
    float3 conic;
    float radius;
    bool ok = compute_cov2d_bounds(cov2d, conic, radius);
    if (!ok)
        return; 
    
    conics[idx] = conic;
    xys[idx] = center;
    radii[idx] = (int)radius;
    uint2 tile_min, tile_max;
    get_tile_bbox(center, radius, tile_bounds, tile_min, tile_max);
    int32_t tile_area = (tile_max.x - tile_min.x) * (tile_max.y - tile_min.y);
    if (tile_area <= 0) {
        
        return;
    }
    num_tiles_hit[idx] = tile_area;
    
    depths[idx] = 0.0f;

}