#include "CandidateRules.h"

#include <utility>

bool compare_rule_utility(const Rule& rule1, const Rule& rule2){return rule1.utility>rule2.utility;}
bool compare_rule_fidelity(const Rule& rule1, const Rule& rule2){return rule1.fidelity>rule2.fidelity;}
std::string doubleToStringWithTwoDecimals(double value) {
    std::ostringstream oss;
    oss << std::fixed << std::setprecision(2) << value;
    return oss.str();
}
vector<Fragment> CandidateRuleCollection::get_constrained_fragments(const unordered_set<PredicateId>& target_predicates) {
    set<Fragment> constrained_fragments{};
    for(const auto& [fragment,location]:fragment_collection.get_fragment_count()){
        if(((fragment.front() != 0 && fragment.back() !=0 && fragment.size()>2) || (fragment.size() == 2 && fragment.front() == 0 && fragment.back() == 0 && fragment.front() != fragment.back())) && location.size() > 5 && (has(target_predicates, fragment.front()) || has(target_predicates, fragment.back()))){
            constrained_fragments.insert(fragment);
        }
    }
    vector<Fragment> constrained_fragments_vector(constrained_fragments.begin(), constrained_fragments.end());
    return constrained_fragments_vector;
}
vector<Fragment> CandidateRuleCollection::get_constrained_cyclic_fragments(){
    set<Fragment> constrained_fragments{};
    for(const auto& [fragment,location]:fragment_collection.get_fragment_cyclic_count()){
        if(location.size() >= 10){
            constrained_fragments.insert(fragment);
        }
    }
    vector<Fragment> constrained_fragments_vector(constrained_fragments.begin(), constrained_fragments.end());
    return constrained_fragments_vector;
};

