#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <pthread.h>
#include <semaphore.h>
#include <vector>
#include <string>
#include <map>
#include <set>
#include <queue>
#include <algorithm>
#include <gsl/gsl_rng.h>

#define MAX_STRING 1000
#define MAX_THREADS 100
#define MAX_LENGTH 100

double sigmoid(double x)
{
    return 1.0 / (1.0 + exp(-x));
}

double abs_val(double x)
{
    if (x < 0) return -x;
    else return x;
}

struct ArgStruct
{
    void *ptr;
    int id;
    
    ArgStruct(void *_ptr, int _id)
    {
        ptr = _ptr;
        id = _id;
    }
};

struct Triplet
{
    int h, t, r;
    
    friend bool operator < (Triplet u, Triplet v)
    {
        if (u.r == v.r)
        {
            if (u.h == v.h) return u.t < v.t;
            return u.h < v.h;
        }
        return u.r < v.r;
    }
    
    friend bool operator == (Triplet u, Triplet v)
    {
        if (u.h == v.h && u.t == v.t && u.r == v.r) return true;
        return false;
    }
};

struct RankListEntry
{
    int id;
    double val;
    
    friend bool operator < (RankListEntry u, RankListEntry v)
    {
        return u.val > v.val;
    }
};

struct Parameter
{
    double data, m, v, t;
    
    Parameter()
    {
        data = 0; m = 0; v = 0; t = 0;
    }
    
    void clear()
    {
        data = 0; m = 0; v = 0; t = 0;
    }
    
    void update(double grad, double learning_rate, double weight_decay=0)
    {
        double g = grad - weight_decay * data;

        t += 1;
        m = 0.9 * m + 0.1 * g;
        v = 0.999 * v + 0.001 * g * g;

        double bias1 = 1 - exp(log(0.9) * t);
        double bias2 = 1 - exp(log(0.999) * t);

        double mt = m / bias1;
        double vt = sqrt(v) / sqrt(bias2) + 0.00000001;

        data += learning_rate * mt / vt;
    }
};

struct Rule
{
    std::vector<int> r_body;
    int r_head;
    int type;
    double n_correct, n_groundings, H, cn;
    Parameter wt;
    
    Rule()
    {
        r_body.clear(); r_head = -1;
        type = -1;
        n_correct = 0; n_groundings = 0;
        H = 0;
        cn = 0;
        wt.clear();
    }
    
    void clear()
    {
        r_body.clear(); r_head = -1;
        type = -1;
        n_correct = 0; n_groundings = 0;
        H = 0;
        cn = 0;
        wt.clear();
    }
    
    friend bool operator < (Rule u, Rule v)
    {
        if (u.type == v.type)
        {
            if (u.r_head == v.r_head)
            {
                for (int k = 0; k != u.type; k++)
                {
                    if (u.r_body[k] != v.r_body[k])
                    return u.r_body[k] < v.r_body[k];
                }
            }
            return u.r_head < v.r_head;
        }
        return u.type < v.type;
    }
    
    friend bool operator == (Rule u, Rule v)
    {
        if (u.r_body == v.r_body && u.r_head == v.r_head && u.type == v.type) return true;
        return false;
    }
};

class KnowledgeGraph
{
protected:
    int entity_size, relation_size, train_triplet_size, valid_triplet_size, test_triplet_size, all_triplet_size;
    std::map<std::string, int> ent2id, rel2id;
    std::map<int, std::string> id2ent, id2rel;
    std::vector<Triplet> train_triplets, valid_triplets, test_triplets;
    std::vector<std::pair<int, int> > *e2rn;
    std::set<Triplet> set_train_triplets, set_all_triplets;

public:
    friend class RuleMiner;
    friend class ReasoningPredictor;
    friend class RuleGenerator;
    
    KnowledgeGraph()
    {
        entity_size = 0; relation_size = 0;
        train_triplet_size = 0; valid_triplet_size = 0; test_triplet_size = 0;
        all_triplet_size = 0;
        
        ent2id.clear(); rel2id.clear();
        id2ent.clear(); id2rel.clear();
        train_triplets.clear(); valid_triplets.clear(); test_triplets.clear();
        set_train_triplets.clear(); set_all_triplets.clear();
        e2rn = NULL;
    }
    
    ~KnowledgeGraph()
    {
        ent2id.clear(); rel2id.clear();
        id2ent.clear(); id2rel.clear();
        train_triplets.clear(); valid_triplets.clear(); test_triplets.clear();
        set_train_triplets.clear(); set_all_triplets.clear();
        for (int k = 0; k != entity_size; k++) e2rn[k].clear();
        delete [] e2rn;
    }
    
