#include <cstring>
#include <type_traits>
#include <vector_types.h>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda/std/atomic>
#include <ATen/Tensor.h>
#include <ATen/ops/empty.h>
#include <Python.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_bf16.h>
#include <torch/library.h>
#include <vector>

namespace detail
{
    enum class TransferMode {
        DEFAULT,
        LDG
    };

    template<TransferMode Mode>
    struct Transfer;

    template<>
    struct Transfer<TransferMode::DEFAULT> {
        template<class T>
        __host__ __device__ static void call(T* dst, const T* src) {
            *dst = *src;
        }
    };

    template<>
    struct Transfer<TransferMode::LDG> {
        template<class T>
        __device__ static void call(T* dst, const T* src) {
            *dst = __ldg(src);
        }
    };

    template<class CopyType, int NBytes, TransferMode Mode, class TrueType>
    __host__ __device__ void memcpy_as(TrueType* __restrict__ dst, const TrueType* __restrict__ src) {
        static_assert(NBytes % sizeof(TrueType) == 0, "Number of bytes must be a multiple of the true type size");
        static_assert(NBytes % sizeof(CopyType) == 0, "Number of bytes must be a multiple of the copy type size");

        // in order to do simple byte-level copying, the underlying type must be trivially copyable (i.e., compatible
        // with memcpy)
        static_assert(std::is_trivially_copyable_v<TrueType>, "TrueType must be trivially copyable");
        const auto* read_address = reinterpret_cast<const CopyType*>(src);
        auto* write_address = reinterpret_cast<CopyType*>(dst);
#pragma unroll
        for (int i = 0; i < NBytes; i += sizeof(CopyType)) {
            Transfer<Mode>::call(write_address, read_address);
            ++read_address;
            ++write_address;
        }
    }

    constexpr __host__ __device__ std::size_t alignment_from_size(std::size_t size) {
    for (int i = 2; i <= 16; i *= 2) {
    if ((size % i) != 0) {
    return i / 2;
}
}
return 16;
}
}// namespace detail
template<std::size_t Count, detail::TransferMode Mode, class T>
__host__ __device__ void memcpy_aligned(T* dst, const T* src, std::integral_constant<std::size_t, Count> = {}) {
    static_assert(std::is_trivially_copyable_v<T>, "T must be trivially copyable");

    constexpr const int NBytes = sizeof(T) * Count;
    using detail::memcpy_as;

    if constexpr (NBytes % sizeof(int4) == 0) {
        memcpy_as<int4, NBytes, Mode>(dst, src);
    } else if constexpr (NBytes % sizeof(int2) == 0) {
        memcpy_as<int2, NBytes, Mode>(dst, src);
    } else if constexpr (NBytes % sizeof(int1) == 0) {
        memcpy_as<int1, NBytes, Mode>(dst, src);
    } else if constexpr (NBytes % sizeof(short1) == 0) {
        memcpy_as<short1, NBytes, Mode>(dst, src);
    } else {
        memcpy_as<char1, NBytes, Mode>(dst, src);
    }
}

template<class ElementType, std::size_t ElementCount>
class alignas(detail::alignment_from_size(sizeof(ElementType) * ElementCount)) GenericVector {
static_assert(std::is_trivial_v<ElementType>, "Only trivial types are supported");

public:
GenericVector() = default;

constexpr static __host__ __device__ GenericVector constant(ElementType value) {
GenericVector result;
for (int k = 0; k < size; ++k) {
result.values[k] = value;
}
return result;
}

constexpr static __host__ __device__ GenericVector zeros() {
    return constant(0.f);
}

constexpr static __host__ __device__ GenericVector ones() {
    return constant(1.f);
}

constexpr __host__ __device__ ElementType& operator[](int index) {
    return values[index];
}

constexpr __host__ __device__ const ElementType& operator[](int index) const {
    return values[index];
}

static constexpr const std::size_t size = ElementCount;
static constexpr const std::size_t bytes = ElementCount * sizeof(ElementType);

static __host__ __device__ GenericVector load(const ElementType* address) {
    GenericVector result;
    memcpy_aligned<size, detail::TransferMode::DEFAULT>(result.values, address);
    return result;
}

static __device__ GenericVector load_ldg(const ElementType* address) {
    GenericVector result;
    memcpy_aligned<size, detail::TransferMode::LDG>(result.values, address);
    return result;
}

__host__ __device__ void store(ElementType* dst) {
    memcpy_aligned<size, detail::TransferMode::DEFAULT>(dst, values);
}

private:
ElementType values[size];
};


