#include "common.h"
#include "vec.h"

namespace {

template <typename scalar_t, int SIZE>
inline void softmax(float* __restrict__ out, const scalar_t* __restrict__ input) {
  using bVec = at::vec::Vectorized<scalar_t>;
  using fVec = at::vec::Vectorized<float>;

  constexpr int kVecSize = bVec::size();

  // step 1: get max
  fVec max_fvec = fVec(-std::numeric_limits<float>::infinity());
  if constexpr (SIZE < kVecSize) {
    // SIZE = 1, 2, 4, 8, 16; only the top half is used
    bVec x_bvec = bVec::loadu(input, SIZE);
    fVec x_fvec0, x_fvec1;
    std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
    x_fvec0 = fVec::set(max_fvec, x_fvec0, SIZE);
    max_fvec = at::vec::maximum(max_fvec, x_fvec0);
    x_fvec0.store(out, SIZE);
  } else {
    for (int d = 0; d < SIZE; d += kVecSize) {
      bVec x_bvec = bVec::loadu(input + d);
      fVec x_fvec0, x_fvec1;
      std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);

      max_fvec = at::vec::maximum(max_fvec, x_fvec0);
      max_fvec = at::vec::maximum(max_fvec, x_fvec1);
      x_fvec0.store(out + d);
      x_fvec1.store(out + d + fVec::size());
    }
  }
  float max_val = vec_reduce_max(max_fvec);
  max_fvec = fVec(max_val);

  // step 2: sum of (x - max).exp()
  fVec sum_fvec = fVec(float(0));
  if constexpr (SIZE < fVec::size()) {
    // SIZE = 1, 2, 4, 8
    fVec x_fvec = (fVec::loadu(out, SIZE) - max_fvec).exp_u20();
    x_fvec = fVec::set(sum_fvec, x_fvec, SIZE);
    sum_fvec += x_fvec;
    x_fvec.store(out, SIZE);
  } else {
    for (int d = 0; d < SIZE; d += fVec::size()) {
      fVec x_fvec = (fVec::loadu(out + d) - max_fvec).exp_u20();
      sum_fvec += x_fvec;
      x_fvec.store(out + d);
    }
  }
  float sum_val = vec_reduce_sum(sum_fvec);

  // step 3: x * (1 / sum)
  sum_fvec = fVec(1.f / sum_val);
  if constexpr (SIZE < fVec::size()) {
    // SIZE = 1, 2, 4, 8
    fVec out_fvec = fVec::loadu(out, SIZE) * sum_fvec;
    out_fvec.store(out, SIZE);
  } else {
    for (int d = 0; d < SIZE; d += fVec::size()) {
      fVec out_fvec = fVec::loadu(out + d) * sum_fvec;
      out_fvec.store(out + d);
    }
  }
}

template <typename scalar_t, int NUM_EXPERTS>
void grouped_topk_kernel_impl(
    float* __restrict__ topk_weights,
    int32_t* __restrict__ topk_ids,
    const scalar_t* __restrict__ gating_output,
    int64_t num_tokens,
    int64_t topk,
    int64_t num_groups,
    int64_t topk_group,
    bool renormalize) {
  const int64_t num_experts_per_group = NUM_EXPERTS / num_groups;
  at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
    alignas(64) float scores[NUM_EXPERTS];

    using elem_t = std::pair<float, int32_t>;
    std::vector<elem_t> queue(num_groups);
    std::vector<elem_t> queue2(topk_group * num_experts_per_group);

    for (int64_t i = begin; i < end; ++i) {
      // do softmax to get scores
      softmax<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);

      // find max score per group
      for (int64_t g = 0; g < num_groups; ++g) {
        float gmax = -std::numeric_limits<float>::infinity();
        for (int64_t e = 0; e < num_experts_per_group; ++e) {
          gmax = std::max(gmax, scores[g * num_experts_per_group + e]);
        }
        queue[g] = {gmax, g};
      }

      // find group topk
      std::partial_sort(
          queue.begin(), queue.begin() + topk_group, queue.end(), [](const elem_t& x, const elem_t& y) -> bool {
            return x.first > y.first;
          });

      for (int64_t g = 0; g < topk_group; ++g) {
        int32_t group_idx = queue[g].second;
        for (int64_t e = 0; e < num_experts_per_group; ++e) {
          int32_t expert_idx = group_idx * num_experts_per_group + e;
          queue2[g * num_experts_per_group + e] = {scores[expert_idx], expert_idx};
        }
      }

      // find global topk
      std::partial_sort(
          queue2.begin(), queue2.begin() + topk, queue2.end(), [](const elem_t& x, const elem_t& y) -> bool {
            return x.first > y.first;
          });

      for (int64_t j = 0; j < topk; ++j) {
        topk_weights[i * topk + j] = queue2[j].first;
        topk_ids[i * topk + j] = queue2[j].second;
      }

      if (renormalize) {
        float sum = 0.f;
        for (int64_t j = 0; j < topk; ++j) {
          sum += topk_weights[i * topk + j];
        }
        float scale = 1.f / sum;
        for (int64_t j = 0; j < topk; ++j) {
          topk_weights[i * topk + j] *= scale;
        }
      }
    }
  });
}

