#pragma once
#include <common.h>

namespace QUTLASS {

void fusedQuantizeMxQuest_host(torch::Tensor&       D,
                               torch::Tensor&       D_sf,
                               torch::Tensor const& A,
                               torch::Tensor const& B);

void fusedQuantizeMxQuestWithMask_host(torch::Tensor&       D,
                                       torch::Tensor&       D_sf,
                                       torch::Tensor&       D_mask,
                                       torch::Tensor const& A,
                                       torch::Tensor const& B);

void fusedQuantizeMxAbsMax_host(torch::Tensor&       D,
                                torch::Tensor&       D_sf,
                                torch::Tensor const& A,
                                torch::Tensor const& B);

void fusedQuantizeMxQuestHad64_host(torch::Tensor&       D,
                                    torch::Tensor&       D_sf,
                                    torch::Tensor const& A,
                                    torch::Tensor const& B);

void fusedQuantizeMxAbsMaxHad64_host(torch::Tensor&       D,
                                     torch::Tensor&       D_sf,
                                     torch::Tensor const& A,
                                     torch::Tensor const& B);

void fusedQuantizeMxQuestHad128_host(torch::Tensor&       D,
                                     torch::Tensor&       D_sf,
                                     torch::Tensor const& A,
                                     torch::Tensor const& B);

void fusedQuantizeMxAbsMaxHad128_host(torch::Tensor&       D,
                                      torch::Tensor&       D_sf,
                                      torch::Tensor const& A,
                                      torch::Tensor const& B);

void fusedQuantizeNvQuest_host(torch::Tensor&       D,
                               torch::Tensor&       D_sf,
                               torch::Tensor const& A,
                               torch::Tensor const& B,
                               torch::Tensor const& global_scale);

void fusedQuantizeNvQuestHad32_host(torch::Tensor&       D,
                                    torch::Tensor&       D_sf,
                                    torch::Tensor const& A,
                                    torch::Tensor const& B,
                                    torch::Tensor const& global_scale);

void fusedQuantizeNvQuestHad64_host(torch::Tensor&       D,
                                    torch::Tensor&       D_sf,
                                    torch::Tensor const& A,
                                    torch::Tensor const& B,
                                    torch::Tensor const& global_scale);

void fusedQuantizeNvQuestHad128_host(torch::Tensor&       D,
                                     torch::Tensor&       D_sf,
                                     torch::Tensor const& A,
                                     torch::Tensor const& B,
                                     torch::Tensor const& global_scale);

void fusedQuantizeNvAbsMax_host(torch::Tensor&       D,
                                torch::Tensor&       D_sf,
                                torch::Tensor const& A,
                                torch::Tensor const& B,
                                torch::Tensor const& global_scale);

void fusedQuantizeNvAbsMaxHad32_host(torch::Tensor&       D,
                                     torch::Tensor&       D_sf,
                                     torch::Tensor const& A,
                                     torch::Tensor const& B,
                                     torch::Tensor const& global_scale);

void fusedQuantizeNvAbsMaxHad64_host(torch::Tensor&       D,
                                     torch::Tensor&       D_sf,
                                     torch::Tensor const& A,
                                     torch::Tensor const& B,
                                     torch::Tensor const& global_scale);

void fusedQuantizeNvAbsMaxHad128_host(torch::Tensor&       D,
                                      torch::Tensor&       D_sf,
                                      torch::Tensor const& A,
                                      torch::Tensor const& B,
                                      torch::Tensor const& global_scale);

void fusedQuantizeMxAbsMax_host_sm100(torch::Tensor&       D,
                                      torch::Tensor&       D_sf,
                                      torch::Tensor const& A,
                                      torch::Tensor const& B,
                                      torch::Tensor const& global_scale);

void fusedQuantizeNvAbsMax_host_sm100(torch::Tensor&       D,
                                      torch::Tensor&       D_sf,
                                      torch::Tensor const& A,
                                      torch::Tensor const& B,
                                      torch::Tensor const& global_scale);

void fusedQuantizeWushMxAbsMax_host(torch::Tensor&       D,
                                    torch::Tensor&       D_sf,
                                    torch::Tensor const& A,
                                    torch::Tensor const& B);                                      

}  // namespace QUTLASS
