/**
 * @file
 * @brief Warp-scope maps on shared vectors.
 */

#pragma once

#include "../../../../common/common.metal"
#include "../../../../types/types.metal"

namespace kittens {
namespace ore {

/**
 * @brief Performs a reduction operation on elements of a shared memory vector within a warp.
 *
 * This function applies a specified operation to reduce the elements of a shared memory vector `src` to a single value.
 * The result is stored in `accum`. If the `reset` parameter is true, the reduction includes an initial value `src_accum`.
 * The reduction operation is performed in a warp-wide context, ensuring synchronization between threads in the warp.
 *
 * @tparam op The operation to perform on the elements. Must provide a static `op` method.
 * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
 * @tparam reset A boolean flag indicating whether to include an initial value in the reduction.
 * @param[out] accum The result of the reduction operation.
 * @param[in] src The shared memory vector to reduce.
 * @param[in] src_accum The initial value to include in the reduction if `reset` is false.
 */
template<typename op, typename SV, bool reset>
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
reduce(thread typename SV::dtype &dst_accum, threadgroup const SV &src, thread const typename SV::dtype &src_accum, const ushort laneid) {
    using T = typename SV::dtype;
    T accum;
    if(laneid < src.length) accum = src[laneid]; // initialize a register accumulator
    
    metal::simdgroup_barrier(metal::mem_flags::mem_none);
    for(int i = laneid+kittens::ore::SIMD_THREADS; i < src.length; i+=kittens::ore::SIMD_THREADS) {
        accum = op::template op<T>(accum, src[i]);
    }
    metal::simdgroup_barrier(metal::mem_flags::mem_none);
    // We can now reduce within the warp.
    if (src.length >= 32) {
        accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 1));
        metal::simdgroup_barrier(metal::mem_flags::mem_none);
        accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 2));
        metal::simdgroup_barrier(metal::mem_flags::mem_none);
        accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 4));
        metal::simdgroup_barrier(metal::mem_flags::mem_none);
        accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 8));
        metal::simdgroup_barrier(metal::mem_flags::mem_none);
        accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 16));
        metal::simdgroup_barrier(metal::mem_flags::mem_none);
    } else {
        switch (src.length) {
            case(24):
                accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 1));
                metal::simdgroup_barrier(metal::mem_flags::mem_none);
                accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 2));
                metal::simdgroup_barrier(metal::mem_flags::mem_none);
                accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 4));
                metal::simdgroup_barrier(metal::mem_flags::mem_none);
                accum = op::template op<T>(accum, shfl_down_sync<T>(accum, base_types::constants<T>::zero(), 8));
                metal::simdgroup_barrier(metal::mem_flags::mem_none);
                accum = op::template op<T>(accum, shfl_down_sync<T>(accum, base_types::constants<T>::zero(), 16));
                metal::simdgroup_barrier(metal::mem_flags::mem_none);
                break;
            case (16):
                accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 1));
                metal::simdgroup_barrier(metal::mem_flags::mem_none);
                accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 2));
                metal::simdgroup_barrier(metal::mem_flags::mem_none);
                accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 4));
                metal::simdgroup_barrier(metal::mem_flags::mem_none);
                accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 8));
                metal::simdgroup_barrier(metal::mem_flags::mem_none);
                break;
            case (8):
                accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 1));
                metal::simdgroup_barrier(metal::mem_flags::mem_none);
                accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 2));
                metal::simdgroup_barrier(metal::mem_flags::mem_none);
                accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 4));
                metal::simdgroup_barrier(metal::mem_flags::mem_none);
                break;
            default:
                static_assert(SV::length % 8 == 0, "something went very very wrong with rv init");
        }
    }
    
    if (!reset) accum = op::template op<T>(accum, src_accum);
    // broadcast to all threads in the warp.
    dst_accum = shfl_sync(accum, 0);
}

/* ----------  WRAPPERS FOR PRETTINESS  ---------- */

