#include "Fragments.h"

#include <utility>

void FragmentCollection::storing_fragments(const map<Fragment, Location>& fragments){
    map<Fragment, vector<Location>> local_fragment_count;

    for (const auto &fragment: fragments) {
        Fragment reversed_fragment = fragment.first;
        std::reverse(reversed_fragment.begin(), reversed_fragment.end());
        if (fragment.first < reversed_fragment) {
            local_fragment_count[fragment.first].push_back(fragment.second);
        } else if(fragment.first > reversed_fragment){
            vector<NodeId> reversed_location = fragment.second;
            std::reverse(reversed_location.begin(), reversed_location.end());
            local_fragment_count[reversed_fragment].push_back(reversed_location);
        }else{
            vector<NodeId> reversed_location = fragment.second;
            std::reverse(reversed_location.begin(), reversed_location.end());
            if(reversed_location<fragment.second){
                local_fragment_count[reversed_fragment].push_back(reversed_location);
            }else{
                local_fragment_count[fragment.first].push_back(fragment.second);
            }
        }
    }

#pragma omp critical
    {
        for (const auto &entry : local_fragment_count) {
            this->fragment_count[entry.first].insert(this->fragment_count[entry.first].end(),entry.second.begin(), entry.second.end());
        }
    }
}

void FragmentCollection::storing_cyclic_fragments(const map<Fragment, Location>& fragments){
    map<Fragment, vector<Location>> local_fragment_cyclic_count;

    for (const auto &fragment: fragments) {
        Fragment new_fragment = fragment.first;
        Fragment new_location = fragment.second;
        tie(new_fragment, new_location) = order_cyclic_fragment_and_location(new_fragment, new_location);
        vector<pair<Fragment, Location>> sub_patterns = find_bodies_of_cyclic_fragment(new_fragment, new_location);
        for(auto sub_pattern:sub_patterns){
            map<Fragment, Location> local_sub_pattern{};
            local_sub_pattern.insert({sub_pattern.first, sub_pattern.second});
            storing_fragments(local_sub_pattern);
        }
        local_fragment_cyclic_count[new_fragment].push_back(new_location);
    }

#pragma omp critical
    {
        for (const auto &entry : local_fragment_cyclic_count) {
            this->fragment_cyclic_count[entry.first].insert(this->fragment_cyclic_count[entry.first].end(),entry.second.begin(), entry.second.end());
        }
    }
}

void FragmentCollection::get_fragments(size_t number_of_random_walks, size_t max_depth, bool force_cyclic){
    unordered_set<NodeId> nodes_set = hg.get_nodes();
    vector<NodeId> nodes(nodes_set.begin(), nodes_set.end());
    nodes.reserve(nodes_set.size());

#pragma omp parallel for schedule(dynamic)
    for(size_t i = 0; i < nodes.size(); i++) {
        NodeId node = nodes[i];
        vector<EdgeId> unaries = unary_binary_edges[node].first;
        vector<EdgeId> binaries = unary_binary_edges[node].second;
        unordered_set<EdgeId> previous_edges{};

        // Reserve space to avoid multiple allocations
        unaries.reserve(number_of_random_walks);
        binaries.reserve(number_of_random_walks);

        // Single Unary predicates
        map<Fragment, Location> initial_unary_fragment{};
        for(auto unary : unaries){
            previous_edges.insert(unary);
            Fragment fragment{edge_to_predicate[unary]};
            Location unary_location{unary};
            initial_unary_fragment.insert({fragment, unary_location});
        }
        storing_fragments(initial_unary_fragment);
        set_initial_fragments(node, unaries, binaries);
        Fragment null_fragment{0};
        Location null_location{};
        initial_unary_fragment.insert({null_fragment,null_location});
        for(auto binary:binaries){
            NodeId next_node = hg.get_next_node(node, binary);
            map<Fragment, Location> final_fragments{};
            for(const auto& fragment:initial_unary_fragment){
                Fragment new_fragment = fragment.first;
                new_fragment.push_back(hg.get_predicate_id(binary));
                Location new_location = fragment.second;
                new_location.push_back(binary);
                final_fragments.insert({new_fragment, new_location});
            }
            size_t left_walks = 1 + ((number_of_random_walks - 1) / binaries.size());
            previous_edges.insert(binary);
            next_step(left_walks,
                      next_node,
                      node,
                      1,
                      final_fragments,
                      previous_edges,
                      max_depth,
                      force_cyclic);
        }
    }
}

