#pragma once
#include "utils.h"
#include "kg.h"

class Predictor
{
public:
    std::vector<Rule> *rules;
    int n_thread, top_k, n_epoch;
    double temperature, learning_rate, weight_decay;
    double portion;
    double prior_weight, H_temperature;
    bool test;
    bool clip;
    long long total_count;
    double total_loss;
    std::vector< std::pair<int, int> > ranks;
    KnowledgeGraph *p_kg;
    Graph *G;
    sem_t mutex;
    real_t *mean;

    Predictor(std::vector<Rule> *_rules, KnowledgeGraph *kg, Graph *g, bool clipping=false);
    ~Predictor();
    
    void rule_dest(Rule *r, int vst, std::vector<int> &dests, Triplet removed_triplet);
    void learn_thread(int thread);
    static void *learn_thread_caller(void *arg);
    void learn(double _learning_rate, double _weight_decay, double _temperature, double _portion, int _num_threads);

    void evaluate_thread(int thread);
    static void *evaluate_thread_caller(void *arg);
    Result evaluate(int _num_threads);
};

