/******************************************************************************
 * Copyright (c) 2023, Tri Dao.
 ******************************************************************************/

#pragma once

namespace flash {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<bool Varlen = true>
struct BlockInfo {

    template<typename Params>
    __device__ BlockInfo(const Params& params, const int bidb):
        sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]),
        sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]),
        actual_seqlen_q(params.actual_seqlen_q == nullptr ?
                            (!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q :
                                                                         params.cu_seqlens_q[bidb + 1] - sum_s_q) :
                            params.actual_seqlen_q[bidb]),
        actual_seqlen_k(params.actual_seqlen_k == nullptr ?
                            (!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k :
                                                                         params.cu_seqlens_k[bidb + 1] - sum_s_k) :
                            params.actual_seqlen_k[bidb])
    {
    }

    template<typename index_t>
    inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const
    {
        return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
    }

    template<typename index_t>
    inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const
    {
        return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
    }

    const int sum_s_q;
    const int sum_s_k;
    const int actual_seqlen_q;
    const int actual_seqlen_k;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace flash
