#include "ruletrainer.h"


RuleTrainer::RuleTrainer(Graph* _G)
{
    G = _G;
    dRs = NULL;
    vto = new std::vector<int> *[_G->n_edge_type];
    npos = new int *[_G->n_edge_type];
    weight = new std::vector<real_t> *[_G->n_edge_type];
    vfrom = new std::vector<int>[_G->n_edge_type];
    total_pos = new int[_G->n_edge_type];
    total_neg = new int[_G->n_edge_type];

    for (int k = 0; k < _G->n_edge_type; k++)
    {
        vto[k] = new std::vector<int>[_G->n_vertex];
        npos[k] = new int[_G->n_vertex];
        weight[k] = new std::vector<real_t>[_G->n_vertex];
        total_pos[k] = 0;
        total_neg[k] = 0;
    }

    for (int k = 0; k < _G->n_edge_type; k++)
    {
        for (int v_from = 0; v_from < _G->n_vertex; v_from++)
        {
            npos[k][v_from] = 0;
            if (G->linklist[v_from][k].size() > 0)
            {
                vfrom[k].push_back(v_from);
                npos[k][v_from] = G->linklist[v_from][k].size();
                total_pos[k] += npos[k][v_from];
            }
            for (auto ptrvto = G->linklist[v_from][k].begin(); ptrvto != G->linklist[v_from][k].end(); ptrvto++)
            {
                vto[k][v_from].push_back(*ptrvto);
                weight[k][v_from].push_back(1.0);
            }
        }
    }

    sem_init(&mutex, 0, 1);
}

RuleTrainer::~RuleTrainer()
{
    for (int k = 0; k < G->n_edge_type; k++)
    {
        delete[] vto[k];
        delete[] npos[k];
        delete[] weight[k];
    }
    delete[] vto;
    delete[] npos;
    delete[] weight;
    delete[] vfrom;
    if (dRs != NULL)
        delete[] dRs;
    sem_init(&mutex, 0, 1);
}

std::vector<int> RuleTrainer::end_set(int vst, int n_step)
{
    std::set<int> s[2];
    s[0].insert(vst);
    for (int i = 0; i < n_step; i++)
    {
        for (auto it = s[i%2].begin(); it != s[i%2].end(); it++)
        {
            for (int k = 0; k < G->n_edge_type; k++)
            {
                for (auto ptrvto = G->linklist[*it][k].begin(); ptrvto != G->linklist[*it][k].end(); ptrvto++)
                {
                    s[(i+1)%2].insert(*ptrvto);
                }
            }
        }
        s[i%2].clear();
    }

    std::vector<int> ret;
    for (auto it = s[n_step%2].begin(); it != s[n_step%2].end(); it++)
    {
        ret.push_back(*it);
    }
    return ret;
}

void RuleTrainer::rule_dest(Rule *r, int vst, std::set<int> &dest)
{
    dest.clear();

    std::queue< std::pair<int, int> > queue;
    queue.push(std::make_pair(vst, 0));
    int current_e, current_d, current_r, next_e;
    while (!queue.empty())
    {
        std::pair<int, int> pair = queue.front();
        current_e = pair.first;
        if (current_e == -1) return;
        current_d = pair.second;
        queue.pop();
        if (current_d == int(r->r_body.size()))
        {
            dest.insert(current_e);
            continue;
        }
        current_r = r->r_body[current_d];
        for (auto ptrvto = G->linklist[current_e][current_r].begin(); ptrvto != G->linklist[current_e][current_r].end(); ptrvto++)
        {
            next_e = *ptrvto;
            queue.push(std::make_pair(next_e, current_d + 1));
        }
    }
}

std::pair<int, int> RuleTrainer::accuracy(Rule *r)
{
    int n0 = 0, n1 = 0;

    std::set<int> dest;
    for (int vfrom = 0; vfrom < G->n_vertex; vfrom++)
    {
        rule_dest(r, vfrom, dest);
        for (int i = 0; i < npos[r->head][vfrom]; i++)
            n0 += dest.count(vto[r->head][vfrom][i]);
        n1 += dest.size();
    }
    return std::pair<int, int>{n0, n1};

// debug
    // real_t n0 = 0, n1 = 0;

    // std::set<int> dest;
    // for (int vfrom = 0; vfrom < G->n_vertex; vfrom++)
    // {
    //     rule_dest(r, vfrom, dest);
    //     for (int i = 0; i < npos[r->head][vfrom]; i++)
    //         n0 += dest.count(vto[r->head][vfrom][i]) * weight[r->head][vfrom][i];
    //     for (int i = 0; i < vto[r->head][vfrom].size(); i++)
    //         n1 += dest.count(vto[r->head][vfrom][i]) * weight[r->head][vfrom][i];
    // }
    // return std::pair<real_t, real_t>{n0, n1};
}

