﻿#ifndef RENDER_H
#define RENDER_H

#include "sdf.cginc"
#include "shading.cginc"
#include "render_defs.cginc"

float3x3 setCamera(in float3 ro, in float3 ta, float cr)
{
    float3 cw = normalize(ta - ro);
    float3 cp = float3(sin(cr), cos(cr), 0.);
    float3 cu = normalize(cross(cw, cp));
    float3 cv = (cross(cu, cw));
    return float3x3(cu, cv, cw);
}

#ifdef NO_CTSS
float3 render(in float3 ro, in float3 rd, in float3 rdx, in float3 rdy, const float tan_theta)
{
    float tmax = TMAX;

    #ifdef BBOX_FLOOR
    // raytrace "floor"
    float tp1 = (BBOX_FLOOR_Y - ro.y) / rd.y;
    if (tp1 > 0.)
    {
        tmax = min(tmax, tp1);
    }
    #endif

    // sphere tracing
    float t = TMIN;
    for (int i = 0; i < MAX_RAYMARCH_STEPS && t < tmax; i++)
    {
        float3 p = ro + rd * t;
        float2 h = sdf(p);

        // classical sphere tracing (with cone intersection detection criterion)
        float coneRad = t * tan_theta;
        float coneOcclusion = (1. - h.x / coneRad) / 2.;

        bool hardHit = OCCLUSION_HARD < coneOcclusion;
        // bool hardHit = h.x < 0.0001;

        // hard hit
        if (hardHit)
        {
            float2 sample = float2(t, h.y); // ray depth and object code
            // float2 sample = float2(t, h.y); // ray depth and object code

#ifdef DEBUG_DEPTH
            float depth = sample.x / TMAX;
            return float3(depth * depth, 0., 0.);
#endif

            float3 pos = ro + rd * sample.x;
            float3 nor = calcNormal(pos);

#ifdef DEBUG_NORMAL
            return normal2color(nor);
#endif

            // return sample.y;
            return shadedColor(ro, rd, rdx, rdy, sample, pos, nor);
        }

        float step = h.x * SPHERE_TRACE_STEPSIZE_MULT;
        // increment depth along ray
        #ifdef GRID_TRACE
        float d2cell = abs(sd2dBox(opRep(p, GRID).xz, 0.5*GRID.xz));
        t += sign(step) * min(max(d2cell, 0.01), abs(step));
        #else
        t += step;
        #endif
    }

    #ifdef WHITE_BG
    return float3(1., 1., 1.);
    #else
    return float3(clamp(FOG_COLOR - max(rd.y, 0.) * 0.3, 0., 1.));
    #endif
}

#else

#define TWO_OVER_FIVE_MINUS_ONE (2. / 5. - 1.)
#define FOUR_OVER_FIVE_MINUS_ONE (4. / 5. - 1.)
#define SIX_OVER_FIVE_MINUS_ONE (6. / 5. - 1.)
#define EIGHT_OVER_FIVE_MINUS_ONE (8. / 5. - 1.)
#define TEN_OVER_FIVE_MINUS_ONE (10. / 5. - 1.)

#define MIN_NORMAL 0.001f

#define SHIFT(x, n) ((0 <= n) ? (x << n): (x >> -n))

