#include "predictor.h"

Predictor::Predictor(std::vector<Rule> *_rules, KnowledgeGraph *kg, Graph *g, bool clipping)
{
    n_thread = 4; top_k = 100; n_epoch = 1;
    temperature = 100; learning_rate = 0.01; weight_decay = 0.0005;
    portion = 1.0;
    prior_weight = 0; H_temperature = 1;
    total_count = 0; total_loss = 0;
    rules = NULL;
    test = true;
    ranks.clear();
    p_kg = kg;
    G = g;
    clip = clipping;
    sem_init(&mutex, 0, 1);

    rules = _rules;
    mean = new real_t[g->n_edge_type];
    for (int i = 0; i < g->n_edge_type; i++)
    {
        mean[i] = 0.0;
        for (int j = 0; j < int(_rules[i].size()); j++)
        {
            mean[i] += _rules[i][j].wt.var.value;
        }
        mean[i] /= real_t(_rules[i].size());
    }
}

Predictor::~Predictor()
{
    total_count = 0; total_loss = 0;
    ranks.clear();
    sem_init(&mutex, 0, 1);
    delete[] mean;
}

void Predictor::rule_dest(Rule *r, int vst, std::vector<int> &dests, Triplet removed_triplet)
{
    std::queue< std::pair<int, int> > queue;
    queue.push(std::make_pair(vst, 0));
    int current_e, current_d, current_r, next_e;
    while (!queue.empty())
    {
        std::pair<int, int> pair = queue.front();
        current_e = pair.first;
        if (current_e == -1) return;
        current_d = pair.second;
        queue.pop();
        if (current_d == int(r->r_body.size()))
        {
            dests.push_back(current_e);
            continue;
        }
        current_r = r->r_body[current_d];
        for (auto ptrvto = G->linklist[current_e][current_r].begin(); ptrvto != G->linklist[current_e][current_r].end(); ptrvto++)
        {
            next_e = *ptrvto;
            if (current_e == removed_triplet.h && current_r == removed_triplet.r && next_e == removed_triplet.t)
                continue;
            queue.push(std::make_pair(next_e, current_d + 1));
        }
    }

    // set all value to 1-------------------------
    if (clip)
    {
        std::set<int> s;
        for (int i = 0; i < dests.size(); i++)
            s.insert(dests[i]);
        dests.clear();
        for (auto it = s.begin(); it != s.end(); it++)
            dests.push_back(*it);
    }
    // -------------------------------------------
}

void Predictor::learn_thread(int thread)
{
    int triplet_size = p_kg->triplets.size();
    int bg = int(triplet_size / n_thread) * thread;
    int ed = bg + int(triplet_size / n_thread * portion);
    if (thread == n_thread - 1 && portion == 1) ed = triplet_size;
    
    std::vector<int> dests;
    Triplet triplet;
    int h, r, dest, index;
    double logit, target, grad;
    std::map<int, double> dest2logit;
    std::map<int, std::vector<int> > dest2index;
    std::map<int, double>::iterator iter;
    
    for (int epoch = 0; epoch < n_epoch; epoch++)
    {
    for (int T = bg; T < ed; T++)
    {
        if (T % 10 == 0)
        {
            total_count += 10;
            printf("Learning Rule Weights: %.3lf%% | Loss: %.6lf          %c", (double)total_count / (double)(triplet_size * portion * n_epoch + 1) * 100, total_loss / (total_count), 13);
            fflush(stdout);
        }
        
        h = p_kg->triplets[T].h;
        r = p_kg->triplets[T].r;

                
        dest2logit.clear();
        dest2index.clear();
        for (index = 0; index != int(rules[r].size()); index++)
        {
            dests.clear();
            rule_dest(&rules[r][index], h, dests, p_kg->triplets[T]);
            for (int i = 0; i != int(dests.size()); i++)
            {
                dest = dests[i];
                if (dest2logit.count(dest) == 0) dest2logit[dest] = 0;
                if (dest2index.count(dest) == 0) dest2index[dest] = std::vector<int>();
                dest2logit[dest] += rules[r][index].wt.var.value / temperature;
                dest2index[dest].push_back(index);
            }
        }

        double max_val = -1000000, sum_val = 0;
        for (iter = dest2logit.begin(); iter != dest2logit.end(); iter++)
            max_val = std::max(max_val, iter->second);
        for (iter = dest2logit.begin(); iter != dest2logit.end(); iter++)
            sum_val += exp(iter->second - max_val);
        for (iter = dest2logit.begin(); iter != dest2logit.end(); iter++)
            dest2logit[iter->first] = exp(dest2logit[iter->first] - max_val) / sum_val;
        for (iter = dest2logit.begin(); iter != dest2logit.end(); iter++)
        {
            dest = iter->first;
            logit = iter->second;
            
            triplet = p_kg->triplets[T];
            triplet.t = dest;
            if (std::count(G->linklist[triplet.h][triplet.r].begin(), G->linklist[triplet.h][triplet.r].end(), triplet.t))
                target = 1;
            else
                target = 0;
            grad = (target - logit) / temperature;
            
            double add = (target - logit) / dest2logit.size();
            if (add < 0)
                add = -add;
            total_loss += add;

            for (int k = 0; k != int(dest2index[dest].size()); k++)
            {
                index = dest2index[dest][k];
                rules[r][index].wt.var.grad = grad;
                rules[r][index].wt.update(learning_rate, weight_decay, true);
            }
        }
    }
    }
    dest2logit.clear();
    dest2index.clear();
    dests.clear();
}