CandidateRuleCollection::CandidateRuleCollection(FragmentCollection fragments, double min_fidelity, const vector<string>& target_predicates) {
    unordered_set<PredicateId> target_predicates_ids{};
    unordered_map<PredicateId, double> target_counts{};
    double total_count{0};
    for(const auto& predicate:target_predicates){
        target_predicates_ids.insert(fragments.get_predicate_id(predicate));
    }
    map<Fragment, vector<Location>> fragment_count = fragments.get_fragment_count();
    map<Fragment, vector<Location>> fragment_cyclic_count = fragments.get_fragment_cyclic_count();
    fragment_collection = std::move(fragments);
    this->fragment_collection.set_fragment_count(fragment_count);
    this->fragment_collection.set_fragment_cyclic_count(fragment_cyclic_count);
    for(auto edge:fragments.hg.get_predicates()){
        if(has(target_predicates_ids,edge.second)){
            target_counts[edge.second]++;
            total_count++;
        }
    }
    size_t not_empty_nodes{0};
    if(fragments.is_unary(*target_predicates_ids.begin())){
        for(auto node:fragments.hg.get_nodes()){
            auto members = fragments.hg.get_memberships(node);
            for(auto member:members){
                if(fragments.is_unary_edge(member)){
                    not_empty_nodes++;
                    break;
                }
            }
        }
    }
    total_count += (double)(fragments.hg.number_of_nodes() - not_empty_nodes);
    unordered_map<PredicateId, double > prior_probability{};
    for(auto predicate:target_counts){
        prior_probability[predicate.first]=total_count/predicate.second;
    }
    cout << fragment_count.size()+fragment_cyclic_count.size() << " Fragments found" << endl;
    vector<Fragment> constrained_fragments = get_constrained_fragments(target_predicates_ids);
    vector<Fragment> constrained_cyclic_fragments = get_constrained_cyclic_fragments();
    cout << constrained_fragments.size() + constrained_cyclic_fragments.size() << " Constrained Fragments found" << endl;
    Timer normal_fragments("Normal Fragments");
    set<Fragment> unique_rules{};
#pragma omp parallel for schedule(dynamic)
    for(int i =0; i<constrained_fragments.size(); i++){
        Fragment fragment = constrained_fragments[i];
        auto TT_count = (double)fragment_count[fragment].size();
        //Option 1 HEAD <- BODY
        // IF HEAD IN TARGETS
        size_t symmetry_factor = 1;
        Fragment reversed_fragment = fragment;
        reverse(reversed_fragment.begin(),reversed_fragment.end());
        if(has(target_predicates_ids, fragment.front())){
            Fragment head_1{fragment.front()};
            Fragment body_1(fragment.begin() + 1, fragment.end());
            if(fragment == reversed_fragment){
                symmetry_factor = 2;
            }
            body_1.insert(body_1.begin(),0); //Replace removed unary edge by NULL Predicate
            body_1 = order_fragment(body_1); // Sort body to canonical shape
            // Degenerate case of Unary->Unary or Binary-> Binary Rule
            if(fragment.size()==2){
                if(fragment_collection.is_unary(fragment.front())){
                    body_1 = Fragment{fragment.back()};
                }else{
                    head_1 = Fragment{0,fragment.front(),0};
                    body_1 = Fragment{0,fragment.back(), 0};
                }
            }

            double TF_count_1 = (double)fragment_count[body_1].size() - TT_count;
            double fidelity_1 = (TT_count*symmetry_factor*prior_probability[head_1[0]])/(TF_count_1 + TT_count);
            if(fidelity_1 > min_fidelity) {
                cout << endl;
                Rule rule_front{};
                rule_front.body = body_1;
                rule_front.head = head_1;
                rule_front.rule = fragment;
                rule_front.fidelity = fidelity_1;
                rule_front.locations = fragment_count[fragment];
#pragma omp critical
                {
                    candidate_rules[rule_front.rule[0]].push_back(rule_front);
                }
            }
        }

        if(has(target_predicates_ids, fragment.back())){
            if(reversed_fragment != fragment){
                //Option 2 BODY -> HEAD
                Fragment rule_2 = fragment;
                reverse(rule_2.begin(),rule_2.end()); // Enforce HEAD <- BODY shape, as body might get reversed for utility calculation
                Fragment head_2{rule_2.front()};
                Fragment body_2(rule_2.begin() + 1, rule_2.end());
                body_2.insert(body_2.begin(),0); //Replace removed unary edge by NULL Predicate
                body_2 = order_fragment(body_2); // Sort body to canonical shape

                // Degenerate case of Unary->Unary or Binary-> Binary Rule
                if(rule_2.size()==2){
                    if(fragment_collection.is_unary(rule_2.front())){
                        body_2 = Fragment{rule_2.back()};
                    }else{
                        head_2 = Fragment{0,rule_2.front(),0};
                        body_2 = Fragment{0,rule_2.back(), 0};
                    }
                }
                double TF_count_2 = (double)fragment_count[body_2].size() - TT_count;
                double fidelity_2 = (TT_count*prior_probability[head_2[0]])/(TF_count_2 + TT_count);
                if(fidelity_2 > min_fidelity){
                    Rule rule_back{};
                    rule_back.body=body_2;
                    rule_back.head=head_2;
                    rule_back.rule=rule_2;
                    rule_back.fidelity=fidelity_2;
                    // Reorder locations to HEAD <- BODY shape
                    for(const auto& location: fragment_count[fragment]){
                        Location reversed_location = location;
                        reverse(reversed_location.begin(),reversed_location.end());
                        rule_back.locations.push_back(reversed_location);
                    }
#pragma omp critical
                    {
                        candidate_rules[rule_back.rule[0]].push_back(rule_back);
                    }
                }
            }
        }
    }
    normal_fragments.Stop();
    Timer cyclic_fragments_timer("Cyclic Fragments");
//#pragma omp parallel for schedule(dynamic)
    for(int j = 0; j < constrained_cyclic_fragments.size(); j++) {
        set<Fragment> tested_fragments{};
        Fragment fragment = constrained_cyclic_fragments[j];
        auto TT_count = (double)fragment_cyclic_count[fragment].size();
        if(fragment.size()==6){
            double symmetry_factor = 1;
            for(int i=0; i<6; i++) {
                if (has(target_predicates_ids, fragment[i])) {
                    Fragment head_cyclic{fragment[i]};
                    Fragment body_cyclic{fragment[(i + 1) % 6], fragment[(i + 2) % 6], fragment[(i + 3) % 6],
                                         fragment[(i + 4) % 6], fragment[(i + 5) % 6]};
                    Fragment reversed_body = body_cyclic;
                    std::reverse(reversed_body.begin(), reversed_body.end());
                    bool reverse_location = false;
                    if (body_cyclic > reversed_body) { // Find canonical order
                        body_cyclic = reversed_body;
                        reverse_location = true;
                    }
                    size_t symmetry_count = count(body_cyclic.begin(), body_cyclic.end(), head_cyclic[0]);
                    if(symmetry_count == 1){
                        symmetry_factor = 2;
                    }else if(symmetry_count == 2){
                        symmetry_factor = 3;
                    }
                    Fragment rule{head_cyclic[0], body_cyclic[0], body_cyclic[1], body_cyclic[2], body_cyclic[3], body_cyclic[4]};
                    if (i % 2 == 0) { // i.e. the head is a unary predicate
                        body_cyclic.push_back(0);
                        body_cyclic.insert(body_cyclic.begin(), 0);
                    }
                    double TF_count_cyclic = (double)fragment_count[body_cyclic].size() - TT_count;
                    double fidelity_cyclic = (TT_count *symmetry_factor*prior_probability[head_cyclic[0]])/ (TF_count_cyclic + TT_count);
                    if (fidelity_cyclic > min_fidelity && !has(tested_fragments, rule)) {
                        tested_fragments.insert(rule);
                        unique_rules.insert(rule);
                        Rule rule_cyclic{};
                        rule_cyclic.body = body_cyclic;
                        rule_cyclic.head = head_cyclic;
                        rule_cyclic.rule = rule;
                        rule_cyclic.fidelity = fidelity_cyclic;
                        for(const auto& location:fragment_cyclic_count[fragment]){
                            Location ordered_location{};
                            Fragment ordered_fragment{};
                            if(reverse_location){
                                vector<size_t> order{(size_t)(i)%6, (size_t)(i+5)%6, (size_t)(i+4)%6, (size_t)(i+3)%6, (size_t)(i+2)%6, (size_t)(i+1)%6};
                                tie(ordered_fragment, ordered_location) = reorder_cyclic_fragment(fragment,location, {order});

                            }else{
                                vector<size_t> order{(size_t)(i)%6, (size_t)(i+1)%6, (size_t)(i+2)%6, (size_t)(i+3)%6, (size_t)(i+4)%6, (size_t)(i+5)%6};
                                tie(ordered_fragment, ordered_location) = reorder_cyclic_fragment(fragment,location, {order});
                            }
                            rule_cyclic.locations.push_back(ordered_location);
                        }
#pragma omp critical
                        {
                            candidate_rules[rule_cyclic.rule[0]].push_back(rule_cyclic);
                        }
                    }
                }
            }
        }else{
            for(int i=0; i<8; i++) {
                if (fragment[i] != 0 && i%2 !=0 && has(target_predicates_ids, fragment[i])) {
                    Fragment head_cyclic{fragment[i]};
                    Fragment body_cyclic{fragment[(i + 1) % 8], fragment[(i + 2) % 8], fragment[(i + 3) % 8],
                                         fragment[(i + 4) % 8], fragment[(i + 5) % 8],
                                         fragment[(i + 6) % 8], fragment[(i + 7) % 8]};
                    Fragment reversed_body = body_cyclic;
                    std::reverse(reversed_body.begin(), reversed_body.end());
                    bool reverse_location = false;
                    if (body_cyclic > reversed_body) { // Find canonical order
                        body_cyclic = reversed_body;
                        reverse_location = true;
                    }

                    Fragment rule{head_cyclic[0], body_cyclic[0], body_cyclic[1], body_cyclic[2], body_cyclic[3], body_cyclic[4], body_cyclic[5], body_cyclic[6]};
                    double TF_count_cyclic;
                    if(!fragment_count[body_cyclic].empty()){
                        TF_count_cyclic = (double)fragment_count[body_cyclic].size() - TT_count;
                    }else{
                        TF_count_cyclic = (double)fragment_count[reversed_body].size() - TT_count;
                    }

                    double fidelity_cyclic = TT_count / (TF_count_cyclic + TT_count);
                    if (fidelity_cyclic > min_fidelity && !has(unique_rules, rule)) {
                        Rule rule_cyclic{};
                        rule_cyclic.body = body_cyclic;
                        rule_cyclic.head = head_cyclic;
                        rule_cyclic.rule = rule;
                        rule_cyclic.fidelity = fidelity_cyclic;
                        for(const auto& location:fragment_cyclic_count[fragment]){
                            Location ordered_location{};
                            Fragment ordered_fragment{};
                            if(reverse_location){
                                vector<size_t> order{(size_t)(i)%8, (size_t)(i+7)%8, (size_t)(i+6)%8, (size_t)(i+5)%8, (size_t)(i+4)%8, (size_t)(i+3)%8, (size_t)(i+2)%8, (size_t)(i+1)%8};
                                tie(ordered_fragment, ordered_location) = reorder_cyclic_fragment(fragment,location, {order});

                            }else{
                                vector<size_t> order{(size_t)(i)%8, (size_t)(i+1)%8, (size_t)(i+2)%8, (size_t)(i+3)%8, (size_t)(i+4)%8, (size_t)(i+5)%8, (size_t)(i+6)%8, (size_t)(i+7)%8};
                                tie(ordered_fragment, ordered_location) = reorder_cyclic_fragment(fragment,location, {order});
                            }
                            rule_cyclic.locations.push_back(ordered_location);
                        }
#pragma omp critical
                        {
                            candidate_rules[rule_cyclic.rule[0]].push_back(rule_cyclic);
                        }
                    }
                }
            }
        }
    }
    cyclic_fragments_timer.Stop();
    Timer initial_utility_calculation("Initial Utility Calculation");
    Timer coverage_calculation("Coverage Calculation");
    size_t total_number_of_rules{0};
    for(auto & [predicate, rules]:candidate_rules){
        total_number_of_rules += rules.size();
    }
    cout << "Total number of rules " << total_number_of_rules << endl;
    calculate_coverage();
    coverage_calculation.Stop();
    Timer utility_calculation("Utility Calculation");
    for(auto& [predicate, rules]:candidate_rules){
        size_t heads{};
        for(auto& rule:rules){
            auto utility_pair = calculate_utility({rule});
            rule.utility = utility_pair.first;
            heads += utility_pair.second;
        }
    }
    utility_calculation.Stop();
    initial_utility_calculation.Stop();
}