std::pair<int, int> RuleTrainer::accuracy2(Rule *r)
{
    int n0 = 0, n1 = 0;

    std::set<int> dest;
    for (int vfrom = 0; vfrom < G->n_vertex; vfrom++)
    {
        rule_dest(r, vfrom, dest);
        for (int i = 0; i < npos[r->head][vfrom]; i++)
            n0 += dest.count(vto[r->head][vfrom][i]);
        n1 += npos[r->head][vfrom];
    }
    return std::pair<int, int>{n0, n1};
}

bool RuleTrainer::reweight(Rule *r)
{
    if (r == NULL)
    {
        for (int k = 0; k < G->n_edge_type; k++)
        {
            for (auto ptrvfrom = vfrom[k].begin(); ptrvfrom != vfrom[k].end(); ptrvfrom++)
            {
                for (int i = 0; i < npos[k][*ptrvfrom]; i++)
                {
                    weight[k][*ptrvfrom][i] = 1.0;
                }
                for (int i = npos[k][*ptrvfrom]; i < (int)weight[k][*ptrvfrom].size(); i++)
                {
                    weight[k][*ptrvfrom][i] = 1.0;
                }
            }
        }

        return true;
    }

    std::set<int> dest;

// debug
    // std::ofstream fout("accs.txt", std::ios::app);
    // fout<<acc<<"\n";
    // fout.close();

//debug    
    // acc = pow(acc, 1.0);
    real_t sumpospre = 0, sumnegpre = 0, sumposcur = 0, sumnegcur = 0;
    for (auto ptrvfrom = vfrom[r->head].begin(); ptrvfrom != vfrom[r->head].end(); ptrvfrom++)
    {
        rule_dest(r, *ptrvfrom, dest);
        for (int i = 0; i < npos[r->head][*ptrvfrom]; i++)
        {
            sumpospre += weight[r->head][*ptrvfrom][i];
            weight[r->head][*ptrvfrom][i] *= 1 - ratio * dest.count(vto[r->head][*ptrvfrom][i]);
            sumposcur += weight[r->head][*ptrvfrom][i];
        }
        for (int i = npos[r->head][*ptrvfrom]; i < (int)weight[r->head][*ptrvfrom].size(); i++)
        {
            sumnegpre += weight[r->head][*ptrvfrom][i];
            weight[r->head][*ptrvfrom][i] /= 1 - ratio * dest.count(vto[r->head][*ptrvfrom][i]) + 1e-10;
            sumnegcur += weight[r->head][*ptrvfrom][i];
        }
    }

    sumposcur = sumpospre / std::max(sumposcur, 1e-5);
    sumnegcur = sumnegpre / std::max(sumnegcur, 1e-5);

    for (auto ptrvfrom = vfrom[r->head].begin(); ptrvfrom != vfrom[r->head].end(); ptrvfrom++)
    {
        for (int i = 0; i < npos[r->head][*ptrvfrom]; i++)
        {
            weight[r->head][*ptrvfrom][i] *= sumposcur;
        }
        for (int i = npos[r->head][*ptrvfrom]; i < (int)weight[r->head][*ptrvfrom].size(); i++)
        {
            weight[r->head][*ptrvfrom][i] *= sumnegcur;
        }
    }
    return true;
}

