#include <iostream>
#include <iomanip>

#include <vector>
#include <array>
#include <numeric>
#include <set>
#include <chrono>

#include <Eigen/Dense>

#include "clari_tree.hpp"

using namespace Eigen;
using namespace Eigen::indexing;
using namespace std;

// constructor
Node::Node()
    : left(nullptr), right(nullptr), is_leaf(true), obj(0), threshold(0),
      feature_idx(0), coefficients(Eigen::VectorXd::Zero(1)) {}

// destructor
Node::~Node()
{
    delete left;
    delete right;
}

// assignment operator
Node &Node::operator=(const Node &other)
{
    if (this != &other)
    {
        // Clean up existing resources
        delete left;
        delete right;

        // Copy data from the other node
        left = other.left ? new Node(*other.left) : nullptr;
        right = other.right ? new Node(*other.right) : nullptr;
        is_leaf = other.is_leaf;
        obj = other.obj;
        threshold = other.threshold;
        feature_idx = other.feature_idx;
        coefficients = other.coefficients;
    }
    return *this;
}

// copy constructor
Node::Node(const Node &other)
{
    left = other.left ? new Node(*other.left) : nullptr;
    right = other.right ? new Node(*other.right) : nullptr;
    is_leaf = other.is_leaf;
    obj = other.obj;
    threshold = other.threshold;
    feature_idx = other.feature_idx;
    coefficients = other.coefficients;
}

// print the tree structure
std::string Node::print_tree(int indentation)
{
    std::string indent(indentation, ' ');
    if (is_leaf)
    {
        return indent + "Ridge reg on points in leaf; gives objective " + std::to_string(obj) + "\n";
    }
    else
    {
        std::string split_string = indent + "If feature " + std::to_string(feature_idx) +
                                   " <= " + std::to_string(threshold) + ":\n";
        if (left)
        {
            split_string += left->print_tree(indentation + 2);
        }
        split_string += indent + "Else:\n";
        if (right)
        {
            split_string += right->print_tree(indentation + 2);
        }
        return split_string;
    }
}

// constructor
Greedy::Greedy(double kappa, int depth, double lambda, int stride, bool verbose)
    : kappa(kappa), depth(depth), verbose(verbose), root(new Node()), lambda(lambda), stride(stride) {}

// destructor
Greedy::~Greedy()
{
    delete root;
}

// assignment operator
Greedy &Greedy::operator=(const Greedy &other)
{
    if (this != &other)
    {
        // Clean up existing resources
        delete root;

        // Copy data from the other tree
        X = other.X;
        y = other.y;
        kappa = other.kappa;
        scaled_kappa = other.scaled_kappa;
        verbose = other.verbose;
        depth = other.depth;
        lambda = other.lambda;
        scaled_lambda = other.scaled_lambda;
        n = other.n;
        m = other.m;
        root = other.root ? new Node(*other.root) : nullptr;
        stride = other.stride;
    }
    return *this;
}

// copy constructor
Greedy::Greedy(const Greedy &other)
    : X(other.X), y(other.y), kappa(other.kappa), verbose(other.verbose),
      depth(other.depth), n(other.n), m(other.m), lambda(other.lambda), stride(other.stride)
{
    root = other.root ? new Node(*other.root) : nullptr;
}

double Greedy::fit(MatrixXd X, VectorXd y, const std::vector<int>& categorical_idx)
{
    this->X = X;
    this->y = y;
    this->n = X.rows();
    this->m = X.cols();
    // Ensure intercept is present at col 0 exactly once
    // bool has_intercept = (X.cols() > 0) && (X.col(0).array() == 1.0).all();
    // if (!has_intercept) {
    //     Eigen::MatrixXd X1(X.rows(), X.cols() + 1);
    //     X1.col(0).setOnes();
    //     X1.rightCols(X.cols()) = X;
    //     X.swap(X1);
    // }
    this->categorical_idx_.clear();
    this->categorical_idx_.reserve(categorical_idx.size());
    for (int j_raw : categorical_idx) {
        this->categorical_idx_.push_back(j_raw + 1);  // shift for intercept
    }
    // scale by n
    this->scaled_lambda = this->n * this->lambda;
    // scale by TSS
    double mean_y = y.mean();
    double tss = (y.array() - mean_y).matrix().squaredNorm();
    this->scaled_lambda = this->lambda * tss;
    this->scaled_kappa = this->n * this->kappa;
    // NEW: detect feature types and build X_reg_
    detect_feature_types();
    if (p_reg_ <= 1) {
        throw std::runtime_error("Provide continuous columns (no non-binary, non-categorical numeric features found).");
    }
    // /*
    // Compute root loss and root cholesky decomposition
    // */
    // MatrixXd gram = X.transpose() * X + scaled_kappa * MatrixXd::Identity(X.cols(), X.cols());
    // gram(0, 0) -= scaled_kappa - 1e-12; // change the first position to 0
    // LLT<MatrixXd> lltOfA(gram);
    // // MatrixXd L = lltOfA.matrixL();
    // VectorXd b = X.transpose() * y;
    // double y_sum_sq = y.squaredNorm();
    // double parent_loss = Greedy::loss(lltOfA, b, y_sum_sq);

    // Root Cholesky built on X_reg_, not on full X
    MatrixXd gram = X_reg_.transpose() * X_reg_ + scaled_kappa * MatrixXd::Identity(p_reg_, p_reg_);
    gram(0,0) -= scaled_kappa - 1e-12;
    LLT<MatrixXd> lltOfA(gram);
    VectorXd b = X_reg_.transpose() * y;
    double y_sum_sq = y.squaredNorm();
    double parent_loss = Greedy::loss(lltOfA, b, y_sum_sq);

    delete this->root; // delete old root if exists
    this->root = new Node();
    this->root->obj = parent_loss + this->scaled_lambda;

    // get sorted feature indices
    // that is, for each feature, we want a row vector of indices sorted by feature value
    vector<vector<unsigned long int>> sorted_indices(this->m, vector<unsigned long int>(this->n));
    for (unsigned long int feature = 1; feature < this->m; feature++)
    {
        vector<unsigned long int> indices(this->n);
        iota(indices.begin(), indices.end(), 0); // fill with 0, 1, ..., n-1
        sort(indices.begin(), indices.end(), [&](unsigned long int a, unsigned long int b)
             { return X(a, feature) < X(b, feature); });
        sorted_indices[feature] = indices;
    }

    // learn partitioning structure and resulting loss
    double objective = recursive_fit(sorted_indices, lltOfA, b, y_sum_sq, this->root, this->depth);
    // Now learn the coefficient vectors for each of the nodes.
    fit_coefficients(this->root, this->X, this->y);

    return objective;
}

