{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ce63a777",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle5\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"MIG-6abedaa4-16cd-51b2-9b2f-043073ed897a\"\n",
    "\n",
    "from model.ours2 import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "398c7676",
   "metadata": {},
   "outputs": [],
   "source": [
    "options = {}\n",
    "\n",
    "#basic setting\n",
    "options['use_cuda'] = True\n",
    "options['vocab_dir'] = '../MINERVA/datasets/data_preprocessed/FB15K-237/vocab/'\n",
    "options['data_input_dir'] = '../MINERVA/datasets/data_preprocessed/FB15K-237/'\n",
    "options['device'] = 'cuda' if options['use_cuda'] else 'cpu'\n",
    "options['relation_vocab'] = json.load(open(options['vocab_dir'] + '/relation_vocab.json'))\n",
    "options['entity_vocab'] = json.load(open(options['vocab_dir'] + '/entity_vocab.json'))\n",
    "options['model_dir'] = './outputs_FB15K-237/'\n",
    "options['output_dir'] = './outputs_FB15K-237/'\n",
    "\n",
    "#agent setting\n",
    "options['pretrained_embeddings_relation'] = {}\n",
    "options['pretrained_embeddings_entity'] = {}\n",
    "options['embedding_size'] = 50\n",
    "options['hidden_size'] = 200\n",
    "options['use_entity_embeddings'] = 1\n",
    "options['train_entity_embeddings'] = 0\n",
    "options['train_relation_embeddings'] = 1\n",
    "options['path_length'] = 3\n",
    "options['LSTM_layers'] = 1\n",
    "options['max_num_actions'] = 100\n",
    "options['gnn_layer'] = 1\n",
    "\n",
    "#hyperparameters\n",
    "options['test_rollouts'] = 100\n",
    "options['num_rollouts'] = 20\n",
    "options['batch_size'] = 8\n",
    "options['eval_batch_size'] = 12\n",
    "options['beta'] = 0.08\n",
    "options['Lambda'] = 0.08\n",
    "options['gamma'] = 1\n",
    "options['positive_reward'] = 1\n",
    "options['negative_reward'] = 0\n",
    "options['learning_rate'] = 5e-5\n",
    "options['grad_clip_norm'] = 100\n",
    "options['eval_every'] = 100\n",
    "options['total_iterations'] = 16000\n",
    "options['pool'] = 'max'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "27aa7441",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(options)\n",
    "# trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d6a50465",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
      "  warnings.warn(warning.format(ret))\n",
      "/root/Research/GraphRL/experiments/model/ours2.py:334: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  return loss, new_state, F.log_softmax(scores), label_action, chosen_relation\n",
      "/root/Research/GraphRL/experiments/model/ours2.py:636: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  y = idx // self.max_num_actions\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.2493, Hits@3: 0.3342, Hits@10: 0.4219, MRR: 0.3060\n"
     ]
    }
   ],
   "source": [
    "trainer.agent.load_state_dict(torch.load(options['model_dir'] + 'agent.ckpt'))\n",
    "trainer.agent.eval()\n",
    "trainer.test_environment = trainer.test_test_environment\n",
    "test_results = trainer.test(beam=True, print_paths=False, save_model=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cb5b6cee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hits@1: 0.2493, Hits@3: 0.3342, Hits@10: 0.4219, MRR: 0.3060\n"
     ]
    }
   ],
   "source": [
    "print(test_results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
