#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <vector>
#include <cmath>
#include <random>
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#include <cuda_bf16.h>
#include <sampled_gs_quant.h>

#define MAX_N_GRID 16
#define THREADS_PER_BLOCK 512

#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=false)
{
   if (code != cudaSuccess)
   {
      fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
      if (abort) exit(code);
   }
}

__device__ __forceinline__ int hash(int x, int y)
{
    return (x * 73856093) ^ (y * 19349663);
}

__device__ static float atomicMaxFloat(float* address, float val)
{
    int* address_as_i = (int*) address;
    int old = *address_as_i, assumed;
    do {
        assumed = old;
        old = atomicCAS(address_as_i, assumed,
            __float_as_int(::fmaxf(val, __int_as_float(assumed))));
    } while (assumed != old);
    return __int_as_float(old);
}

__global__ void sampled_max_abs_kernel(const __nv_bfloat16 *__restrict__ matrix, const int numel,
                                       const int tile_count, const int tile_length, const int seed, float *global_max)
{
    using BlockReduce = cub::BlockReduce<__nv_bfloat16, THREADS_PER_BLOCK>;
    __shared__ typename BlockReduce::TempStorage reduce_storage;

    const int thread_id = threadIdx.x;
    const int global_sample_idx = blockIdx.x * blockDim.x + threadIdx.x;

    const int total_tiles = (numel + tile_length - 1) / tile_length;
    const int tile_idx = abs(hash(global_sample_idx / tile_length, seed)) % total_tiles;
    const int tile_offset = global_sample_idx % tile_length;

    __nv_bfloat16 val = __float2bfloat16(0.0f);

    if (global_sample_idx < tile_count * tile_length)
    {
        const int idx = tile_idx * tile_length + tile_offset;

        if (idx < numel)
            val = __habs(matrix[idx]);
    }

    __nv_bfloat16 block_max = BlockReduce(reduce_storage).Reduce(val, cub::Max());

    if (thread_id == 0)
    {
        atomicMaxFloat(global_max, __bfloat162float(block_max));
    }
}

