// This file is automatically generated.
// Do not edit by hand!

#pragma once

#include "common.h"
#include <array>
#include <cstdio>
#include <string>

#define DECL_KERNEL(ns)                                                        \
  namespace ns {                                                               \
  template <class scalar_t>                                                    \
  cudaError_t                                                                  \
  async_reasoning_attention_gpu(scalar_t *out, float scale, const int *locations,      \
                        const scalar_t *queries, const int *fragment_lengths,  \
                        const scalar_t **key_fragments,                        \
                        const scalar_t **value_fragments, Shape shape);        \
  }

DECL_KERNEL(v1)
DECL_KERNEL(v2)
DECL_KERNEL(v3)
DECL_KERNEL(v4)
DECL_KERNEL(v5)
DECL_KERNEL(v6)
DECL_KERNEL(v7)
DECL_KERNEL(v8)
DECL_KERNEL(v9)
DECL_KERNEL(v10)
DECL_KERNEL(v10b)
DECL_KERNEL(v11)
DECL_KERNEL(v11b)
DECL_KERNEL(v12)
DECL_KERNEL(v12b)
DECL_KERNEL(v13)
DECL_KERNEL(v13b)
DECL_KERNEL(v14)
DECL_KERNEL(v14b)
DECL_KERNEL(v15)
DECL_KERNEL(v15b)
DECL_KERNEL(v16)
DECL_KERNEL(v17)
DECL_KERNEL(v18)
DECL_KERNEL(v19)
DECL_KERNEL(v20)
DECL_KERNEL(v21)
DECL_KERNEL(v22)
DECL_KERNEL(v23)
DECL_KERNEL(v24)
DECL_KERNEL(v25)
DECL_KERNEL(v26)
#undef DECL_KERNEL

template <class scalar_t>
cudaError_t async_reasoning_attention_gpu_dispatch(
    scalar_t *out, float scale, const int *locations, const scalar_t *queries,
    const int *fragment_lengths, const scalar_t **key_fragments,
    const scalar_t **value_fragments, const Shape &shape,
    const std::string &version) {
  if (version == "v1") {
    return v1::async_reasoning_attention_gpu(out, scale, locations, queries,
                                             fragment_lengths, key_fragments,
                                             value_fragments, shape);
  } else if (version == "v2") {
    return v2::async_reasoning_attention_gpu(out, scale, locations, queries,
                                             fragment_lengths, key_fragments,
                                             value_fragments, shape);
  } else if (version == "v3") {
    return v3::async_reasoning_attention_gpu(out, scale, locations, queries,
                                             fragment_lengths, key_fragments,
                                             value_fragments, shape);
  } else if (version == "v4") {
    return v4::async_reasoning_attention_gpu(out, scale, locations, queries,
                                             fragment_lengths, key_fragments,
                                             value_fragments, shape);
  } else if (version == "v5") {
    return v5::async_reasoning_attention_gpu(out, scale, locations, queries,
                                             fragment_lengths, key_fragments,
                                             value_fragments, shape);
  } else if (version == "v6") {
    return v6::async_reasoning_attention_gpu(out, scale, locations, queries,
                                             fragment_lengths, key_fragments,
                                             value_fragments, shape);
  } else if (version == "v7") {
    return v7::async_reasoning_attention_gpu(out, scale, locations, queries,
                                             fragment_lengths, key_fragments,
                                             value_fragments, shape);
  } else if (version == "v8") {
    return v8::async_reasoning_attention_gpu(out, scale, locations, queries,
                                             fragment_lengths, key_fragments,
                                             value_fragments, shape);
  } else if (version == "v9") {
    return v9::async_reasoning_attention_gpu(out, scale, locations, queries,
                                             fragment_lengths, key_fragments,
                                             value_fragments, shape);
  } else if (version == "v10") {
    return v10::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else if (version == "v10b") {
    return v10b::async_reasoning_attention_gpu(out, scale, locations, queries,
                                               fragment_lengths, key_fragments,
                                               value_fragments, shape);
  } else if (version == "v11") {
    return v11::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else if (version == "v11b") {
    return v11b::async_reasoning_attention_gpu(out, scale, locations, queries,
                                               fragment_lengths, key_fragments,
                                               value_fragments, shape);
  } else if (version == "v12") {
    return v12::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else if (version == "v12b") {
    return v12b::async_reasoning_attention_gpu(out, scale, locations, queries,
                                               fragment_lengths, key_fragments,
                                               value_fragments, shape);
  } else if (version == "v13") {
    return v13::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else if (version == "v13b") {
    return v13b::async_reasoning_attention_gpu(out, scale, locations, queries,
                                               fragment_lengths, key_fragments,
                                               value_fragments, shape);
  } else if (version == "v14") {
    return v14::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else if (version == "v14b") {
    return v14b::async_reasoning_attention_gpu(out, scale, locations, queries,
                                               fragment_lengths, key_fragments,
                                               value_fragments, shape);
  } else if (version == "v15") {
    return v15::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else if (version == "v15b") {
    return v15b::async_reasoning_attention_gpu(out, scale, locations, queries,
                                               fragment_lengths, key_fragments,
                                               value_fragments, shape);
  } else if (version == "v16") {
    return v16::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else if (version == "v17") {
    return v17::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else if (version == "v18") {
    return v18::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else if (version == "v19") {
    return v19::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else if (version == "v20") {
    return v20::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else if (version == "v21") {
    return v21::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else if (version == "v22") {
    return v22::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else if (version == "v23") {
    return v23::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else if (version == "v24") {
    return v24::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else if (version == "v25") {
    return v25::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else if (version == "v26") {
    return v26::async_reasoning_attention_gpu(out, scale, locations, queries,
                                              fragment_lengths, key_fragments,
                                              value_fragments, shape);
  } else {
    fprintf(stderr, "Invalid kernel version `%s`!", version.c_str());
    std::exit(1);
  }
}
constexpr const int NUM_KERNEL_VERSIONS = 32;
const std::array<std::string, NUM_KERNEL_VERSIONS> &get_all_versions() {
  static std::array<std::string, NUM_KERNEL_VERSIONS> versions = {
      "v1",   "v2",  "v3",   "v4",  "v5",   "v6",  "v7",   "v8",
      "v9",   "v10", "v10b", "v11", "v11b", "v12", "v12b", "v13",
      "v13b", "v14", "v14b", "v15", "v15b", "v16", "v17",  "v18",
      "v19",  "v20", "v21",  "v22", "v23",  "v24", "v25",  "v26"};
  return versions;
}