    void read_data(char *data_path)
    {
        char s_head[MAX_STRING], s_tail[MAX_STRING], s_ent[MAX_STRING], s_rel[MAX_STRING], s_file[MAX_STRING];
        int h, t, r, id;
        Triplet triplet;
        std::map<std::string, int>::iterator iter;
        FILE *fi;
        
        strcpy(s_file, data_path);
        strcat(s_file, "/entities.dict");
        fi = fopen(s_file, "rb");
        if (fi == NULL)
        {
            printf("ERROR: file of entities not found!\n");
            exit(1);
        }
        while (1)
        {
            if (fscanf(fi, "%d %s", &id, s_ent) != 2) break;
            
            ent2id[s_ent] = id;
            id2ent[id] = s_ent;
            entity_size += 1;
        }
        fclose(fi);
        
        strcpy(s_file, data_path);
        strcat(s_file, "/relations.dict");
        fi = fopen(s_file, "rb");
        if (fi == NULL)
        {
            printf("ERROR: file of relations not found!\n");
            exit(1);
        }
        while (1)
        {
            if (fscanf(fi, "%d %s", &id, s_rel) != 2) break;
            
            rel2id[s_rel] = id;
            id2rel[id] = s_rel;
            relation_size += 1;
        }
        fclose(fi);
        
        strcpy(s_file, data_path);
        strcat(s_file, "/train.txt");
        fi = fopen(s_file, "rb");
        if (fi == NULL)
        {
            printf("ERROR: file of train triplets not found!\n");
            exit(1);
        }
        while (1)
        {
            if (fscanf(fi, "%s %s %s", s_head, s_rel, s_tail) != 3) break;
            if (ent2id.count(s_head) == 0 || ent2id.count(s_tail) == 0 || rel2id.count(s_rel) == 0) continue;
            
            h = ent2id[s_head]; t = ent2id[s_tail]; r = rel2id[s_rel];
            triplet.h = h; triplet.t = t; triplet.r = r;
            train_triplets.push_back(triplet);
            set_train_triplets.insert(triplet);
            set_all_triplets.insert(triplet);
        }
        fclose(fi);
        
        train_triplet_size = int(train_triplets.size());
        e2rn = new std::vector<std::pair<int, int> > [entity_size];
        for (int k = 0; k != train_triplet_size; k++)
        {
            h = train_triplets[k].h; r = train_triplets[k].r; t = train_triplets[k].t;
            e2rn[h].push_back(std::make_pair(r, t));
        }

        strcpy(s_file, data_path);
        strcat(s_file, "/valid.txt");
        fi = fopen(s_file, "rb");
        if (fi == NULL)
        {
            printf("ERROR: file of test triplets not found!\n");
            exit(1);
        }
        while (1)
        {
            if (fscanf(fi, "%s %s %s", s_head, s_rel, s_tail) != 3) break;
            if (ent2id.count(s_head) == 0 || ent2id.count(s_tail) == 0 || rel2id.count(s_rel) == 0) continue;

            h = ent2id[s_head]; t = ent2id[s_tail]; r = rel2id[s_rel];
            triplet.h = h; triplet.t = t; triplet.r = r;
            valid_triplets.push_back(triplet);
            set_all_triplets.insert(triplet);
        }
        fclose(fi);
        valid_triplet_size = int(valid_triplets.size());
        
        strcpy(s_file, data_path);
        strcat(s_file, "/test.txt");
        fi = fopen(s_file, "rb");
        if (fi == NULL)
        {
            printf("ERROR: file of test triplets not found!\n");
            exit(1);
        }
        while (1)
        {
            if (fscanf(fi, "%s %s %s", s_head, s_rel, s_tail) != 3) break;
            if (ent2id.count(s_head) == 0 || ent2id.count(s_tail) == 0 || rel2id.count(s_rel) == 0) continue;

            h = ent2id[s_head]; t = ent2id[s_tail]; r = rel2id[s_rel];
            triplet.h = h; triplet.t = t; triplet.r = r;
            test_triplets.push_back(triplet);
            set_all_triplets.insert(triplet);
        }
        fclose(fi);
        test_triplet_size = int(test_triplets.size());

        all_triplet_size = int(set_all_triplets.size());
        
        printf("#Entities: %d          \n", entity_size);
        printf("#Relations: %d          \n", relation_size);
        printf("#Train triplets: %d          \n", train_triplet_size);
        printf("#Valid triplets: %d          \n", valid_triplet_size);
        printf("#Test triplets: %d          \n", test_triplet_size);
        printf("#All triplets: %d          \n", all_triplet_size);
    }
    
    bool check_observed(Triplet triplet)
    {
        if (set_train_triplets.count(triplet) != 0) return true;
        else return false;
    }

    bool check_true(Triplet triplet)
    {
        if (set_all_triplets.count(triplet) != 0) return true;
        else return false;
    }
    
