#include <cutlass/numeric_types.h>
#include <cutlass/bfloat16.h>

#include <cuda_runtime.h>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda/atomic>
#include <cuda_fp16.h>

#include <cute/tensor.hpp>
#include <cute/arch/mma_sm80.hpp>
#include <cute/arch/copy_sm80.hpp>
#include <cute/algorithm/gemm.hpp>

#include "xla/ffi/api/ffi.h"

#include <math.h>

using namespace cute;
namespace cg = cooperative_groups;
using bf16 = cutlass::bfloat16_t;
using fp16 = cutlass::half_t;

__device__ __host__ inline constexpr uint32_t float_to_ordered_uint(float f) {
    uint32_t u = reinterpret_cast<const uint32_t&>(f);
    uint32_t mask = -int32_t(u >> 31) | 0x80000000;
    return u ^ mask;
}

__device__ __host__ inline constexpr uint32_t fp16_to_tagged_ordered_uint(fp16 f, uint16_t tag) {
    uint16_t u = reinterpret_cast<const uint16_t&>(f);
    uint16_t mask = -int16_t(u >> 15) | 0x8000;
    return (uint32_t(u ^ mask) << 16) | uint32_t(tag);
}

__device__ __host__ inline constexpr uint16_t extract_tag_from_ordered_uint(uint32_t u) {
    return uint16_t(u & 0xFFFF);
}

