#pragma once
#include <vector>
#include <iostream>

#include "utils.h"
#include "rule.h"


class RuleTrainer
{
public:

    Graph *G;

    // vto[e][v_from] = neibor of v_from via e.
    std::vector<int> **vto;
    std::vector<real_t> **weight;
    int **npos;
    int *total_pos;
    int *total_neg;
    std::vector<int> *vfrom;
    real_t lr;
    real_t power;
    real_t norm;
    real_t ratio;

    sem_t mutex;
    dRule *dRs;
    std::queue<int> tasks;
    bool ***nonzero;

    struct Args
    {
        int n_rule;
        int len_rule;
        int n_epoch;
        int n_batch_per_epoch;
        std::vector<Rule> *rules;
        RuleTrainer *ruletrainer;
    };

    RuleTrainer(Graph* _G);
    ~RuleTrainer();

    std::vector<int> end_set(int vst, int n_step=3);
    void rule_dest(Rule *r, int vst, std::set<int> &dest);
    std::pair<int, int> accuracy(Rule *r);
    std::pair<int, int> accuracy2(Rule *r);
    bool reweight(Rule *r);
    // assert all rules are of same rulehead
    bool reweight_multirules(std::vector<Rule> *rules);
    void negative_sample_all(int n_step=3);
    void negative_sample(real_t rate=3.0);

    void setup_nonzero(int n_step=3);

    // compute l0, l1.
    std::pair<real_t, real_t> l_value(int head_type, dRule *dR, int batch_st=0, int batch_size=0);
    Rule train_one_rule(int head_type, int n_epoch, int n_batch_per_epoch, dRule *dR);
    void train_one_head(int n_rule, int n_epoch, int n_batch_per_epoch, int head_type, std::vector<Rule> *rule, dRule *dR);
    static void *train_thread_caller(void *arg);
    void train_thread(int n_rule, int len_rule, int n_epoch, int n_batch_per_epoch, std::vector<Rule> *rules, int thread);
    void train(int n_rule, int len_rule, int n_epoch, int n_batch_per_epoch, int n_thread, std::vector<Rule> *rules, real_t _lr=1e-2, real_t _power=2.0, real_t _norm=0.2, real_t _ratio=0.8);
};



