#include <pybind11/functional.h>
#include <pybind11/stl.h>
#include <pybind11/type_caster_pyobject_ptr.h>

#include "pybind11_tests.h"

#include <cstddef>
#include <vector>

namespace {

std::vector<PyObject *> make_vector_pyobject_ptr(const py::object &ValueHolder) {
    std::vector<PyObject *> vec_obj;
    for (int i = 1; i < 3; i++) {
        vec_obj.push_back(ValueHolder(i * 93).release().ptr());
    }
    // This vector now owns the refcounts.
    return vec_obj;
}

} // namespace

TEST_SUBMODULE(type_caster_pyobject_ptr, m) {
    m.def("cast_from_pyobject_ptr", []() {
        PyObject *ptr = PyLong_FromLongLong(6758L);
        return py::cast(ptr, py::return_value_policy::take_ownership);
    });
    m.def("cast_handle_to_pyobject_ptr", [](py::handle obj) {
        auto rc1 = obj.ref_count();
        auto *ptr = py::cast<PyObject *>(obj);
        auto rc2 = obj.ref_count();
        if (rc2 != rc1 + 1) {
            return -1;
        }
        return 100 - py::reinterpret_steal<py::object>(ptr).attr("value").cast<int>();
    });
    m.def("cast_object_to_pyobject_ptr", [](py::object obj) {
        py::handle hdl = obj;
        auto rc1 = hdl.ref_count();
        auto *ptr = py::cast<PyObject *>(std::move(obj));
        auto rc2 = hdl.ref_count();
        if (rc2 != rc1) {
            return -1;
        }
        return 300 - py::reinterpret_steal<py::object>(ptr).attr("value").cast<int>();
    });
    m.def("cast_list_to_pyobject_ptr", [](py::list lst) {
        // This is to cover types implicitly convertible to object.
        py::handle hdl = lst;
        auto rc1 = hdl.ref_count();
        auto *ptr = py::cast<PyObject *>(std::move(lst));
        auto rc2 = hdl.ref_count();
        if (rc2 != rc1) {
            return -1;
        }
        return 400 - static_cast<int>(py::len(py::reinterpret_steal<py::list>(ptr)));
    });

    m.def(
        "return_pyobject_ptr",
        []() { return PyLong_FromLongLong(2314L); },
        py::return_value_policy::take_ownership);
    m.def("pass_pyobject_ptr", [](PyObject *ptr) {
        return 200 - py::reinterpret_borrow<py::object>(ptr).attr("value").cast<int>();
    });

    m.def("call_callback_with_object_return",
          [](const std::function<py::object(int)> &cb, int value) { return cb(value); });
    m.def(
        "call_callback_with_pyobject_ptr_return",
        [](const std::function<PyObject *(int)> &cb, int value) { return cb(value); },
        py::return_value_policy::take_ownership);
    m.def(
        "call_callback_with_pyobject_ptr_arg",
        [](const std::function<int(PyObject *)> &cb, py::handle obj) { return cb(obj.ptr()); },
        py::arg("cb"), // This triggers return_value_policy::automatic_reference
        py::arg("obj"));

    m.def("cast_to_pyobject_ptr_nullptr", [](bool set_error) {
        if (set_error) {
            py::set_error(PyExc_RuntimeError, "Reflective of healthy error handling.");
        }
        PyObject *ptr = nullptr;
        py::cast(ptr);
    });

    m.def("cast_to_pyobject_ptr_non_nullptr_with_error_set", []() {
        py::set_error(PyExc_RuntimeError, "Reflective of unhealthy error handling.");
        py::cast(Py_None);
    });

    m.def("pass_list_pyobject_ptr", [](const std::vector<PyObject *> &vec_obj) {
        int acc = 0;
        for (const auto &ptr : vec_obj) {
            acc = acc * 1000 + py::reinterpret_borrow<py::object>(ptr).attr("value").cast<int>();
        }
        return acc;
    });

    m.def("return_list_pyobject_ptr_take_ownership",
          make_vector_pyobject_ptr,
          // Ownership is transferred one-by-one when the vector is converted to a Python list.
          py::return_value_policy::take_ownership);

    m.def("return_list_pyobject_ptr_reference",
          make_vector_pyobject_ptr,
          // Ownership is not transferred.
          py::return_value_policy::reference);

    m.def("dec_ref_each_pyobject_ptr", [](const std::vector<PyObject *> &vec_obj) {
        std::size_t i = 0;
        for (; i < vec_obj.size(); i++) {
            py::handle h(vec_obj[i]);
            if (static_cast<std::size_t>(h.ref_count()) < 2) {
                break; // Something is badly wrong.
            }
            h.dec_ref();
        }
        return i;
    });

    m.def("pass_pyobject_ptr_and_int", [](PyObject *, int) {});

#ifdef PYBIND11_NO_COMPILE_SECTION // Change to ifndef for manual testing.
    {
        PyObject *ptr = nullptr;
        (void) py::cast(*ptr);
    }
#endif
}