template <int M = 64, int N = 64, int K = 64>
__global__ void adjust_kernel(
    const bf16* X_ptr,
    const bf16* K_ptr,
    float* S_ptr,
    int32_t* lab_ptr
) {
    using LayoutX = Layout<Shape<Int<M>, Int<K>>, Stride<Int<K>, Int<1>>>;
    using LayoutK = Layout<Shape<Int<N>, Int<K>>, Stride<Int<K>, Int<1>>>;
    using LayoutS = Layout<Shape<Int<M>, Int<N>>, Stride<Int<N>, Int<1>>>;

    cg::thread_block cta = cg::this_thread_block();
    const int tid = cta.thread_rank();
    const int num_threads = cta.size();

    constexpr size_t align = 128;
    constexpr size_t X_offset = 0;
    constexpr size_t K_offset = (X_offset + cosize_v<LayoutX> * sizeof(bf16) + align-1) & ~(align-1);
    constexpr size_t S_offset = (K_offset + cosize_v<LayoutK> * sizeof(bf16) + align-1) & ~(align-1);

    struct SharedStorage{
        alignas(128) bf16 X[cosize_v<LayoutX>];
        alignas(128) bf16 K_[cosize_v<LayoutK>];
        alignas(128) float S[cosize_v<LayoutS>];
    };
    __shared__ SharedStorage shared;
    //extern __shared__ uint8_t smem[];

    auto gX = make_tensor(make_gmem_ptr(X_ptr), LayoutX{});
    auto gK = make_tensor(make_gmem_ptr(K_ptr), LayoutK{});
    auto gS = make_tensor(make_gmem_ptr(S_ptr), LayoutS{});
    Tensor gL = make_tensor(make_gmem_ptr(lab_ptr), Layout<Shape<Int<M>>>());
    //auto sX = make_tensor(make_smem_ptr(reinterpret_cast<bf16*>(smem + X_offset)), LayoutX{});
    auto sX = make_tensor(make_smem_ptr(shared.X), LayoutX{});
    //auto sK = make_tensor(make_smem_ptr(reinterpret_cast<bf16*>(smem + K_offset)), LayoutK{});
    auto sK = make_tensor(make_smem_ptr(shared.K_), LayoutK{});
    //auto sS = make_tensor(make_smem_ptr(reinterpret_cast<float*>(smem + S_offset)), LayoutS{});
    auto sS = make_tensor(make_smem_ptr(shared.S), LayoutS{});

    auto copy_atom = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, bf16>{};
    auto tiled_copy = make_tiled_copy(copy_atom, Layout<Shape<_32, _4>>{}, Layout<Shape<_1, _8>>{});
    auto thr_copy = tiled_copy.get_thread_slice(tid);
    copy(tiled_copy, thr_copy.partition_S(gX), thr_copy.partition_D(sX));
    copy(tiled_copy, thr_copy.partition_S(gK), thr_copy.partition_D(sK));

    cp_async_fence();
    cp_async_wait<0>();
    __syncthreads();

    auto tiled_mma = make_tiled_mma(
        SM80_16x8x16_F32BF16BF16F32_TN{},
        Layout<Shape<_2,_2,_1>>{}
    );
    auto thr_mma = tiled_mma.get_thread_slice(tid);

    auto tCsX = thr_mma.partition_A(sX);
    auto tCsK = thr_mma.partition_B(sK);
    auto tCsS = thr_mma.partition_C(sS);

    auto tCrS = thr_mma.make_fragment_C(tCsS);

    clear(tCrS);
    gemm(tiled_mma, tCsX, tCsK, tCrS);
    copy(tCrS, tCsS);
    auto tCgS = thr_mma.partition_C(gS);
    copy(tCrS, tCgS);


    __syncthreads();

    auto warp = cg::tiled_partition<32>(cta);
    int wid = warp.meta_group_rank();
    int lid = warp.thread_rank();
    
    Tensor tRsS = zipped_divide(sS, Shape<_4,_32>{})(make_coord(wid, lid), make_coord(_,_));
    Tensor wRgL = local_partition(gL, Layout<Shape<_4>, Stride<_1>>{}, wid);
    Tensor tRrS = make_tensor_like(tRsS);

    copy(tRsS, tRrS);
    for (int row = 0; row < size<0>(tRrS); row++) {
        float max_val = tRrS(row,0);
        int local_max_index = 0;
        for (int col = 1; col < size<1>(tRrS); col++) {
            if (tRrS(row,col) > max_val) {
                max_val = tRrS(row,col);
                local_max_index = col;
            }
        }
        uint32_t max_val_u = float_to_ordered_uint(max_val);
        uint32_t warp_max_u = __reduce_max_sync(0xFFFFFFFF, max_val_u);
        //float warp_max = cg::reduce(warp, max_val, cg::greater<float>());
        if (max_val_u == warp_max_u) {
            if (cg::coalesced_threads().thread_rank() == 0) {
                // We are the lowest indexed thread with the max value
                wRgL(row) = lid + local_max_index * 32; // set label to index of max value
            }
        }
    }

    // accumulate K in registers
    // (wid, lid) -> row, col of centroids
    // does label map to our wid?
    //Tensor all_tAsS = logical_divide(sS, Shape<_4,_32>{}); //(wid, rows),(lid,cols)
    //auto row_inverse = right_inverse(get<_0>(all_tAsS.layout()));
    //Tensor tAsS = all_tAsS(make_coord(wid, _), make_coord(lid, _));
    //Tensor tAgL = local_partition(gL, Layout<Shape<_32>, Stride<_1>>{}, lid);
    //for (int centroid = wid; centroid < 64; centroid += 4) {
    //}

    Tensor bAsK = logical_divide(sK, Tile<Layout<_4>, Layout<_32,_2>>{}); //(wid, rows),(lid, pair)
    if (thread0()) {print("bAsK: "); print(bAsK); print("\n");}
    Tensor tAsK = bAsK(make_coord(wid, _), make_coord(lid, _));
    if (thread0()) {print("tAsK: "); print(tAsK); print("\n");}
    Tensor tArK = make_tensor<float>(tAsK.layout());
    if (thread0()) {print("tArK: "); print(tArK); print("\n");}
    Tensor tArK_count = make_tensor<int>(get<0>(tArK.layout()).shape());
    if (thread0()) {print("tArK_count: "); print(tArK_count); print("\n");}
    for (int i = 0; i < size(tArK); i++) {
        tArK(i) = float(tAsK(i));
    }
    clear(tArK);
    clear(tArK_count);

    Tensor bAgL = zipped_divide(gL, Shape<_32>{}); //(lid), (rest)
    Tensor tAgL = bAgL(lid, _);
    Tensor tArL = make_tensor_like(tAgL);
    copy(tAgL, tArL);
    if (thread0()) {print("bAgL: "); print(bAgL); print("\n");}
    if (thread0()) {print("tAgL: "); print(tAgL); print("\n");}
    if (thread0()) {print("tArL: "); print(tArL); print("\n");}

    Tensor bAsX = logical_divide(sX, Tile<Layout<_1>, Layout<_32,_2>>{}); //(row), (lid, pair)
    if (thread0()) {print("bAsX: "); print(bAsX); print("\n");}
    Tensor tAsX = bAsX(_, make_coord(lid, _));
    if (thread0()) {print("tAsX: "); print(tAsX); print("\n");}

    for (int active_lid = 0; active_lid < 32; active_lid++) {
        for (int rest = 0; rest < size(tArL); rest++) {
            int label = __shfl_sync(0xFFFFFFFF, tArL(rest), active_lid);
            if ((label % 4) == wid) {
                int x_index = active_lid + rest * 32;
                int current_label = label;
                int current_row = current_label / 4;
                tArK_count(current_row) += 1;
                Tensor chosen_tArK = tArK(current_row, _);
                for (int i = 0; i < size(chosen_tArK); i++) {
                    chosen_tArK(i) += float(tAsX(x_index, i));
                }
            }
        }
    }

    for (int i = 0; i < size(tArK); i++) {
        tAsK(i) = bf16(tArK(i));
    }
    Tensor tAgS = logical_divide(gS, Tile<Layout<_4>, Layout<_32,_2>>{})(make_coord(wid, _), make_coord(lid, _));
    copy(tArK, tAgS);
}