void *Predictor::learn_thread_caller(void *arg)
{
    Predictor *ptr = (Predictor *)(((ArgStruct *)arg)->ptr);
    int thread = ((ArgStruct *)arg)->id;
    ptr->learn_thread(thread);
    pthread_exit(NULL);
}

void Predictor::learn(double _learning_rate, double _weight_decay, double _temperature, double _portion, int _n_thread)
{
    learning_rate = _learning_rate;
    weight_decay = _weight_decay;
    temperature = _temperature;
    portion = _portion;
    n_thread = _n_thread;
    
    total_count = 0;
    total_loss = 0;

    std::random_shuffle((p_kg->triplets).begin(), (p_kg->triplets).end());
    
    pthread_t *pt = (pthread_t *)malloc(n_thread * sizeof(pthread_t));
    for (int k = 0; k < n_thread; k++)
        pthread_create(&pt[k], NULL, Predictor::learn_thread_caller, new ArgStruct(this, k));
    for (int k = 0; k < n_thread; k++)
        pthread_join(pt[k], NULL);
    printf("Learning Rule Weights | DONE! | Loss: %.6lf                             \n", total_loss / total_count);
    free(pt);
}

void Predictor::evaluate_thread(int thread)
{
    std::vector<Triplet> *p_triplets;
    p_triplets = &(p_kg->test_triplets);
    
    int triplet_size = int((*p_triplets).size());

    
    int bg = int(triplet_size / n_thread) * thread;
    int ed = int(triplet_size / n_thread) * (thread + 1);
    if (thread == n_thread - 1) ed = triplet_size;
    
    std::vector<int> dests;
    Triplet triplet;
    int h, r, t, dest, num_g, num_ge, index;
    double t_val;
    RankListEntry *rank_list;
    
    rank_list = new RankListEntry [G->n_vertex];
    // std::ofstream fout(std::to_string(thread));

    for (int T = bg; T != ed; T++)
    {
        if (T % 10 == 0)
        {
            total_count += 10;
            printf("Evaluation | Progress: %.3lf%%          %c", (double)total_count / (double)(triplet_size + 1) * 100, 13);
            fflush(stdout);
        }
        
        h = (*p_triplets)[T].h;
        r = (*p_triplets)[T].r;
        t = (*p_triplets)[T].t;
        
        real_t m = 0.0;
        for (index = 0; index != int(rules[r].size()); index++)
        {
            m += rules[r][index].wt.var.value;
        }
        m /= real_t(rules[r].size());

        for (int k = 0; k < G->n_vertex; k++)
        {
            rank_list[k].id = k;
            rank_list[k].val = 0;
        }
        
        for (index = 0; index != int(rules[r].size()); index++)
        {
            dests.clear();
            rule_dest(&rules[r][index], h, dests, (*p_triplets)[T]);
            for (int i = 0; i != int(dests.size()); i++)
            {
                dest = dests[i];
                rank_list[dest].val += rules[r][index].wt.var.value;
                if (dest == t)
                {
                    rules[r][index].contribution += rules[r][index].wt.var.value;
                }
            }

        }

        t_val = rank_list[t].val;
//debug
// std::cout<<"Weights: "<<t_val<<"\n";
        
// output
        // triplet.h = h;
        // triplet.r = r;
        // triplet.t = t;
        // fout<<h<<" "<<r<<" "<<t<<" ";
        // for (index = 0; index < G->n_vertex; index++)
        // {
        //     fout<<rank_list[index].val<<" ";
        // }
        // fout<<"\n";
        // for (int k = 0; k != G->n_vertex; k++)
        // {
        //     triplet.t = k;
        //     if (p_kg->all_triplets.count(triplet))
        //         fout<<k<<" ";
        // }
        // fout<<"\n";


        std::sort(rank_list, rank_list + G->n_vertex);


        num_g = 0; num_ge = 0;
        triplet = (*p_triplets)[T];
        for (int k = 0; k != G->n_vertex; k++)
        {
            triplet.t = rank_list[k].id;
            if (p_kg->all_triplets.count(triplet) && rank_list[k].id != t)
                continue;

            if (rank_list[k].val > t_val) num_g += 1;
            if (rank_list[k].val >= t_val) num_ge += 1;
            if (rank_list[k].val < t_val) break;
        }

        sem_wait(&mutex);
        ranks.push_back(std::make_pair(num_g, num_ge));
//debug
// std::cout<<num_g<<" "<<num_ge<<"\n";

        sem_post(&mutex);
    }

    delete []rank_list;
    dests.clear();
    pthread_exit(NULL);
}

