#pragma once

#include <Eigen/Dense>
#include <string>
#include <vector>

/*
Minimal public interface for constant-leaf trees with
both greedy splits (GreedyConst) and CLARITree splits (CLARITreeConst).
Implementation is in clari_tree_const.cpp
*/

class ConstNode {
public:
    ConstNode* left;                 // left child, if not leaf
    ConstNode* right;                // right child, if not leaf
    bool is_leaf;               // is leaf flag
    double obj;                 // objective value at this node
    double threshold;           // split threshold
    int    feature_idx;         // split feature index
    double prediction;          // constant prediction at leaf

    ConstNode();
    ~ConstNode();

    // Rule-of-three
    ConstNode& operator=(const ConstNode& other);
    ConstNode(const ConstNode& other);

    // Pretty-print
    std::string print_tree(int indentation = 0);
};

class GreedyConst {
public:
    // Data (copied on fit)
    Eigen::MatrixXd X;
    Eigen::VectorXd y;

    bool  verbose;
    char  depth;                // max depth (kept as char to match your original)
    unsigned long int n;        // #samples
    unsigned long int m;        // #features
    double lambda;              // user penalty in [0,1]
    double scaled_lambda;       // dataset-specific penalty-per-leaf
    int stride;          // to speed up traversal or align with quantiles, we can skip some data points
    ConstNode*  root;                // root pointer

    explicit GreedyConst(char depth, double lambda = 0.0, int stride = 1, bool verbose = true);
    virtual ~GreedyConst();

    // Rule-of-three
    GreedyConst& operator=(const GreedyConst& other);
    GreedyConst(const GreedyConst& other);

    // Fit and predict
    double fit(Eigen::MatrixXd X, Eigen::VectorXd y);
    Eigen::VectorXd predict(Eigen::MatrixXd X);
    double predict_row(Eigen::VectorXd x);

    // Debug
    std::string print_tree();
    std::size_t n_leaves() const;

private:
    static std::size_t count_leaves(const ConstNode* n);

protected:
    // Train-time helpers
    void fit_coefficients(ConstNode* node, Eigen::MatrixXd X, Eigen::VectorXd y);

    // Greedy recursive builder (virtual so CLARITreeConst can override)
    virtual double recursive_fit(std::vector<std::vector<unsigned long int>> sorted_indices,
                                 int n, double y_sum, double y_sum_sq,
                                 ConstNode* node, char depth_remaining);

    // Sum-of-squares loss from mean
    static double loss(int n, double sum, double sum_sq);
};

// CLARITreeConst: special case of CLARITree for constant-leaf trees
class CLARITreeConst : public GreedyConst {
public:
    explicit CLARITreeConst(char depth, double lambda = 0.0, int stride = 1, bool verbose = true);

protected:
    double recursive_fit(std::vector<std::vector<unsigned long int>> sorted_indices,
                         int n, double y_sum, double y_sum_sq,
                         ConstNode* node, char depth_remaining) override;
};



