#include <iostream>
#include <fstream>
#include <set>
#include <ctime>
#include <assert.h>
#include <map>
#include <sys/time.h>
#include "utils.h"
#include "kg.h"
#include "rule.h"
#include "ruletrainer.h"
#include "predictor.h"
using namespace std;


// argv = data_folder, reweight_type, update_function_clip, alpha, n_threads
int main(int argc, char *argv[])
{
    struct timeval t1, t2;
    double timeuse;
    gettimeofday(&t1,NULL);
    
    if (argc < 6)
    {
        cout<<"Wrong arguments!\n";
        return 0;
    }

    string folder = argv[1];
    string reweight_type = argv[2];
    string update_function_clip = argv[3];
    int alpha = atoi(argv[4]);
    int n_thread = atoi(argv[5]);


    srand(time(0));
    int v_from, v_to;
    string ruleoutput;
    KnowledgeGraph *KG = new KnowledgeGraph();
    KG->load_data(folder + "/train.txt");
    KG->load_valid_data(folder + "/valid.txt");
    KG->load_test_data(folder + "/test.txt");

    Graph *G = new Graph(KG->n_entity, KG->n_relation_with_reverse());
    for (auto it = KG->triplets.begin(); it != KG->triplets.end(); it++)
    {
        G->add_edge(it->h, it->t, it->r);
    }



    RuleTrainer x = RuleTrainer(G);
    x.negative_sample_all();
    // x.negative_sample(5);

    vector<Rule> *rules = new vector<Rule>[G->n_edge_type];

    real_t ratio;
    if (reweight_type == "Y")
        ratio = 0.8;
    else if (reweight_type == "N")
        ratio = 0.0;
    else
    {
        cout<<"Wrong argument!\n";
        return 0;
    }
    
    x.train(400, 3, 15, 1, n_thread, rules, 0.01, alpha, 0.2, ratio);



// read rules from txt------------------------------------
// cout<<"Reading rules\n";
// ifstream fin(ruleoutput);
// Rule input;
// double flin;
// int tmp;
// int cnt = 0;
// int step = 0;
// while (fin>>input.head)
// {
//     cout<<step<<'\n';
//     step++;
//     input.r_body.clear();
//     fin>>flin;
//     input.wt.set_init_value(flin);
//     for (int i = 0; i < 3; i++)
//     {
//         fin>>tmp;
//         input.r_body.push_back(tmp);
//     }
//     auto p = x.accuracy(&input);
//     input.wt.var.value = (real_t)p.first/(real_t)p.second;
//     rules[input.head].push_back(input);
// }
// cout<<"Finishied\n";
// -------------------------------------------------------

// output rules in txt-------------------------------------
    // ofstream fout(ruleoutput);
    // for (int k = 0; k < G->n_edge_type; k++)
    // {
    //     for (auto it = rules[k].begin(); it != rules[k].end(); it++)
    //     {
    //         fout<<it->head<<" "<<it->wt.var.value<<" ";
    //         for (auto it2 = it->r_body.begin(); it2 != it->r_body.end(); it2++)
    //         {
    //             fout<<*it2<<" ";
    //         }
    //         fout<<"\n";
    //         // fout<<it->wt.var.value<<" ";
    //     }
    // }
    // fout.close();
// -------------------------------------------------------

// output rules in str-------------------------------------------------
    // ofstream fout(ruleoutput);
    // fout<<"[";
    // for (int i = 0; i < G->n_edge_type; i++)
    // {
    //     fout<<"[";
    //     for (int j = 0; j < rules[i].size(); j++)
    //     {
    //         fout<<"[";
    //         fout<<"\""<<KG->id2r[rules[i][j].head]<<"\",";
    //         for (int k = 0; k < rules[i][j].r_body.size(); k++)
    //             fout<<"\""<<KG->id2r[rules[i][j].r_body[k]]<<"\",";
    //         fout<<"],\n";
    //     }
    //     fout<<"],";
    // }
    // fout<<"]";
    // fout.close();
// ------------------------------------------------------------

    cout<<"\n";
    cout<<"Rule Num: ";
    int rrr = 0;
    for (int i = 0; i < G->n_edge_type; i++)
    {
        cout<<" "<<rules[i].size();
        if (rules[i].size() == 0) rrr++;
    }
    cout<<"\n";
    cout<<rrr<<endl;
    cout<<endl;

    bool clipping;
    if (update_function_clip == "Y")
        clipping = true;
    else if (update_function_clip == "N")
        clipping = false;
    else
    {
        cout<<"Wrong argument!\n";
        return 0;
    }

    Predictor pred = Predictor(rules, KG, G, clipping);
    pred.n_epoch = 1;
    pred.test = true;
    Result result = pred.evaluate(n_thread);
    printf("Iter 0, MR: %lf, MRR: %lf, Hit@1: %lf, Hit@3: %lf, Hit@10: %lf.\n", result.mr, result.mrr, result.h1, result.h3, result.h10);
    for (int i = 0; i < 30; i++)
    {
        pred.learn(0.01, 0.0005, 100, 1.0, n_thread);
        Result result = pred.evaluate(n_thread);
        printf("Iter %d, MR: %lf, MRR: %lf, Hit@1: %lf, Hit@3: %lf, Hit@10: %lf.\n", (i+1)*pred.n_epoch, result.mr, result.mrr, result.h1, result.h3, result.h10);
    }


    vector<int> test_count(G->n_edge_type);
    for (auto it = KG->test_triplets.begin(); it != KG->test_triplets.end(); it++)
    {
        test_count[it->r]++;
    }

    // This is for summarizing rule distributions. --------------------------
    // ofstream foutcont("output.txt");
    // for (int i = 0; i < G->n_edge_type; i++)
    // {
    //     for (int j = 0; j < rules[i].size(); j++)
    //     {
    //         foutcont<<rules[i][j].wt.var.value<<" "<<rules[i][j].contribution/real_t(test_count[i])<<"\n";
    //     }
    // }
    // foutcont.close();
    // -----------------------------------------------------------------------

    gettimeofday(&t2,NULL);
    timeuse = (t2.tv_sec - t1.tv_sec) + (double)(t2.tv_usec - t1.tv_usec)/1000000.0;

    cout<<"time = "<<timeuse<<endl; 
    delete[] rules;
    return 0;
}