    void rule_search(int r, int e, int goal, int *path, int depth, int max_depth, std::set<Rule> &rule_set)
    {
        if (e == goal)
        {
            Rule rule;
            rule.type = depth;
            rule.r_head = r;
            rule.r_body.clear();
            for (int k = 0; k != depth; k++)
            {
                rule.r_body.push_back(path[k]);
            }
            rule_set.insert(rule);
            //return;
        }
        if (depth == max_depth)
        {
            return;
        }
        
        int len = int(e2rn[e].size());
        int cur_r, cur_n;
        for (int k = 0; k != len; k++)
        {
            cur_r = e2rn[e][k].first;
            cur_n = e2rn[e][k].second;
            path[depth] = cur_r;
            rule_search(r, cur_n, goal, path, depth+1, max_depth, rule_set);
        }
    }
    
    void rule_destination(int e, Rule rule, std::vector<int> &dests)
    {
        std::queue< std::pair<int, int> > queue;
        queue.push(std::make_pair(e, 0));
        int current_e, current_d, current_r, next_e;
        while (!queue.empty())
        {
            std::pair<int, int> pair = queue.front();
            current_e = pair.first;
            current_d = pair.second;
            queue.pop();
            if (current_d == int(rule.r_body.size()))
            {
                dests.push_back(current_e);
                continue;
            }
            current_r = rule.r_body[current_d];
            for (int k = 0; k != int(e2rn[current_e].size()); k++)
            {
                if (e2rn[current_e][k].first != current_r) continue;
                
                next_e = e2rn[current_e][k].second;
                queue.push(std::make_pair(next_e, current_d + 1));
            }
        }
    }
};

class RuleMiner
{
protected:
    KnowledgeGraph *p_kg;
    int num_threads, max_length;
    double threshold, support;
    long long total_count;
    std::vector<Rule> *rel2rules;
    std::map<Rule, std::pair<double, double> > *rel2rule2stat;
    sem_t mutex;

public:
    RuleMiner()
    {
        num_threads = 4; max_length = 3;
        total_count = 0;
        rel2rule2stat = NULL; rel2rules = NULL;
        sem_init(&mutex, 0, 1);
        p_kg = NULL;
    }
    
    ~RuleMiner()
    {
        total_count = 0;
        if (rel2rule2stat != NULL)
        {
            for (int r = 0; r != p_kg->relation_size; r++) rel2rule2stat[r].clear();
            delete [] rel2rule2stat;
            rel2rule2stat = NULL;
        }
        if (rel2rules != NULL)
        {
            for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear();
            delete [] rel2rules;
            rel2rules = NULL;
        }
        sem_init(&mutex, 0, 1);
        p_kg = NULL;
    }
    
    void init_knowledge_graph(KnowledgeGraph *_p_kg)
    {
        p_kg = _p_kg;
        rel2rules = new std::vector<Rule> [p_kg->relation_size];
        rel2rule2stat = new std::map<Rule, std::pair<double, double> > [p_kg->relation_size];
    }
    
    void clear()
    {
        total_count = 0;
        for (int k = 0; k != p_kg->relation_size; k++)
        {
            rel2rules[k].clear();
            rel2rule2stat[k].clear();
        }
        sem_init(&mutex, 0, 1);
    }
    
    void search_thread(int thread)
    {
        int triplet_size = p_kg->train_triplet_size;
        int bg = int(triplet_size / num_threads) * thread;
        int ed = int(triplet_size / num_threads) * (thread + 1);
        if (thread == num_threads - 1) ed = triplet_size;
        
        std::set<Rule>::iterator iter;
        std::set<Rule> rule_set;
        std::vector<int> dests;
        Rule rule;
        Triplet triplet;
        int path[MAX_LENGTH], h, r, t, hit, dest;
        
        for (int T = bg; T != ed; T++)
        {
            if (T % 10 == 0)
            {
                total_count += 10;
                printf("Rule Discovery | Progress: %.3lf%%          %c", (double)total_count / (double)(triplet_size + 1) * 100, 13);
                fflush(stdout);
            }
            
            h = p_kg->train_triplets[T].h;
            r = p_kg->train_triplets[T].r;
            t = p_kg->train_triplets[T].t;
            
            rule_set.clear();
            p_kg->rule_search(r, h, t, path, 0, max_length, rule_set);
        
            for (iter = rule_set.begin(); iter != rule_set.end(); iter++)
            {
                if (iter->type == 1 && iter->r_body[0] == r)
                {
                    rule_set.erase(iter);
                    break;
                }
            }
            
            for (iter = rule_set.begin(); iter != rule_set.end(); iter++)
            {
                rule = *iter;
                dests.clear();
                p_kg->rule_destination(h, rule, dests);
                
                hit = 0;
                triplet = p_kg->train_triplets[T];
                for (int k = 0; k != int(dests.size()); k++)
                {
                    dest = dests[k];
                    triplet.t = dest;
                    if (p_kg->check_observed(triplet)) hit += 1;
                }
                
                sem_wait(&mutex);
                if (rel2rule2stat[r].count(rule) == 0) rel2rule2stat[r][rule] = std::make_pair(0.0, 0.0);
                rel2rule2stat[r][rule].first += hit;
                rel2rule2stat[r][rule].second += dests.size();
                sem_post(&mutex);
            }
        }
        rule_set.clear();
        dests.clear();
    }
    
