#include <cutlass/numeric_types.h>
#include <cutlass/bfloat16.h>

#include <cuda_runtime.h>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda/atomic>
#include <cuda_fp16.h>

#include <cute/tensor.hpp>
#include <cute/arch/mma_sm80.hpp>
#include <cute/arch/copy_sm80.hpp>
#include <cute/algorithm/gemm.hpp>

#include "xla/ffi/api/ffi.h"

#include <math.h>

using namespace cute;
namespace cg = cooperative_groups;
using bf16 = cutlass::bfloat16_t;
using fp16 = cutlass::half_t;

__device__ __host__ inline constexpr uint32_t float_to_ordered_uint(float f) {
    uint32_t u = reinterpret_cast<const uint32_t&>(f);
    uint32_t mask = -int32_t(u >> 31) | 0x80000000;
    return u ^ mask;
}

__device__ __host__ inline constexpr uint32_t fp16_to_tagged_ordered_uint(fp16 f, uint16_t tag) {
    uint16_t u = reinterpret_cast<const uint16_t&>(f);
    uint16_t mask = -int16_t(u >> 15) | 0x8000;
    return (uint32_t(u ^ mask) << 16) | uint32_t(tag);
}

__device__ __host__ inline constexpr uint16_t extract_tag_from_ordered_uint(uint32_t u) {
    return uint16_t(u & 0xFFFF);
}