torch::Tensor sampled_max_abs(torch::Tensor matrix, int tile_count, int tile_length, int seed)
{
    TORCH_CHECK(matrix.is_cuda(), "Input tensor 'matrix' must be on CUDA device.");
    TORCH_CHECK(matrix.dtype() == torch::kBFloat16, "Input tensor 'matrix' must be of type BFloat16.");
    TORCH_CHECK(matrix.is_contiguous(), "Input tensor 'matrix' must be contiguous.");

    const int numel = matrix.numel();

    const int total_samples = tile_count * tile_length;
    const int num_blocks = (total_samples + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;

    const torch::Tensor global_max = torch::zeros({1}, torch::device(matrix.device()).dtype(torch::kFloat));

    sampled_max_abs_kernel<<<num_blocks, THREADS_PER_BLOCK>>>(
        reinterpret_cast<const __nv_bfloat16 *>(matrix.data_ptr<at::BFloat16>()), numel,
        tile_count, tile_length, seed,
        global_max.data_ptr<float>());

    gpuErrchk( cudaDeviceSynchronize() );
    gpuErrchk( cudaPeekAtLastError() );

    return global_max;
}

__global__ void scale_grid_search_kernel(const __nv_bfloat16 *__restrict__ matrix, const int numel, const int n_grid, const __nv_bfloat16 absmax,
                                         const int tile_count, const int tile_length, const int seed, float *global_scale_errs)
{
    using BlockReduce = cub::BlockReduce<float, THREADS_PER_BLOCK>;
    __shared__ typename BlockReduce::TempStorage reduce_storage;

    const int thread_id = threadIdx.x;
    const int global_sample_idx = blockIdx.x * blockDim.x + threadIdx.x;

    const int total_tiles = (numel + tile_length - 1) / tile_length;
    const int tile_idx = abs(hash(global_sample_idx / tile_length, seed)) % total_tiles;
    const int tile_offset = global_sample_idx % tile_length;

    const __nv_bfloat16 upper_bound = __float2bfloat16(127.0f);
    const __nv_bfloat16 lower_bound = __float2bfloat16(-127.0f);
    const float absmax_f = __bfloat162float(absmax);


    __nv_bfloat16 val_t = __float2bfloat16(0.0f);
    if (global_sample_idx < tile_count * tile_length)
    {
        const int idx = tile_idx * tile_length + tile_offset;

        if (idx < numel){
            val_t = matrix[idx];
        }
    }
    const __nv_bfloat16 val = val_t;
    for (int scale_idx = 0; scale_idx < n_grid; scale_idx++)
    {
        const __nv_bfloat16 scale = __float2bfloat16(absmax_f / 127.0f * (scale_idx+1) / n_grid);
        __nv_bfloat16 q_val = val / scale;
        q_val = __hmin(upper_bound, __hmax(lower_bound, q_val));
        q_val = hrint(q_val);
        __nv_bfloat16 error = (q_val * scale) - val;
        error *= error;
        const float block_sum = BlockReduce(reduce_storage).Reduce(__bfloat162float(error), cub::Sum());
        __syncthreads();

        if (thread_id == 0)
        {
            atomicAdd(&global_scale_errs[scale_idx], block_sum);
        }
    }
}

std::tuple<torch::Tensor, torch::Tensor> sampled_scale_grid_search(
        torch::Tensor matrix, float absmax, int n_grid, int tile_count, int tile_length, int seed)
{
    TORCH_CHECK(matrix.is_cuda(), "Input tensor 'matrix' must be on CUDA device.");
    TORCH_CHECK(matrix.dtype() == torch::kBFloat16, "Input tensor 'matrix' must be of type BFloat16.");
    TORCH_CHECK(matrix.is_contiguous(), "Input tensor 'matrix' must be contiguous.");

    const int numel = matrix.numel();

    const int total_samples = tile_count * tile_length;

    const int num_blocks = (total_samples + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;

    torch::Tensor global_scale_errs = torch::zeros({n_grid}, torch::device(matrix.device()).dtype(torch::kFloat));

    scale_grid_search_kernel<<<num_blocks, THREADS_PER_BLOCK>>>(
        reinterpret_cast<const __nv_bfloat16 *>(matrix.data_ptr<at::BFloat16>()), numel, n_grid, __float2bfloat16(absmax),
        tile_count, tile_length, seed,
        global_scale_errs.data_ptr<float>());

    gpuErrchk( cudaDeviceSynchronize() );
    gpuErrchk( cudaPeekAtLastError() );

    int min_scale_idx = global_scale_errs.argmin().item<int>();
    float optimal_scale = (absmax / 127.0f) * (min_scale_idx + 1) / n_grid;
    at::Tensor optimal_scale_tensor = torch::tensor(optimal_scale, torch::dtype(torch::kFloat32).device(matrix.device()));
    return std::make_tuple(global_scale_errs, optimal_scale_tensor);
}

__global__ void quantize_tensor_kernel(
    const __nv_bfloat16* __restrict__ input, int8_t* __restrict__ output,
    const __nv_bfloat16 scale, const int num_elements)
{
    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
    const __nv_bfloat16 upper_bound = __float2bfloat16(127.0f);
    const __nv_bfloat16 lower_bound = __float2bfloat16(-127.0f);

    if (idx < num_elements)
    {
        __nv_bfloat16 val = __hmin(upper_bound, __hmax(lower_bound, input[idx] / scale));
        output[idx] = __bfloat162short_rn(val);
    }
}

torch::Tensor quantize_tensor(torch::Tensor input_tensor, float scale)
{
    TORCH_CHECK(input_tensor.is_cuda(), "Input tensor 'input_tensor' must be on CUDA device.");
    TORCH_CHECK(input_tensor.dtype() == torch::kBFloat16, "Input tensor 'input_tensor' must be of type BFloat16.");
    TORCH_CHECK(input_tensor.is_contiguous(), "Input tensor 'input_tensor' must be contiguous.");
    int num_elements = input_tensor.numel();

    auto output_tensor = torch::empty_like(input_tensor, torch::dtype(torch::kInt8));

    int num_blocks = (num_elements + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;

    __nv_bfloat16 scale_bf16 = __float2bfloat16(scale);

    quantize_tensor_kernel<<<num_blocks, THREADS_PER_BLOCK>>>(
        reinterpret_cast<const __nv_bfloat16 *>(input_tensor.data_ptr<at::BFloat16>()),
        output_tensor.data_ptr<int8_t>(),
        scale_bf16, num_elements);

    gpuErrchk( cudaDeviceSynchronize() );
    gpuErrchk( cudaPeekAtLastError() );

    return output_tensor;
}

std::tuple<torch::Tensor, torch::Tensor> grid_search_quant_int8(
    torch::Tensor input_tensor, int n_grid, float sampling, int seed, bool do_quant)
{
    TORCH_CHECK(input_tensor.is_cuda(), "Input tensor 'input_tensor' must be on CUDA device.");
    TORCH_CHECK(input_tensor.dtype() == torch::kBFloat16, "Input tensor 'input_tensor' must be of type BFloat16.");
    TORCH_CHECK(input_tensor.is_contiguous(), "Input tensor 'input_tensor' must be contiguous.");
    const int tile_length = 128;
    const int tile_count = std::max(1, static_cast<int>((input_tensor.numel() * sampling) / tile_length));
    at::Tensor absmax_tensor = sampled_max_abs(input_tensor, tile_count, tile_length, seed);
    float absmax = absmax_tensor.item<float>();

    auto [scale_errs, optimal_scale_tensor] = sampled_scale_grid_search(
            input_tensor, absmax, n_grid, tile_count, tile_length, seed);
    float optimal_scale = optimal_scale_tensor.item<float>();
    if (do_quant) {
        auto output_tensor = quantize_tensor(input_tensor, optimal_scale);
        return std::make_tuple(output_tensor, optimal_scale_tensor);
    } else {
        return std::make_tuple(input_tensor, optimal_scale_tensor);
    }
}

// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
// {
//     m.def("sampled_max_abs", &sampled_max_abs, "Find the max absolute value using sampling",
//           py::arg("matrix"), py::arg("tile_count"), py::arg("tile_length"), py::arg("seed"));
//     m.def("sampled_scale_grid_search", &sampled_scale_grid_search, "Find the optimal scale using grid search",
//           py::arg("matrix"), py::arg("absmax"), py::arg("n_grid"), py::arg("tile_count"), py::arg("tile_length"), py::arg("seed"));
//     m.def("quantize_tensor", &quantize_tensor, "Quantize tensor to int8 (CUDA)",
//           py::arg("input_tensor"), py::arg("scale"));
//     m.def("grid_search_quant_int8", &grid_search_quant_int8, "Quantize tensor to int8 using grid search",
//             py::arg("input_tensor"), py::arg("n_grid")=10, py::arg("sampling")=0.1, py::arg("seed")=42, py::arg("do_quant")=true);
// }
