//!!! adds STA with cross pattern and temporal feature, adapt to multiple resolutions : testing!!!!
#include "kittens.cuh"
#include <cooperative_groups.h>
#include <iostream>
#include <stdio.h>

#include "pyutils/torch_helpers.cuh"
#include <ATen/cuda/CUDAContext.h>

#define CLAMP(value, min, max) ((value) < (min) ? (min) : ((value) > (max) ? (max) : (value)))
#define ABS(x) ((x) < 0 ? -(x) : (x))

constexpr int CONSUMER_WARPGROUPS = (3); 
constexpr int PRODUCER_WARPGROUPS = (1); 
constexpr int NUM_WARPGROUPS      = (CONSUMER_WARPGROUPS+PRODUCER_WARPGROUPS); 
constexpr int NUM_WORKERS         = (NUM_WARPGROUPS*kittens::WARPGROUP_WARPS); //4*4

using namespace kittens;
namespace cg = cooperative_groups;

template<int D> struct fwd_attend_ker_tile_dims {};
template<> struct fwd_attend_ker_tile_dims<64> {
    constexpr static int tile_width = (64);
    constexpr static int qo_height  = (4*16);
    constexpr static int kv_height  = (6*16);
    constexpr static int stages     = (4); 
};
template<> struct fwd_attend_ker_tile_dims<128> {
    constexpr static int tile_width = (128);
    constexpr static int qo_height  = (4*16);
    constexpr static int kv_height  = (6*16);
    constexpr static int stages     = (2); 
};

template<int D> struct fwd_globals {
    using q_tile    =         st_bf<fwd_attend_ker_tile_dims<D>::qo_height, fwd_attend_ker_tile_dims<D>::tile_width>;
    using k_tile    =         st_bf<fwd_attend_ker_tile_dims<D>::kv_height, fwd_attend_ker_tile_dims<D>::tile_width>;
    using v_tile    =         st_bf<fwd_attend_ker_tile_dims<D>::kv_height, fwd_attend_ker_tile_dims<D>::tile_width>;
    using l_col_vec = col_vec<st_fl<fwd_attend_ker_tile_dims<D>::qo_height, fwd_attend_ker_tile_dims<D>::tile_width>>;
    using o_tile    =         st_bf<fwd_attend_ker_tile_dims<D>::qo_height, fwd_attend_ker_tile_dims<D>::tile_width>;

    using q_gl = gl<bf16,  -1, -1, -1, -1, q_tile>;
    using k_gl = gl<bf16,  -1, -1, -1, -1, k_tile>;
    using v_gl = gl<bf16,  -1, -1, -1, -1, v_tile>;
    using l_gl = gl<float, -1, -1, -1, -1, l_col_vec>;
    using o_gl = gl<bf16,  -1, -1, -1, -1, o_tile>;

    q_gl q;
    k_gl k;
    v_gl v;
    l_gl l;
    o_gl o;

    const int N; 
    const int text_L;
    const int hr;
    
    const int TT, TH, TW;
    const int CT, CH, CW;

    int* num_window_ptr;
    int* window_sizes_ptr;
};