/**
    Returns a 32-bit visibility mask given current cone radius, SDF value, and SDF normal.
    Mask represents ~6x6 subpixel visibility (occupancy).
    Visibility over 6 horizontal lines is encoded with 4, 6, 6, 6, 6, 4 bits (32 bits total).
    Note that corner bits are omitted since 6x6=36, but we only have 32 bits.
    Example bit mask with full visibility (Xs show the space that is omitted in the mask):
        X1111X
        111111
        111111
        111111
        111111
        X1111X
    This will be packed in a single int as 32 bits: 1111 111111 111111 111111 111111 1111
*/
int getVisibilityMask(in float coneOcclusion, in float r, in float2 nor)
{
    // ensure non-zero normal x, y components
    float2 n = normalize(float2(
        clamp(abs(nor.x), MIN_NORMAL, 1.) * (nor.x < 0. ? -1.: 1.),
        clamp(abs(nor.y), MIN_NORMAL, 1.) * (nor.y < 0. ? -1.: 1.)
    ));  // clamp abs then apply sign

    float h = r * (1. - 2. * coneOcclusion);
    h = -h;

    // define y = ax + b to describe how the surface intersects the traced cone
    // solve for a and b using simple geometry and math
    float a = -n.x / n.y;  // slope;
    float b = h * (n.y * n.y + n.x * n.x) / n.y;  // y intersect

    // NOTE:
    // yIncr = 2. / 5. * r;
    // for i in (0, 1, 2, 3, 4, 5)
    // py = yIncr * i - r
    // vis = ((py - b) / a / r + 1.) / 2. = (py - b) / a / r / 2. + 0.5
    float t = 0.5f / a / r;

    // solve for x (intersection of surface and horizontal line given by y)
    // saturate visibility to 0-1
    float vis0 = saturate((-r - b) * t + 0.5);
    float vis1 = saturate((r * TWO_OVER_FIVE_MINUS_ONE - b) * t + 0.5);
    float vis2 = saturate((r * FOUR_OVER_FIVE_MINUS_ONE - b) * t + 0.5);
    float vis3 = saturate((r * SIX_OVER_FIVE_MINUS_ONE - b) * t + 0.5);
    float vis4 = saturate((r * EIGHT_OVER_FIVE_MINUS_ONE - b) * t + 0.5);
    float vis5 = saturate((r * TEN_OVER_FIVE_MINUS_ONE - b) * t + 0.5);

    // apply vis - 1 if n.x is positive
    float vis_r = (0 < n.x) ? -1.: 0.; // 0 or -1
    // convert visibility to bit shift amounts (max 6 bit shift)
    int shift0 = (int)(6. * (vis0 + vis_r));
    int shift1 = (int)(6. * (vis1 + vis_r));
    int shift2 = (int)(6. * (vis2 + vis_r));
    int shift3 = (int)(6. * (vis3 + vis_r));
    int shift4 = (int)(6. * (vis4 + vis_r));
    int shift5 = (int)(6. * (vis5 + vis_r));

    int mask_int = 0;  // 32-bit mask
    int mask4 = 0x1E; // 0b00011110, the side bits are 0s for line 0 and 5
    int mask6 = 0x3F; // 0b00111111;  // all 6 bits are used for lines 1, 2, 3, 4
    mask_int |= (SHIFT(mask6, shift0) & mask4) << 27;  // >> 1 << 28
    mask_int |= (SHIFT(mask6, shift1) & mask6) << 22;
    mask_int |= (SHIFT(mask6, shift2) & mask6) << 16;
    mask_int |= (SHIFT(mask6, shift3) & mask6) << 10;
    mask_int |= (SHIFT(mask6, shift4) & mask6) << 4;
    mask_int |= (SHIFT(mask6, shift5) & mask4) >> 1;  // >> 1 << 0

    return mask_int;
}

int bitCountOnes(int num) {
    int count = 0;
    for (int i = 0; i < 32; i++)
    {
        count += num & 1;
        num = num >> 1;
    }
    return count;
}

