#ifndef CLARI_TREE_HPP
#define CLARI_TREE_HPP
#pragma once

#include <string>
#include <vector>
#include <Eigen/Dense>
#include <fstream>   
#include <sstream>   
#include <string>    
#include <cstddef>

// --------- Forward declaration ----------
class Node;

//
// ========== Node ==========
// Represents one node in the regression tree
//
class Node {
public:
    Node* left;                   // left child (nullptr if leaf)
    Node* right;                  // right child (nullptr if leaf)
    bool is_leaf;                 // flag for leaf node
    double obj;                   // objective value at this node
    double threshold;             // threshold used for splitting
    int feature_idx;              // feature index used for splitting
    Eigen::VectorXd coefficients; // ridge regression coefficients

    Node();
    ~Node();

    Node& operator=(const Node& other);   // copy assignment
    Node(const Node& other);              // copy constructor

    std::string print_tree(int indentation = 0); // print subtree
    
};

//
// ========== Greedy ==========
// Greedy regression tree with ridge regression in each leaf
//
class Greedy {
public:
    Eigen::MatrixXd X;   // feature matrix
    Eigen::VectorXd y;   // target vector
    double kappa;        // ridge regularization parameter
    bool verbose;        // verbosity flag
    int depth;           // maximum depth
    unsigned long int n; // number of samples
    unsigned long int m; // number of features
    double lambda;       // penalty on number of leaf nodes
    double scaled_kappa; // scaled by n
    double scaled_lambda;// scaled by TSS
    int stride;          // to speed up traversal or align with quantiles, we can skip some data points
    std::vector<int> continuous_idx_;  // indices in X (excluding intercept col 0)
    std::vector<int> binary_idx_;      // indices in X (excluding intercept col 0)
    std::vector<int> categorical_idx_; // optional, provided by caller (excluding intercept col 0)
    Eigen::MatrixXd X_reg_;   // [n x p_reg] = [1 | continuous features]
    int p_reg_ = 0;           // = 1 + continuous_idx_.size()
    int p_split_ = 0;         // = X.cols() (original with intercept)
    Node* root;          // root node

    Greedy(double kappa, int depth, double lambda = 0.0, int stride = 1, bool verbose = true);
    virtual ~Greedy();

    Greedy& operator=(const Greedy& other);
    Greedy(const Greedy& other);

    double fit(Eigen::MatrixXd X, Eigen::VectorXd y, const std::vector<int>& categorical_idx = {}); // fit tree, return objective
    void fit_coefficients(Node* node, Eigen::MatrixXd X, Eigen::VectorXd y);

    virtual double recursive_fit(
        std::vector<std::vector<unsigned long int>> sorted_indices,
        Eigen::LLT<Eigen::MatrixXd> llt,
        Eigen::VectorXd b,
        double y_sum_sq,
        Node* node,
        int depth_remaining
    );

    // double loss(Eigen::MatrixXd L, Eigen::VectorXd b, double y_sum_sq);
    double loss(const Eigen::LLT<Eigen::MatrixXd>& llt,
                const Eigen::VectorXd& b,
                double y_sum_sq);
    Eigen::VectorXd predict(Eigen::MatrixXd X);
    double predict_row(Eigen::VectorXd x);

    std::string print_tree();
    std::size_t n_leaves() const;

    // --- NEW helpers---
    static bool is_binary_column(const Eigen::VectorXd& col);
    void detect_feature_types(); // fills *_idx_, builds X_reg_, sets p_reg_, p_split_

    // Row in regression space [1 | x_cont]
    Eigen::RowVectorXd reg_row(int i) const;

    // Full recompute (regression view) for a subset of rows
    std::tuple<Eigen::LLT<Eigen::MatrixXd>, Eigen::VectorXd, double>
    recompute_stats_from_rows(const std::vector<int>& rows);

private:
    static std::size_t count_leaves(const Node* n);

  
};

//
// ========== CLARITree ==========
// Main CLARITree algorithm with recursive splitting strategy
//
class CLARITree : public Greedy {
public:
    CLARITree(double kappa, int depth, double lambda = 0.0, int stride = 1, bool verbose = true);

    double recursive_fit(
        std::vector<std::vector<unsigned long int>> sorted_indices,
        Eigen::LLT<Eigen::MatrixXd> llt,
        Eigen::VectorXd b,
        double y_sum_sq,
        Node* node,
        int depth_remaining
    ) override;
};
// ============================================================
// GreedyFull: ablation greedy (no rank-one updates anywhere)
// ============================================================
class GreedyFull : public Greedy {
public:
    GreedyFull(double kappa, int depth, double lambda = 0.0, int stride = 1, bool verbose = true)
        : Greedy(kappa, depth, lambda, stride, verbose) {}

    double recursive_fit(
        std::vector<std::vector<unsigned long int>> sorted_indices,
        Eigen::LLT<Eigen::MatrixXd> /* llt_unused */,
        Eigen::VectorXd             /* b_unused */,
        double                      /* y_sum_sq_unused */,
        Node* node,
        int depth_remaining
    ) override;
};


// ============================================================
// CLARITreeFull: ablation version (no rank-one updates)
// ============================================================
class CLARITreeFull : public Greedy {
public:
    CLARITreeFull(double kappa, int depth, double lambda, int stride, bool verbose)
        : Greedy(kappa, depth, lambda, stride, verbose) {}

    double recursive_fit(std::vector<std::vector<unsigned long int>> sorted_indices,
                         Eigen::LLT<Eigen::MatrixXd> llt,
                         Eigen::VectorXd b,
                         double y_sum_sq,
                         Node *node,
                         int depth_remaining
                        ) override; 
};

//
// ========== I/O Utilities ==========
// Parse CSV file into X (features) and y (target)
// Last column is y, other columns are features
//
std::vector<double> parseLine(const std::string& line, char delimiter);

bool readCSV(const std::string& filename,
             Eigen::MatrixXd& X,
             Eigen::VectorXd& y,
             bool has_header = true,
             char delimiter = ',');

#endif // CLARI_TREE_HPP



