#include <iostream>
#include <vector>
#include <unordered_map>
#include <memory>
#include <cmath>
#include <algorithm>
#include <random>
#include <map>
#include <cassert>
#include <set>

#include "quad_tree.h"

using namespace std;

static std::random_device rd;
static std::mt19937 gen_sample(rd());

void point_of_id_to_origin(const std::vector<double>& point,
                                 const std::vector<double>& root_origin,
                                 double width, std::vector<double>& result) {
    if (point.size() != root_origin.size())
        throw std::invalid_argument("Dimension mismatch.");

    for (size_t i = 0; i < point.size(); ++i) {
        double delta = point[i] - root_origin[i];
        double snapped = std::floor(delta / width) * width + root_origin[i];
        result[i] = snapped;
    }
}


HSTNode::HSTNode(int d, const int i, double w)
    : depth(d), id_in_depth(i), width(w), node_sampler(), node_total_weight(0) {}

void HSTNode::Build_Node_Sampler() {
    node_sampler.clear();
    node_total_weight = 0.0;

    for (auto& child : children) {
        double weight = child.second->B_count > 0 ? 0 : child.second->A_count;
        node_sampler.insert(make_pair(child.second->depth, child.second->id_in_depth), weight, child.second);
        node_total_weight += weight;
    }
}


QuadTree::QuadTree(int d, double delta, int max_depth)
    : dim(d), delta(delta), max_depth(max_depth) {
    default_random_engine gen;
    uniform_real_distribution<double> dist(0.0, delta / 2);
    random_shift.resize(dim);
    for (int i = 0; i < dim; ++i) {
        random_shift[i] = dist(gen);
    }
}


// Transform vector<int> to bitset to be the key of unordered_map
template <std::size_t N>
bitset<N> QuadTree::to_bitset(const vector<int>& vec) {
    if (vec.size() > N) throw out_of_range("Error: child_idx size exceeds the limit of DIM bits.");
    bitset<N> bs;
    for (size_t i = 0; i < vec.size(); ++i){
        if (vec[i] != 0 && vec[i] != 1) {
            throw invalid_argument("Error: Invalid index value. Got vec[" + to_string(i) + "] = " + to_string(vec[i]) + ", must be 0 or 1.");
        }        
        bs[i] = vec[i];
    }
    return bs;
}


// Map the point to the subgrid ID of the current grid
vector<int> QuadTree::point_to_grid_index(const vector<double>& point, int id_in_depth, double width, std::vector<std::vector<double>>& A, std::vector<std::vector<double>>& B) {
    std::vector<double> origin(DIM);
    if (id_in_depth >= B_ID_shift){
        id_in_depth -= B_ID_shift;
        point_of_id_to_origin(B[id_in_depth], root_origin, width, origin);
    }else{
        point_of_id_to_origin(A[id_in_depth], root_origin, width, origin);
    }
    vector<int> child_idx(DIM);
    for (size_t i = 0; i < DIM; ++i){
        child_idx[i] = static_cast<int>((point[i] - origin[i]) / (width / 2.0));
    }
    return child_idx;
}


// Check if a point in A matches any point in B
bool QuadTree::is_matched(const vector<double>& point_a, const vector<vector<double>>& dataset_b,
                int id_in_depth, double width, vector<vector<double>>& A, vector<vector<double>>& B) {
    if (dataset_b.size() == 0) return false;
    vector<int> child_idx_a = point_to_grid_index(point_a, id_in_depth, width, A, B);
    // Check the points in B. If a B point is found that belongs to the same subgrid, it does not match in the current cell.
    for (const auto& b : dataset_b) {
        if (point_to_grid_index(b, id_in_depth, width, A, B) == child_idx_a) {
            return false;  // They are in the same subgrid, so not a match.
        }
    }
    return true;  // No points from B in the same subgrid, hence it's a match.
}


// Recursively build the HST tree
void QuadTree::build(const vector<vector<double>>& A, const vector<vector<double>>& B, 
    const vector<vector<double>>& c_A, const vector<vector<double>>& c_B) {
    vector<double> min_corner(dim, numeric_limits<double>::max());
    for (const auto& pt : A) {
        for (size_t i = 0; i < dim; ++i)
            min_corner[i] = min(min_corner[i], pt[i]);
    }
    for (const auto& pt : B) {
        for (size_t i = 0; i < dim; ++i)
            min_corner[i] = min(min_corner[i], pt[i]);
    }

    root_origin.resize(dim);
    for (int i = 0; i < dim; ++i){
        root_origin[i] = min_corner[i] - random_shift[i];
        // cout << "root origin: " << root_origin[i] << endl;
    }

    if (c_A.size() == 0 && c_B.size() == 0){
        root = make_shared<HSTNode>(0, 0, delta);
        root->Build_Node_Sampler();
        root->A_count = 0;
        root->B_count = 0;
        root->match_count = 0;
        root->is_leaf = false;
        root->parent = nullptr;
        // cout << "root->width: " << root->width << endl;
    }else{
        shared_ptr<HSTNode> parent_node = nullptr;
        unordered_map<bitset<DIM>, vector<pair<string, vector<double>>>> buckets;
        root = build_subtree(c_A, c_B, 0, root_origin, root_origin, delta, parent_node);
    }
}


