// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// This file is for blocksparse attention utils cuda kernel.

#include <assert.h>
#include <cuda.h>
#include <torch/all.h>

// Save the start index of each block in the given range into block_offset.
// Returns the updated block count.
__device__ int64_t save_blocks(
    int* block_offset,
    int64_t range_start,
    int64_t range_end,
    int64_t block_size,
    int64_t input_block_count,
    int64_t kv_seqlen) {
  if (range_start >= kv_seqlen) {
    return input_block_count;
  }
  if (range_end > kv_seqlen) {
    range_end = kv_seqlen;
  }
  int64_t current_block_count = input_block_count;
  for (int idx = range_start; idx < range_end; idx += block_size) {
    block_offset[current_block_count++] = idx;
  }
  return current_block_count;
}

// CUDA kernel: convert sparse vertical/slash indices to block/column offsets.
__global__ void convert_vertical_slash_indexes_kernel(
    const int* q_seqlens,         // [BATCH, ]
    const int* kv_seqlens,        // [BATCH, ]
    const int* vertical_indexes,  // [BATCH, N_HEADS, NNZ_V]
    const int* slash_indexes,     // [BATCH, N_HEADS, NNZ_S]
    int* block_count,             // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
    int* block_offset,            // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
    int* column_count,            // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
    int* column_index,            // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
    int64_t N_HEADS,
    int64_t N_ROWS,
    int64_t BLOCK_SIZE_M,
    int64_t BLOCK_SIZE_N,
    int64_t NNZ_V,
    int64_t NNZ_S,
    bool causal  // True for intra, False for succ
) {
  const int batch_idx = blockIdx.y;
  const int head_idx = blockIdx.x;
  const int group_idx = blockIdx.z;

  int64_t q_seqlen = q_seqlens[batch_idx];
  int64_t kv_seqlen = kv_seqlens[batch_idx];
  int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
  int64_t start_m = block_idx_m * BLOCK_SIZE_M;
  if (start_m >= q_seqlen) {
    return;
  }
  int64_t end_m = start_m + BLOCK_SIZE_M;
  vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
  slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
  int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
  block_count += row_offset;
  block_offset += row_offset * NNZ_S;
  column_count += row_offset;
  column_index += row_offset * NNZ_V;

  bool has_slash = true;
  int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
  int64_t s = 0, v = 0;
  int64_t v_idx = vertical_indexes[v++];
  int64_t s_idx = slash_indexes[s++];
  if (causal) {
    while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
      s_idx = slash_indexes[s++];
    }
    if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
    s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
  } else {
    while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
      s_idx = slash_indexes[s++];
    }
    if (s_idx > end_m + kv_seqlen) has_slash = false;
    s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
  }

  int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
  if (!has_slash) {
    if (causal) {
      range_start = (kv_seqlen - q_seqlen) + end_m;
      range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
    } else {
      range_start = kv_seqlen;
      range_end = kv_seqlen + BLOCK_SIZE_N;
    }
  }

  bool slash_finished = false;
  while (1) {
    if (v_idx < range_end) {
      if (v_idx < range_start) {
        column_index[tmp_col_cnt++] = v_idx;
      }
      if (v < NNZ_V) {
        v_idx = vertical_indexes[v++];
      } else {
        if (causal)
          v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
        else
          v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
      }
    } else {
      if ((s < NNZ_S && causal) || (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
        if (causal)
          s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], BLOCK_SIZE_M);
        else
          s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
      } else {
        if (v == NNZ_V || (v_idx > range_start && causal)) {
          // add the last vertical if no more slash
          if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
            column_index[tmp_col_cnt++] = v_idx;
          }
          tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
          break;
        } else {
          if (causal) {
            range_start = (kv_seqlen - q_seqlen) + end_m;
            range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
          } else {
            // if slash_finished but there are vertical left, save current
            // blocks
            tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
            range_start = kv_seqlen;
            range_end = kv_seqlen + BLOCK_SIZE_N;
          }
          slash_finished = true;
        }
      }
      if (!slash_finished) {
        if (s_idx > range_end + BLOCK_SIZE_M) {
          tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
          range_start = s_idx - BLOCK_SIZE_M;
          range_end = s_idx;
        } else if (s_idx > range_end) {
          range_end += BLOCK_SIZE_M;
        }
      }
    }
  }

  block_count[0] = tmp_blk_cnt;
  column_count[0] = tmp_col_cnt;
}

