

#include "kittens.cuh"
#include <tuple>
#include <cooperative_groups.h>
namespace cg = cooperative_groups;

#ifdef TORCH_COMPILE
#define TK_COMPILE_CYLON_LINEAR
#endif

#define NUM_WORKERS (8)
#define NUM_THREADS (NUM_WORKERS*kittens::WARP_THREADS)
#define NUM_WARPGROUPS (NUM_WORKERS/kittens::WARPGROUP_WARPS)

using namespace kittens;

template<ducks::sv::all SV, ducks::st::all ST>
__device__ inline void cumulative_add(SV &dst, const ST &src) {
    // this is called along a warpgroup
    static_assert(ST::cols <= 128);
    static_assert(ST::cols == SV::length);
    int lane = threadIdx.x % 128;
    if(lane < ST::cols) {
        float f = dst[lane];
        // acc equal to the last row of dst
        for (auto i = 0; i < ST::rows; i++) {
            f += __bfloat162float(src[{i, lane}]);
        }
        dst[lane] = f;
    }
}

template<ducks::rt::all RT>
__device__ inline void softmax_featuremap_inplace(RT &tile) {
    col_vec<RT> max_vec, sum_vec;
    row_max(max_vec, tile);
    sub_row(tile, tile, max_vec); // now in range (-infty, 0) for numerical stability
    exp2(tile, tile);
    row_sum(sum_vec, tile);
    div_row(tile, tile, sum_vec);
}

#define CHUNK_SIZE 64
#define ATTN_D 128
#define ATTN_F 128
#define HALF_ATTN_F 64

struct cylon_linear_sn_globals { 
    // shapes    
    using q_tile = st_bf<CHUNK_SIZE, ATTN_F>;
    using k_tile = st_bf<CHUNK_SIZE, ATTN_F>;
    using v_tile = st_bf<CHUNK_SIZE, ATTN_D>;
    using o_tile = st_bf<CHUNK_SIZE, ATTN_D>;
    using kv_state_tile = st_fl<ATTN_F, ATTN_D>;
    using k_state_vec = sv_fl<ATTN_D>;
    using qk_map_tile = st_bf<ATTN_D, HALF_ATTN_F>;

    // global layouts
    using q_gl = gl<bf16,  -1, -1, -1, -1, q_tile>;
    using k_gl = gl<bf16,  -1, -1, -1, -1, k_tile>;
    using v_gl = gl<bf16,  -1, -1, -1, -1, v_tile>;
    using qmap_gl = gl<bf16,  -1, -1, -1, -1, qk_map_tile>;
    using kmap_gl = gl<bf16,  -1, -1, -1, -1, qk_map_tile>;
    using k_state_gl = gl<float, -1, -1, -1, -1, k_state_vec>;
    using kv_state_gl = gl<float, -1, -1, -1, -1, kv_state_tile>;
    using o_gl = gl<bf16,  -1, -1, -1, -1, o_tile>;

    // pointers
    q_gl q;
    k_gl k;
    v_gl v;
    qmap_gl qmap;
    kmap_gl kmap;
    k_state_gl k_state;
    kv_state_gl kv_state;
    o_gl o;

    float *shared_norm;
};


