/*
 * Copyright (C) SSQR Kernel.2025 Elvir Crncevic (elvircrn@gmail.com)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "common.cuh"
#include <ATen/cuda/CUDAContext.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/python.h>
#include <torch/script.h> // One-stop header.


namespace matryoshka {

void matryoshka_swizzle(u32 bits, int buffer0_size, const u64 *buffer0,
                        const u32 *buffer1, const u32 *buffer2, const u64 *buffer3, const u64 *buffer4, u32 *v0_lower,
                        u32 *v0_higher, u32 *v1, u32 *v2, u32 *v3_lower, u32 *v3_higher, u32 *v4_lower, u32 *v4_higher) {
  auto reduce = [](u32 b) -> u32 {
    u32 res = ((b) & 0b11u) | (((b >> 2u) & 0b11u) << 4u) |
           (((b >> 4u) & 0b11u) << 8u) | (((b >> 6u) & 0b11u) << 12u) |
           (((b >> 8u) & 0b11u) << 16u) | (((b >> 10u) & 0b11u) << 20u) |
           (((b >> 12u) & 0b11u) << 24u) | (((b >> 14u) & 0b11u) << 28u);

    b >>= 16u;

    res |= (((b) & 0b11u) | (((b >> 2u) & 0b11u) << 4u) |
           (((b >> 4u) & 0b11u) << 8u) | (((b >> 6u) & 0b11u) << 12u) |
           (((b >> 8u) & 0b11u) << 16u) | (((b >> 10u) & 0b11u) << 20u) |
           (((b >> 12u) & 0b11u) << 24u) | (((b >> 14u) & 0b11u) << 28u)) << 2u;


    return res;
  };

  auto reduce_1_0 = [](u32 b) {
    u32 res = 0;

    for (int j = 0; j < 2; j++) {
      for (int i = 0; i < 2; i++) {
        res |= (((b >> (8 * i + 0u)) & 0b1) << ( 0u + (3 - (2 * j + i)))) |
               (((b >> (8 * i + 1u)) & 0b1) << ( 4u + (3 - (2 * j + i)))) |
               (((b >> (8 * i + 2u)) & 0b1) << ( 8u + (3 - (2 * j + i)))) |
               (((b >> (8 * i + 3u)) & 0b1) << (12u + (3 - (2 * j + i)))) |
               (((b >> (8 * i + 4u)) & 0b1) << (16u + (3 - (2 * j + i)))) |
               (((b >> (8 * i + 5u)) & 0b1) << (20u + (3 - (2 * j + i)))) |
               (((b >> (8 * i + 6u)) & 0b1) << (24u + (3 - (2 * j + i)))) |
               (((b >> (8 * i + 7u)) & 0b1) << (28u + (3 - (2 * j + i))));
      }
      b >>= 16u;
    }
    return res;
  };

  for (u32 i = 0; i < buffer0_size; i++) {
    u64 b = buffer0[i];
    v0_lower[i] = reduce(b);
    v0_higher[i] = reduce(static_cast<u32>(b >> 32ull));

    if (bits >= 3) {
      v1[i] = reduce_1_0(buffer1[i]);
      if (bits >= 4) {
        v2[i] = reduce_1_0(buffer2[i]);

        if (bits >= 6) {
          u64 b = buffer3[i];
          v3_lower[i] = reduce(b);
          v3_higher[i] = reduce(static_cast<u32>(b >> 32ull));

          if (bits >= 8) {
            u64 b = buffer4[i];
            v4_lower[i] = reduce(b);
            v4_higher[i] = reduce(static_cast<u32>(b >> 32ull));
          }
        }
      }
    }
  }
}

void shuffle_scales(int m, int n, int group_size, half* scales, half* reordered_scales) {
  if (group_size == -1) {
    // TODO: Also redored but later.
    for (int i = 0; i < m; i++) {
      reordered_scales[i] = scales[i];
    }
  } else {
    auto reordered_scales_base = reordered_scales;
    for (int i = 0; i < m; i += 16) {
      for (int j = 0; j < n / group_size; j++) {
        for (int k = 0; k < 16; k++) {
          // printf("%3d", (int) __half2float(scales[(i + k) * (n / group_size) + j]));
          *(reordered_scales++) = scales[(i + k) * (n / group_size) + j];
        }
      }
    }
  }
}

void _split_up_n_bits(const u8 *___w, u32 bits, int m, int n, u64 *buffer0, u32 *buffer1, u32 *buffer2, u64 *buffer3, u64 *buffer4, int is_bf16) {
  // Assumes buffers are zero'd before calling.
  int cnt = 0;
  for (u32 i = 0; i < (m * n) / (4 * 16 * 16); i++) {
    // We organize data lane_id-wise.
    // This needs to be done on a per-gpu basis.
    const u8 *__w = ___w + i * 16 * 16 * 4;
    for (int lane_id = 0; lane_id < 32; lane_id++) {
      // Each warp is responsible locally for 8 values in a tile.
      const u8 *_w = __w + lane_id * 8;
      // A single tensor core warp tile load is represented with 16 bits of b0, b1 and b2
      // ... but we need four of these to fill up the 1-bit buffers.
      for (u64 t = 0; t < 4; t++) {
        const u8 *w = _w + t * 16 * 16;
        for (u64 j = 0; j < 8; j++) {
          u32 val = *(w++);
          u64 lsb  = (val & 0b11ull);
          u32 msb0 = (val & 0b100u) >> 2ull;
          u32 msb1 = (val & 0b1000u) >> 3ull;
          (*buffer0) |= (lsb << ((t * 8 + j) * 2));
          (*buffer1) |= (msb0 << (t * 8 + j));
          if (bits >= 4) {
            (*buffer2) |= (msb1 << (t * 8 + j));
            if (bits >= 6) {
              val >>= 4u;
              msb0  = val & 0b11ull;
              (*buffer3) |= (msb0 << ((t * 8 + j) * 2));
              if (bits >= 8) {
                msb1 = (val >> 2ull) & 0b11ull;
                (*buffer4) |= (msb1 << ((t * 8 + j) * 2));
              }
            }
          }

        }
      }
      cnt++;
      buffer0++;
      buffer1++;
      if (bits >= 4) buffer2++;
      if (bits >= 6) buffer3++;
      if (bits >= 8) buffer4++;
    }
  }
}
}


namespace matryoshka {
int gptq_matryoshka_matmul_batched(
    // W and meta
    int bits, int prob_m, int prob_n, int prob_k, int group_size,
    // Quantization
    const void *v_buffer, const void *scales,
    // 32-bit
    int row_offsets_len, void *row_offsets,
    // 32-bit
    void *col_vals, int nnz,
    // 16-bit
    // Input
    void *X,
    // Output
    void *y,
    // GPU meta
    cudaStream_t stream = nullptr, void *measurements = nullptr, Features features = Features{._ = 0u});
}

// Function to convert an integer to half-precision using round-down
__half int2half_rd(const int value) {
  // Convert integer to float first
  float floatValue = static_cast<float>(value);
  // Convert float to __half
  __half halfValue = __float2half_rd(floatValue);
  return halfValue;
}

template <class Bit_t, class Scalar_t> Scalar_t host_dequantize(Bit_t q, Scalar_t s, Scalar_t z) {
  // TODO: Clean up these ifs.
  Scalar_t result;
  if constexpr (std::is_same_v<Scalar_t, half>) {
    result = s * (int2half_rd(static_cast<const int>(q)) - z);
  } else {
    result = s * (Scalar_t(q) - z);
  }
#if 0
  printf(" %f = %f x (%f - %f)\n", result, s, Scalar_t(q), z);
#endif
  return result;
}

template <class Weight_t> struct Weights2D {
  int m;
  int n;
  Weight_t *w;

  Weight_t &operator()(int i, int j) { return w[i * n + j]; }
};


#pragma torch_expose
void gptq_matryoshka_mul_batched(s64 m, s64 n, s64 k, s64 bits, int group_size, const torch::Tensor &v_buffer, const torch::Tensor &scales,
                                 const torch::Tensor &row_offsets, const torch::Tensor &col_val_ptr, s64 nnz, const torch::Tensor &X,
                                 s64 _feature_flag, const torch::Tensor &Y, torch::Tensor &out) {
  u32 feature_flag = static_cast<u32>(_feature_flag);
  int dev = v_buffer.get_device();

  auto row_offsets_ptr = nnz ? row_offsets.data_ptr() : nullptr;
  auto col_vals_ptr = nnz ? col_val_ptr.data_ptr() : nullptr;
  // Choose which algorithm to use
  s64 row_offsets_len = nnz ? row_offsets.sizes()[0] : 0;
  void* measurements_ptr = nullptr;

  // TODO: Propagate error one layer up.
  matryoshka::gptq_matryoshka_matmul_batched(bits, m, n, k, group_size, v_buffer.data_ptr(), scales.data_ptr(),
      row_offsets_len, row_offsets_ptr,col_vals_ptr, nnz, X.data_ptr(), out.data_ptr(), at::cuda::getCurrentCUDAStream(dev), measurements_ptr, Features{._ = feature_flag});
}

void populate_row_offsets(int m, const int *h_row_offsets, int *new_row_offsets) {
  int cnt{};
  for (int i = 0; i < m; i += 16) {
    int _cnt{};
    for (int j = 0; j < 16; j++) {
      _cnt = std::max(_cnt, h_row_offsets[i + j + 1] - h_row_offsets[i + j]);
    }
    cnt += _cnt * 16;
    new_row_offsets[i / 16 + 1] = cnt;
  }
}

void convert_to_ptcsr_v2(int m, int cnt, const ColVal *h_col_vals, const int *h_row_offsets, int *new_row_offsets, ColVal *new_col_vals) {
  for (int i = 0; i < cnt; i++) {
    new_col_vals[i]._ = 0;
  }

  auto r = h_row_offsets;
  auto cv = h_col_vals;
  for (int i = 0; i < m; i += 16) {
    int offsets[16]{};
    bool found = false;
    auto colval_ptr = new_col_vals + new_row_offsets[i / 16];

    do {
      found = false;
      for (int j = 0; j < 16; j++) {
        if (r[i + j] + offsets[j] < r[i + j + 1]) {
          found = true;
          break;
        }
      }

      if (found) {
        for (int j = 0; j < 16; j++) {
          if (r[i + j] + offsets[j] < r[i + j + 1]) {
            *colval_ptr = cv[r[i + j] + offsets[j]];
            offsets[j]++;
          }
          colval_ptr++;
        }
      }
    } while (found);
  }
}


#pragma torch_expose
void matryoshka_pack(int bits, int m, int n, const torch::Tensor &w_bits, const torch::Tensor &buffer0, const torch::Tensor &buffer1, const torch::Tensor &buffer2, const torch::Tensor &buffer3, const torch::Tensor &buffer4, int group_size, const torch::Tensor &scales, const torch::Tensor &reordered_scales,
  const torch::Tensor &v, int is_bf16, const torch::Tensor &csr_row_offsets, const torch::Tensor &row_offsets_v2) {
  u8 *w_swizzled = new u8[m * n];
  swizzle_weights(m, n, (u8*) w_bits.data_ptr(), w_swizzled, is_bf16, bits);
  matryoshka::_split_up_n_bits(w_swizzled, bits, m, n, (u64 *)buffer0.data_ptr(), (u32 *)buffer1.data_ptr(), (u32 *)buffer2.data_ptr(), (u64 *)buffer3.data_ptr(), (u64 *)buffer4.data_ptr(), is_bf16);

  int offset = (m * n) / 32;
  u32 *v_base = (u32*) v.data_ptr();

  populate_row_offsets(m, (int*) csr_row_offsets.data_ptr(), (int*) row_offsets_v2.data_ptr());

  matryoshka::matryoshka_swizzle(bits, (m * n) / 32, (u64 *) buffer0.data_ptr(), ( u32*) buffer1.data_ptr(), (u32*) buffer2.data_ptr(), (u64 *) buffer3.data_ptr(), (u64 *) buffer4.data_ptr(),
    v_base, v_base + offset, v_base + 2 * offset, v_base + 3 * offset, v_base + 4 * offset, v_base + 5 * offset, v_base + 6 * offset, v_base + 7 * offset);
  matryoshka::shuffle_scales(m, n, group_size, (half*) scales.data_ptr(), (half*) reordered_scales.data_ptr());
  delete[] w_swizzled;
}

#pragma torch_expose
void post_process_sparsity(int m, const torch::Tensor &csr_row_offsets, const torch::Tensor &csr_col_vals, const torch::Tensor &row_offsets_v2,
  const torch::Tensor &new_col_vals) {
  int * row_offsets_v2_ptr = (int*) row_offsets_v2.data_ptr();
  convert_to_ptcsr_v2(m, row_offsets_v2_ptr[m / 16], (ColVal*)csr_col_vals.data_ptr(), (int*) csr_row_offsets.data_ptr(), row_offsets_v2_ptr,
    (ColVal*) new_col_vals.data_ptr());
}


#ifdef TORCH_EXPOSE_DEFINITIONS
TORCH_EXPOSE_DEFINITIONS
#endif
