#include <torch/extension.h>
#include <iostream>
#include <vector>
#include <pybind11/pybind11.h>
#include "tqdm/tqdm.h"

using std::vector;

torch::Tensor _get_hessian(int64_t pcount, vector<torch::Tensor> params, torch::Tensor grads) {
    py::gil_scoped_release release;
    torch::Tensor hessian = torch::empty({pcount, pcount});
    for(int64_t row : tqdm::range(pcount)) {
        vector<torch::Tensor> outputs = {grads[row]};
        vector<torch::Tensor> grad_outputs = vector<torch::Tensor>();
        vector<torch::Tensor> grads_2nd = torch::autograd::grad(outputs, params, grad_outputs, true, true, false);
        torch::Tensor hessian_row = torch::empty(pcount);
        int64_t low = 0;
        for(torch::Tensor g: grads_2nd) {
            g = g.reshape(-1);
            int64_t high = low + g.numel();
            hessian_row.slice(0, low, high) = g;
            low = high;
        }
        hessian.select(0, row) = hessian_row.slice(0, 0, pcount).cpu().detach();
    }
    py::gil_scoped_acquire acquire;
    return hessian;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("_get_hessian", &_get_hessian, "Hessian calculation.");
}