

#include "kittens.cuh"
#include "pyutils/pyutils.cuh"
#include <tuple>
#include <cooperative_groups.h>
namespace cg = cooperative_groups;

#ifdef TORCH_COMPILE
#define TK_COMPILE_CYLON_LINEAR
#endif

#define NUM_WORKERS (8)
#define NUM_THREADS (NUM_WORKERS*kittens::WARP_THREADS)
#define NUM_WARPGROUPS (NUM_WORKERS/kittens::WARPGROUP_WARPS)

using namespace kittens;

template<ducks::sv::all SV, ducks::st::all ST>
__device__ inline void cumulative_add(SV &dst, const ST &src) {
    // this is called along a warpgroup
    static_assert(ST::cols <= 128);
    static_assert(ST::cols == SV::length);
    int lane = threadIdx.x % 128;
    if(lane < ST::cols) {
        float f = dst[lane];
        // acc equal to the last row of dst
        for (auto i = 0; i < ST::rows; i++) {
            f += __bfloat162float(src[{i, lane}]);
        }
        dst[lane] = f;
    }
}

template<ducks::rt::all RT>
__device__ inline void softmax_featuremap_inplace(RT &tile) {
    col_vec<RT> max_vec, sum_vec;
    warp::row_max(max_vec, tile);
    warp::sub_row(tile, tile, max_vec); // now in range (-infty, 0) for numerical stability
    warp::exp2(tile, tile);
    warp::row_sum(sum_vec, tile);
    warp::div_row(tile, tile, sum_vec);
}

#define CHUNK_SIZE 64
#define ATTN_D 128
#define ATTN_F 128
#define HALF_ATTN_F 64

struct cylon_globals { 
    // shapes    
    using q_tile = st_bf<CHUNK_SIZE, ATTN_F>;
    using k_tile = st_bf<CHUNK_SIZE, ATTN_F>;
    using v_tile = st_bf<CHUNK_SIZE, ATTN_D>;
    using o_tile = st_bf<CHUNK_SIZE, ATTN_D>;
    using kv_state_tile = st_fl<ATTN_F, ATTN_D>;
    using k_state_vec = sv_fl<ATTN_D>;
    using qk_map_tile = st_bf<ATTN_D, HALF_ATTN_F>;

    // global layouts
    using q_gl = gl<bf16,  -1, -1, -1, -1, q_tile>;
    using k_gl = gl<bf16,  -1, -1, -1, -1, k_tile>;
    using v_gl = gl<bf16,  -1, -1, -1, -1, v_tile>;
    using qmap_gl = gl<bf16,  -1, -1, -1, -1, qk_map_tile>;
    using kmap_gl = gl<bf16,  -1, -1, -1, -1, qk_map_tile>;
    using k_state_gl = gl<float, -1, -1, -1, -1, k_state_vec>;
    using kv_state_gl = gl<float, -1, -1, -1, -1, kv_state_tile>;
    using o_gl = gl<bf16,  -1, -1, -1, -1, o_tile>;

    // pointers
    q_gl q;
    k_gl k;
    v_gl v;
    qmap_gl qmap;
    kmap_gl kmap;
    k_state_gl k_state;
    kv_state_gl kv_state;
    o_gl o;
    int N;

    int dynamic_shared_memory() { return 227000; }
    dim3 grid()  { return dim3(q.depth(), q.batch(), qmap.batch()); } //dim3(Q.batch * ((Q.depth + 3) / 4)); }
    dim3 block() { return dim3(256); }
};