struct Shape {
    int F;          // fragments
    int W;          // workers
    int Hq;         // q heads
    int Hkv;        // kv heads
    int E;          // qk dim
    int Ev;         // v dim
    int S;          // q seq len
};

namespace cg = cooperative_groups;

struct BlockResult {
    float lse;
    cuda::atomic<int, cuda::thread_scope_device> counter;
};

constexpr const int SubWarpSize = 16;

template<int E, int Ev, int GQA, class scalar_t>
__global__ __launch_bounds__(384) void hogwild_attention_gpu_kernel12(
        scalar_t* out, char* workspace, float scale,
        const int* locations, const scalar_t* queries,
        const int* fragment_lengths,
        const scalar_t* const* key_fragments,
        const scalar_t* const* value_fragments,
        Shape shape) {
    // Input:   keys: [Hkv, fragment_lengths[i], E] for i in [F]
    //          values: [Hkv, fragment_lengths[i], Ev] for i in [F]
    //          fragment_lengths: [F]
    //          queries: [F, W, Hq, S, E]
    //          locations [F, W, S]
    // Scratch: workspace [W, Hq, S, Ev] (in float32, iff scalar_t != float32) + [W, Hq, S] BlockResult
    // Output:  [W, Hq, S, Ev]
    // attention mask: s attends to l iff locations[b, s] >= l (i.e., shifted causal masking)

    int W = shape.W;
    int Hq = shape.Hq;
    int S = shape.S;
    assert(E == shape.E);
    assert(Ev == shape.Ev);

    auto block = cg::this_thread_block();
    auto warp = cg::tiled_partition<32>(block);
    auto sub_warp = cg::tiled_partition<SubWarpSize>(block);

    ptrdiff_t q_stride = E * S * Hq * W;
    extern __shared__ float scratch[];

    int hkv = blockIdx.x;
    int w = blockIdx.y % W;
    int s = blockIdx.y / W;
    int split = blockIdx.z;
    int splits = gridDim.z;

    int hq = hkv * GQA;
    ptrdiff_t q_offset = ((w * Hq + hq) * S + s) * E;

    constexpr const int VecSize = 4;
    constexpr int VPH_k = E / (SubWarpSize * VecSize);   // vectors per head per thread
    constexpr int VPH_v = Ev / (SubWarpSize * VecSize);   // vectors per head per thread

    using vec_t = GenericVector<scalar_t, VecSize>;
    using fvec_t = GenericVector<float, VecSize>;
    using q_cache_t = GenericVector<scalar_t, E / SubWarpSize>;
    q_cache_t q_cache[GQA];

    // combine values
    using v_cache_t = GenericVector<float, Ev / SubWarpSize>;
    v_cache_t v_cache[GQA];
    float maximum[GQA];
    for (int gqa = 0; gqa < GQA; ++gqa) {
        v_cache[gqa] = v_cache_t::zeros();
        maximum[gqa] = std::numeric_limits<float>::lowest();
    }

    // determine maximum and online logsumexp
    float lse[GQA] = {};
    {
        for (int f = 0; f < shape.F; ++f) {
            int q_loc = locations[(f * W + w) * S + s];
            int L = fragment_lengths[f];
            int maxL = std::min(L, q_loc + 1);

            for (int gqa = 0; gqa < GQA; ++gqa) {
                for (int ee = 0; ee < E / (SubWarpSize * VecSize); ++ee) {
                    int e = (ee * SubWarpSize + sub_warp.thread_rank()) * VecSize;
                    vec_t qv = vec_t::load(queries + f * q_stride + q_offset + gqa * S * E + e);
                    for (int j = 0; j < VecSize; ++j) {
                        q_cache[gqa][ee * VecSize + j] = qv[j];
                    }
                }
            }

            const scalar_t *value_fragment = value_fragments[f];
            const scalar_t *key_fragment = key_fragments[f];

            for (int l = sub_warp.meta_group_rank() * splits + split;
                 l < maxL; l += sub_warp.meta_group_size() * splits) {
                ptrdiff_t k_offset = (hkv * L + l) * E;
                ptrdiff_t v_offset = (hkv * L + l) * Ev;

                vec_t keys[VPH_k];
                vec_t vals[VPH_v];
                #pragma unroll
                for (int ee = 0; ee < VPH_k; ++ee) {
                    int e = (ee * SubWarpSize + sub_warp.thread_rank()) * VecSize;
                    keys[ee] = vec_t::load_ldg(key_fragment + k_offset + e);
                }

                #pragma unroll
                for (int ee = 0; ee < VPH_v; ++ee) {
                    int e = (ee * SubWarpSize + sub_warp.thread_rank()) * VecSize;
                    vals[ee] = vec_t::load_ldg(value_fragment + v_offset + e);
                }

                #pragma unroll
                for (int gqa = 0; gqa < GQA; ++gqa) {
                    float qk = 0;
                    for (int ee = 0; ee < VPH_k; ++ee) {
                        vec_t kv = keys[ee];
                        for (int j = 0; j < VecSize; ++j) {
                            qk += (float) q_cache[gqa][ee * VecSize + j] * (float) kv[j];
                        }
                    }
                    qk = cg::reduce(sub_warp, qk, cg::plus<float>{});
                    if (qk > maximum[gqa]) {
                        float rescale = std::exp(scale * (maximum[gqa] - qk));
                        for (int j = 0; j < v_cache_t::size; ++j) {
                            v_cache[gqa][j] *= rescale;
                        }
                        lse[gqa] *= rescale;
                        maximum[gqa] = qk;
                    }
                    float att = std::exp(scale * (qk - maximum[gqa]));
                    lse[gqa] += std::exp(scale * (qk - maximum[gqa]));

                    for (int ee = 0; ee < VPH_v; ++ee) {
                        vec_t vv = vals[ee];
                        for (int j = 0; j < VecSize; ++j) {
                            v_cache[gqa][ee * VecSize + j] += att * (float) vv[j];
                        }
                    }
                }
            }
        }
    }

    #pragma unroll
    for (int gqa = 0; gqa < GQA; ++gqa) {
        using m_lse_t = GenericVector<float, 2>;
        m_lse_t data;
        int h = hkv * GQA + gqa;
        // combine split-k results
        if (sub_warp.thread_rank() == 0) {
            data[0] = maximum[gqa];
            data[1] = lse[gqa];
            data.store(scratch + 2 * sub_warp.meta_group_rank());
        }

        __syncthreads();
        float r_max = maximum[gqa];
        float l_max = maximum[gqa];
        float r_lse = 0;
        if (warp.thread_rank() < sub_warp.meta_group_size()) {
            data = m_lse_t::load(scratch + 2*warp.thread_rank());
            r_max = data[0];
            r_lse = data[1];
        }

        maximum[gqa] = cg::reduce(warp, r_max, cg::greater<float>{});
        r_lse *= std::exp(scale * (r_max - maximum[gqa]));
        lse[gqa] = cg::reduce(warp, r_lse, cg::plus<float>{});
        // Note: It *is* possible that no thread in this warp had any valid position (due to causal masking),
        // which would lead to division by zero -> 0 * inf = NaN here.
        if(lse[gqa] != 0) {
            float rescale = std::exp(scale * (l_max - maximum[gqa])) / lse[gqa];
            for (int j = 0; j < v_cache_t::size; ++j) {
                v_cache[gqa][j] *= rescale;
            }
        }
        __syncthreads();

        for (int ee = 0; ee < Ev / (SubWarpSize * VecSize); ++ee) {
            int e = (ee * SubWarpSize + sub_warp.thread_rank()) * VecSize;
            fvec_t store;
            for (int j = 0; j < VecSize; ++j) {
                store[j] = v_cache[gqa][ee * VecSize + j];
            }
            store.store(scratch + e + Ev * sub_warp.meta_group_rank());
        }

        int res_idx = ((w * Hq + h) * S + s);
        float *global_accumulator = nullptr;
        BlockResult *my_result_info = nullptr;
        if (warp.meta_group_rank() == 0) {
            if constexpr (std::is_same_v<scalar_t, float>) {
                global_accumulator = out;
                my_result_info = reinterpret_cast<BlockResult *>(workspace) + res_idx;
            } else {
                global_accumulator = reinterpret_cast<float *>(workspace);
                float *acc_end = global_accumulator + W * Hq * S * Ev;
                my_result_info = reinterpret_cast<BlockResult *>(acc_end) + res_idx;
            }
        }

        __syncthreads();
        if (warp.meta_group_rank() == 0) {
            // only first block doesn't need to wait
            if (split != 0) {
                // wait in a single thread until our block has its turn
                if (threadIdx.x == 0) {
                    while (my_result_info->counter.load() != split) {
                    }
                }
                __syncwarp();
            }
            __threadfence();    // make sure the results written by the previous block are visible.
            float res_lse = (split != 0) ? my_result_info->lse : std::numeric_limits<float>::lowest();
            lse[gqa] = std::log(lse[gqa]) + scale * maximum[gqa];
            float max = std::max(lse[gqa], res_lse);
            float sa = std::exp(lse[gqa] - max);
            float sb = std::exp(res_lse - max);

            // write result
            for (int e = VecSize * warp.thread_rank(); e < Ev; e += VecSize * warp.size()) {
                // merge the local results
                fvec_t res = fvec_t::zeros();
                for (int j = 0; j < sub_warp.meta_group_size(); ++j) {
                    fvec_t sv = fvec_t::load(scratch + e + Ev * j);
                    for (int jj = 0; jj < VecSize; ++jj) {
                        res[jj] += sv[jj];
                    }
                }

                if (split == 0) {
                    // first one just writes
                    res.store(global_accumulator + res_idx * Ev + e);
                } else if (split == splits - 1) {
                    // last one does final conversion
                    fvec_t old = fvec_t::load(global_accumulator + res_idx * Ev + e);
                    vec_t cv;
                    for (int j = 0; j < VecSize; ++j) {
                        cv[j] = (scalar_t) ((res[j] * sa + old[j] * sb) / (sa + sb));
                    }
                    cv.store(out + res_idx * Ev + e);
                } else {
                    // everyone else merges
                    fvec_t old = fvec_t::load(global_accumulator + res_idx * Ev + e);
                    for (int j = 0; j < VecSize; ++j) {
                        res[j] = (res[j] * sa + old[j] * sb) / (sa + sb);
                    }
                    res.store(global_accumulator + res_idx * Ev + e);
                }
            }
            if (threadIdx.x == 0) {
                my_result_info->lse = std::log(sa + sb) + max;
            }
            __syncwarp();
            __threadfence();  // ensure everything is written before we signal
            if (threadIdx.x == 0) {
                if(split != splits - 1) {
                    // signal the next block to go
                    my_result_info->counter.store(split + 1);
                } else {
                    // last block resets memory
                    my_result_info->counter.store(0);
                }
            }
        }
        __syncthreads();
    }
}