template<int D, bool is_causal, bool text_q, bool text_kv>
__global__  __launch_bounds__((NUM_WORKERS)*kittens::WARP_THREADS, 1)
void compact_attn_fwd_ker(const __grid_constant__ fwd_globals<D> g) {
    extern __shared__ int __shm[]; 
    tma_swizzle_allocator al((int*)&__shm[0]);
    int warpid = kittens::warpid(), warpgroupid = warpid/kittens::WARPGROUP_WARPS;

    using K = fwd_attend_ker_tile_dims<D>;

    using q_tile    =         st_bf<K::qo_height, K::tile_width>;
    using k_tile    =         st_bf<K::kv_height, K::tile_width>;
    using v_tile    =         st_bf<K::kv_height, K::tile_width>;
    using l_col_vec = col_vec<st_fl<K::qo_height, K::tile_width>>;
    using o_tile    =         st_bf<K::qo_height, K::tile_width>;
    
    q_tile    (&q_smem)[CONSUMER_WARPGROUPS] = al.allocate<q_tile, CONSUMER_WARPGROUPS>();
    k_tile    (&k_smem)[K::stages]           = al.allocate<k_tile, K::stages          >();
    v_tile    (&v_smem)[K::stages]           = al.allocate<v_tile, K::stages          >();
    l_col_vec (&l_smem)[CONSUMER_WARPGROUPS] = al.allocate<l_col_vec, CONSUMER_WARPGROUPS>();
    auto      (*o_smem)                      = reinterpret_cast<o_tile(*)>(q_smem);

    int TT, TH, TW;
    TT = g.TT;
    TH = g.TH;
    TW = g.TW;
    int CT, CH, CW;
    CT = g.CT;
    CH = g.CH;
    CW = g.CW;
    int qblock_num_per_Tile = (TT * TH * TW) / K::qo_height;
    int kvblock_num_per_Tile = (TT * TH * TW) / K::kv_height;
    int text_block_num = (int)(384/K::kv_height);

    int img_kv_blocks;
    int kv_blocks   = g.N / (K::kv_height);
    if constexpr (text_kv) {
        // img_kv_blocks = kv_blocks - 3;
        img_kv_blocks = kv_blocks - text_block_num;
    } else {
        img_kv_blocks = kv_blocks;
    }
    int kv_head_idx = blockIdx.y / g.hr;
    int seq_idx;
    
    if constexpr (text_q) {
        // max text_L = text_block_num * kv_height = 384
        seq_idx = (g.N - K::kv_height * text_block_num) / K::qo_height + blockIdx.x * CONSUMER_WARPGROUPS; // skip the first (CT*CH*CW*6) units(64 tokens) = N//64
    } else {
        seq_idx = blockIdx.x * CONSUMER_WARPGROUPS; // 192/3=64 tokens as the smallest unit for q
    }
    __shared__ kittens::semaphore qsmem_semaphore, k_smem_arrived[K::stages], v_smem_arrived[K::stages], compute_done[K::stages];

    if (threadIdx.x == 0) { 
        // init semaphores
        init_semaphore(qsmem_semaphore, 0, 1); 
        for(int j = 0; j < K::stages; j++) {
            init_semaphore(k_smem_arrived[j], 0, 1); 
            init_semaphore(v_smem_arrived[j], 0, 1); 
            init_semaphore(compute_done[j], CONSUMER_WARPGROUPS, 0); 
        }
        tma::expect_bytes(qsmem_semaphore, sizeof(q_smem));
        
        // load first CONSUMER_WARPGROUPS blocks for q
        for (int wg = 0; wg < CONSUMER_WARPGROUPS; wg++) { 
            coord<q_tile> q_tile_idx = {blockIdx.z, blockIdx.y, (seq_idx) + wg, 0};
            tma::load_async(q_smem[wg], g.q, q_tile_idx, qsmem_semaphore);
        }
        // load first stage-1 blocks for kv
        if constexpr (text_q){
            for (int j = 0; j < K::stages - 1; j++) {
                coord<k_tile> kv_tile_idx = {blockIdx.z, kv_head_idx, j, 0};
                tma::expect_bytes(k_smem_arrived[j], sizeof(k_tile));
                tma::load_async(k_smem[j], g.k, kv_tile_idx, k_smem_arrived[j]);
                tma::expect_bytes(v_smem_arrived[j], sizeof(v_tile));
                tma::load_async(v_smem[j], g.v, kv_tile_idx, v_smem_arrived[j]);
            }
        } else {
            int qt = seq_idx / qblock_num_per_Tile / (CH * CW);
            int qh = (seq_idx / qblock_num_per_Tile) % (CH * CW) / CW;
            int qw = (seq_idx / qblock_num_per_Tile) % CW;

            int count = 0;
            int j = 0;
            while (count < K::stages - 1) { // start from 0 to find the first num of stages - 1 blocks
                bool mask = false;
                int kt = j / kvblock_num_per_Tile / (CH * CW);
                int kh = (j / kvblock_num_per_Tile) % (CH * CW) / CW;
                int kw = (j / kvblock_num_per_Tile) % CW;
                int dist = ABS(kt - qt);
                int num_window = (g.num_window_ptr)[dist];
                if (num_window == 0){
                    j += 1;
                    continue;
                }
                int DH[2] = {CH,CH};
                int DW[2] = {CW,CW};
                int qh_adjusted[2] = {0,0};
                int qw_adjusted[2] = {0,0};
                for (int i=0; i<num_window; i++){
                    DH[i] = (g.window_sizes_ptr)[dist*4+i*2+0];
                    DW[i] = (g.window_sizes_ptr)[dist*4+i*2+1];

                    qh_adjusted[i] = CLAMP(qh, DH[i], CH-DH[i]-1);
                    qw_adjusted[i] = CLAMP(qw, DW[i], CH-DW[i]-1);
                    
                    mask = mask | ((ABS(qh_adjusted[i] - kh) <= DH[i]) && (ABS(qw_adjusted[i] - kw) <= DW[i]));
                    if (mask) {
                        break;
                    }
                }

                if (mask){
                    coord<k_tile> kv_tile_idx = {blockIdx.z, kv_head_idx, j, 0};
                    tma::expect_bytes(k_smem_arrived[count], sizeof(k_tile));
                    tma::load_async(k_smem[count], g.k, kv_tile_idx, k_smem_arrived[count]);
                    tma::expect_bytes(v_smem_arrived[count], sizeof(v_tile));
                    tma::load_async(v_smem[count], g.v, kv_tile_idx, v_smem_arrived[count]);
                    count += 1;
                }
                j += 1;
            }
        }
    }

    // if (blockIdx.x == 17 && threadIdx.x == 384) {
    //     printf("blockIdx.x: %d, threadIdx.x: %d, before syncthreads %d \n", blockIdx.x, threadIdx.x, seq_idx);
    // }

    __syncthreads(); 

    // if (blockIdx.x == 17 && threadIdx.x == 384) {
    //     printf("blockIdx.x: %d, threadIdx.x: %d, load first done %d \n", blockIdx.x, threadIdx.x, seq_idx);
    // }

    int pipe_idx = K::stages - 1; 

    if(warpgroupid == NUM_WARPGROUPS-1) {
        warpgroup::decrease_registers<32>();      
        
        int kv_iters; 
        if constexpr (is_causal) {
            kv_iters = (seq_idx * (K::qo_height/kittens::TILE_ROW_DIM<bf16>)) - 1 + (CONSUMER_WARPGROUPS * (K::qo_height/kittens::TILE_ROW_DIM<bf16>)); 
            kv_iters = ((kv_iters / (K::kv_height/kittens::TILE_ROW_DIM<bf16>)) == 0) ? (0) : ((kv_iters / (K::kv_height/kittens::TILE_ROW_DIM<bf16>)) - 1);
        }
        else { kv_iters = kv_blocks-2;}

        if(warpid == NUM_WORKERS-4) { 
            if constexpr (text_q){ // skip the first num of stages - 1 blocks and load the rest
                for (auto kv_idx = pipe_idx - 1; kv_idx <= kv_iters; kv_idx++) {
                    // printf("blockIdx.x: %d, stage: %d, kv_idx: %d kv_iters: %d sizeof(k_tile): %d\n", blockIdx.x, K::stages, kv_idx, kv_iters, sizeof(k_tile));
                    coord<k_tile> kv_tile_idx = {blockIdx.z, kv_head_idx, kv_idx + 1, 0};
                    tma::expect_bytes(k_smem_arrived[(kv_idx+1)%K::stages], sizeof(k_tile));
                    tma::load_async(k_smem[(kv_idx+1)%K::stages], g.k, kv_tile_idx, k_smem_arrived[(kv_idx+1)%K::stages]);
                    tma::expect_bytes(v_smem_arrived[(kv_idx+1)%K::stages], sizeof(v_tile));
                    tma::load_async(v_smem[(kv_idx+1)%K::stages], g.v, kv_tile_idx, v_smem_arrived[(kv_idx+1)%K::stages]);
                    kittens::wait(compute_done[(kv_idx)%K::stages], (kv_idx/K::stages)%2);
                }
            } else {
                int qt = seq_idx / qblock_num_per_Tile / (CH * CW);                                                                                                    
                int qh = (seq_idx / qblock_num_per_Tile) % (CH * CW) / CW;
                int qw = (seq_idx / qblock_num_per_Tile) % CW;

                int count = 0;
                for (int kt = 0; kt < CT; kt++) {
                    int dist = ABS(kt - qt);
                    int num_window = (g.num_window_ptr)[dist];
                    if (num_window == 0) {
                        continue;
                    }
                    else if (num_window == 1) {
                        int DH = (g.window_sizes_ptr)[dist*4];
                        int DW = (g.window_sizes_ptr)[dist*4+1];

                        int adjusted_qh = CLAMP(qh, DH, CH-DH-1);
                        int adjusted_qw = CLAMP(qw, DW, CW-DW-1);
                        
                        int k_h_min = CLAMP(adjusted_qh-DH, 0, CH-1);
                        int k_h_max = CLAMP(adjusted_qh+DH, 0, CH-1);
                        int k_w_min = CLAMP(adjusted_qw-DW, 0, CW-1);
                        int k_w_max = CLAMP(adjusted_qw+DW, 0, CW-1);
                        // if (blockIdx.x==17 && threadIdx.x == 384) {
                        //     printf("blockIdx.x: %d, threadIdx.x: %d, k_h_min:%d, k_h_max:%d, k_w_min:%d, k_w_max:%d ;\n", blockIdx.x, threadIdx.x, k_h_min, k_h_max, k_w_min, k_w_max);
                        // }
                        for (int kh = k_h_min; kh <= k_h_max; kh++) {
                            for (int kw = k_w_min; kw <= k_w_max; kw++) {
                                for (int j = 0; j < kvblock_num_per_Tile; j++){
                                    if (count >= K::stages - 1) { // skip first num of (stage) kv blocks and load the rest 
                                        int index = ((kt * (CH * CW)) + (kh * CW) + kw) * kvblock_num_per_Tile + j;
                                        coord<k_tile> kv_tile_idx = {blockIdx.z, kv_head_idx, index, 0};
                                        tma::expect_bytes(k_smem_arrived[count%K::stages], sizeof(k_tile));
                                        tma::load_async(k_smem[count%K::stages], g.k, kv_tile_idx, k_smem_arrived[count%K::stages]);
                                        tma::expect_bytes(v_smem_arrived[count%K::stages], sizeof(v_tile));
                                        tma::load_async(v_smem[count%K::stages], g.v, kv_tile_idx, v_smem_arrived[count%K::stages]);
                                        kittens::wait(compute_done[(count - 1)%K::stages], ((count - 1)/K::stages)%2);
                                        count += 1;
                                    } else {
                                        count += 1;
                                    }
                                }
                            }
                        }
                        // if (blockIdx.x==17 && threadIdx.x == 384) {
                        //     printf("blockIdx.x: %d, threadIdx.x: %d, count:%d;", blockIdx.x, threadIdx.x, count);
                        // }
                    }
                    else if (num_window == 2) {
                        int DH1_tmp = (g.window_sizes_ptr)[dist*4];
                        int DW1_tmp = (g.window_sizes_ptr)[dist*4+1];
                        int DH2_tmp = (g.window_sizes_ptr)[dist*4+2];
                        int DW2_tmp = (g.window_sizes_ptr)[dist*4+3];

                        int qh1 = CLAMP(qh, DH1_tmp, CH-DH1_tmp-1);
                        int qw1 = CLAMP(qw, DW1_tmp, CW-DW1_tmp-1);
                        int qh2 = CLAMP(qh, DH2_tmp, CH-DH2_tmp-1);
                        int qw2 = CLAMP(qw, DW2_tmp, CW-DW2_tmp-1);
                        
                        int k_h_min1 = CLAMP(qh1-DH1_tmp, 0, CH-1);
                        int k_h_max1 = CLAMP(qh1+DH1_tmp, 0, CH-1);
                        int k_w_min1 = CLAMP(qw1-DW1_tmp, 0, CW-1);
                        int k_w_max1 = CLAMP(qw1+DW1_tmp, 0, CW-1);
                        int k_h_min2 = CLAMP(qh2-DH2_tmp, 0, CH-1);
                        int k_h_max2 = CLAMP(qh2+DH2_tmp, 0, CH-1);
                        int k_w_min2 = CLAMP(qw2-DW2_tmp, 0, CW-1);
                        int k_w_max2 = CLAMP(qw2+DW2_tmp, 0, CW-1);

                        int k_h_min = min(k_h_min1, k_h_min2);
                        int k_h_max = max(k_h_min1, k_h_min2) - 1;
                        int k_w_min = max(k_w_min1, k_w_min2);
                        int k_w_max = min(k_w_max1, k_w_max2);
                        
                        for (int kh = k_h_min; kh <= k_h_max; kh++) {
                            for (int kw = k_w_min; kw <= k_w_max; kw++) {
                                for (int j = 0; j < kvblock_num_per_Tile; j++){
                                    if (count >= K::stages - 1) { // skip first num of (stage) kv blocks 
                                        int index = ((kt * (CH * CW)) + (kh * CW) + kw) * kvblock_num_per_Tile + j;
                                        coord<k_tile> kv_tile_idx = {blockIdx.z, kv_head_idx, index, 0};
                                        tma::expect_bytes(k_smem_arrived[count%K::stages], sizeof(k_tile));
                                        tma::load_async(k_smem[count%K::stages], g.k, kv_tile_idx, k_smem_arrived[count%K::stages]);
                                        tma::expect_bytes(v_smem_arrived[count%K::stages], sizeof(v_tile));
                                        tma::load_async(v_smem[count%K::stages], g.v, kv_tile_idx, v_smem_arrived[count%K::stages]);
                                        kittens::wait(compute_done[(count - 1)%K::stages], ((count - 1)/K::stages)%2);
                                        count += 1;
                                    } else {
                                        count += 1;
                                    }
                                }
                            }
                        }
                        // printf("middle1: count: %d \n", count);
                        k_h_min = max(k_h_min1, k_h_min2);
                        k_h_max = min(k_h_max1, k_h_max2);
                        k_w_min = min(k_w_min1, k_w_min2);
                        k_w_max = max(k_w_max1, k_w_max2);
                        // if (blockIdx.x == 0 )
                        //     printf("middle2: %d %d %d %d \n", k_h_min, k_h_max, k_w_min, k_w_max);
                        for (int kh = k_h_min; kh <= k_h_max; kh++) {
                            for (int kw = k_w_min; kw <= k_w_max; kw++) {
                                for (int j = 0; j < kvblock_num_per_Tile; j++){
                                    if (count >= K::stages - 1) { // skip first num of (stage) kv blocks 
                                        int index = ((kt * (CH * CW)) + (kh * CW) + kw) * kvblock_num_per_Tile + j;
                                        coord<k_tile> kv_tile_idx = {blockIdx.z, kv_head_idx, index, 0};
                                        tma::expect_bytes(k_smem_arrived[count%K::stages], sizeof(k_tile));
                                        tma::load_async(k_smem[count%K::stages], g.k, kv_tile_idx, k_smem_arrived[count%K::stages]);
                                        tma::expect_bytes(v_smem_arrived[count%K::stages], sizeof(v_tile));
                                        tma::load_async(v_smem[count%K::stages], g.v, kv_tile_idx, v_smem_arrived[count%K::stages]);
                                        kittens::wait(compute_done[(count - 1)%K::stages], ((count - 1)/K::stages)%2);
                                        count += 1;
                                    } else {
                                        count += 1;
                                    }
                                }
                            }
                        }
                        // printf("middle2: count: %d \n", count);
                        k_h_min = min(k_h_max1, k_h_max2) + 1;
                        k_h_max = max(k_h_max1, k_h_max2);
                        k_w_min = max(k_w_min1, k_w_min2);
                        k_w_max = min(k_w_max1, k_w_max2);
                        // if (blockIdx.x == 0)
                        //     printf("middle3: %d %d %d %d \n", k_h_min, k_h_max, k_w_min, k_w_max);
                        for (int kh = k_h_min; kh <= k_h_max; kh++) {
                            for (int kw = k_w_min; kw <= k_w_max; kw++) {
                                for (int j = 0; j < kvblock_num_per_Tile; j++){
                                    if (count >= K::stages - 1) { // skip first num of (stage) kv blocks 
                                        int index = ((kt * (CH * CW)) + (kh * CW) + kw) * kvblock_num_per_Tile + j;
                                        coord<k_tile> kv_tile_idx = {blockIdx.z, kv_head_idx, index, 0};
                                        tma::expect_bytes(k_smem_arrived[count%K::stages], sizeof(k_tile));
                                        tma::load_async(k_smem[count%K::stages], g.k, kv_tile_idx, k_smem_arrived[count%K::stages]);
                                        tma::expect_bytes(v_smem_arrived[count%K::stages], sizeof(v_tile));
                                        tma::load_async(v_smem[count%K::stages], g.v, kv_tile_idx, v_smem_arrived[count%K::stages]);
                                        kittens::wait(compute_done[(count - 1)%K::stages], ((count - 1)/K::stages)%2);
                                        count += 1;
                                    } else {
                                        count += 1;
                                    }
                                }
                            }
                        }
                    }
                }

                // for text kv blocks
                for (int index = img_kv_blocks; index < kv_blocks; index++) {
                    coord<k_tile> kv_tile_idx = {blockIdx.z, kv_head_idx, index, 0};
                    tma::expect_bytes(k_smem_arrived[count%K::stages], sizeof(k_tile));
                    tma::load_async(k_smem[count%K::stages], g.k, kv_tile_idx, k_smem_arrived[count%K::stages]);
                    tma::expect_bytes(v_smem_arrived[count%K::stages], sizeof(v_tile));
                    tma::load_async(v_smem[count%K::stages], g.v, kv_tile_idx, v_smem_arrived[count%K::stages]);
                    kittens::wait(compute_done[(count - 1)%K::stages], ((count - 1)/K::stages)%2);
                    count += 1;
                }
                // if (threadIdx.x == 384) {
                //     printf("blockIdx.x: %d, threadIdx.x: %d, count:%d;", blockIdx.x, threadIdx.x, count);
                // }
            }
        }
    }
    else {
        warpgroup::increase_registers<160>();

        rt_fl<16, K::kv_height>  att_block;
        rt_bf<16, K::kv_height>  att_block_mma;
        rt_fl<16, K::tile_width> o_reg;
        
        col_vec<rt_fl<16, K::kv_height>> max_vec, norm_vec, max_vec_last_scaled, max_vec_scaled;
        
        neg_infty(max_vec);
        zero(norm_vec);
        zero(o_reg);

        int kv_iters; 
        if constexpr (is_causal) {
            kv_iters = (seq_idx * 4) - 1 + (CONSUMER_WARPGROUPS * 4);
            kv_iters = (kv_iters/8);
        }
        else if constexpr (text_q){ 
            // the last three kv blocks are for text, we process them separately
            kv_iters = img_kv_blocks - 1;
        } else {
            kv_iters = 0;
            int qt = seq_idx / qblock_num_per_Tile / (CH * CW);
            int qh = (seq_idx / qblock_num_per_Tile) % (CH * CW) / CW;
            int qw = (seq_idx / qblock_num_per_Tile) % CW;

            for (int kt = 0; kt < CT; kt++) {
                int dist = ABS(kt - qt);
                int num_window = (g.num_window_ptr)[dist];
                if (num_window == 0) {
                    continue;
                }
                else if (num_window == 1){
                    int DH1_tmp = (g.window_sizes_ptr)[dist*4];
                    int DW1_tmp = (g.window_sizes_ptr)[dist*4+1];
                    kv_iters += (CLAMP(DH1_tmp*2+1, 1, CH) * CLAMP(DW1_tmp*2+1, 1, CW) * kvblock_num_per_Tile);
                }
                else{
                    int DH1_tmp = (g.window_sizes_ptr)[dist*4];
                    int DW1_tmp = (g.window_sizes_ptr)[dist*4+1];
                    int DH2_tmp = (g.window_sizes_ptr)[dist*4+2];
                    int DW2_tmp = (g.window_sizes_ptr)[dist*4+3];
                    int repeat = min(CLAMP(DH1_tmp*2+1, 1, CH),CLAMP(DH2_tmp*2+1, 1, CH)) 
                        * min(CLAMP(DW1_tmp*2+1, 1, CW),CLAMP(DW2_tmp*2+1, 1, CW));
                    kv_iters += (CLAMP(DH1_tmp*2+1, 1, CH) * CLAMP(DW1_tmp*2+1, 1, CW) 
                        + CLAMP(DH2_tmp*2+1, 1, CH) * CLAMP(DW2_tmp*2+1, 1, CW) - repeat) * kvblock_num_per_Tile;
                }
            }
            kv_iters -= 1;
            // if (blockIdx.x == 17 && threadIdx.x == 384) {
            //         printf("blockIdx.x: %d, threadIdx.x: %d, kv_iters:%d;", blockIdx.x, threadIdx.x, kv_iters);
            // }
            // kv_iters = CLAMP(DT*2+1, 1, CT) * CLAMP(DH*2+1, 1, CH) * CLAMP(DW*2+1, 1, CW) * 3 - 1 ; 
        }

        kittens::wait(qsmem_semaphore, 0);
        
        for (auto kv_idx = 0; kv_idx <= kv_iters; kv_idx++) {

            kittens::wait(k_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2);
            warpgroup::mm_ABt(att_block, q_smem[warpgroupid], k_smem[(kv_idx)%K::stages]);
            
            copy(max_vec_last_scaled, max_vec);
            if constexpr (D == 64) { mul(max_vec_last_scaled, max_vec_last_scaled, 1.44269504089f*0.125f); }
            else                   { mul(max_vec_last_scaled, max_vec_last_scaled, 1.44269504089f*0.08838834764f); }
            
            warpgroup::mma_async_wait();

            row_max(max_vec, att_block, max_vec);
            
            if constexpr (D == 64) { 
                mul(att_block, att_block,    1.44269504089f*0.125f); 
                mul(max_vec_scaled, max_vec, 1.44269504089f*0.125f);
            }
            else                   { 
                mul(att_block, att_block,    1.44269504089f*0.08838834764f); 
                mul(max_vec_scaled, max_vec, 1.44269504089f*0.08838834764f);
            }

            sub_row(att_block, att_block, max_vec_scaled);
            exp2(att_block, att_block);
            sub(max_vec_last_scaled, max_vec_last_scaled, max_vec_scaled);
            exp2(max_vec_last_scaled,       max_vec_last_scaled);
            mul(norm_vec,            norm_vec,     max_vec_last_scaled);
            row_sum(norm_vec,  att_block, norm_vec);
            add(att_block, att_block, 0.f);
            copy(att_block_mma, att_block); 
            mul_row(o_reg, o_reg, max_vec_last_scaled); 

            kittens::wait(v_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2); 

            warpgroup::mma_AB(o_reg, att_block_mma, v_smem[(kv_idx)%K::stages]);
            warpgroup::mma_async_wait();

            if(warpgroup::laneid() == 0) arrive(compute_done[(kv_idx)%K::stages], 1);
        }

        // the last three kv blocks are for text, we process them separately
        if constexpr(text_kv) {
            for (auto kv_idx = kv_iters + 1; kv_idx <= kv_iters + text_block_num; kv_idx++) {

                kittens::wait(k_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2);
                warpgroup::mm_ABt(att_block, q_smem[warpgroupid], k_smem[(kv_idx)%K::stages]);
                
                copy(max_vec_last_scaled, max_vec);
                if constexpr (D == 64) { mul(max_vec_last_scaled, max_vec_last_scaled, 1.44269504089f*0.125f); }
                else                   { mul(max_vec_last_scaled, max_vec_last_scaled, 1.44269504089f*0.08838834764f); }
                
                warpgroup::mma_async_wait();
                // apply non-pad mask
                int offset = g.text_L - (kv_idx - (kv_iters + 1)) * K::kv_height;
                // printf("k_idx_start: %d, k_idx_end: %d, text_end: %d, offset: %d\n", k_idx_start, k_idx_end, text_end, offset);
                right_fill(att_block, att_block, offset, base_types::constants<float>::neg_infty());


                row_max(max_vec, att_block, max_vec);
                
                if constexpr (D == 64) { 
                    mul(att_block, att_block,    1.44269504089f*0.125f); 
                    mul(max_vec_scaled, max_vec, 1.44269504089f*0.125f);
                }
                else                   { 
                    mul(att_block, att_block,    1.44269504089f*0.08838834764f); 
                    mul(max_vec_scaled, max_vec, 1.44269504089f*0.08838834764f);
                }

                sub_row(att_block, att_block, max_vec_scaled);
                exp2(att_block, att_block);
                sub(max_vec_last_scaled, max_vec_last_scaled, max_vec_scaled);
                exp2(max_vec_last_scaled,       max_vec_last_scaled);
                mul(norm_vec,            norm_vec,     max_vec_last_scaled);
                row_sum(norm_vec,  att_block, norm_vec);
                add(att_block, att_block, 0.f);
                copy(att_block_mma, att_block); 
                mul_row(o_reg, o_reg, max_vec_last_scaled); 

                kittens::wait(v_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2); 

                warpgroup::mma_AB(o_reg, att_block_mma, v_smem[(kv_idx)%K::stages]);
                warpgroup::mma_async_wait();

                if(warpgroup::laneid() == 0) arrive(compute_done[(kv_idx)%K::stages], 1);
            }
        }
        div_row(o_reg, o_reg, norm_vec);
        warpgroup::store(o_smem[warpgroupid], o_reg); 
        warpgroup::sync(warpgroupid+4);

        
        if (warpid % 4 == 0) {
            coord<o_tile> o_tile_idx = {blockIdx.z, blockIdx.y, (seq_idx) + warpgroupid, 0};
            tma::store_async(g.o, o_smem[warpgroupid], o_tile_idx);
        }

        mul(max_vec_scaled,   max_vec_scaled, 0.69314718056f);
        log(norm_vec, norm_vec);
        add(norm_vec, norm_vec, max_vec_scaled);

        if constexpr (D == 64) { mul(norm_vec, norm_vec, -8.0f); }
        else                   { mul(norm_vec, norm_vec, -11.313708499f); }
        
        warpgroup::store(l_smem[warpgroupid], norm_vec);
        warpgroup::sync(warpgroupid+4);

        if (warpid % 4 == 0) {
            coord<l_col_vec> tile_idx = {blockIdx.z, blockIdx.y, 0, (seq_idx) + warpgroupid};
            tma::store_async(g.l, l_smem[warpgroupid], tile_idx);
        }
        
        tma::store_async_wait();
        
    }
}