void CandidateRuleCollection::calculate_coverage() {
    for(const auto& [predicate, rules]:this->candidate_rules){
        PredicateId predicate_id = predicate;
        vector<Rule> rules_vector(rules.begin(), rules.end());
        // FIND LOCATIONS OF PREDICATE
#pragma omp parallel for schedule(dynamic)
        for(size_t i=0; i<rules_vector.size(); i++){
            auto& rule = rules_vector[i];
            for(auto location:rule.locations){
                if(rule.covered_locations_count[location[0]]){
                    rule.covered_locations_count[location[0]]++;
                }else{
                    rule.covered_locations_count[location[0]]=1;
                }
                rule.covered_locations.insert(location[0]);
            }
#pragma omp critical
            {
                this->candidate_rules[predicate_id][i].covered_locations = rule.covered_locations;
                this->candidate_rules[predicate_id][i].covered_locations_count = rule.covered_locations_count;
            }
        }
    }
}

vector<Rule> CandidateRuleCollection::get_final_rules(size_t max_number_of_rules, const vector<string>& target_predicates) {
    map<PredicateId, vector<Rule>> final_rules{};
    map<PredicateId, vector<Rule>> left_candidate_rules{};
    vector<PredicateId> predicates_of_interest{};
    vector<PredicateId> informative_predicates{};
    unordered_map<PredicateId, unordered_set<PredicateId >> remaining_predicates{};
    for(auto predicate:this->fragment_collection.hg.get_predicate_name_to_id()){
        if(has(target_predicates, predicate.first)){
            predicates_of_interest.push_back(predicate.second);
        }else{
            informative_predicates.push_back(predicate.second);
        }
    }
    for(auto predicate:predicates_of_interest){
        remaining_predicates[predicate] = unordered_set<PredicateId>(informative_predicates.begin(), informative_predicates.end());
    }
    vector<Rule> final_order;
    map<PredicateId, pair<double, Rule>> next_best_candidates{}; // Predicate: {utility_increase, {candidate_rules}}
    // Initialise the final rules with at least one rule per predicate
    for(const auto& predicate:predicates_of_interest){
        for(const auto& rule:candidate_rules[predicate]){
            left_candidate_rules[predicate].push_back(rule);
        }
    }
    for(auto predicate:predicates_of_interest){
        if(!left_candidate_rules[predicate].empty()) {
            if (final_rules[predicate].empty()) {
                sort(left_candidate_rules[predicate].begin(), left_candidate_rules[predicate].end(),
                     compare_rule_fidelity);
                int number_of_rules = min((size_t)20,
                                          left_candidate_rules[predicate].size());
                left_candidate_rules[predicate] = vector<Rule>(left_candidate_rules[predicate].begin(),
                                                               left_candidate_rules[predicate].begin() + number_of_rules);
                final_order.insert(final_order.end(),left_candidate_rules[predicate].begin(),left_candidate_rules[predicate].end());
                continue;
                double max_utility = left_candidate_rules[predicate][0].utility;
                Rule best_rule = left_candidate_rules[predicate][0];
                int rule_complexity = 0;
                int best_rule_location = 0;
                for (auto atom: best_rule.rule) {
                    if (atom != 0) {
                        rule_complexity++;
                    }
                }
                int rule_location=0;
                for (const auto &rule: left_candidate_rules[predicate]) {
                    if (rule.utility == max_utility) {
                        int new_rule_complexity = 0;
                        for (auto atom: rule.rule) {
                            if (atom != 0) {
                                new_rule_complexity++;
                            }
                        }
                        if (new_rule_complexity < rule_complexity) {
                            best_rule = rule;
                            rule_complexity = new_rule_complexity;
                            best_rule_location = rule_location;
                        }
                        rule_location++;
                    } else {
                        break;
                    }
                }
                final_rules[predicate] = vector<Rule>{best_rule};
                final_order.push_back(best_rule);
                left_candidate_rules[predicate].erase(left_candidate_rules[predicate].begin() + best_rule_location);
                for(auto remaining_predicate:best_rule.body){
                    remaining_predicates[predicate].erase(remaining_predicate);
                }
                vector<PredicateId> remaining_predicates_vector(remaining_predicates[predicate].begin(), remaining_predicates[predicate].end());
            }
            auto min_utility_pair = calculate_utility(final_rules[predicate]);
            double min_utility = min_utility_pair.first;
            double max_utility = min_utility;
            Rule most_promising_rule{};
            for (size_t i=0; i<left_candidate_rules[predicate].size(); i++) {
                const auto &rule = left_candidate_rules[predicate][i];
                vector<Rule> rule_set = final_rules[predicate];
                rule_set.push_back(rule);
                auto utility_pair = calculate_utility(rule_set);
                double utility = utility_pair.first;
                if (utility > max_utility) {
                    max_utility = utility;
                    most_promising_rule = rule;
                } else if (utility == max_utility) {
                    int new_rule_complexity = 0;
                    for (auto atom: rule.rule) {
                        if (atom != 0) {
                            new_rule_complexity++;
                        }
                    }
                    int rule_complexity = 0;
                    for (auto atom: most_promising_rule.rule) {
                        if (atom != 0) {
                            rule_complexity++;
                        }
                    }
                    if (new_rule_complexity < rule_complexity) {
                        most_promising_rule = rule;
                    }
                }
            }
            next_best_candidates[predicate] = {max_utility - min_utility, most_promising_rule};
        }
    }
    size_t N = 0;
    for(const auto& rule_sets:final_rules){
        N += rule_sets.second.size();
    }
    N=max_number_of_rules;
    while(N < max_number_of_rules){
        // Find the maximum utility increase
        double utility_increase{0};
        PredicateId maximal_increasing_predicate{};
        Rule maximal_increasing_rule{};
        for(const auto& candidate:next_best_candidates){
            if(candidate.second.first > utility_increase){
                maximal_increasing_predicate = candidate.first;
                maximal_increasing_rule = candidate.second.second;
                utility_increase = candidate.second.first;
            }
        }
        if(utility_increase > 0){
            final_rules[maximal_increasing_predicate].push_back(maximal_increasing_rule);
            final_order.push_back(maximal_increasing_rule);
            int j = 0;
            for(const auto& rule:left_candidate_rules[maximal_increasing_predicate]){
                if(rule.rule == maximal_increasing_rule.rule){
                    left_candidate_rules[maximal_increasing_predicate].erase(left_candidate_rules[maximal_increasing_predicate].begin()+j);
                    break;
                }
                j++;
            }
            auto min_utility_pair = calculate_utility(final_rules[maximal_increasing_predicate]);
            double min_utility = min_utility_pair.first;
            double max_utility = min_utility;
            Rule most_promising_rule{};
//#pragma omp parallel for schedule(dynamic)
            for(int i=0; i<left_candidate_rules[maximal_increasing_predicate].size(); i++){
                vector<Rule> rule_set = final_rules[maximal_increasing_predicate];
                rule_set.push_back(left_candidate_rules[maximal_increasing_predicate][i]);
                auto utility_pair = calculate_utility(rule_set);
                double utility = utility_pair.first;
                if(utility > max_utility){
                    max_utility = utility;
                    most_promising_rule = left_candidate_rules[maximal_increasing_predicate][i];
                } else if (utility == max_utility) {
                    int new_rule_complexity = 0;
                    for (auto atom: left_candidate_rules[maximal_increasing_predicate][i].rule) {
                        if (atom != 0) {
                            new_rule_complexity++;
                        }
                    }
                    int rule_complexity = 0;
                    for (auto atom: most_promising_rule.rule) {
                        if (atom != 0) {
                            rule_complexity++;
                        }
                    }
                    if (new_rule_complexity < rule_complexity) {
                        most_promising_rule = left_candidate_rules[maximal_increasing_predicate][i];
                    }
                }
            }
            next_best_candidates[maximal_increasing_predicate]={max_utility-min_utility, most_promising_rule};
            N++;
        }else{
            break;
        }

    }
    return final_order;
}