// should be launched with a grid of size (HEADS, BATCH, NUM_FEATURE_MAPS) and blocks of 256 threads.
__global__ __launch_bounds__(NUM_THREADS, 1)
void cylon_linear_attention_smd (
    const __grid_constant__ cylon_globals g
)  { // alpha is for linear component, beta is for sliding window component. Array, per head.
    auto cluster_group = cg::this_cluster();

    extern __shared__ int __shm[]; // this is the CUDA shared memory
    tma_swizzle_allocator al((int*)&__shm[0]);

    const int batch = blockIdx.y;
    const int head  = blockIdx.x;
    const int feature_map_id = blockIdx.z;
    // const int batch_head_id = batch*gridDim.x + head;

    // smem
    using q_tile = st_bf<CHUNK_SIZE, ATTN_F>;
    using k_tile = st_bf<CHUNK_SIZE, ATTN_F>;
    using v_tile = st_bf<CHUNK_SIZE, ATTN_D>;
    using o_tile = st_bf<CHUNK_SIZE, ATTN_D>;
    using kv_state_tile = st_fl<ATTN_F, ATTN_D>;
    using k_state_vec = sv_fl<ATTN_D>;
    using qk_map_tile = st_bf<ATTN_D, HALF_ATTN_F>;
    q_tile (&q_smem)[2] = al.allocate<q_tile, 2>(); // 32k, (tic/toc)*16k
    k_tile (&k_smem)[2] = al.allocate<k_tile, 2>(); // 32k, (2-ring)*(64x128)
    v_tile (&v_smem)[2] = al.allocate<v_tile, 2>(); // 32k, (2-ring)*(64x128)
    o_tile (&o_smem)    = al.allocate<o_tile>   (); // 16k
    qk_map_tile (&qf_map) = al.allocate<qk_map_tile>(); // 16k, for fusing featuremap computation
    qk_map_tile (&kf_map) = al.allocate<qk_map_tile>(); // 16k, for fusing featuremap computation

    // norm stuff
    st_bf<CHUNK_SIZE, ATTN_F> (&kv_smem)[2]        = al.allocate<st_bf<CHUNK_SIZE, ATTN_F>, 2>(); // 32k, 64x128 featurized 
    st_bf<CHUNK_SIZE, 4*16>   (&k_scratch_smem)[2] = al.allocate<st_bf<CHUNK_SIZE, 4*16>, 2>(); // 16k, 64x64 each.
    row_vec<st_fl<CHUNK_SIZE,4*16>> (&cumsum_k_smem)[2] = al.allocate<row_vec<st_fl<CHUNK_SIZE,4*16>>, 2>(); // smol
    col_vec<st_fl<CHUNK_SIZE,4*16>> (&norm_exchange)[1] = al.allocate<col_vec<st_fl<CHUNK_SIZE,4*16>>, 1>(); // smol       (we just need one of these?)

    int warpid = kittens::warpid();
    int warpgroupid = warpid/4;
    int warpgroup_warpid = warpgroup::warpid();
    int blocks = g.N / (q_tile::rows);

    int tic = 0, toc = 1;
    // int ring_id = 0;

    __shared__ semaphore qkv_semaphore;
    if (warpid == 0) {
        warp::init_semaphore(qkv_semaphore, 0, 1);
        warp::tma::expect_bytes(qkv_semaphore, 
            size_bytes<typeof(q_smem[0])> + 
            size_bytes<typeof(k_smem[0])> + 
            size_bytes<typeof(v_smem[0])> +
            // we need qk maps to be loaded on this first iter, too.
            size_bytes<typeof(qf_map)> +
            size_bytes<typeof(kf_map)>
        );
        // first thing we need to do is load the QK map
        warp::tma::load_async(qf_map, g.qmap, {feature_map_id, head, 0, 0}, qkv_semaphore); // load the right head
        warp::tma::load_async(kf_map, g.kmap, {feature_map_id, head, 0, 0}, qkv_semaphore);
        // now we also load the first data we need
        warp::tma::load_async(q_smem[tic],  g.q,      {batch, head, 0, 0}, qkv_semaphore);
        warp::tma::load_async(k_smem[tic], g.k, {batch, head, 0, 0}, qkv_semaphore);
        warp::tma::load_async(v_smem[tic], g.v, {batch, head, 0, 0}, qkv_semaphore);
    }

    // persistent register tile for k accumulation
    // this is going to be split across the two warpgroups involved.
    rt_fl<1*16, 8*16> local_kv; // 64 registers
    // contains the latest KV state up to the previous tile (for this tile, we'll have to do something causal)

    // TODO let's also define a persistent register tile for k_cumsum?

    warp::zero(local_kv);
    warpgroup::zero(cumsum_k_smem[warpgroupid]);

    __syncthreads();

    for (int block = 0; block < blocks; block++, tic^=1, toc^=1) {

        wait(qkv_semaphore, tic);  // ding! memory arrived
        __syncthreads();

        if (warpid == 0 && block < blocks-1) {
            warp::tma::expect_bytes(qkv_semaphore,
                size_bytes<typeof(q_smem[0])> + 
                size_bytes<typeof(k_smem[0])> + 
                size_bytes<typeof(v_smem[0])>
            );
            warp::tma::load_async(q_smem[toc],           g.q, {batch, head, block+1, 0}, qkv_semaphore); 

            // TODO we can probably get rid of ring_id and replace it with the same tic/toc binary ring
            warp::tma::load_async(k_smem[toc], g.k, {batch, head, block+1, 0}, qkv_semaphore); 
            warp::tma::load_async(v_smem[toc], g.v, {batch, head, block+1, 0}, qkv_semaphore);
        }
        __syncthreads();

        // this is partitioned across the two warpgroups.
        rt_fl<1*16, 8*16> linear_o; // 64 registers
        rt_fl<1*16, 4*16>::col_vec linear_norm_vec; // 2 registers
        warp::zero(linear_norm_vec);

        // ******* linear attn ******** // 

        // (1) GENERATE LINEAR_K
        rt_fl<1*16, 4*16> linear_k; // 32 registers
        
        warpgroup::mm_AB(linear_k, k_smem[tic], kf_map);
        warpgroup::mma_async_wait();
        if(warpgroupid) warp::mul(linear_k, linear_k, -1.44269504089f);
        else            warp::mul(linear_k, linear_k,  1.44269504089f);
        softmax_featuremap_inplace(linear_k);

        warpgroup::store(k_scratch_smem[warpgroupid], linear_k); // screw it, this is now just a scratchpad.

        // (2) GENERATE LINEAR_Q
        rt_fl<1*16, 4*16> linear_q; // 32 registers
        rt_bf<1*16, 4*16> linear_q_bf; // 16 registers

        warpgroup::mm_AB(linear_q, q_smem[tic], qf_map);
        warpgroup::mma_async_wait();
        if(warpgroupid) warp::mul(linear_q, linear_q, -1.44269504089f);
        else            warp::mul(linear_q, linear_q,  1.44269504089f);
        softmax_featuremap_inplace(linear_q);
        warp::copy(linear_q_bf, linear_q); // now to bf16

        // (3) COMPUTE QUADRATIC BLOCK
        rt_fl<1*16, 4*16> scores;
        rt_bf<1*16, 4*16> scores_bf;
        rt_fl<1*16, 8*16> quad_o;
        rt_fl<1*16, 4*16>::col_vec quad_norm_vec;

        warpgroup::sync(warpgroup::groupid()); 
        warpgroup::mm_ABt(scores, linear_q_bf, k_scratch_smem[warpgroupid]);     // contains scores for half of feature map only
        warpgroup::mma_async_wait();

        // Make quadratic block causal
        warp::apply(scores, scores, [warpgroup_warpid]__device__(int r, int c, float val) {
            return c <= (warpgroup_warpid*16 + r) ? val : 0.0f;
        });

        // Compute quad_norm_vec
        warp::row_sum(quad_norm_vec, scores);

        warp::copy(scores_bf, scores);

        // Compute quadratic block outputs
        warpgroup::mm_AB(quad_o, scores_bf, v_smem[tic]);
        warpgroup::mma_async_wait();

        if (block == 0) {
            warp::zero(linear_o);
        } else {
            // copy the local KV cache into shared memory to shared memory and do matmul
            warpgroup::store(kv_smem[warpgroupid], local_kv);
            __syncthreads(); // this should probably be a cooperative group of just the 4 warps
            warpgroup::mm_AB(linear_o, linear_q_bf, kv_smem[warpgroupid]);
            warpgroup::mma_async_wait();

            // next we need to go figure out the norm.
            // first we load sum(k) from smem to registers.
            row_vec<rt_bf<1*16,4*16>> cumsum_k_reg;
            warp::load(cumsum_k_reg, cumsum_k_smem[warpgroupid]);
            // now we can project this up into a register tile
            // we're broadcasting along the column axis (filling all rows with the same value)
            rt_bf<1*16,4*16> cumsum_k_reg_tile;
            warp::broadcast_col(cumsum_k_reg_tile, cumsum_k_reg);
            // next we matmul! this gives us a tile.
            rt_fl<1*16,1*16> norm_tile;
            warp::zero(norm_tile);
            warp::mma_ABt(norm_tile, linear_q_bf, cumsum_k_reg_tile, norm_tile);
            warp::row_max(linear_norm_vec, norm_tile); // technically any column slice would work but this is EZ
            // ^ note this incorporates alpha since it was premultiplied onto linear_q!
        }

        // (5) COMBINE AND STORE OUTPUTS
        // TODO do we need these?
        // tma::store_async_wait();

        // __syncthreads();

        // next step is to sum two norm vecs
        if (warpgroupid == 0) {
            warp::add(quad_norm_vec, quad_norm_vec, linear_norm_vec);
            warp::add(quad_o, quad_o, linear_o);

            __syncthreads();        // this and ...

            warpgroup::load(linear_o, o_smem);
            warpgroup::load(linear_norm_vec, norm_exchange[0]);

            warp::add(quad_o, quad_o, linear_o);
            warp::add(quad_norm_vec, quad_norm_vec, linear_norm_vec);

            warp::div_row(quad_o, quad_o, quad_norm_vec); // this half is now normalized

            warpgroup::store(o_smem, quad_o);
        } else {
            warp::add(linear_o, linear_o, quad_o);
            warp::add(linear_norm_vec, linear_norm_vec, quad_norm_vec);

            warpgroup::store(o_smem, linear_o);
            warpgroup::store(norm_exchange[0], linear_norm_vec);
            __syncthreads();        // ... this are meant to match up!
        }

        __syncthreads();

        // now both warpgroups have the data
        if(warpgroupid == 1) {
            warpgroup::load(quad_o, o_smem);
        }

        // gl.o shape: B, H, N, D
        // quad_o shape: B, H, N, D
        group<2>::storeAdd(g.o, quad_o, {batch, head, block, 0});
        

        // (6) UPDATE KV AND CUMSUM K STATES for next iteration
        // copy the local K into shared memory & do matmul

        cumulative_add(cumsum_k_smem[warpgroupid], k_scratch_smem[warpgroupid]);   // Q: can't we add from register to smem?
        // problem: this brings local_kv up to the current block's kv state. we want some cumsum thing  
        warpgroup::mma_AtB(local_kv, k_scratch_smem[warpgroupid], v_smem[tic]); // Q: can we leave one of these in register?
        warpgroup::mma_async_wait();
    }
    tma::store_async_wait();

    // Finally we want to write out the kv state and the k state
    // reinterpret k state as a vector of length 128, to save a tma call
    k_state_vec (&k_state_smem) = *reinterpret_cast<k_state_vec*>(&cumsum_k_smem[0].data[0]);
    // store out kv state into smem.
    kv_state_tile (&kv_state_smem) = reinterpret_cast<kv_state_tile&>(q_smem[0].data[0]); // we can overwrite early stuff, it's fine
    group<8>::store(kv_state_smem, local_kv); // all 8 warps store their own chunk.
    __syncthreads();
    // write out kv state
    if(warpid == 0){
        warp::tma::store_add_async(g.kv_state, kv_state_smem, {batch, head, 0, 0});
        warp::tma::store_add_async(g.k_state, k_state_smem, {batch, head, 0, 0});
        warp::tma::store_commit_group();
    }
    __syncthreads();
    tma::store_async_wait();
}



