/*
<%
setup_pybind11(cfg)
cfg["compiler_args"] = ["-O3", "-std=c++17"]
import os
cfg["include_dirs"] += [os.path.join(os.environ["CONDA_PREFIX"], "include")]
%>
*/
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>

#include <lemon/smart_graph.h>
#include <lemon/network_simplex.h>
#include <lemon/preflow.h>
#include <lemon/tolerance.h>

#include <vector>
#include <cmath>
#include <stdexcept>

namespace py = pybind11;
using Graph = lemon::SmartDigraph;

namespace lemon {
    float       Tolerance<float>::def_epsilon       = static_cast<float>(1e-4);
    double      Tolerance<double>::def_epsilon      = 1e-10;
    long double Tolerance<long double>::def_epsilon = 1e-14;
}

/* ===================== MIN COST FLOW ===================== */

py::dict solve_mcf(
    long long n,
    py::array_t<long long, py::array::c_style | py::array::forcecast> src,
    py::array_t<long long, py::array::c_style | py::array::forcecast> dst,
    py::array_t<double,    py::array::c_style | py::array::forcecast> cost,
    py::array_t<double,    py::array::c_style | py::array::forcecast> cap,
    py::array_t<double,    py::array::c_style | py::array::forcecast> supply,
    double tol
) {
    if (n <= 0) throw std::runtime_error("n must be > 0");

    auto srcb = src.request(), dstb = dst.request();
    auto cb   = cost.request();
    auto capb = cap.request();
    auto bb   = supply.request();

    if (srcb.ndim != 1 || dstb.ndim != 1 || cb.ndim != 1 || capb.ndim != 1)
        throw std::runtime_error("src/dst/cost/cap must be 1D arrays");
    if (bb.ndim != 1) throw std::runtime_error("supply must be 1D array");

    const long long m = (long long)srcb.size;
    if ((long long)dstb.size != m || (long long)cb.size != m || (long long)capb.size != m)
        throw std::runtime_error("edge arrays must all have the same length m");
    if ((long long)bb.size != n) throw std::runtime_error("supply must have length n");

    auto srcp = (long long*)srcb.ptr;
    auto dstp = (long long*)dstb.ptr;
    auto cp   = (double*)cb.ptr;
    auto capp = (double*)capb.ptr;
    auto bp   = (double*)bb.ptr;

    Graph g;
    std::vector<Graph::Node> nodes(n);
    for (long long i = 0; i < n; ++i) nodes[i] = g.addNode();

    std::vector<Graph::Arc> arcs(m);
    for (long long e = 0; e < m; ++e) {
        if (srcp[e] < 0 || srcp[e] >= n) throw std::runtime_error("src index out of range");
        if (dstp[e] < 0 || dstp[e] >= n) throw std::runtime_error("dst index out of range");
        if (capp[e] < 0.0) throw std::runtime_error("capacity must be nonnegative");
        arcs[e] = g.addArc(nodes[srcp[e]], nodes[dstp[e]]);
    }

    Graph::ArcMap<double> capMap(g), costMap(g), flowMap(g);
    Graph::NodeMap<double> supplyMap(g), potMap(g);

    for (long long e = 0; e < m; ++e) {
        capMap[arcs[e]]  = capp[e];
        costMap[arcs[e]] = cp[e];
    }
    for (long long i = 0; i < n; ++i) supplyMap[nodes[i]] = bp[i];

    lemon::NetworkSimplex<Graph, double, double> ns(g);
    auto status = ns.upperMap(capMap).costMap(costMap).supplyMap(supplyMap).run();

    py::array_t<double> flow(m), potential(n), reduced_cost(m);
    py::array_t<bool> at_capacity(m);

    auto flowp = (double*)flow.request().ptr;
    auto potp  = (double*)potential.request().ptr;
    auto redp  = (double*)reduced_cost.request().ptr;
    auto capm  = (bool*)at_capacity.request().ptr;

    double total_cost = 0.0;

    if (status == decltype(ns)::OPTIMAL) {
        ns.flowMap(flowMap);
        ns.potentialMap(potMap);
        total_cost = ns.totalCost();

        for (long long i = 0; i < n; ++i) potp[i] = potMap[nodes[i]];

        for (long long e = 0; e < m; ++e) {
            auto a = arcs[e];
            const double x = flowMap[a];
            flowp[e] = x;

            const double pi_u = potMap[g.source(a)];
            const double pi_v = potMap[g.target(a)];
            const double rc   = costMap[a] + pi_u - pi_v;
            redp[e] = rc;

            const double up = capMap[a];
            capm[e] = (std::abs(x - up) <= tol);
        }
    } else {
        for (long long e = 0; e < m; ++e) { flowp[e] = 0.0; redp[e] = 0.0; capm[e] = false; }
        for (long long i = 0; i < n; ++i) potp[i] = 0.0;
        total_cost = 0.0;
    }

    py::dict out;
    out["status"] = (int)status;
    out["flow"] = flow;
    out["potential"] = potential;
    out["reduced_cost"] = reduced_cost;
    out["at_capacity"] = at_capacity;
    out["total_cost"] = total_cost;
    return out;
}

