/*
 * Copyright (C) SSQR Kernel.2025 Elvir Crncevic (elvircrn@gmail.com)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>

#include <cstdio>

#define DEVICE_INLINE __forceinline__ __device__

#ifndef __host__
#define __host__
#endif

#ifndef __device__
#define __device__
#endif

#ifndef __forceinline__
#define __forceinline__
#endif

using u64 = unsigned long long;
using s32 = int;
using s64 = long long int;
using u32 = unsigned int;
using u8 = unsigned char;
using u16 = unsigned short;

static constexpr u64 SECOND_ORDER_FRAGMENT_SIZE_BITS = 8ull;

template<class T> __host__ __device__ constexpr int get_bits() {
  if constexpr (std::is_same_v<T, int> || std::is_same_v<T, unsigned int>) {
    return 32;
  } else {
    return 64;
  }
}

template<class Bit_t, unsigned BITS> struct TileArray {
private:
  Bit_t *buffer{};

public:
  Bit_t *ptr{};

  explicit TileArray(Bit_t *_ptr) : buffer(_ptr), ptr(_ptr) {}

  void push(Bit_t s, Bit_t z, Bit_t *x, int n, Bit_t buff = Bit_t{}) {
    Bit_t b = (z << BITS) | s;
    int i = 0, j = 2;

    constexpr int VALUES_PER_ADDR = get_bits<Bit_t>() / BITS;
    const int VALUES_TO_ADD = n;

    for (; i < VALUES_TO_ADD;) {
      for (; j < VALUES_PER_ADDR && i < VALUES_TO_ADD; j++, i++) {
        b |= (x[i] << (j * BITS));
      }
      *buffer = b | buff;
      b = 0;
      j = 0;
      buffer++;
    }
  }
};

template<class Bit_t> struct _BitArray {
  Bit_t *w{};
  const int bits;
  Bit_t *out;

  _BitArray(Bit_t *w, int bits) : w(w), bits(bits), out(nullptr) {}

  __host__ __device__ Bit_t operator[](int w_id) {
    int addr = (w_id / 18);
    return (w[addr] >> ((w_id % 18ull) * bits)) & ((1ull << bits) - 1ull);
  }
};

union ColVal {
  u32 _;

  struct {
    unsigned short c;
    half v;
  } members;
};

template<class T>
union ColValT {
  u32 _;

  struct {
    unsigned short c;
    T v;
  } members;
};

union SecondOrder {
  u64 v;

  struct SO {
    half2 ss;
    half2 zz;
  } members;

  __device__ __forceinline__ half2 get_sws2() const { return members.ss; }

  __device__ __forceinline__ half2 get_swz2() const { return members.zz; }
};

#define CHECK_CUDA(func)                                                       \
  {                                                                            \
    cudaError_t status = (func);                                               \
    if (status != cudaSuccess) {                                               \
      printf("CUDA API failed at line %d with error: %s (%d)\n", __LINE__,     \
             cudaGetErrorString(status), status);                              \
      exit(1);                                                     \
    }                                                                          \
  }

struct Timer {
  cudaEvent_t ce_start{}, ce_stop{};
  cudaStream_t stream;

  inline void start() { cudaEventRecord(ce_start, stream); }

  inline float end_and_measure() {
    float time_ms{};
    cudaEventRecord(ce_stop, stream);
    cudaEventSynchronize(ce_stop);
    cudaEventElapsedTime(&time_ms, ce_start, ce_stop);
    // Returns ms
    return time_ms;
  }

  inline Timer(cudaStream_t stream) : stream(stream) {
    cudaEventCreate(&ce_start);
    cudaEventCreate(&ce_stop);
  }

  inline Timer(Timer &&timer) = delete;

  inline Timer(const Timer &timer) = delete;

  ~Timer() {
    cudaEventDestroy(ce_start);
    cudaEventDestroy(ce_stop);
  }
};

#define IS_PROFILER_MODE 0
#define PERSISTENT_CACHE 0
#define ENABLE_LORA 0
#define WITH_SWIZZLE 1
#define ENABLE_SPARSITY_V2 1
#define TRANSPOSE_SPQR 0
#define ENABLE_STREAM_K 0

#if IS_PROFILER_MODE
static constexpr int BENCHMARK_SPQR_NUM_RUNS = 16;
static constexpr int BENCHMARK_SPQR_WARMUPS = 0;
#else
static constexpr int BENCHMARK_SPQR_NUM_RUNS = 2000;
static constexpr int BENCHMARK_SPQR_WARMUPS = 1900;
#endif

enum class SparseCompressionStrategy { CSR = 0, PTCSR = 1, PTCSR_V2 = 2 };

inline u8 *swizzle_weights(int m, int n, const u8 *w_original, u8 *w_host_updated, bool is_bf16, int bits) {
  auto w_host_updated_base = w_host_updated;
  for (int _i = 0; _i < m; _i += 16) {
    for (int _j = 0; _j < n; _j += 16) {
      for (u64 j = 0; j < 8; j++) {
        for (u64 k = 0; k < 4; k++) {
          *(w_host_updated++) = w_original[n * (_i + j) + _j + 2 * k + 0];
          *(w_host_updated++) = w_original[n * (_i + j) + _j + 2 * k + 1];


          *(w_host_updated++) = w_original[n * (_i + j + 8) + _j + 2 * k + 0];
          *(w_host_updated++) = w_original[n * (_i + j + 8) + _j + 2 * k + 1];


          *(w_host_updated++) = w_original[n * (_i + j) + 8 + _j + 2 * k + 0];
          *(w_host_updated++) = w_original[n * (_i + j) + 8 + _j + 2 * k + 1];


          *(w_host_updated++) = w_original[n * (_i + j + 8) + 8 + _j + 2 * k + 0];
          *(w_host_updated++) = w_original[n * (_i + j + 8) + 8 + _j + 2 * k + 1];
        }
      }

    }
  }

  if (bits >= 6) {
    return w_host_updated_base;
  }

  if (!is_bf16) {
    // We do this for fast int -> fp16 conversion.
    for (int i = 0; i < m * n; i += 8) {
      u8 base[8] = {
          w_host_updated_base[0],
          w_host_updated_base[1],
          w_host_updated_base[2],
          w_host_updated_base[3],
          w_host_updated_base[4],
          w_host_updated_base[5],
          w_host_updated_base[6],
          w_host_updated_base[7]
      };

      w_host_updated_base[0] = base[0];
      w_host_updated_base[4] = base[1];
      w_host_updated_base[1] = base[2];
      w_host_updated_base[5] = base[3];
      w_host_updated_base[2] = base[4];
      w_host_updated_base[6] = base[5];
      w_host_updated_base[3] = base[6];
      w_host_updated_base[7] = base[7];

      w_host_updated_base += 8;
    }
  } else {
    // We do this for fast int -> bf16 conversion.
    //    7    6    5    4    3    2     1   0
    // 0b0000 0000 1111 0000 0000 0000 1111 0000
    //              1                    0
    //         3                    2
    //     4                   5
    //                    7                   6
    for (int i = 0; i < m * n; i += 8) {
      auto w = w_host_updated_base;
      std::swap(w[2 * 2], w[2 * 2 + 1]);

      const u8 base[8] = {
          w_host_updated_base[0],
          w_host_updated_base[1],
          w_host_updated_base[2],
          w_host_updated_base[3],
          w_host_updated_base[4],
          w_host_updated_base[5],
          w_host_updated_base[6],
          w_host_updated_base[7]
      };

      w_host_updated_base[0] = base[6];
      w_host_updated_base[1] = base[0];
      w_host_updated_base[2] = base[2];
      w_host_updated_base[3] = base[5];
      w_host_updated_base[4] = base[7];
      w_host_updated_base[5] = base[1];
      w_host_updated_base[6] = base[3];
      w_host_updated_base[7] = base[4];



      w_host_updated_base += 8;
    }
  }

  return w_host_updated;
}



inline u32 *swizzle_weights_3bit(int m, int n, u64 *w_original) {
  int num_w_64 = (m * n) / 16;
  u32 *w_host_updated = new u32[2 * num_w_64];
  constexpr static unsigned long long int NUM_USEFUL_BITS = 18ull * static_cast<u64>(3);
  auto idx = [](u64 _w, u64 i) {
    _w >>= (6ull * i + 6ull);
    return _w & 0b111111ull;
  };
  auto w_host_ptr = w_original;
  for (int i = 0; i < num_w_64; i += 32) {
    for (u64 j = 0; j < 8; j++) {
      for (u64 k = 0; k < 4; k++) {
        u64 swizzled = (w_host_ptr[4 * j + k] & ((1ull << 6) - 1ull)) | (((w_host_ptr[4 * j + k] >> 54ull) & 0xFFull) << 54ull);
        for (u64 o = 0; o < 2; o++) {
          for (u64 l = 0; l < 2; l++) {
            for (u64 p = 0; p < 2; p++) {
              u64 b = idx(w_host_ptr[16 * o + j + 8ull * p], 4ull * l + k);
              swizzled |= (b << (6ull + (4ull * o + 2ull * l + p) * 6ull));
            }
          }
        }

        u32 w2 = u32(u64(swizzled >> NUM_USEFUL_BITS) & ((1ull << SECOND_ORDER_FRAGMENT_SIZE_BITS) - 1ull));
        u32 w1 = swizzled & 0b111111ull;
        u64 w = (swizzled >> 6ull) & ((1ull << 48ull) - 1ull);

        u64 w0_lower = w & ((1ull << 18ull) - 1ull);
        u64 w0_higher = (w >> 18ull) & ((1ull << 30ull) - 1ull);

        w_host_updated[i + 4 * j + k] = w2 | (w1 << SECOND_ORDER_FRAGMENT_SIZE_BITS) | (w0_lower << 14ull);
        w_host_updated[i + num_w_64 + 4 * j + k] = w0_higher;
      }
    }
    w_host_ptr += 32;
  }
  return w_host_updated;
}

struct SparsityV2 {
  SparseCompressionStrategy sparse_compression_strategy;
  const int *row_offsets;
  const ColVal *col_vals;
  int nnz;
};

inline SparsityV2 to_sparsity_v2_buggy(int beta1, int m, const int *row_offsets_ptr, const ColVal *col_vals_ptr) {
  // Preprocessing
  auto new_row_offsets = new int[m / 16 + 1];

  int cnt{};
  int sanity{};
  new_row_offsets[0] = 0;
  for (int i = 0; i < m; i += 16) {
    int _cnt{};
    for (int j = 0; j < 16; j++) {
      _cnt = (_cnt > row_offsets_ptr[i + j + 1] - row_offsets_ptr[i + j]) ? _cnt : (row_offsets_ptr[i + j + 1] - row_offsets_ptr[i + j]);
    }
    cnt += _cnt * 16;
    new_row_offsets[i / 16 + 1] = cnt;
  }
  auto new_col_vals = new ColVal[cnt];
  for (int i = 0; i < cnt; i++) {
    new_col_vals[i]._ = 0;
  }

  auto r = row_offsets_ptr;
  auto cv = col_vals_ptr;
  for (int i = 0; i < m; i += beta1) {
    int offsets[16]{};
    bool found = true;
    auto colval_ptr = new_col_vals + new_row_offsets[i / 16];
    while (found) {
      found = false;
      for (int j = 0; j < beta1; j++) {
        if (r[i + j] + offsets[j] < r[i + j + 1]) {
          found = true;
        }
      }

      if (found) {
        for (int j = 0; j < beta1; j++) {
          if (r[i + j] + offsets[j] < r[i + j + 1]) {
            *colval_ptr = cv[r[i + j] + offsets[j]];
            offsets[j]++;
          }
          sanity++;
          colval_ptr++;
        }
      }
    }
  }


  return SparsityV2{
    .sparse_compression_strategy =  SparseCompressionStrategy::PTCSR_V2,
    .row_offsets = new_row_offsets,
    .col_vals = new_col_vals,
    .nnz = cnt
  };
}


union Features {
  u32 _;

  struct {
    u32 is_async : 1;
    u32 is_bf16 : 1;
    u32 special_scale: 1;
    u32 stub: 29;
  } flags;
};

// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template <typename T, int n> struct Vec {
  T elems[n];

  DEVICE_INLINE T &operator[](int i) { return elems[i]; }

  DEVICE_INLINE const T operator[](int i) const { return elems[i]; }
};


namespace matryoshka {
struct Float16Weight {
  static constexpr const char* name = "fp16";
  using Scalar_t = half;
  using Group_t = half2;
  using FragA = Vec<half2, 4>;
  using FragB = Vec<half2, 2>;
};

struct Bfloat16Weight {
  static constexpr const char* name = "bf16";
  using Scalar_t = nv_bfloat16;
  using Group_t = nv_bfloat162;
  using FragA = Vec<nv_bfloat162, 4>;
  using FragB = Vec<nv_bfloat162, 2>;
};

void _split_up_n_bits(const u8 *w, int m, int n, u32 *buffer0, u32 *buffer1, u32 *buffer2 = nullptr);
}