PYBIND11_MODULE(cylon_linear, m) {
    m.doc() = "cylon_linear python module";
    kittens::py::bind_kernel<cylon_linear_attention_smd>(m, "cylon",
        &cylon_globals::q,
        &cylon_globals::k,
        &cylon_globals::v,
        &cylon_globals::qmap,
        &cylon_globals::kmap,
        &cylon_globals::k_state,
        &cylon_globals::kv_state,
        &cylon_globals::o,
        &cylon_globals::N
    );
}



// cylon_globals cylon_init(
//     bf16 *d_q, bf16 *d_k, bf16 *d_v, bf16 *d_o,
//     bf16 *d_qmap, bf16 *d_kmap,
//     float *d_k_state, float *d_kv_state,
//     uint32_t ATTN_B, uint32_t ATTN_H, uint32_t ATTN_N,
//     uint32_t NUM_FEATURE_MAPS
// ) {
//     // global pointers. 
//     using q_tile = st_bf<CHUNK_SIZE, ATTN_F>;
//     using k_tile = st_bf<CHUNK_SIZE, ATTN_F>;
//     using v_tile = st_bf<CHUNK_SIZE, ATTN_D>;
//     using o_tile = st_bf<CHUNK_SIZE, ATTN_D>;
//     using kv_state_tile = st_fl<ATTN_F, ATTN_D>;
//     using k_state_vec = sv_fl<ATTN_D>;
//     using qk_map_tile = st_bf<ATTN_D, HALF_ATTN_F>;
    
