// astar.cpp

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h> // Include this header for py::cast with STL containers
#include <queue>
#include <vector>
#include <unordered_map>
#include <limits>
#include <algorithm> // For std::reverse

namespace py = pybind11;

struct Node {
    int index;
    double g_cost; // Cost from start to this node
    double f_cost; // Estimated total cost (g + h)

    Node(int idx, double g, double f) : index(idx), g_cost(g), f_cost(f) {}

    // Comparator for the priority queue (min-heap)
    bool operator>(const Node& other) const {
        return f_cost > other.f_cost;
    }
};

py::object astar(
    int start_idx,
    int end_idx,
    py::array_t<double> node_positions,  // (M, N)
    py::array_t<int> edge_idx,           // (M, E)
    py::array_t<int> edge_invalid_mask,  // (M, E)
    py::array_t<double> edge_cost,       // (M, E)
    py::array_t<double> heuristic        // (M)
) {
    // Access arrays with unchecked for performance
    auto edge_idx_ = edge_idx.unchecked<2>();
    auto edge_invalid_mask_ = edge_invalid_mask.unchecked<2>();
    auto edge_cost_ = edge_cost.unchecked<2>();
    auto heuristic_ = heuristic.unchecked<1>();

    int M = node_positions.shape(0); // Number of nodes
    int E = edge_idx.shape(1);       // Number of edges per node

    // Priority queue for the open list
    std::priority_queue<Node, std::vector<Node>, std::greater<Node>> open_list;

    // Vectors for tracking costs and paths
    std::vector<double> g_costs(M, std::numeric_limits<double>::infinity());
    std::vector<int> came_from(M, -1);
    std::vector<bool> closed_set(M, false);

    // Initialize the start node
    g_costs[start_idx] = 0.0;
    double f_start = heuristic_[start_idx];
    open_list.emplace(start_idx, 0.0, f_start);

    while (!open_list.empty()) {
        Node current = open_list.top();
        open_list.pop();

        if (closed_set[current.index]) {
            continue;
        }
        closed_set[current.index] = true;

        if (current.index == end_idx) {
            // Reconstruct the path
            std::vector<int> path;
            int idx = end_idx;
            while (idx != -1) {
                path.push_back(idx);
                idx = came_from[idx];
            }
            std::reverse(path.begin(), path.end()); // Reverse the path
            return py::cast(path);                  // Return the path as a Python list
        }

        // Explore neighbors
        for (int i = 0; i < E; ++i) {
            int neighbor_idx = edge_idx_(current.index, i);
            if (neighbor_idx == -1 || neighbor_idx >= M) {
                continue; // Skip invalid indices
            }
            if (edge_invalid_mask_(current.index, i) == 1) {
                continue; // Skip invalid edges
            }
            if (closed_set[neighbor_idx]) {
                continue; // Skip already evaluated nodes
            }

            double tentative_g_cost = g_costs[current.index] + edge_cost_(current.index, i);

            if (tentative_g_cost < g_costs[neighbor_idx]) {
                g_costs[neighbor_idx] = tentative_g_cost;
                double f_cost = tentative_g_cost + heuristic_[neighbor_idx];
                open_list.emplace(neighbor_idx, tentative_g_cost, f_cost);
                came_from[neighbor_idx] = current.index;
            }
        }
    }

    // If no path is found, return None
    return py::none();
}

PYBIND11_MODULE(astar_module, m) {
    m.doc() = "A* search algorithm implemented in C++ with Python bindings";
    m.def("astar", &astar, "A* search algorithm",
          py::arg("start_idx"),
          py::arg("end_idx"),
          py::arg("node_positions"),
          py::arg("edge_idx"),
          py::arg("edge_invalid_mask"),
          py::arg("edge_cost"),
          py::arg("heuristic"));
}