    static void *search_thread_caller(void *arg)
    {
        RuleMiner *ptr = (RuleMiner *)(((ArgStruct *)arg)->ptr);
        int thread = ((ArgStruct *)arg)->id;
        ptr->search_thread(thread);
        pthread_exit(NULL);
    }
    
    void search(int _max_length, int _num_threads)
    {
        max_length = _max_length;
        num_threads = _num_threads;
        
        pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t));
        for (int k = 0; k != num_threads; k++) pthread_create(&pt[k], NULL, RuleMiner::search_thread_caller, new ArgStruct(this, k));
        for (int k = 0; k != num_threads; k++) pthread_join(pt[k], NULL);
        printf("Rule Discovery | DONE!                              \n");
        free(pt);
    }
    
    void filter(double _threshold, double _support)
    {
        threshold = _threshold;
        support = _support;
        
        int rel;
        Rule rule;
        std::map< Rule, std::pair<double, double> >::iterator iter;
        
        for (rel = 0; rel != p_kg->relation_size; rel++)
        {
            for (iter = rel2rule2stat[rel].begin(); iter != rel2rule2stat[rel].end(); iter++)
            {
                rule = iter->first;
                double n_correct, n_groundings, precision;
                n_correct = (iter->second).first;
                n_groundings = (iter->second).second;
                precision = n_correct / n_groundings;
                if (n_correct < support || precision < threshold) continue;
                rule.n_correct = n_correct;
                rule.n_groundings = n_groundings;
                rel2rules[rel].push_back(rule);
            }
        }
    }
    
    std::vector<Rule> *get_logic_rules()
    {
        return rel2rules;
    }
};

class ReasoningPredictor
{
protected:
    KnowledgeGraph *p_kg;
    std::vector<Rule> *rel2rules;
    int num_threads, top_k;
    double temperature, learning_rate, weight_decay;
    bool test;
    long long total_count;
    double total_loss;
    std::vector<int> hits;
    sem_t mutex;

public:
    ReasoningPredictor()
    {
        num_threads = 4; top_k = 100;
        temperature = 100; learning_rate = 0.01; weight_decay = 0.0005;
        total_count = 0; total_loss = 0;
        rel2rules = NULL;
        test = true;
        sem_init(&mutex, 0, 1);
        p_kg = NULL;
    }
    
    ~ReasoningPredictor()
    {
        total_count = 0; total_loss = 0;
        if (rel2rules != NULL)
        {
            for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear();
            delete [] rel2rules;
            rel2rules = NULL;
        }
        sem_init(&mutex, 0, 1);
        p_kg = NULL;
    }
    
    void init_knowledge_graph(KnowledgeGraph *_p_kg)
    {
        p_kg = _p_kg;
    }
    
    void set_logic_rules(std::vector<Rule> * _rel2rules)
    {
        if (rel2rules != NULL)
        {
            for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear();
            delete [] rel2rules;
        }
        
        rel2rules = new std::vector<Rule> [p_kg->relation_size];
        for (int r = 0; r != p_kg->relation_size; r++)
        {
            rel2rules[r] = _rel2rules[r];
            for (int k = 0; k != int(rel2rules[r].size()); k++)
            {
                rel2rules[r][k].wt.clear();
                rel2rules[r][k].H = 0;
            }
        }
    }
    
    std::vector<Rule> *get_logic_rules()
    {
        return rel2rules;
    }
    