template <int M = 64, int N = 64, int K = 64>
__global__ void adjust_kernel(
    const bf16* X_ptr,
    const bf16* K_ptr,
    float* S_ptr,
    int32_t* lab_ptr
) {
    using LayoutX = Layout<Shape<Int<M>, Int<K>>, Stride<Int<K>, Int<1>>>;
    using LayoutK = Layout<Shape<Int<N>, Int<K>>, Stride<Int<K>, Int<1>>>;
    using LayoutS = Layout<Shape<Int<M>, Int<N>>, Stride<Int<N>, Int<1>>>;

    cg::thread_block cta = cg::this_thread_block();
    const int tid = cta.thread_rank();
    const int num_threads = cta.size();

    constexpr size_t align = 128;
    constexpr size_t X_offset = 0;
    constexpr size_t K_offset = (X_offset + cosize_v<LayoutX> * sizeof(bf16) + align-1) & ~(align-1);
    constexpr size_t S_offset = (K_offset + cosize_v<LayoutK> * sizeof(bf16) + align-1) & ~(align-1);

    struct SharedStorage{
        alignas(128) bf16 X[cosize_v<LayoutX>];
        alignas(128) bf16 K_[cosize_v<LayoutK>];
        alignas(128) float S[cosize_v<LayoutS>];
    };
    __shared__ SharedStorage shared;
    //extern __shared__ uint8_t smem[];

    auto gX = make_tensor(make_gmem_ptr(X_ptr), LayoutX{});
    auto gK = make_tensor(make_gmem_ptr(K_ptr), LayoutK{});
    auto gS = make_tensor(make_gmem_ptr(S_ptr), LayoutS{});
    Tensor gL = make_tensor(make_gmem_ptr(lab_ptr), Layout<Shape<Int<M>>>());
    //auto sX = make_tensor(make_smem_ptr(reinterpret_cast<bf16*>(smem + X_offset)), LayoutX{});
    auto sX = make_tensor(make_smem_ptr(shared.X), LayoutX{});
    //auto sK = make_tensor(make_smem_ptr(reinterpret_cast<bf16*>(smem + K_offset)), LayoutK{});
    auto sK = make_tensor(make_smem_ptr(shared.K_), LayoutK{});
    //auto sS = make_tensor(make_smem_ptr(reinterpret_cast<float*>(smem + S_offset)), LayoutS{});
    auto sS = make_tensor(make_smem_ptr(shared.S), LayoutS{});

    auto copy_atom = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, bf16>{};
    auto tiled_copy = make_tiled_copy(copy_atom, Layout<Shape<_32, _4>>{}, Layout<Shape<_1, _8>>{});
    auto thr_copy = tiled_copy.get_thread_slice(tid);
    copy(tiled_copy, thr_copy.partition_S(gX), thr_copy.partition_D(sX));
    copy(tiled_copy, thr_copy.partition_S(gK), thr_copy.partition_D(sK));

    cp_async_fence();
    cp_async_wait<0>();
    __syncthreads();

    auto tiled_mma = make_tiled_mma(
        SM80_16x8x16_F32BF16BF16F32_TN{},
        Layout<Shape<_2,_2,_1>>{}
    );
    auto thr_mma = tiled_mma.get_thread_slice(tid);

    auto tCsX = thr_mma.partition_A(sX);
    auto tCsK = thr_mma.partition_B(sK);
    auto tCsS = thr_mma.partition_C(sS);

    auto tCrS = thr_mma.make_fragment_C(tCsS);

    clear(tCrS);
    gemm(tiled_mma, tCsX, tCsK, tCrS);
    copy(tCrS, tCsS);
    auto tCgS = thr_mma.partition_C(gS);
    copy(tCrS, tCgS);


    __syncthreads();

    auto warp = cg::tiled_partition<32>(cta);
    int wid = warp.meta_group_rank();
    int lid = warp.thread_rank();
    
    Tensor tRsS = zipped_divide(sS, Shape<_4,_32>{})(make_coord(wid, lid), make_coord(_,_));
    Tensor wRgL = local_partition(gL, Layout<Shape<_4>, Stride<_1>>{}, wid);
    Tensor tRrS = make_tensor_like(tRsS);

    copy(tRsS, tRrS);
    for (int row = 0; row < size<0>(tRrS); row++) {
        float max_val = tRrS(row,0);
        int local_max_index = 0;
        for (int col = 1; col < size<1>(tRrS); col++) {
            if (tRrS(row,col) > max_val) {
                max_val = tRrS(row,col);
                local_max_index = col;
            }
        }
        uint32_t max_val_u = float_to_ordered_uint(max_val);
        uint32_t warp_max_u = __reduce_max_sync(0xFFFFFFFF, max_val_u);
        //float warp_max = cg::reduce(warp, max_val, cg::greater<float>());
        if (max_val_u == warp_max_u) {
            if (cg::coalesced_threads().thread_rank() == 0) {
                // We are the lowest indexed thread with the max value
                wRgL(row) = lid + local_max_index * 32; // set label to index of max value
            }
        }
    }

    // accumulate K in registers
    // (wid, lid) -> row, col of centroids
    // does label map to our wid?
    //Tensor all_tAsS = logical_divide(sS, Shape<_4,_32>{}); //(wid, rows),(lid,cols)
    //auto row_inverse = right_inverse(get<_0>(all_tAsS.layout()));
    //Tensor tAsS = all_tAsS(make_coord(wid, _), make_coord(lid, _));
    //Tensor tAgL = local_partition(gL, Layout<Shape<_32>, Stride<_1>>{}, lid);
    //for (int centroid = wid; centroid < 64; centroid += 4) {
    //}

    Tensor bAsK = logical_divide(sK, Tile<Layout<_4>, Layout<_32,_2>>{}); //(wid, rows),(lid, pair)
    if (thread0()) {print("bAsK: "); print(bAsK); print("\n");}
    Tensor tAsK = bAsK(make_coord(wid, _), make_coord(lid, _));
    if (thread0()) {print("tAsK: "); print(tAsK); print("\n");}
    Tensor tArK = make_tensor<float>(tAsK.layout());
    if (thread0()) {print("tArK: "); print(tArK); print("\n");}
    Tensor tArK_count = make_tensor<int>(get<0>(tArK.layout()).shape());
    if (thread0()) {print("tArK_count: "); print(tArK_count); print("\n");}
    for (int i = 0; i < size(tArK); i++) {
        tArK(i) = float(tAsK(i));
    }
    clear(tArK);
    clear(tArK_count);

    Tensor bAgL = zipped_divide(gL, Shape<_32>{}); //(lid), (rest)
    Tensor tAgL = bAgL(lid, _);
    Tensor tArL = make_tensor_like(tAgL);
    copy(tAgL, tArL);
    if (thread0()) {print("bAgL: "); print(bAgL); print("\n");}
    if (thread0()) {print("tAgL: "); print(tAgL); print("\n");}
    if (thread0()) {print("tArL: "); print(tArL); print("\n");}

    Tensor bAsX = logical_divide(sX, Tile<Layout<_1>, Layout<_32,_2>>{}); //(row), (lid, pair)
    if (thread0()) {print("bAsX: "); print(bAsX); print("\n");}
    Tensor tAsX = bAsX(_, make_coord(lid, _));
    if (thread0()) {print("tAsX: "); print(tAsX); print("\n");}

    for (int active_lid = 0; active_lid < 32; active_lid++) {
        for (int rest = 0; rest < size(tArL); rest++) {
            int label = __shfl_sync(0xFFFFFFFF, tArL(rest), active_lid);
            if ((label % 4) == wid) {
                int x_index = active_lid + rest * 32;
                int current_label = label;
                int current_row = current_label / 4;
                tArK_count(current_row) += 1;
                Tensor chosen_tArK = tArK(current_row, _);
                for (int i = 0; i < size(chosen_tArK); i++) {
                    chosen_tArK(i) += float(tAsX(x_index, i));
                }
            }
        }
    }

    for (int i = 0; i < size(tArK); i++) {
        tAsK(i) = bf16(tArK(i));
    }
    Tensor tAgS = logical_divide(gS, Tile<Layout<_4>, Layout<_32,_2>>{})(make_coord(wid, _), make_coord(lid, _));
    copy(tArK, tAgS);
}