void Greedy::fit_coefficients(Node *node, MatrixXd X, VectorXd y)
{
    /*
    Fit the coefficients for the linear regression at this node.
    */
    if (node->is_leaf)
    {
        // Build local X_reg from provided rows
        MatrixXd Xloc(X.rows(), p_reg_);
        Xloc.col(0).setOnes();
        for (int k = 0; k < (int)continuous_idx_.size(); ++k) {
            Xloc.col(1 + k) = X.col(continuous_idx_[k]);
        }
        // // fit ridge regression
        // MatrixXd gram = X.transpose() * X + scaled_kappa * MatrixXd::Identity(X.cols(), X.cols());
        // gram(0, 0) -= scaled_kappa - 1e-12;
        // // node->coefficients = gram.inverse() * X.transpose() * y;
        // LLT<MatrixXd> llt(gram);
        // node->coefficients = llt.solve(X.transpose() * y);
        // // check that coefficients lead to same sum sq error + lambda penalty as currently recorded in node->obj
        // VectorXd yhat = X * node->coefficients;

        // fit ridge regression
        MatrixXd gram = Xloc.transpose() * Xloc + scaled_kappa * MatrixXd::Identity(Xloc.cols(), Xloc.cols());
        gram(0, 0) -= scaled_kappa - 1e-12;
        // node->coefficients = gram.inverse() * X.transpose() * y;
        LLT<MatrixXd> llt(gram);
        node->coefficients = llt.solve(Xloc.transpose() * y);
        // check that coefficients lead to same sum sq error + lambda penalty as currently recorded in node->obj
        VectorXd yhat = Xloc * node->coefficients;
        double reg = 1e-12 * node->coefficients(0) * node->coefficients(0) + this->scaled_kappa * node->coefficients.tail(node->coefficients.size() - 1).squaredNorm(); // delete the penalty of intercept
        double new_loss = (y - yhat).squaredNorm() + reg;
        // double new_loss = (y - yhat).squaredNorm() + this->scaled_kappa * node->coefficients.squaredNorm();
        double new_obj = new_loss + this->scaled_lambda;
        // if (abs(new_obj - node->obj) > 1e-6)
        // {
        //     double diff = std::abs(new_obj - node->obj);
        //     cerr << "Warning: Fit coefficients in at least one leaf do not match recorded loss."
        //          << "Diff = " << diff << endl;
        // }
    }
    else
    {
        // access those rows of X for which row(node->feature_idx) < node->threshold
        vector<unsigned long int> left_indices;
        vector<unsigned long int> right_indices;
        for (unsigned long int i = 0; i < X.rows(); i++)
        {
            if (X(i, node->feature_idx) < node->threshold)
            {
                left_indices.push_back(i);
            }
            else
            {
                right_indices.push_back(i);
            }
        }
        // fit coefficients for children
        fit_coefficients(node->left, X(left_indices, Eigen::all), y(left_indices));
        fit_coefficients(node->right, X(right_indices, Eigen::all), y(right_indices));

        // Eigen::VectorXi idxL = Eigen::VectorXi::Zero(static_cast<int>(left_indices.size()));
        // for (int i = 0; i < idxL.size(); ++i) idxL(i) = static_cast<int>(left_indices[i]);

        // Eigen::VectorXi idxR = Eigen::VectorXi::Zero(static_cast<int>(right_indices.size()));
        // for (int i = 0; i < idxR.size(); ++i) idxR(i) = static_cast<int>(right_indices[i]);

        // auto XL = X(idxL, all).eval();
        // auto yL = y(idxL).eval();
        // auto XR = X(idxR, all).eval();
        // auto yR = y(idxR).eval();

        // fit_coefficients(node->left,  XL, yL);
        // fit_coefficients(node->right, XR, yR);

    }
    return;
}

double Greedy::recursive_fit(vector<vector<unsigned long int>> sorted_indices, LLT<MatrixXd> llt, VectorXd b, double y_sum_sq, Node *node, int depth_remaining)
{
    /*
    Finds a greedy tree from the current node, stopping if depth is 0 or if no valid split is found.
    Returns the objective of the greedy tree from this node
    Assumes current node has loss value filled in with its loss if it were a leaf.
    */
    if (depth_remaining == 0 || node->obj <= 2 * this->scaled_lambda)
    {
        node->is_leaf = true;
        return node->obj;
    }

    // find best split
    bool split_flag = false;
    unsigned long int best_feature;
    unsigned long int best_feature_index;
    bool best_is_binary = false;
    // samples_past_quota shared across features; initialize to account for skipping feature 0 (intercept)
    // In old code, feature 0 would increment it n times but never split, so we start with n
    // Use size of any feature (they all have same size = number of rows in current node)
    int samples_past_quota = sorted_indices.size() > 1 ? static_cast<int>(sorted_indices[1].size()) : 0;
    double min_obj = node->obj; // start with parent loss
    LLT<MatrixXd> best_llt_right;
    LLT<MatrixXd> best_llt_left;
    VectorXd best_b_left;
    VectorXd best_b_right;
    vector<int> best_indices_left;
    double best_left_obj, best_right_obj;
    double best_y_sum_sq_left, best_y_sum_sq_right;

    for (unsigned long int feature = 1; feature < this->m; feature++)
    {   
        bool is_bin = std::find(binary_idx_.begin(), binary_idx_.end(), (int)feature) != binary_idx_.end();
        bool is_cont = std::find(continuous_idx_.begin(), continuous_idx_.end(), (int)feature) != continuous_idx_.end();

        // --- 1) Handle binary feature (0/1) ---------------------------------
        if (is_bin)
        {
            std::vector<int> left_rows, right_rows;
            left_rows.reserve(this->n);
            right_rows.reserve(this->n);

            // Split by 0/1
            for (int i = 0; i < this->n; ++i)
            {
                if (this->X(i, feature) <= 0.5)
                    left_rows.push_back(i);
                else
                    right_rows.push_back(i);
            }
            if (left_rows.empty() || right_rows.empty())
                continue; // not splittable

            auto [lltL, bL, yssL] = recompute_stats_from_rows(left_rows);
            auto [lltR, bR, yssR] = recompute_stats_from_rows(right_rows);

            double left_obj = loss(lltL, bL, yssL) + this->scaled_lambda;
            double right_obj = loss(lltR, bR, yssR) + this->scaled_lambda;

            if (left_obj + right_obj < min_obj)
            {
                split_flag = true;
                min_obj = left_obj + right_obj;
                best_feature = feature;
                best_feature_index = -1; // flag: binary split
                best_is_binary = true;
                best_llt_left = lltL;
                best_llt_right = lltR;
                best_b_left = bL;
                best_b_right = bR;
                best_left_obj = left_obj;
                best_right_obj = right_obj;
                best_indices_left.assign(left_rows.begin(), left_rows.end());
                best_y_sum_sq_left = yssL;
                best_y_sum_sq_right = yssR;
            }
            continue; // done with binary feature
        }

        // --- 2) Skip non-continuous (categorical not one-hot) ---------------
        if (!is_cont)
            continue;
        MatrixXd gram_left = this->scaled_kappa * MatrixXd::Identity(p_reg_, p_reg_);
        gram_left(0, 0) = 1e-12;
        VectorXd b_left = VectorXd::Zero(p_reg_);
        LLT<MatrixXd> llt_left(gram_left);
        double y_sum_sq_left = 0; // sum of squares of y values in left child
        vector<int> left_indices = {};

        VectorXd b_right = b;             // copy parent b
        LLT<MatrixXd> llt_right = llt;    // use parent llt
        double y_sum_sq_right = y_sum_sq; // copy parent y sum squared
        for (unsigned long int feature_idx = 0; feature_idx < sorted_indices[feature].size(); feature_idx++)
        {
            int row = sorted_indices[feature][feature_idx]; // get the row index for this feature & threshold
            b_left += reg_row(row).transpose() * this->y(row);
            llt_left = llt_left.rankUpdate(reg_row(row), 1);
            y_sum_sq_left += this->y(row) * this->y(row);
            left_indices.push_back(row);

            b_right -= reg_row(row).transpose() * this->y(row);
            llt_right = llt_right.rankUpdate(reg_row(row), -1);
            y_sum_sq_right -= this->y(row) * this->y(row);

            samples_past_quota++;
            if (samples_past_quota < this->stride)
            {
                continue; // skip if we haven't seen enough samples since last split
            }
            if (feature_idx == sorted_indices[feature].size() - 1)
            {
                // if this is the last feature index, we can't split further
                continue;
            }
            if (this->X(row, feature) == this->X(sorted_indices[feature][feature_idx + 1], feature))
            {
                continue; // skip if this is not a valid split point
            }
            samples_past_quota -= this->stride;

            // double left_obj = loss(llt_left.matrixL(), b_left, y_sum_sq_left) + this->scaled_lambda;
            // double right_obj = loss(llt_right.matrixL(), b_right, y_sum_sq_right) + this->scaled_lambda;
            double left_obj = loss(llt_left, b_left, y_sum_sq_left) + this->scaled_lambda;
            double right_obj = loss(llt_right, b_right, y_sum_sq_right) + this->scaled_lambda;
            if (left_obj + right_obj < min_obj)
            {
                split_flag = true;
                min_obj = left_obj + right_obj;
                best_feature = feature;
                best_feature_index = feature_idx;
                best_llt_left = llt_left;
                best_llt_right = llt_right;
                best_b_left = b_left;
                best_b_right = b_right;
                best_left_obj = left_obj;
                best_right_obj = right_obj;
                best_indices_left = left_indices;
                best_y_sum_sq_left = y_sum_sq_left;
                best_y_sum_sq_right = y_sum_sq_right;
            }
        }
    }

    if (!split_flag)
    {
        node->is_leaf = true; // no valid split found, this is a leaf
        return node->obj;     // return the loss at this node
    }
    // if the depth remaining is 1, we have find the best feature and don't need to sort indices for children
    if (depth_remaining == 1)
    {
        node->is_leaf = false;
        node->feature_idx = best_feature;
        if (best_is_binary) {
            node->threshold = 0.5;
        } else {
            node->threshold =
                (this->X(sorted_indices[best_feature][best_feature_index], best_feature) +
                this->X(sorted_indices[best_feature][best_feature_index + 1], best_feature)) /
                2.0;
        }

        node->left = new Node();
        node->right = new Node();
        node->left->obj = best_left_obj;
        node->right->obj = best_right_obj;

        node->obj = best_left_obj + best_right_obj;
        return node->obj;
    }

    // if valid split found, create children nodes to see if it's worth continuing to split
    node->left = new Node();
    node->left->obj = best_left_obj; // set left loss to minimum loss found
    node->right = new Node();
    node->right->obj = best_right_obj; // set right loss to minimum loss found

    // update sorted indices for children
    set<unsigned long int> left_indices_set(best_indices_left.begin(), best_indices_left.end());
    vector<vector<unsigned long int>> sorted_indices_left(this->m, vector<unsigned long int>());
    vector<vector<unsigned long int>> sorted_indices_right(this->m, vector<unsigned long int>());
    for (unsigned long int feature = 1; feature < this->m; feature++)
    {
        // for each feature, emplace back for left or right child based on which it's in
        for (unsigned long int idx : sorted_indices[feature])
        {
            if (left_indices_set.find(idx) != left_indices_set.end())
            {
                sorted_indices_left[feature].push_back(idx);
            }
            else
            {
                sorted_indices_right[feature].push_back(idx);
            }
        }
    }

    // now replace losses based on greedy completions
    double final_left_objective = Greedy::recursive_fit(sorted_indices_left, best_llt_left, best_b_left, best_y_sum_sq_left, node->left, depth_remaining - 1);
    double final_right_objective = Greedy::recursive_fit(sorted_indices_right, best_llt_right, best_b_right, best_y_sum_sq_right, node->right, depth_remaining - 1);

    if (final_left_objective + final_right_objective < node->obj)
    {
        node->is_leaf = false;
        node->feature_idx = best_feature;
        if (best_is_binary) {
            node->threshold = 0.5;
        } else {
            node->threshold = (this->X(sorted_indices[best_feature][best_feature_index], best_feature) + this->X(sorted_indices[best_feature][best_feature_index + 1], best_feature)) / 2.0;
        }
        node->obj = final_left_objective + final_right_objective;
        return final_left_objective + final_right_objective; // return the objective at this node
    }
    else
    {
        delete node->left;    // delete left child if it was created
        delete node->right;   // delete right child if it was created
        node->is_leaf = true; // no valid split found, this is a leaf
        return node->obj;     // return the objective at this node
    }
}
/*
Loss using cholesky decomposition
Parameters:
- L, the lower triangular form of ch decomp for XTX + kappa(I)
- y_sum_sq, sum of squares of y values
- b, equal to X^Ty

\|y\|^2 - \|L^{-1} b\|^2
*/
// double Greedy::loss(MatrixXd L, VectorXd b, double y_sum_sq)
// {
//     VectorXd z = L.triangularView<Lower>().solve(b);
//     // Now z = L^{-1} b, so ||z||^2 = b^T A^{-1} b
//     return y_sum_sq - z.squaredNorm();
// }