    void learn_thread(int thread)
    {
        int triplet_size = p_kg->train_triplet_size;
        int bg = int(triplet_size / num_threads) * thread;
        int ed = int(triplet_size / num_threads) * (thread + 1);
        if (thread == num_threads - 1) ed = triplet_size;
        
        std::vector<int> dests;
        Triplet triplet;
        int h, r, t, dest, index;
        double logit, target, grad;
        std::map<int, double> dest2logit;
        std::map<int, std::vector<int> > dest2index;
        std::map<int, double>::iterator iter;
        
        for (int T = bg; T != ed; T++)
        {
            if (T % 10 == 0)
            {
                total_count += 10;
                printf("Learning Rule Weights | Progress: %.3lf%% | Loss: %.6lf          %c", (double)total_count / (double)(triplet_size + 1) * 100, total_loss / total_count, 13);
                fflush(stdout);
            }
            
            h = p_kg->train_triplets[T].h;
            r = p_kg->train_triplets[T].r;
            t = p_kg->train_triplets[T].t;
            
            dest2logit.clear();
            dest2index.clear();
            for (index = 0; index != int(rel2rules[r].size()); index++)
            {
                dests.clear();
                p_kg->rule_destination(h, rel2rules[r][index], dests);
                
                for (int i = 0; i != int(dests.size()); i++)
                {
                    dest = dests[i];
                    if (dest2logit.count(dest) == 0) dest2logit[dest] = 0;
                    if (dest2index.count(dest) == 0) dest2index[dest] = std::vector<int>();
                    dest2logit[dest] += rel2rules[r][index].wt.data / temperature;
                    dest2index[dest].push_back(index);
                }
            }

            double max_val = -1000000, sum_val = 0;
            for (iter = dest2logit.begin(); iter != dest2logit.end(); iter++)
                max_val = std::max(max_val, iter->second);
            for (iter = dest2logit.begin(); iter != dest2logit.end(); iter++)
                sum_val += exp(iter->second - max_val);
            for (iter = dest2logit.begin(); iter != dest2logit.end(); iter++)
                dest2logit[iter->first] = exp(dest2logit[iter->first] - max_val) / sum_val;

            for (iter = dest2logit.begin(); iter != dest2logit.end(); iter++)
            {
                dest = iter->first;
                logit = iter->second;
                
                triplet = p_kg->train_triplets[T];
                triplet.t = dest;
                if (p_kg->check_observed(triplet) == true) target = 1.0;
                else target = 0;
                grad = (target - logit) / temperature;
                
                total_loss += abs_val(target - logit) / dest2logit.size();

                for (int k = 0; k != int(dest2index[dest].size()); k++)
                {
                    index = dest2index[dest][k];
                    rel2rules[r][index].wt.update(grad, learning_rate, weight_decay);
                }
            }
        }
        dest2logit.clear();
        dest2index.clear();
        dests.clear();
    }
    
    static void *learn_thread_caller(void *arg)
    {
        ReasoningPredictor *ptr = (ReasoningPredictor *)(((ArgStruct *)arg)->ptr);
        int thread = ((ArgStruct *)arg)->id;
        ptr->learn_thread(thread);
        pthread_exit(NULL);
    }
    