template <int B, int D, int K, int WARPS>
__global__ void __launch_bounds__(WARPS*32,1) adjust_fp16_kernel(
    Tensor<ViewEngine<const fp16*>, Layout<Shape<int, int, Int<B>, Int<D>>, Stride<int, Int<B*D>, Int<D>, Int<1>>>> grid_gX,
    Tensor<ViewEngine<const fp16*>, Layout<Shape<int, Int<K>, Int<D>>, Stride<Int<K*D>, Int<D>, Int<1>>>> grid_gCtot,
    //Tensor<ViewEngine<const fp16*>, Layout<Shape<int, Int<K>, Int<D>>, Stride<Int<K*D>, Int<D>, Int<1>>>> grid_gCcnt,
    Tensor<ViewEngine<const fp16*>, Layout<Shape<int, Int<K>>, Stride<Int<K>, Int<1>>>> grid_gCcnt,
    //Tensor<ViewEngine<fp16*>, Layout<Shape<int, int, Int<B>, Int<K>>, Stride<int, Int<B*K>, Int<K>, Int<1>>>> grid_gS,
    //Tensor<ViewEngine<int32_t*>, Layout<Shape<int, int, Int<B>>, Stride<int, Int<B>, Int<1>>>> grid_gL,
    Tensor<ViewEngine<fp16*>, Layout<Shape<int, Int<K>, Int<D>>, Stride<Int<K*D>, Int<D>, Int<1>>>> grid_gCtot_out,
    //Tensor<ViewEngine<fp16*>, Layout<Shape<int, Int<K>, Int<D>>, Stride<Int<K*D>, Int<D>, Int<1>>>> grid_gCcnt_out,
    Tensor<ViewEngine<fp16*>, Layout<Shape<int, Int<K>>, Stride<Int<K>, Int<1>>>> grid_gCcnt_out,
    fp16 beta,
    int iters
) {
    //Note: Designed for D=64, B%16==0 (e.g. 64), K%(8*WARPS) == 0 (e.g. K=64, WARPS=4)
    //D=64 corresponds to standard LLM head size.
    //Centroids are kept only in registers, leaving all smem for data and allowing B and K to optimized ~separately.

    //Select outer batch for this block
    const int block_index = blockIdx.x;
    Tensor gX = grid_gX(block_index,_,_,_);
    Tensor gCtot = grid_gCtot(block_index,_,_);
    Tensor gCcnt = grid_gCcnt(block_index,_);
    //Tensor gS = grid_gS(block_index,_,_,_);
    //Tensor gL = grid_gL(block_index,_,_);
    Tensor gCtot_out = grid_gCtot_out(block_index,_,_);
    Tensor gCcnt_out = grid_gCcnt_out(block_index,_);

    using Swizzle128B_fp16 = Swizzle<3,3,3>;
    using Swizzle128B_half2 = Swizzle<3,2,3>;

    const int N = size<0>(gX);
    struct SharedStorage{
        alignas(128) fp16 X[cosize_v<decltype(gX(0,_,_).layout())>];
        //alignas(128) fp16 C[cosize_v<decltype(layout(gCtot))>];
        alignas(128) uint32_t tagged_max[B];
    };
    __shared__ SharedStorage smem;
    Tensor sX = make_tensor(make_smem_ptr(smem.X), composition(Swizzle128B_fp16{}, gX(0,_,_).layout()));
    //Tensor sC = make_tensor(make_smem_ptr(smem.C), composition(Swizzle128B_fp16{}, layout(gCtot)));
    //Tensor sX = make_tensor(make_smem_ptr(smem.X), gX(0,_,_).layout());
    Tensor sT = make_tensor(make_smem_ptr(smem.tagged_max), Layout<Shape<Int<B>>>{});

    cg::thread_block cta = cg::this_thread_block();
    const int tid = cta.thread_rank();
    const int num_threads = cta.size();
    auto warp = cg::tiled_partition<32>(cta);
    const int wid = warp.meta_group_rank();
    const int lid = warp.thread_rank();
    const int num_warps = warp.meta_group_size();
    auto quad = cg::tiled_partition<4>(warp);
    const int qid = quad.meta_group_rank();
    const int qlid = quad.thread_rank();

    constexpr int copy_warp_count = (8 < WARPS) ? 8 : WARPS;
    auto blk_copy = make_tiled_copy(
        Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, fp16>{},
        make_layout(
            make_shape(Int<4*copy_warp_count>{}, _8{}),
            LayoutRight{}
        ),
        Layout<Shape<_1,_8>>{}
    );
    auto thr_copy = blk_copy.get_thread_slice(tid);

    using MMA_TYPE = SM80_16x8x16_F16F16F16F16_TN;
    auto atom_mma = MMA_Atom<MMA_TYPE>{};
    auto blk_mma = make_tiled_mma(
        MMA_TYPE{},
        Layout<Shape<_1,Int<WARPS>,_1>>{}
    );
    auto thr_mma = blk_mma.get_thread_slice(tid);


    CUTE_STATIC_ASSERT(D % 2 == 0, "D must be even for fp16 half2");
    auto sX2_layout = Layout<Shape<Int<B>, Int<D/2>>, Stride<Int<D/2>, Int<1>>>{};
    auto sX2_sw_layout = composition(Swizzle128B_half2{}, sX2_layout);
    Tensor sX2 = make_tensor(make_smem_ptr(reinterpret_cast<__half2*>(smem.X)), sX2_layout);
    Tensor sX2_sw = make_tensor(make_smem_ptr(reinterpret_cast<__half2*>(smem.X)), sX2_sw_layout);
    auto gC2_layout = Layout<Shape<Int<K>, Int<D/2>>, Stride<Int<D/2>, _1>>{};
    auto gC2_sw_layout = composition(Swizzle128B_half2{}, gC2_layout);
    Tensor gCtot2 = make_tensor(make_gmem_ptr(reinterpret_cast<const __half2*>(gCtot.data())), gC2_layout);
    Tensor gCtot_out2 = make_tensor(make_gmem_ptr(reinterpret_cast<__half2*>(gCtot_out.data())), gC2_layout);
    Tensor gCtot2_sw = make_tensor(make_gmem_ptr(reinterpret_cast<const __half2*>(gCtot.data())), gC2_sw_layout);
    Tensor gCtot_out2_sw = make_tensor(make_gmem_ptr(reinterpret_cast<__half2*>(gCtot_out.data())), gC2_sw_layout);
    //Tensor sC2 = make_tensor(make_smem_ptr(reinterpret_cast<__half2*>(smem.C)), gC2_layout);

    using RTile = Tile<Layout<Shape<_8,Int<WARPS>>>, Layout<_4>>;

    // For efficient warp shuffle assembly of tensor core fragments we need the register
    // layout of the centroid accumulators to be arranged with the standard 128B swizzle
    // normally used in smem for avoiding bank conflicts.

    //auto full_thread_R_coord = make_coord(make_coord(make_coord(qid,wid),qlid), make_coord(_,_));
    auto full_thread_R_coord = make_coord(make_coord(make_coord(_,wid),qlid), make_coord(_,qid));
    Tensor tRgCtot2_sw = zipped_divide(gCtot2_sw, RTile{})(full_thread_R_coord);
    Tensor tRgCtot_out2_sw = zipped_divide(gCtot_out2_sw, RTile{})(full_thread_R_coord);
    Tensor tRgCtot_out2 = zipped_divide(gCtot_out2, RTile{})(full_thread_R_coord);

    //Tensor tRsC2 = zipped_divide(sC2, RTile{})(full_thread_R_coord);

    if (wid < copy_warp_count) {
        //copy(gCtot, sC);
        //copy(blk_copy, thr_copy.partition_S(gCtot), thr_copy.partition_D(sC));
        //cp_async_fence();
    }
    //cp_async_wait<0>();
    //__syncthreads();

    //Tensor tRrC2 = make_tensor_like(tRgCtot2_sw);
    //copy(tRgCtot2_sw, tRrC2);
    //copy(tRsC2, tRrC2);

    constexpr auto row_layout_R = zipped_divide(layout<0>(gCtot2), get<0>(RTile{}));
    constexpr auto centroid_to_row_warp_n = make_identity_tensor(flatten(shape(row_layout_R)));
    auto cC2 = make_identity_tensor(shape(gCtot2));
    auto tRcC2 = zipped_divide(cC2, RTile{})(full_thread_R_coord);

    //Tensor qCgCcnt = outer_partition(gCcnt, Tile<Layout<Shape<_8,Int<WARPS>>>, Layout<_1>>{}, make_coord(make_coord(qid, wid), 0));
    Tensor qCgCcnt = outer_partition(gCcnt, Tile<Layout<Shape<_8,Int<WARPS>>>>{}, make_coord(make_coord(qid, wid)));
    Tensor qCgCcnt_out = outer_partition(gCcnt_out, Tile<Layout<Shape<_8,Int<WARPS>>>>{}, make_coord(make_coord(qid, wid)));
    Tensor qCcnt = make_tensor_like(qCgCcnt);
    copy(qCgCcnt, qCcnt);
    Tensor qCinvcnt = make_tensor_like(qCcnt);
    // sq norms of centroids for converting dot products to distances
    Tensor Csqmag = make_tensor<fp16>(Shape<_2,Int<size<1>(tRcC2)>>{});

    //Tensor tCsC = thr_mma.partition_B(sC);
    //Tensor tCrC = thr_mma.make_fragment_B(gCtot);
    //copy(tCsC, tCrC);
    Tensor tCgCtot = thr_mma.partition_B(gCtot);
    Tensor tCrC = make_tensor_like(tCgCtot);
    copy(tCgCtot, tCrC);
    Tensor tCrC2 = recast<__half2>(tCrC);



    //constexpr auto row_layout_R = zipped_divide(Layout<Shape<Int<K>>>{}, Tile<Layout<Shape<_8,_4>>>{});

    for (int iter = 0; iter < iters; iter++) for (int minibatch = 0; minibatch < N; minibatch++) {
        __syncthreads();
        if (wid < copy_warp_count) {
            copy(blk_copy, thr_copy.partition_S(gX(minibatch,_,_)), thr_copy.partition_D(sX));
            cp_async_fence();
        }
        for (int i = tid; i < size(sT); i += num_threads) {
            sT(i) = 0;
        }
        for (int n = 0; n < size(qCcnt); n++) {
            qCinvcnt(n) = fp16(1.0f) / qCcnt(n);
        }
        //Tensor tCrC2 = recast<__half2>(tCrC);
        for (int n = 0; n < size<1>(tCrC2); n++) {
            Tensor tCrC2_n = tCrC2(_,n,_);
            const __half2 inv_cnt_n2 = __half2(__half(qCinvcnt(n)), __half(qCinvcnt(n)));
            __half2 row_sum2 = __half2(0.0f, 0.0f);
            for (int i = 0; i < size(tCrC2_n); i++) {
                const __half2 val2 = tCrC2_n(i) * inv_cnt_n2;
                row_sum2 = row_sum2 + val2 * val2;
            }
            row_sum2 = cg::reduce(quad, row_sum2, cg::plus<__half2>());
            const fp16 row_sum = fp16(row_sum2.x + row_sum2.y);
            Csqmag(0, n) = __shfl_sync(0xFFFFFFFF, row_sum, (qlid*2)*4);
            Csqmag(1, n) = __shfl_sync(0xFFFFFFFF, row_sum, (qlid*2+1)*4);
        }

        cp_async_wait<0>();
        __syncthreads(); // make sure Csqmag, sX and sT are ready

        //Tensor tCgS = thr_mma.partition_C(gS(minibatch,_,_));
        auto shape_S = Shape<Int<B>, Int<K>>{};
        //Tensor cS = make_identity_tensor(shape(gS(minibatch,_,_)));
        Tensor cS = make_identity_tensor(shape_S);
        Tensor tCcS = thr_mma.partition_C(cS);
        //Tensor tCrS = thr_mma.make_fragment_C(tCgS);
        Tensor tCsX = thr_mma.partition_A(sX);
        //Tensor tCrX = thr_mma.make_fragment_A(tCsX);
        //Tensor tCsC = thr_mma.partition_B(gCtot);
        //Tensor tCsC = thr_mma.partition_B(sC);
        //Tensor tCrC = thr_mma.make_fragment_B(tCsC);
        //copy(tCsC, tCrC);
        //clear(tCrS);
        //copy(tCsX, tCrX);
        for (int m = 0; m < size<1>(tCsX); m++) {
            //Tensor tCrS_m = make_tensor_like(tCgS(_,m,_));
            Tensor tCrS_m = make_tensor<fp16>(shape(tCcS(_,m,_)));
            clear(tCrS_m);
            Tensor tCrX_m = make_tensor_like(tCsX(_,m,_));
            copy(tCsX(_,m,_), tCrX_m);
            auto tCcS_m = tCcS(_,m,_);
            uint32_t quad_lane_max_u = 0;
            int quad_lane_x_id = 0;
            for (int n = 0; n < size<1>(tCrC); n++) {
                const __half2 inv_cnt_n2 = __half2(__half(qCinvcnt(n)), __half(qCinvcnt(n)));
                for (int k = 0; k < size<2>(tCrC); k++) {
                    Tensor tCrC_nk = make_tensor_like(tCrC(_,n,k));
                    Tensor tCrC2_nk = make_tensor(reinterpret_cast<__half2*>(tCrC_nk.data()), Layout<Shape<_2>>{});
                    copy(tCrC(_,n,k), tCrC_nk);
                    for (int i = 0; i < size(tCrC2_nk); i++) {
                        tCrC2_nk(i) = inv_cnt_n2 * tCrC2_nk(i);
                    }
                    gemm(atom_mma, tCrX_m(_,k), tCrC_nk, tCrS_m(_,n));
                }
            }
            // argmax over tCrS(_,m,_)
            for (int inner_row = 0; inner_row < size<0,1>(tCrS_m); inner_row++) {
                uint32_t local_max_u = 0;
                for (int n = 0; n < size<1>(tCrS_m); n++) {
                    for (int inner_col = 0; inner_col < size<0,0>(tCrS_m); inner_col++) {
                        auto coord = make_coord(make_coord(inner_col,inner_row), n);
                        const uint16_t cluster_id = get<1>(tCcS_m(coord));
                        const fp16 val = tCrS_m(coord) - fp16(0.5f) * fp16(Csqmag(inner_col + n*2));
                        const uint32_t val_u = fp16_to_tagged_ordered_uint(val, cluster_id);
                        local_max_u = umax(local_max_u, val_u);
                    }
                }
                uint32_t quad_max_u = cg::reduce(quad, local_max_u, cg::greater<uint32_t>());
                int x_id = get<0>(tCcS_m(make_coord(make_coord(0,inner_row),0)));
                //if (quad.thread_rank() == (2*m + inner_row) % 4) {
                if (quad.thread_rank() == inner_row) {
                    quad_lane_max_u = quad_max_u;
                    quad_lane_x_id = x_id;
                }
            }
            // only writes from lanes 0,1 are useful here, but doing this predicated is slower for some reason
            atomicMax(&sT(quad_lane_x_id), quad_lane_max_u);
        } // end for m in gemm loop


        __syncthreads();
        for (int i = tid; i < size(sT); i += num_threads) {
            uint16_t label = extract_tag_from_ordered_uint(sT(i));
            //gL(minibatch,i) = int(label);
        }
        //__syncthreads();

        Tensor bAsT = zipped_divide(sT, Shape<_32>{}); //(lid), (rest)
        Tensor tAsT = bAsT(lid, _);
        Tensor tArT = make_tensor_like(tAsT);
        copy(tAsT, tArT);

        //Tensor tCrC2 = recast<__half2>(tCrC);
        const __half2 beta2 = __half2(__half(beta), __half(beta));

        for (int rest = 0; rest < size(tArT); rest++) {
            //continue;
            const int val = extract_tag_from_ordered_uint(tArT(rest));
            auto [val_row, val_warp, val_n] = centroid_to_row_warp_n(val);
            uint32_t active_lanes = __ballot_sync(0xFFFFFFFF, val_warp == wid);
            for (int i = 0; i < 8; i++) {
                uint32_t quad_i_active_lanes = __ballot_sync(0xFFFFFFFF, (val_warp == wid) && (val_row == i));
                if (qid == i) {
                    active_lanes = quad_i_active_lanes;
                }
            }
            while (__any_sync(0xFFFFFFFF, active_lanes != 0)) {
                const int active_lid = __ffs(active_lanes) -1;
                int label = __shfl_sync(0xFFFFFFFF, val, active_lid); // broadcast because if we own this label all threads must update
                if (active_lanes == 0) continue;
                auto [label_quad, label_warp, label_n] = centroid_to_row_warp_n(label);
                const int x_index = bAsT.layout()(active_lid, rest);
                if (qid == label_quad) {
                    //Tensor tCrC_n = tCrC(_,label_n,_);
                    Tensor tCrC2_n = tCrC2(_,label_n,_);
                    for (int i = 0; i < size(tCrC2_n); i++) {
                        const __half2 x_val2 = sX2_sw(x_index, 4*i + qlid);
                        tCrC2_n(i) = beta2 * tCrC2_n(i) + x_val2;
                    }
                    qCcnt(label_n) = beta * qCcnt(label_n) + fp16(1.0f);
                }
                active_lanes &= active_lanes - 1; // clear lowest set bit
            }
        }
        //copy(tCrC, tCsC);
    } // end for minibatch
    Tensor tCgCtot_out = thr_mma.partition_B(gCtot_out);
    copy(tCrC, tCgCtot_out); // this write is not coalesced
    //__syncthreads();
    //if (wid < copy_warp_count) {
    //    copy(thr_copy.partition_S(sC), thr_copy.partition_D(gCtot_out));
    //}
    for (int n = 0; n < size(qCcnt); n++) {
        if (qlid == 0) {
            qCgCcnt_out(n) = qCcnt(n);
        }
        //CUTE_STATIC_ASSERT(D % 4 == 0, "D must be multiple of 4 for this write loop");
        //for (int offset = 0; offset < D/4; offset++) {
        //    qCgCcnt_out(n,offset*4 + qlid) = qCcnt(n);
        //}
    }
}



