#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include "attention.h"

using namespace GPT;
namespace py = pybind11;

PYBIND11_MODULE(chunk_attn_c, m) {
    py::class_<Trace>(m, "Trace")
      .def(py::init<bool>(), py::arg("record_kernel_t") = false)
      .def_readwrite("record_kernel_t", &Trace::record_kernel_t)
      .def_readwrite("chunk_kernel_t", &Trace::chunk_kernel_t)
      .def_readwrite("seq_kernel_t", &Trace::seq_kernel_t);

    py::class_<Attention, std::shared_ptr<Attention>>(m, "Attention")
      .def(py::init([](int n_head,
                       int d_embed,
                       int chunk_size,
                       int memory_mb,
                       int share_prefix,
                       py::object dtype,
                       py::object device) {
               torch::Dtype dtype_c = torch::get_default_dtype_as_scalartype();
               if (!dtype.is_none()) {
                   dtype_c = torch::python::detail::py_object_to_dtype(dtype);
               }
               torch::Device device_c = torch::randn(1).device();
               if (!device.is_none()) {
                   device_c = torch::python::detail::py_object_to_device(device);
               }
               return new Attention(
                 n_head, d_embed, chunk_size, memory_mb, share_prefix, dtype_c, device_c);
           }),
           py::arg("n_head") = 12,
           py::arg("d_embed") = 64,
           py::arg("chunk_size") = 64,
           py::arg("memory_mb") = 1024,
           py::arg("share_prefix") = true,
           py::arg("dtype") = py::none(),
           py::arg("device") = py::none())
      .def("forward",
           &GPT::Attention::forward,
           py::arg("q"),
           py::arg("partition") = 0,
           py::arg("trace") = nullptr)
      .def("add_prompt", &GPT::Attention::add_prompt, py::arg("tokens"), py::arg("k"), py::arg("v"))
      .def("append_completions",
           &GPT::Attention::append_completions,
           py::arg("tokens"),
           py::arg("k"),
           py::arg("v"))
      .def("duplicate", &GPT::Attention::duplicate, py::arg("seq_idx"), py::arg("copies"))
      .def("remove", &GPT::Attention::remove, py::arg("seq_idx"))
      .def("print", &GPT::Attention::print, py::arg("root") = nullptr, py::arg("level") = 0)
      .def("get_chunks_raw", &GPT::Attention::get_chunks_raw);
}