template <int B, int D, int K, int WARPS>
__global__ void __launch_bounds__(WARPS*32,1) adjust_fp16_kernel(
    Tensor<ViewEngine<const fp16*>, Layout<Shape<int, int, Int<B>, Int<D>>, Stride<int, Int<B*D>, Int<D>, Int<1>>>> grid_gX,
    Tensor<ViewEngine<const fp16*>, Layout<Shape<int, Int<K>, Int<D>>, Stride<Int<K*D>, Int<D>, Int<1>>>> grid_gCtot,
    Tensor<ViewEngine<const fp16*>, Layout<Shape<int, Int<K>, Int<D>>, Stride<Int<K*D>, Int<D>, Int<1>>>> grid_gCcnt,
    Tensor<ViewEngine<fp16*>, Layout<Shape<int, int, Int<B>, Int<K>>, Stride<int, Int<B*K>, Int<K>, Int<1>>>> grid_gS,
    Tensor<ViewEngine<int32_t*>, Layout<Shape<int, int, Int<B>>, Stride<int, Int<B>, Int<1>>>> grid_gL,
    Tensor<ViewEngine<fp16*>, Layout<Shape<int, Int<K>, Int<D>>, Stride<Int<K*D>, Int<D>, Int<1>>>> grid_gCtot_out,
    Tensor<ViewEngine<fp16*>, Layout<Shape<int, Int<K>, Int<D>>, Stride<Int<K*D>, Int<D>, Int<1>>>> grid_gCcnt_out,
    fp16 beta,
    int iters
) {
    //Note: Designed for D=64, B%16==0 (e.g. 64), K%(8*WARPS) == 0 (e.g. K=64, WARPS=4)
    //D=64 corresponds to standard LLM head size.
    //Centroids are kept only in registers, leaving all smem for data and allowing B and K to optimized ~separately.

    //Select outer batch for this block
    const int block_index = blockIdx.x;
    Tensor gX = grid_gX(block_index,_,_,_);
    Tensor gCtot = grid_gCtot(block_index,_,_);
    Tensor gCcnt = grid_gCcnt(block_index,_,_);
    Tensor gS = grid_gS(block_index,_,_,_);
    Tensor gL = grid_gL(block_index,_,_);
    Tensor gCtot_out = grid_gCtot_out(block_index,_,_);
    Tensor gCcnt_out = grid_gCcnt_out(block_index,_,_);

    using Swizzle128B_fp16 = Swizzle<3,3,3>;
    using Swizzle128B_half2 = Swizzle<3,2,3>;

    const int N = size<0>(gX);
    struct SharedStorage{
        alignas(128) fp16 X[cosize_v<decltype(gX(0,_,_).layout())>];
        //alignas(128) fp16 C[cosize_v<decltype(gCtot(0,_,_).layout())>];
        alignas(128) uint32_t tagged_max[B];
    };
    __shared__ SharedStorage smem;
    Tensor sX = make_tensor(make_smem_ptr(smem.X), composition(Swizzle128B_fp16{}, gX(0,_,_).layout()));
    //Tensor sC = make_tensor(make_smem_ptr(smem.C), composition(Swizzle128B_fp16{}, gCtot(0,_,_).layout()));
    //Tensor sX = make_tensor(make_smem_ptr(smem.X), gX(0,_,_).layout());
    Tensor sT = make_tensor(make_smem_ptr(smem.tagged_max), Layout<Shape<Int<B>>>{});

    cg::thread_block cta = cg::this_thread_block();
    const int tid = cta.thread_rank();
    const int num_threads = cta.size();
    auto warp = cg::tiled_partition<32>(cta);
    const int wid = warp.meta_group_rank();
    const int lid = warp.thread_rank();
    const int num_warps = warp.meta_group_size();
    auto quad = cg::tiled_partition<4>(warp);
    const int qid = quad.meta_group_rank();
    const int qlid = quad.thread_rank();

    constexpr int copy_warp_count = (8 < WARPS) ? 8 : WARPS;
    auto blk_copy = make_tiled_copy(
        Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, fp16>{},
        Layout<Shape<_32,Int<copy_warp_count>>>{},
        Layout<Shape<_1,_8>>{}
    );
    auto thr_copy = blk_copy.get_thread_slice(tid);

    using MMA_TYPE = SM80_16x8x16_F16F16F16F16_TN;
    auto atom_mma = MMA_Atom<MMA_TYPE>{};
    auto blk_mma = make_tiled_mma(
        MMA_TYPE{},
        Layout<Shape<_1,Int<WARPS>,_1>>{}
    );
    auto thr_mma = blk_mma.get_thread_slice(tid);


    CUTE_STATIC_ASSERT(D % 2 == 0, "D must be even for fp16 half2");
    auto sX2_layout = Layout<Shape<Int<B>, Int<D/2>>, Stride<Int<D/2>, Int<1>>>{};
    auto sX2_sw_layout = composition(Swizzle128B_half2{}, sX2_layout);
    Tensor sX2 = make_tensor(make_smem_ptr(reinterpret_cast<__half2*>(smem.X)), sX2_layout);
    Tensor sX2_sw = make_tensor(make_smem_ptr(reinterpret_cast<__half2*>(smem.X)), sX2_sw_layout);
    auto gC2_layout = Layout<Shape<Int<K>, Int<D/2>>, Stride<Int<D/2>, _1>>{};
    auto gC2_sw_layout = composition(Swizzle128B_half2{}, gC2_layout);
    Tensor gCtot2 = make_tensor(make_gmem_ptr(reinterpret_cast<const __half2*>(gCtot.data())), gC2_layout);
    Tensor gCtot_out2 = make_tensor(make_gmem_ptr(reinterpret_cast<__half2*>(gCtot_out.data())), gC2_layout);
    Tensor gCtot2_sw = make_tensor(make_gmem_ptr(reinterpret_cast<const __half2*>(gCtot.data())), gC2_sw_layout);
    Tensor gCtot_out2_sw = make_tensor(make_gmem_ptr(reinterpret_cast<__half2*>(gCtot_out.data())), gC2_sw_layout);
    //Tensor sC2 = make_tensor(make_smem_ptr(reinterpret_cast<__half2*>(smem.C)), gC2_layout);

    using RTile = Tile<Layout<Shape<_8,Int<WARPS>>>, Layout<_4>>;

    // For efficient warp shuffle assembly of tensor core fragments we need the register
    // layout of the centroid accumulators to be arranged with the standard 128B swizzle
    // normally used in smem for avoiding bank conflicts.

    auto full_thread_R_coord = make_coord(make_coord(make_coord(_,wid),qlid), make_coord(_,qid));
    Tensor tRgCtot2_sw = zipped_divide(gCtot2_sw, RTile{})(full_thread_R_coord);
    Tensor tRgCtot_out2_sw = zipped_divide(gCtot_out2_sw, RTile{})(full_thread_R_coord);

    Tensor tRrC2 = make_tensor_like(tRgCtot2_sw);
    copy(tRgCtot2_sw, tRrC2);

    constexpr auto row_layout_R = zipped_divide(layout<0>(gCtot2), get<0>(RTile{}));
    constexpr auto centroid_to_row_warp_n = make_identity_tensor(flatten(shape(row_layout_R)));
    auto cC2 = make_identity_tensor(shape(gCtot2));
    auto tRcC2 = zipped_divide(cC2, RTile{})(full_thread_R_coord);

    Tensor qCgCcnt = outer_partition(gCcnt, Tile<Layout<Shape<_8,Int<WARPS>>>, Layout<_1>>{}, make_coord(make_coord(qid, wid), 0));
    Tensor qCgCcnt_out = outer_partition(gCcnt_out, Tile<Layout<Shape<_8,Int<WARPS>>>, Layout<_1>>{}, make_coord(make_coord(qid, wid), 0));
    Tensor qCcnt = make_tensor_like(qCgCcnt(_,0));
    copy(qCgCcnt(_,0), qCcnt);
    Tensor qCinvcnt = make_tensor_like(qCcnt);
    // sq norms of centroids for converting dot products to distances
    Tensor Csqmag = make_tensor<fp16>(Shape<_2,Int<size<1>(tRrC2)>>{});



    //constexpr auto row_layout_R = zipped_divide(Layout<Shape<Int<K>>>{}, Tile<Layout<Shape<_8,_4>>>{});

    for (int iter = 0; iter < iters; iter++) for (int minibatch = 0; minibatch < N; minibatch++) {
        __syncthreads();
        if (wid < copy_warp_count) {
            copy(blk_copy, thr_copy.partition_S(gX(minibatch,_,_)), thr_copy.partition_D(sX));
            cp_async_fence();
        }
        for (int i = tid; i < size(sT); i += num_threads) {
            sT(i) = 0;
        }
        for (int row = 0; row < size<0>(tRrC2); row++) {
            for (int n = 0; n < size<1>(tRrC2); n++) {
                const __half count = __shfl_sync(0xFFFFFFFF, qCcnt(n), 4*row + qlid);
                const __half2 val2 = tRrC2(row,n) / __half2(count,count);
                const __half2 val2_sq = __hmul2(val2, val2);
                const fp16 val = fp16(val2_sq.x + val2_sq.y);
                const fp16 row_sum = cg::reduce(warp, val, cg::plus<fp16>());
                //const int C_idx = get<0>(tRcC2(row,n));
                //if (warp.thread_rank() == 0) {
                //    smem.tagged_max[C_idx] = 0;
                //}
                const int accum_col_idx = (row + n*8)/2; // accumulator of mma is arranged in fp16 pairs along rows
                if (qlid == accum_col_idx%4) { // does this quad lane own the relevant column of the accumulator?
                    Csqmag(row%2, n) = row_sum;
                }
            }
        }
        for (int n = 0; n < size(qCcnt); n++) {
            qCinvcnt(n) = fp16(1.0f) / qCcnt(n);
        }
        cp_async_wait<0>();
        __syncthreads(); // make sure Csqmag, sX and sT are ready

        Tensor tCgS = thr_mma.partition_C(gS(minibatch,_,_));
        Tensor cS = make_identity_tensor(shape(gS(minibatch,_,_)));
        Tensor tCcS = thr_mma.partition_C(cS);
        Tensor tCrS = thr_mma.make_fragment_C(tCgS);
        Tensor tCsX = thr_mma.partition_A(sX);
        Tensor tCrX = thr_mma.make_fragment_A(tCsX);
        Tensor tCsC = thr_mma.partition_B(gCtot);
        clear(tCrS);
        copy(tCsX, tCrX);
        for (int m = 0; m < size<1>(tCsX); m++) {
            for (int n = 0; n < size<1>(tCsC); n++) {
                //const __half2 inv_cnt_n2 = __float2half2_rn(1.0f / float(qCcnt(n)));
                const __half2 inv_cnt_n2 = __half2(__half(qCinvcnt(n)), __half(qCinvcnt(n)));
                for (int k = 0; k < size<2>(tCsC); k++) {
                    Tensor tCrC_nk = make_tensor_like(tCsC(_,n,k));
                    Tensor tCrC2_nk = make_tensor(reinterpret_cast<__half2*>(tCrC_nk.data()), Layout<Shape<_2>>{});
                    // Here the xor shuffles assemble tensor core tiles from the swizzled centroid data
                    tCrC2_nk(0) = __shfl_xor_sync(0xFFFFFFFF, tRrC2(qid^(2*k),n), (2*k)<<2) * inv_cnt_n2;
                    tCrC2_nk(1) = __shfl_xor_sync(0xFFFFFFFF, tRrC2(qid^(2*k+1),n), (2*k+1)<<2) * inv_cnt_n2;
                    gemm(atom_mma, tCrX(_,m,k), tCrC_nk, tCrS(_,m,n));
                }
            }
        }
        //copy(tCrS, tCgS);
        auto row_col_tCrS_layout = make_layout(
            make_layout(layout<0,1>(tCrS.layout()), layout<1>(tCrS.layout())), //row
            make_layout(layout<0,0>(tCrS.layout()), layout<2>(tCrS.layout()))  //col
        );
        Tensor row_col_tCrS = make_tensor(tCrS.data(), row_col_tCrS_layout);
        auto row_col_tCgS_layout = make_layout(
            make_layout(layout<0,1>(tCgS.layout()), layout<1>(tCgS.layout())), //row
            make_layout(layout<0,0>(tCgS.layout()), layout<2>(tCgS.layout()))  //col
        );

        for (int row_base = 0; row_base < size<0>(row_col_tCrS); row_base += 4) {
            //continue;
            uint32_t quad_lane_max_u = 0;
            int quad_lane_x_id = 0;
            for (int row = row_base; row < row_base + 4; row++) {
                if (row >= size<0>(row_col_tCrS)) continue;
                uint32_t local_max_u = 0;
                for (int col = 0; col < size<1>(row_col_tCrS); col++) {
                    const uint16_t cluster_id = get<1>(tCcS(row_col_tCrS_layout(row, col)));
                    const fp16 val = row_col_tCrS(row, col) - fp16(0.5f) * fp16(Csqmag(col)); // put this in the accumulator initially?
                    // transform fp16 to uint16 with an order preserving function then
                    // pack a 16bit cluster_id in the least significant bits of a uint32
                    // this merges the argmax with the max and automatically does deterministic tie-breaks
                    const uint32_t val_u = fp16_to_tagged_ordered_uint(val, cluster_id);
                    local_max_u = umax(local_max_u, val_u);
                }
                uint32_t quad_max_u = cg::reduce(quad, local_max_u, cg::greater<uint32_t>());
                int quad_max_cluster_id = extract_tag_from_ordered_uint(quad_max_u);
                int x_id = get<0>(tCcS(row_col_tCrS_layout(row,0)));
                if (row % 4 == quad.thread_rank()) {
                    // each lane in the quad gets a different row
                    quad_lane_max_u = quad_max_u;
                    quad_lane_x_id = x_id;
                }
            }
            if (quad.thread_rank() < size<0>(row_col_tCrS) - row_base) {
                uint32_t old_max_u = atomicMax(&sT(quad_lane_x_id), quad_lane_max_u); // parallel disjoint atomicMax
            }
        }


        __syncthreads();
        for (int i = tid; i < size(sT); i += num_threads) {
            uint16_t label = extract_tag_from_ordered_uint(sT(i));
            //gL(minibatch,i) = int(label);
        }
        //__syncthreads();

        Tensor bAsT = zipped_divide(sT, Shape<_32>{}); //(lid), (rest)
        Tensor tAsT = bAsT(lid, _);
        Tensor tArT = make_tensor_like(tAsT);
        copy(tAsT, tArT);

        for (int rest = 0; rest < size(tArT); rest++) {
            //continue;
            __half2 beta2 = __half2(__half(beta), __half(beta));
            const int val = extract_tag_from_ordered_uint(tArT(rest));
            //const int val = 0xFFFF & tArT(rest);
            auto [val_row, val_warp, val_n] = centroid_to_row_warp_n(val);
            uint32_t active_lanes = __ballot_sync(0xFFFFFFFF, val_warp == wid);
            while (active_lanes != 0) {
                const int active_lid = __ffs(active_lanes) - 1;
                int label = __shfl_sync(0xFFFFFFFF, val, active_lid); // broadcast because if we own this label all threads must update
                auto [label_row, label_warp, label_n] = centroid_to_row_warp_n(label);
                const int x_index = bAsT.layout()(active_lid, rest);
                // Read rows from sX in Swizzle128B layout
                const __half2 x_val2 = sX2_sw(x_index, 4*(qid^label_row) + qlid);
                CUTE_STATIC_ASSERT(size<1>(tRrC2) == 1, "for now require n=1 in tRrC2");
                CUTE_STATIC_ASSERT(size<0>(tRrC2) == 8, "should be 8 rows in tRrC2");
                tRrC2(label_row, label_n) = beta2 * tRrC2(label_row, label_n) + x_val2;
                if (label_row == qid) {
                    qCcnt(label_n) = beta * qCcnt(label_n) + fp16(1.0f);
                }
                active_lanes &= active_lanes - 1; // clear lowest set bit
            }
        }
    } // end for minibatch
    copy(tRrC2, tRgCtot_out2_sw);
    for (int n = 0; n < size(qCcnt); n++) {
        CUTE_STATIC_ASSERT(D % 4 == 0, "D must be multiple of 4 for this write loop");
        for (int offset = 0; offset < D/4; offset++) {
            qCgCcnt_out(n,offset*4 + qlid) = qCcnt(n);
        }
    }
}