using namespace xla::ffi;

Error adjust_host(
    cudaStream_t stream,
    Buffer<BF16> X_buf,
    Buffer<BF16> K_buf,
    ResultBuffer<F32> S_buf,
    ResultBuffer<S32> lab_buf
) {
    constexpr int M = 64;
    constexpr int N = 64;
    constexpr int K = 64;
    const int num_warps = 4;
    //const size_t shared_bytes = (M * K + N * K) * sizeof(bf16) + M * N * sizeof(float);
    const size_t shared_bytes = 0;
    dim3 grid(1);
    dim3 block(32 * num_warps);
    adjust_kernel<M,N,K><<<grid, block, shared_bytes, stream>>>(
        reinterpret_cast<const bf16*>(X_buf.typed_data()),
        reinterpret_cast<const bf16*>(K_buf.typed_data()),
        S_buf->typed_data(),
        lab_buf->typed_data()
    );
    cudaError_t last_err = cudaGetLastError();
    if (last_err != cudaSuccess) return Error::Internal(std::string("CUDA error: ") + cudaGetErrorString(last_err));
    return Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
    adjust,
    adjust_host,
    Ffi::Bind()
        .Ctx<PlatformStream<cudaStream_t>>()
        .Arg<Buffer<BF16>>()
        .Arg<Buffer<BF16>>()
        .Ret<Buffer<F32>>()
        .Ret<Buffer<S32>>(),
    {Traits::kCmdBufferCompatible}
);