vector<int> QuadTree::calculate_current_id(const vector<double>& current_origin, const vector<double>& root_origin, double width) {
    vector<int> current_id(dim);
    for (size_t i = 0; i < dim; ++i) {
        double quo = (current_origin[i] - root_origin[i]) / width;
        // TODO: 1e-3? 1e-4?
        if (fabs(quo - round(quo)) > 1e-3) {
            throw runtime_error("Warning: grid id is not an integer, quo = " + to_string(quo));
        }        
        current_id[i] = round(quo);
    }
    return current_id;
}


void QuadTree::process_subgrid(shared_ptr<HSTNode>& node, unordered_map<bitset<DIM>, vector<pair<string, vector<double>>>>& buckets,
                               const vector<double>& current_origin, const vector<double>& root_origin, double width, int depth) {
    // For each subgrid, recursively build the tree
    for (const auto& bucket : buckets) {
        vector<double> child_origin = current_origin;
        for (size_t i = 0; i < dim; ++i) {
            child_origin[i] += (bucket.first[i] == 0) ? 0 : width / 2;
        }

        // Create a new HSTNode for this subgrid
        vector<vector<double>> points_A_in_subgrid;
        vector<vector<double>> points_B_in_subgrid;

        // Separate the points into A and B points for the next level
        for (const auto& point : bucket.second) {
            if (point.first == "A") {
                points_A_in_subgrid.push_back(point.second);
            } else if (point.first == "B") {
                points_B_in_subgrid.push_back(point.second);
            }
        }
        // Recursively build the subtree for this subgrid
        auto child_node = build_subtree(points_A_in_subgrid, points_B_in_subgrid,
            depth + 1, child_origin, root_origin, width / 2, node);
        // Only add as child if node was created (has points)
        if (child_node) {
            node->children[bucket.first] = child_node;
            assert(child_node.get() != node.get());
        }
    }
}


shared_ptr<HSTNode> QuadTree::build_subtree(const vector<vector<double>>& points_A, const vector<vector<double>>& points_B, 
    int depth, const vector<double>& current_origin, const vector<double>& root_origin, double width, shared_ptr<HSTNode>& parent_node) {

    return nullptr;
}


std::shared_ptr<HSTNode> QuadTree::sample_node_by_weight(std::shared_ptr<HSTNode> node) {
    if (!node || node->node_total_weight <= 0) return nullptr;
    
    std::uniform_real_distribution<> dis(0.0, node->node_total_weight);
    double random_weight = dis(gen_sample);
    
    std::shared_ptr<BBTNode> sampled = node->node_sampler.sample(random_weight);
    
    return sampled ? sampled->hst_ptr : nullptr;
}

double QuadTree::calculateTotalWeight() {
    double total = 0;
    std::function<void(std::shared_ptr<HSTNode>)> calc = [&](std::shared_ptr<HSTNode> node) {
        if (!node) return;
        total += node->width * node->match_count;
        // cout << "depth: " << node->depth << " width: " << node->width << " match_count: " << node->match_count << " weight: " << node->tree_weight << endl;
        for (auto& child : node->children) {
            calc(child.second);
        }
    };
    calc(root);
    return total;
}

void QuadTree::Build_Tree_Sampler() {
    weight_tree.clear();
    total_weight = 0;

    std::function<void(std::shared_ptr<HSTNode>)> build_weights = [&](std::shared_ptr<HSTNode> node) {
        if (!node) return;
        
        double weight = node->width * node->match_count;
        node->tree_weight = weight;
        total_weight += weight;
        
        auto node_id = std::make_pair(node->depth, node->id_in_depth);
        weight_tree.insert(node_id, weight, node);
        
        for (auto& child : node->children) {
            build_weights(child.second);
        }
    };
    
    build_weights(root);
}

int countNonZeroNodes_AVL(const std::shared_ptr<BBTNode>& node, double threshold = 1e-4) {
    if (!node) return 0;

    int count = (node->weight > threshold) ? 1 : 0;
    count += countNonZeroNodes_AVL(node->left, threshold);
    count += countNonZeroNodes_AVL(node->right, threshold);
    return count;
}

int countNonZeroNodes_quadtree(const std::shared_ptr<HSTNode>& node, double threshold = 1e-4) {
    if (!node) return 0;

    int count = (node->tree_weight > threshold) ? 1 : 0;
    for (auto& child : node->children) {
        count += countNonZeroNodes_quadtree(child.second, threshold);
    }

    return count;
}

