/**
 * @file
 * @brief Conversions on vectors stored in registers.
 */

#pragma once // done

#include "../../../../common/common.metal"
#include "../../../../types/types.metal"

namespace kittens {
namespace ore {

namespace detail {
    static METAL_FUNC int colstart_from_laneid(const int laneid) { // rowvec
        return (laneid % 2) * 2 + ((laneid / 8) % 2) * 4;
    }
    // 0,1,2,3,4,5,6,7 -> 0,2,1,3,8,10,9,11
    static METAL_FUNC int leader_from_col(const int col) { // rowvec
        return (col / 4) * 8 + (col / 2) % 2 + (col % 2) * 2;
    }
    // 0,2,1,3,8,10,9,11 -> 0,1,0,1,0,1,0,1
    static METAL_FUNC int idx_from_colleader(const int laneid) { // rowvec
        return ((laneid % 8) / 2) % 2; // % 2 to protect against non-leaders
    }
    
    static METAL_FUNC int row_from_laneid(const int laneid) { // rowvec
        return (laneid / 2) % 4 + (laneid / 16) * 16;
    }
    // 0,1,2,3,4,5,6,7 -> 0, 2, 4, 6, 16, 18, 20, 22
    static METAL_FUNC int leader_from_row(const int row) { // rowvec
        return (row/4) * 16 + (row % 4) * 2;
    }
    
}
/**
 * @brief Copies data from one register vector to another.
 *
 * @tparam RV1 The type of the destination register vector.
 * @tparam RV2 The type of the source register vector.
 * @param dst[out] The destination register vector.
 * @param src[in] The source register vector to copy from.
 */
template<typename RV1, typename RV2>
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV1>() && ducks::is_register_vector<RV2>(), void>::type
copy(thread RV1 &dst, thread const RV2 &src, const ushort laneid) {
    static_assert(RV1::outer_dim == RV2::outer_dim, "Outer dimensions of the register vectors must be the same.");
    using D1 = typename RV1::dtype;
    using D2 = typename RV2::dtype;
    if (metal::is_same_v<typename RV1::layout, typename RV2::layout>) {
        #pragma clang loop unroll(full)
        for(int i = 0; i < RV1::outer_dim; i++) {
            #pragma clang loop unroll(full)
            for(int j = 0; j < RV1::inner_dim; j++) {
                dst[i][j] = base_types::convertor<D1, D2>::convert(src[i][j]);
            }
        }
    }
    // row vector -> col vector
    else if (RV1::inner_dim == 1 && RV2::inner_dim == 2) {
        const int row        = detail::row_from_laneid(laneid);
        const int laneid_src = detail::leader_from_col(row);
        const int send_idx   = detail::idx_from_colleader(laneid);
        #pragma clang loop unroll(full)
        for(int i = 0; i < RV1::outer_dim; i++) {
            dst[i][0] = base_types::convertor<D1,D2>::convert(shfl_sync<D2>(src[i][send_idx], laneid_src));
        }
    }
    // col vector -> row vector
    else if (RV1::inner_dim == 2 && RV2::inner_dim == 1) {
        const int col1 = detail::colstart_from_laneid(laneid);
        const int col2 = col1 + 1;
        const int laneid_src1 = detail::leader_from_row(col1);
        const int laneid_src2 = detail::leader_from_row(col2);
        #pragma clang loop unroll(full)
        for(int i = 0; i < RV1::outer_dim; i++) {
            dst[i][0] = base_types::convertor<D1,D2>::convert(shfl_sync<D2>(src[i][0], laneid_src1));
            dst[i][1] = base_types::convertor<D1,D2>::convert(shfl_sync<D2>(src[i][0], laneid_src2));
        }
    }
    else {
        static_assert(RV1::inner_dim == RV2::inner_dim, "Something has gone deeply wrong with how register vectors were instantiated.");
    }
}

}
}