// should be launched with a grid of size (HEADS, BATCH) and blocks of 256 threads.
__global__ __launch_bounds__(NUM_THREADS, 1)
void cylon_linear_attention_smd_sn (
    const __grid_constant__ cylon_linear_sn_globals g, int N
)  { // alpha is for linear component, beta is for sliding window component. Array, per head.
    auto cluster_group = cg::this_cluster();

    extern __shared__ int __shm[]; // this is the CUDA shared memory
    tma_swizzle_allocator al((int*)&__shm[0]);

    const int batch = blockIdx.y;
    const int head  = blockIdx.x;
    const int feature_map_id = blockIdx.z;
    // const int batch_head_id = batch*gridDim.x + head;

    float shared_norm = g.shared_norm[head];

    // smem
    using q_tile = st_bf<CHUNK_SIZE, ATTN_F>;
    using k_tile = st_bf<CHUNK_SIZE, ATTN_F>;
    using v_tile = st_bf<CHUNK_SIZE, ATTN_D>;
    using o_tile = st_bf<CHUNK_SIZE, ATTN_D>;
    using kv_state_tile = st_fl<ATTN_F, ATTN_D>;
    using k_state_vec = sv_fl<ATTN_D>;
    using qk_map_tile = st_bf<ATTN_D, HALF_ATTN_F>;
    q_tile (&q_smem)[2] = al.allocate<q_tile, 2>(); // 32k, (tic/toc)*16k
    k_tile (&k_smem)[2] = al.allocate<k_tile, 2>(); // 48k, (3-ring)*(64x128)
    v_tile (&v_smem)[2] = al.allocate<v_tile, 2>(); // 48k, (3-ring)*(64x128)
    o_tile (&o_smem)    = al.allocate<o_tile>   (); // 16k
    qk_map_tile (&qf_map) = al.allocate<qk_map_tile>(); // 16k, for fusing featuremap computation
    qk_map_tile (&kf_map) = al.allocate<qk_map_tile>(); // 16k, for fusing featuremap computation

    // norm stuff
    st_bf<CHUNK_SIZE, ATTN_F> (&kv_smem)[2] = al.allocate<st_bf<CHUNK_SIZE, ATTN_F>, 2>(); // 32k, 64x128 featurized 
    row_vec<st_fl<CHUNK_SIZE,4*16>> (&cumsum_k_smem)[2] = al.allocate<row_vec<st_fl<CHUNK_SIZE,4*16>>, 2>(); // smol
    // col_vec<st_fl<CHUNK_SIZE,4*16>> (&norm_exchange)[2] = al.allocate<col_vec<st_fl<CHUNK_SIZE,4*16>>, 2>(); // smol
    col_vec<st_fl<CHUNK_SIZE,4*16>> (&norm_exchange)[1] = al.allocate<col_vec<st_fl<CHUNK_SIZE,4*16>>, 1>(); // smol       (we just need one of these?)
    st_bf<CHUNK_SIZE, 4*16> (*k_scratch_smem)     = reinterpret_cast<st_bf<CHUNK_SIZE, 4*16>*>(&kv_smem[0].data[0]);

    int warpid = kittens::warpid();
    int warpgroupid = warpid/4;
    int warpgroup_warpid = warpgroup::warpid();
    int blocks = N / (q_tile::rows);

    int tic = 0, toc = 1;
    // int ring_id = 0;

    __shared__ semaphore qkv_semaphore;
    if (warpid == 0) {
        init_semaphore(qkv_semaphore, 0, 1);
        tma::expect_bytes(qkv_semaphore, 
            size_bytes<typeof(q_smem[0])> + 
            size_bytes<typeof(k_smem[0])> + 
            size_bytes<typeof(v_smem[0])> +
            // we need qk maps to be loaded on this first iter, too.
            size_bytes<typeof(qf_map)> +
            size_bytes<typeof(kf_map)>
        );
        // first thing we need to do is load the QK map
        tma::load_async(qf_map, g.qmap, {feature_map_id, head, 0, 0}, qkv_semaphore); // load the right head
        tma::load_async(kf_map, g.kmap, {feature_map_id, head, 0, 0}, qkv_semaphore);
        // now we also load the first data we need
        tma::load_async(q_smem[tic],  g.q,      {batch, head, 0, 0}, qkv_semaphore);
        tma::load_async(k_smem[tic], g.k, {batch, head, 0, 0}, qkv_semaphore);
        tma::load_async(v_smem[tic], g.v, {batch, head, 0, 0}, qkv_semaphore);
    }

    // persistent register tile for k accumulation
    rt_fl<1*16, 8*16> local_kv; // this is going to be split across the two warpgroups involved.
    // contains the latest KV state up to the previous tile (for this tile, we'll have to do something causal)

    // TODO let's also define a persistent register tile for k_cumsum?

    zero(local_kv);
    warpgroup::zero(cumsum_k_smem[warpgroupid]);

    __syncthreads();

    for (int block = 0; block < blocks; block++, tic^=1, toc^=1) {

        wait(qkv_semaphore, tic);  // ding! memory arrived
        __syncthreads();

        if (warpid == 0 && block < blocks-1) {
            tma::expect_bytes(qkv_semaphore,
                size_bytes<typeof(q_smem[0])> + 
                size_bytes<typeof(k_smem[0])> + 
                size_bytes<typeof(v_smem[0])>
            );
            tma::load_async(q_smem[toc],           g.q, {batch, head, block+1, 0}, qkv_semaphore); 

            // TODO we can probably get rid of ring_id and replace it with the same tic/toc binary ring
            tma::load_async(k_smem[toc], g.k, {batch, head, block+1, 0}, qkv_semaphore); 
            tma::load_async(v_smem[toc], g.v, {batch, head, block+1, 0}, qkv_semaphore);
        }
        __syncthreads();

        rt_fl<1*16, 8*16> linear_o; // this is partitioned across the two warpgroups.
        rt_fl<1*16, 4*16>::col_vec linear_norm_vec;
        zero(linear_norm_vec);

        // ******* linear attn ******** // 

        // (1) GENERATE LINEAR_Q
        rt_fl<1*16, 4*16> linear_q;
        rt_bf<1*16, 4*16> linear_q_bf;

        warpgroup::mm_AB(linear_q, q_smem[tic], qf_map);
        warpgroup::mma_async_wait();
        if(warpgroupid) mul(linear_q, linear_q, -1.44269504089f);
        else            mul(linear_q, linear_q,  1.44269504089f);
        softmax_featuremap_inplace(linear_q);
        copy(linear_q_bf, linear_q); // now to bf16

        // (2) GENERATE LINEAR_K
        rt_fl<1*16, 4*16> linear_k;
        
        warpgroup::mm_AB(linear_k, k_smem[tic], kf_map);
        warpgroup::mma_async_wait();
        if(warpgroupid) mul(linear_k, linear_k, -1.44269504089f);
        else            mul(linear_k, linear_k,  1.44269504089f);
        softmax_featuremap_inplace(linear_k);

        // (3) COMPUTE QUADRATIC BLOCK
        rt_fl<1*16, 4*16> scores;
        rt_bf<1*16, 4*16> scores_bf;
        rt_fl<1*16, 8*16> quad_o;
        rt_fl<1*16, 4*16>::col_vec quad_norm_vec;

        warpgroup::store(k_scratch_smem[warpgroupid], linear_k); // screw it, this is now just a scratchpad.
        warpgroup::sync(0); 

        warpgroup::mm_ABt(scores, linear_q_bf, k_scratch_smem[warpgroupid]);     // contains scores for half of feature map only
        warpgroup::mma_async_wait();

        // Make quadratic block causal
        for (int j = 0; j < 4; j++) {
            auto &scores_subtile = reinterpret_cast<rt_fl<1*16, 1*16>&>(scores.tiles[0][j]);
            if (j>warpgroup_warpid) zero(scores_subtile);       // need to use the warpid within the warpgroup. Not within the CTA
            else if (j==warpgroup_warpid) make_causal(scores_subtile, scores_subtile, kittens::base_types::constants<float>::zero());
        }

        // Compute quad_norm_vec
        row_sum(quad_norm_vec, scores);

        copy(scores_bf, scores);

        // Compute quadratic block outputs
        warpgroup::mm_AB(quad_o, scores_bf, v_smem[tic]);
        warpgroup::mma_async_wait();

        if (block == 0) {
            zero(linear_o);
        } else {
            // copy the local KV cache into shared memory to shared memory and do matmul
            warpgroup::store(kv_smem[warpgroupid], local_kv);
            __syncthreads(); // this should probably be a cooperative group of just the 4 warps
            warpgroup::mm_AB(linear_o, linear_q_bf, kv_smem[warpgroupid]);
            warpgroup::mma_async_wait();

            // next we need to go figure out the norm.
            // first we load sum(k) from smem to registers.
            row_vec<rt_bf<1*16,4*16>> cumsum_k_reg;
            load(cumsum_k_reg, cumsum_k_smem[warpgroupid]);
            // now we can project this up into a register tile
            // we're broadcasting along the column axis (filling all rows with the same value)
            rt_bf<1*16,4*16> cumsum_k_reg_tile;
            broadcast_col(cumsum_k_reg_tile, cumsum_k_reg);
            // next we matmul! this gives us a tile.
            rt_fl<1*16,1*16> norm_tile;
            zero(norm_tile);
            mma_ABt(norm_tile, linear_q_bf, cumsum_k_reg_tile, norm_tile);
            row_max(linear_norm_vec, norm_tile); // technically any column slice would work but this is EZ
            // ^ note this incorporates alpha since it was premultiplied onto linear_q!
        }

        // (5) COMBINE AND STORE OUTPUTS
        tma::store_async_wait();

        __syncthreads();

        // next step is to sum two norm vecs
        if (warpgroupid == 0) {
            add(quad_norm_vec, quad_norm_vec, linear_norm_vec);
            add(quad_o, quad_o, linear_o);

            __syncthreads();        // this and ...

            warpgroup::load(linear_o, o_smem);
            warpgroup::load(linear_norm_vec, norm_exchange[0]);

            add(quad_o, quad_o, linear_o);
            add(quad_norm_vec, quad_norm_vec, linear_norm_vec);
            
            add(quad_norm_vec, quad_norm_vec, shared_norm);

            div_row(quad_o, quad_o, quad_norm_vec); // this half is now normalized

            warpgroup::store(o_smem, quad_o);
        } else {
            add(linear_o, linear_o, quad_o);
            add(linear_norm_vec, linear_norm_vec, quad_norm_vec);

            warpgroup::store(o_smem, linear_o);
            warpgroup::store(norm_exchange[0], linear_norm_vec);
            __syncthreads();        // ... this are meant to match up!
        }

        __syncthreads();

        if(warpid == 0) {
            tma::store_add_async(g.o, o_smem, {batch, head, block, 0});
        }

        // (6) UPDATE KV AND CUMSUM K STATES for next iteration
        // copy the local K into shared memory & do matmul

        warpgroup::store(k_scratch_smem[warpgroupid], linear_k); // screw it, this is now just a scratchpad.
        warpgroup::sync(3);        // TODO: replace with warpgroup-level sync
        cumulative_add(cumsum_k_smem[warpgroupid], k_scratch_smem[warpgroupid]);   // Q: can't we add from register to smem?
        // problem: this brings local_kv up to the current block's kv state. we want some cumsum thing  
        warpgroup::mma_AtB(local_kv, k_scratch_smem[warpgroupid], v_smem[tic]); // Q: can we leave one of these in register?
        warpgroup::mma_async_wait();
    }
    tma::store_async_wait();

    // Finally we want to write out the kv state and the k state
    // reinterpret k state as a vector of length 128, to save a tma call
    k_state_vec (&k_state_smem) = *reinterpret_cast<k_state_vec*>(&cumsum_k_smem[0].data[0]);
    // store out kv state into smem.
    kv_state_tile (&kv_state_smem) = reinterpret_cast<kv_state_tile&>(q_smem[0].data[0]); // we can overwrite early stuff, it's fine
    group<8>::store(kv_state_smem, local_kv); // all 8 warps store their own chunk.
    __syncthreads();
    // write out kv state
    if(warpid == 0){
        tma::store_add_async(g.kv_state, kv_state_smem, {batch, head, 0, 0});
        tma::store_add_async(g.k_state, k_state_smem, {batch, head, 0, 0});
        tma::store_commit_group();
    }
    __syncthreads();
    tma::store_async_wait();
}