torch::Tensor 
ca_fwd(torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor o, int TT, int TH, int TW, int CT, int CH, int CW, torch::Tensor num_window, torch::Tensor window_sizes, int text_length, bool process_text, bool has_text)
{
    CHECK_INPUT(q);
    CHECK_INPUT(k);
    CHECK_INPUT(v);

    auto batch    = q.size(0);
    auto seq_len  = q.size(2); 
    auto head_dim = q.size(3);  
    auto qo_heads = q.size(1);
    auto kv_heads = k.size(1);

    // check to see that these dimensions match for all inputs
    TORCH_CHECK(q.size(0) == batch, "Q batch dimension - idx 0 - must match for all inputs");
    TORCH_CHECK(k.size(0) == batch, "K batch dimension - idx 0 - must match for all inputs");
    TORCH_CHECK(v.size(0) == batch, "V batch dimension - idx 0 - must match for all inputs");

    TORCH_CHECK(q.size(2) == seq_len, "Q sequence length dimension - idx 2 - must match for all inputs");
    TORCH_CHECK(k.size(2) == seq_len, "K sequence length dimension - idx 2 - must match for all inputs");
    TORCH_CHECK(v.size(2) == seq_len, "V sequence length dimension - idx 2 - must match for all inputs");

    TORCH_CHECK(q.size(3) == head_dim, "Q head dimension - idx 3 - must match for all non-vector inputs");
    TORCH_CHECK(k.size(3) == head_dim, "K head dimension - idx 3 - must match for all non-vector inputs");
    TORCH_CHECK(v.size(3) == head_dim, "V head dimension - idx 3 - must match for all non-vector inputs");

    TORCH_CHECK(qo_heads >= kv_heads, "QO heads must be greater than or equal to KV heads");
    TORCH_CHECK(qo_heads % kv_heads == 0, "QO heads must be divisible by KV heads");
    TORCH_CHECK(q.size(1) == qo_heads, "QO head dimension - idx 1 - must match for all inputs");
    TORCH_CHECK(k.size(1) == kv_heads, "KV head dimension - idx 1 - must match for all inputs");
    TORCH_CHECK(v.size(1) == kv_heads, "KV head dimension - idx 1 - must match for all inputs");  

    auto hr = qo_heads / kv_heads;

    c10::BFloat16* q_ptr = q.data_ptr<c10::BFloat16>();
    c10::BFloat16* k_ptr = k.data_ptr<c10::BFloat16>();
    c10::BFloat16* v_ptr = v.data_ptr<c10::BFloat16>();

    bf16*  d_q = reinterpret_cast<bf16*>(q_ptr);
    bf16*  d_k = reinterpret_cast<bf16*>(k_ptr);
    bf16*  d_v = reinterpret_cast<bf16*>(v_ptr);

    torch::Tensor l_vec = torch::empty({static_cast<const uint>(batch), 
                                        static_cast<const uint>(qo_heads), 
                                        static_cast<const uint>(seq_len), 
                                        static_cast<const uint>(1)}, 
                                        torch::TensorOptions().dtype(torch::kFloat).device(q.device()).memory_format(at::MemoryFormat::Contiguous));
        

    bf16*  o_ptr = reinterpret_cast<bf16*>(o.data_ptr<c10::BFloat16>());
    bf16*  d_o   = reinterpret_cast<bf16*>(o_ptr);

    float* l_ptr = reinterpret_cast<float*>(l_vec.data_ptr<float>());
    float* d_l   = reinterpret_cast<float*>(l_ptr);
    
    cudaDeviceSynchronize();
    auto stream = at::cuda::getCurrentCUDAStream().stream(); 

    if (head_dim == 128) {
        using q_tile    =         st_bf<fwd_attend_ker_tile_dims<128>::qo_height, fwd_attend_ker_tile_dims<128>::tile_width>;
        using k_tile    =         st_bf<fwd_attend_ker_tile_dims<128>::kv_height, fwd_attend_ker_tile_dims<128>::tile_width>;
        using v_tile    =         st_bf<fwd_attend_ker_tile_dims<128>::kv_height, fwd_attend_ker_tile_dims<128>::tile_width>;
        using l_col_vec = col_vec<st_fl<fwd_attend_ker_tile_dims<128>::qo_height, fwd_attend_ker_tile_dims<128>::tile_width>>;
        using o_tile    =         st_bf<fwd_attend_ker_tile_dims<128>::qo_height, fwd_attend_ker_tile_dims<128>::tile_width>;

        using q_global = gl<bf16,  -1, -1, -1, -1, q_tile>;
        using k_global = gl<bf16,  -1, -1, -1, -1, k_tile>;
        using v_global = gl<bf16,  -1, -1, -1, -1, v_tile>;
        using l_global = gl<float, -1, -1, -1, -1, l_col_vec>;
        using o_global = gl<bf16,  -1, -1, -1, -1, o_tile>;

        using globals      = fwd_globals<128>;

        q_global qg_arg{d_q, static_cast<unsigned int>(batch), static_cast<unsigned int>(qo_heads), static_cast<unsigned int>(seq_len), 128U};
        k_global kg_arg{d_k, static_cast<unsigned int>(batch), static_cast<unsigned int>(kv_heads), static_cast<unsigned int>(seq_len), 128U};
        v_global vg_arg{d_v, static_cast<unsigned int>(batch), static_cast<unsigned int>(kv_heads), static_cast<unsigned int>(seq_len), 128U};
        l_global lg_arg{d_l, static_cast<unsigned int>(batch), static_cast<unsigned int>(qo_heads), 1U,   static_cast<unsigned int>(seq_len)};
        o_global og_arg{d_o, static_cast<unsigned int>(batch), static_cast<unsigned int>(qo_heads), static_cast<unsigned int>(seq_len), 128U};

        
        int* window_sizes_ptr = reinterpret_cast<int*>(window_sizes.data_ptr<int>());
        int* num_window_ptr = reinterpret_cast<int*>(num_window.data_ptr<int>());

        globals g{qg_arg, kg_arg, vg_arg, lg_arg, og_arg, static_cast<int>(seq_len),  static_cast<int>(text_length), static_cast<int>(hr), 
            static_cast<int>(TT), static_cast<int>(TH), static_cast<int>(TW), 
            static_cast<int>(CT), static_cast<int>(CH), static_cast<int>(CW), 
            num_window_ptr, window_sizes_ptr};

        auto mem_size = kittens::MAX_SHARED_MEMORY;
        auto threads  = NUM_WORKERS * kittens::WARP_THREADS;

        if (has_text) {
            // TORCH_CHECK(seq_len % (CONSUMER_WARPGROUPS*kittens::TILE_DIM*4) == 0, "sequence length must be divisible by 192");
            TORCH_CHECK(seq_len % (CONSUMER_WARPGROUPS*kittens::TILE_ROW_DIM<bf16>*4) == 0, "sequence length must be divisible by 192");// for thread block is launched for every 192 tokens, CONSUMER_WARPGROUPS=3
            dim3 grid_image(seq_len/(CONSUMER_WARPGROUPS*kittens::TILE_ROW_DIM<bf16>*4)-2, qo_heads, batch);
            dim3 grid_text(2, qo_heads, batch);
            if (!process_text) {
                cudaFuncSetAttribute(
                    compact_attn_fwd_ker<128, false, false, true>,
                    cudaFuncAttributeMaxDynamicSharedMemorySize,
                    mem_size
                );
                compact_attn_fwd_ker<128, false, false, true><<<grid_image, (32*NUM_WORKERS), mem_size, stream>>>(g); // launch [<N/192,1,1>, 32*16]
            }
            else {
                cudaFuncSetAttribute(
                    compact_attn_fwd_ker<128, false, true, true>,
                    cudaFuncAttributeMaxDynamicSharedMemorySize,
                    mem_size
                );
                compact_attn_fwd_ker<128, false, true, true><<<grid_text, (32*NUM_WORKERS), mem_size, stream>>>(g);
            }
        } 
        else {
            TORCH_CHECK(seq_len % (CONSUMER_WARPGROUPS*kittens::TILE_ROW_DIM<bf16>*4) == 0, "sequence length must be divisible by 192");// for thread block is launched for every 192 tokens, CONSUMER_WARPGROUPS=3
            dim3 grid_image(seq_len/(CONSUMER_WARPGROUPS*kittens::TILE_ROW_DIM<bf16>*4), qo_heads, batch);
            cudaFuncSetAttribute(
                compact_attn_fwd_ker<128, false, false, false>,
                cudaFuncAttributeMaxDynamicSharedMemorySize,
                mem_size
            );
            compact_attn_fwd_ker<128, false, false, false><<<grid_image, (32*NUM_WORKERS), mem_size, stream>>>(g); // launch [<N/192,1,1>, 32*16]
        }
        
        CHECK_CUDA_ERROR(cudaGetLastError());
        cudaStreamSynchronize(stream);
    }
    else if(head_dim == 64){
        using q_tile    =         st_bf<fwd_attend_ker_tile_dims<64>::qo_height, fwd_attend_ker_tile_dims<64>::tile_width>;
        using k_tile    =         st_bf<fwd_attend_ker_tile_dims<64>::kv_height, fwd_attend_ker_tile_dims<64>::tile_width>;
        using v_tile    =         st_bf<fwd_attend_ker_tile_dims<64>::kv_height, fwd_attend_ker_tile_dims<64>::tile_width>;
        using l_col_vec = col_vec<st_fl<fwd_attend_ker_tile_dims<64>::qo_height, fwd_attend_ker_tile_dims<64>::tile_width>>;
        using o_tile    =         st_bf<fwd_attend_ker_tile_dims<64>::qo_height, fwd_attend_ker_tile_dims<64>::tile_width>;

        using q_global = gl<bf16,  -1, -1, -1, -1, q_tile>;
        using k_global = gl<bf16,  -1, -1, -1, -1, k_tile>;
        using v_global = gl<bf16,  -1, -1, -1, -1, v_tile>;
        using l_global = gl<float, -1, -1, -1, -1, l_col_vec>;
        using o_global = gl<bf16,  -1, -1, -1, -1, o_tile>;

        using globals      = fwd_globals<64>;

        q_global qg_arg{d_q, static_cast<unsigned int>(batch), static_cast<unsigned int>(qo_heads), static_cast<unsigned int>(seq_len), 64};
        k_global kg_arg{d_k, static_cast<unsigned int>(batch), static_cast<unsigned int>(kv_heads), static_cast<unsigned int>(seq_len), 64};
        v_global vg_arg{d_v, static_cast<unsigned int>(batch), static_cast<unsigned int>(kv_heads), static_cast<unsigned int>(seq_len), 64};
        l_global lg_arg{d_l, static_cast<unsigned int>(batch), static_cast<unsigned int>(qo_heads), 1U,   static_cast<unsigned int>(seq_len)};
        o_global og_arg{d_o, static_cast<unsigned int>(batch), static_cast<unsigned int>(qo_heads), static_cast<unsigned int>(seq_len), 64};

        
        int* window_sizes_ptr = reinterpret_cast<int*>(window_sizes.data_ptr<int>());
        int* num_window_ptr = reinterpret_cast<int*>(num_window.data_ptr<int>());

        globals g{qg_arg, kg_arg, vg_arg, lg_arg, og_arg, static_cast<int>(seq_len),  static_cast<int>(text_length), static_cast<int>(hr), 
            static_cast<int>(TT), static_cast<int>(TH), static_cast<int>(TW), 
            static_cast<int>(CT), static_cast<int>(CH), static_cast<int>(CW), 
            num_window_ptr, window_sizes_ptr};

        auto mem_size = kittens::MAX_SHARED_MEMORY;
        auto threads  = NUM_WORKERS * kittens::WARP_THREADS;

        if (has_text) {
            // TORCH_CHECK(seq_len % (CONSUMER_WARPGROUPS*kittens::TILE_DIM*4) == 0, "sequence length must be divisible by 192");
            TORCH_CHECK(seq_len % (CONSUMER_WARPGROUPS*kittens::TILE_ROW_DIM<bf16>*4) == 0, "sequence length must be divisible by 192");// for thread block is launched for every 192 tokens, CONSUMER_WARPGROUPS=3
            dim3 grid_image(seq_len/(CONSUMER_WARPGROUPS*kittens::TILE_ROW_DIM<bf16>*4)-2, qo_heads, batch);
            dim3 grid_text(2, qo_heads, batch);
            if (!process_text) {
                cudaFuncSetAttribute(
                    compact_attn_fwd_ker<64, false, false, true>,
                    cudaFuncAttributeMaxDynamicSharedMemorySize,
                    mem_size
                );
                compact_attn_fwd_ker<64, false, false, true><<<grid_image, (32*NUM_WORKERS), mem_size, stream>>>(g); // launch [<N/192,1,1>, 32*16]
            }
            else {
                cudaFuncSetAttribute(
                    compact_attn_fwd_ker<64, false, true, true>,
                    cudaFuncAttributeMaxDynamicSharedMemorySize,
                    mem_size
                );
                compact_attn_fwd_ker<64, false, true, true><<<grid_text, (32*NUM_WORKERS), mem_size, stream>>>(g);
            }
        } 
        else {
            TORCH_CHECK(seq_len % (CONSUMER_WARPGROUPS*kittens::TILE_ROW_DIM<bf16>*4) == 0, "sequence length must be divisible by 192");// for thread block is launched for every 192 tokens, CONSUMER_WARPGROUPS=3
            dim3 grid_image(seq_len/(CONSUMER_WARPGROUPS*kittens::TILE_ROW_DIM<bf16>*4), qo_heads, batch);
            cudaFuncSetAttribute(
                compact_attn_fwd_ker<64, false, false, false>,
                cudaFuncAttributeMaxDynamicSharedMemorySize,
                mem_size
            );
            compact_attn_fwd_ker<64, false, false, false><<<grid_image, (32*NUM_WORKERS), mem_size, stream>>>(g); // launch [<N/192,1,1>, 32*16]
        }
        
        CHECK_CUDA_ERROR(cudaGetLastError());
        cudaStreamSynchronize(stream);
    }

    return o;
    cudaDeviceSynchronize();
}