template <typename scalar_t, int SIZE>
inline void sigmoid(float* __restrict__ out, const scalar_t* __restrict__ input) {
  using bVec = at::vec::Vectorized<scalar_t>;
  using fVec = at::vec::Vectorized<float>;

  const fVec one = fVec(1.f);

  constexpr int kVecSize = bVec::size();
  for (int d = 0; d < SIZE; d += kVecSize) {
    bVec x_bvec = bVec::loadu(input + d);
    fVec x_fvec0, x_fvec1;
    std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);

    x_fvec0 = one / (one + x_fvec0.neg().exp_u20());
    x_fvec1 = one / (one + x_fvec1.neg().exp_u20());

    x_fvec0.store(out + d);
    x_fvec1.store(out + d + fVec::size());
  }
}

template <typename scalar_t, int SIZE>
inline void
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) {
  using bVec = at::vec::Vectorized<scalar_t>;
  using fVec = at::vec::Vectorized<float>;
  for (int d = 0; d < SIZE; d += bVec::size()) {
    bVec bias_vec = bVec::loadu(bias + d);
    fVec bias0, bias1;
    std::tie(bias0, bias1) = at::vec::convert_to_float(bias_vec);

    fVec x0 = fVec::loadu(scores + d) + bias0;
    fVec x1 = fVec::loadu(scores + d + fVec::size()) + bias1;
    x0.store(scores2 + d);
    x1.store(scores2 + d + fVec::size());
  }
}

template <typename scalar_t, int NUM_EXPERTS, int TOPK>
void biased_grouped_topk_kernel_impl(
    float* __restrict__ topk_weights,
    int32_t* __restrict__ topk_ids,
    const scalar_t* __restrict__ gating_output,
    const scalar_t* __restrict__ bias,
    int64_t num_tokens,
    int64_t num_groups,
    int64_t topk_group,
    bool renormalize) {
  using Vec = at::vec::Vectorized<float>;

  const int64_t num_experts_per_group = NUM_EXPERTS / num_groups;
  at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
    // scores: sigmoid
    alignas(64) float scores[NUM_EXPERTS];
    // scores for choice: sigmoid + bias
    alignas(64) float scores2[NUM_EXPERTS];

    using elem_t = std::pair<float, int32_t>;
    std::vector<elem_t> queue(num_groups);
    std::vector<elem_t> queue2(topk_group * num_experts_per_group);

    for (int64_t i = begin; i < end; ++i) {
      // do sigmoid to get scores
      sigmoid<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
      apply_bias<scalar_t, NUM_EXPERTS>(scores2, scores, bias);

      for (int64_t g = 0; g < num_groups; ++g) {
        // find the max
        float gmax = at::vec::reduce_all<float>(
            [](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
            scores2 + g * num_experts_per_group,
            num_experts_per_group);

        // find position of first max,
        // note that we may have multiple max values.
        int first_max_idx = -1;
        for (int64_t e = 0; e < num_experts_per_group; ++e) {
          if (scores2[g * num_experts_per_group + e] == gmax) {
            first_max_idx = g * num_experts_per_group + e;
            break;
          }
        }

        // find the 2nd max
        scores2[first_max_idx] = -std::numeric_limits<float>::infinity();
        float gmax2 = at::vec::reduce_all<float>(
            [](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
            scores2 + g * num_experts_per_group,
            num_experts_per_group);
        // restore scores for choice
        scores2[first_max_idx] = gmax;

        queue[g] = {gmax + gmax2, g};
      }

      // find group topk
      std::partial_sort(
          queue.begin(), queue.begin() + topk_group, queue.end(), [](const elem_t& x, const elem_t& y) -> bool {
            return x.first > y.first;
          });

      for (int64_t g = 0; g < topk_group; ++g) {
        int32_t group_idx = queue[g].second;
        for (int64_t e = 0; e < num_experts_per_group; ++e) {
          int32_t expert_idx = group_idx * num_experts_per_group + e;
          queue2[g * num_experts_per_group + e] = {scores2[expert_idx], expert_idx};
        }
      }

      // find global topk
      std::partial_sort(
          queue2.begin(), queue2.begin() + TOPK, queue2.end(), [](const elem_t& x, const elem_t& y) -> bool {
            return x.first > y.first;
          });

      for (int j = 0; j < TOPK; ++j) {
        int32_t index = queue2[j].second;
        topk_ids[i * TOPK + j] = index;
        topk_weights[i * TOPK + j] = scores[index];
      }

#if defined(CPU_CAPABILITY_AVX512)
      if (renormalize) {
        __mmask16 mask = (1ULL << TOPK) - 1;
        __m512 x = _mm512_maskz_loadu_ps(mask, topk_weights + i * TOPK);
        float sum = _mm512_reduce_add_ps(x);
        __m512 vscale = _mm512_set1_ps(1.f / sum);
        __m512 y = _mm512_mul_ps(x, vscale);
        _mm512_mask_storeu_ps(topk_weights + i * TOPK, mask, y);
      }
#else
      if (renormalize) {
        float sum = 0.f;
        for (int64_t j = 0; j < TOPK; ++j) {
          sum += topk_weights[i * TOPK + j];
        }
        float scale = 1.f / sum;
        for (int64_t j = 0; j < TOPK; ++j) {
          topk_weights[i * TOPK + j] *= scale;
        }
      }
#endif
    }
  });
}