void FragmentCollection::set_initial_fragments(NodeId node, vector<EdgeId> unaries, const vector<EdgeId> &binaries) {
    map<Fragment, Location> unary_fragments{};
    for(int i=0; i<unaries.size(); i++){
        for(int j = i+1; j<unaries.size(); j++){
            Fragment fragment{};
            fragment.push_back(edge_to_predicate[unaries[i]]);
            fragment.push_back(edge_to_predicate[unaries[j]]);
            Location unary_location{unaries[i], unaries[j]};
            unary_fragments.insert({fragment, unary_location});
        }
    }
    storing_fragments(unary_fragments);
    map<NodeId,vector<EdgeId>> node_edge_memberships{};
    for(auto binary:binaries){
        NodeId next = hg.get_next_node(node, binary);
        node_edge_memberships[next].push_back(binary);
    }
    for(auto members:node_edge_memberships){
        if(members.second.size()>1){
            map<Fragment, Location> binary_fragments{};
            for(int i=0; i<members.second.size(); i++){
                for(int j = i+1; j<members.second.size(); j++){
                    Fragment fragment{};
                    fragment.push_back(edge_to_predicate[members.second[i]]);
                    fragment.push_back(edge_to_predicate[members.second[j]]);
                    Location binary_location{members.second[i], members.second[j]};
                    binary_fragments.insert({fragment, binary_location});
                }
            }
            storing_fragments(binary_fragments);
        }
    }
}

void FragmentCollection::next_step(size_t number_of_random_walks,
                                   NodeId node,
                                   NodeId starting_node,
                                   int depth,
                                   const map<vector<PredicateId>, Location> &fragments,
                                   unordered_set<EdgeId> previous_edges,
                                   size_t max_depth,
                                   bool force_cyclic){
    if((depth == 3 || depth ==4) && starting_node == node){
        map<vector<PredicateId>,Location> cyclic_fragments{};
        for(const auto& fragment:fragments){
            Fragment new_fragment = fragment.first;
            Location new_location = fragment.second;
            cyclic_fragments.insert({new_fragment, new_location});
        }
        storing_cyclic_fragments(cyclic_fragments);
    }
    if(depth <= 3 && depth < max_depth){
        map<vector<PredicateId>, Location> new_fragments{};
        vector<EdgeId> unaries = unary_binary_edges[node].first;
        vector<EdgeId> binaries = unary_binary_edges[node].second;
        vector<EdgeId> new_binaries{};
        vector<EdgeId> new_unaries{};
        std::set_difference(unaries.begin(),unaries.end(),previous_edges.begin(),previous_edges.end(),std::inserter(new_unaries,new_unaries.begin()));
        std::set_difference(binaries.begin(),binaries.end(),previous_edges.begin(),previous_edges.end(),std::inserter(new_binaries,new_binaries.begin()));
        vector<EdgeId> cyclic_binaries{};
        vector<EdgeId> non_cyclic_binaries{};
        vector<EdgeId> final_binaries;
        bool store_fragments = false;
        if(force_cyclic && depth == 2){
            for(auto binary:new_binaries){
                NodeId next_node = hg.get_next_node(node, binary);
                if(next_node == starting_node){
                    cyclic_binaries.push_back(binary);
                }else{
                    non_cyclic_binaries.push_back(binary);
                }
            }
            if(!cyclic_binaries.empty()){
                store_fragments = true;
            }
            if(cyclic_binaries.size()>number_of_random_walks) {
                final_binaries = get_random_elements(cyclic_binaries, number_of_random_walks);
            }else{
                final_binaries = cyclic_binaries;
                if(!non_cyclic_binaries.empty()) {
                    vector<EdgeId> non_cyclic_final_binaries = get_random_elements(non_cyclic_binaries,
                                                                                   number_of_random_walks -
                                                                                   cyclic_binaries.size());
                    final_binaries.insert(final_binaries.end(), non_cyclic_final_binaries.begin(),
                                          non_cyclic_final_binaries.end());
                }
            }
        }else{
            store_fragments = true;
            if(new_binaries.size()>number_of_random_walks){
                final_binaries = get_random_elements(new_binaries, number_of_random_walks);
            }else{
                final_binaries = new_binaries;
            }
        }
        for(const auto& fragment:fragments){
            vector<PredicateId> new_fragment = fragment.first;
            new_fragment.push_back(0);
            new_fragments.insert({new_fragment,fragment.second});
        }
        for(auto unary:new_unaries){
            previous_edges.insert(unary);
            for(const auto& fragment:fragments){
                vector<PredicateId> new_fragment = fragment.first;
                new_fragment.push_back(edge_to_predicate[unary]);
                Location new_location = fragment.second;
                if(!has(fragment.second,unary)){
                    new_location.push_back(unary);
                    new_fragments.insert({new_fragment,new_location});
                }
            }
        }
        if(store_fragments){
            storing_fragments(new_fragments);
        }
        for(auto binary:final_binaries){
            NodeId next_node = hg.get_next_node(node, binary);
            map<vector<PredicateId>,Location> final_fragments{};
            for(const auto& fragment:new_fragments){
                vector<PredicateId> new_fragment = fragment.first;
                new_fragment.push_back(edge_to_predicate[binary]);
                Location new_location = fragment.second;
                new_location.push_back(binary);
                final_fragments.insert({new_fragment,new_location});
            }
            unordered_set<EdgeId> previous_edges_copy = previous_edges;
            previous_edges_copy.insert(binary);
            next_step(1 + ((number_of_random_walks - 1) / final_binaries.size()),
                      next_node,
                      starting_node,
                      depth+1,
                      final_fragments,
                      previous_edges_copy,
                      max_depth,
                      force_cyclic);
        }
    }
}

