
#pragma once

#include <bits/stdc++.h>
#include "basic.hpp"

using namespace std;


vector<pair<int, int>> reorder_single_DFS(vector<pair<int, int>> &edges) {
    int n = edges.size();
    vector<vector<int>> children(n);
    int root = -1;

    for (const auto &edge : edges) {
        int node = edge.first;
        int parent = edge.second;
        if (parent == -1)root = node;
        else children[parent].push_back(node);
    }

    vector<pair<int, int>> result;
    result.reserve(n);

    stack<pair<int, int>> s;
    s.push({root, -1});
    
    while (!s.empty()) {
        auto [node, parent] = s.top(); s.pop();
        result.push_back({node, parent});
        
        const auto &childList = children[node];
        for (int i = childList.size() - 1; i >= 0; --i)s.push({childList[i], node});  
    }
    
    return result;
}


// vector<pair<int, int>> reorder_multiple_DFS(vector<pair<int, int>> &edges) {
//     int n = edges.size();
//     vector<vector<int>> children(n);
//     vector<int> roots;

//     for (const auto &edge : edges) {
//         int node = edge.first; int parent = edge.second;
//         if (parent == -1)roots.push_back(node);
//         else children[parent].push_back(node);
//     }

//     vector<pair<int, int>> result;
//     result.reserve(n);

//     for (int root : roots) {
//         stack<pair<int, int>> s;
//         s.push({root, -1});
//         while (!s.empty()) {
//             auto [node, parent] = s.top();
//             s.pop();
//             result.push_back({node, parent});
            
//             const auto &childList = children[node];
//             for (int i = childList.size() - 1; i >= 0; --i)s.push({childList[i], node});
//         }
//     }
    
//     return result;
// }

void reorder_multiple_DFS(vector<pair<int, int>>& edges) {
    int n = edges.size();
    vector<vector<int>> children(n);
    vector<int> roots;

    for (const auto& edge : edges) {
        int node = edge.first;
        int parent = edge.second;
        if (parent == -1) roots.push_back(node);
        else children[parent].push_back(node);
    }

    vector<pair<int, int>> reordered;
    reordered.reserve(n);

    for (int root : roots) {
        stack<pair<int, int>> s;
        s.push({root, -1});
        while (!s.empty()) {
            auto [node, parent] = s.top();
            s.pop();
            reordered.push_back({node, parent});

            const auto& childList = children[node];
            for (int i = childList.size() - 1; i >= 0; --i)
                s.push({childList[i], node});
        }
    }

    edges = std::move(reordered); 
}




vector<pair<int, int>> reorder_single_optimized(vector<pair<int, int>> &edges, float* query, int dim, float alpha=1.05) {
    int n = edges.size();
    vector<vector<int>> children(n);
    int root = -1;

    for (const auto &edge : edges) {
        int node = edge.first;
        int parent = edge.second;
        if (parent == -1)root = node;
        else children[parent].push_back(node);
    }

    vector<pair<int, int>> result;
    result.reserve(n);

    stack<pair<int, int>> s; s.push({root, -1});

    int prev=-1;
    int count = 0;
    while (!s.empty()) {
        auto [node, parent] = s.top(); s.pop();

        if (parent == -1)result.push_back({node, parent});
        else if(get_distance(query, dim, node, parent)*alpha > get_distance(query, dim, node, prev)){result.push_back({node, prev});count++;}
        else result.push_back({node, parent});
        prev = node;
        
        const auto &childList = children[node];
        for (int i = childList.size() - 1; i >= 0; --i)s.push({childList[i], node});  
    }
    
    cout<<"num flattened : "<<count<<endl;
    return result;
}

vector<pair<int, int>> reorder_multiple_optimized(vector<pair<int, int>> &edges, float* query, int dim, float alpha=1.05) {
    int n = edges.size();
    vector<vector<int>> children(n);
    vector<int> roots;

    // Build children lists and identify roots.
    for (const auto &edge : edges) {
        int node = edge.first;
        int parent = edge.second;
        if (parent == -1) roots.push_back(node);
        else children[parent].push_back(node);
    }

    vector<pair<int, int>> result; result.reserve(n);

    int count = 0;

    // Process each tree in the forest.
    for (int root : roots) {
        stack<pair<int, int>> s;
        s.push({root, -1});
        int prev = -1;  // Reset prev for each separate DFS.
        while (!s.empty()) {
            auto [node, parent] = s.top();
            s.pop();

            if (parent == -1)
                result.push_back({node, parent});
            else if (get_distance(query, dim, node, parent) * alpha > get_distance(query, dim, node, prev)) {
                result.push_back({node, prev});
                count++;
            } else {
                result.push_back({node, parent});
            }
            prev = node;

            const auto &childList = children[node];
            for (int i = childList.size() - 1; i >= 0; --i)
                s.push({childList[i], node});
        }
    }

    cout << "num flattened : " << count << endl;
    return result;
}



/* multithread reordering */
void reorder_multiple_DFS_hash(std::vector<std::pair<int,int>>& edges) {
    // build children by global ID
    std::unordered_map<int,std::vector<int>> children;
    std::vector<int> roots;
    children.reserve(edges.size());

    for (auto& e : edges) {
        int node   = e.first;
        int parent = e.second;
        if (parent == -1) {
            roots.push_back(node);
        } else {
            children[parent].push_back(node);
        }
    }

    // DFS
    std::vector<std::pair<int,int>> reordered;
    reordered.reserve(edges.size());
    for (int root : roots) {
        std::stack<std::pair<int,int>> st;
        st.push({root,-1});
        while (!st.empty()) {
            auto [node,parent] = st.top(); st.pop();
            reordered.emplace_back(node,parent);
            auto it = children.find(node);
            if (it != children.end()) {
                for (int i = (int)it->second.size()-1; i >= 0; --i)
                    st.push({ it->second[i], node });
            }
        }
    }

    edges = std::move(reordered);
}


void parallel_reorder_multiple_DFS(std::vector<std::vector<std::pair<int, int>>>& all_edges,
                                   int n_core)
{
    int G = (int)all_edges.size();

    #pragma omp parallel for schedule(dynamic)
    for (int gi = 0; gi < G; ++gi) {
        reorder_multiple_DFS_hash(all_edges[gi]);
    }
}