//     using q_global = gl<bf16, -1, -1, -1, -1, q_tile>;
//     using k_global = gl<bf16, -1, -1, -1, -1, k_tile>;
//     using v_global = gl<bf16, -1, -1, -1, -1, v_tile>;
//     using o_global = gl<bf16, -1, -1, -1, -1, o_tile>;
//     using kv_state_global = gl<float, -1, -1, -1, -1, kv_state_tile>;
//     using k_state_global = gl<float, -1, -1, -1, -1, k_state_vec>;
//     using qmap_global = gl<bf16, -1, -1, -1, -1, qk_map_tile>;
//     using kmap_global = gl<bf16, -1, -1, -1, -1, qk_map_tile>;
    
//     using globals = cylon_globals;
//     q_global q_arg{d_q, ATTN_B, ATTN_H, ATTN_N, ATTN_F};
//     k_global k_arg{d_k, ATTN_B, ATTN_H, ATTN_N, ATTN_F};
//     v_global v_arg{d_v, ATTN_B, ATTN_H, ATTN_N, ATTN_D};
//     o_global o_arg{d_o, ATTN_B, ATTN_H, ATTN_N, ATTN_D};
//     qmap_global qmap_arg{d_qmap, NUM_FEATURE_MAPS, ATTN_H, ATTN_F, HALF_ATTN_F};
//     kmap_global kmap_arg{d_kmap, NUM_FEATURE_MAPS, ATTN_H, ATTN_F, HALF_ATTN_F};
//     kv_state_global kv_state_arg{d_kv_state, ATTN_B, ATTN_H, ATTN_F, ATTN_D};
//     k_state_global k_state_arg{d_k_state, ATTN_B, ATTN_H, 1, ATTN_D};