pair<double,size_t> CandidateRuleCollection::calculate_utility(const vector<Rule>& rules) {
    double global_utility{0};
    double global_precision{0};
    for(const auto& rule:rules){
        global_precision += rule.fidelity;
    }
    double global_complexity = calculate_complexity(rules);
    unordered_set<EdgeId> head_ids{};
    for(const auto& rule:rules){
        head_ids.insert(rule.covered_locations.begin(),rule.covered_locations.end());
    }
    vector<EdgeId> head_ids_vector(head_ids.begin(), head_ids.end());
    double global_recall{0};
#pragma omp parallel for schedule(dynamic) reduction(+:global_recall)
    for(size_t i = 0; i<head_ids_vector.size(); i++){
        double local_recall_degree{1};
        for(auto rule:rules){
            local_recall_degree += rule.covered_locations_count[head_ids_vector[i]];
        }
        global_recall += log(local_recall_degree);
    }
    global_utility = global_precision * global_recall * global_complexity;
    // FIDELITY
    return {global_utility, head_ids.size()};
}

void CandidateRuleCollection::print_rules(const string& logic,
                                          const string& output_file_name,
                                          size_t max_number_of_rules,
                                          const vector<string>& target_predicates) {

    string model_filename = output_file_name;
    if(logic == "psl") {
        model_filename += ".json";
    }else{
        model_filename += "." + logic;
    }
    ofstream file;
    file.open(model_filename);
    vector<Rule> rules = get_final_rules(max_number_of_rules, target_predicates);
    map<PredicateId,vector<Rule>> utility_normalisation_map{};
    for(auto rule:rules){
        utility_normalisation_map[rule.head[0]].push_back(rule);
    }
    for(auto& [predicate, normalised_rules]:utility_normalisation_map){
        double utility_sum = 0;
        for(const auto& rule:normalised_rules){
            utility_sum += rule.utility;
        }
        for(auto& rule:normalised_rules){
            rule.utility = rule.utility/utility_sum;
        }
    }
    for(auto& rule:rules){
        for(auto&  normalised_rule:utility_normalisation_map[rule.head[0]]){
            if(normalised_rule.rule == rule.rule){
                rule.utility = normalised_rule.utility;
            }
        }
    }
    if(file.is_open()){
        if(logic == "psl"){
            file << "{\n";
            file << "    \"rules\": [\n";
        }
        for(int i =0; i< rules.size(); i++){
            string printable_rule = print_rule(rules[i], logic);
            if(logic == "psl") {
                file << "        " << printable_rule;
                if (i != rules.size() - 1) {
                    file << ",\n";
                } else {
                    file << "\n";
                }
            }else{
                file << printable_rule;
                if(i != rules.size()-1){
                    file << "\n";
                }
            }
        }
        if(logic == "psl"){
            file << "    ]\n";
            file << "}\n";
        }
    }else{
        throw FileNotOpenedException(model_filename);
    }
}

