
/**
 * @file
 * @brief Functions for a warpgroup to collaboratively transfer  data directly between global memory and registers and back.
 */

/**
 * @brief Collaboratively loads data into register vectors from a source array in global memory.
 *
 * @tparam RV The register vector type.
 * @tparam U The data type of the source array.
 * @param[out] dst The destination register vector to load data into.
 * @param[in] src The source array in global memory to load data from.
 */
template<typename RV, typename U>
METAL_FUNC static typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
load(thread RV &dst, device const U *_src, const int threadIdx) {
    using T  = typename RV::dtype;
    using U2 = typename base_types::packing<U>::packed_type;
    using T2 = typename base_types::packing<T>::packed_type;
    
    device const U *src = &_src[warpid(threadIdx) * RV::outer_dim * kittens::ore::TILE_DIM]; // pretend smaller, do single warp load.
    
    // Call warp level store
    ::kittens::ore::load<RV, U>(dst, src, simd_laneid(threadIdx));
}

/**
 * @brief Collaboratively stores data from register vectors to a destination array in global memory.
 *
 * @tparam RV The register vector type.
 * @tparam U The data type of the destination array.
 * @param[out] dst The destination array in global memory to store data into.
 * @param[in] src The source register vector to store data from.
 */
template<typename RV, typename U>
METAL_FUNC static typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
store(device U *_dst, thread const RV &src, const int threadIdx) {
    using T  = typename RV::dtype;
    using U2 = typename base_types::packing<U>::packed_type;
    using T2 = typename base_types::packing<T>::packed_type;
    
    device U *dst = &_dst[warpid(threadIdx) * RV::outer_dim * kittens::ore::TILE_DIM]; // pretend smaller, do single warp store.

    // Call warp level store
    ::kittens::ore::store<RV, U>(dst, src, simd_laneid(threadIdx));
}