template<class scalar_t>
void hogwild_attention_gpu(scalar_t* out, float scale,
                           const int* locations, const scalar_t* queries,
                           const int* fragment_lengths,
                           const scalar_t** key_fragments,
                           const scalar_t** value_fragments,
                           Shape shape) {
    int problem_size = shape.Hkv * shape.W * shape.S;
    int sms = -1;
    if(cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0) != cudaSuccess) return;
    // Note: The current kernel will **not** work if there is only one split!
    int splits = max(2, sms / problem_size);

    dim3 grid_dim{(unsigned)shape.Hkv, (unsigned)shape.W * (unsigned)shape.S, (unsigned)splits};
    dim3 block_dim{384, 1, 1};
    size_t smem = shape.Ev * sizeof(float) * block_dim.x / SubWarpSize;
    static char* workspace = nullptr;
    static std::size_t workspace_size = 0;

    std::size_t required_workspace = shape.W * shape.Hq * shape.S; // [W, Hq, S]
    if(workspace_size < required_workspace) {
        if(workspace)
            assert(cudaFree(workspace));
        size_t alloc = required_workspace * sizeof(BlockResult);
        if constexpr (!std::is_same_v<scalar_t, float>)
        alloc += sizeof(float ) * required_workspace * shape.Ev;
        if(cudaMalloc(&workspace, alloc) != cudaSuccess) return;
        if(cudaMemset(workspace, 0, alloc) != cudaSuccess) return;
        workspace_size = required_workspace;
    }

    if (shape.E == 128 && shape.Ev == 128 && shape.Hq == shape.Hkv * 5) {
        hogwild_attention_gpu_kernel12< 128, 128, 5 ><<<grid_dim, block_dim, smem>>>(
                out, workspace, scale, locations, queries, fragment_lengths, key_fragments, value_fragments, shape);
    } else {
        printf("Unsupported head dimension");
    }
}

