#include <iostream>
#include <iomanip>

#include <vector>
#include <array>
#include <numeric>
#include <set>
#include <chrono>

#include <Eigen/Dense>

#include "clari_tree_const.hpp"
using namespace Eigen;
using namespace std;

ConstNode::ConstNode()
    : left(nullptr), right(nullptr), is_leaf(true), obj(0), threshold(0),
      feature_idx(0), prediction(0.0) {}

// destructor
ConstNode::~ConstNode()
{
    delete left;
    delete right;
}

// assignment operator for ConstNode
ConstNode& ConstNode::operator=(const ConstNode& other) {
    if (this != &other) {
        // Clean up existing resources
        delete left;
        delete right;

        // Copy data from the other node
        left = other.left ? new ConstNode(*other.left) : nullptr;
        right = other.right ? new ConstNode(*other.right) : nullptr;
        is_leaf = other.is_leaf;
        obj = other.obj;
        threshold = other.threshold;
        feature_idx = other.feature_idx;
        prediction = other.prediction;
    }
    return *this;
}
// copy constructor for ConstNode
ConstNode::ConstNode(const ConstNode& other) {
    left = other.left ? new ConstNode(*other.left) : nullptr;
    right = other.right ? new ConstNode(*other.right) : nullptr;
    is_leaf = other.is_leaf;
    obj = other.obj;
    threshold = other.threshold;
    feature_idx = other.feature_idx;
    prediction = other.prediction;
}

std::string ConstNode::print_tree(int indentation) {
    string indent(indentation, ' ');
    if (is_leaf) {
        return indent + "predict " + to_string(this->prediction) + "; gives objective " + to_string(obj) + "\n";
    } else {
        string split_string = indent + "If feature " + to_string(feature_idx) + " <= " + to_string(threshold) + ":\n";
        if (left){
            split_string += left->print_tree(indentation + 2);
        }
        split_string += indent + "Else:\n";
        if (right){
            split_string += right->print_tree(indentation + 2);
        }
        return split_string;
    }
}


// =======================
// GreedyConst implementation
// =======================
GreedyConst::GreedyConst(char depth, double lambda, int stride, bool verbose)
    : depth(depth), verbose(verbose), root(new ConstNode()), lambda(lambda), stride(stride) {}

GreedyConst::~GreedyConst() {
    delete root;
}

// assignment operator
GreedyConst& GreedyConst::operator=(const GreedyConst& other) {
    if (this != &other) {
        // Clean up existing resources
        delete root;

        // Copy data from the other tree
        X = other.X;
        y = other.y;
        verbose = other.verbose;
        depth = other.depth;
        lambda = other.lambda;
        scaled_lambda = other.scaled_lambda;
        n = other.n;
        m = other.m;
        root = other.root ? new ConstNode(*other.root) : nullptr;
        stride = other.stride;
    }
    return *this;
}

// copy constructor
GreedyConst::GreedyConst(const GreedyConst& other)
    : X(other.X), y(other.y), verbose(other.verbose),
        depth(other.depth), n(other.n), m(other.m), lambda(other.lambda), stride(other.stride) {
        root = other.root ? new ConstNode(*other.root) : nullptr;
        }

double GreedyConst::fit(MatrixXd X, VectorXd y) {
    this->X = X;
    this->y = y;
    this->n = X.rows();
    this->m = X.cols();
    // scale by n 
    this->scaled_lambda = this->n * this->lambda;
    // scale by TSS
    double mean_y = y.mean();
    double tss = (y.array() - mean_y).matrix().squaredNorm();
    this->scaled_lambda = this->lambda * tss;
    
    /*
    Compute root loss and info
    */
    double y_sum = y.sum();
    double y_sum_sq = y.squaredNorm();
    double parent_loss = GreedyConst::loss(this->n, y_sum, y_sum_sq);

    delete this->root; // delete old root if exists
    this->root = new ConstNode();
    this->root->obj = parent_loss + this->scaled_lambda;

    // get sorted feature indices
    // that is, for each feature, we want a row vector of indices sorted by feature value
    vector<vector<unsigned long int>> sorted_indices(this->m, vector<unsigned long int>(this->n));
    for (unsigned long int feature = 0; feature < this->m; feature++) {
        vector<unsigned long int> indices(this->n);
        iota(indices.begin(), indices.end(), 0); // fill with 0, 1, ..., n-1
        sort(indices.begin(), indices.end(), [&](unsigned long int a, unsigned long int b) {
            return X(a, feature) < X(b, feature);
        });
        sorted_indices[feature] = indices;
    }

    // learn partitioning structure and resulting loss
    double objective = recursive_fit(sorted_indices, this->n, y_sum, y_sum_sq, this->root, this->depth);
    // Now learn the coefficient vectors for each of the nodes.
    fit_coefficients(this->root, this->X, this->y);

    return objective;
}