float3 render(in float3 ro, in float3 rd, in float3 rdx, in float3 rdy, const float tan_theta)
{
    float tmax = TMAX;
    #ifdef BBOX_FLOOR
    // raytrace floor plane
    float tp1 = (BBOX_FLOOR_Y - ro.y) / rd.y;
    if (tp1 > 0.)
    {
        tmax = min(tmax, tp1);
    }
    #endif

    // CTSS variables
    int numSamples = 0;
    bool sampleBg = true; // if need to sample Background
    bool hasHardHit = false;

    float3 colorTotal = float3(0., 0., 0.);
    float weightTotal = 0.;

    #ifdef USE_CTSS_WEIGHTED
    float maxConeOcclusion = 0.;
    float coneRadiusAtMaxOcclusion = 0.;
    int visibilityMask = 0;
    #endif

    // sphere tracing
    float t = TMIN;
    float tP = TMIN;

    bool softHitP = false; // previous
    bool hardHitP = false; // previous
    bool hitSampled = false; // if the current hit was already sampled

    // ray depth and object code
    float2 hP; // previous
    float2 hSample = float2(0., 0.); // sample

    // float tAtEntry = 0.;

    for (int i = 0; i < MAX_RAYMARCH_STEPS; i++)
    {
        float3 p = ro + rd * t;
        float2 h = sdf(p);

        float coneRad = t * tan_theta;
        float coneOcclusion = (1. - h.x / coneRad) / 2.;

        bool softHit = OCCLUSION_SOFT < coneOcclusion;
        bool hardHit = OCCLUSION_HARD < coneOcclusion;
        bool fullHit = OCCLUSION_STOP < coneOcclusion;

        bool hardHitEntry = hardHit && !hardHitP;
        bool softHitEntry = softHit && !softHitP;
        bool softHitExit = !softHit && softHitP;

        if (softHit)
        {
            if (softHitEntry)
            {
                hitSampled = false;
                hasHardHit = false;
                // tAtEntry = t;

#ifdef USE_CTSS_WEIGHTED
                maxConeOcclusion = coneOcclusion;
                coneRadiusAtMaxOcclusion = coneRad;
#endif
            }

            // update position for soft and hard entry points
            // if (softHitEntry || (hardHitEntry && !hasHardHit))
            if (softHitEntry)
            {
                // record backtraced cone intersection
                // current ray depth minus safe cone radius
                hSample = float2(CONE_BACKTRACE(t, tan_theta), h.y);
            }

            if (hardHitEntry && !hasHardHit)
            {
                // record backtraced hard hit intersection
                // previous ray depth + previous SDF value
                hSample = float2(tP + hP.x, hP.y);
            }

            hasHardHit = hasHardHit || hardHitEntry;

            #ifdef USE_CTSS_WEIGHTED
            if (maxConeOcclusion < coneOcclusion)
            {
                maxConeOcclusion = coneOcclusion;
                coneRadiusAtMaxOcclusion = coneRad;
            }
            #endif
        }

        // check if need to sample
        bool sample = (numSamples < CTSS_NUM_SAMPLES - 1) && (softHitExit || fullHit);
        if (!hitSampled && sample)
        {
            float3 pos = ro + rd * hSample.x;
            float3 nor = calcNormal(pos);

#ifdef USE_CTSS_WEIGHTED
            // project the normal from world space to camera space; rd is the camera ray
            // assumes camera is level with the horizon
            float3 right = normalize(cross(rd, float3(0., 1., 0.)));
            float3 up = normalize(cross(right, rd));
            float2 nor2d = normalize(float2(dot(right, nor), dot(up, nor)));

            // compute current visibility mask given normal, cone occlusion, and cone radius of current hit
            int visibilityMaskCrt = getVisibilityMask(maxConeOcclusion, coneRadiusAtMaxOcclusion, nor2d);
            visibilityMaskCrt &= ~visibilityMask; // correlation with previous hits, removes invisible bits
            visibilityMask |= visibilityMaskCrt; // update visibility mask given current visibility, adds visible bits

            float visibility = bitCountOnes(visibilityMaskCrt) / 32.;  // visible bit ratio
            float weight = max(MIN_SAMPLE_WEIGHT, visibility);

            // check if need to stop tracing based on the updated visibility mask (if no more possible visible bits)
            fullHit = fullHit || (bitCountOnes(visibilityMask) == 32);
#else
            float weight = 1.0;
#endif

            // // DEBUG GROUP DEPTH
            // float depth = (t - tAtEntry);
            // // float innerEdgeFactor = depth / (25. * coneRad) - abs(dot(rd, -nor));
            // // float innerEdgeFactor = depth / (25. * coneRad);
            // float innerEdgeFactor = 1. - dot(rd, -nor);
            // innerEdgeFactor = max(0., innerEdgeFactor);
            // colorTotal += float3(innerEdgeFactor, 0, 0);
            // weightTotal += 1.0;

            colorTotal += weight * shadedColor(ro, rd, rdx, rdy, hSample, pos, nor);
            weightTotal += weight;

            // reset vars and increment sample counter
            hitSampled = true;
            hasHardHit = false;
            numSamples++;
        }

        #ifdef DEBUG_DEPTH
        return float3(CONE_BACKTRACE(hSample.x, tan_theta) / TMAX, 0., 0.);
        #endif

        sampleBg = sampleBg && !fullHit;

        if (fullHit || tmax < t)
        {
            break;
        }

        // update previous
        softHitP = softHit;
        hardHitP = hardHit;

        h.x *= SPHERE_TRACE_STEPSIZE_MULT;

        // update previous h and t
        hP = h;
        tP = t;

        // increment depth along ray
        #ifdef GRID_TRACE
        // if on a grid with randomized objects (for each cell),
        // need to cap maximum move distance to distance from the point within cell to the boundary of the cell,
        // otherwise, the trace may miss geometry and have visual artifacts
        float d2cell = abs(sd2dBox(opRep(p, GRID).xz, 0.5*GRID.xz));
        t += max(min(max(d2cell, 0.01), abs(h.x)), CONE_RAD_STEP_MULT*coneRad);
        #else
        t += max(h.x, CONE_RAD_STEP_MULT*coneRad); // speed up sphere tracing near hard hits (sdf ~= 0)
        #endif
    }

    // background (sky)
    if (sampleBg)
    {
#ifdef USE_CTSS_WEIGHTED
        float bgWeight = bitCountOnes(~visibilityMask) / 32.0;
#else
        float bgWeight = 1.0;
#endif

#ifdef WHITE_BG
        colorTotal += bgWeight * float3(1., 1., 1.);
#else
        colorTotal += bgWeight * float3(clamp(FOG_COLOR - max(rd.y, 0.) * 0.3, 0., 1.));
#endif
        weightTotal += bgWeight;
    }

    colorTotal /= max(MIN_SAMPLE_WEIGHT, weightTotal);

#ifdef DEBUG_NUM_SAMPLES
    float v = float(numSamples) / float(CTSS_NUM_SAMPLES + 1);
    colorTotal = lerp(colorTotal, v * v, 0.9975);
#endif

    return colorTotal;
}
#endif

