#include <vector>
#include <unordered_map>
#include <memory>
#include <bitset>
#include <limits>
#include <stdexcept>
#include <cmath>
#include "balanced_binary_tree.h"
#include "dataset_config.h"

struct HSTNode {
    int depth;
    int id_in_depth;  // actrually one node id (A or B)
    double width;
    bool is_leaf = false;
    int A_count = 0;
    int B_count = 0;
    int match_count = 0;
    double tree_weight = 0.0; 
    std::unordered_map<std::bitset<DIM>, std::shared_ptr<HSTNode>> children;
    std::shared_ptr<HSTNode> parent;
    AVLTree node_sampler; 
    double node_total_weight = 0.0; 
    std::optional<int> single_A_point; 
    
    HSTNode(int d, const int i, double w);
    void Build_Node_Sampler();
};

class QuadTree {
public:
    QuadTree(int d, double delta, int max_depth);
    void build(const std::vector<std::vector<double>>& A, const std::vector<std::vector<double>>& B, 
        const std::vector<std::vector<double>>& c_A, const std::vector<std::vector<double>>& c_B);
    std::shared_ptr<HSTNode> sample_by_weight(); 
    std::shared_ptr<HSTNode> sample_node_by_weight(std::shared_ptr<HSTNode> node);
    std::shared_ptr<HSTNode> getRoot() { return root; }
    double calculateTotalWeight();
    void Build_Tree_Sampler();
    void insert(std::vector<double>& point, char label, int point_i, std::vector<std::vector<double>>& A, std::vector<std::vector<double>>& B);
    std::shared_ptr<HSTNode> remove(const std::vector<double>& point, char label, std::vector<std::vector<double>>& A, std::vector<std::vector<double>>& B);
    double total_weight = 0.0; 
    AVLTree weight_tree;

    void dfs_quadtree(const std::shared_ptr<HSTNode>& root) {
        if (!root) return;
        std::cout
          << "quadtree match_count: " << root->match_count
          << " A count: "      << root->A_count
          << " B count: "      << root->B_count
          << " width: "       << root->width
          << " weight: "      << root->tree_weight
          << std::endl;
        std::cout << "id in depth: " << root->id_in_depth << " depth: " << root->depth << std::endl;
        std::cout << std::endl;
    
        for (const auto& [key, child] : root->children) {
            dfs_quadtree(child);
        }
    }

private:
    int dim;
    double delta;
    int max_depth;
    std::vector<double> random_shift;
    std::shared_ptr<HSTNode> root;
    std::vector<double> root_origin;

    template <std::size_t N = DIM>
    std::bitset<N> to_bitset(const std::vector<int>& vec);
    std::vector<int> point_to_grid_index(const std::vector<double>& point, int id_in_depth, double width, std::vector<std::vector<double>>& A, std::vector<std::vector<double>>& B);
    bool is_matched(const std::vector<double>& point_a, const std::vector<std::vector<double>>& dataset_b, int id_in_depth, double width, std::vector<std::vector<double>>& A, std::vector<std::vector<double>>& B);
    std::shared_ptr<HSTNode> build_subtree(const std::vector<std::vector<double>>& points_A, const std::vector<std::vector<double>>& points_B, 
        int depth, const std::vector<double>& current_origin, const std::vector<double>& root_origin, double width, std::shared_ptr<HSTNode>& parent_node);
    std::vector<int> calculate_current_id(const std::vector<double>& current_origin, const std::vector<double>& root_origin, double width);
    void add_points_to_subbuckets(const std::vector<std::vector<double>>& points_A, const std::vector<std::vector<double>>& points_B, 
        std::unordered_map<std::bitset<DIM>, std::vector<std::pair<std::string, std::vector<double>>>>& buckets,
        const std::vector<double>& current_origin, double width);
    void process_subgrid(std::shared_ptr<HSTNode>& node, std::unordered_map<std::bitset<DIM>, std::vector<std::pair<std::string, std::vector<double>>>>& buckets,
        const std::vector<double>& current_origin, const std::vector<double>& root_origin, double width, int depth);
    std::shared_ptr<HSTNode> insert_recursive(std::shared_ptr<HSTNode>& node, const std::vector<double>& point, char label, int depth, int point_i, std::vector<std::vector<double>>& A, std::vector<std::vector<double>>& B);
    std::shared_ptr<HSTNode> remove_recursive(std::shared_ptr<HSTNode> node, const std::vector<double>& point, char label, int depth, std::vector<std::vector<double>>& A, std::vector<std::vector<double>>& B);
    
    void Update_Sampler();
    
    void delete_update_A(std::shared_ptr<HSTNode> node, int C);
    void delete_update_B(std::shared_ptr<HSTNode> node, int C);
    void insert_update_A(std::shared_ptr<HSTNode> node, int C, const std::vector<double>& point, int ai);
    void insert_update_B(std::shared_ptr<HSTNode> node, int C);
};