void GreedyConst::fit_coefficients(ConstNode* node, MatrixXd X, VectorXd y) {
    /*
    Fit the coefficients for the linear regression at this node.
    */
    if (node->is_leaf) {
        // fit
        node->prediction = y.sum()/y.size();
    } else {
        // access those rows of X for which row(node->feature_idx) < node->threshold
        vector<unsigned long int> left_indices;
        vector<unsigned long int> right_indices;
        for (unsigned long int i = 0; i < X.rows(); i++) {
            if (X(i, node->feature_idx) < node->threshold) {
                left_indices.push_back(i);
            } else {
                right_indices.push_back(i);
            }
        }
        // fit coefficients for children
        fit_coefficients(node->left, X(left_indices, Eigen::all), y(left_indices));
        fit_coefficients(node->right, X(right_indices, Eigen::all), y(right_indices));
    }
    return;
}

double GreedyConst::recursive_fit(vector<vector<unsigned long int>> sorted_indices, int n, double y_sum, double y_sum_sq, ConstNode* node, char depth_remaining) {
    /*
    Finds a greedy tree from the current node, stopping if depth is 0 or if no valid split is found.
    Returns the objective of the greedy tree from this node
    Assumes current node has loss value filled in with its loss if it were a leaf.
    */
    if (depth_remaining == 0 || node->obj <= 2*this->scaled_lambda) {
        node->is_leaf = true;
        return node->obj;
    }

    // find best split
    bool split_flag = false;
    unsigned long int best_feature;
    unsigned long int best_feature_index;
    int samples_past_quota = 0; // counts number of samples we've seen, minus stride * # splits we've made; can't split unless samples_past_quota >= stride
    double min_obj = node->obj; // start with parent loss
    vector<int> best_indices_left;
    double best_left_obj, best_right_obj;
    double best_y_sum_sq_left, best_y_sum_sq_right;
    double best_y_sum_left, best_y_sum_right;
    int best_n_left, best_n_right;

    for (unsigned long int feature = 0; feature < this->m; feature++) {
        double y_sum_sq_left = 0; // sum of squares of y values in left child
        double y_sum_left = 0;
        int n_left = 0;
        vector<int> left_indices = {};

        double y_sum_sq_right = y_sum_sq; // copy parent y sum squared
        double y_sum_right = y_sum;
        int n_right = n;
        for (unsigned long int feature_idx = 0; feature_idx < sorted_indices[feature].size(); feature_idx++){
            int row = sorted_indices[feature][feature_idx]; // get the row index for this feature & threshold
            y_sum_sq_left += this->y(row) * this->y(row);
            y_sum_left += this->y(row);
            left_indices.push_back(row);
            n_left++;

            y_sum_sq_right -= this->y(row) * this->y(row);
            y_sum_right -= this->y(row);
            n_right--;

            samples_past_quota++;
            if (samples_past_quota < this->stride)
            {
                continue; // skip if we haven't seen enough samples since last split
            }
            if (feature_idx == sorted_indices[feature].size() - 1) {
                // if this is the last feature index, we can't split further
                continue;
            }
            if (this->X(row, feature) == this->X(sorted_indices[feature][feature_idx + 1], feature)) {
                continue; // skip if this is not a valid split point
            }
            samples_past_quota -= this->stride;

            double left_obj = loss(n_left, y_sum_left, y_sum_sq_left) + this->scaled_lambda;
            double right_obj = loss(n_right, y_sum_right, y_sum_sq_right) + this->scaled_lambda;

            if (left_obj + right_obj < min_obj) {
                split_flag = true;
                min_obj = left_obj + right_obj;
                best_feature = feature;
                best_feature_index = feature_idx;
                best_left_obj = left_obj;
                best_right_obj = right_obj;
                best_indices_left = left_indices;
                best_y_sum_sq_left = y_sum_sq_left;
                best_y_sum_sq_right = y_sum_sq_right;
                best_y_sum_left = y_sum_left;
                best_y_sum_right = y_sum_right;
                best_n_left = n_left;
                best_n_right = n_right;
            }
        }
    }

    if (!split_flag) {
        node->is_leaf = true; // no valid split found, this is a leaf
        return node->obj; // return the loss at this node
    }
    // if the depth remaining is 1, we have find the best feature and don't need to sort indices for children
    if (depth_remaining == 1) {
        node->is_leaf = false;
        node->feature_idx = best_feature;
        node->threshold =
            (this->X(sorted_indices[best_feature][best_feature_index], best_feature) +
            this->X(sorted_indices[best_feature][best_feature_index + 1], best_feature)) / 2.0;

        node->left  = new ConstNode();
        node->right = new ConstNode();
        node->left->obj  = best_left_obj;   
        node->right->obj = best_right_obj;  

        node->obj = best_left_obj + best_right_obj;
        return node->obj;
    }

    // if valid split found, create children nodes to see if it's worth continuing to split
    node->left = new ConstNode();
    node->left->obj = best_left_obj; // set left loss to minimum loss found
    node->right = new ConstNode();
    node->right->obj = best_right_obj; // set right loss to minimum loss found

    // update sorted indices for children
    set<unsigned long int> left_indices_set(best_indices_left.begin(), best_indices_left.end());
    vector<vector<unsigned long int>> sorted_indices_left(this->m, vector<unsigned long int>());
    vector<vector<unsigned long int>> sorted_indices_right(this->m, vector<unsigned long int>());
    for (unsigned long int feature = 0; feature < this->m; feature++) {
        // for each feature, emplace back for left or right child based on which it's in
        for (unsigned long int idx : sorted_indices[feature]) {
            if (left_indices_set.find(idx) != left_indices_set.end()) {
                sorted_indices_left[feature].push_back(idx);
            } else {
                sorted_indices_right[feature].push_back(idx);
            }
        }
    }

    // now replace losses based on greedy completions
    double final_left_objective = GreedyConst::recursive_fit(sorted_indices_left, best_n_left, best_y_sum_left, best_y_sum_sq_left, node->left, depth_remaining - 1);
    double final_right_objective = GreedyConst::recursive_fit(sorted_indices_right, best_n_right, best_y_sum_right, best_y_sum_sq_right, node->right, depth_remaining - 1);

    if (final_left_objective + final_right_objective < node->obj) {
        node->is_leaf = false;
        node->feature_idx = best_feature;
        node->threshold = (this->X(sorted_indices[best_feature][best_feature_index], best_feature) + this->X(sorted_indices[best_feature][best_feature_index + 1], best_feature)) / 2.0;
        node->obj = final_left_objective + final_right_objective;
        return final_left_objective + final_right_objective; // return the objective at this node
    } else {
        delete node->left; // delete left child if it was created
        delete node->right; // delete right child if it was created
        node->is_leaf = true; // no valid split found, this is a leaf
        return node->obj; // return the objective at this node
    }
}
/*
sum squares from taking the mean
Parameters: 
- sum, the sum nE[y]
- sum_sq, the unscaled second moment nE[y^2]

n(E[x^2] - E[x]^2)
*/
double GreedyConst::loss(int n, double sum, double sum_sq) {
    return sum_sq - sum*sum/n;
}

