{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "316f0514",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle5\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"MIG-dc67f6ae-4c27-5869-bcd2-8560a2da46c7\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d2e28869",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_params(LSTM_layers = 1, batch_size = 8, beta = 0.15, Lambda = 0.15, learning_rate = 5e-5):\n",
    "    options = {}\n",
    "\n",
    "    #basic setting\n",
    "    options['use_cuda'] = True\n",
    "    options['vocab_dir'] = '../MINERVA/datasets/data_preprocessed/nell-995/vocab/'\n",
    "    options['data_input_dir'] = '../MINERVA/datasets/data_preprocessed/nell-995/'\n",
    "    options['device'] = 'cuda' if options['use_cuda'] else 'cpu'\n",
    "    options['relation_vocab'] = json.load(open(options['vocab_dir'] + '/relation_vocab.json'))\n",
    "    options['entity_vocab'] = json.load(open(options['vocab_dir'] + '/entity_vocab.json'))\n",
    "    options['model_dir'] = './outputs_NELL-sub-tune/'\n",
    "    options['output_dir'] = './outputs_NELL-sub-tune/'\n",
    "\n",
    "    #agent setting\n",
    "    options['pretrained_embeddings_relation'] = {}\n",
    "    options['pretrained_embeddings_entity'] = {}\n",
    "    options['embedding_size'] = 50\n",
    "    options['hidden_size'] = 200\n",
    "    options['use_entity_embeddings'] = 1\n",
    "    options['train_entity_embeddings'] = 1\n",
    "    options['train_relation_embeddings'] = 1\n",
    "    options['path_length'] = 3\n",
    "    options['LSTM_layers'] = LSTM_layers\n",
    "    options['max_num_actions'] = 40\n",
    "    options['gnn_layer'] = 1\n",
    "\n",
    "    #hyperparameters\n",
    "    options['test_rollouts'] = 40\n",
    "    options['num_rollouts'] = 20\n",
    "    options['batch_size'] = batch_size\n",
    "    options['eval_batch_size'] = 32\n",
    "    options['beta'] = beta\n",
    "    options['Lambda'] = Lambda\n",
    "    options['gamma'] = 1\n",
    "    options['positive_reward'] = 1\n",
    "    options['negative_reward'] = 0\n",
    "    options['learning_rate'] = learning_rate\n",
    "    options['grad_clip_norm'] = 100\n",
    "    options['eval_every'] = 200\n",
    "    options['total_iterations'] = 2000*(64/batch_size)\n",
    "    options['pool'] = 'max'\n",
    "    \n",
    "    return options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc64eafe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "Reading vocab...\n",
      "batcher loaded\n",
      "KG constructed\n",
      "Reading vocab...\n",
      "Contains full graph\n"
     ]
    }
   ],
   "source": [
    "from model.ours2 import *\n",
    "\n",
    "results = {}\n",
    "for layer in [1, 2]:\n",
    "    for bl in [0.1, 0.08, 0.12, 0.05, 0.15, 0.02, 0.18]:\n",
    "        for bs in [32, 64, 128]:\n",
    "            for lr in [1e-4, 5e-5, 5e-4, 1e-3, 1e-5]:\n",
    "                params = set_params(layer, bs, bl, bl, lr)\n",
    "                name = f'{layer}-{bs}-{bl}-{bl}-{lr}'\n",
    "\n",
    "                trainer = Trainer(params)\n",
    "                trainer.train()\n",
    "                torch.cuda.empty_cache()\n",
    "\n",
    "                trainer.agent.load_state_dict(torch.load(params['model_dir'] + 'agent.ckpt'))\n",
    "                trainer.agent.eval()\n",
    "                trainer.test_environment = trainer.test_test_environment\n",
    "                tmp = trainer.test(beam=True, print_paths=False, save_model=False)\n",
    "\n",
    "                print(name)\n",
    "                print(tmp)\n",
    "                print('-------------')\n",
    "                results[name] = tmp\n",
    "\n",
    "                with open(params['output_dir'] + 'results_table.pk5', 'wb') as f:\n",
    "                    pickle5.dump(results, f)\n",
    "                    \n",
    "                del trainer\n",
    "                gc.collect()\n",
    "                torch.cuda.empty_cache()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
