/**
 * @file
 * @brief Templated layouts for global memory.
 */
 
#pragma once

#include "../../common/common.metal"
#include "../shared/shared.metal"
#include "../register/register.metal"
#include "util.metal"


namespace kittens {
namespace ore {
/* ----------   Associative dictionary for global layouts  ---------- */

namespace detail {
template<typename... Args>
struct descriptor_dict {
    METAL_FUNC descriptor_dict() {}
    template<typename T> METAL_FUNC descriptor_dict(T _, int b, int d, int r, int c) {}
    METAL_FUNC descriptor_dict(thread const descriptor_dict &other) {}
};
}

/* ----------  Global layout descriptor  ---------- */

namespace ducks {
namespace gl {
struct identifier {};
}

template <typename T>
static constexpr bool is_tile() {
    return kittens::ore::ducks::is_shared_tile<T>() || kittens::ore::ducks::is_register_tile<T>();
}
    
template <typename T>
static constexpr bool is_vec() {
    return kittens::ore::ducks::is_shared_vector<T>() || kittens::ore::ducks::is_register_vector<T>();
}
}


template<typename _T, int b, int d, int r, int c, typename... TMA_Types>
struct gl {
    using identifier = ducks::gl::identifier;
    
    using T     = typename base_types::packing<_T>::unpacked_type;
    using T2    = typename base_types::packing<_T>::packed_type;
    using dtype = T;
    
    device T* raw_ptr;
    
    ducks::g::make_dim_t<b> batch;
    ducks::g::make_dim_t<d> depth;
    ducks::g::make_dim_t<r> rows;
    ducks::g::make_dim_t<c> cols;
    
    detail::descriptor_dict<TMA_Types...> tma_descs;
    
    METAL_FUNC gl(device T *_data,
                  ducks::g::make_arg_t<b> _batch,
                  ducks::g::make_arg_t<d> _depth,
                  ducks::g::make_arg_t<r> _rows,
                  ducks::g::make_arg_t<c> _cols) :
    raw_ptr(_data), batch(_batch), depth(_depth), rows(_rows), cols(_cols) {
        tma_descs = detail::descriptor_dict<TMA_Types...>(raw_ptr, batch, depth, rows, cols);
    }
    METAL_FUNC gl(thread const gl &other) :
    raw_ptr(other.raw_ptr), batch(other.batch), depth(other.depth), rows(other.rows), cols(other.cols), tma_descs(other.tma_descs) {}

    METAL_FUNC device T& operator[](const thread coord &idx) {
        return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r)*cols + idx.c];
    }
    METAL_FUNC device const T& operator[](const thread coord &idx) const {
        return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r)*cols + idx.c];
    }
    template<typename TILE>
    METAL_FUNC typename metal::enable_if<ducks::is_tile<TILE>(), device T&>::type
    get(const thread coord &idx) {
        return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r*TILE::rows)*cols + idx.c*TILE::cols];
    }
    template<typename TILE>
    METAL_FUNC typename metal::enable_if<ducks::is_tile<TILE>(), device const T&>::type
    get(const thread coord &idx) const {
        return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r*TILE::rows)*cols + idx.c*TILE::cols];
    }
    template<typename VEC>
    METAL_FUNC typename metal::enable_if<ducks::is_vec<VEC>(), device T&>::type
    get(const thread coord &idx) {
        return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r)*cols + idx.c*VEC::length];
    }
    template<typename VEC>
    METAL_FUNC typename metal::enable_if<ducks::is_vec<VEC>(), device const T&>::type
    get(const thread coord &idx) const {
        return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r)*cols + idx.c*VEC::length];
    }
    METAL_FUNC size_t row_stride() const { return cols; }
};

namespace ducks {
template <typename T>
struct has_gl_identifier {
    static constant constexpr bool value = false; // Default case
};

// Specialize for specific template instantiations of st
template <typename _T, int b, int d, int r, int c, typename... TMA_Types>
struct has_gl_identifier<kittens::ore::gl<_T, b, d, r, c, TMA_Types ...>> {
    static constant constexpr bool value = true;
};

template <typename GL>
static constexpr bool is_global_layout() {
    return has_gl_identifier<GL>::value;
}
template <typename GL>
static constexpr void assert_gl() {
    static_assert(is_global_layout<GL>(), "T must be a gl");
}
}

}
}