VectorXd GreedyConst::predict(MatrixXd X){
    /*
    Predicts the output for the input matrix X using the fitted tree.
    Returns a vector of predictions.
    */
    VectorXd predictions(X.rows());
    for (int i = 0; i < X.rows(); i++) {
        predictions(i) = predict_row(X.row(i));
    }
    return predictions;
}

double GreedyConst::predict_row(VectorXd x) {
    // Traverse the tree to make a prediction
    ConstNode* current = this->root;
    while (current != nullptr) {
        if (current->is_leaf) {
            return current->prediction;
        }
        // Decide whether to go left or right
        if (x(current->feature_idx) <= current->threshold) {
            current = current->left;
        } else {
            current = current->right;
        }
    }
    // If we reach here, something went wrong
    throw runtime_error("Invalid tree structure");
}

string GreedyConst::print_tree() {
    if (this->root == nullptr) {
        return "No tree currently fit!";
    }
    return this->root->print_tree();
}


std::size_t GreedyConst::count_leaves(const ConstNode* n) {
    if (!n) return 0;
    if (n->is_leaf) return 1;
    return count_leaves(n->left) + count_leaves(n->right);
}

std::size_t GreedyConst::n_leaves() const {
    return count_leaves(this->root);
}

// =======================
// CLARITreeConst implementation
// =======================