/* ===================== MAX FLOW ===================== */

py::dict max_flow(
    long long n,
    py::array_t<long long, py::array::c_style | py::array::forcecast> src,
    py::array_t<long long, py::array::c_style | py::array::forcecast> dst,
    py::array_t<double,    py::array::c_style | py::array::forcecast> cap,
    long long s,
    long long t
) {
    if (n <= 0) throw std::runtime_error("n must be > 0");
    if (s < 0 || s >= n) throw std::runtime_error("source s out of range");
    if (t < 0 || t >= n) throw std::runtime_error("sink t out of range");
    if (s == t) throw std::runtime_error("source and sink must be different");

    auto srcb = src.request(), dstb = dst.request(), capb = cap.request();
    if (srcb.ndim != 1 || dstb.ndim != 1 || capb.ndim != 1)
        throw std::runtime_error("src/dst/cap must be 1D arrays");

    const long long m = (long long)srcb.size;
    if ((long long)dstb.size != m || (long long)capb.size != m)
        throw std::runtime_error("src/dst/cap must have the same length m");

    auto srcp = (long long*)srcb.ptr;
    auto dstp = (long long*)dstb.ptr;
    auto capp = (double*)capb.ptr;

    Graph g;
    std::vector<Graph::Node> nodes(n);
    for (long long i = 0; i < n; ++i) nodes[i] = g.addNode();

    std::vector<Graph::Arc> arcs(m);
    for (long long e = 0; e < m; ++e) {
        if (srcp[e] < 0 || srcp[e] >= n) throw std::runtime_error("src index out of range");
        if (dstp[e] < 0 || dstp[e] >= n) throw std::runtime_error("dst index out of range");
        if (capp[e] < 0.0) throw std::runtime_error("capacity must be nonnegative");
        arcs[e] = g.addArc(nodes[srcp[e]], nodes[dstp[e]]);
    }

    Graph::ArcMap<double> capMap(g);
    for (long long e = 0; e < m; ++e) capMap[arcs[e]] = capp[e];

    lemon::Preflow<Graph, Graph::ArcMap<double>> pf(g, capMap, nodes[s], nodes[t]);
    pf.run();

    py::array_t<double> flow(m);
    auto flowp = (double*)flow.request().ptr;
    for (long long e = 0; e < m; ++e) flowp[e] = pf.flow(arcs[e]);

    py::dict out;
    out["value"] = (double)pf.flowValue();
    out["flow"] = flow;
    return out;
}

PYBIND11_MODULE(lemon_mcf, m) {
        m.def("solve_mcf", &solve_mcf,
          py::arg("n"), py::arg("src"), py::arg("dst"),
          py::arg("cost"), py::arg("cap"), py::arg("supply"),
          py::arg("tol") = 1e-9);

    m.def("max_flow", &max_flow,
          py::arg("n"), py::arg("src"), py::arg("dst"), py::arg("cap"),
          py::arg("s"), py::arg("t"));
}