    void learn(double _learning_rate, double _weight_decay, double _temperature, int _num_threads)
    {
        learning_rate = _learning_rate;
        weight_decay = _weight_decay;
        temperature = _temperature;
        num_threads = _num_threads;
        
        total_count = 0;
        total_loss = 0;
        
        pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t));
        for (int k = 0; k != num_threads; k++) pthread_create(&pt[k], NULL, ReasoningPredictor::learn_thread_caller, new ArgStruct(this, k));
        for (int k = 0; k != num_threads; k++) pthread_join(pt[k], NULL);
        printf("Learning Rule Weights | DONE! | Loss: %.6lf                             \n", total_loss / total_count);
        free(pt);
    }
    
    void H_score_thread(int thread)
    {
        int triplet_size = p_kg->train_triplet_size;
        int bg = int(triplet_size / num_threads) * thread;
        int ed = int(triplet_size / num_threads) * (thread + 1);
        if (thread == num_threads - 1) ed = triplet_size;
        
        std::vector<int> dests;
        int h, r, t, dest, index;
        RankListEntry entry;
        std::map<int, std::vector<int> > dest2index;
        std::map<int, std::vector<int> > index2dest;
        std::map<int, std::vector<int> >::iterator iter;
        std::vector<RankListEntry> rule2score;
        std::set<int> valid;
        
        for (int T = bg; T != ed; T++)
        {
            if (T % 10 == 0)
            {
                total_count += 10;
                printf("Computing H Score | Progress: %.3lf%%          %c", (double)total_count / (double)(triplet_size + 1) * 100, 13);
                fflush(stdout);
            }
            
            h = p_kg->train_triplets[T].h;
            r = p_kg->train_triplets[T].r;
            t = p_kg->train_triplets[T].t;
            
            index2dest.clear();
            for (index = 0; index != int(rel2rules[r].size()); index++)
            {
                dests.clear();
                p_kg->rule_destination(h, rel2rules[r][index], dests);
                
                for (int i = 0; i != int(dests.size()); i++)
                {
                    dest = dests[i];
                    if (dest2index.count(dest) == 0) dest2index[dest] = std::vector<int>();
                    if (index2dest.count(index) == 0) index2dest[index] = std::vector<int>();
                    dest2index[dest].push_back(index);
                    index2dest[index].push_back(dest);
                }
            }

            rule2score.clear();
            for (iter = index2dest.begin(); iter != index2dest.end(); iter++)
            {
                index = iter->first;
                dests = iter->second;

                entry.id = index;
                entry.val = 0;

                for (int i = 0; i != int(dests.size()); i++)
                {
                    dest = dests[i];
                    
                    if (dest == t) entry.val += rel2rules[r][index].wt.data;
                    else entry.val -= rel2rules[r][index].wt.data / dest2index.size();
                }

                rule2score.push_back(entry);
            }

            if (top_k == 0)
            {
                for (int k = 0; k != int(rule2score.size()); k++)
                {
                    index = rule2score[k].id;
                    rel2rules[r][index].H += rule2score[k].val / triplet_size;
                }
            }
            else
            {
                std::sort(rule2score.begin(), rule2score.end());
            
                for (int k = 0; k != int(rule2score.size()); k++)
                {
                    if (k == top_k) break;

                    index = rule2score[k].id;
                    rel2rules[r][index].H += 1.0 / triplet_size;
                }
            }
        }
        index2dest.clear();
        dest2index.clear();
        dests.clear();
        pthread_exit(NULL);
    }
    
    static void *H_score_thread_caller(void *arg)
    {
        ReasoningPredictor *ptr = (ReasoningPredictor *)(((ArgStruct *)arg)->ptr);
        int thread = ((ArgStruct *)arg)->id;
        ptr->H_score_thread(thread);
        pthread_exit(NULL);
    }
    
    void H_score(int _top_k, int _num_threads)
    {
        top_k = _top_k;
        num_threads = _num_threads;
        
        total_count = 0;
        total_loss = 0;
        
        pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t));
        for (int k = 0; k != num_threads; k++) pthread_create(&pt[k], NULL, ReasoningPredictor::H_score_thread_caller, new ArgStruct(this, k));
        for (int k = 0; k != num_threads; k++) pthread_join(pt[k], NULL);
        printf("Computing H Score | DONE!                              \n");
        free(pt);
    }
    
    void evaluate_thread(int thread)
    {
        std::vector<Triplet> *p_triplets;
        if (test) p_triplets = &(p_kg->test_triplets);
        else p_triplets = &(p_kg->valid_triplets);
        
        int triplet_size = int((*p_triplets).size());
        int bg = int(triplet_size / num_threads) * thread;
        int ed = int(triplet_size / num_threads) * (thread + 1);
        if (thread == num_threads - 1) ed = triplet_size;
        
        std::vector<int> dests;
        Triplet triplet;
        int h, r, t, hit, dest, npos, nneg, index;
        RankListEntry *rank_list;
        
        rank_list = new RankListEntry [p_kg->entity_size];
        
        for (int T = bg; T != ed; T++)
        {
            if (T % 10 == 0)
            {
                total_count += 10;
                printf("Evaluation | Progress: %.3lf%%          %c", (double)total_count / (double)(triplet_size + 1) * 100, 13);
                fflush(stdout);
            }
            
            h = (*p_triplets)[T].h;
            r = (*p_triplets)[T].r;
            t = (*p_triplets)[T].t;
            
            for (int k = 0; k != p_kg->entity_size; k++)
            {
                rank_list[k].id = k;
                rank_list[k].val = 0;
            }
            
            for (index = 0; index != int(rel2rules[r].size()); index++)
            {
                dests.clear();
                p_kg->rule_destination(h, rel2rules[r][index], dests);
                
                for (int i = 0; i != int(dests.size()); i++)
                {
                    dest = dests[i];
                    rank_list[dest].val += rel2rules[r][index].wt.data;
                }
            }
            
            std::sort(rank_list, rank_list + p_kg->entity_size);
            
            npos = 0; nneg = 0;
            for (int k = 0; k != p_kg->entity_size; k++)
            {
                if (rank_list[k].val > 0) npos += 1;
                if (rank_list[k].val < 0) nneg += 1;
            }
            
            hit = 1;
            triplet = (*p_triplets)[T];
            for (int k = 0; k != p_kg->entity_size; k++)
            {
                if (rank_list[k].id == t)
                {
                    if (rank_list[k].val == 0) hit = npos + (p_kg->entity_size - npos - nneg) / 2 + 1;
                    break;
                }
                
                triplet.t = rank_list[k].id;
                if (p_kg->check_true(triplet) == false) hit += 1;
            }
            
            //for (int k = 0; k != entity_size; k++)
            //{
            //    printf("%s %lf %lf\n", id2ent[rank_list[k].id].c_str(), rank_list[k].total, rank_list[k].cn);
            //}
            
            sem_wait(&mutex);
            hits.push_back(hit);
            sem_post(&mutex);
        }
        delete [] rank_list;
        dests.clear();
        pthread_exit(NULL);
    }
    
    static void *evaluate_thread_caller(void *arg)
    {
        ReasoningPredictor *ptr = (ReasoningPredictor *)(((ArgStruct *)arg)->ptr);
        int thread = ((ArgStruct *)arg)->id;
        ptr->evaluate_thread(thread);
        pthread_exit(NULL);
    }
    
    void evaluate(bool _test)
    {
        test = _test;
        
        hits.clear();
        total_count = 0;
        total_loss = 0;
        sem_init(&mutex, 0, 1);
        
        pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t));
        for (int k = 0; k != num_threads; k++) pthread_create(&pt[k], NULL, ReasoningPredictor::evaluate_thread_caller, new ArgStruct(this, k));
        for (int k = 0; k != num_threads; k++) pthread_join(pt[k], NULL);
        if (test == true) printf("Evaluation Test | DONE!                              \n");
        else printf("Evaluation Valid | DONE!                              \n");
        free(pt);
        
        double hit1 = 0, hit3 = 0, hit10 = 0, mrr = 0, mr = 0, cn = 0;
        for (int k = 0; k != int(hits.size()); k++)
        {
            if (hits[k] <= 1) hit1 += 1;
            if (hits[k] <= 3) hit3 += 1;
            if (hits[k] <= 10) hit10 += 1;
            mrr += 1.0 / hits[k];
            mr += hits[k];
            cn += 1;
        }
        if (cn != 0)
        {
            hit1 /= cn;
            hit3 /= cn;
            hit10 /= cn;
            mrr /= cn;
            mr /= cn;
        }

        printf("Hit@1: %.3lf\n", hit1 * 100);
        printf("Hit@3: %.3lf\n", hit3 * 100);
        printf("Hit@10: %.3lf\n", hit10 * 100);
        printf("MRR: %.3lf\n", mrr);
        printf("MR: %.3lf\n", mr);
    }
    
    std::vector<Rule> *get_top_rules(int top_n)
    {
        std::vector<Rule> *rel2top_rules = new std::vector<Rule> [p_kg->relation_size];
        std::vector<RankListEntry> rank_list;
        RankListEntry entry;
        for (int r = 0; r != p_kg->relation_size; r++)
        {
            rank_list.clear();
            for (int k = 0; k != int(rel2rules[r].size()); k++)
            {
                entry.id = k;
                entry.val = rel2rules[r][k].H;
                rank_list.push_back(entry);
            }
            std::sort(rank_list.begin(), rank_list.end());
            for (int k = 0; k != int(rel2rules[r].size()); k++)
            {
                if (k >= top_n) break;
                int index = rank_list[k].id;
                rel2top_rules[r].push_back(rel2rules[r][index]);
            }
        }
        return rel2top_rules;
    }
};

