#include "clari_tree.hpp"
#include <iostream>
#include <chrono>

using namespace Eigen;
using namespace std;


int main(int argc, char* argv[]){
    if (argc != 5) {
        cout << "Usage to run a file: " << argv[0] << " <csv_file> <depth>" << endl;
        cout << "Running test cases instead" << endl;

        // Test case 1: 

        MatrixXd X {{1, 0, 0, 1},
                    {0, 1, 0, 2},
                    {0, 0, 1, 3}, 
                    {1, 0, 0, 4},
                    {0, 1, 0, 5},
                    {0, 0, 1, 6}};

        VectorXd y(6);
        y << 1, 2, 3, -1, -2, -3;
        double kappa = 0.1;

        Greedy greedy_tree = Greedy(kappa, 2, 0.02, 1, false);
        double greedy_loss = greedy_tree.fit(X, y);
        cout << "Greedy loss found: " << greedy_loss << endl;
        CLARITree clari_tree = CLARITree(kappa, 2, 0.02, 1, false);
        double clari_loss = clari_tree.fit(X, y);
        cout << "CLARITree loss found: " << clari_loss << endl;
        cout << "Greedy tree structure:\n" << greedy_tree.print_tree() << endl;
        cout << "CLARITree structure:\n" << clari_tree.print_tree() << endl;
        cout << "MSE of predictions (not counting ridge penalty on optima):" << endl;
        cout << "Greedy predictions: " << (greedy_tree.predict(X) - y).squaredNorm() / X.rows() << endl;
        cout << "CLARITree predictions: " << (clari_tree.predict(X) - y).squaredNorm() / X.rows() << endl;
        // cout << endl;

        // test case 2:
        MatrixXd X2 = MatrixXd::Identity(3, 3);
        VectorXd y2(3);
        y2 << 1, 2, 3;
        kappa = 0.1;

        // Call the fit function
        greedy_loss = greedy_tree.fit(X2, y2);
        cout << "Greedy loss found: " << greedy_loss << endl;
        clari_loss = clari_tree.fit(X2, y2);
        cout << "CLARITree loss found: " << clari_loss << endl;
        cout << "Greedy tree structure:\n" << greedy_tree.print_tree() << endl;
        cout << "CLARITree structure:\n" << clari_tree.print_tree() << endl;
        cout << "MSE of predictions (not counting ridge penalty on optima):" << endl;
        cout << "Greedy predictions: " << (greedy_tree.predict(X2) - y2).squaredNorm() / X2.rows() << endl;
        cout << "CLARITree predictions: " << (clari_tree.predict(X2) - y2).squaredNorm() / X2.rows() << endl;

    } else {
        cout << "Running with provided CSV file: " << argv[1] << " and depth: " << argv[2] << " and lambda: " << argv[3] << endl;
        int depth = stoi(argv[2]);
        double lambda = stod(argv[3]);
        double kappa = stod(argv[4]);
        string filename = argv[1];
        MatrixXd X3;
        VectorXd y3;
        if (!readCSV(filename, X3, y3)) {
            cerr << "Error reading CSV file" << endl;
            return 1;
        }
     
        Greedy greedy_tree = Greedy(kappa, depth, lambda, 1, false);
        CLARITree clari_tree = CLARITree(kappa, depth, lambda, 1, false);

        // ---- Compute R^2 ----
        auto compute_R2 = [&](const VectorXd& y_true, const VectorXd& y_pred){
            double ss_res = (y_true - y_pred).squaredNorm();
            double mean_y = y_true.mean();
            double ss_tot = (y_true.array() - mean_y).matrix().squaredNorm();
            return 1.0 - ss_res / ss_tot;
        };

        
        // Fit the model
        // timing:
        auto start = chrono::high_resolution_clock::now();
        double greedy_loss = greedy_tree.fit(X3, y3);
        VectorXd yhat_greedy = greedy_tree.predict(X3);
        auto end = chrono::high_resolution_clock::now();
        chrono::duration<double> elapsed = end - start;
        cout << "Greedy loss found: " << greedy_loss << " in " << elapsed.count() << " seconds" << endl;
        cout << "Greedy R^2:  " << compute_R2(y3, yhat_greedy) << endl;
        start = chrono::high_resolution_clock::now();
        double clari_loss = clari_tree.fit(X3, y3);
        VectorXd yhat_clari = clari_tree.predict(X3);
        end = chrono::high_resolution_clock::now();
        elapsed = end - start;
        cout << "CLARITree loss found: " << clari_loss << " in " << elapsed.count() << " seconds" << endl;
        cout << "CLARITree R^2: " << compute_R2(y3, yhat_clari) << endl;

        // cout << "Greedy tree structure:\n" << greedy_tree.print_tree() << endl;
        // cout << "CLARITree structure:\n" << clari_tree.print_tree() << endl;
    }
        return 0;

}