template <int K, int num_warps>
Error adjust_fp16_host(
    cudaStream_t stream,
    Buffer<F16> X_buf,
    Buffer<F16> Ctot_buf,
    Buffer<F16> Ccnt_buf,
    //ResultBuffer<F16> S_buf,
    //ResultBuffer<S32> lab_buf,
    ResultBuffer<F16> Ctot_out_buf,
    ResultBuffer<F16> Ccnt_out_buf,
    float beta,
    int iters
) {
    constexpr int B = 64;
    constexpr int D = 64;
    //constexpr int K = 64;
    //constexpr int num_warps = 8;

    size_t outer_batch_size = 1;
    const int num_batch_dims = X_buf.dimensions().size() - 2;
    if (num_batch_dims < 0) return Error::InvalidArgument("X must be rank 2 plus batch dims");
    for (int i = 0; i < X_buf.dimensions().size() - 2; i++) outer_batch_size *= X_buf.dimensions()[i];
    const int O = outer_batch_size;
    //check N matches
    const int total_N = X_buf.dimensions()[num_batch_dims];
    if (total_N % B != 0) return Error::InvalidArgument("X first dimension must be multiple of B");
    const int N = total_N / B;
    //if (S_buf->dimensions().size() != 3+num_batch_dims) return Error::InvalidArgument("S must be rank 3 plus batch dims");
    //if (S_buf->dimensions()[num_batch_dims] != N) return Error::InvalidArgument("S first dimension must match X");
    //int lab_buf_elems = 1;
    //for (int i = num_batch_dims; i < lab_buf->dimensions().size(); i++) lab_buf_elems *= lab_buf->dimensions()[i];
    //if (lab_buf_elems != N * B) return Error::InvalidArgument("lab must have N*B elements");
    const size_t shared_bytes = 0;
    dim3 grid(outer_batch_size);
    dim3 block(32 * num_warps);
    auto X_layout = make_layout(make_shape(O, N, Int<B>{}, Int<D>{}), LayoutRight{});
    auto C_layout = make_layout(make_shape(O, Int<K>{}, Int<D>{}), LayoutRight{});
    auto cnt_layout = make_layout(make_shape(O, Int<K>{}), LayoutRight{});
    auto S_layout = make_layout(make_shape(O, N, Int<B>{}, Int<K>{}), LayoutRight{});
    auto lab_layout = make_layout(make_shape(O, N, Int<B>{}), LayoutRight{});
    adjust_fp16_kernel<B,D,K,num_warps><<<grid, block, shared_bytes, stream>>>(
        make_tensor(reinterpret_cast<const fp16*>(X_buf.typed_data()), X_layout),
        make_tensor(reinterpret_cast<const fp16*>(Ctot_buf.typed_data()), C_layout),
        make_tensor(reinterpret_cast<const fp16*>(Ccnt_buf.typed_data()), cnt_layout),
        //make_tensor(reinterpret_cast<fp16*>(S_buf->typed_data()), S_layout),
        //make_tensor(lab_buf->typed_data(), lab_layout),
        make_tensor(reinterpret_cast<fp16*>(Ctot_out_buf->typed_data()), C_layout),
        make_tensor(reinterpret_cast<fp16*>(Ccnt_out_buf->typed_data()), cnt_layout),
        fp16(beta),
        iters
    );
    cudaError_t last_err = cudaGetLastError();
    if (last_err != cudaSuccess) return Error::Internal(std::string("CUDA error: ") + cudaGetErrorString(last_err));
    return Error::Success();
}