using namespace xla::ffi;

Error adjust_host(
    cudaStream_t stream,
    Buffer<BF16> X_buf,
    Buffer<BF16> K_buf,
    ResultBuffer<F32> S_buf,
    ResultBuffer<S32> lab_buf
) {
    constexpr int M = 64;
    constexpr int N = 64;
    constexpr int K = 64;
    const int num_warps = 4;
    //const size_t shared_bytes = (M * K + N * K) * sizeof(bf16) + M * N * sizeof(float);
    const size_t shared_bytes = 0;
    dim3 grid(1);
    dim3 block(32 * num_warps);
    adjust_kernel<M,N,K><<<grid, block, shared_bytes, stream>>>(
        reinterpret_cast<const bf16*>(X_buf.typed_data()),
        reinterpret_cast<const bf16*>(K_buf.typed_data()),
        S_buf->typed_data(),
        lab_buf->typed_data()
    );
    cudaError_t last_err = cudaGetLastError();
    if (last_err != cudaSuccess) return Error::Internal(std::string("CUDA error: ") + cudaGetErrorString(last_err));
    return Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
    adjust,
    adjust_host,
    Ffi::Bind()
        .Ctx<PlatformStream<cudaStream_t>>()
        .Arg<Buffer<BF16>>()
        .Arg<Buffer<BF16>>()
        .Ret<Buffer<F32>>()
        .Ret<Buffer<S32>>(),
    {Traits::kCmdBufferCompatible}
);