double CandidateRuleCollection::compute_local_utility(EdgeId head, vector<Rule> rules) {
    double total_recall_degree{0};
    vector<double> local_structure_discounted_recalls{};
    double total_local_utility{0};
    double normalization_factor{0};
    for(auto & rule : rules){
        if(rule.covered_locations.find(head) == rule.covered_locations.end()){
            continue;
        }
        double recall_degree = ((double)rule.covered_locations_count[head]/1.0);
        total_local_utility += rule.fidelity * rule.local_structure_penalty[head];
        total_recall_degree += recall_degree;
        normalization_factor += 1;
    }
    double global_coverage_mass = log(1+total_recall_degree);
    return (global_coverage_mass * total_local_utility)/normalization_factor;
}

string CandidateRuleCollection::print_rule(Rule rule, const string& logic) {
    const char* vars[4];
    if(logic == "psl"){
        vars[0] = "A";
        vars[1] = "B";
        vars[2] = "C";
        vars[3] = "D";
    }else{
        vars[0] = "v0";
        vars[1] = "v1";
        vars[2] = "v2";
        vars[3] = "v3";
    }
    map<NodeId, int> nodes_to_var{};
    vector<EdgeId> location = *rule.locations.begin();
    int var_counter = 0;
    for(int i = (int)location.size()-1; i>-1; i--){
        vector<NodeId> nodes = this->fragment_collection.get_edge(location[i]);
        for(auto node:nodes){
            if (nodes_to_var.find(node) == nodes_to_var.end()) {
                nodes_to_var[node] = var_counter;
                var_counter++;
            }
        }
    }
    string printable_rule{};
    int k = (int)location.size()-1;
    if(logic == "psl"){
        printable_rule += "\"" + doubleToStringWithTwoDecimals(rule.fidelity) + ": ";
    }
    for(int i = (int)rule.rule.size() - 1; i > 0; i--) {
        if (rule.rule[i] != 0) {
            if(logic != "psl"){
                printable_rule += "!";
            }
            std::string predicate = this->fragment_collection.get_predicate_from_id(rule.rule[i]);
            size_t pos = predicate.find_last_not_of("0123456789");
            std::string clean_predicate = predicate.substr(0, pos + 1);
            std::string constant = predicate.substr(pos + 1);
            printable_rule += clean_predicate;
            printable_rule += "(";
            if (this->fragment_collection.is_unary(rule.rule[i])) {
                printable_rule += vars[nodes_to_var[this->fragment_collection.get_edge(location[k])[0]]];
            } else {
                printable_rule += vars[nodes_to_var[this->fragment_collection.get_edge(location[k])[0]]];
                printable_rule += ",";
                printable_rule += vars[nodes_to_var[this->fragment_collection.get_edge(location[k])[1]]];
            }
            if(!constant.empty()){
                printable_rule += ",'" + constant + "'";
            }
            printable_rule += ")";
            k--;
            if (i > 1) {
                for(int j=i-1; j>0; j--){
                    if(rule.rule[j] != 0){
                        if(logic == "psl")
                            printable_rule += " & ";
                        else
                            printable_rule += " v ";
                        break;
                    }
                }
            }
        }
    }
    if(logic == "psl")
        printable_rule += " >> ";
    else
        printable_rule += " v ";
    std::string predicate = this->fragment_collection.get_predicate_from_id(rule.rule[0]);
    size_t pos = predicate.find_last_not_of("0123456789");
    std::string clean_predicate = predicate.substr(0, pos + 1);
    std::string constant = predicate.substr(pos + 1);
    printable_rule += clean_predicate;
    printable_rule += "(";
    if (this->fragment_collection.is_unary(rule.rule[0])) {
        NodeId node = nodes_to_var[this->fragment_collection.get_edge(location[k])[0]];
        printable_rule += vars[node];
    } else {
        printable_rule += vars[nodes_to_var[this->fragment_collection.get_edge(location[k])[0]]];
        printable_rule += ",";
        printable_rule += vars[nodes_to_var[this->fragment_collection.get_edge(location[k])[1]]];
    }
    if(!constant.empty()){
        printable_rule += ",'" + constant + "'";
    }
    printable_rule += ")";
    if(logic == "psl"){
        printable_rule += " ^2\"";
    }
    return printable_rule;
}

