#define KERNEL_EVAL __REPL_KERNEL_EVAL
#define TRAINS_IN __REPL_TRAINS_IN
#define TRAINS_OUT __REPL_TRAINS_OUT
#define DEBUG __REPL_DEBUG

#include <sm_60_atomic_functions.h>

__inline__ __device__
float phi(float x) {
    return KERNEL_EVAL;
}

// Kernel: one Block per (n, to) of (Nthreads,)
extern "C" __global__ void in_volt_fixed(
        const float *__restrict__ S,      // [N * TRAINS_IN * Ti], sorted along ti
        const float *__restrict__ weights,      // [TRAINS_OUT * TRAINS_IN]
        const float *__restrict__ neg_delays_inv_widths,      // [TRAINS_OUT]
        const float *__restrict__ inv_widths,      // [TRAINS_OUT]
        const float *__restrict__ T,      // [To] evaluation times
        float *V,             // [N * To * TRAINS_OUT]
        int Ti,
        int To
) {
    const uint n = blockIdx.x;
    const uint to_i = blockIdx.y;

    float to = T[to_i];

    const uint tid = threadIdx.x;

    extern __shared__ float S_buffer[]; // Ti + TRAINS_OUT
    for (uint i = tid; i < TRAINS_OUT; i += blockDim.x) {
        S_buffer[Ti + i] = 0;
    }

    for (int j = 0; j < TRAINS_IN; ++j) {
        for (uint ti_i = tid; ti_i < Ti; ti_i += blockDim.x) {
            S_buffer[ti_i] = S[n * TRAINS_IN * Ti + j * Ti + ti_i];
        }
        __syncthreads();
        for (uint i = tid; i < TRAINS_OUT; i += blockDim.x) {
            float acc = 0;
            float di = neg_delays_inv_widths[i];
            float wij = weights[j * TRAINS_OUT + i];
            if(DEBUG && isnan(__fmaf_rn(wij, acc, S_buffer[Ti + i]))) {
                printf("(%d, %d: %d) nan at j=%d, i=%d, wij=%f, acc=%f, S_buff=%f\n", n, to_i, tid, j, i, wij, acc, S_buffer[Ti + i]);
            }
            float iwi = inv_widths[i];
            for (uint ti_i = 0; ti_i < Ti; ++ti_i) {
                float ti = S_buffer[ti_i];
                if (ti > to) {
                    break;
                }
                acc += phi(__fmaf_rn(to - ti, iwi, di));
            }
            if(DEBUG && isnan(__fmaf_rn(wij, acc, S_buffer[Ti + i]))) {
                printf("(%d, %d: %d) nan at j=%d, i=%d, wij=%f, acc=%f, S_buff=%f\n", n, to_i, tid, j, i, wij, acc, S_buffer[Ti + i]);
            }
            S_buffer[Ti + i] = __fmaf_rn(wij, acc, S_buffer[Ti + i]);
        }
        __syncthreads();
    }
    for(uint i = tid; i < TRAINS_OUT; i += blockDim.x) {
        V[n * TRAINS_OUT * To + to_i * TRAINS_OUT + i] = S_buffer[Ti + i];
    }
}

/*
// Host launcher
__host__ void launch_computeV(
        const float *h_S,   // [N * TRAINS_IN * Ti], sorted per (n,j)
        const float *h_c,   // [TRAINS_OUT * TRAINS_IN]
        const float *h_d,   // [TRAINS_OUT]
        const float *h_w,   // [TRAINS_OUT]
        const float *h_T,   // [To]
        float *h_V,
        int N,
        int Ti,
        int To
) {
    size_t szS = sizeof(float) * N * TRAINS_IN * Ti;
    size_t szC = sizeof(float) * TRAINS_OUT * TRAINS_IN;
    size_t szD = sizeof(float) * TRAINS_OUT;
    size_t szW = sizeof(float) * TRAINS_OUT;
    size_t szT = sizeof(float) * To;
    size_t szV = sizeof(float) * N * TRAINS_OUT * To;

    float *d_S, *d_c, *d_d, *d_w, *d_T, *d_V;
    cudaMalloc(&d_S, szS);
    cudaMalloc(&d_c, szC);
    cudaMalloc(&d_d, szD);
    cudaMalloc(&d_w, szW);
    cudaMalloc(&d_T, szT);
    cudaMalloc(&d_V, szV);

    cudaMemcpy(d_S, h_S, szS, cudaMemcpyHostToDevice);
    cudaMemcpy(d_c, h_c, szC, cudaMemcpyHostToDevice);
    cudaMemcpy(d_d, h_d, szD, cudaMemcpyHostToDevice);
    cudaMemcpy(d_w, h_w, szW, cudaMemcpyHostToDevice);
    cudaMemcpy(d_T, h_T, szT, cudaMemcpyHostToDevice);

    dim3 grid(N, To);
    dim3 block(THREADS);
    printf("Launching \n");
    // Launch the kernel
    computeV_kernel<<<grid, block, (Ti + TRAINS_OUT) * sizeof(float)>>>(
            d_S, d_c, d_d, d_w, d_T, d_V, Ti, To
    );
    auto err = cudaDeviceSynchronize();
    printf("Finished %d \n", err);
    cudaMemcpy(h_V, d_V, szV, cudaMemcpyDeviceToHost);

    // Cleanup
    cudaFree(d_S);
    cudaFree(d_c);
    cudaFree(d_d);
    cudaFree(d_w);
    cudaFree(d_T);
    cudaFree(d_V);
}


__host__ int main() {
    // Dimensions
    const int N = N_SAMPLES;
    const int Ti = N_TI;
    const int To = N_TO;

    // Host arrays
    std::vector<float> h_S(N * TRAINS_IN * Ti);
    std::vector<float> h_c(TRAINS_OUT * TRAINS_IN);
    std::vector<float> h_ndiw(TRAINS_OUT);
    std::vector<float> h_iw(TRAINS_OUT);
    std::vector<float> h_T(To);
    std::vector<float> h_V(N * TRAINS_OUT * To);

    // Fill h_S with sorted example data: S[n,j,ti] = ti + j*10 + n*100
    for (int n = 0; n < N; ++n) {
        for (int j = 0; j < TRAINS_IN; ++j) {
            for (int ti = 0; ti < Ti; ++ti) {
                h_S[(n * TRAINS_IN + j) * Ti + ti] = float(n * 100 + j * 10 + ti);
            }
        }
    }

    // Simple weights c[i,j] = 1/(1+j)
    for (int i = 0; i < TRAINS_OUT; ++i)
        for (int j = 0; j < TRAINS_IN; ++j)
            h_c[i * TRAINS_IN + j] = 1.0f;

    // Delays d[i] = i, widths w[i] = 1.0 + i
    for (int i = 0; i < TRAINS_OUT; ++i) {
        h_ndiw[i] = -float(i);
        h_iw[i] = 1.0f;
    }

    // Evaluation times T[to] = to * 5.0
    for (int to = 0; to < To; ++to) {
        h_T[to] = float(to * 5);
    }

    launch_computeV(h_S.data(), h_c.data(), h_ndiw.data(), h_iw.data(), h_T.data(), h_V.data(), N, Ti, To);

}*/