double Greedy::loss(const LLT<MatrixXd>& llt, const VectorXd& b, double y_sum_sq) {
    // z = A^{-1} b via Cholesky solve (no matrixL() materialization)
    VectorXd z = llt.solve(b);       
    return y_sum_sq - b.dot(z);      
}



VectorXd Greedy::predict(MatrixXd X)
{
    /*
    Predicts the output for the input matrix X using the fitted tree.
    Returns a vector of predictions.
    */
    VectorXd predictions(X.rows());
    for (int i = 0; i < X.rows(); i++)
    {
        predictions(i) = predict_row(X.row(i));
    }
    return predictions;
}

double Greedy::predict_row(VectorXd x)
{
    // Traverse the tree to make a prediction
    Node *current = this->root;
    while (current != nullptr)
    {
        if (current->is_leaf)
        {
            // return current->coefficients.transpose() * x;
            // assemble z = [1, x_cont]^T
            VectorXd z(p_reg_);
            z(0) = 1.0;
            for (int k = 0; k < (int)continuous_idx_.size(); ++k) {
                z(1 + k) = x(continuous_idx_[k]);
            }
            return current->coefficients.dot(z);
        }
        // Decide whether to go left or right
        if (x(current->feature_idx) <= current->threshold)
        {
            current = current->left;
        }
        else
        {
            current = current->right;
        }
    }
    // If we reach here, something went wrong
    throw runtime_error("Invalid tree structure");
}

string Greedy::print_tree()
{
    if (this->root == nullptr)
    {
        return "No tree currently fit!";
    }
    return this->root->print_tree();
}

std::size_t Greedy::count_leaves(const Node* n) {
    if (!n) return 0;
    if (n->is_leaf) return 1;
    return count_leaves(n->left) + count_leaves(n->right);
}

std::size_t Greedy::n_leaves() const {
    return count_leaves(this->root);
}

bool Greedy::is_binary_column(const Eigen::VectorXd& col) {
    for (int i = 0; i < col.size(); ++i) {
        double v = col(i);
        if (!(v == 0.0 || v == 1.0)) return false;
    }
    return true;
}

void Greedy::detect_feature_types() {
    const int P = this->X.cols();     // includes intercept at 0

    // Mark which columns are categorical (we expect categorical_idx_ is already shifted in fit)
    std::vector<bool> is_categorical(P, false);
    for (int j : categorical_idx_) {
        if (j > 0 && j < P)            // ignore intercept (0) and out-of-range
            is_categorical[j] = true;
    }

    // scan features (skip intercept at 0)
    binary_idx_.clear();
    continuous_idx_.clear();
    for (int j = 1; j < P; ++j) {
        Eigen::VectorXd col = this->X.col(j);
        if (is_binary_column(col)) {
            binary_idx_.push_back(j);
        } else if (!is_categorical[j]) {
            continuous_idx_.push_back(j);
        } // else: multi-class categorical (not one-hot), will be skipped for splitting
    }

    // build X_reg_ = [1 | continuous columns]
    p_reg_ = 1 + (int)continuous_idx_.size();
    X_reg_.resize(this->n, p_reg_);
    X_reg_.col(0) = Eigen::VectorXd::Ones(this->n);
    for (int k = 0; k < (int)continuous_idx_.size(); ++k) {
        X_reg_.col(1 + k) = this->X.col(continuous_idx_[k]);
    }
}