std::shared_ptr<HSTNode> QuadTree::sample_by_weight() {
    // int count1 = countNonZeroNodes_AVL(weight_tree.getRoot());
    // int count2 = countNonZeroNodes_quadtree(root);
    // std::cout << "Tree 1 non-zero nodes: " << count1 << std::endl;
    // std::cout << "Tree 2 non-zero nodes: " << count2 << std::endl;

    if (!root || total_weight <= 0) return nullptr;
    
    std::uniform_real_distribution<> dis(0.0, total_weight);
    double random_weight = dis(gen_sample);
    // cout << "random_weight: " << random_weight << endl;
    
    // Use AVL tree's improved sampling method
    std::shared_ptr<BBTNode> result = weight_tree.sample(random_weight);

    return result ? result->hst_ptr : nullptr;

}


void QuadTree::insert_update_A(shared_ptr<HSTNode> node, int C, const vector<double>& point, int ai){
    while(node != nullptr) {
        if (node->A_count != 0 && node->is_leaf) {
            node->single_A_point = ai;
        }
        if (node->A_count == 0) {  // First A point in this node
            node->single_A_point = ai;
        } else if (node->A_count == 1 && !node->is_leaf) {  // Second A point, clear single_A_point
            node->single_A_point.reset();
        }
        node->A_count += 1;
        if (node->parent) {
            if (node->B_count == 0) {
                double old_weight = node->A_count - 1;
                double new_weight = node->A_count; 
                node->parent->node_sampler.remove(make_pair(node->depth, node->id_in_depth), old_weight);
                node->parent->node_sampler.insert(make_pair(node->depth, node->id_in_depth), new_weight, node);
                node->parent->node_total_weight += 1;
            }
        }
        
        if(node->B_count > 0 && C == 1) {
            node->match_count += 1;
            C = 0;
        }
        
        double old_weight = node->tree_weight;
        total_weight -= old_weight;
        
        double new_weight = node->width * node->match_count;
        total_weight += new_weight;
        node->tree_weight = new_weight;
        
        auto node_id = std::make_pair(node->depth, node->id_in_depth);
        
        if (fabs(old_weight - new_weight) > 1e-4) {
            weight_tree.remove(node_id, old_weight);
            weight_tree.insert(node_id, new_weight, node);
        }
        
        node = node->parent;
    }
}


void QuadTree::insert_update_B(shared_ptr<HSTNode> node, int C){
    node->B_count += 1;
    if (node->B_count == 1){
        node->match_count = node->A_count - C;
        C = node->A_count;
    }else{
        node->match_count = node->match_count - C;
        C = 0;
    }
    
    if (node->parent) {
        if (node->B_count == 1) {
            node->parent->node_sampler.remove(
                make_pair(node->depth, node->id_in_depth), node->A_count
            );
            node->parent->node_total_weight -= node->A_count;
        }
    }
    double old_weight = node->tree_weight;
    total_weight -= old_weight;
    
    double new_weight = node->width * node->match_count;
    node->tree_weight = new_weight;
    total_weight += new_weight;
    
    auto node_id = std::make_pair(node->depth, node->id_in_depth);
    
    if (fabs(old_weight - new_weight) > 1e-4) {
        weight_tree.remove(node_id, old_weight);
        weight_tree.insert(node_id, new_weight, node);
    }

    auto node_parent = node->parent;
    if (node_parent != nullptr) this->insert_update_B(node_parent, C);
}


void QuadTree::insert(vector<double>& point, char label, int point_i, vector<vector<double>>& A, vector<vector<double>>& B) {
    shared_ptr<HSTNode> leaf_node = insert_recursive(root, point, label, 0, point_i, A, B);
    if (label == 'A'){
        insert_update_A(leaf_node, 1, point, point_i);
    }else if (label == 'B'){
        insert_update_B(leaf_node, 0);
    }else{
        throw std::invalid_argument("Error: wrong label. Valid labels are 'A' or 'B'.");
    }
}


// TODO: the decomposition does not continue after reaching the original leaf node.
shared_ptr<HSTNode> QuadTree::insert_recursive(shared_ptr<HSTNode>& node, const vector<double>& point, char label, int depth, int point_id, vector<vector<double>>& A, vector<vector<double>>& B) {
    if (depth >= max_depth) {
        node->is_leaf = true;
        return node;
    }

    auto idx = point_to_grid_index(point, node->id_in_depth, node->width, A, B);
    auto bs = to_bitset(idx);

    if (node->children.find(bs) == node->children.end()) {
        double new_width = node->width / 2;
        // cout << "222 " << depth + 1 << endl;
        // vector<int> new_id = calculate_current_id(child_origin, root_origin, new_width);
        auto child_node = make_shared<HSTNode>(depth + 1, point_id, new_width);
        child_node->Build_Node_Sampler();
        child_node->parent = node;
        node->children[bs] = child_node;
        assert(child_node.get() != node.get());

    }

    return insert_recursive(node->children[bs], point, label, depth + 1, point_id, A, B);
}