void *Predictor::evaluate_thread_caller(void *arg)
{
    Predictor *ptr = (Predictor *)(((ArgStruct *)arg)->ptr);
    int thread = ((ArgStruct *)arg)->id;
    ptr->evaluate_thread(thread);
    pthread_exit(NULL);
}

Result Predictor::evaluate(int _num_threads)
{
    // mean = 0;
    // int cnt = 0;
    // for (int r = 0; r < G->n_edge_type; r++)
    // {
    //     for (int i = 0; i < rules[r].size(); i++)
    //     {
    //         mean += rules[r][i].wt.var.value;
    //         cnt ++;
    //     }
    // }
    // mean /= (real_t)cnt;
    // std::cout<<"MEAN: "<<mean<<"\n";

    n_thread = _num_threads;
    
    ranks.clear();
    total_count = 0;
    total_loss = 0;
    sem_init(&mutex, 0, 1);
    
    pthread_t *pt = (pthread_t *)malloc(n_thread * sizeof(pthread_t));
    for (int k = 0; k != n_thread; k++) pthread_create(&pt[k], NULL, Predictor::evaluate_thread_caller, new ArgStruct(this, k));
    for (int k = 0; k != n_thread; k++) pthread_join(pt[k], NULL);
    if (test == true) printf("Evaluation Test | DONE!                              \n");
    else printf("Evaluation Valid | DONE!                              \n");
    free(pt);

    int num_entities = G->n_vertex;
    double *table_mr = (double *)calloc(num_entities + 1, sizeof(double));
    double *table_mrr = (double *)calloc(num_entities + 1, sizeof(double));
    double *table_hit1 = (double *)calloc(num_entities + 1, sizeof(double));
    double *table_hit3 = (double *)calloc(num_entities + 1, sizeof(double));
    double *table_hit10 = (double *)calloc(num_entities + 1, sizeof(double));
    for (int rank = 1; rank <= num_entities; rank++)
    {
        table_mr[rank] = rank;
        table_mrr[rank] = 1.0 / rank;
        if (rank <= 1) table_hit1[rank] = 1;
        if (rank <= 3) table_hit3[rank] = 1;
        if (rank <= 10) table_hit10[rank] = 1;
    }
    for (int rank = 1; rank <= num_entities; rank++)
    {
        table_mr[rank] += table_mr[rank - 1];
        table_mrr[rank] += table_mrr[rank - 1];
        table_hit1[rank] += table_hit1[rank - 1];
        table_hit3[rank] += table_hit3[rank - 1];
        table_hit10[rank] += table_hit10[rank - 1];
    }

    
    double mr = 0, mrr = 0, hit1 = 0, hit3 = 0, hit10 = 0;
    std::vector< std::pair<int, int> >::iterator iter;
    for (iter = ranks.begin(); iter != ranks.end(); iter++)
    {
        int num_g = iter->first;
        int num_ge = iter->second;
        mr += (table_mr[num_ge] - table_mr[num_g]) / (num_ge - num_g);
        mrr += (table_mrr[num_ge] - table_mrr[num_g]) / (num_ge - num_g);
        hit1 += (table_hit1[num_ge] - table_hit1[num_g]) / (num_ge - num_g);
        hit3 += (table_hit3[num_ge] - table_hit3[num_g]) / (num_ge - num_g);
        hit10 += (table_hit10[num_ge] - table_hit10[num_g]) / (num_ge - num_g);
    }

    free(table_mr);
    free(table_mrr);
    free(table_hit1);
    free(table_hit3);
    free(table_hit10);

    mr /= ranks.size();
    mrr /= ranks.size();
    hit1 /= ranks.size();
    hit3 /= ranks.size();
    hit10 /= ranks.size();

    Result result(mr, mrr, hit1, hit3, hit10);
    return result;
}