Eigen::RowVectorXd Greedy::reg_row(int i) const {
    Eigen::RowVectorXd r(p_reg_);
    r(0) = 1.0; // intercept
    for (int k = 0; k < (int)continuous_idx_.size(); ++k) {
        r(1 + k) = X(i, continuous_idx_[k]);
    }
    return r;
}


auto Greedy::recompute_stats_from_rows(const std::vector<int>& rows)
    -> std::tuple<Eigen::LLT<Eigen::MatrixXd, Eigen::Lower>, Eigen::VectorXd, double>
{
    MatrixXd gram = this->scaled_kappa * MatrixXd::Identity(p_reg_, p_reg_);
    gram(0,0) = 1e-12;
    VectorXd bb = VectorXd::Zero(p_reg_);
    double yss = 0.0;
    for (int r : rows) {
        auto xr = reg_row(r);
        bb.noalias()   += xr.transpose() * this->y(r);
        gram.noalias() += xr.transpose() * xr;
        yss += this->y(r) * this->y(r);
    }
    LLT<MatrixXd> lltmp(gram);
    return {lltmp, bb, yss};
}

// constructor
CLARITree::CLARITree(double kappa, int depth, double lambda, int stride, bool verbose)
    : Greedy(kappa, depth, lambda, stride, verbose) {}