void CandidateRuleCollection::calculate_structure_penalties() {
    for(const auto& [predicate, rules]:this->candidate_rules){
        vector<Rule> rules_vector = vector<Rule>(rules.begin(), rules.end());
        // FIND LOCATIONS OF PREDICATE
#pragma omp parallel for schedule(dynamic)
        for(int j = 0; j<rules_vector.size(); j++){
            auto& rule = rules_vector[j];
            map<EdgeId, vector<double>> local_structure_penalty{};
            if(rule.rule.size() == 2){
                for(auto location:rule.locations){
                    double structure_penalty;
                    bool unary_edge = this->fragment_collection.is_unary_edge(location[0]);
                    vector<NodeId> nodes = this->fragment_collection.hg.get_edge(location[0]);
                    if(unary_edge){
                        vector<EdgeId> unaries;
                        vector<EdgeId> binaries;
                        tie(unaries, binaries) = this->fragment_collection.get_unary_binary_edges(nodes[0]);
                        structure_penalty = 2/(double)unaries.size();
                    }else{
                        vector<EdgeId> unaries_1;
                        vector<EdgeId> binaries_1;
                        tie(unaries_1, binaries_1) = this->fragment_collection.get_unary_binary_edges(nodes[0]);
                        vector<EdgeId> unaries_2;
                        vector<EdgeId> binaries_2;
                        tie(unaries_2, binaries_2) = this->fragment_collection.get_unary_binary_edges(nodes[1]);
                        structure_penalty = 1.0/((double)binaries_2.size()+1.0/(double)binaries_1.size());
                    }
                    local_structure_penalty[location[0]].push_back(structure_penalty);
                }
            }else{
                for(auto location:rule.locations){
                    double structure_penalty = 1;
                    if((rule.rule.size() == 6 || rule.rule.size() == 8)){
                        structure_penalty = 2;
                    }
                    NodeId starting_node= this->fragment_collection.get_edge(location[0])[0];
                    NodeId current_node = starting_node;
                    bool last_edge_binary = false;
                    for(int i=0; i<location.size()-1; i++){
                        if(i == 0 && !this->fragment_collection.is_unary_edge(location[0])){
                            last_edge_binary = true;
                            vector<NodeId> nodes_1 = this->fragment_collection.hg.get_edge(location[0]);
                            vector<NodeId> nodes_2 = this->fragment_collection.hg.get_edge(location[1]);
                            if(nodes_1[1] == nodes_2[0] || nodes_1[1] == nodes_2[1]){
                                current_node = nodes_1[1];
                            }else{
                                current_node = nodes_1[0];
                            }
                        }
                        vector<EdgeId> unaries;
                        vector<EdgeId> binaries;
                        tie(unaries, binaries) = this->fragment_collection.get_unary_binary_edges(current_node);
                        size_t number_of_binaries=binaries.size();
                        if(number_of_binaries == 0){
                            continue;
                        }
                        size_t number_of_unaries=unaries.size();
                        bool unary_edge = this->fragment_collection.is_unary_edge(location[i]);
                        if(unary_edge) {
                            structure_penalty *= 2 / (double)(number_of_unaries + 1);
                        }else{
                            if(current_node == starting_node){
                                structure_penalty *= 1 / (double)(number_of_binaries);
                            }else{
                                structure_penalty *= 2 / (double)(number_of_binaries);
                            }
                            if(last_edge_binary){
                                structure_penalty *= 1/(double)(number_of_unaries + 1);
                            }
                            last_edge_binary = true;
                            current_node = this->fragment_collection.hg.get_next_node(current_node, location[i]);
                        }
                    }
                    local_structure_penalty[location[0]].push_back(structure_penalty);
                }
            }
            for(auto structure_penalty:local_structure_penalty){
                double average_structure_penalty = accumulate(structure_penalty.second.begin(), structure_penalty.second.end(), 0.0)/(double)structure_penalty.second.size();
                rule.local_structure_penalty[structure_penalty.first] = average_structure_penalty;
            }
        }
#pragma omp critical
        {
            this->candidate_rules[predicate] = vector<Rule>(rules_vector.begin(), rules_vector.end());
        }
    }
}

double CandidateRuleCollection::calculate_complexity(const vector<Rule>& rules) {
    double complexity = 1;
    for(auto rule:rules){
        complexity *=exp(-1*(double)rule.locations[0].size());
    }
    return pow(complexity, 1.0/rules.size());
}
