/**
 * @file
 * @brief Group (collaborative warp) ops for loading shared vectors from and storing to global memory.
 */

/**
 * @brief Loads data from global memory into shared memory vector.
 *
 * This function loads data from a global memory location pointed to by `src` into a shared memory vector `dst`.
 * It calculates the number of elements that can be transferred in one operation based on the size ratio of `float4` to the data type of `SV`.
 * The function ensures coalesced memory access and efficient use of bandwidth by dividing the work among threads in a warp.
 *
 * @tparam SV Shared vector type, must satisfy ducks::sv::all concept.
 * @param dst Reference to the shared vector where the data will be loaded.
 * @param src Pointer to the global memory location from where the data will be loaded.
 */
template<typename SV>
METAL_FUNC static typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
load(threadgroup SV &dst, device const typename SV::dtype *src, const int threadIdx) {
    constexpr int elem_per_transfer = sizeof(float4) / sizeof(typename SV::dtype);
    constexpr int total_calls = SV::length / elem_per_transfer; // guaranteed to divide
    metal::simdgroup_barrier(metal::mem_flags::mem_none);
    #pragma clang loop unroll(full)
    for(int i = simd_laneid(threadIdx); i < total_calls; i+=GROUP_THREADS) {
        if(i * elem_per_transfer < dst.length)
            *(threadgroup float4*)&dst[i*elem_per_transfer] = *(device float4*)&src[i*elem_per_transfer];
    }
}
 
/**
 * @brief Stores data from a shared memory vector to global memory.
 *
 * This function stores data from a shared memory vector `src` to a global memory location pointed to by `dst`.
 * Similar to the load function, it calculates the number of elements that can be transferred in one operation based on the size ratio of `float4` to the data type of `SV`.
 * The function ensures coalesced memory access and efficient use of bandwidth by dividing the work among threads in a warp.
 *
 * @tparam SV Shared vector type, must satisfy ducks::sv::all concept.
 * @param dst Pointer to the global memory location where the data will be stored.
 * @param src Reference to the shared vector from where the data will be stored.
 */
template<typename SV>
METAL_FUNC static typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
store(device typename SV::dtype *dst, threadgroup const SV &src, const int threadIdx) {
    constexpr int elem_per_transfer = sizeof(float4) / sizeof(typename SV::dtype);
    constexpr int total_calls = SV::length / elem_per_transfer; // guaranteed to divide
    metal::simdgroup_barrier(metal::mem_flags::mem_none);
    #pragma clang loop unroll(full)
    for(int i = threadIdx % GROUP_THREADS; i < total_calls; i+= GROUP_THREADS) {
        if(i * elem_per_transfer < src.length)
            *(device float4*)&dst[i*elem_per_transfer] = *(threadgroup float4*)&src[i*elem_per_transfer]; // lmao it's identical
    }
}