double CLARITree::recursive_fit(vector<vector<unsigned long int>> sorted_indices, LLT<MatrixXd> llt, VectorXd b, double y_sum_sq, Node *node, int depth_remaining)
{
    /*
    For every possible single step split, compute the loss using greedy,
    then pick the split locally the best with that heuristic.
    Then, replace those greedy calls with another CLARITree call
    */
    if (depth_remaining == 0 || node->obj <= 2 * this->scaled_lambda)
    {
        node->is_leaf = true;
        return node->obj; // return the objective at this node
    }

    // find best split
    bool split_flag = false;
    bool best_is_binary = false; 
    unsigned long int best_feature = 0;
    unsigned long int best_feature_index;
    // samples_past_quota shared across features; initialize to account for skipping feature 0 (intercept)
    // In old code, feature 0 would increment it n times but never split, so we start with n
    // Use size of any feature (they all have same size = number of rows in current node)
    int samples_past_quota = sorted_indices.size() > 1 ? static_cast<int>(sorted_indices[1].size()) : 0;
    double min_obj = node->obj; // start with parent loss
    LLT<MatrixXd> best_llt_right;
    LLT<MatrixXd> best_llt_left;
    VectorXd best_b_left;
    VectorXd best_b_right;
    vector<unsigned long int> best_indices_left;
    double best_left_leaf_obj, best_right_leaf_obj;
    double best_y_sum_sq_left, best_y_sum_sq_right;

    for (unsigned long int feature = 1; feature < this->m; feature++)
    {   
        bool is_bin = std::find(binary_idx_.begin(), binary_idx_.end(), (int)feature) != binary_idx_.end();
        bool is_cont = std::find(continuous_idx_.begin(), continuous_idx_.end(), (int)feature) != continuous_idx_.end();

        // --- 1) Handle binary feature (0/1) ---------------------------------
        if (is_bin)
        {
            std::vector<int> left_rows, right_rows;
            left_rows.reserve(this->n);
            right_rows.reserve(this->n);

            // Split by 0/1
            for (int i = 0; i < this->n; ++i)
            {
                if (this->X(i, feature) <= 0.5)
                    left_rows.push_back(i);
                else
                    right_rows.push_back(i);
            }
            if (left_rows.empty() || right_rows.empty())
                continue; // not splittable

            auto [lltL, bL, yssL] = recompute_stats_from_rows(left_rows);
            auto [lltR, bR, yssR] = recompute_stats_from_rows(right_rows);

            double left_obj = loss(lltL, bL, yssL) + this->scaled_lambda;
            double right_obj = loss(lltR, bR, yssR) + this->scaled_lambda;

            if (left_obj + right_obj < min_obj)
            {
                split_flag = true;
                min_obj = left_obj + right_obj;
                best_is_binary = true; 
                best_feature = feature;
                best_feature_index = -1; // flag: binary split
                best_llt_left = lltL;
                best_llt_right = lltR;
                best_b_left = bL;
                best_b_right = bR;
                best_left_leaf_obj = left_obj;         
                best_right_leaf_obj = right_obj;
                best_indices_left.assign(left_rows.begin(), left_rows.end());
                best_y_sum_sq_left = yssL;
                best_y_sum_sq_right = yssR;
            }
            continue; // done with binary feature
        }

        // --- 2) Skip non-continuous (categorical not one-hot) ---------------
        if (!is_cont)
            continue;

        MatrixXd gram_left = this->scaled_kappa * MatrixXd::Identity(p_reg_, p_reg_);
        gram_left(0, 0) = 1e-12;
        VectorXd b_left = VectorXd::Zero(p_reg_);
        LLT<MatrixXd> llt_left(gram_left);
        double y_sum_sq_left = 0; // sum of squares of y values in left child
        vector<unsigned long int> left_indices = {};
        VectorXd b_right = b;             // copy parent b
        LLT<MatrixXd> llt_right = llt;    // use parent llt
        double y_sum_sq_right = y_sum_sq; // copy parent y sum squared
        for (unsigned long int feature_idx = 0; feature_idx < sorted_indices[feature].size(); feature_idx++)
        {
            int row = sorted_indices[feature][feature_idx]; // get the row index for this feature & threshold
            b_left += reg_row(row).transpose() * this->y(row);
            llt_left = llt_left.rankUpdate(reg_row(row), 1);
            y_sum_sq_left += this->y(row) * this->y(row);
            left_indices.push_back(row);

            b_right -= reg_row(row).transpose() * this->y(row);
            llt_right = llt_right.rankUpdate(reg_row(row), -1);
            y_sum_sq_right -= this->y(row) * this->y(row);
            samples_past_quota++;
            if (samples_past_quota < this->stride)
            {
                continue; // skip if we haven't seen enough samples since last split
            }
            if (feature_idx == sorted_indices[feature].size() - 1)
            {
                // if this is the last feature index, we can't split further
                continue;
            }
            if (this->X(row, feature) == this->X(sorted_indices[feature][feature_idx + 1], feature))
            {
                // skip if this is not a valid split point
                continue;
            }

            samples_past_quota -= this->stride;
            // loss estimate based on greedy completion
            // double left_obj = loss(llt_left.matrixL(), b_left, y_sum_sq_left) + this->scaled_lambda;
            // double right_obj = loss(llt_right.matrixL(), b_right, y_sum_sq_right) + this->scaled_lambda;
            double left_obj = loss(llt_left, b_left, y_sum_sq_left) + this->scaled_lambda;
            double right_obj = loss(llt_right, b_right, y_sum_sq_right) + this->scaled_lambda;

            if (depth_remaining == 1)
            {
                // if remaining depth=1 one can find the solution directly
                if (left_obj + right_obj < min_obj)
                {
                    split_flag = true;
                    min_obj = left_obj + right_obj;
                    best_feature = feature;
                    best_feature_index = feature_idx;
                    best_llt_left = llt_left;
                    best_llt_right = llt_right;
                    best_b_left = b_left;
                    best_b_right = b_right;
                    best_left_leaf_obj = left_obj;
                    best_right_leaf_obj = right_obj;
                    best_indices_left = left_indices;
                    best_y_sum_sq_left = y_sum_sq_left;
                    best_y_sum_sq_right = y_sum_sq_right;
                }
            }
            else
            {
                // compute artifacts needed for greedy call.
                set<unsigned long int> left_indices_set(left_indices.begin(), left_indices.end());

                delete node->left;  // delete left child if it was created
                delete node->right; // delete right child if it was created
                node->left = new Node();
                node->left->obj = left_obj; // set left obj to minimum obj found
                node->right = new Node();
                node->right->obj = right_obj; // set right obj to minimum obj found

                // update sorted indices for children
                vector<vector<unsigned long int>> sorted_indices_left(this->m, vector<unsigned long int>());
                vector<vector<unsigned long int>> sorted_indices_right(this->m, vector<unsigned long int>());
                for (unsigned long int feature = 1; feature < this->m; feature++)
                {
                    // for each feature, emplace back for left or right child based on which it's in
                    for (unsigned long int idx : sorted_indices[feature])
                    {
                        if (left_indices_set.find(idx) != left_indices_set.end())
                        {
                            sorted_indices_left[feature].push_back(idx);
                        }
                        else
                        {
                            sorted_indices_right[feature].push_back(idx);
                        }
                    }
                }

                double greedy_completion_left_obj = Greedy::recursive_fit(sorted_indices_left, llt_left, b_left, y_sum_sq_left, node->left, depth_remaining - 1);
                double greedy_completion_right_obj = Greedy::recursive_fit(sorted_indices_right, llt_right, b_right, y_sum_sq_right, node->right, depth_remaining - 1);
                // lookahead tree in other papers, will this be much worse than us?
                // double greedy_completion_left_obj = Greedy::recursive_fit(sorted_indices_left, llt_left, b_left, y_sum_sq_left, node->left, 1);
                // double greedy_completion_right_obj = Greedy::recursive_fit(sorted_indices_right, llt_right, b_right, y_sum_sq_right, node->right, 1);

                if (greedy_completion_left_obj + greedy_completion_right_obj < min_obj)
                {
                    split_flag = true;
                    min_obj = greedy_completion_left_obj + greedy_completion_right_obj;
                    best_feature = feature;
                    best_feature_index = feature_idx;
                    best_llt_left = llt_left;
                    best_llt_right = llt_right;
                    best_b_left = b_left;
                    best_b_right = b_right;
                    best_left_leaf_obj = left_obj;
                    best_right_leaf_obj = right_obj;
                    best_indices_left = left_indices;
                    best_y_sum_sq_left = y_sum_sq_left;
                    best_y_sum_sq_right = y_sum_sq_right;
                }
            }
        }
    }

    if (!split_flag)
    {
        node->is_leaf = true; // no valid split found, this is a leaf
        return node->obj;     // return the loss at this node
    }

    if (depth_remaining == 1)
    {
        node->is_leaf = false;
        node->feature_idx = best_feature;
        if (best_is_binary) {
            node->threshold = 0.5;               // <-- binary threshold
        } else {
            const auto &order = sorted_indices[best_feature];
            node->threshold =
                (this->X(order[best_feature_index], best_feature) +
                 this->X(order[best_feature_index + 1], best_feature)) / 2.0;
        }

        node->left = new Node();
        node->right = new Node();
        node->left->obj = best_left_leaf_obj;
        node->right->obj = best_right_leaf_obj;

        node->obj = node->left->obj + node->right->obj;
        return node->obj;
    }

    // if valid split found, create children nodes to see if it's worth continuing to split
    delete node->left;  // delete left child if it was created
    delete node->right; // delete right child if it was created
    node->left = new Node();
    node->left->obj = best_left_leaf_obj; // set left obj to minimum leaf obj found
    node->right = new Node();
    node->right->obj = best_right_leaf_obj; // set right obj to minimum leaf obj found

    // update sorted indices for children
    set<unsigned long int> left_indices_set(best_indices_left.begin(), best_indices_left.end());
    vector<vector<unsigned long int>> sorted_indices_left(this->m, vector<unsigned long int>());
    vector<vector<unsigned long int>> sorted_indices_right(this->m, vector<unsigned long int>());
    for (unsigned long int feature = 1; feature < this->m; feature++)
    {
        // for each feature, emplace back for left or right child based on which it's in
        for (unsigned long int idx : sorted_indices[feature])
        {
            if (left_indices_set.find(idx) != left_indices_set.end())
            {
                sorted_indices_left[feature].push_back(idx);
            }
            else
            {
                sorted_indices_right[feature].push_back(idx);
            }
        }
    }

    // now replace losses based on recursive completions (using CLARITree's approach, not the full greedy Greedy)
    double final_left_obj = recursive_fit(sorted_indices_left, best_llt_left, best_b_left, best_y_sum_sq_left, node->left, depth_remaining - 1);
    double final_right_obj = recursive_fit(sorted_indices_right, best_llt_right, best_b_right, best_y_sum_sq_right, node->right, depth_remaining - 1);

    if (final_left_obj + final_right_obj < node->obj)
    {
        node->is_leaf = false; // this node is not a leaf, we found a valid split
        node->feature_idx = best_feature;
        if (best_is_binary) {
            node->threshold = 0.5;               // <-- binary threshold here too
        } else {
            node->threshold =
                (this->X(sorted_indices[best_feature][best_feature_index], best_feature) +
                 this->X(sorted_indices[best_feature][best_feature_index + 1], best_feature)) / 2.0;
        }
        // node->threshold = (this->X(sorted_indices[best_feature][best_feature_index], best_feature) + this->X(sorted_indices[best_feature][best_feature_index + 1], best_feature)) / 2.0;
        node->obj = final_left_obj + final_right_obj;
        return final_left_obj + final_right_obj; // return the obj at this node
    }
    else
    {
        delete node->left;    // delete left child if it was created
        delete node->right;   // delete right child if it was created
        node->is_leaf = true; // no valid split found, this is a leaf
        return node->obj;     // return the obj at this node
    }
}