// Host function: launches the kernel with 64 threads per block.
void convert_vertical_slash_indexes_64x64(
    const int* q_seqlens,         // [BATCH, ]
    const int* kv_seqlens,        // [BATCH, ]
    const int* vertical_indexes,  // [BATCH, N_HEADS, NNZ_V]
    const int* slash_indexes,     // [BATCH, N_HEADS, NNZ_S]
    int* block_count,             // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
    int* block_offset,            // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
    int* column_count,            // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
    int* column_index,            // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
    int64_t BATCH_SIZE,
    int64_t N_HEADS,
    int64_t N_ROWS,
    int64_t BLOCK_SIZE_M,
    int64_t BLOCK_SIZE_N,
    int64_t NNZ_V,
    int64_t NNZ_S,
    bool causal) {
  const int N_THREADS = 64;
  const dim3 dimBlock((int32_t)N_THREADS);
  const dim3 dimGrid(
      (int32_t)N_HEADS, (int32_t)BATCH_SIZE, ((int32_t)N_ROWS + (int32_t)N_THREADS - 1) / (int32_t)N_THREADS);
  convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
      q_seqlens,
      kv_seqlens,
      vertical_indexes,
      slash_indexes,
      block_count,
      block_offset,
      column_count,
      column_index,
      N_HEADS,
      N_ROWS,
      BLOCK_SIZE_M,
      BLOCK_SIZE_N,
      NNZ_V,
      NNZ_S,
      causal);
}

// Host function: prepares tensor pointers and launches the CUDA kernel.
void convert_vertical_slash_indexes(
    torch::Tensor& block_count,      // [BATCH, N_HEADS, NUM_ROWS]
    torch::Tensor& block_offset,     // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
    torch::Tensor& column_count,     // [BATCH, N_HEADS, NUM_ROWS]
    torch::Tensor& column_index,     // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
    torch::Tensor q_seqlens,         // [BATCH, ]
    torch::Tensor kv_seqlens,        // [BATCH, ]
    torch::Tensor vertical_indexes,  // [BATCH, N_HEADS, NNZ_V]
    torch::Tensor slash_indexes,     // [BATCH, N_HEADS, NNZ_S]
    int64_t context_size,
    int64_t block_size_M,
    int64_t block_size_N,
    bool causal) {
  cudaSetDevice(q_seqlens.get_device());

  int64_t batch_size = slash_indexes.size(0);
  int64_t num_heads = slash_indexes.size(1);
  int64_t nnz_slash = slash_indexes.size(2);
  int64_t nnz_vertical = vertical_indexes.size(2);
  int64_t num_rows = (context_size + block_size_M - 1) / block_size_M;

  convert_vertical_slash_indexes_64x64(
      q_seqlens.data_ptr<int>(),
      kv_seqlens.data_ptr<int>(),
      vertical_indexes.data_ptr<int>(),
      slash_indexes.data_ptr<int>(),
      block_count.data_ptr<int>(),
      block_offset.data_ptr<int>(),
      column_count.data_ptr<int>(),
      column_index.data_ptr<int>(),
      batch_size,
      num_heads,
      num_rows,
      block_size_M,
      block_size_N,
      nnz_vertical,
      nnz_slash,
      causal);
}

// --- mergehead kernels --- //

// Kernel: like above, but supports per-head variable NNZ_V/NNZ_S.
__global__ void convert_vertical_slash_indexes_kernel_mergehead(
    const int* q_seqlens,         // [BATCH, ]
    const int* kv_seqlens,        // [BATCH, ]
    const int* vertical_indexes,  // [BATCH, N_HEADS, NNZ_V]
    const int* slash_indexes,     // [BATCH, N_HEADS, NNZ_S]
    const int* per_head_vertical_topkv,
    const int* per_head_slash_topkv,
    int* block_count,   // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
    int* block_offset,  // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
    int* column_count,  // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
    int* column_index,  // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
    int64_t N_HEADS,
    int64_t N_ROWS,
    int64_t BLOCK_SIZE_M,
    int64_t BLOCK_SIZE_N,
    int64_t NNZ_V,
    int64_t NNZ_S,
    bool causal  // True for intra, False for succ
) {
  const int batch_idx = blockIdx.y;
  const int head_idx = blockIdx.x;
  const int group_idx = blockIdx.z;

  int64_t q_seqlen = q_seqlens[batch_idx];
  int64_t kv_seqlen = kv_seqlens[batch_idx];
  int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
  int64_t start_m = block_idx_m * BLOCK_SIZE_M;
  if (start_m >= q_seqlen) {
    return;
  }
  int64_t end_m = start_m + BLOCK_SIZE_M;
  vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
  slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
  int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
  block_count += row_offset;
  block_offset += row_offset * NNZ_S;
  column_count += row_offset;
  column_index += row_offset * NNZ_V;

  // MergeHead: each head has it's unique max topk NNZ_V，NNZ_S. (NNZ_V，NNZ_S
  // above is buffer size, use to compute offset)
  NNZ_S = per_head_slash_topkv[head_idx];
  NNZ_V = per_head_vertical_topkv[head_idx];

  bool has_slash = true;
  int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
  int64_t s = 0, v = 0;
  int64_t v_idx = vertical_indexes[v++];
  int64_t s_idx = slash_indexes[s++];
  if (causal) {
    while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
      s_idx = slash_indexes[s++];
    }
    if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
    s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
  } else {
    while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
      s_idx = slash_indexes[s++];
    }
    if (s_idx > end_m + kv_seqlen) has_slash = false;
    s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
  }

  int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
  if (!has_slash) {
    if (causal) {
      range_start = (kv_seqlen - q_seqlen) + end_m;
      range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
    } else {
      range_start = kv_seqlen;
      range_end = kv_seqlen + BLOCK_SIZE_N;
    }
  }

  bool slash_finished = false;
  while (1) {
    if (v_idx < range_end) {
      if (v_idx < range_start) {
        column_index[tmp_col_cnt++] = v_idx;
      }
      if (v < NNZ_V) {
        v_idx = vertical_indexes[v++];
      } else {
        if (causal)
          v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
        else
          v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
      }
    } else {
      if ((s < NNZ_S && causal) || (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
        if (causal)
          s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], BLOCK_SIZE_M);
        else
          s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
      } else {
        if (v == NNZ_V || (v_idx > range_start && causal)) {
          // add the last vertical if no more slash
          if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
            column_index[tmp_col_cnt++] = v_idx;
          }
          tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
          break;
        } else {
          if (causal) {
            range_start = (kv_seqlen - q_seqlen) + end_m;
            range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
          } else {
            // if slash_finished but there are vertical left, save current
            // blocks
            tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
            range_start = kv_seqlen;
            range_end = kv_seqlen + BLOCK_SIZE_N;
          }
          slash_finished = true;
        }
      }
      if (!slash_finished) {
        if (s_idx > range_end + BLOCK_SIZE_M) {
          tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
          range_start = s_idx - BLOCK_SIZE_M;
          range_end = s_idx;
        } else if (s_idx > range_end) {
          range_end += BLOCK_SIZE_M;
        }
      }
    }
  }

  block_count[0] = tmp_blk_cnt;
  column_count[0] = tmp_col_cnt;
}