bool RuleTrainer::reweight_multirules(std::vector<Rule> *rules)
{
    int n_rule = rules->size();
    int type = (*rules)[0].head;
    for (auto ptrvfrom = vfrom[type].begin(); ptrvfrom != vfrom[type].end(); ptrvfrom++)
    {
        for (int i = 0; i < npos[type][*ptrvfrom]; i++)
        {
            weight[type][*ptrvfrom][i] = 1e-3;
        }
        for (int i = npos[type][*ptrvfrom]; i < (int)weight[type][*ptrvfrom].size(); i++)
        {
            weight[type][*ptrvfrom][i] = 1e-3;
        }
    }

    std::set<int> dest;
    for (auto r = rules->begin(); r != rules->end(); r++)
    {
        for (auto ptrvfrom = vfrom[r->head].begin(); ptrvfrom != vfrom[r->head].end(); ptrvfrom++)
        {
            rule_dest(&*r, *ptrvfrom, dest);
            for (int i = 0; i < npos[r->head][*ptrvfrom]; i++)
            {
                weight[r->head][*ptrvfrom][i] += r->wt.var.value * dest.count(vto[r->head][*ptrvfrom][i]);
            }
            for (int i = npos[r->head][*ptrvfrom]; i < (int)weight[r->head][*ptrvfrom].size(); i++)
            {
                weight[r->head][*ptrvfrom][i] += r->wt.var.value * dest.count(vto[r->head][*ptrvfrom][i]);
            }
        }
    }

    // real_t mpos=1e-5, mneg=1e-5;
    // for (auto ptrvfrom = vfrom[type].begin(); ptrvfrom != vfrom[type].end(); ptrvfrom++)
    // {
    //     for (int i = 0; i < npos[type][*ptrvfrom]; i++)
    //         mpos = std::max(mpos, weight[type][*ptrvfrom][i]);
    //     for (int i = npos[type][*ptrvfrom]; i < (int)weight[type][*ptrvfrom].size(); i++)
    //         mneg = std::max(mneg, weight[type][*ptrvfrom][i]);
    // }
    for (auto ptrvfrom = vfrom[type].begin(); ptrvfrom != vfrom[type].end(); ptrvfrom++)
    {
        for (int i = 0; i < npos[type][*ptrvfrom]; i++)
        {
            // weight[type][*ptrvfrom][i] = std::max(1e-5, weight[type][*ptrvfrom][i]);
            weight[type][*ptrvfrom][i] = 1 / weight[type][*ptrvfrom][i];
        }
        for (int i = npos[type][*ptrvfrom]; i < (int)weight[type][*ptrvfrom].size(); i++)
        {
            // weight[type][*ptrvfrom][i] = std::max(1e-5, weight[type][*ptrvfrom][i]);
            // weight[type][*ptrvfrom][i] /= mneg;
        }
    }


    return true;
}

void RuleTrainer::negative_sample_all(int n_step)
{
    int a[G->n_vertex];
    std::fill(a, a + G->n_vertex, 1);
    std::vector<int> v;
    real_t w;
    for (int i = 0; i < G->n_vertex; i++)
    {
        v = end_set(i, n_step);
        for (int k = 0; k < G->n_edge_type; k++)
        {
            w = G->n_vertex;
            for (auto ptrvto = G->linklist[i][k].begin(); ptrvto != G->linklist[i][k].end(); ptrvto++)
            {
                a[*ptrvto] = 0;
                w -= 1;
            }
            w = (real_t)npos[k][i] / w;
            for (auto it = v.begin(); it != v.end(); it++)
            {
                if (a[*it] == 1)
                {
                    vto[k][i].push_back(*it);
                    weight[k][i].push_back(1.0);
                }
            }
            for (auto ptrvto = G->linklist[i][k].begin(); ptrvto != G->linklist[i][k].end(); ptrvto++)
                a[*ptrvto] = 1;
        }
        
    }
}