// add the full update version, which means don't contain the cholesky update.
// GreedyFull
double GreedyFull::recursive_fit(vector<vector<unsigned long int>> sorted_indices, LLT<MatrixXd> llt, VectorXd b, double y_sum_sq, Node *node, int depth_remaining)
{
    /*
    Finds a greedy tree from the current node, stopping if depth is 0 or if no valid split is found.
    Returns the objective of the greedy tree from this node
    Assumes current node has loss value filled in with its loss if it were a leaf.
    */
    if (depth_remaining == 0 || node->obj <= 2 * this->scaled_lambda)
    {
        node->is_leaf = true;
        return node->obj;
    }

    // find best split
    bool split_flag = false;
    unsigned long int best_feature;
    unsigned long int best_feature_index;
    bool best_is_binary = false;
    // samples_past_quota shared across features; initialize to account for skipping feature 0 (intercept)
    // Use size of any feature (they all have same size = number of rows in current node)
    int samples_past_quota = sorted_indices.size() > 1 ? static_cast<int>(sorted_indices[1].size()) : 0;
    double min_obj = node->obj; // start with parent loss
    LLT<MatrixXd> best_llt_right;
    LLT<MatrixXd> best_llt_left;
    VectorXd best_b_left;
    VectorXd best_b_right;
    vector<int> best_indices_left;
    double best_left_obj, best_right_obj;
    double best_y_sum_sq_left, best_y_sum_sq_right;

    // helper: full recompute from a set of row indices (no rank-one updates)
    auto recompute_stats = [&](const vector<int> &rows)
    {
        const int p = p_reg_;
        MatrixXd gram = this->scaled_kappa * MatrixXd::Identity(p, p);
        gram(0, 0) = 1e-12; // tiny jitter for numerical stability (bias anchor)
        VectorXd bb = VectorXd::Zero(p);
        double yss = 0.0;
        for (int r : rows)
        {
            auto xr = reg_row(r);
            bb.noalias() += xr.transpose() * this->y(r);
            gram.noalias() += xr.transpose() * xr;
            yss += this->y(r) * this->y(r);
        }
        LLT<MatrixXd> lltmp(gram);
        return tuple<LLT<MatrixXd>, VectorXd, double>(lltmp, bb, yss);
    };

    for (unsigned long int feature = 1; feature < this->m; feature++)
    {
        bool is_bin = std::find(binary_idx_.begin(), binary_idx_.end(), (int)feature) != binary_idx_.end();
        bool is_cont = std::find(continuous_idx_.begin(), continuous_idx_.end(), (int)feature) != continuous_idx_.end();

        // --- 1) Handle binary feature (0/1) ---------------------------------
        if (is_bin)
        {
            std::vector<int> left_rows, right_rows;
            // Only iterate over current node's samples, not global n
            for (unsigned long int idx : sorted_indices[feature])
            {
                int i = (int)idx;
                if (this->X(i, feature) <= 0.5)
                    left_rows.push_back(i);
                else
                    right_rows.push_back(i);
            }
            if (left_rows.empty() || right_rows.empty())
                continue; // not splittable

            auto [lltL, bL, yssL] = recompute_stats(left_rows);
            auto [lltR, bR, yssR] = recompute_stats(right_rows);

            double left_obj = loss(lltL, bL, yssL) + this->scaled_lambda;
            double right_obj = loss(lltR, bR, yssR) + this->scaled_lambda;

            if (left_obj + right_obj < min_obj)
            {
                split_flag = true;
                min_obj = left_obj + right_obj;
                best_feature = feature;
                best_feature_index = -1; // flag: binary split
                best_is_binary = true;
                best_llt_left = lltL;
                best_llt_right = lltR;
                best_b_left = bL;
                best_b_right = bR;
                best_left_obj = left_obj;
                best_right_obj = right_obj;
                best_indices_left = left_rows;
                best_y_sum_sq_left = yssL;
                best_y_sum_sq_right = yssR;
            }
            continue; // done with binary feature
        }

        // --- 2) Skip non-continuous (categorical not one-hot) ---------------
        if (!is_cont)
            continue;

        // For GreedyFull, we don't use incremental updates - we fully recompute at each split point
        vector<int> left_indices = {};
        for (unsigned long int feature_idx = 0; feature_idx < sorted_indices[feature].size(); feature_idx++)
        {
            int row = (int)sorted_indices[feature][feature_idx]; // get the row index for this feature & threshold
            left_indices.push_back(row);

            samples_past_quota++;
            if (samples_past_quota < this->stride)
            {
                continue; // skip if we haven't seen enough samples since last split
            }
            if (feature_idx == sorted_indices[feature].size() - 1)
            {
                // if this is the last feature index, we can't split further
                continue;
            }
            if (this->X(row, feature) == this->X(sorted_indices[feature][feature_idx + 1], feature))
            {
                continue; // skip if this is not a valid split point
            }
            samples_past_quota -= this->stride;

            // left child: indices = left_indices
            auto [lltL, bL, yssL] = recompute_stats(left_indices);

            // right child: indices = tail of this feature's order
            vector<int> right_indices;
            right_indices.reserve(sorted_indices[feature].size() - feature_idx - 1);
            for (unsigned long int j = feature_idx + 1; j < sorted_indices[feature].size(); ++j) {
                right_indices.push_back((int)sorted_indices[feature][j]);
            }

            auto [lltR, bR, yssR] = recompute_stats(right_indices);

            // double left_obj = loss(lltL.matrixL(), bL, yssL) + this->scaled_lambda;
            // double right_obj = loss(lltR.matrixL(), bR, yssR) + this->scaled_lambda;
            double left_obj = loss(lltL, bL, yssL) + this->scaled_lambda;
            double right_obj = loss(lltR, bR, yssR) + this->scaled_lambda;
            if (left_obj + right_obj < min_obj)
            {
                split_flag = true;
                min_obj = left_obj + right_obj;
                best_feature = feature;
                best_feature_index = feature_idx;
                best_llt_left = lltL;
                best_llt_right = lltR;
                best_b_left = bL;
                best_b_right = bR;
                best_left_obj = left_obj;
                best_right_obj = right_obj;
                best_indices_left = left_indices;
                best_y_sum_sq_left = yssL;
                best_y_sum_sq_right = yssR;
            }
        }
    }

    if (!split_flag)
    {
        node->is_leaf = true; // no valid split found, this is a leaf
        return node->obj;     // return the loss at this node
    }
    // if the depth remaining is 1, we have find the best feature and don't need to sort indices for children
    if (depth_remaining == 1)
    {
        node->is_leaf = false;
        node->feature_idx = best_feature;
        if (best_is_binary) {
            node->threshold = 0.5;
        } else {
            node->threshold =
                (this->X(sorted_indices[best_feature][best_feature_index], best_feature) +
                 this->X(sorted_indices[best_feature][best_feature_index + 1], best_feature)) /
                2.0;
        }

        node->left = new Node();
        node->right = new Node();
        node->left->obj = best_left_obj;
        node->right->obj = best_right_obj;

        node->obj = best_left_obj + best_right_obj;
        return node->obj;
    }

    // if valid split found, create children nodes to see if it's worth continuing to split
    node->left = new Node();
    node->left->obj = best_left_obj; // set left loss to minimum loss found
    node->right = new Node();
    node->right->obj = best_right_obj; // set right loss to minimum loss found

    // update sorted indices for children
    set<unsigned long int> left_indices_set(best_indices_left.begin(), best_indices_left.end());
    vector<vector<unsigned long int>> sorted_indices_left(this->m, std::vector<unsigned long int>());
    vector<vector<unsigned long int>> sorted_indices_right(this->m, std::vector<unsigned long int>());
    for (unsigned long int feature = 1; feature < this->m; feature++)
    {
        // for each feature, emplace back for left or right child based on which it's in
        for (unsigned long int idx : sorted_indices[feature])
        {
            if (left_indices_set.find(idx) != left_indices_set.end())
            {
                sorted_indices_left[feature].push_back(idx);
            }
            else
            {
                sorted_indices_right[feature].push_back(idx);
            }
        }
    }

    // now replace losses based on greedy completions
    double final_left_objective = GreedyFull::recursive_fit(sorted_indices_left,  best_llt_left,  best_b_left,  best_y_sum_sq_left,  node->left,  depth_remaining - 1);
    double final_right_objective = GreedyFull::recursive_fit(sorted_indices_right, best_llt_right, best_b_right, best_y_sum_sq_right, node->right, depth_remaining - 1);

    if (final_left_objective + final_right_objective < node->obj)
    {
        node->is_leaf = false;
        node->feature_idx = best_feature;
        if (best_is_binary) {
            node->threshold = 0.5;
        } else {
            node->threshold = (this->X(sorted_indices[best_feature][best_feature_index], best_feature) + this->X(sorted_indices[best_feature][best_feature_index + 1], best_feature)) / 2.0;
        }
        node->obj = final_left_objective + final_right_objective;
        return final_left_objective + final_right_objective; // return the objective at this node
    }
    else
    {
        delete node->left;    // delete left child if it was created
        delete node->right;   // delete right child if it was created
        node->is_leaf = true; // no valid split found, this is a leaf
        return node->obj;     // return the objective at this node
    }
}