CLARITreeConst::CLARITreeConst(char depth, double lambda, int stride, bool verbose)
    : GreedyConst(depth, lambda, stride, verbose) {}

double CLARITreeConst::recursive_fit(vector<vector<unsigned long int>> sorted_indices, int n, double y_sum, double y_sum_sq, ConstNode* node, char depth_remaining) {
    /*
    For every possible single step split, compute the loss using greedy, 
    then pick the split locally the best with that heuristic. 
    Then, replace those greedy calls with another CLARITree call
    */
    if (depth_remaining == 0 || node->obj <= 2*this->scaled_lambda) {
        node->is_leaf = true;
        return node->obj; // return the objective at this node
    }

    // find best split
    bool split_flag = false;
    unsigned long int best_feature;
    unsigned long int best_feature_index;
    int samples_past_quota = 0; // counts number of samples we've seen, minus stride * # splits we've made; can't split unless samples_past_quota >= stride
    double min_obj = node->obj; // start with parent loss
    vector<unsigned long int> best_indices_left;
    double best_left_leaf_obj, best_right_leaf_obj;
    double best_y_sum_sq_left, best_y_sum_sq_right;
    double best_y_sum_left, best_y_sum_right;
    int best_n_left, best_n_right;

    for (unsigned long int feature = 0; feature < this->m; feature++) {
        double y_sum_sq_left = 0; // sum of squares of y values in left child
        double y_sum_left = 0;
        int n_left = 0;
        vector<unsigned long int> left_indices = {};

        double y_sum_sq_right = y_sum_sq; // copy parent y sum squared
        double y_sum_right = y_sum;
        int n_right = n;
        for (unsigned long int feature_idx = 0; feature_idx < sorted_indices[feature].size(); feature_idx++){
            int row = sorted_indices[feature][feature_idx]; // get the row index for this feature & threshold
            y_sum_sq_left += this->y(row) * this->y(row);
            y_sum_left += this->y(row);
            left_indices.push_back(row);
            n_left++;

            y_sum_right -= this->y(row);
            y_sum_sq_right -= this->y(row) * this->y(row);
            n_right--;

            samples_past_quota++;
            if (samples_past_quota < this->stride)
            {
                continue; // skip if we haven't seen enough samples since last split
            }
            if (feature_idx == sorted_indices[feature].size() - 1) {
                // if this is the last feature index, we can't split further
                continue;
            }
            if (this->X(row, feature) == this->X(sorted_indices[feature][feature_idx + 1], feature)) {
                // skip if this is not a valid split point
                continue;
            }
            samples_past_quota -= this->stride;

            // loss estimate based on greedy completion
            double left_obj = loss(n_left, y_sum_left, y_sum_sq_left) + this->scaled_lambda;
            double right_obj = loss(n_right, y_sum_right, y_sum_sq_right) + this->scaled_lambda;

            if (depth_remaining == 1) {
                // if remaining depth=1 one can find the solution directly
                if (left_obj + right_obj < min_obj) {
                    split_flag = true;
                    min_obj = left_obj + right_obj;
                    best_feature = feature;
                    best_feature_index = feature_idx;
                    best_left_leaf_obj  = left_obj;   
                    best_right_leaf_obj = right_obj;
                    best_indices_left   = left_indices; 
                    best_y_sum_sq_left  = y_sum_sq_left;
                    best_y_sum_sq_right = y_sum_sq_right;
                    best_y_sum_left     = y_sum_left;
                    best_y_sum_right    = y_sum_right;
                    best_n_left         = n_left;
                    best_n_right        = n_right;
                }
            }
            else{
                // compute artifacts needed for greedy call.
                set<unsigned long int> left_indices_set(left_indices.begin(), left_indices.end());

                delete node->left; // delete left child if it was created
                delete node->right; // delete right child if it was created
                node->left = new ConstNode();
                node->left->obj = left_obj; // set left obj to minimum obj found
                node->right = new ConstNode();
                node->right->obj = right_obj; // set right obj to minimum obj found


                // update sorted indices for children
                vector<vector<unsigned long int>> sorted_indices_left(this->m, vector<unsigned long int>());
                vector<vector<unsigned long int>> sorted_indices_right(this->m, vector<unsigned long int>());
                for (unsigned long int feature = 0; feature < this->m; feature++) {
                    // for each feature, emplace back for left or right child based on which it's in
                    for (unsigned long int idx : sorted_indices[feature]) {
                        if (left_indices_set.find(idx) != left_indices_set.end()) {
                            sorted_indices_left[feature].push_back(idx);
                        } else {
                            sorted_indices_right[feature].push_back(idx);
                        }
                    }
                }

                double greedy_completion_left_obj = GreedyConst::recursive_fit(sorted_indices_left, n_left, y_sum_left, y_sum_sq_left, node->left, depth_remaining - 1);
                double greedy_completion_right_obj = GreedyConst::recursive_fit(sorted_indices_right, n_right, y_sum_right, y_sum_sq_right, node->right, depth_remaining - 1);

                if (greedy_completion_left_obj + greedy_completion_right_obj < min_obj) {
                    split_flag = true;
                    min_obj = greedy_completion_left_obj + greedy_completion_right_obj;
                    best_feature = feature;
                    best_feature_index = feature_idx;
                    best_left_leaf_obj = left_obj;
                    best_right_leaf_obj = right_obj;
                    best_indices_left = left_indices;
                    best_y_sum_sq_left = y_sum_sq_left;
                    best_y_sum_sq_right = y_sum_sq_right;
                    best_y_sum_left = y_sum_left;
                    best_y_sum_right = y_sum_right;
                    best_n_left = n_left;
                    best_n_right = n_right;
                }
            }   
        }
    }

    if (!split_flag) {
        node->is_leaf = true; // no valid split found, this is a leaf
        return node->obj; // return the loss at this node
    }
    
    if (depth_remaining == 1) {
        node->is_leaf = false;
        node->feature_idx = best_feature;
        const auto& order = sorted_indices[best_feature];
        node->threshold =
            (this->X(order[best_feature_index], best_feature) +
            this->X(order[best_feature_index + 1], best_feature)) / 2.0;

        node->left  = new ConstNode();
        node->right = new ConstNode();
        node->left->obj  = best_left_leaf_obj;   
        node->right->obj = best_right_leaf_obj;

        node->obj = node->left->obj + node->right->obj;
        return node->obj;
    }

    // if valid split found, create children nodes to see if it's worth continuing to split
    delete node->left; // delete left child if it was created
    delete node->right; // delete right child if it was created
    node->left = new ConstNode();
    node->left->obj = best_left_leaf_obj; // set left obj to minimum leaf obj found
    node->right = new ConstNode();
    node->right->obj = best_right_leaf_obj; // set right obj to minimum leaf obj found

    // update sorted indices for children
    set<unsigned long int> left_indices_set(best_indices_left.begin(), best_indices_left.end());
    vector<vector<unsigned long int>> sorted_indices_left(this->m, vector<unsigned long int>());
    vector<vector<unsigned long int>> sorted_indices_right(this->m, vector<unsigned long int>());
    for (unsigned long int feature = 0; feature < this->m; feature++) {
        // for each feature, emplace back for left or right child based on which it's in
        for (unsigned long int idx : sorted_indices[feature]) {
            if (left_indices_set.find(idx) != left_indices_set.end()) {
                sorted_indices_left[feature].push_back(idx);
            } else {
                sorted_indices_right[feature].push_back(idx);
            }
        }
    }

    // now replace losses based on recursive completions (using CLARITree's approach, not the full greedy tree)
    double final_left_obj = recursive_fit(sorted_indices_left, best_n_left, best_y_sum_left, best_y_sum_sq_left, node->left, depth_remaining - 1);
    double final_right_obj = recursive_fit(sorted_indices_right, best_n_right, best_y_sum_right, best_y_sum_sq_right, node->right, depth_remaining - 1);

    if (final_left_obj + final_right_obj < node->obj) {
        node->is_leaf = false; // this node is not a leaf, we found a valid split
        node->feature_idx = best_feature;
        node->threshold = (this->X(sorted_indices[best_feature][best_feature_index], best_feature) + this->X(sorted_indices[best_feature][best_feature_index + 1], best_feature)) / 2.0;
        node->obj = final_left_obj + final_right_obj;
        return final_left_obj + final_right_obj; // return the obj at this node
    } else {
        delete node->left; // delete left child if it was created
        delete node->right; // delete right child if it was created
        node->is_leaf = true; // no valid split found, this is a leaf
        return node->obj; // return the obj at this node
    }
}