class RuleGenerator
{
protected:
    KnowledgeGraph *p_kg;
    std::vector<Rule> *rel2rules, *rel2pool;
    std::vector<int> *mapping;
    const gsl_rng_type * gsl_T;
    gsl_rng * gsl_r;
public:
    RuleGenerator()
    {
        rel2rules = NULL;
        rel2pool = NULL;
        mapping = NULL;
        p_kg = NULL;

        gsl_rng_env_setup();
        gsl_T = gsl_rng_rand48;
        gsl_r = gsl_rng_alloc(gsl_T);
        gsl_rng_set(gsl_r, 314159265);
    }
    
    ~RuleGenerator()
    {
        if (rel2rules != NULL)
        {
            for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear();
            delete [] rel2rules;
            rel2rules = NULL;
        }
        if (rel2pool != NULL)
        {
            for (int r = 0; r != p_kg->relation_size; r++) rel2pool[r].clear();
            delete [] rel2pool;
            rel2pool = NULL;
        }
        if (mapping != NULL)
        {
            for (int r = 0; r != p_kg->relation_size; r++) mapping[r].clear();
            delete [] mapping;
            mapping = NULL;
        }
        p_kg = NULL;
    }
    
    void init_knowledge_graph(KnowledgeGraph *_p_kg)
    {
        p_kg = _p_kg;
    }
    
    void set_logic_rules(std::vector<Rule> * _rel2rules)
    {
        if (rel2rules != NULL)
        {
            for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear();
            delete [] rel2rules;
        }
        
        rel2rules = new std::vector<Rule> [p_kg->relation_size];
        for (int r = 0; r != p_kg->relation_size; r++)
        {
            rel2rules[r] = _rel2rules[r];
            for (int k = 0; k != int(rel2rules[r].size()); k++)
            {
                rel2rules[r][k].wt.clear();
                rel2rules[r][k].H = 0;
            }
        }
    }
    
    std::vector<Rule> *get_logic_rules()
    {
        return rel2rules;
    }
    