shared_ptr<HSTNode> QuadTree::remove(const vector<double>& point, char label, vector<vector<double>>& A, vector<vector<double>>& B) {
    shared_ptr<HSTNode> leaf_node = remove_recursive(root, point, label, 0, A, B);
    if (label == 'A') {
        delete_update_A(leaf_node, 1);
    } else if (label == 'B') {
        delete_update_B(leaf_node, 0);
    } else {
        throw invalid_argument("Error: wrong label. Valid labels are 'A' or 'B'.");
    }
    return leaf_node;
}

shared_ptr<HSTNode> QuadTree::remove_recursive(shared_ptr<HSTNode> node, const vector<double>& point, char label, int depth, vector<vector<double>>& A, vector<vector<double>>& B) {
    if (depth >= max_depth) {
        return node;
    }

    auto idx = point_to_grid_index(point, node->id_in_depth, node->width, A, B);
    auto bs = to_bitset(idx);

    // cout << "bs: " << node->width << endl;

    if (node->children.find(bs) == node->children.end()) {
        for (const auto& pair : node->children) {
            // std::cout << "child: " << pair.first << std::endl; 
        }
        // cout << "bs: " << bs << " " << node->width << endl;
        throw runtime_error("Point not found in tree");
    }
    return remove_recursive(node->children[bs], point, label, depth + 1, A, B);
}

void QuadTree::delete_update_A(shared_ptr<HSTNode> node, int C) {
    while (node != nullptr) {
        if (node->A_count == 1) { 
            node->single_A_point.reset();
        } else if (node->A_count == 2) { 
            for (auto& child : node->children) {
                if (child.second->A_count > 0) {
                    if (child.second->A_count == 1) {
                        node->single_A_point = child.second->single_A_point;
                    }
                    break;
                }
            }
        } else if (node->A_count == 0) {
            // node->single_A_point.reset();
            throw std::runtime_error("Node has no A-class points.");
        }
        node->A_count--;
        if (node->parent) {
            if (node->B_count == 0) {
                double old_weight = node->A_count + 1;
                double new_weight = node->A_count; 
                node->parent->node_sampler.remove(make_pair(node->depth, node->id_in_depth), old_weight);
                node->parent->node_sampler.insert(make_pair(node->depth, node->id_in_depth), new_weight, node);
                node->parent->node_total_weight -= 1;
            }
        }
        if (node->B_count > 0 && C == 1) {
            node->match_count--;
            C = 0;
        }
        
        double old_weight = node->tree_weight;
        total_weight -= old_weight;
        
        double new_weight = node->width * node->match_count;
        node->tree_weight = new_weight;
        total_weight += new_weight;
        
        auto node_id = std::make_pair(node->depth, node->id_in_depth);
        
        if (fabs(old_weight - new_weight) > 1e-4) {
            weight_tree.remove(node_id, old_weight);
            weight_tree.insert(node_id, new_weight, node);
        }
        node = node->parent;
    }
}

void QuadTree::delete_update_B(shared_ptr<HSTNode> node, int C) {
    node->B_count--;
    if (node->B_count == 0){
        node->match_count = 0;
        C = node->A_count;
    }else{
        node->match_count += C;
        C = 0;
    }

    auto parent_node = node->parent;
    if (parent_node != nullptr) {
        if (node -> A_count == 0 && node->B_count == 0 && node->children.empty()){
            auto it = std::find_if(parent_node->children.begin(), parent_node->children.end(),
                                [&](auto& kv){ return kv.second.get() == node.get(); });
            if (it == parent_node->children.end()) throw std::logic_error("prune: child node not found in parent->children map");
            std::bitset<DIM> key = it->first;
            parent_node->children.erase(key);
        }
        this->delete_update_B(parent_node, C);
    }
    

    if (node->parent) {
        if (node->B_count == 0 && node->A_count != 0) {
            node->parent->node_sampler.insert(
                make_pair(node->depth, node->id_in_depth),
                node->A_count,
                node
            );
            node->parent->node_total_weight += node->A_count;
        }
    }
    double old_weight = node->tree_weight;
    total_weight -= old_weight;
    double new_weight = node->width * node->match_count;
    node->tree_weight = new_weight;
    total_weight += new_weight;

    if (fabs(old_weight - new_weight) > 1e-4) {
        auto node_id = std::make_pair(node->depth, node->id_in_depth);
        weight_tree.remove(node_id, old_weight);
        weight_tree.insert(node_id, new_weight, node);
    }
}