void RuleTrainer::negative_sample(real_t rate)
{
    if (rate == 0)
    {
        negative_sample_all();
        return;
    }

    int randvto = 0;
    int *visited = new int[G->n_vertex];
    std::fill(visited, visited + G->n_vertex, 0);
    for (int k = 0; k < G->n_edge_type; k++)
    {
        for (auto ptrvfrom = vfrom[k].begin(); ptrvfrom != vfrom[k].end(); ptrvfrom++)
        {
            for (auto ptrvto = vto[k][*ptrvfrom].begin(); ptrvto != vto[k][*ptrvfrom].end(); ptrvto++)
            {
                visited[*ptrvto] = 1;
            }
            int num = rate * vto[k][*ptrvfrom].size();
            for (int cnt = 0; cnt < num; cnt++)
            {
                randvto = rand() % G->n_vertex;
                if (visited[randvto] == 1) continue;
                visited[randvto] = 1;
                vto[k][*ptrvfrom].push_back(randvto);
                weight[k][*ptrvfrom].push_back(-1);
                total_neg[k]++;
            }
            for (auto ptrvto = vto[k][*ptrvfrom].begin(); ptrvto != vto[k][*ptrvfrom].end(); ptrvto++)
            {
                visited[*ptrvto] = 0;
            }
        }
    }
    real_t w;
    for (int k = 0; k < G->n_edge_type; k++)
    {
        w = (real_t)(G->n_vertex * G->n_vertex - total_pos[k]) / (real_t)total_neg[k];
        // w = (real_t)total_pos[k] / (real_t)total_neg[k];
        for (auto ptrvfrom = vfrom[k].begin(); ptrvfrom != vfrom[k].end(); ptrvfrom++)
        {
            for (int i = npos[k][*ptrvfrom]; i < (int)weight[k][*ptrvfrom].size(); i++)
            {
                weight[k][*ptrvfrom][i] = w;
            }
        }
    }
}

std::pair<real_t, real_t> RuleTrainer::l_value(int head_type, dRule *dR, int batch_st, int batch_size)
{
    real_t l0 = 0.0, l1 = 0.0;
    std::vector<real_t> val;
    int batch_ed = std::min(batch_st + batch_size, (int)vfrom[head_type].size());
    // for (auto ptrvfrom = vfrom[head_type].begin(); ptrvfrom != vfrom[head_type].end(); ptrvfrom++)
    for (int vf = 0; vf < vfrom[head_type].size(); vf++)
    {
        val = dR->forward(vfrom[head_type][vf], vto[head_type][vfrom[head_type][vf]]);
        for (int i = 0; i < npos[head_type][vfrom[head_type][vf]]; i++)
        {
            l0 += weight[head_type][vfrom[head_type][vf]][i] * val[i];
        }
        for (int i = npos[head_type][vfrom[head_type][vf]]; i < (int)vto[head_type][vfrom[head_type][vf]].size(); i++)
        {
            l1 += weight[head_type][vfrom[head_type][vf]][i] * val[i];
        }
    }
// debug
// std::cout<<"l: "<<l0<<" "<<l1<<"\n";

    return std::pair<real_t, real_t>{l0, l0 + l1};
}

Rule RuleTrainer::train_one_rule(int head_type, int n_epoch, int n_batch_per_epoch, dRule *dR)
{
    dR->reset(nonzero[head_type]);

    dR->set_power(power);
    std::pair<real_t, real_t> l;
    real_t dl0, dl1;
    int batch_size = vfrom[head_type].size() / n_batch_per_epoch;
    std::vector<int>::iterator ptrvfrom;
    auto vfrom_r = vfrom[head_type];

    for (int epoch = 0; epoch < n_epoch; epoch++)
    {
        std::random_shuffle(vfrom_r.begin(), vfrom_r.end());
        ptrvfrom = vfrom_r.begin();

        for (int batch = 0; batch < n_batch_per_epoch; batch++)
        {
            dR->forward_R();
            l = l_value(head_type, dR, batch * batch_size, batch_size);
            if (l.first == 0) break;
            // dl0 = -1.0 / l.second - 1.0 / ((real_t)total_pos[head_type] + 1.0);
            // dl1 = l.first / (l.second * l.second);

            dl0 = -1.0 / l.first;// - 1.0 / ((real_t)total_pos[head_type] + 1.0);
            dl1 = (1.0 - norm) / l.second;

            dl0 += dl1;


            std::vector<real_t> grad;
            for (int j = 0; j < batch_size; j++)
            {
                if (ptrvfrom == vfrom_r.end())
                    break;
                dR->forward(*ptrvfrom, vto[head_type][*ptrvfrom]);
                grad.clear();

                for (int i = 0; i < npos[head_type][*ptrvfrom]; i++)
                    grad.push_back(weight[head_type][*ptrvfrom][i] * dl0);
                for (int i = npos[head_type][*ptrvfrom]; i < (int)vto[head_type][*ptrvfrom].size(); i++)
                    grad.push_back(weight[head_type][*ptrvfrom][i] * dl1);

                    // grad.push_back(0.3);


                dR->backward(grad);
                ptrvfrom++;
            }

            dR->backward_R();
            
            dR->update(lr);
        }
    }
    // std::cout<<"\n-------------\n";
    Rule ret = dR->to_rule();
    ret.head = head_type;
    // debug
    // for (int i = 0; i < 3; i++)
    // {
    //     real_t m = 0;
    //     for (int j = 0; j < G->n_edge_type; j++)
    //     {
    //         m = std::max(m, dR->R[i][j].value);
    //     }
    //     std::cout<<m<<" ";
    // }
    // std::cout<<" acc: ";
    // auto accpair = accuracy(&ret);
    // real_t acc = (real_t)accpair.first / (real_t)accpair.second;
    // std::cout<<acc<<" "<<accpair.first<<" "<<total_pos[ret.head]<<"\n";

    return ret;
}