double CLARITreeFull::recursive_fit(vector<vector<unsigned long int>> sorted_indices, LLT<MatrixXd> llt, VectorXd b, double y_sum_sq, Node *node, int depth_remaining)
{
    /*
    For every possible single step split, compute the loss using greedy,
    then pick the split locally the best with that heuristic.
    Then, replace those greedy calls with another CLARITree call
    */
    if (depth_remaining == 0 || node->obj <= 2 * this->scaled_lambda)
    {
        node->is_leaf = true;
        return node->obj; // return the objective at this node
    }

    // find best split
    bool split_flag = false;
    unsigned long int best_feature;
    unsigned long int best_feature_index;
    bool best_is_binary = false;
    // samples_past_quota shared across features; initialize to account for skipping feature 0 (intercept)
    // Use size of any feature (they all have same size = number of rows in current node)
    int samples_past_quota = sorted_indices.size() > 1 ? static_cast<int>(sorted_indices[1].size()) : 0;
    double min_obj = node->obj; // start with parent loss
    LLT<MatrixXd> best_llt_right;
    LLT<MatrixXd> best_llt_left;
    VectorXd best_b_left;
    VectorXd best_b_right;
    vector<unsigned long int> best_indices_left;
    double best_left_leaf_obj, best_right_leaf_obj;
    double best_y_sum_sq_left, best_y_sum_sq_right;

    for (unsigned long int feature = 1; feature < this->m; feature++)
    {
        bool is_bin = std::find(binary_idx_.begin(), binary_idx_.end(), (int)feature) != binary_idx_.end();
        bool is_cont = std::find(continuous_idx_.begin(), continuous_idx_.end(), (int)feature) != continuous_idx_.end();

        // --- 1) Handle binary feature (0/1) ---------------------------------
        if (is_bin)
        {
            std::vector<int> left_rows, right_rows;
            // Only iterate over current node's samples, not global n
            for (unsigned long int idx : sorted_indices[feature])
            {
                int i = (int)idx;
                if (this->X(i, feature) <= 0.5)
                    left_rows.push_back(i);
                else
                    right_rows.push_back(i);
            }
            if (left_rows.empty() || right_rows.empty())
                continue; // not splittable

            // ---- FULL RECOMPUTE (no rankUpdate) ----
            auto recompute_stats_bin = [&](const vector<int>& rows)
            {
                const int p = p_reg_;
                MatrixXd gram = this->scaled_kappa * MatrixXd::Identity(p, p);
                gram(0, 0) = 1e-12;
                VectorXd bb = VectorXd::Zero(p);
                double yss = 0.0;
                for (int r : rows)
                {
                    auto xr = reg_row(r);
                    bb.noalias() += xr.transpose() * this->y(r);
                    gram.noalias() += xr.transpose() * xr;
                    yss += this->y(r) * this->y(r);
                }
                LLT<MatrixXd> lltmp(gram);
                return tuple<LLT<MatrixXd>, VectorXd, double>(lltmp, bb, yss);
            };

            auto [llt_left, b_left, y_sum_sq_left] = recompute_stats_bin(left_rows);
            auto [llt_right, b_right, y_sum_sq_right] = recompute_stats_bin(right_rows);

            double left_obj = loss(llt_left, b_left, y_sum_sq_left) + this->scaled_lambda;
            double right_obj = loss(llt_right, b_right, y_sum_sq_right) + this->scaled_lambda;

            if (left_obj + right_obj < min_obj)
            {
                split_flag = true;
                min_obj = left_obj + right_obj;
                best_is_binary = true;
                best_feature = feature;
                best_feature_index = -1; // flag: binary split
                best_llt_left = llt_left;
                best_llt_right = llt_right;
                best_b_left = b_left;
                best_b_right = b_right;
                best_left_leaf_obj = left_obj;
                best_right_leaf_obj = right_obj;
                best_indices_left.clear();
                best_indices_left.reserve(left_rows.size());
                for (int r : left_rows) best_indices_left.push_back((unsigned long int)r);
                best_y_sum_sq_left = y_sum_sq_left;
                best_y_sum_sq_right = y_sum_sq_right;
            }
            continue; // done with binary feature
        }

        // --- 2) Skip non-continuous (categorical not one-hot) ---------------
        if (!is_cont)
            continue;

        vector<unsigned long int> left_indices = {};
        int n_feature = sorted_indices[feature].size();
        // int samples_past_quota = 0;
        for (unsigned long int feature_idx = 0; feature_idx < n_feature; feature_idx++)
        {
            int row = sorted_indices[feature][feature_idx];
            left_indices.push_back(row);
            samples_past_quota++;
            if (samples_past_quota < this->stride)
            {
                continue;
            }
            if (feature_idx == n_feature - 1)
            {
                continue;
            }
            if (this->X(row, feature) == this->X(sorted_indices[feature][feature_idx + 1], feature))
            {
                continue;
            }
            samples_past_quota -= this->stride;

            // ---- FULL RECOMPUTE (no rankUpdate) ----
            auto recompute_stats = [&](const vector<unsigned long int>& rows)
            {
                const int p = p_reg_;
                MatrixXd gram = this->scaled_kappa * MatrixXd::Identity(p, p);
                gram(0, 0) = 1e-12;
                VectorXd bb = VectorXd::Zero(p);
                double yss = 0.0;
                for (auto r : rows)
                {
                    auto xr = reg_row((int)r);
                    bb.noalias() += xr.transpose() * this->y((int)r);
                    gram.noalias() += xr.transpose() * xr;
                    yss += this->y((int)r) * this->y((int)r);
                }
                LLT<MatrixXd> lltmp(gram);
                return tuple<LLT<MatrixXd>, VectorXd, double>(lltmp, bb, yss);
            };
            
            vector<unsigned long int> right_indices;
            right_indices.reserve(sorted_indices[feature].size() - feature_idx - 1);
            for (unsigned long int j = feature_idx + 1; j < sorted_indices[feature].size(); ++j) {
                right_indices.push_back(sorted_indices[feature][j]);
            }
            auto [llt_left, b_left, y_sum_sq_left] = recompute_stats(left_indices);
            auto [llt_right, b_right, y_sum_sq_right] = recompute_stats(right_indices);

            // double left_obj = loss(llt_left.matrixL(), b_left, y_sum_sq_left) + this->scaled_lambda;
            // double right_obj = loss(llt_right.matrixL(), b_right, y_sum_sq_right) + this->scaled_lambda;
            double left_obj = loss(llt_left, b_left, y_sum_sq_left) + this->scaled_lambda;
            double right_obj = loss(llt_right, b_right, y_sum_sq_right) + this->scaled_lambda;
            if (depth_remaining == 1)
            {
                if (left_obj + right_obj < min_obj)
                {
                    split_flag = true;
                    min_obj = left_obj + right_obj;
                    best_feature = feature;
                    best_feature_index = feature_idx;
                    best_llt_left = llt_left;
                    best_llt_right = llt_right;
                    best_b_left = b_left;
                    best_b_right = b_right;
                    best_left_leaf_obj = left_obj;
                    best_right_leaf_obj = right_obj;
                    best_indices_left = left_indices;
                    best_y_sum_sq_left = y_sum_sq_left;
                    best_y_sum_sq_right = y_sum_sq_right;
                }
            }
            else
            {
                set<unsigned long int> left_indices_set(left_indices.begin(), left_indices.end());

                delete node->left;
                delete node->right;
                node->left = new Node();
                node->left->obj = left_obj;
                node->right = new Node();
                node->right->obj = right_obj;

                vector<vector<unsigned long int>> sorted_indices_left(this->m, vector<unsigned long int>());
                vector<vector<unsigned long int>> sorted_indices_right(this->m, vector<unsigned long int>());
                for (unsigned long int feature = 1; feature < this->m; feature++)
                {
                    for (unsigned long int idx : sorted_indices[feature])
                    {
                        if (left_indices_set.find(idx) != left_indices_set.end())
                        {
                            sorted_indices_left[feature].push_back(idx);
                        }
                        else
                        {
                            sorted_indices_right[feature].push_back(idx);
                        }
                    }
                }

                auto make_greedy_full = [&] {
                    GreedyFull gf(this->kappa, this->depth, this->lambda, this->stride, this->verbose);
                    gf.X = this->X; gf.y = this->y;
                    gf.n = this->n; gf.m = this->m;
                    gf.scaled_kappa = this->scaled_kappa;
                    gf.scaled_lambda = this->scaled_lambda;
                    gf.p_reg_ = this->p_reg_;
                    gf.X_reg_ = this->X_reg_;
                    gf.categorical_idx_ = this->categorical_idx_;
                    gf.continuous_idx_ = this->continuous_idx_;
                    gf.binary_idx_ = this->binary_idx_;
                    gf.root = nullptr;
                    return gf;
                };

                GreedyFull gfL = make_greedy_full();
                GreedyFull gfR = make_greedy_full();

                double greedy_completion_left_obj =
                    gfL.recursive_fit(sorted_indices_left,  llt_left,  b_left,
                                    y_sum_sq_left,  node->left,  depth_remaining - 1);
                double greedy_completion_right_obj =
                    gfR.recursive_fit(sorted_indices_right, llt_right, b_right,
                                    y_sum_sq_right, node->right, depth_remaining - 1);
                if (greedy_completion_left_obj + greedy_completion_right_obj < min_obj)
                {
                    split_flag = true;
                    min_obj = greedy_completion_left_obj + greedy_completion_right_obj;
                    best_feature = feature;
                    best_feature_index = feature_idx;
                    best_llt_left = llt_left;
                    best_llt_right = llt_right;
                    best_b_left = b_left;
                    best_b_right = b_right;
                    best_left_leaf_obj = left_obj;
                    best_right_leaf_obj = right_obj;
                    best_indices_left = left_indices;
                    best_y_sum_sq_left = y_sum_sq_left;
                    best_y_sum_sq_right = y_sum_sq_right;
                }
            }
        }
    }

    if (!split_flag)
    {
        node->is_leaf = true;
        return node->obj;
    }

    if (depth_remaining == 1)
    {
        node->is_leaf = false;
        node->feature_idx = best_feature;
        if (best_is_binary) {
            node->threshold = 0.5;
        } else {
            const auto &order = sorted_indices[best_feature];
            node->threshold =
                (this->X(order[best_feature_index], best_feature) +
                 this->X(order[best_feature_index + 1], best_feature)) /
                2.0;
        }

        node->left = new Node();
        node->right = new Node();
        node->left->obj = best_left_leaf_obj;
        node->right->obj = best_right_leaf_obj;

        node->obj = node->left->obj + node->right->obj;
        return node->obj;
    }

    delete node->left;
    delete node->right;
    node->left = new Node();
    node->left->obj = best_left_leaf_obj;
    node->right = new Node();
    node->right->obj = best_right_leaf_obj;

    set<unsigned long int> left_indices_set(best_indices_left.begin(), best_indices_left.end());
    vector<vector<unsigned long int>> sorted_indices_left(this->m, vector<unsigned long int>());
    vector<vector<unsigned long int>> sorted_indices_right(this->m, vector<unsigned long int>());
    for (unsigned long int feature = 1; feature < this->m; feature++)
    {
        for (unsigned long int idx : sorted_indices[feature])
        {
            if (left_indices_set.find(idx) != left_indices_set.end())
            {
                sorted_indices_left[feature].push_back(idx);
            }
            else
            {
                sorted_indices_right[feature].push_back(idx);
            }
        }
    }

    double final_left_obj = recursive_fit(sorted_indices_left, best_llt_left, best_b_left, best_y_sum_sq_left, node->left, depth_remaining - 1);
    double final_right_obj = recursive_fit(sorted_indices_right, best_llt_right, best_b_right, best_y_sum_sq_right, node->right, depth_remaining - 1);

    if (final_left_obj + final_right_obj < node->obj)
    {
        node->is_leaf = false;
        node->feature_idx = best_feature;
        if (best_is_binary) {
            node->threshold = 0.5;
        } else {
            node->threshold =
                (this->X(sorted_indices[best_feature][best_feature_index], best_feature) +
                 this->X(sorted_indices[best_feature][best_feature_index + 1], best_feature)) / 2.0;
        }
        node->obj = final_left_obj + final_right_obj;
        return final_left_obj + final_right_obj;
    }
    else
    {
        delete node->left;
        delete node->right;
        node->is_leaf = true;
        return node->obj;
    }
}