//     // Initialize output and state tensors to zero
//     #ifndef TORCH_COMPILE
//     cudaMemset(d_o, 0, ATTN_B * ATTN_H * ATTN_N * ATTN_D * sizeof(bf16));
//     cudaMemset(d_kv_state, 0, ATTN_B * ATTN_H * ATTN_F * ATTN_D * sizeof(float));
//     cudaMemset(d_k_state, 0, ATTN_B * ATTN_H * ATTN_D * sizeof(float));
//     #endif

//     globals g{
//         q_arg, k_arg, v_arg, 
//         qmap_arg, kmap_arg,
//         k_state_arg, kv_state_arg,
//         o_arg
//     };
//     return g;
// }

// #ifdef TK_COMPILE_CYLON_LINEAR
// #include "pyutils/torch_helpers.cuh"
// #include <iostream>
// void dispatch_cylon( 
//     bf16 *d_q, bf16 *d_k, bf16 *d_v, bf16 *d_o,
//     bf16 *d_qmap, bf16 *d_kmap,
//     float *d_k_state, float *d_kv_state,
//     int ATTN_B, int ATTN_H, int ATTN_N,
//     int NUM_FEATURE_MAPS
// ){
//     cylon_globals g = cylon_init(
//         d_q, d_k, d_v, d_o,
//         d_qmap, d_kmap,
//         d_k_state, d_kv_state,
//         ATTN_B, ATTN_H, ATTN_N,
//         NUM_FEATURE_MAPS
//     );