Error adjust_fp16_host_multiK(
    cudaStream_t stream,
    Buffer<F16> X_buf,
    Buffer<F16> Ctot_buf,
    Buffer<F16> Ccnt_buf,
    //ResultBuffer<F16> S_buf,
    //ResultBuffer<S32> lab_buf,
    ResultBuffer<F16> Ctot_out_buf,
    ResultBuffer<F16> Ccnt_out_buf,
    float beta,
    int iters
) {
    const int num_batch_dims = Ctot_buf.dimensions().size() - 2;
    if (num_batch_dims < 0) return Error::InvalidArgument("Ctot must be rank 2 plus batch dims");
    const int K = Ctot_buf.dimensions()[num_batch_dims];
    switch (K) {
        //case 8: return adjust_fp16_host<8,1>(stream, X_buf, Ctot_buf, Ccnt_buf, S_buf, lab_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
        //case 16: return adjust_fp16_host<16,2>(stream, X_buf, Ctot_buf, Ccnt_buf, S_buf, lab_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
	case 8: return adjust_fp16_host<8,1>(stream, X_buf, Ctot_buf, Ccnt_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
	case 16: return adjust_fp16_host<16,2>(stream, X_buf, Ctot_buf, Ccnt_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
        case 32: return adjust_fp16_host<32,4>(stream, X_buf, Ctot_buf, Ccnt_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
        case 64: return adjust_fp16_host<64,4>(stream, X_buf, Ctot_buf, Ccnt_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
	case 128: return adjust_fp16_host<128,8>(stream, X_buf, Ctot_buf, Ccnt_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
	case 256: return adjust_fp16_host<256,16>(stream, X_buf, Ctot_buf, Ccnt_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
	case 512: return adjust_fp16_host<512,16>(stream, X_buf, Ctot_buf, Ccnt_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
	case 1024: return adjust_fp16_host<1024,16>(stream, X_buf, Ctot_buf, Ccnt_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
	case 2048: return adjust_fp16_host<2048,16>(stream, X_buf, Ctot_buf, Ccnt_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
        //case 128: return adjust_fp16_host<128,16>(stream, X_buf, Ctot_buf, Ccnt_buf, S_buf, lab_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
        //case 256: return adjust_fp16_host<256,16>(stream, X_buf, Ctot_buf, Ccnt_buf, S_buf, lab_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
        //case 512: return adjust_fp16_host<512,16>(stream, X_buf, Ctot_buf, Ccnt_buf, S_buf, lab_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
        //case 1024: return adjust_fp16_host<1024,16>(stream, X_buf, Ctot_buf, Ccnt_buf, S_buf, lab_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
        //case 2048: return adjust_fp16_host<2048,16>(stream, X_buf, Ctot_buf, Ccnt_buf, S_buf, lab_buf, Ctot_out_buf, Ccnt_out_buf, beta, iters);
        default: return Error::InvalidArgument("K must be 8/16/32/64/128/256/512/1024/2048");
    }
}


XLA_FFI_DEFINE_HANDLER_SYMBOL(
    adjust_fp16,
    adjust_fp16_host_multiK,
    Ffi::Bind()
        .Ctx<PlatformStream<cudaStream_t>>()
        .Arg<Buffer<F16>>()
        .Arg<Buffer<F16>>()
        .Arg<Buffer<F16>>()
        //.Ret<Buffer<F16>>()
        //.Ret<Buffer<S32>>()
        .Ret<Buffer<F16>>()
        .Ret<Buffer<F16>>()
        .Attr<float>("beta")
        .Attr<int>("iters"),
    {Traits::kCmdBufferCompatible}
);
