#include "utils.h"

ArgStruct::ArgStruct(void *_ptr, int _id)
{
    ptr = _ptr;
    id = _id;
}

bool operator < (Triplet u, Triplet v)
{
    if (u.r == v.r)
    {
        if (u.h == v.h) return u.t < v.t;
        return u.h < v.h;
    }
    return u.r < v.r;
}
    
bool operator == (Triplet u, Triplet v)
{
    if (u.h == v.h && u.t == v.t && u.r == v.r) return true;
    return false;
}

Parameter::Parameter(bool randinit)
{
    var.value = randinit ? randn() : 0.0;
    // var.value = 0;
    var.grad = 0;
    m = 0;
    v = 0;
    vm = 0;
    t = 0;
    record = 0;
}

Rule::Rule()
{
    wt = Parameter(false);
    contribution = 0;
}

void Parameter::clear()
{
    m = 0;
    v = 0;
    vm = 0;
    t = 0;
    var.grad=0;
}

bool operator == (const Rule& a, const Rule& b)
{
    if (a.head == b.head)
        return a.r_body == b.r_body;
    return a.head == b.head;
}

bool operator < (const Rule &a, const Rule &b)
{
    if (a.head == b.head)
        return a.r_body < b.r_body;
    return a.head < b.head;
}

void Parameter::set_init_value(real_t x)
{
    var.value = x;
    record = x;
}

void Parameter::update(real_t lr, real_t weight_decay, bool maximize)
{
    real_t g = var.grad - weight_decay * (var.value - record);
    t += 1;
    m = 0.9 * m + 0.1 * g;
    v = 0.999 * v + 0.001 * g * g;

    real_t bias1 = 1 - std::pow(0.9, t);
    real_t bias2 = 1 - std::pow(0.999, t);

    real_t mt = m / bias1;
    real_t vt = sqrt(v) / sqrt(bias2) + 0.00000001;

    if (maximize)
        var.value += lr * mt / vt;
    else
        var.value -= lr * mt / vt;
    var.grad = 0;
}

Variable::Variable()
{
    value = 0;
    grad = 0;
}

void Rule::clear()
{
    r_body.clear();
    head = -1;
    wt = Parameter(false);
}

Graph::Graph(int _n_vertex, int _n_edge_type)
{
    n_vertex = _n_vertex;
    n_edge_type = _n_edge_type;
    linklist = new std::vector<int> *[n_vertex];
    for (int i = 0; i < n_vertex; i++)
    {
        linklist[i] = new std::vector<int>[n_edge_type];
    }
    offset = _n_edge_type / 2;
}

Graph::~Graph()
{
    for (int i = 0; i < n_vertex; i++)
    {
        delete[] linklist[i];
    }
    delete[] linklist;
}

void Graph::add_edge(int v_from, int v_to, int edge_type)
{
    linklist[v_from][edge_type].push_back(v_to);
}

bool operator < (RankListEntry u, RankListEntry v)
{
    return u.val > v.val;
}

Result::Result()
{
    h1 = 0; h3 = 0; h10 = 0; mr = 0; mrr = 0;
}

Result::Result(double mr_, double mrr_, double h1_, double h3_, double h10_)
{
    h1 = h1_; h3 = h3_; h10 = h10_; mr = mr_; mrr = mrr_;
}