FragmentCollection::FragmentCollection(const HyperGraph& hg, size_t number_of_random_walks, size_t max_depth, bool force_cyclic){
    this->hg = hg;
    this->edge_to_predicate = this->hg.get_predicates();
    for(auto node:this->hg.get_nodes()){
        vector<EdgeId> unaries;
        vector<EdgeId> binaries;
        unordered_set<EdgeId> previous_edges{};
        tie(unaries, binaries) = this->hg.get_unary_binary_edges(node);
        this->unary_binary_edges[node] = make_pair(unaries, binaries);
    }
    get_fragments(number_of_random_walks, max_depth, force_cyclic);
}


map<Fragment, vector<Location>> FragmentCollection::get_fragment_count() {
    map<Fragment, vector<Location>> fast_fragment_count;
    for(auto &[fragment, locations]: this->fragment_count){
        set<Location> unique_locations{locations.begin(), locations.end()};
        fast_fragment_count[fragment] = vector<Location>(unique_locations.begin(), unique_locations.end());
    }
    return fast_fragment_count;
}

bool FragmentCollection::is_unary(PredicateId predicate) {
    return this->hg.is_unary(predicate);
}

map<Fragment, vector<Location>> FragmentCollection::get_fragment_cyclic_count() {
    map<Fragment, vector<Location>> fast_fragment_count;
    for(auto &[fragment, locations]: this->fragment_cyclic_count){
        set<Location> unique_locations{locations.begin(), locations.end()};
        fast_fragment_count[fragment] = vector<Location>(unique_locations.begin(), unique_locations.end());
    }
    return fast_fragment_count;
}

string FragmentCollection::get_predicate_from_id(PredicateId predicate) {
    return this->hg.get_predicate_from_id(predicate);
}

vector<NodeId> FragmentCollection::get_edge(EdgeId edge) {
    return this->hg.get_edge(edge);
}

PredicateId FragmentCollection::get_predicate_id(Predicate predicate) {
    return this->hg.get_predicate_id(std::move(predicate));
}

bool FragmentCollection::is_unary_edge(EdgeId edge) {
    return is_unary(this->hg.get_predicate_id(edge));
}

pair<vector<EdgeId>, vector<EdgeId>> FragmentCollection::get_unary_binary_edges(NodeId node) {
    return this->unary_binary_edges[node];
}

void FragmentCollection::set_fragment_count(map<Fragment, vector<Location>> new_fragment_count) {
    this->fragment_count = new_fragment_count;
}

void FragmentCollection::set_fragment_cyclic_count(map<Fragment, vector<Location>> new_fragment_cyclic_count) {
    this->fragment_cyclic_count = new_fragment_cyclic_count;
}

FragmentCollection::FragmentCollection() =default;