void RuleTrainer::train_one_head(int n_rule, int n_epoch, int n_batch_per_epoch, int head_type, std::vector<Rule> *rule, dRule *dR)
{
// original-------------------------------------------------------------
    std::set< std::vector<int> > s;
    Rule r;

    for (int i = 0; i < n_rule; i++)
    {
        r = train_one_rule(head_type, n_epoch, n_batch_per_epoch, dR);
        if (!reweight(&r))  continue;
        s.insert(r.r_body);
    }

    for (auto it = s.begin(); it != s.end(); it++)
    {
        r.r_body = *it;
        auto accs = accuracy(&r);
        real_t neg_acc = real_t(total_pos[head_type]-accs.first) / real_t(G->n_vertex*G->n_vertex-accs.second);
        r.wt.set_init_value((real_t)accs.first / (real_t)accs.second);
        rule->push_back(r);
    }

    // output rules for each head
    // std::string path = "./rules/";
    // path += std::to_string(head_type);
    // std::ofstream fout(path);
    // for (auto it = rule->begin(); it != rule->end(); it++)
    // {
    //     fout<<it->head<<" "<<it->wt.var.value;
    //     for (auto rr = it->r_body.begin(); rr != it->r_body.end(); rr++)
    //     {
    //         fout<<" "<<*rr;
    //     }
    //     fout<<"\n";
    // }
    // fout.close();
// ---------------------------------------------------------------------

// new------------------------------------------------------------------
    // std::set< std::vector<int> > s;
    // Rule r;

    // int n_rule_per_round = 400;
    // int n_round = n_rule / n_rule_per_round;
    // for (int round = 0; round < n_round; round++)
    // {
    //     for (int i = 0; i < n_rule_per_round; i++)
    //     {
    //         r = train_one_rule(head_type, n_epoch, n_batch_per_epoch, dR);
    //         if (accuracy(&r).first == 0)  continue;
    //         s.insert(r.r_body);
    //     }

    //     rule->clear();
    //     for (auto it = s.begin(); it != s.end(); it++)
    //     {
    //         r.r_body = *it;
    //         auto accs = accuracy(&r);
    //         r.wt.set_init_value((real_t)accs.first / (real_t)accs.second);
    //         rule->push_back(r);
    //     }
    //     reweight2(rule);
    // }
// ---------------------------------------------------------------------


    // sem_wait(&mutex);
    // sem_post(&mutex);
// write rules
    // std::string path = "./data/umls/rules/";
    // path += std::to_string(head_type);
    // std::ofstream fout(path);
    // for (auto it = rule->begin(); it != rule->end(); it++)
    // {
    //     fout<<it->head;
    //     for (auto it2 = it->r_body.begin(); it2 != it->r_body.end(); it2++)
    //     {
    //         fout<<" "<<*it2;
    //     }
    //     fout<<"\n";
    // }
}

void RuleTrainer::train_thread(int n_rule, int len_rule, int n_epoch, int n_batch_per_epoch, std::vector<Rule> *rules, int thread)
{
    int head;
    while (true)
    {
        sem_wait(&mutex);
        if (tasks.empty())
        {
            sem_post(&mutex);
            return;
        }
        head = tasks.front();
        tasks.pop();
        sem_post(&mutex);
        printf("Learning Rules %.3lf%% |  #Relation %d/%d          %c", (real_t)(head + 1) / (real_t)(G->n_edge_type) * 100, head + 1, G->n_edge_type, 13);
        train_one_head(n_rule, n_epoch, n_batch_per_epoch, head, &rules[head], &dRs[thread]);
        fflush(stdout);
    }
}

