#include <pybind11/pybind11.h>
#include <pybind11/eigen.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include "clari_tree.hpp"
#include "clari_tree_const.hpp"

#include <sstream>
#include <fstream>
#include <cstdlib>
#include <numeric>
#include <algorithm>
#include <iostream>

namespace py = pybind11;

// -------------------------------------------------------------
// Compute R^2 metric between two vectors (optimized)
// -------------------------------------------------------------
double compute_R2(const Eigen::Ref<const Eigen::VectorXd>& y_true, 
                  const Eigen::Ref<const Eigen::VectorXd>& y_pred) {
    if (y_true.size() != y_pred.size()) {
        throw std::invalid_argument("y_true and y_pred must have the same length");
    }
    
    // Use Eigen's optimized operations
    const double ss_res = (y_true - y_pred).squaredNorm();
    const double mean_y = y_true.mean();
    const double ss_tot = (y_true.array() - mean_y).matrix().squaredNorm();
    return ss_tot > 0.0 ? 1.0 - ss_res / ss_tot : 1.0;
}

// -------------------------------------------------------------
// Pybind11 module definition
// -------------------------------------------------------------
PYBIND11_MODULE(_core, m) {
    m.doc() = "clari_tree C++ core (CLARITree + CLARITreeConst + ...)";

    // ------------------- CSV read -------------------
    m.def("read_csv",
          [](const std::string& filename,
             bool has_header,
             char delimiter) {
              Eigen::MatrixXd X;
              Eigen::VectorXd y;
              if (!readCSV(filename, X, y, has_header, delimiter)) {
                  throw std::runtime_error("Failed to read CSV file: " + filename);
              }
              return std::make_pair(X, y);
          },
          py::arg("filename"),
          py::arg("has_header") = true,
          py::arg("delimiter") = ',',
          "Read CSV file into (X, y).");

    // ------------------- Greedy -------------------
    py::class_<Greedy>(m, "Greedy")
        .def(py::init<double, int, double, int, bool>(),
             py::arg("kappa"),
             py::arg("depth"),
             py::arg("lambda_") = 0.0,
             py::arg("stride") = 1,
             py::arg("verbose") = true)

        // Fit with NumPy arrays (ultra-optimized, zero-copy)
        .def("fit",
             [](Greedy& self,
                const Eigen::Ref<const Eigen::MatrixXd>& X,
                const Eigen::Ref<const Eigen::VectorXd>& y,
                const std::vector<int>& categorical_idx
                ) {
                 
                 // Release GIL immediately for maximum performance
                 py::gil_scoped_release release;
                 
                 // Direct call without any intermediate steps
                 double result = self.fit(X, y, categorical_idx);
                 
                 
                 return result;
             },
             py::arg("X"), 
             py::arg("y"),
             py::arg("categorical_idx") = std::vector<int>(),
             "Fit the tree with (X, y). Returns objective (loss).\n"
             "Args:\n"
             "  X: feature matrix (with intercept column as first column)\n"
             "  y: target vector\n"
             "  categorical_idx: indices of categorical features (0-indexed, excluding intercept)")


        // Predict values for new inputs
        .def("predict",
            [](Greedy& self,
               const Eigen::Ref<const Eigen::MatrixXd>& X) {
                py::gil_scoped_release release;
                return self.predict(X);
            },
            py::arg("X"),
            "Predict values for X.")

        .def("print_tree", &Greedy::print_tree)
        .def("n_leaves", &Greedy::n_leaves,
         "Return the number of leaf nodes in the tree.");

    // ------------------- CLARITree -------------------
    py::class_<CLARITree, Greedy>(m, "CLARITree")
        .def(py::init<double, int, double, int, bool>(),
             py::arg("kappa"),
             py::arg("depth"),
             py::arg("lambda_") = 0.0,
             py::arg("stride") = 1,
             py::arg("verbose") = true,
             "CLARITree(kappa, depth, lambda_, verbose)\n\n"
            "kappa: ...\n"
            "depth: maximum tree depth\n"
            "lambda_: ridge penalty\n"
            "stride: to align with quantiles, we can skip some data points\n"
            "verbose: print training details");
    

    // ------------------- CLARITreeFull -------------------
    py::class_<CLARITreeFull, Greedy>(m, "CLARITreeFull")
        .def(py::init<double, int, double, int, bool>(),
            py::arg("kappa"),
            py::arg("depth"),
            py::arg("lambda_") = 0.0,
            py::arg("stride") = 1,
            py::arg("verbose") = true,
            "CLARITreeFull(kappa, depth, lambda_, stride, verbose)\n\n"
            "Ablation version of CLARITree that recomputes full regressions "
            "at every possible split (no rank-one updates). "
            "All other interfaces are identical to CLARITree.");

    // ------------------- GreedyConst (constant-leaf, greedy) -------------------
    py::class_<GreedyConst>(m, "GreedyConst")
        .def(py::init([](int depth, double lambda, int stride, bool verbose){
                return new GreedyConst(static_cast<char>(depth), lambda, stride, verbose);
            }),
            py::arg("depth"),
            py::arg("lambda_") = 0.0,
            py::arg("stride") = 1,
            py::arg("verbose") = true,
            "Constant-leaf greedy regression tree.\n"
            "Args:\n"
            "  depth (int): max tree depth\n"
            "  lambda_ (float): per-leaf penalty (scaled by TSS inside fit)\n"
            "  stride (int): to speed up traversal or align with quantiles, we can skip some data points\n"
            "  verbose (bool)")

        .def("fit",
            [](GreedyConst& self,
                const Eigen::Ref<const Eigen::MatrixXd>& X,
                const Eigen::Ref<const Eigen::VectorXd>& y) {
                py::gil_scoped_release release;
                return self.fit(X, y);  
            },
            py::arg("X"), py::arg("y"),
            "Fit the tree with (X, y). Returns objective (loss).")

        .def("predict",
            [](GreedyConst& self, const Eigen::Ref<const Eigen::MatrixXd>& X){
                py::gil_scoped_release release;
                return self.predict(X);
            },
            py::arg("X"),
            "Predict values for X.")

        .def("print_tree", &GreedyConst::print_tree,
            "Pretty-print the trained tree.")
            
        .def("n_leaves", &GreedyConst::n_leaves,
            "Return the number of leaf nodes in the tree.");
    // ------------------- CLARITreeConst (constant-leaf, CLARITree splits) -------------------
    py::class_<CLARITreeConst, GreedyConst>(m, "CLARITreeConst")
        .def(py::init([](int depth, double lambda, int stride, bool verbose){
                return new CLARITreeConst(static_cast<char>(depth), lambda, stride, verbose);
            }),
            py::arg("depth"),
            py::arg("lambda_") = 0.0,
            py::arg("stride") = 1,
            py::arg("verbose") = true,
            "Constant-leaf tree with CLARITree splits (special case of CLARITree).\n"
            "Args:\n"
            "  depth (int): max tree depth\n"
            "  lambda_ (float): per-leaf penalty (scaled by TSS inside fit)\n"
            "  stride (int): to speed up traversal or align with quantiles, we can skip some data points\n"
            "  verbose (bool)");

            
    // ------------------- Utility functions -------------------
    m.def("compute_R2", &compute_R2,
          py::arg("y_true"), py::arg("y_pred"),
          "Compute R^2 between y_true and y_pred.");

}