template <int K, int num_warps>
Error adjust_fp16_host(
    cudaStream_t stream,
    Buffer<F16> X_buf,
    Buffer<F16> Ctot_buf,
    Buffer<F16> Ccnt_buf,
    ResultBuffer<F16> S_buf,
    ResultBuffer<S32> lab_buf,
    ResultBuffer<F16> Ctot_out_buf,
    ResultBuffer<F16> Ccnt_out_buf,
    float beta,
    int iters
) {
    constexpr int B = 64;
    constexpr int D = 64;
    //constexpr int K = 64;
    //constexpr int num_warps = 8;

    size_t outer_batch_size = 1;
    const int num_batch_dims = X_buf.dimensions().size() - 2;
    if (num_batch_dims < 0) return Error::InvalidArgument("X must be rank 2 plus batch dims");
    for (int i = 0; i < X_buf.dimensions().size() - 2; i++) outer_batch_size *= X_buf.dimensions()[i];
    const int O = outer_batch_size;
    //check N matches
    const int total_N = X_buf.dimensions()[num_batch_dims];
    if (total_N % B != 0) return Error::InvalidArgument("X first dimension must be multiple of B");
    const int N = total_N / B;
    if (S_buf->dimensions().size() != 3+num_batch_dims) return Error::InvalidArgument("S must be rank 3 plus batch dims");
    if (S_buf->dimensions()[num_batch_dims] != N) return Error::InvalidArgument("S first dimension must match X");
    int lab_buf_elems = 1;
    for (int i = num_batch_dims; i < lab_buf->dimensions().size(); i++) lab_buf_elems *= lab_buf->dimensions()[i];
    if (lab_buf_elems != N * B) return Error::InvalidArgument("lab must have N*B elements");
    const size_t shared_bytes = 0;
    dim3 grid(outer_batch_size);
    dim3 block(32 * num_warps);
    auto X_layout = make_layout(make_shape(O, N, Int<B>{}, Int<D>{}), LayoutRight{});
    auto C_layout = make_layout(make_shape(O, Int<K>{}, Int<D>{}), LayoutRight{});
    auto S_layout = make_layout(make_shape(O, N, Int<B>{}, Int<K>{}), LayoutRight{});
    auto lab_layout = make_layout(make_shape(O, N, Int<B>{}), LayoutRight{});
    adjust_fp16_kernel<B,D,K,num_warps><<<grid, block, shared_bytes, stream>>>(
        make_tensor(reinterpret_cast<const fp16*>(X_buf.typed_data()), X_layout),
        make_tensor(reinterpret_cast<const fp16*>(Ctot_buf.typed_data()), C_layout),
        make_tensor(reinterpret_cast<const fp16*>(Ccnt_buf.typed_data()), C_layout),
        make_tensor(reinterpret_cast<fp16*>(S_buf->typed_data()), S_layout),
        make_tensor(lab_buf->typed_data(), lab_layout),
        make_tensor(reinterpret_cast<fp16*>(Ctot_out_buf->typed_data()), C_layout),
        make_tensor(reinterpret_cast<fp16*>(Ccnt_out_buf->typed_data()), C_layout),
        fp16(beta),
        iters
    );
    cudaError_t last_err = cudaGetLastError();
    if (last_err != cudaSuccess) return Error::Internal(std::string("CUDA error: ") + cudaGetErrorString(last_err));
    return Error::Success();
}