    void set_pool(std::vector<Rule> * _rel2rules)
    {
        if (rel2pool != NULL)
        {
            for (int r = 0; r != p_kg->relation_size; r++) rel2pool[r].clear();
            delete [] rel2pool;
        }
        
        rel2pool = new std::vector<Rule> [p_kg->relation_size];
        for (int r = 0; r != p_kg->relation_size; r++)
        {
            rel2pool[r] = _rel2rules[r];
            for (int k = 0; k != int(rel2pool[r].size()); k++)
            {
                rel2pool[r][k].wt.clear();
                rel2pool[r][k].H = 0;
                rel2pool[r][k].cn = 0;
            }
        }
    }
    
    void sample_from_pool(int _number, double _temperature=1)
    {
        if (rel2rules != NULL)
        {
            for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear();
            delete [] rel2rules;
        }
        if (mapping != NULL)
        {
            for (int r = 0; r != p_kg->relation_size; r++) mapping[r].clear();
            delete [] mapping;
        }
        
        rel2rules = new std::vector<Rule> [p_kg->relation_size];
        mapping = new std::vector<int> [p_kg->relation_size];
        for (int r = 0; r != p_kg->relation_size; r++)
        {
            std::vector<double> probability;
            double max_val = -1000000, sum_val = 0;
            for (int k = 0; k != int(rel2pool[r].size()); k++)
                max_val = std::max(max_val, rel2pool[r][k].H);
            for (int k = 0; k != int(rel2pool[r].size()); k++)
                sum_val += exp((rel2pool[r][k].H - max_val) / _temperature);
            for (int k = 0; k != int(rel2pool[r].size()); k++)
                probability.push_back(exp((rel2pool[r][k].H - max_val) / _temperature) / sum_val);
            
            for (int k = 0; k != _number; k++)
            {
                double sum_prob = 0, rand_val = gsl_rng_uniform(gsl_r);// double(rand()) / double(RAND_MAX);
                for (int index = 0; index != int(rel2pool[r].size()); index++)
                {
                    sum_prob += probability[index];
                    if (sum_prob > rand_val)
                    {
                        rel2rules[r].push_back(rel2pool[r][index]);
                        mapping[r].push_back(index);
                        break;
                    }
                }                
            }
        }
    }
    
    void random_from_pool(int _number)
    {
        if (rel2rules != NULL)
        {
            for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear();
            delete [] rel2rules;
        }
        if (mapping != NULL)
        {
            for (int r = 0; r != p_kg->relation_size; r++) mapping[r].clear();
            delete [] mapping;
        }
        
        rel2rules = new std::vector<Rule> [p_kg->relation_size];
        mapping = new std::vector<int> [p_kg->relation_size];
        std::vector<int> rand_index;
        for (int r = 0; r != p_kg->relation_size; r++)
        {
            rand_index.clear();
            for (int k = 0; k != int(rel2pool[r].size()); k++) rand_index.push_back(k);
            std::random_shuffle(rand_index.begin(), rand_index.end());
            for (int k = 0; k != int(rel2pool[r].size()); k++)
            {
                if (k >= _number) break;
                int index = rand_index[k];
                rel2rules[r].push_back(rel2pool[r][index]);
                mapping[r].push_back(index);
            }
        }
    }
    
    void best_from_pool(int _number)
    {
        if (rel2rules != NULL)
        {
            for (int r = 0; r != p_kg->relation_size; r++) rel2rules[r].clear();
            delete [] rel2rules;
        }
        if (mapping != NULL)
        {
            for (int r = 0; r != p_kg->relation_size; r++) mapping[r].clear();
            delete [] mapping;
        }
        
        rel2rules = new std::vector<Rule> [p_kg->relation_size];
        mapping = new std::vector<int> [p_kg->relation_size];
        std::vector<RankListEntry> rank_list;
        RankListEntry entry;
        for (int r = 0; r != p_kg->relation_size; r++)
        {
            rank_list.clear();
            for (int k = 0; k != int(rel2pool[r].size()); k++)
            {
                entry.id = k;
                entry.val = rel2pool[r][k].H;
                rank_list.push_back(entry);
            }
            std::sort(rank_list.begin(), rank_list.end());
            for (int k = 0; k != int(rel2pool[r].size()); k++)
            {
                if (k >= _number) break;
                int index = rank_list[k].id;
                rel2rules[r].push_back(rel2pool[r][index]);
                mapping[r].push_back(index);
            }
        }
    }
    
    void update(std::vector<Rule> * _rel2rules)
    {
        for (int r = 0; r != p_kg->relation_size; r++)
        {
            for (int k = 0; k != int(rel2rules[r].size()); k++) rel2rules[r][k].H = _rel2rules[r][k].H;
            for (int k = 0; k != int(rel2rules[r].size()); k++)
            {
                int index = mapping[r][k];
                rel2pool[r][index].H = (rel2pool[r][index].H * rel2pool[r][index].cn + rel2rules[r][k].H) / (rel2pool[r][index].cn + 1);
                rel2pool[r][index].cn += 1;
            }
        }
    }
};
