#include <torch/extension.h>
#include <gemm.h>
#include "fused_quantize_host.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>
#include <vector>
#include <iostream>
#include <utility>


namespace QUTLASS::quartet {

torch::Tensor matmul_mxf4_bf16_tn(
                        torch::Tensor const& A,
                        torch::Tensor const& B,
                        torch::Tensor const& A_sf,
                        torch::Tensor const& B_sf,
                        float alpha)
{
    torch::checkAllContiguous("matmul_mxf4_bf16_tn", {{A, "A", 0},
                                                      {B, "B", 1}});
    torch::checkDeviceType("matmul_mxf4_bf16_tn", {A, B}, at::DeviceType::CUDA);

    torch::checkAllSameGPU("matmul_mxf4_bf16_tn", {{A, "A", 0},
                                                   {B, "B", 1}});
    uint32_t M = A.size(0);
    uint32_t N = B.size(0);
    auto OUT = torch::empty({M, N}, torch::dtype(torch::kBFloat16).device(A.device()));

    matmul_host_mxf4_bf16_tn(OUT, A, B, A_sf, B_sf, alpha);

    return OUT;
}

torch::Tensor matmul_mxf8_bf16_tn(
                        torch::Tensor const& A,
                        torch::Tensor const& B,
                        torch::Tensor const& A_sf,
                        torch::Tensor const& B_sf,
                        float alpha)
{
    torch::checkAllContiguous("matmul_mxf8_bf16_tn", {{A, "A", 0},
                                                      {B, "B", 1}});
    torch::checkDeviceType("matmul_mxf8_bf16_tn", {A, B}, at::DeviceType::CUDA);

    torch::checkAllSameGPU("matmul_mxf8_bf16_tn", {{A, "A", 0},
                                                   {B, "B", 1}});
    uint32_t M = A.size(0);
    uint32_t N = B.size(0);
    auto OUT = torch::empty({M, N}, torch::dtype(torch::kBFloat16).device(A.device()));

    matmul_host_mxf8_bf16_tn(OUT, A, B, A_sf, B_sf, alpha);

    return OUT;
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> fusedQuantize(
                        torch::Tensor const& A,
                        torch::Tensor const& B,
                        torch::Tensor& OUT,
                        torch::Tensor& OUT_sf,
                        torch::Tensor& OUT_mask)
{
    torch::checkAllContiguous("fusedQuantize", {{A, "A", 0},
                                                {B, "B", 1}});
    torch::checkDeviceType("fusedQuantize", {A, B}, at::DeviceType::CUDA);

    torch::checkAllSameGPU("fusedQuantize", {{A, "A", 0},
                                             {B, "B", 1}});
    //uint32_t M = A.size(0)*A.size(1)/32;
    //uint32_t N = B.size(1);

    //auto OUT = torch::empty({M, N/2}, torch::dtype(torch::kByte).device(A.device())); //FIXME: uint32_t/float_e2m1_t
    //auto OUT_sf = torch::empty({M*N/32, 1}, torch::dtype(torch::kFloat8_e8m0fnu).device(A.device()));
    //auto OUT_mask = torch::empty({M, N/8}, torch::dtype(torch::kByte).device(A.device()));

    //uint32_t M = A.size(-2)* (A.numel() / (A.size(-2) * A.size(-1)));

    //printf("%d %d %d\n", A.size(-1), A.size(-2), A.numel() / (A.size(-2) * A.size(-1)));

    fusedQuantize_host(OUT, OUT_sf, OUT_mask, A, B);

    return std::make_tuple(OUT, OUT_sf, OUT_mask);
}

std::tuple<torch::Tensor, torch::Tensor> fusedQuantize_bwd(
                        torch::Tensor const& A,
                        torch::Tensor const& B,
                        torch::Tensor& OUT,
                        torch::Tensor& OUT_sf)
{
    torch::checkAllContiguous("fusedQuantize", {{A, "A", 0},
                                                {B, "B", 1}});
    torch::checkDeviceType("fusedQuantize", {A, B}, at::DeviceType::CUDA);

    torch::checkAllSameGPU("fusedQuantize", {{A, "A", 0},
                                             {B, "B", 1}});
    //uint32_t M = A.size(0)*A.size(1)/32;
    //uint32_t N = B.size(1);

    //auto OUT = torch::empty({M, N/2}, torch::dtype(torch::kByte).device(A.device()));
    //auto OUT_sf = torch::empty({M*N/32, 1}, torch::dtype(torch::kFloat8_e8m0fnu).device(A.device()));

    fusedQuantize_bwd_host(OUT, OUT_sf, A, B);

    return std::make_tuple(OUT, OUT_sf);
}

//====== pybind ======

#define DEFINE_pybind(name) m.def(#name, &name, #name);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m
)
{
    m.def("matmul_mxf4_bf16_tn", &matmul_mxf4_bf16_tn,
        "matmul_mxf4_bf16_tn");
    m.def("matmul_mxf8_bf16_tn", &matmul_mxf8_bf16_tn,
        "matmul_mxf8_bf16_tn");
    m.def("fusedQuantize", &QUTLASS::quartet::fusedQuantize,
        "fusedQuantize");
    m.def("fusedQuantize_bwd", &QUTLASS::quartet::fusedQuantize_bwd,
        "fusedQuantize_bwd");
}
}