// Launch the mergehead kernel with 64 threads per block.
void convert_vertical_slash_indexes_64x64_mergehead(
    const int* q_seqlens,         // [BATCH, ]
    const int* kv_seqlens,        // [BATCH, ]
    const int* vertical_indexes,  // [BATCH, N_HEADS, NNZ_V]
    const int* slash_indexes,     // [BATCH, N_HEADS, NNZ_S]
    int* per_head_vertical_topkv,
    int* per_head_slash_topkv,
    int* block_count,   // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
    int* block_offset,  // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
    int* column_count,  // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
    int* column_index,  // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
    int64_t BATCH_SIZE,
    int64_t N_HEADS,
    int64_t N_ROWS,
    int64_t BLOCK_SIZE_M,
    int64_t BLOCK_SIZE_N,
    int64_t NNZ_V,
    int64_t NNZ_S,
    bool causal) {
  const int N_THREADS = 64;
  const dim3 dimBlock(N_THREADS);
  const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
  convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock>>>(
      q_seqlens,
      kv_seqlens,
      vertical_indexes,
      slash_indexes,
      per_head_vertical_topkv,
      per_head_slash_topkv,
      block_count,
      block_offset,
      column_count,
      column_index,
      N_HEADS,
      N_ROWS,
      BLOCK_SIZE_M,
      BLOCK_SIZE_N,
      NNZ_V,
      NNZ_S,
      causal);
}

// Host wrapper for mergehead kernel.
void convert_vertical_slash_indexes_mergehead(
    torch::Tensor& block_count,            // [BATCH, N_HEADS, NUM_ROWS]
    torch::Tensor& block_offset,           // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
    torch::Tensor& column_count,           // [BATCH, N_HEADS, NUM_ROWS]
    torch::Tensor& column_index,           // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
    torch::Tensor q_seqlens,               // [BATCH, ]
    torch::Tensor kv_seqlens,              // [BATCH, ]
    torch::Tensor vertical_indexes,        // [BATCH, N_HEADS, NNZ_V]
    torch::Tensor slash_indexes,           // [BATCH, N_HEADS, NNZ_S]
    torch::Tensor vertical_indices_count,  // [N_HEADS, ]
    torch::Tensor slash_indices_count,
    int64_t context_size,
    int64_t block_size_M,
    int64_t block_size_N,
    bool causal) {
  cudaSetDevice(q_seqlens.get_device());

  int batch_size = slash_indexes.size(0);
  int num_heads = slash_indexes.size(1);
  int nnz_slash = slash_indexes.size(2);
  int nnz_vertical = vertical_indexes.size(2);
  int num_rows = (context_size + block_size_M - 1) / block_size_M;

  convert_vertical_slash_indexes_64x64_mergehead(
      q_seqlens.data_ptr<int>(),
      kv_seqlens.data_ptr<int>(),
      vertical_indexes.data_ptr<int>(),
      slash_indexes.data_ptr<int>(),
      vertical_indices_count.data_ptr<int>(),
      slash_indices_count.data_ptr<int>(),
      block_count.data_ptr<int>(),
      block_offset.data_ptr<int>(),
      column_count.data_ptr<int>(),
      column_index.data_ptr<int>(),
      batch_size,
      num_heads,
      num_rows,
      block_size_M,
      block_size_N,
      nnz_vertical,
      nnz_slash,
      causal);
}