template<class scalar_t>
__global__ void rope_kernel(
        scalar_t* rotated_queries, const scalar_t* queries, const float* cosines, const float* sines,
        int F, int W, int Hq, int S, int E)
{
    int f = blockIdx.x / S;
    int s = blockIdx.x % S;
    int h = blockIdx.y;
    int w = blockIdx.z;

    const scalar_t* query = queries + ((w * Hq + h) * S + s) * E;
    scalar_t* result = rotated_queries + (((f * W + w) * Hq + h) * S + s) * E;
    int e = threadIdx.x;
    float x1 = query[e];
    float x2 = query[e + E/2];

    // fetch a tuple of activations, which we imagine as a complex number
    int offset = (((f*W + w) * S + s) * E);

    result[e] = x1 * cosines[offset + e] - x2 * sines[offset + e];
    result[e + E/2] = x2 * cosines[offset + e + E/2] + x1 * sines[offset + e + E/2];
}

template<class scalar_t>
void rope_gpu(
        scalar_t* rotated_queries, const scalar_t* queries, const float* cosines, const float* sines,
        int F, int W, int Hq, int S, int E) {
    dim3 grid_dim(F*S , Hq, W);
    dim3 block_dim(E/2, 1, 1);
    rope_kernel<<<grid_dim, block_dim>>>(rotated_queries, queries, cosines, sines, F, W, Hq, S, E);
}