//     // launch
//     unsigned long mem_size = kittens::MAX_SHARED_MEMORY;
//     cudaFuncSetAttribute(
//         cylon_linear_attention_smd,
//         cudaFuncAttributeMaxDynamicSharedMemorySize,
//         mem_size
//     );
//     dim3 grid(ATTN_H, ATTN_B, NUM_FEATURE_MAPS);
//     cylon_linear_attention_smd<<<grid,NUM_THREADS,mem_size>>>(g, ATTN_N);
//     CHECK_CUDA_ERROR(cudaGetLastError());
// }

// std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> cylon_linear(
//     const torch::Tensor q, 
//     const torch::Tensor k,
//     const torch::Tensor v,
//     const torch::Tensor qmap,
//     const torch::Tensor kmap
// ) {
//     CHECK_INPUT(q);
//     CHECK_INPUT(k);
//     CHECK_INPUT(v);
//     CHECK_INPUT(qmap);
//     CHECK_INPUT(kmap);

//     int B = q.size(0);
//     int H = q.size(1);
//     int N = q.size(2);
//     int DQ = q.size(3);
//     int DV = v.size(3);

//     int K = qmap.size(0);
//     int FD = qmap.size(3);

//     // checks
//     TORCH_CHECK(k.size(0) == B, "k batch?");
//     TORCH_CHECK(k.size(1) == H, "k heads?");
//     TORCH_CHECK(k.size(2) == N, "k length?");

//     TORCH_CHECK(v.size(0) == B, "v batch?");
//     TORCH_CHECK(v.size(1) == H, "v heads?");
//     TORCH_CHECK(v.size(2) == N, "v length?");

//     TORCH_CHECK(qmap.size(1) == H, "qmap heads?");
//     TORCH_CHECK(qmap.size(2) == DQ, "qmap in dim?");
//     TORCH_CHECK(qmap.size(3) == FD, "qmap feature dim?");

//     TORCH_CHECK(qmap.size(0) == K, "qmap k?");
//     TORCH_CHECK(kmap.size(1) == H, "kmap heads?");
//     TORCH_CHECK(kmap.size(2) == DQ, "kmap in dim?");
//     TORCH_CHECK(kmap.size(3) == FD, "kmap feature dim?");


//     // allocate outputs
//     torch::Tensor out = torch::zeros({B, H, N, DV}, v.options());
//     torch::Tensor kv_state = torch::zeros({B, H, FD*2, DV}, torch::dtype(torch::kFloat32).device(v.device()));
//     torch::Tensor k_state = torch::zeros({B, H, 1, FD*2}, torch::dtype(torch::kFloat32).device(v.device()));

//     // convert to bf16 
//     c10::BFloat16 *q_bf16 = q.data_ptr<c10::BFloat16>();
//     c10::BFloat16 *k_bf16 = k.data_ptr<c10::BFloat16>();
//     c10::BFloat16 *v_bf16 = v.data_ptr<c10::BFloat16>();
//     c10::BFloat16 *qmap_bf16 = qmap.data_ptr<c10::BFloat16>();
//     c10::BFloat16 *kmap_bf16 = kmap.data_ptr<c10::BFloat16>();
    
//     bf16 *d_q = reinterpret_cast<bf16*>(q_bf16);
//     bf16 *d_k = reinterpret_cast<bf16*>(k_bf16);
//     bf16 *d_v = reinterpret_cast<bf16*>(v_bf16);
//     bf16 *d_qmap = reinterpret_cast<bf16*>(qmap_bf16);
//     bf16 *d_kmap = reinterpret_cast<bf16*>(kmap_bf16);
//     bf16 *d_o = reinterpret_cast<bf16*>(out.data_ptr<c10::BFloat16>());
//     float *d_k_state = k_state.data_ptr<float>();
//     float *d_kv_state = kv_state.data_ptr<float>();

//     dispatch_cylon(
//         d_q, d_k, d_v, d_o, 
//         d_qmap, d_kmap,
//         d_k_state, d_kv_state,
//         B, H, N, K
//     );

//     CHECK_CUDA_ERROR(cudaGetLastError());
//     return std::make_tuple(out, kv_state, k_state);
// }
// #else
// #include "harness.impl"
// #endif