#define KERNEL_EVAL x
#define TRAINS_OUT 100
#define DEBUG true

#define THREADS 32
#define FULL_MASK 0xffffffff

#define N_SAMPLES 1000
#define N_TI 100
#define N_TO 300

#include <sm_60_atomic_functions.h>
#include <cstdio>

#include <cstdio>
#include <vector>

__inline__ __device__
float phi(float x) {
    return KERNEL_EVAL;
}

// Kernel: one Block per (n, i, to) of (THREADS,)
extern "C" __global__ void out_volt_fixed(
        const float *__restrict__ S,      // [N * TRAINS_OUT * Ti], sorted along ti
        const float *__restrict__ weights,      // [TRAINS_OUT]
        const float *__restrict__ neg_delays_inv_widths,      // [TRAINS_OUT]
        const float *__restrict__ inv_widths,      // [TRAINS_OUT]
        const float *__restrict__ T,      // [To] evaluation times
        float *V,             // [N * To * TRAINS_OUT]
        int Ti,
        int To
) {
    const uint n = blockIdx.x;
    const uint i = blockIdx.y;
    const uint to_i = blockIdx.z;
    const uint tid = threadIdx.x;

    float to = T[to_i];
    float iwi = inv_widths[i];
    float di = neg_delays_inv_widths[i];

    float acc = 0;

    for (uint ti_i = tid; ti_i < Ti; ti_i += THREADS) {
        float ti = S[n * TRAINS_OUT * Ti + i * Ti + ti_i];
        if (ti >= to) {
            break;
        }
        acc += phi(__fmaf_rn(to - ti, iwi, di));
        if (DEBUG && isnan(acc)) {
            printf("(%d, %d, %d: %d) evaluating (%f-%f) * %f + %f=%f=> %f\n", n, i, to_i, tid, to, ti, iwi, di,
                   __fmaf_rn(to - ti, iwi, di), phi(__fmaf_rn(to - ti, iwi, di)));
        }
    }
    //if (DEBUG) { printf("(%d, %d, %d) accumulating %f to %d\n", n, to_i, tid, acc, i); }

    __syncwarp();
    for (int offset = 16; offset > 0; offset /= 2) {
        acc += __shfl_down_sync(FULL_MASK, acc, offset);
    }
    if (tid == 0) {
        V[n * To * TRAINS_OUT + to_i * TRAINS_OUT + i] = acc * weights[i];
        if (DEBUG && isnan(V[n * To * TRAINS_OUT + to_i * TRAINS_OUT + i])) {
            printf("(%d, %d, %d: %d), %f, %f\n", n, i, to_i, tid, acc, weights[i]);
        }
    }
}



// Host launcher
__host__ void launch_computeV(
        const float *h_S,   // [N * TRAINS_OUT * Ti], sorted per (n,j)
        const float *h_c,   // [TRAINS_OUT]
        const float *h_d,   // [TRAINS_OUT]
        const float *h_w,   // [TRAINS_OUT]
        const float *h_T,   // [To]
        float *h_V,
        int N,
        int Ti,
        int To
) {
    size_t szS = sizeof(float) * N * TRAINS_OUT * Ti;
    size_t szC = sizeof(float) * TRAINS_OUT;
    size_t szD = sizeof(float) * TRAINS_OUT;
    size_t szW = sizeof(float) * TRAINS_OUT;
    size_t szT = sizeof(float) * To;
    size_t szV = sizeof(float) * N * TRAINS_OUT * To;

    float *d_S, *d_c, *d_d, *d_w, *d_T, *d_V;
    cudaMalloc(&d_S, szS);
    cudaMalloc(&d_c, szC);
    cudaMalloc(&d_d, szD);
    cudaMalloc(&d_w, szW);
    cudaMalloc(&d_T, szT);
    cudaMalloc(&d_V, szV);

    cudaMemcpy(d_S, h_S, szS, cudaMemcpyHostToDevice);
    cudaMemcpy(d_c, h_c, szC, cudaMemcpyHostToDevice);
    cudaMemcpy(d_d, h_d, szD, cudaMemcpyHostToDevice);
    cudaMemcpy(d_w, h_w, szW, cudaMemcpyHostToDevice);
    cudaMemcpy(d_T, h_T, szT, cudaMemcpyHostToDevice);

    dim3 grid(N, TRAINS_OUT, To);
    dim3 block(THREADS);
    printf("Launching \n");
    // Launch the kernel
    out_volt_fixed<<<grid, block>>>(
            d_S, d_c, d_d, d_w, d_T, d_V, Ti, To
    );
    auto err = cudaDeviceSynchronize();
    printf("Finished %d \n", err);
    cudaMemcpy(h_V, d_V, szV, cudaMemcpyDeviceToHost);

    // Cleanup
    cudaFree(d_S);
    cudaFree(d_c);
    cudaFree(d_d);
    cudaFree(d_w);
    cudaFree(d_T);
    cudaFree(d_V);
}


__host__ int main() {
    // Dimensions
    const int N = N_SAMPLES;
    const int Ti = N_TI;
    const int To = N_TO;

    // Host arrays
    std::vector<float> h_S(N * TRAINS_OUT * Ti);
    std::vector<float> h_c(TRAINS_OUT);
    std::vector<float> h_ndiw(TRAINS_OUT);
    std::vector<float> h_iw(TRAINS_OUT);
    std::vector<float> h_T(To);
    std::vector<float> h_V(N * TRAINS_OUT * To);

    // Fill h_S with sorted example data: S[n,j,ti] = ti + j*10 + n*100
    for (int n = 0; n < N; ++n) {
        for (int j = 0; j < TRAINS_OUT; ++j) {
            for (int ti = 0; ti < Ti; ++ti) {
                h_S[(n * TRAINS_OUT + j) * Ti + ti] = float(n * 100 + j * 10 + ti);
            }
        }
    }

    // Simple weights c[i,j] = 1/(1+j)
    for (int i = 0; i < TRAINS_OUT; ++i)
        h_c[i] = 1.0f;

    // Delays d[i] = i, widths w[i] = 1.0 + i
    for (int i = 0; i < TRAINS_OUT; ++i) {
        h_ndiw[i] = -float(i);
        h_iw[i] = 1.0f;
    }

    // Evaluation times T[to] = to * 5.0
    for (int to = 0; to < To; ++to) {
        h_T[to] = float(to * 5);
    }

    launch_computeV(h_S.data(), h_c.data(), h_ndiw.data(), h_iw.data(), h_T.data(), h_V.data(), N, Ti, To);
    printf("Finished\n");


    return 0;

}