#define KERNEL_EVAL __REPL_KERNEL_EVAL
#define TRAINS_IN __REPL_TRAINS_IN
#define DEBUG __REPL_DEBUG

#include <sm_60_atomic_functions.h>
#include <cstdio>

__inline__ __device__
float phi(float x) {
    return KERNEL_EVAL;
}

// Kernel: one Block per (n, to) of (Nthreads,)
extern "C" __global__ void fit_contrib_fixed(
        const float *__restrict__ S,      // [N * TRAINS_IN * Ti], sorted along ti
        const float *__restrict__ T,      // [To] evaluation times
        float *A,             // Design Matrix: [N * To * (1 + TRAINS_IN)]
        const float *__restrict__ di_arr,
        const float *__restrict__ iwi_arr,
        const int Ti,
        const int To
) {
    const uint n = blockIdx.x;
    const uint to_i = blockIdx.y;

    float to = T[to_i];
    float ndiw = di_arr[0];
    float iwi = iwi_arr[0];

    const uint tid = threadIdx.x;


    if(DEBUG && tid == 0 && n == 0) {
        printf("n: %d, to: %d/%f\n", n, to_i, to);
        printf("tid: %d\n", tid);
        printf("d=%f, w=%f\n", -ndiw / iwi, 1 / iwi);
    }

    for (unsigned int j = tid; j < TRAINS_IN; j += blockDim.x) {
        float acc = 0;
        for (uint ti_i = 0; ti_i < Ti; ++ti_i) {
            float ti = S[n * TRAINS_IN * Ti + j * Ti + ti_i];
            if (ti > to) {
                break;
            }
            acc += phi(__fmaf_rn(to - ti, iwi, ndiw));
            if(DEBUG && tid == 0&& n == 0) {
                printf("ti: %f, acc: %f\n", ti, acc);
            }
        }
        A[n * To * (TRAINS_IN + 1) + to_i * (TRAINS_IN + 1) + j + 1] = acc;
        if(DEBUG && tid == 0&& n == 0) {
            printf("flat: %d, %d, %d\n", n * To * (TRAINS_IN + 1), to_i * (TRAINS_IN + 1), j + 1);
            printf("A[%d, %d, %d] = %f\n", n, to_i, j+1, acc);
        }
    }
}

/*
// Host launcher
__host__ void launch_computeV(
        const float *h_S,   // [N * TRAINS_IN * Ti], sorted per (n,j)
        const float *h_c,   // [TRAINS_OUT * TRAINS_IN]
        const float *h_d,   // [TRAINS_OUT]
        const float *h_w,   // [TRAINS_OUT]
        const float *h_T,   // [To]
        float *h_V,
        int N,
        int Ti,
        int To
) {
    size_t szS = sizeof(float) * N * TRAINS_IN * Ti;
    size_t szC = sizeof(float) * TRAINS_OUT * TRAINS_IN;
    size_t szD = sizeof(float) * TRAINS_OUT;
    size_t szW = sizeof(float) * TRAINS_OUT;
    size_t szT = sizeof(float) * To;
    size_t szV = sizeof(float) * N * TRAINS_OUT * To;

    float *d_S, *d_c, *d_d, *d_w, *d_T, *d_V;
    cudaMalloc(&d_S, szS);
    cudaMalloc(&d_c, szC);
    cudaMalloc(&d_d, szD);
    cudaMalloc(&d_w, szW);
    cudaMalloc(&d_T, szT);
    cudaMalloc(&d_V, szV);

    cudaMemcpy(d_S, h_S, szS, cudaMemcpyHostToDevice);
    cudaMemcpy(d_c, h_c, szC, cudaMemcpyHostToDevice);
    cudaMemcpy(d_d, h_d, szD, cudaMemcpyHostToDevice);
    cudaMemcpy(d_w, h_w, szW, cudaMemcpyHostToDevice);
    cudaMemcpy(d_T, h_T, szT, cudaMemcpyHostToDevice);

    dim3 grid(N, To);
    dim3 block(THREADS);
    printf("Launching \n");
    // Launch the kernel
    computeV_kernel<<<grid, block, (Ti + TRAINS_OUT) * sizeof(float)>>>(
            d_S, d_c, d_d, d_w, d_T, d_V, Ti, To
    );
    auto err = cudaDeviceSynchronize();
    printf("Finished %d \n", err);
    cudaMemcpy(h_V, d_V, szV, cudaMemcpyDeviceToHost);

    // Cleanup
    cudaFree(d_S);
    cudaFree(d_c);
    cudaFree(d_d);
    cudaFree(d_w);
    cudaFree(d_T);
    cudaFree(d_V);
}


__host__ int main() {
    // Dimensions
    const int N = N_SAMPLES;
    const int Ti = N_TI;
    const int To = N_TO;

    // Host arrays
    std::vector<float> h_S(N * TRAINS_IN * Ti);
    std::vector<float> h_c(TRAINS_OUT * TRAINS_IN);
    std::vector<float> h_ndiw(TRAINS_OUT);
    std::vector<float> h_iw(TRAINS_OUT);
    std::vector<float> h_T(To);
    std::vector<float> h_V(N * TRAINS_OUT * To);

    // Fill h_S with sorted example data: S[n,j,ti] = ti + j*10 + n*100
    for (int n = 0; n < N; ++n) {
        for (int j = 0; j < TRAINS_IN; ++j) {
            for (int ti = 0; ti < Ti; ++ti) {
                h_S[(n * TRAINS_IN + j) * Ti + ti] = float(n * 100 + j * 10 + ti);
            }
        }
    }

    // Simple weights c[i,j] = 1/(1+j)
    for (int i = 0; i < TRAINS_OUT; ++i)
        for (int j = 0; j < TRAINS_IN; ++j)
            h_c[i * TRAINS_IN + j] = 1.0f;

    // Delays d[i] = i, widths w[i] = 1.0 + i
    for (int i = 0; i < TRAINS_OUT; ++i) {
        h_ndiw[i] = -float(i);
        h_iw[i] = 1.0f;
    }

    // Evaluation times T[to] = to * 5.0
    for (int to = 0; to < To; ++to) {
        h_T[to] = float(to * 5);
    }

    launch_computeV(h_S.data(), h_c.data(), h_ndiw.data(), h_iw.data(), h_T.data(), h_V.data(), N, Ti, To);

}*/