#include <torch/extension.h>

void multi_tensor_lamb_compute_update_term_cuda(
  int chunk_size,
  at::Tensor noop_flag,
  std::vector<std::vector<at::Tensor>> tensor_lists,
  at::Tensor per_tensor_beta1,
  at::Tensor per_tensor_beta2,
  at::Tensor per_tensor_beta3,
  at::Tensor per_tensor_bias_correction,
  at::Tensor step,
  at::Tensor per_tensor_epsilon,
  const int mode,
  at::Tensor per_tensor_decay,
  at::Tensor global_scale,
  at::Tensor global_grad_norm,
  const float max_grad_norm);

void multi_tensor_lamb_update_weights_cuda(
  int chunk_size,
  at::Tensor noop_flag,
  std::vector<std::vector<at::Tensor>> tensor_lists,
  at::Tensor per_tensor_param_norm,
  at::Tensor per_tensor_update_norm,
  at::Tensor update_norm_offset,
  at::Tensor learning_rate,
  at::Tensor per_tensor_decay,
  at::Tensor global_grad_norm,
  bool use_nvlamb);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("multi_tensor_lamb_compute_update_term", &multi_tensor_lamb_compute_update_term_cuda,
        "Computes update term for LAMB optimizer", py::call_guard<py::gil_scoped_release>());
  m.def("multi_tensor_lamb_update_weights", &multi_tensor_lamb_update_weights_cuda,
        "Applies update term for LAMB optimizer", py::call_guard<py::gil_scoped_release>());
}
