#pragma once
#include <common.h>

void matmul_host_mxf4_bf16_tn(torch::Tensor& D,
                              torch::Tensor const& A,
                              torch::Tensor const& B,
                              torch::Tensor const& A_sf,
                              torch::Tensor const& B_sf,
                              torch::Tensor const& alpha);

void matmul_host_ada_mxf4_bf16_tn(torch::Tensor const& input,
                                  torch::Tensor const& weight,
                                  torch::Tensor const& input_sf,
                                  torch::Tensor const& weight_sf,
                                  torch::Tensor &out,
                                  torch::Tensor const& alpha);

void matmul_host_nvf4_bf16_tn(torch::Tensor& D,
                              torch::Tensor const& A,
                              torch::Tensor const& B,
                              torch::Tensor const& A_sf,
                              torch::Tensor const& B_sf,
                              torch::Tensor const& alpha);

void matmul_host_mxf8_bf16_tn(torch::Tensor& D,
                              torch::Tensor const& A,
                              torch::Tensor const& B,
                              torch::Tensor const& A_sf,
                              torch::Tensor const& B_sf,
                              torch::Tensor const& alpha);

void matmul_host_mxf8_bf16_nn(torch::Tensor& D,
                              torch::Tensor const& A,
                              torch::Tensor const& B,
                              torch::Tensor const& A_sf,
                              torch::Tensor const& B_sf,
                              torch::Tensor const& alpha);