// ---- CSV I/O implementations ----

std::vector<double> parseLine(const std::string &line, char delimiter)
{
    std::vector<double> result;
    std::stringstream ss(line);
    std::string cell;

    while (getline(ss, cell, delimiter))
    {
        // Trim whitespace
        cell.erase(0, cell.find_first_not_of(" \t\r\n"));
        cell.erase(cell.find_last_not_of(" \t\r\n") + 1);

        try
        {
            result.push_back(std::stod(cell));
        }
        catch (const std::invalid_argument &)
        {
            std::cerr << "Warning: Could not parse '" << cell << "' as number, using 0" << std::endl;
            result.push_back(0.0);
        }
    }
    return result;
}


bool readCSV(const std::string &filename,
             Eigen::MatrixXd &X,
             Eigen::VectorXd &y,
             bool has_header,
             char delimiter)
{
    std::ifstream file(filename);
    if (!file.is_open())
    {
        std::cerr << "Error: Could not open file " << filename << std::endl;
        return false;
    }

    std::vector<std::vector<double>> data;
    std::string line;
    bool first_row = true;

    while (getline(file, line))
    {
        if (has_header && first_row)
        {
            first_row = false;
            continue;
        }
        std::vector<double> row = parseLine(line, delimiter);
        if (!row.empty())
        {
            data.push_back(row);
        }
    }
    file.close();

    if (data.empty())
    {
        std::cerr << "Error: No data found in file" << std::endl;
        return false;
    }

    int num_rows = data.size();
    int num_cols = data[0].size();

    if (num_cols < 2)
    {
        std::cerr << "Error: Need at least 2 columns (features + target)" << std::endl;
        return false;
    }

    // Initialize matrices
    X.resize(num_rows, num_cols - 1);
    y.resize(num_rows);

    for (int i = 0; i < num_rows; i++)
    {
        if (data[i].size() != num_cols)
        {
            std::cerr << "Error: Inconsistent number of columns in row " << i << std::endl;
            return false;
        }
        for (int j = 0; j < num_cols - 1; j++)
        {
            X(i, j) = data[i][j];
        }
        y(i) = data[i][num_cols - 1];
    }

    // Add intercept column
    Eigen::MatrixXd X_with_intercept(X.rows(), X.cols() + 1);
    X_with_intercept.leftCols(1) = Eigen::VectorXd::Ones(X.rows());
    X_with_intercept.rightCols(X.cols()) = X;
    X = X_with_intercept;

    return true;
}