cylon_linear_sn_globals cylon_linear_sn_init(
    bf16 *d_q, bf16 *d_k, bf16 *d_v, bf16 *d_o,
    bf16 *d_qmap, bf16 *d_kmap,
    float *d_k_state, float *d_kv_state,
    float *d_shared_norm,
    int ATTN_B, int ATTN_H, int ATTN_N,
    int NUM_FEATURE_MAPS
) {
    // global pointers. 
    using q_tile = st_bf<CHUNK_SIZE, ATTN_F>;
    using k_tile = st_bf<CHUNK_SIZE, ATTN_F>;
    using v_tile = st_bf<CHUNK_SIZE, ATTN_D>;
    using o_tile = st_bf<CHUNK_SIZE, ATTN_D>;
    using kv_state_tile = st_fl<ATTN_F, ATTN_D>;
    using k_state_vec = sv_fl<ATTN_D>;
    using qk_map_tile = st_bf<ATTN_D, HALF_ATTN_F>;
    
    using q_global = gl<bf16, -1, -1, -1, -1, q_tile>;
    using k_global = gl<bf16, -1, -1, -1, -1, k_tile>;
    using v_global = gl<bf16, -1, -1, -1, -1, v_tile>;
    using o_global = gl<bf16, -1, -1, -1, -1, o_tile>;
    using kv_state_global = gl<float, -1, -1, -1, -1, kv_state_tile>;
    using k_state_global = gl<float, -1, -1, -1, -1, k_state_vec>;
    using qmap_global = gl<bf16, -1, -1, -1, -1, qk_map_tile>;
    using kmap_global = gl<bf16, -1, -1, -1, -1, qk_map_tile>;
    
    using globals = cylon_linear_sn_globals;
    q_global q_arg{d_q, ATTN_B, ATTN_H, ATTN_N, ATTN_F};
    k_global k_arg{d_k, ATTN_B, ATTN_H, ATTN_N, ATTN_F};
    v_global v_arg{d_v, ATTN_B, ATTN_H, ATTN_N, ATTN_D};
    o_global o_arg{d_o, ATTN_B, ATTN_H, ATTN_N, ATTN_D};
    qmap_global qmap_arg{d_qmap, NUM_FEATURE_MAPS, ATTN_H, ATTN_F, HALF_ATTN_F};
    kmap_global kmap_arg{d_kmap, NUM_FEATURE_MAPS, ATTN_H, ATTN_F, HALF_ATTN_F};
    kv_state_global kv_state_arg{d_kv_state, ATTN_B, ATTN_H, ATTN_F, ATTN_D};
    k_state_global k_state_arg{d_k_state, ATTN_B, ATTN_H, 1, ATTN_D};

    // Initialize output and state tensors to zero
   
    #ifndef TORCH_COMPILE
    cudaMemset(d_o, 0, ATTN_B * ATTN_H * ATTN_N * ATTN_D * sizeof(bf16));
    cudaMemset(d_kv_state, 0, ATTN_B * ATTN_H * ATTN_F * ATTN_D * sizeof(float));
    cudaMemset(d_k_state, 0, ATTN_B * ATTN_H * ATTN_D * sizeof(float));
    #endif

    globals g{
        q_arg, k_arg, v_arg, 
        qmap_arg, kmap_arg,
        k_state_arg, kv_state_arg,
        o_arg, d_shared_norm
    };
    return g;
}

