{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "\n",
    "from torch.utils.data   import DataLoader\n",
    "from utils              import *\n",
    "from model              import *\n",
    "from dataloader         import TrainDataset\n",
    "from dataloader         import BidirectionalOneShotIterator\n",
    "\n",
    "\n",
    "def construct_args():\n",
    "    parser = argparse.ArgumentParser(description='LAMAKE')\n",
    "    # Data paths\n",
    "    parser.add_argument('--data_path', type=str, default='../data', help='Path to the dataset')\n",
    "    parser.add_argument('--process_path', type=str, default='/data/pj20/lamake_data', help='Path to the entity hierarchy')\n",
    "    parser.add_argument('--dataset', type=str, default='FB15K-237', help='Dataset name')\n",
    "    parser.add_argument('--hierarchy_type', type=str, default='seed', choices=['seed', 'llm'],  help='Type of hierarchy to use')\n",
    "    \n",
    "    # train, valid, test\n",
    "    parser.add_argument('--do_train', action='store_true')\n",
    "    parser.add_argument('--do_valid', action='store_true')\n",
    "    parser.add_argument('--do_test',  action='store_true')\n",
    "    parser.add_argument('--evaluate_train', action='store_true', help='Evaluate on training data')\n",
    "\n",
    "    parser.add_argument('--countries', action='store_true', help='Use Countries S1/S2/S3 datasets')\n",
    "    parser.add_argument('--regions', type=int, nargs='+', default=None, \n",
    "                        help='Region Id for Countries S1/S2/S3 datasets, DO NOT MANUALLY SET')\n",
    "    \n",
    "    # Model settings\n",
    "    parser.add_argument('-de', '--double_entity_embedding', action='store_true')\n",
    "    parser.add_argument('-dr', '--double_relation_embedding', action='store_true')\n",
    "    \n",
    "    parser.add_argument('-n', '--negative_sample_size', default=128, type=int)\n",
    "    parser.add_argument('-d', '--hidden_dim', default=500, type=int)\n",
    "    parser.add_argument('-g', '--gamma', default=12.0, type=float)\n",
    "    parser.add_argument('-adv', '--negative_adversarial_sampling', action='store_true')\n",
    "    parser.add_argument('-a', '--adversarial_temperature', default=1.0, type=float)\n",
    "    parser.add_argument('-b', '--batch_size', default=1024, type=int)\n",
    "    parser.add_argument('-r', '--regularization', default=0.0, type=float)\n",
    "    parser.add_argument('--test_batch_size', default=4, type=int, help='valid/test batch size')\n",
    "\n",
    "    # Model hyperparameters\n",
    "    parser.add_argument('--model', type=str, default='TransE', help='Knowledge graph embedding model')\n",
    "    \n",
    "    # Hyperparameters\n",
    "    parser.add_argument('--rho', type=float, default=0.5, help='Weight for the randomly initialized component')\n",
    "    parser.add_argument('--lambda_1', type=float, default=0.5, help='Weight for the inter-level cluster separation')\n",
    "    parser.add_argument('--lambda_2', type=float, default=0.5, help='Weight for the hierarchical distance maintenance')\n",
    "    parser.add_argument('--lambda_3', type=float, default=0.5, help='Weight for the cluster cohesion')\n",
    "    parser.add_argument('--zeta_1', type=float, default=0.5, help='Weight for the entire hierarchical constraint')\n",
    "    parser.add_argument('--zeta_2', type=float, default=0.5, help='Weight for the text embedding deviation')\n",
    "    parser.add_argument('--zeta_3', type=float, default=0.5, help='Weight for the link prediction score')\n",
    "    \n",
    "    # Training settings\n",
    "    parser.add_argument('--num_epochs', type=int, default=1000, help='Number of training epochs')\n",
    "    parser.add_argument('--early_stop', type=int, default=10, help='Number of epochs for early stopping')\n",
    "    parser.add_argument('--cuda', action='store_true', help='Use GPU for training')\n",
    "    parser.add_argument('--uni_weight', action='store_true', help='Use uniform weight for positive and negative samples')\n",
    "\n",
    "    parser.add_argument('-lr', '--learning_rate', default=0.0001, type=float)\n",
    "    parser.add_argument('-cpu', '--cpu_num', default=10, type=int)\n",
    "    parser.add_argument('-init', '--init_checkpoint', default=None, type=str)\n",
    "    parser.add_argument('-save', '--save_path', default=None, type=str)\n",
    "    parser.add_argument('--max_steps', default=100000, type=int)\n",
    "    parser.add_argument('--warm_up_steps', default=None, type=int)\n",
    "    \n",
    "    parser.add_argument('--save_checkpoint_steps', default=10000, type=int)\n",
    "    parser.add_argument('--valid_steps', default=10000, type=int)\n",
    "    parser.add_argument('--log_steps', default=100, type=int, help='train log every xx steps')\n",
    "    parser.add_argument('--test_log_steps', default=1000, type=int, help='valid/test log every xx steps')\n",
    "    \n",
    "    parser.add_argument('--nentity', type=int, default=0, help='DO NOT MANUALLY SET')\n",
    "    parser.add_argument('--nrelation', type=int, default=0, help='DO NOT MANUALLY SET')\n",
    "    \n",
    "    args = parser.parse_args(args=[])\n",
    "    \n",
    "    args.data_path = f'{args.data_path}/{args.dataset}'\n",
    "    args.save_path = f'{args.process_path}/{args.dataset}/checkpoints/{args.model}'\n",
    "    \n",
    "    return args\n",
    "\n",
    "args = construct_args()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.do_train = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "if (not args.do_train) and (not args.do_valid) and (not args.do_test):\n",
    "    raise ValueError('one of train/val/test mode must be choosed.')\n",
    "if args.init_checkpoint:\n",
    "    override_config(args)\n",
    "elif args.data_path is None:\n",
    "    raise ValueError('one of init_checkpoint/data_path must be choosed.')\n",
    "if args.do_train and args.save_path is None:\n",
    "    raise ValueError('Where do you want to save your trained model?')\n",
    "if args.save_path and not os.path.exists(args.save_path):\n",
    "    os.makedirs(args.save_path)\n",
    "\n",
    "set_logger(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(os.path.join(args.data_path, 'entities.dict')) as fin:\n",
    "    entity2id = dict()\n",
    "    for line in fin:\n",
    "        eid, entity = line.strip().split('\\t')\n",
    "        entity2id[entity] = int(eid)\n",
    "    id2entity = {v: k for k, v in entity2id.items()}\n",
    "        \n",
    "with open(os.path.join(args.data_path, 'relations.dict')) as fin:\n",
    "    relation2id = dict()\n",
    "    for line in fin:\n",
    "        rid, relation = line.strip().split('\\t')\n",
    "        relation2id[relation] = int(rid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-05 15:06:02,402 INFO     Base Model: TransE\n",
      "2024-05-05 15:06:02,403 INFO     Data Path: ../data/FB15K-237\n",
      "2024-05-05 15:06:02,404 INFO     #entity: 14541\n",
      "2024-05-05 15:06:02,404 INFO     #relation: 237\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-05 15:06:02,591 INFO     #train: 272115\n",
      "2024-05-05 15:06:02,605 INFO     #valid: 17535\n",
      "2024-05-05 15:06:02,621 INFO     #test: 20466\n",
      "100%|██████████| 272115/272115 [00:00<00:00, 450734.96it/s]\n",
      "100%|██████████| 17535/17535 [00:00<00:00, 99258.83it/s]\n",
      "100%|██████████| 20466/20466 [00:00<00:00, 868138.08it/s]\n"
     ]
    }
   ],
   "source": [
    "nentity = len(entity2id)\n",
    "nrelation = len(relation2id)\n",
    "\n",
    "args.nentity = nentity\n",
    "args.nrelation = nrelation\n",
    "\n",
    "logging.info('Base Model: %s' % args.model)\n",
    "logging.info('Data Path: %s' % args.data_path)\n",
    "logging.info('#entity: %d' % nentity)\n",
    "logging.info('#relation: %d' % nrelation)\n",
    "\n",
    "train_triples = read_triple(os.path.join(args.data_path, 'train.txt'), entity2id, relation2id)\n",
    "logging.info('#train: %d' % len(train_triples))\n",
    "valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'), entity2id, relation2id)\n",
    "logging.info('#valid: %d' % len(valid_triples))\n",
    "test_triples  = read_triple(os.path.join(args.data_path, 'test.txt'),  entity2id, relation2id)\n",
    "logging.info('#test: %d' % len(test_triples))\n",
    "entity_info_train = read_entity_info(os.path.join(f'{args.process_path}/{args.dataset}',\\\n",
    "    f'entity_info_{args.hierarchy_type}_hier.json'), train_triples, id2entity)\n",
    "entity_info_valid = read_entity_info(os.path.join(f'{args.process_path}/{args.dataset}',\\\n",
    "    f'entity_info_{args.hierarchy_type}_hier.json'), valid_triples, id2entity)\n",
    "entity_info_test = read_entity_info(os.path.join(f'{args.process_path}/{args.dataset}',\\\n",
    "    f'entity_info_{args.hierarchy_type}_hier.json'), test_triples, id2entity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "if args.do_train:\n",
    "    # Set training dataloader iterator\n",
    "    train_dataloader_head = DataLoader(\n",
    "        TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, entity_info_train, 'head-batch'), \n",
    "        batch_size=args.batch_size,\n",
    "        shuffle=True, \n",
    "        num_workers=max(1, args.cpu_num//2),\n",
    "        collate_fn=TrainDataset.collate_fn\n",
    "    )\n",
    "    \n",
    "    train_dataloader_tail = DataLoader(\n",
    "        TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, entity_info_train, 'tail-batch'), \n",
    "        batch_size=args.batch_size,\n",
    "        shuffle=True, \n",
    "        num_workers=max(1, args.cpu_num//2),\n",
    "        collate_fn=TrainDataset.collate_fn\n",
    "    )\n",
    "    \n",
    "    train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[12442,     9,  9097],\n",
       "         [13299,   194, 13616],\n",
       "         [ 1389,   173,  5641],\n",
       "         ...,\n",
       "         [12798,   111,  5835],\n",
       "         [ 5635,     2,  5029],\n",
       "         [   45,   122,  4994]]),\n",
       " tensor([[ 9451, 11360,  9586,  ...,  6641, 12764, 13862],\n",
       "         [13814,  8415, 13967,  ...,  5200, 12163, 14080],\n",
       "         [ 5950, 11913,  9431,  ...,  3454, 11881,  5658],\n",
       "         ...,\n",
       "         [ 5304, 11429, 10302,  ...,  3779, 14390,  4326],\n",
       "         [ 2716,  7173,  7194,  ...,  2645,   658,   180],\n",
       "         [ 1694,  4621,  7333,  ..., 12884, 14031,  2634]]),\n",
       " tensor([0.1961, 0.2236, 0.2236,  ..., 0.0373, 0.0210, 0.1085]),\n",
       " tensor([ 7114, 10154,  8905,  ...,  5430,  7479,  3452]),\n",
       " tensor([ 9260, 10154,  8916,  ..., 10263, 10237,  2786]),\n",
       " tensor([[ 7113,  7112,  7115,  7104,  7116],\n",
       "         [10153, 10151, 10155, 10141, 10152],\n",
       "         [ 8904,  8898,  8906,  8894,  8899],\n",
       "         ...,\n",
       "         [ 5429,  5428,  5431,  5424,  5432],\n",
       "         [ 7478,  7477,  7480,  7473,  7481],\n",
       "         [ 3451,  3450,  3453,  3449,  3454]]),\n",
       " tensor([[ 9259,  9258,  9261,  9257,  9262],\n",
       "         [10153, 10151, 10155, 10141, 10152],\n",
       "         [ 8915,  8911,  8917,  8910,  8912],\n",
       "         ...,\n",
       "         [10262, 10261, 10264, 10259, 10265],\n",
       "         [10236, 10230, 10238, 10229, 10231],\n",
       "         [ 2784,  2783,  2785,  2781,  2787]]),\n",
       " tensor([[ 7113,  7112,  7104,  ...,    -1,    -1,    -1],\n",
       "         [10153, 10151, 10141,  ...,    -1,    -1,    -1],\n",
       "         [ 8904,  8898,  8894,  ...,    -1,    -1,    -1],\n",
       "         ...,\n",
       "         [ 5429,  5428,  5424,  ...,    -1,    -1,    -1],\n",
       "         [ 7478,  7477,  7473,  ...,    -1,    -1,    -1],\n",
       "         [ 3451,  3450,  3449,  ...,    -1,    -1,    -1]]),\n",
       " tensor([[ 9259,  9258,  9257,  ...,    -1,    -1,    -1],\n",
       "         [10153, 10151, 10141,  ...,    -1,    -1,    -1],\n",
       "         [ 8915,  8911,  8910,  ...,    -1,    -1,    -1],\n",
       "         ...,\n",
       "         [10262, 10261, 10259,  ...,    -1,    -1,    -1],\n",
       "         [10236, 10230, 10229,  ...,    -1,    -1,    -1],\n",
       "         [ 2784,  2783,  2781,  ...,    -1,    -1,    -1]]),\n",
       " 'tail-batch')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(train_iterator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the entity hierarchy and text embeddings\n",
    "entity_text_embeddings = read_entity_initial_embedding(args)\n",
    "# Load the cluster embeddings\n",
    "cluster_embeddings = read_cluster_embeddings(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Size of random relation_embedding: torch.Size([237, 500])\n",
      "Size of random entity_embedding_init: torch.Size([14541, 500])\n",
      "Size of entity_text_embeddings: torch.Size([14541, 500])\n",
      "Size of cluster_embeddings: torch.Size([10451, 500])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy                as np\n",
    "import torch.nn             as nn\n",
    "import torch.nn.functional  as F\n",
    "\n",
    "from utils                  import *\n",
    "from dataloader             import *\n",
    "from tqdm                   import tqdm\n",
    "from torch.utils.data       import DataLoader\n",
    "from sklearn.metrics        import average_precision_score\n",
    "\n",
    "\n",
    "class KGFIT(nn.Module):\n",
    "    def __init__(self, base_model, nentity, nrelation, hidden_dim, gamma, \n",
    "                    double_entity_embedding=False, double_relation_embedding=False,\n",
    "                    entity_text_embeddings=None, cluster_embeddings=None, \n",
    "                    rho=0.4, lambda_1=0.5, lambda_2=0.5, lambda_3=0.5, \n",
    "                    zeta_1=0.3, zeta_2=0.2, zeta_3=0.5,\n",
    "                    ):\n",
    "        \n",
    "        super(KGFIT, self).__init__()\n",
    "        self.model_name = base_model\n",
    "        self.nentity = nentity\n",
    "        self.nrelation = nrelation\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.epsilon = 2.0\n",
    "        \n",
    "        self.gamma = nn.Parameter(\n",
    "            torch.Tensor([gamma]), \n",
    "            requires_grad=False\n",
    "        )\n",
    "        \n",
    "        self.embedding_range = nn.Parameter(\n",
    "            torch.Tensor([(self.gamma.item() + self.epsilon) / hidden_dim]), \n",
    "            requires_grad=False\n",
    "        )\n",
    "        \n",
    "        self.entity_dim = hidden_dim*2 if double_entity_embedding else hidden_dim\n",
    "        self.relation_dim = hidden_dim*2 if double_relation_embedding else hidden_dim\n",
    "        \n",
    "        # Initialize relation embeddings (Equation 7)\n",
    "        self.relation_embedding = nn.Parameter(torch.zeros(nrelation, self.relation_dim))\n",
    "        nn.init.uniform_(\n",
    "            tensor=self.relation_embedding, \n",
    "            a=-self.embedding_range.item(), \n",
    "            b=self.embedding_range.item()\n",
    "        )\n",
    "        print(f\"Size of random relation_embedding: {self.relation_embedding.size()}\")\n",
    "        \n",
    "        # Initialize randomly initialized component of entity embeddings\n",
    "        self.entity_embedding_init = nn.Parameter(torch.zeros(nentity, self.entity_dim))\n",
    "        nn.init.uniform_(\n",
    "            tensor=self.entity_embedding_init, \n",
    "            a=-self.embedding_range.item(), \n",
    "            b=self.embedding_range.item()\n",
    "        )\n",
    "        print(f\"Size of random entity_embedding_init: {self.entity_embedding_init.size()}\")\n",
    "        \n",
    "        ent_text_emb, ent_desc_emb      = torch.chunk(entity_text_embeddings, 2, dim=1)\n",
    "        clus_text_emb, clus_desc_emb    = torch.chunk(cluster_embeddings, 2, dim=1)\n",
    "        \n",
    "        # concatenate ent_text_emb[:self.entity_dim/2] and ent_desc_emb[:self.entity_dim/2], size: (nentity, self.entity_dim)\n",
    "        self.entity_text_embeddings = torch.cat([ent_text_emb[:, :self.entity_dim//2], ent_desc_emb[:, :self.entity_dim//2]], dim=1)\n",
    "        self.entity_text_embeddings.requires_grad = False\n",
    "        print(f\"Size of entity_text_embeddings: {self.entity_text_embeddings.size()}\")\n",
    "        # concatenate clus_text_emb[:self.entity_dim/2] and clus_desc_emb[:self.entity_dim/2], size: (nentity, self.entity_dim)\n",
    "        self.cluster_embeddings     = torch.cat([clus_text_emb[:, :self.entity_dim//2], clus_desc_emb[:, :self.entity_dim//2]], dim=1)\n",
    "        self.cluster_embeddings.requires_grad = False\n",
    "        print(f\"Size of cluster_embeddings: {self.cluster_embeddings.size()}\")\n",
    "        \n",
    "        if base_model == 'pRotatE':\n",
    "            self.modulus = nn.Parameter(torch.Tensor([[0.5 * self.embedding_range.item()]]))\n",
    "        \n",
    "        #Do not forget to modify this line when you add a new model in the \"forward\" function\n",
    "        if base_model not in ['TransE', 'DistMult', 'ComplEx', 'RotatE', 'pRotatE']:\n",
    "            raise ValueError('model %s not supported' % base_model)\n",
    "            \n",
    "        if base_model == 'RotatE' and (not double_entity_embedding or double_relation_embedding):\n",
    "            raise ValueError('RotatE should use --double_entity_embedding')\n",
    "\n",
    "        if base_model == 'ComplEx' and (not double_entity_embedding or not double_relation_embedding):\n",
    "            raise ValueError('ComplEx should use --double_entity_embedding and --double_relation_embedding')\n",
    "        \n",
    "        # Hyperparameters\n",
    "        self.rho = rho              # Hyperparameter controlling the influence of the randomly initialized component in the embedding\n",
    "        \n",
    "        self.lambda_1 = lambda_1    # Hyperparameter controlling the influence of the inter-level cluster separation\n",
    "        self.lambda_2 = lambda_2    # Hyperparameter controlling the influence of the hierarchical distance maintenance\n",
    "        self.lambda_3 = lambda_3    # Hyperparameter controlling the influence of the cluster cohesion\n",
    "        \n",
    "        self.zeta_1 = zeta_1        # Hyperparameter controlling the influence of the entire hierarchical constraint\n",
    "        self.zeta_2 = zeta_2        # Hyperparameter controlling the influence of the text embedding deviation\n",
    "        self.zeta_3 = zeta_3        # Hyperparameter controlling the influence of the link prediction score\n",
    "\n",
    "    @staticmethod\n",
    "    def get_masked_embeddings(indices, embeddings, dim_size):\n",
    "        \"\"\"\n",
    "        Retrieves and applies a mask to embeddings based on provided indices.\n",
    "        \n",
    "        Args:\n",
    "            indices (torch.Tensor): Tensor of indices with possible -1 indicating invalid entries.\n",
    "            embeddings (torch.nn.Parameter): Embeddings from which to select.\n",
    "            dim_size (tuple): The desired dimension sizes of the output tensor.\n",
    "        \n",
    "        Returns:\n",
    "            torch.Tensor: Masked and selected embeddings based on valid indices.\n",
    "        \"\"\"\n",
    "        valid_mask = indices != -1\n",
    "        # Initialize tensor to hold the masked embeddings\n",
    "        masked_embeddings = torch.zeros(*dim_size, dtype=embeddings.dtype, device=embeddings.device)\n",
    "        # Apply mask to filter valid indices\n",
    "        valid_indices = indices[valid_mask]\n",
    "        selected_embeddings = torch.index_select(embeddings, dim=0, index=valid_indices)\n",
    "        # Place selected embeddings back into the appropriate locations\n",
    "        masked_embeddings.view(-1, embeddings.shape[1])[valid_mask.view(-1)] = selected_embeddings\n",
    "        return masked_embeddings\n",
    "\n",
    "\n",
    "    def forward(self, sample, self_cluster_ids, neighbor_clusters_ids, parent_ids, mode='single'):\n",
    "        if mode == 'single':\n",
    "            self_cluster_ids_head, self_cluster_ids_tail = self_cluster_ids\n",
    "            neighbor_clusters_ids_head, neighbor_clusters_ids_tail = neighbor_clusters_ids\n",
    "            parent_ids_head, parent_ids_tail = parent_ids\n",
    "            \n",
    "            # positive relation embeddings,     size: (batch_size, 1, hidden_dim)\n",
    "            relation = torch.index_select(self.relation_embedding, dim=0, index=sample[:, 1]).unsqueeze(1)\n",
    "            # positive head embeddings,         size: (batch_size, 1, hidden_dim)\n",
    "            head_init = torch.index_select(self.entity_embedding_init, dim=0, index=sample[:, 0]).unsqueeze(1)\n",
    "            # positive tail embeddings,         size: (batch_size, 1, hidden_dim)\n",
    "            tail_init = torch.index_select(self.entity_embedding_init, dim=0, index=sample[:, 2]).unsqueeze(1)\n",
    "            # positive head text embeddings,    size: (batch_size, 1, hidden_dim)\n",
    "            head_text = torch.index_select(self.entity_text_embeddings, dim=0, index=sample[:, 0]).unsqueeze(1)\n",
    "            # positive tail text embeddings,    size: (batch_size, 1, hidden_dim)\n",
    "            tail_text = torch.index_select(self.entity_text_embeddings, dim=0, index=sample[:, 2]).unsqueeze(1)\n",
    "            # positive head cluster embeddings, size: (batch_size, 1, hidden_dim)\n",
    "            cluster_emb_head = torch.index_select(self.cluster_embeddings, dim=0, index=self_cluster_ids_head).unsqueeze(1)\n",
    "            # positive tail cluster embeddings, size: (batch_size, 1, hidden_dim)\n",
    "            cluster_emb_tail = torch.index_select(self.cluster_embeddings, dim=0, index=self_cluster_ids_tail).unsqueeze(1)\n",
    "            \n",
    "            # Example usage in the model's forward function\n",
    "            neighbor_clusters_emb_head = self.get_masked_embeddings(\n",
    "                neighbor_clusters_ids_head, self.cluster_embeddings,\n",
    "                (neighbor_clusters_ids_head.size(0), neighbor_clusters_ids_head.size(1), self.hidden_dim)\n",
    "            )\n",
    "\n",
    "            neighbor_clusters_emb_tail = self.get_masked_embeddings(\n",
    "                neighbor_clusters_ids_tail, self.cluster_embeddings,\n",
    "                (neighbor_clusters_ids_tail.size(0), neighbor_clusters_ids_tail.size(1), self.hidden_dim)\n",
    "            )\n",
    "\n",
    "            parent_emb_head = self.get_masked_embeddings(\n",
    "                parent_ids_head, self.cluster_embeddings,\n",
    "                (parent_ids_head.size(0), parent_ids_head.size(1), self.hidden_dim)\n",
    "            )\n",
    "\n",
    "            parent_emb_tail = self.get_masked_embeddings(\n",
    "                parent_ids_tail, self.cluster_embeddings,\n",
    "                (parent_ids_tail.size(0), parent_ids_tail.size(1), self.hidden_dim)\n",
    "            )\n",
    "            \n",
    "            # Combine entity embeddings with text embeddings and randomly initialized component, size: (batch_size, 1, hidden_dim)\n",
    "            head_combined           =   self.rho * head_init + (1 - self.rho) * head_text\n",
    "            tail_combined           =   self.rho * tail_init + (1 - self.rho) * tail_text\n",
    "            \n",
    "            # Text Embedding Deviation,         (lower -> better),     size: (batch_size, 1)\n",
    "            text_dist               =   self.distance(head_combined, head_text  ) + self.distance(tail_combined, tail_text  )\n",
    "\n",
    "            # Cluster Cohesion,                 (lower -> better),     size: (batch_size, 1)\n",
    "            self_cluster_dist       =   self.distance(head_combined, cluster_emb_head) + self.distance(tail_combined, cluster_emb_tail)\n",
    "            \n",
    "            # Inter-level Cluster Separation,   (higher -> better),     size: (batch_size, neibor_cluster_size)\n",
    "            neighbor_cluster_dist   =   self.distance(head_combined, neighbor_clusters_emb_head) + self.distance(tail_combined, neighbor_clusters_emb_tail)\n",
    "            \n",
    "            #Hierarchical Distance Maintenance, (higher -> better),     size: (batch_size, max_parent_num)\n",
    "            hier_dist = 0\n",
    "            for i in range(len(parent_emb_head)-1):\n",
    "                parent_embedding, parent_parent_embedding = parent_emb_head[i], parent_emb_head[i+1]\n",
    "                hier_dist           +=   (self.distance(head_combined, parent_parent_embedding) - self.distance(head_combined, parent_embedding)) / len(parent_emb_head)\n",
    "                \n",
    "            for i in range(len(parent_emb_tail)-1):\n",
    "                parent_embedding, parent_parent_embedding = parent_emb_tail[i], parent_emb_tail[i+1]\n",
    "                hier_dist           +=   (self.distance(tail_combined, parent_parent_embedding) - self.distance(tail_combined, parent_embedding)) / len(parent_emb_tail)\n",
    "                \n",
    "            # KGE Score (positive),               (lower -> better),     size: (batch_size, 1)\n",
    "            link_pred_score         =   self.score_func(head_combined, relation, tail_combined, mode)\n",
    "                \n",
    "            \n",
    "            \n",
    "        elif mode == 'head-batch':\n",
    "            tail_part, head_part = sample\n",
    "            batch_size, negative_sample_size = head_part.size(0), head_part.size(1)\n",
    "            \n",
    "            assert torch.all(head_part < self.nentity), \"head_part contains out-of-bounds indices\"\n",
    "            assert torch.all(tail_part < self.nentity), \"tail_part contains out-of-bounds indices\"\n",
    "            assert torch.all(neighbor_clusters_ids < len(self.cluster_embeddings)), \"neighbor_clusters_ids contains out-of-bounds indices\"\n",
    "            assert torch.all(parent_ids < len(self.cluster_embeddings)), \"parent_ids contains out-of-bounds indices\"\n",
    "            \n",
    "            # positive relation embeddings,     size: (batch_size, 1, hidden_dim)\n",
    "            relation  = torch.index_select(self.relation_embedding, dim=0, index=tail_part[:, 1]).unsqueeze(1)\n",
    "            print(f\"Size of relation: {relation.size()}\")\n",
    "            # positive tail embeddings,         size: (batch_size, 1, hidden_dim)\n",
    "            tail_init = torch.index_select(self.entity_embedding_init, dim=0, index=tail_part[:, 2]).unsqueeze(1)\n",
    "            print(f\"Size of tail_init: {tail_init.size()}\")\n",
    "            # negative head embeddings,         size: (batch_size, negative_sample_size, hidden_dim)\n",
    "            head_init = torch.index_select(self.entity_embedding_init, dim=0, index=head_part.view(-1)).view(batch_size, negative_sample_size, -1)\n",
    "            print(f\"Size of head_init: {head_init.size()}\")\n",
    "            # positive tail text embeddings,    size: (batch_size, 1, hidden_dim)\n",
    "            tail_text = torch.index_select(self.entity_text_embeddings, dim=0, index=tail_part[:, 2]).unsqueeze(1)\n",
    "            print(f\"Size of tail_text: {tail_text.size()}\")\n",
    "            # negative head text embeddings,    size: (batch_size, negative_sample_size, hidden_dim)\n",
    "            head_text = torch.index_select(self.entity_text_embeddings, dim=0, index=head_part.view(-1)).view(batch_size, negative_sample_size, -1)\n",
    "            print(f\"Size of head_text: {head_text.size()}\")\n",
    "            # positive tail cluster embeddings, size: (batch_size, 1, hidden_dim)\n",
    "            cluster_emb = torch.index_select(self.cluster_embeddings, dim=0, index=self_cluster_ids).unsqueeze(1)\n",
    "            print(f\"Size of cluster_emb: {cluster_emb.size()}\")\n",
    "            # positive other cluster embeddings, size: (batch_size, max_num_neighbor_clusters, hidden_dim)\n",
    "            neighbor_cluster_emb = self.get_masked_embeddings(\n",
    "                neighbor_clusters_ids, self.cluster_embeddings,\n",
    "                (neighbor_clusters_ids.size(0), neighbor_clusters_ids.size(1), self.hidden_dim)\n",
    "            )\n",
    "            print(f\"Size of neighbor_cluster_emb: {neighbor_cluster_emb.size()}\")\n",
    "            # positive parent embeddings, size: (batch_size, max_parent_num, hidden_dim)\n",
    "            parent_emb = self.get_masked_embeddings(\n",
    "                parent_ids, self.cluster_embeddings,\n",
    "                (parent_ids.size(0), parent_ids.size(1), self.hidden_dim)\n",
    "            )\n",
    "            \n",
    "            # positive tail embeddings,         size: (batch_size, 1, hidden_dim)\n",
    "            tail_combined           =   self.rho * tail_init + (1 - self.rho) * tail_text\n",
    "            print(f\"Size of tail_combined: {tail_combined.size()}\")\n",
    "            # # negative head embeddings,         size: (batch_size, negative_sample_size, hidden_dim)\n",
    "            head_combined           =   self.rho * head_init + (1 - self.rho) * head_text\n",
    "            print(f\"Size of head_combined: {head_combined.size()}\")\n",
    "            \n",
    "            # Text Embedding Deviation,         (lower -> better),      size: (batch_size, 1)\n",
    "            text_dist               =   self.distance(tail_combined, tail_text  )\n",
    "\n",
    "            # Cluster Cohesion,                 (lower -> better),      size: (batch_size, 1)\n",
    "            self_cluster_dist       =   self.distance(tail_combined, cluster_emb)\n",
    "            \n",
    "            # Inter-level Cluster Separation,   (higher -> better),     size: (batch_size, neibor_cluster_size)\n",
    "            neighbor_cluster_dist   =   self.distance(tail_combined, neighbor_cluster_emb)\n",
    "            \n",
    "            #Hierarchical Distance Maintenance, (higher -> better),     size: (batch_size, max_parent_num)\n",
    "            hier_dist = 0\n",
    "            for i in range(len(parent_emb)-1):\n",
    "                parent_embedding, parent_parent_embedding = parent_emb[i], parent_emb[i+1]\n",
    "                hier_dist           +=  self.distance(tail_combined, parent_parent_embedding) - self.distance(tail_combined, parent_embedding)\n",
    "                \n",
    "            # KGE Score (negative heads),       (lower -> better),      size: (batch_size, negative_sample_size)\n",
    "            link_pred_score         =   self.score_func(head_combined, relation, tail_combined, mode)\n",
    "            \n",
    "            \n",
    "            \n",
    "        elif mode == 'tail-batch':\n",
    "            head_part, tail_part = sample\n",
    "            batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)\n",
    "            \n",
    "            assert torch.all(head_part < self.nentity), \"head_part contains out-of-bounds indices\"\n",
    "            assert torch.all(tail_part < self.nentity), \"tail_part contains out-of-bounds indices\"\n",
    "            assert torch.all(neighbor_clusters_ids < len(self.cluster_embeddings)), \"neighbor_clusters_ids contains out-of-bounds indices\"\n",
    "            assert torch.all(parent_ids < len(self.cluster_embeddings)), \"parent_ids contains out-of-bounds indices\"\n",
    "            \n",
    "            # positive relation embeddings,     size: (batch_size, 1, hidden_dim)\n",
    "            relation  = torch.index_select(self.relation_embedding, dim=0, index=head_part[:, 1]).unsqueeze(1)\n",
    "            print(f\"Size of relation: {relation.size()}\")\n",
    "            # positive head embeddings,         size: (batch_size, 1, hidden_dim)\n",
    "            head_init = torch.index_select(self.entity_embedding_init, dim=0, index=head_part[:, 0]).unsqueeze(1)\n",
    "            print(f\"Size of head_init: {head_init.size()}\")\n",
    "            # negative tail embeddings,         size: (batch_size, negative_sample_size, hidden_dim)\n",
    "            tail_init = torch.index_select(self.entity_embedding_init, dim=0, index=tail_part.view(-1)).view(batch_size, negative_sample_size, -1)\n",
    "            print(f\"Size of tail_init: {tail_init.size()}\")\n",
    "            # positive head text embeddings,    size: (batch_size, 1, hidden_dim)\n",
    "            head_text = torch.index_select(self.entity_text_embeddings, dim=0, index=head_part[:, 0]).unsqueeze(1)\n",
    "            print(f\"Size of head_text: {head_text.size()}\")\n",
    "            # negative tail text embeddings,    size: (batch_size, negative_sample_size, hidden_dim)\n",
    "            tail_text = torch.index_select(self.entity_text_embeddings, dim=0, index=tail_part.view(-1)).view(batch_size, negative_sample_size, -1)\n",
    "            print(f\"Size of tail_text: {tail_text.size()}\")\n",
    "            # positive head cluster embeddings, size: (batch_size, 1, hidden_dim)\n",
    "            cluster_emb = torch.index_select(self.cluster_embeddings, dim=0, index=self_cluster_ids).unsqueeze(1)\n",
    "            print(f\"Size of cluster_emb: {cluster_emb.size()}\")\n",
    "            # positive other cluster embeddings, size: (batch_size, max_num_neighbor_clusters, hidden_dim)\n",
    "            neighbor_cluster_emb = self.get_masked_embeddings(\n",
    "                neighbor_clusters_ids, self.cluster_embeddings,\n",
    "                (neighbor_clusters_ids.size(0), neighbor_clusters_ids.size(1), self.hidden_dim)\n",
    "            )\n",
    "            print(f\"Size of neighbor_cluster_emb: {neighbor_cluster_emb.size()}\")\n",
    "            # positive parent embeddings, size: (batch_size, max_parent_num, hidden_dim)\n",
    "            parent_emb = self.get_masked_embeddings(\n",
    "                parent_ids, self.cluster_embeddings,\n",
    "                (parent_ids.size(0), parent_ids.size(1), self.hidden_dim)\n",
    "            )\n",
    "            print(f\"Size of parent_emb: {parent_emb.size()}\")\n",
    "            \n",
    "            # positive head embeddings,        size: (batch_size, 1, hidden_dim)\n",
    "            head_combined = self.rho * head_init + (1 - self.rho) * head_text \n",
    "            print(f\"Size of head_combined: {head_combined.size()}\")\n",
    "            # negative tail embeddings,       size: (batch_size, negative_sample_size, hidden_dim)\n",
    "            tail_combined = self.rho * tail_init + (1 - self.rho) * tail_text \n",
    "            print(f\"Size of tail_combined: {tail_combined.size()}\")\n",
    "            \n",
    "            # Text Embedding Deviation,         (lower -> better),      size: (batch_size, 1)\n",
    "            text_dist               =   self.distance(head_combined, head_text  )\n",
    "            \n",
    "            # Cluster Cohesion,                 (lower -> better),      size: (batch_size, 1)\n",
    "            self_cluster_dist       =   self.distance(head_combined, cluster_emb)\n",
    "            \n",
    "            # Inter-level Cluster Separation,   (higher -> better),     size: (batch_size, neibor_cluster_size)\n",
    "            neighbor_cluster_dist   =   self.distance(head_combined, neighbor_cluster_emb)\n",
    "            \n",
    "            #Hierarchical Distance Maintenance, (higher -> better),     size: (batch_size, max_parent_num)\n",
    "            hier_dist = 0\n",
    "            for i in range(len(parent_emb)-1):\n",
    "                parent_embedding, parent_parent_embedding = parent_emb[i], parent_emb[i+1]\n",
    "                hier_dist           +=   self.distance(head_combined, parent_parent_embedding) - self.distance(head_combined, parent_embedding)\n",
    "            \n",
    "            # KGE Score (negative tails),       (lower -> better),      size: (batch_size, negative_sample_size)\n",
    "            link_pred_score         =   self.score_func(head_combined, relation, tail_combined, mode)\n",
    "            \n",
    "        \n",
    "        else:\n",
    "            raise ValueError('mode %s not supported' % mode)\n",
    "        \n",
    "        \n",
    "        return text_dist, self_cluster_dist, neighbor_cluster_dist, hier_dist, link_pred_score \n",
    "\n",
    "    def distance(self, embeddings1, embeddings2, metric='cosine'):\n",
    "        \"\"\"\n",
    "        Compute the distance between two sets of embeddings.\n",
    "        \"\"\"\n",
    "        if metric == 'euclidean':\n",
    "            return torch.norm(embeddings1 - embeddings2, p=2, dim=-1)\n",
    "        elif metric == 'cosine':\n",
    "            embeddings1_norm = F.normalize(embeddings1, p=2, dim=-1)\n",
    "            embeddings2_norm = F.normalize(embeddings2, p=2, dim=-1)\n",
    "            cosine_similarity = torch.sum(embeddings1_norm * embeddings2_norm, dim=-1)\n",
    "            cosine_distance = 1 - cosine_similarity\n",
    "            return cosine_distance\n",
    "\n",
    "    def score_func(self, head, relation, tail, mode='single'):\n",
    "        \"\"\"\n",
    "        Compute the score for the given triple (head, relation, tail).\n",
    "        \"\"\"\n",
    "        model_func = {\n",
    "            'TransE': self.TransE,\n",
    "            'DistMult': self.DistMult,\n",
    "            'ComplEx': self.ComplEx,\n",
    "            'RotatE': self.RotatE,\n",
    "            'pRotatE': self.pRotatE\n",
    "        }\n",
    "        \n",
    "        if self.model_name in model_func:\n",
    "            score = model_func[self.model_name](head, relation, tail, mode)\n",
    "        else:\n",
    "            raise ValueError('model %s not supported' % self.model_name)\n",
    "        \n",
    "        return score\n",
    "\n",
    "    def TransE(self, head, relation, tail, mode):\n",
    "        \"\"\"\n",
    "        Compute the score using the TransE model.\n",
    "        \"\"\"\n",
    "        if mode == 'head-batch':\n",
    "            score = head + (relation - tail)\n",
    "        else:\n",
    "            score = (head + relation) - tail\n",
    "\n",
    "        score = self.gamma.item() - torch.norm(score, p=1, dim=2)\n",
    "        return score\n",
    "\n",
    "    def DistMult(self, head, relation, tail, mode):\n",
    "        \"\"\"\n",
    "        Compute the score using the DistMult model.\n",
    "        \"\"\"\n",
    "        if mode == 'head-batch':\n",
    "            score = head * (relation * tail)\n",
    "        else:\n",
    "            score = (head * relation) * tail\n",
    "\n",
    "        score = score.sum(dim = 2)\n",
    "        return score\n",
    "\n",
    "    def ComplEx(self, head, relation, tail, mode):\n",
    "        \"\"\"\n",
    "        Compute the score using the ComplEx model.\n",
    "        \"\"\"\n",
    "        head_re, head_im = torch.chunk(head, 2, dim=2)\n",
    "        relation_re, relation_im = torch.chunk(relation, 2, dim=2)\n",
    "        tail_re, tail_im = torch.chunk(tail, 2, dim=2)\n",
    "\n",
    "        if mode == 'head-batch':\n",
    "            re_score = relation_re * tail_re + relation_im * tail_im\n",
    "            im_score = relation_re * tail_im - relation_im * tail_re\n",
    "            score = head_re * re_score + head_im * im_score\n",
    "        else:\n",
    "            re_score = head_re * relation_re - head_im * relation_im\n",
    "            im_score = head_re * relation_im + head_im * relation_re\n",
    "            score = re_score * tail_re + im_score * tail_im\n",
    "\n",
    "        score = score.sum(dim = 2)\n",
    "        return score\n",
    "\n",
    "    def RotatE(self, head, relation, tail, mode):\n",
    "        \"\"\"\n",
    "        Compute the score using the RotatE model.\n",
    "        \"\"\"\n",
    "        pi = 3.14159265358979323846\n",
    "        \n",
    "        head_re, head_im = torch.chunk(head, 2, dim=2)\n",
    "        tail_re, tail_im = torch.chunk(tail, 2, dim=2)\n",
    "\n",
    "        #Make phases of relations uniformly distributed in [-pi, pi]\n",
    "\n",
    "        phase_relation = relation/(self.embedding_range.item()/pi)\n",
    "\n",
    "        relation_re = torch.cos(phase_relation)\n",
    "        relation_im = torch.sin(phase_relation)\n",
    "\n",
    "        if mode == 'head-batch':\n",
    "            re_score = relation_re * tail_re + relation_im * tail_im\n",
    "            im_score = relation_re * tail_im - relation_im * tail_re\n",
    "            re_score = re_score - head_re\n",
    "            im_score = im_score - head_im\n",
    "        else:\n",
    "            re_score = head_re * relation_re - head_im * relation_im\n",
    "            im_score = head_re * relation_im + head_im * relation_re\n",
    "            re_score = re_score - tail_re\n",
    "            im_score = im_score - tail_im\n",
    "\n",
    "        score = torch.stack([re_score, im_score], dim = 0)\n",
    "        score = score.norm(dim = 0)\n",
    "\n",
    "        score = self.gamma.item() - score.sum(dim = 2)\n",
    "        return score\n",
    "\n",
    "    def pRotatE(self, head, relation, tail, mode):\n",
    "        \"\"\"\n",
    "        Compute the score using the pRotatE model.\n",
    "        \"\"\"\n",
    "        pi = 3.14159262358979323846\n",
    "        \n",
    "        #Make phases of entities and relations uniformly distributed in [-pi, pi]\n",
    "\n",
    "        phase_head = head/(self.embedding_range.item()/pi)\n",
    "        phase_relation = relation/(self.embedding_range.item()/pi)\n",
    "        phase_tail = tail/(self.embedding_range.item()/pi)\n",
    "\n",
    "        if mode == 'head-batch':\n",
    "            score = phase_head + (phase_relation - phase_tail)\n",
    "        else:\n",
    "            score = (phase_head + phase_relation) - phase_tail\n",
    "\n",
    "        score = torch.sin(score)            \n",
    "        score = torch.abs(score)\n",
    "\n",
    "        score = self.gamma.item() - score.sum(dim = 2) * self.modulus\n",
    "        return score\n",
    "    \n",
    "    \n",
    "###### KG-FIT Model ######\n",
    "model = KGFIT(\n",
    "    base_model=args.model,\n",
    "    nentity=nentity,\n",
    "    nrelation=nrelation,\n",
    "    hidden_dim=args.hidden_dim,\n",
    "    gamma=args.gamma,\n",
    "    double_entity_embedding=args.double_entity_embedding,\n",
    "    double_relation_embedding=args.double_relation_embedding,\n",
    "    entity_text_embeddings=entity_text_embeddings,\n",
    "    cluster_embeddings=cluster_embeddings,\n",
    "    rho=args.rho,\n",
    "    lambda_1=args.lambda_1,\n",
    "    lambda_2=args.lambda_2,\n",
    "    lambda_3=args.lambda_3,\n",
    "    zeta_1=args.zeta_1,\n",
    "    zeta_2=args.zeta_2,\n",
    "    zeta_3=args.zeta_3,\n",
    ")\n",
    "##########################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "current_learning_rate = args.learning_rate\n",
    "optimizer = torch.optim.Adam(\n",
    "    filter(lambda p: p.requires_grad, model.parameters()), \n",
    "    lr=current_learning_rate\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.cuda = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Size of relation: torch.Size([1024, 1, 500])\n",
      "Size of head_init: torch.Size([1024, 1, 500])\n",
      "Size of tail_init: torch.Size([1024, 128, 500])\n",
      "Size of head_text: torch.Size([1024, 1, 500])\n",
      "Size of tail_text: torch.Size([1024, 128, 500])\n",
      "Size of cluster_emb: torch.Size([1024, 1, 500])\n",
      "Size of neighbor_cluster_emb: torch.Size([1024, 5, 500])\n",
      "Size of parent_emb: torch.Size([1024, 45, 500])\n",
      "Size of head_combined: torch.Size([1024, 1, 500])\n",
      "Size of tail_combined: torch.Size([1024, 128, 500])\n",
      "Loss: 302.8885498046875\n"
     ]
    },
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
      "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
      "\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
      "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
     ]
    }
   ],
   "source": [
    "model.train()\n",
    "\n",
    "optimizer.zero_grad()\n",
    "\n",
    "positive_sample, negative_sample, subsampling_weight, cluster_id_head, cluster_id_tail, \\\n",
    "    neighbor_clusters_ids_head, neighbor_clusters_ids_tail, parent_ids_head, parent_ids_tail, mode = next(train_iterator)\n",
    "\n",
    "if args.cuda:\n",
    "    positive_sample = positive_sample.cuda()\n",
    "    negative_sample = negative_sample.cuda()\n",
    "    subsampling_weight = subsampling_weight.cuda()\n",
    "    cluster_id_head = cluster_id_head.cuda()\n",
    "    cluster_id_tail = cluster_id_tail.cuda()\n",
    "    neighbor_clusters_ids_head = neighbor_clusters_ids_head.cuda()\n",
    "    neighbor_clusters_ids_tail = neighbor_clusters_ids_tail.cuda()\n",
    "    parent_ids_head = parent_ids_head.cuda()\n",
    "    parent_ids_tail = parent_ids_tail.cuda()\n",
    "\n",
    "## Negative Samples\n",
    "if mode == 'head-batch':\n",
    "    self_cluster_ids = cluster_id_tail\n",
    "    neighbor_clusters_ids = neighbor_clusters_ids_tail\n",
    "    parent_ids = parent_ids_tail\n",
    "    \n",
    "elif mode == 'tail-batch':\n",
    "    self_cluster_ids = cluster_id_head\n",
    "    neighbor_clusters_ids = neighbor_clusters_ids_head\n",
    "    parent_ids = parent_ids_head\n",
    "    \n",
    "\n",
    "text_dist_n, self_cluster_dist_n, neighbor_cluster_dist_n, hier_dist_n, negative_score = \\\n",
    "    model((positive_sample, negative_sample), self_cluster_ids, neighbor_clusters_ids, parent_ids, mode=mode)\n",
    "    \n",
    "neighbor_cluster_dist_mean_n = neighbor_cluster_dist_n.mean(dim=1, keepdim=True)\n",
    "hier_dist_mean_n = hier_dist_n.mean(dim=1, keepdim=True)\n",
    "\n",
    "\n",
    "if args.negative_adversarial_sampling:\n",
    "    #In self-adversarial sampling, we do not apply back-propagation on the sampling weight\n",
    "    negative_score = (F.softmax(negative_score * args.adversarial_temperature, dim = 1).detach() \n",
    "                        * F.logsigmoid(-negative_score)).sum(dim = 1)\n",
    "else:\n",
    "    negative_score = F.logsigmoid(-negative_score).mean(dim = 1)\n",
    "\n",
    "\n",
    "## Positive Sample\n",
    "self_cluster_ids = (cluster_id_head, cluster_id_tail)\n",
    "neighbor_clusters_ids = (neighbor_clusters_ids_head, neighbor_clusters_ids_tail)\n",
    "parent_ids = (parent_ids_head, parent_ids_tail)\n",
    "\n",
    "text_dist_p, self_cluster_dist_p, neighbor_cluster_dist_p, hier_dist_p, positive_score = \\\n",
    "    model(positive_sample, self_cluster_ids, neighbor_clusters_ids, parent_ids, mode='single')\n",
    "    \n",
    "neighbor_cluster_dist_mean_p = neighbor_cluster_dist_p.mean(dim=1, keepdim=True)\n",
    "hier_dist_mean_p = hier_dist_p.mean(dim=1, keepdim=True)\n",
    "\n",
    "positive_score = F.logsigmoid(positive_score).squeeze(dim = 1)\n",
    "\n",
    "if args.uni_weight:\n",
    "    positive_sample_loss = - positive_score.mean()\n",
    "    negative_sample_loss = - negative_score.mean()\n",
    "else:\n",
    "    positive_sample_loss = - (subsampling_weight * positive_score).sum()/subsampling_weight.sum()\n",
    "    negative_sample_loss = - (subsampling_weight * negative_score).sum()/subsampling_weight.sum()\n",
    "\n",
    "## Loss function\n",
    "loss = (positive_sample_loss + negative_sample_loss)/2\n",
    "\n",
    "if args.regularization != 0.0:\n",
    "    #Use L3 regularization for ComplEx and DistMult\n",
    "    regularization = args.regularization * (\n",
    "        model.entity_embedding.norm(p = 3)**3 + \n",
    "        model.relation_embedding.norm(p = 3).norm(p = 3)**3\n",
    "    )\n",
    "    loss = loss + regularization\n",
    "    regularization_log = {'regularization': regularization.item()}\n",
    "else:\n",
    "    regularization_log = {}\n",
    "    \n",
    "loss = model.zeta_3 * loss \\\n",
    "    + model.zeta_1 * (model.lambda_1 * (self_cluster_dist_n + self_cluster_dist_p) \\\n",
    "                        - model.lambda_2 * (neighbor_cluster_dist_mean_n + neighbor_cluster_dist_mean_p) \\\n",
    "                        - model.lambda_3 * (hier_dist_mean_n + hier_dist_mean_p)) \\\n",
    "    + model.zeta_2 * (text_dist_n + text_dist_p)\n",
    "    \n",
    "loss = loss.sum()\n",
    "print(f\"Loss: {loss}\")\n",
    "\n",
    "loss.backward()\n",
    "\n",
    "optimizer.step()\n",
    "\n",
    "log = {\n",
    "    **regularization_log,\n",
    "    'positive_sample_loss': positive_sample_loss.item(),\n",
    "    'negative_sample_loss': negative_sample_loss.item(),\n",
    "    'loss': loss.item()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1024, 1])"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 86835/86835 [00:01<00:00, 57471.65it/s] \n",
      "100%|██████████| 3034/3034 [00:00<00:00, 316980.98it/s]\n",
      "100%|██████████| 3134/3134 [00:00<00:00, 310395.73it/s]\n"
     ]
    }
   ],
   "source": [
    "from utils import *\n",
    "from dataloader import *\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "with open('/home/pj20/server-03/lamake/data/WN18RR/entities.dict') as fin:\n",
    "    entity2id = dict()\n",
    "    for line in fin:\n",
    "        eid, entity = line.strip().split('\\t')\n",
    "        entity2id[entity] = int(eid)\n",
    "    id2entity = {v: k for k, v in entity2id.items()}\n",
    "        \n",
    "with open('/home/pj20/server-03/lamake/data/WN18RR/relations.dict') as fin:\n",
    "    relation2id = dict()\n",
    "    for line in fin:\n",
    "        rid, relation = line.strip().split('\\t')\n",
    "        relation2id[relation] = int(rid)\n",
    "        \n",
    "nentity = len(entity2id)\n",
    "nrelation = len(relation2id)\n",
    "\n",
    "train_triples = read_triple('/home/pj20/server-03/lamake/data/WN18RR/train.txt', entity2id, relation2id)\n",
    "logging.info('#train: %d' % len(train_triples))\n",
    "valid_triples = read_triple('/home/pj20/server-03/lamake/data/WN18RR/valid.txt', entity2id, relation2id)\n",
    "logging.info('#valid: %d' % len(valid_triples))\n",
    "test_triples  = read_triple('/home/pj20/server-03/lamake/data/WN18RR/test.txt',  entity2id, relation2id)\n",
    "logging.info('#test: %d' % len(test_triples))\n",
    "entity_info_train = read_entity_info('/shared/pj20/lamake_data/WN18RR/entity_info_seed_hier.json', train_triples, id2entity)\n",
    "entity_info_valid = read_entity_info('/shared/pj20/lamake_data/WN18RR/entity_info_seed_hier.json', valid_triples, id2entity)\n",
    "entity_info_test = read_entity_info('/shared/pj20/lamake_data/WN18RR/entity_info_seed_hier.json', test_triples, id2entity)\n",
    "\n",
    "#All true triples\n",
    "all_true_triples = train_triples + valid_triples + test_triples\n",
    "\n",
    "test_triples  = read_triple('/home/pj20/server-03/lamake/data/WN18RR/test.txt',  entity2id, relation2id)\n",
    "\n",
    "test_dataloader_head = DataLoader(\n",
    "    TestDataset(\n",
    "        test_triples, \n",
    "        all_true_triples, \n",
    "        nentity, \n",
    "        nrelation, \n",
    "        entity_info_test,\n",
    "        'head-batch',\n",
    "        rerank=True,\n",
    "    ), \n",
    "    batch_size=4,\n",
    "    num_workers=max(1, 10), \n",
    "    collate_fn=TestDataset.collate_fn\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([24250,     9, 23509])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "positive_sample = test_dataloader_head.dataset[1][0]\n",
    "positive_sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm \n",
    "count_3 = 0\n",
    "count_0 = 0\n",
    "\n",
    "dataset = test_dataloader_head.dataset\n",
    "\n",
    "for i in range(len(dataset)):\n",
    "    if dataset[i][2][dataset[i][0][0]] == 3:\n",
    "        count_3 += 1\n",
    "        print(3)\n",
    "    elif dataset[i][2][dataset[i][0][0]] == 0:\n",
    "        count_0 += 1\n",
    "        print(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2589, 545)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "count_3, count_0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8261008296107212"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "count_3 / (count_3 + count_0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity_info = json.load(open('/shared/pj20/lamake_data/WN18RR/entity_info_seed_hier.json'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1433"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(entity_info['01441510']['k_hop_neighbors'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kgc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