#define LAUNCH_GROUPED_TOPK_KERNEL(NE)    \
  grouped_topk_kernel_impl<scalar_t, NE>( \
      topk_weights.data_ptr<float>(),     \
      topk_ids.data_ptr<int32_t>(),       \
      gating_output.data_ptr<scalar_t>(), \
      num_tokens,                         \
      topk,                               \
      num_expert_group,                   \
      topk_group,                         \
      renormalize);

#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK)    \
  biased_grouped_topk_kernel_impl<scalar_t, NE, NTOPK>( \
      topk_weights.data_ptr<float>(),                   \
      topk_ids.data_ptr<int32_t>(),                     \
      gating_output.data_ptr<scalar_t>(),               \
      correction_bias.data_ptr<scalar_t>(),             \
      num_tokens,                                       \
      num_expert_group,                                 \
      topk_group,                                       \
      renormalize);

}  // anonymous namespace

// grouped topk for DeepSeek V2
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
    at::Tensor& hidden_states,
    at::Tensor& gating_output,
    int64_t topk,
    bool renormalize,
    int64_t num_expert_group,
    int64_t topk_group) {
  RECORD_FUNCTION("sgl-kernel::grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
  CHECK_INPUT(gating_output);

  const auto st = hidden_states.scalar_type();
  CHECK_EQ(gating_output.scalar_type(), st);

  int64_t num_tokens = hidden_states.size(0);
  int64_t num_experts = gating_output.size(1);
  TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
  at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
  at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));

  AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "grouped_topk_kernel", [&] {
    switch (num_experts) {
      case 1:
        LAUNCH_GROUPED_TOPK_KERNEL(1);
        break;
      case 2:
        LAUNCH_GROUPED_TOPK_KERNEL(2);
        break;
      case 4:
        LAUNCH_GROUPED_TOPK_KERNEL(4);
        break;
      case 8:
        LAUNCH_GROUPED_TOPK_KERNEL(8);
        break;
      case 16:
        LAUNCH_GROUPED_TOPK_KERNEL(16);
        break;
      case 32:
        LAUNCH_GROUPED_TOPK_KERNEL(32);
        break;
      case 64:
        LAUNCH_GROUPED_TOPK_KERNEL(64);
        break;
      case 128:
        LAUNCH_GROUPED_TOPK_KERNEL(128);
        break;
      case 160:
        LAUNCH_GROUPED_TOPK_KERNEL(160);
        break;
      case 256:
        LAUNCH_GROUPED_TOPK_KERNEL(256);
        break;
      default:
        TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
    }
  });
  return std::make_tuple(topk_weights, topk_ids);
}

// biased grouped topk DeepSeek V3/R1
std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
    at::Tensor& hidden_states,
    at::Tensor& gating_output,
    at::Tensor& correction_bias,
    int64_t topk,
    bool renormalize,
    int64_t num_expert_group,
    int64_t topk_group) {
  RECORD_FUNCTION(
      "sgl-kernel::biased_grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output, correction_bias}));

  CHECK_INPUT(gating_output);
  CHECK_INPUT(correction_bias);

  const auto st = hidden_states.scalar_type();
  CHECK_EQ(gating_output.scalar_type(), st);
  CHECK_EQ(correction_bias.scalar_type(), st);

  int64_t num_tokens = hidden_states.size(0);
  int64_t num_experts = gating_output.size(1);
  TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
  TORCH_CHECK(correction_bias.numel() == num_experts, "Bias shape mismatch");
  at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
  at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));

  AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "biased_grouped_topk_kernel", [&] {
    // NOW only support DSv3 configs
    TORCH_CHECK(topk == 8, "Unexpected topk: ", topk);
    switch (num_experts) {
      case 256:
        LAUNCH_BIASED_GROUPED_TOPK_KERNEL(256, 8);
        break;
      default:
        TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
    }
  });
  return std::make_tuple(topk_weights, topk_ids);
}