void *RuleTrainer::train_thread_caller(void *arg)
{
    Args *args = (Args *)(((ArgStruct *)arg)->ptr);
    int thread = ((ArgStruct *)arg)->id;
    args->ruletrainer->train_thread(args->n_rule, args->len_rule, args->n_epoch, args->n_batch_per_epoch, args->rules, thread);
    pthread_exit(NULL);
}

void RuleTrainer::train(int n_rule, int len_rule, int n_epoch, int n_batch_per_epoch, int n_thread, std::vector<Rule> *rules, real_t _lr, real_t _power, real_t _norm, real_t _ratio)
{
    // set up nonzero--------------------------------------------
    nonzero = new bool**[G->n_edge_type];
    for (int k = 0; k < G->n_edge_type; k++)
    {
        nonzero[k] = new bool*[len_rule];
        for (int l = 0;l < len_rule; l++)
        {
            nonzero[k][l] = new bool[G->n_edge_type];
            for (int i = 0; i < G->n_edge_type; i++)
                nonzero[k][l][i] = false;
        }
    }
    int vto[G->n_vertex];
    std::queue<std::pair<int, int> > Q;

    for (int vs = 0; vs < G->n_vertex; vs++)
    {
        while (!Q.empty()) Q.pop();
        std::fill(vto, vto + G->n_vertex, -1);

        Q.push({vs, 0});
        Q.push({-1, -1});
        int step = 0;
        while (!Q.empty())
        {
            auto p = Q.front();
            Q.pop();
            if (p.first == -1)
            {
                step++;
                Q.push({-1, -1});
                if (step < len_rule)
                    continue;
                else
                    break;
            }

            for (int k = 0; k < G->n_edge_type; k++)
            {
                for (auto ptrvto = G->linklist[p.first][k].begin(); ptrvto != G->linklist[p.first][k].end(); ptrvto++)
                {
                    Q.push({*ptrvto, p.second * G->n_edge_type + k});
                }
            }
        }
        while (!Q.empty())
        {
            auto p = Q.front();
            Q.pop();
            if (p.first != -1)
            {
                vto[p.first] = p.second;
            }
        }

        for (int head_type = 0; head_type < G->n_edge_type; head_type++)
        {
            for (auto ptrvto = G->linklist[vs][head_type].begin(); ptrvto != G->linklist[vs][head_type].end(); ptrvto++)
            {
                int path = vto[*ptrvto];
                if (path == -1)
                    continue;

                for (int i = 0; i < len_rule; i++)
                {
                    int rel = path % G->n_edge_type;
                    path /= G->n_edge_type;
                    nonzero[head_type][len_rule - 1 - i][rel] = true;
                }
            }
        }
    }
    // ----------------------------------------------------------

    norm = _norm;
    lr = _lr;
    power = _power;
    ratio = _ratio;
    if (dRs != NULL)
    {
        delete[] dRs;
        dRs = NULL;
    }
    dRs = new dRule[n_thread];
    for (int i = 0; i < n_thread; i++)
        new(dRs + i)dRule(len_rule, G, true);

    while (!tasks.empty())
        tasks.pop();
    //debug 
    for (int i = 0; i < G->n_edge_type; i++)
        tasks.push(i);

    Args *args = new Args{n_rule, len_rule, n_epoch, n_batch_per_epoch, rules, this};

    pthread_t *pt = (pthread_t *)malloc(n_thread * sizeof(pthread_t));
    for (int i = 0; i < n_thread; i++)
        pthread_create(&pt[i], NULL, RuleTrainer::train_thread_caller, new ArgStruct(args, i));
    for (int i = 0; i < n_thread; i++)
        pthread_join(pt[i], NULL);

    printf("Learning Rules | DONE!                                              \n");

    free(pt);
    delete args;
    for (int k = 0; k < G->n_edge_type; k++)
    {
        for (int l = 0;l < len_rule; l++)
            delete[] nonzero[k][l];
        delete[] nonzero[k];
    }
    delete[] nonzero;
    nonzero = NULL;
}