/**
 * @brief Finds the maximum element in a shared memory vector.
 *
 * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
 * @param[out] max_val The maximum value found in the vector.
 * @param[in] src The shared memory vector to find the maximum in.
 */
template<typename SV>
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
max(thread typename SV::dtype &max_val, threadgroup const SV &src, const ushort laneid) {
    reduce<base_ops::max, SV, true>(max_val, src, max_val, laneid);
}

/**
 * @brief Finds the minimum element in a shared memory vector.
 *
 * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
 * @param[out] min_val The minimum value found in the vector.
 * @param[in] src The shared memory vector to find the minimum in.
 */
template<typename SV>
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
min(thread typename SV::dtype &min_val, threadgroup const SV &src, const ushort laneid) {
    reduce<base_ops::min, SV, true>(min_val, src, min_val);
}

/**
 * @brief Calculates the sum of elements in a shared memory vector.
 *
 * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
 * @param[out] sum_val The sum of the values in the vector.
 * @param[in] src The shared memory vector to sum.
 */
template<typename SV>
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
sum(thread typename SV::dtype &sum_val, threadgroup const SV &src, const ushort laneid) {
    reduce<base_ops::sum, SV, true>(sum_val, src, sum_val, laneid);
}

/**
 * @brief Calculates the product of elements in a shared memory vector.
 *
 * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
 * @param[out] prod_val The product of the values in the vector.
 * @param[in] src The shared memory vector to multiply.
 */
template<typename SV>
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
prod(thread typename SV::dtype &prod_val, threadgroup const SV &src, const ushort laneid) {
    reduce<base_ops::mul, SV, true>(prod_val, src, prod_val, laneid);
}

// Three operand versions.

/**
 * @brief Finds the maximum element in a shared memory vector and accumulates it with src_accum.
 *
 * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
 * @param[out] max_val The maximum value found in the vector, accumulated with src_accum.
 * @param[in] src The shared memory vector to find the maximum in.
 * @param[in] src_accum The initial value to accumulate with the maximum value found.
 */
template<typename SV>
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
max(thread typename SV::dtype &max_val, threadgroup const SV &src, thread const typename SV::dtype &src_accum, const ushort laneid) {
    reduce<base_ops::max, SV, false>(max_val, src, src_accum, laneid);
}

/**
 * @brief Finds the minimum element in a shared memory vector and accumulates it with src_accum.
 *
 * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
 * @param[out] min_val The minimum value found in the vector, accumulated with src_accum.
 * @param[in] src The shared memory vector to find the minimum in.
 * @param[in] src_accum The initial value to accumulate with the minimum value found.
 */
template<typename SV>
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
min(thread typename SV::dtype &min_val, threadgroup const SV &src, thread const typename SV::dtype &src_accum, const ushort laneid) {
    reduce<base_ops::min, SV, false>(min_val, src, src_accum, laneid);
}

/**
 * @brief Calculates the sum of elements in a shared memory vector and accumulates it with src_accum.
 *
 * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
 * @param[out] sum_val The sum of the values in the vector, accumulated with src_accum.
 * @param[in] src The shared memory vector to sum.
 * @param[in] src_accum The initial value to accumulate with the sum of the vector.
 */
template<typename SV>
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
sum(thread typename SV::dtype &sum_val, threadgroup const SV &src, threadgroup const typename SV::dtype &src_accum, const ushort laneid) {
    reduce<base_ops::sum, SV, false>(sum_val, src, src_accum, laneid);
}

/**
 * @brief Calculates the product of elements in a shared memory vector and accumulates it with src_accum.
 *
 * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
 * @param[out] prod_val The product of the values in the vector, accumulated with src_accum.
 * @param[in] src The shared memory vector to multiply.
 * @param[in] src_accum The initial value to accumulate with the product of the vector.
 */
template<typename SV>
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
prod(thread typename SV::dtype &prod_val, threadgroup const SV &src, thread const typename SV::dtype &src_accum, const ushort laneid) {
    reduce<base_ops::mul, SV, false>(prod_val, src, src_accum, laneid);
}
}
}