Error adjust_fp16_host_multiK(
    cudaStream_t stream,
    Buffer<F16> X_buf,
    Buffer<F16> Ctot_buf,
    Buffer<F16> Ccnt_buf,
    ResultBuffer<F16> S_buf,
    ResultBuffer<S32> lab_buf,
    ResultBuffer<F16> Ctot_out_buf,
    ResultBuffer<F16> Ccnt_out_buf,
    float beta,
    int iters
) {
    const int num_batch_dims = Ctot_buf.dimensions().size() - 2;
    if (num_batch_dims < 0) return Error::InvalidArgument("Ctot must be rank 2 plus batch dims");
    const int K = Ctot_buf.dimensions()[num_batch_dims];
    switch (K) {
        //case 8: return adjust_fp16_host<8,1>(stream, X_buf, Ctot_buf, Ccnt_buf, S_buf, lab_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
        //case 16: return adjust_fp16_host<16,2>(stream, X_buf, Ctot_buf, Ccnt_buf, S_buf, lab_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
        //case 32: return adjust_fp16_host<32,4>(stream, X_buf, Ctot_buf, Ccnt_buf, S_buf, lab_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
        case 64: return adjust_fp16_host<64,8>(stream, X_buf, Ctot_buf, Ccnt_buf, S_buf, lab_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
        case 128: return adjust_fp16_host<128,16>(stream, X_buf, Ctot_buf, Ccnt_buf, S_buf, lab_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
        //case 256: return adjust_fp16_host<256,16>(stream, X_buf, Ctot_buf, Ccnt_buf, S_buf, lab_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
        default: return Error::InvalidArgument("K must be 8/16/32/64/128/256");
    }
}


XLA_FFI_DEFINE_HANDLER_SYMBOL(
    adjust_fp16,
    adjust_fp16_host_multiK,
    Ffi::Bind()
        .Ctx<PlatformStream<cudaStream_t>>()
        .Arg<Buffer<F16>>()
        .Arg<Buffer<F16>>()
        .Arg<Buffer<F16>>()
        .Ret<Buffer<F16>>()
        .Ret<Buffer<S32>>()
        .Ret<Buffer<F16>>()
        .Ret<Buffer<F16>>()
        .Attr<float>("beta")
        .Attr<int>("iters"),
    {Traits::kCmdBufferCompatible}
);