#ifdef TK_COMPILE_CYLON_LINEAR
#include "pyutils/torch_helpers.cuh"
#include <iostream>
void dispatch_cylon_linear_sn( 
    bf16 *d_q, bf16 *d_k, bf16 *d_v, bf16 *d_o,
    bf16 *d_qmap, bf16 *d_kmap,
    float *d_k_state, float *d_kv_state,
    float *d_shared_norm,
    int ATTN_B, int ATTN_H, int ATTN_N,
    int NUM_FEATURE_MAPS
){
    cylon_linear_sn_globals g = cylon_linear_sn_init(
        d_q, d_k, d_v, d_o,
        d_qmap, d_kmap,
        d_k_state, d_kv_state,
        d_shared_norm,
        ATTN_B, ATTN_H, ATTN_N,
        NUM_FEATURE_MAPS
    );

    // launch
    unsigned long mem_size = kittens::MAX_SHARED_MEMORY;
    cudaFuncSetAttribute(
        cylon_linear_attention_smd_sn,
        cudaFuncAttributeMaxDynamicSharedMemorySize,
        mem_size
    );
    dim3 grid(ATTN_H, ATTN_B, NUM_FEATURE_MAPS);
    cylon_linear_attention_smd_sn<<<grid,NUM_THREADS,mem_size>>>(g, ATTN_N);
    CHECK_CUDA_ERROR(cudaGetLastError());
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> cylon_linear_sn(
    const torch::Tensor q, 
    const torch::Tensor k,
    const torch::Tensor v,
    const torch::Tensor qmap,
    const torch::Tensor kmap,
    const torch::Tensor shared_norm
) {
    CHECK_INPUT(q);
    CHECK_INPUT(k);
    CHECK_INPUT(v);
    CHECK_INPUT(qmap);
    CHECK_INPUT(kmap);
    CHECK_INPUT(shared_norm);

    int B = q.size(0);
    int H = q.size(1);
    int N = q.size(2);
    int DQ = q.size(3);
    int DV = v.size(3);

    int K = qmap.size(0);
    int FD = qmap.size(3);

    // checks
    TORCH_CHECK(k.size(0) == B, "k batch?");
    TORCH_CHECK(k.size(1) == H, "k heads?");
    TORCH_CHECK(k.size(2) == N, "k length?");

    TORCH_CHECK(v.size(0) == B, "v batch?");
    TORCH_CHECK(v.size(1) == H, "v heads?");
    TORCH_CHECK(v.size(2) == N, "v length?");

    TORCH_CHECK(qmap.size(1) == H, "qmap heads?");
    TORCH_CHECK(qmap.size(2) == DQ, "qmap in dim?");
    TORCH_CHECK(qmap.size(3) == FD, "qmap feature dim?");

    TORCH_CHECK(qmap.size(0) == K, "qmap k?");
    TORCH_CHECK(kmap.size(1) == H, "kmap heads?");
    TORCH_CHECK(kmap.size(2) == DQ, "kmap in dim?");
    TORCH_CHECK(kmap.size(3) == FD, "kmap feature dim?");

    TORCH_CHECK(shared_norm.size(0) == H, "shared norm dim?");

    // allocate outputs
    torch::Tensor out = torch::zeros({B, H, N, DV}, v.options());
    torch::Tensor kv_state = torch::zeros({B, H, FD*2, DV}, torch::dtype(torch::kFloat32).device(v.device()));
    torch::Tensor k_state = torch::zeros({B, H, 1, FD*2}, torch::dtype(torch::kFloat32).device(v.device()));

    // convert to bf16 
    c10::BFloat16 *q_bf16 = q.data_ptr<c10::BFloat16>();
    c10::BFloat16 *k_bf16 = k.data_ptr<c10::BFloat16>();
    c10::BFloat16 *v_bf16 = v.data_ptr<c10::BFloat16>();
    c10::BFloat16 *qmap_bf16 = qmap.data_ptr<c10::BFloat16>();
    c10::BFloat16 *kmap_bf16 = kmap.data_ptr<c10::BFloat16>();
    // c10::BFloat16 *shared_norm_bf16 = shared_norm.data_ptr<c10::BFloat16>();
    
    bf16 *d_q = reinterpret_cast<bf16*>(q_bf16);
    bf16 *d_k = reinterpret_cast<bf16*>(k_bf16);
    bf16 *d_v = reinterpret_cast<bf16*>(v_bf16);
    bf16 *d_qmap = reinterpret_cast<bf16*>(qmap_bf16);
    bf16 *d_kmap = reinterpret_cast<bf16*>(kmap_bf16);
    bf16 *d_o = reinterpret_cast<bf16*>(out.data_ptr<c10::BFloat16>());
    float *d_k_state = k_state.data_ptr<float>();
    float *d_kv_state = kv_state.data_ptr<float>();
    float *d_shared_norm = shared_norm.data_ptr<float>();

    dispatch_cylon_linear_sn(
        d_q, d_k, d_v, d_o, 
        d_qmap, d_kmap,
        d_k_state, d_kv_state,
        d_shared_norm,
        B, H, N, K
    );

    CHECK_CUDA_ERROR(cudaGetLastError());
    return std::make_tuple(out, kv_state, k_state);
}
#else
#include "harness.impl"
#endif