float3 pixelColor(in float2 pixel, in float3 ro, in float3 ta, in float tan_theta, in float focal_length)
{
    float3 color = float3(0., 0., 0.);

    // camera-to-world transformation
    float3x3 ca = setCamera(ro, ta, 0.);

    // ray differentials
    float2 px = (2. * (pixel + float2(1., 0.)) - _ScreenParams.xy) / _ScreenParams.y;
    float2 py = (2. * (pixel + float2(0., 1.)) - _ScreenParams.xy) / _ScreenParams.y;
    float3 rdx = mul(normalize(float3(px, focal_length)), ca);
    float3 rdy = mul(normalize(float3(py, focal_length)), ca);

    #if AA>1
    for( int m=0; m<AA; m++ )
    for( int n=0; n<AA; n++ )
    {
        // pixel coordinates
        float2 o = float2(float(m), float(n)) / float(AA) - (0.5 - 1. / float(2 * AA));
        float2 p = (2. * (pixel + o) - _ScreenParams.xy) / _ScreenParams.y;
    #else
        float2 p = (2. * pixel - _ScreenParams.xy) / _ScreenParams.y;
    #endif

        float3 fw = normalize(float3(p, focal_length));

        // ray direction
        float3 rd = mul(fw, ca);

        color += render(ro, rd, rdx, rdy, tan_theta);
    #if AA>1
    }
    color /= float(AA*AA);
    #endif

    // gain
    color = color * 3. / (2.5 + color);

    // gamma
    color = pow(color, float3(0.4545, 0.4545, 0.4545));

    return color;
}

#endif
