#define TRAINS_IN __REPL_TRAINS_IN
#define TRAINS_OUT __REPL_TRAINS_OUT
#define ACCUMULATE_ABSOLUTE __REPL_ACCUMULATE_ABSOLUTE
#define DEBUG __REPL_DEBUG

#include <sm_60_atomic_functions.h>
#include <cstdint>
#include <cstdio>

// Kernel: one Block per (n, i) of (Nthreads,)
extern "C" __global__ void tau_corr_fixed(
        const uint16_t *__restrict__ S,      // [N * TRAINS_IN * Ti], sorted along ti; Spike times in units of dt
        const float *__restrict__ f,      // [N * TRAINS_OUT * H] evaluation times
        float *C,             // [N * TRAINS_OUT * D_max]
        int Ti,
        int T,
        int H,
        int D_max
) {
    const uint n = blockIdx.x;
    const uint i = blockIdx.y;

    const uint tid = threadIdx.x;

    extern __shared__ unsigned char S_buffer[]; // (Ti + pad) uint16 + (H + D_max) float32

    auto *spikes_buffer = reinterpret_cast<uint16_t *>(S_buffer);
    int bytes_u16 = Ti * sizeof(uint16_t);
    int pad = (4 - (bytes_u16 & 3)) & 3;  // smallest pad to 4-byte alignment
    float *f_buffer = reinterpret_cast<float *>((S_buffer + bytes_u16 + pad));
    float *Res_buffer = f_buffer + H;

    // Clear results accumulators
    for (uint tau_i = tid; tau_i < D_max; tau_i += blockDim.x) {
        Res_buffer[tau_i] = 0;
    }
    // Write f to buffer
    for (uint h_i = tid; h_i < H; h_i += blockDim.x) {
        f_buffer[h_i] = f[n * TRAINS_OUT * H + i * H + h_i];
    }

    for (int j = 0; j < TRAINS_IN; ++j) {
        for (uint ti_i = tid; ti_i < Ti; ti_i += blockDim.x) {
            spikes_buffer[ti_i] = S[n * TRAINS_IN * Ti + j * Ti + ti_i];
        }
        __syncthreads();
        for (uint tau_i = tid; tau_i < D_max; tau_i += blockDim.x) {
            float acc = 0;
            for (uint ti_i = 0; ti_i < Ti; ++ti_i) {
                uint16_t ti = spikes_buffer[ti_i];
                if (tau_i >= ti) {
                    if (tau_i >= H + ti) {
                        break;
                    }
                    acc += f_buffer[tau_i - ti];
                }
            }
            if (ACCUMULATE_ABSOLUTE){
                Res_buffer[tau_i] += fabs(acc);
            }
            else {
                Res_buffer[tau_i] += acc;
            }
        }
        __syncthreads();
    }
    for (uint tau_i = tid; tau_i < D_max; tau_i += blockDim.x) {
        if (DEBUG && n == 0 && i == 2) {
            printf("tau_i=%d, Res_buffer[tau_i]=%f\n", tau_i, Res_buffer[tau_i]);
        }
        C[n * TRAINS_OUT * D_max + i * D_max + tau_i] = Res_buffer[tau_i];
    }
}

/*
// Host launcher
__host__ void launch_computeV(
        const float *h_S,   // [N * TRAINS_IN * Ti], sorted per (n,j)
        const float *h_c,   // [TRAINS_OUT * TRAINS_IN]
        const float *h_d,   // [TRAINS_OUT]
        const float *h_w,   // [TRAINS_OUT]
        const float *h_T,   // [To]
        float *h_V,
        int N,
        int Ti,
        int To
) {
    size_t szS = sizeof(float) * N * TRAINS_IN * Ti;
    size_t szC = sizeof(float) * TRAINS_OUT * TRAINS_IN;
    size_t szD = sizeof(float) * TRAINS_OUT;
    size_t szW = sizeof(float) * TRAINS_OUT;
    size_t szT = sizeof(float) * To;
    size_t szV = sizeof(float) * N * TRAINS_OUT * To;

    float *d_S, *d_c, *d_d, *d_w, *d_T, *d_V;
    cudaMalloc(&d_S, szS);
    cudaMalloc(&d_c, szC);
    cudaMalloc(&d_d, szD);
    cudaMalloc(&d_w, szW);
    cudaMalloc(&d_T, szT);
    cudaMalloc(&d_V, szV);

    cudaMemcpy(d_S, h_S, szS, cudaMemcpyHostToDevice);
    cudaMemcpy(d_c, h_c, szC, cudaMemcpyHostToDevice);
    cudaMemcpy(d_d, h_d, szD, cudaMemcpyHostToDevice);
    cudaMemcpy(d_w, h_w, szW, cudaMemcpyHostToDevice);
    cudaMemcpy(d_T, h_T, szT, cudaMemcpyHostToDevice);

    dim3 grid(N, To);
    dim3 block(THREADS);
    printf("Launching \n");
    // Launch the kernel
    computeV_kernel<<<grid, block, (Ti + TRAINS_OUT) * sizeof(float)>>>(
            d_S, d_c, d_d, d_w, d_T, d_V, Ti, To
    );
    auto err = cudaDeviceSynchronize();
    printf("Finished %d \n", err);
    cudaMemcpy(h_V, d_V, szV, cudaMemcpyDeviceToHost);

    // Cleanup
    cudaFree(d_S);
    cudaFree(d_c);
    cudaFree(d_d);
    cudaFree(d_w);
    cudaFree(d_T);
    cudaFree(d_V);
}


__host__ int main() {
    // Dimensions
    const int N = N_SAMPLES;
    const int Ti = N_TI;
    const int To = N_TO;

    // Host arrays
    std::vector<float> h_S(N * TRAINS_IN * Ti);
    std::vector<float> h_c(TRAINS_OUT * TRAINS_IN);
    std::vector<float> h_ndiw(TRAINS_OUT);
    std::vector<float> h_iw(TRAINS_OUT);
    std::vector<float> h_T(To);
    std::vector<float> h_V(N * TRAINS_OUT * To);

    // Fill h_S with sorted example data: S[n,j,ti] = ti + j*10 + n*100
    for (int n = 0; n < N; ++n) {
        for (int j = 0; j < TRAINS_IN; ++j) {
            for (int ti = 0; ti < Ti; ++ti) {
                h_S[(n * TRAINS_IN + j) * Ti + ti] = float(n * 100 + j * 10 + ti);
            }
        }
    }

    // Simple weights c[i,j] = 1/(1+j)
    for (int i = 0; i < TRAINS_OUT; ++i)
        for (int j = 0; j < TRAINS_IN; ++j)
            h_c[i * TRAINS_IN + j] = 1.0f;

    // Delays d[i] = i, widths w[i] = 1.0 + i
    for (int i = 0; i < TRAINS_OUT; ++i) {
        h_ndiw[i] = -float(i);
        h_iw[i] = 1.0f;
    }

    // Evaluation times T[to] = to * 5.0
    for (int to = 0; to < To; ++to) {
        h_T[to] = float(to * 5);
    }

    launch_computeV(h_S.data(), h_c.data(), h_ndiw.data(), h_iw.data(), h_T.data(), h_V.data(), N, Ti, To);

}*/