template<class scalar_t>
const scalar_t* torch_get_pointer(const at::Tensor& tensor) {
    if constexpr (std::is_same_v<scalar_t, float>) {
        return tensor.const_data_ptr<float>();
    } else if constexpr (std::is_same_v<scalar_t, half>) {
        return reinterpret_cast<const half*>(tensor.const_data_ptr<at::Half>());
    } else if constexpr (std::is_same_v<scalar_t, nv_bfloat16>) {
        return reinterpret_cast<const nv_bfloat16*>(tensor.const_data_ptr<at::BFloat16>());
    } else {
        return nullptr;
    }
}

template<class scalar_t>
scalar_t* torch_get_pointer(at::Tensor& tensor) {
    if constexpr (std::is_same_v<scalar_t, float>) {
        return tensor.data_ptr<float>();
    } else if constexpr (std::is_same_v<scalar_t, half>) {
        return reinterpret_cast<half*>(tensor.data_ptr<at::Half>());
    } else if constexpr (std::is_same_v<scalar_t, nv_bfloat16>) {
        return reinterpret_cast<nv_bfloat16*>(tensor.data_ptr<at::BFloat16>());
    } else {
        return nullptr;
    }
}

template<class scalar_t>
void hogwild_attention_tpl(
        at::Tensor& out, double scale, const at::Tensor& locations, const at::Tensor& queries,
        const at::Tensor& fragment_lengths, const std::vector<at::Tensor>& key_fragments,
        const std::vector<at::Tensor>& value_fragments)
{
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
    // extract pointers and sizes
    int W = out.size(0);
    int Hq = out.size(1);
    int S = out.size(2);
    int Ev = out.size(3);
    TORCH_CHECK(out.is_contiguous());
    scalar_t* out_ptr = torch_get_pointer<scalar_t>(out);
    // Input:   keys: [Hkv, fragment_lengths[i], E] for i in [F]
    //          values: [Hkv, fragment_lengths[i], Ev] for i in [F]

    int F = locations.size(0);
    TORCH_CHECK_EQ(locations.size(1), W);
    TORCH_CHECK_EQ(locations.size(2), S);
    TORCH_CHECK(locations.is_contiguous());
    const int* loc_ptr = locations.const_data_ptr<int>();

    int E = queries.size(4);
    TORCH_CHECK_EQ(queries.size(0), F);
    TORCH_CHECK_EQ(queries.size(1), W);
    TORCH_CHECK_EQ(queries.size(2), Hq);
    TORCH_CHECK_EQ(queries.size(3), S);
    TORCH_CHECK(queries.is_contiguous() && "SDPA");
    const scalar_t* query_ptr = torch_get_pointer<scalar_t>(queries);

    TORCH_CHECK_EQ(fragment_lengths.size(0), F);
    TORCH_CHECK(fragment_lengths.is_contiguous());
    const int* fl_ptr = fragment_lengths.const_data_ptr<int>();

    // check key and value fragments
    TORCH_CHECK_EQ(key_fragments.size(), F);
    TORCH_CHECK_EQ(value_fragments.size(), F);
    // Make exactly one cached memory allocation to store the pointers in
    // NOTE: This is neither thread safe, nor will this memory ever be released again.
    static const scalar_t** frag_ptrs = nullptr;
    if(frag_ptrs == nullptr) {
        C10_CUDA_CHECK(cudaMalloc(&frag_ptrs, sizeof(void *) * 1024));
    }

    std::vector<const scalar_t*> frag_ptrs_host(2*F);
    bool has_batch_dim = key_fragments[0].dim() == 4;
    int fo = has_batch_dim ? 1 : 0;
    int Hkv = key_fragments[0].size(fo);
    for(int f = 0; f < F; ++f) {
        TORCH_CHECK_EQ(key_fragments[f].size(fo + 0), Hkv);
        TORCH_CHECK_EQ(value_fragments[f].size(fo + 0), Hkv);
        int fl = key_fragments[f].size(fo + 1);
        TORCH_CHECK_EQ(value_fragments[f].size(fo + 1), fl);
        TORCH_CHECK_EQ(key_fragments[f].size(fo + 2), E);
        TORCH_CHECK_EQ(value_fragments[f].size(fo + 2), Ev);

        TORCH_CHECK(key_fragments[f].is_contiguous() && "SDPA");
        TORCH_CHECK(value_fragments[f].is_contiguous());

        frag_ptrs_host[f] = torch_get_pointer<scalar_t>(key_fragments[f]);
        frag_ptrs_host[F + f] = torch_get_pointer<scalar_t>(value_fragments[f]);
    }

    C10_CUDA_CHECK(cudaMemcpyAsync(frag_ptrs, frag_ptrs_host.data(), 2*sizeof(void*)*F, cudaMemcpyHostToDevice));

    // finally, launch
    Shape shape = {F, W, Hq, Hkv, E, Ev, S};
    hogwild_attention_gpu(out_ptr, (float)scale, loc_ptr, query_ptr, fl_ptr,
                          frag_ptrs, frag_ptrs + F, shape);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template<class scalar_t>
void hogwild_rope_tpl(
        at::Tensor& out, const at::Tensor& queries, const at::Tensor& cosines,
        const at::Tensor& sines)
{
    // extract pointers and sizes
    int F = out.size(0);
    int W = out.size(1);
    int Hq = out.size(2);
    int S = out.size(3);
    int E = out.size(4);
    TORCH_CHECK(out.is_contiguous());
    TORCH_CHECK(queries.is_contiguous() && "ROPE");
    TORCH_CHECK(cosines.is_contiguous());
    TORCH_CHECK(sines.is_contiguous());

    TORCH_CHECK_EQ(queries.size(0), W);
    TORCH_CHECK_EQ(queries.size(1), Hq);
    TORCH_CHECK_EQ(queries.size(2), S);
    TORCH_CHECK_EQ(queries.size(3), E);

    TORCH_CHECK_EQ(cosines.size(0), F);
    TORCH_CHECK_EQ(cosines.size(1), W);
    TORCH_CHECK_EQ(cosines.size(2), S);
    TORCH_CHECK_EQ(cosines.size(3), E);

    TORCH_CHECK_EQ(sines.size(0), F);
    TORCH_CHECK_EQ(sines.size(1), W);
    TORCH_CHECK_EQ(sines.size(2), S);
    TORCH_CHECK_EQ(sines.size(3), E);

    rope_gpu(torch_get_pointer<scalar_t>(out), torch_get_pointer<scalar_t>(queries),
             torch_get_pointer<float>(cosines), torch_get_pointer<float>(sines),
             F, W, Hq, S, E);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
}

void hogwild_attention(
        at::Tensor& out, double scale, const at::Tensor& locations, const at::Tensor& queries,
        const at::Tensor& fragment_lengths, const std::vector<at::Tensor>& key_fragments,
        const std::vector<at::Tensor>& value_fragments)
{
    if(out.dtype() == at::kHalf) {
        hogwild_attention_tpl<half>(out, scale, locations, queries, fragment_lengths, key_fragments, value_fragments);
    } else if (out.dtype() == at::kFloat) {
        hogwild_attention_tpl<float>(out, scale, locations, queries, fragment_lengths, key_fragments, value_fragments);
    } else if (out.dtype() == at::kBFloat16) {
        hogwild_attention_tpl<nv_bfloat16>(out, scale, locations, queries, fragment_lengths, key_fragments, value_fragments);
    }
}

void hogwild_rope(
        at::Tensor& out, const at::Tensor& queries, const at::Tensor& cosines, const at::Tensor& sines)
{
    if(out.dtype() == at::kHalf) {
        hogwild_rope_tpl<half>(out, queries, cosines, sines);
    } else if (out.dtype() == at::kFloat) {
        hogwild_rope_tpl<float>(out, queries, cosines, sines);
    } else if (out.dtype() == at::kBFloat16) {
        hogwild_rope_tpl<nv_bfloat16>(out, queries, cosines, sines);
    }
}

void hogwild_fused(
        at::Tensor& out, at::Tensor& rotated_queries, double scale, const at::Tensor& locations, const at::Tensor& queries,
        const at::Tensor& fragment_lengths, const std::vector<at::Tensor>& key_fragments,
        const std::vector<at::Tensor>& value_fragments,
        const at::Tensor& cosines, const at::Tensor& sines)
{
    std::vector<at::Tensor> key_fragments_contiguous;
    std::vector<at::Tensor> val_fragments_contiguous;
    key_fragments_contiguous.reserve(key_fragments.size());
    val_fragments_contiguous.reserve(key_fragments.size());
    for(int i = 0; i < key_fragments.size(); ++i) {
        key_fragments_contiguous.push_back(key_fragments[i].contiguous());
        val_fragments_contiguous.push_back(value_fragments[i].contiguous());
    }
    hogwild_rope(rotated_queries, queries, cosines, sines);
    hogwild_attention(out, scale, locations, rotated_queries, fragment_lengths, key_fragments_contiguous, val_fragments_contiguous);
}

extern "C" {
/* Creates a dummy empty _C module that can be imported from Python.
   The import from Python will load the .so consisting of this file
   in this extension, so that the TORCH_LIBRARY static initializers
   below are run. */
PyObject* PyInit_libhogatt(void) {
    static struct PyModuleDef module_def = {
            PyModuleDef_HEAD_INIT,
            "libhogatt", /* name of module */
            NULL,            /* module documentation, may be NULL */
            -1,              /* size of per-interpreter state of the module,
                     or -1 if the module keeps state in global variables. */
            NULL,            /* methods */
    };
    return PyModule_Create(&module_def);
}
}

TORCH_LIBRARY(libhogatt, m) {
    std::vector<at::Tag> tags;
    tags.push_back(at::Tag::needs_fixed_stride_order);
    m.def("hogwild_sdpa(Tensor(a!) output, float scale, Tensor locations, Tensor queries, "
    "Tensor fragment_lengths, Tensor[] key_fragments, Tensor[] value_fragments) -> ()", tags, torch::_RegisterOrVerify::REGISTER);
    m.def("hogwild_rope(Tensor(a!) output, Tensor queries, Tensor cosines, Tensor sines) -> ()", tags, torch::_RegisterOrVerify::REGISTER);
    m.def("hogwild_fused(Tensor(a!) output, Tensor(b!) rq, float scale, Tensor locations, Tensor queries, "
    "Tensor fragment_lengths, Tensor[] key_fragments, Tensor[] value_fragments, Tensor cosines, Tensor sines) -> ()", tags, torch::_RegisterOrVerify::REGISTER);
}

TORCH_LIBRARY_IMPL(libhogatt, CUDA, m) {
    m.impl("hogwild_sdpa", hogwild_attention);
    m.impl("hogwild_rope", hogwild_rope);
    m.impl("hogwild_fused", hogwild_fused);
}
