{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cfe40922",
   "metadata": {
    "code_folding": [
     14
    ],
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# config\n",
    "import os.path as osp\n",
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "from easydict import EasyDict as edict\n",
    "import argparse\n",
    "\n",
    "\n",
    "import sys\n",
    "sys.argv = ['']\n",
    "\n",
    "\n",
    "\n",
    "class cfg():\n",
    "    def __init__(self):\n",
    "#         self.this_dir = osp.dirname(__file__)\n",
    "        # change\n",
    "        self.data_root = osp.abspath(osp.join(osp.abspath(''), '..', '..', 'data', ''))\n",
    "\n",
    "        # TODO: add some static variable  (The frequency of change is low)\n",
    "\n",
    "    def get_args(self):\n",
    "        parser = argparse.ArgumentParser()\n",
    "        # base\n",
    "        parser.add_argument('--gpu', default=0, type=int)\n",
    "        parser.add_argument('--batch_size', default=3500, type=int)\n",
    "        parser.add_argument('--epoch', default=250, type=int)\n",
    "        parser.add_argument(\"--save_model\", default=0, type=int, choices=[0, 1])\n",
    "        parser.add_argument(\"--only_test\", default=0, type=int, choices=[0, 1])\n",
    "        parser.add_argument(\"--enable_sota\", action=\"store_true\", default=False)\n",
    "\n",
    "        # torthlight\n",
    "        parser.add_argument(\"--no_tensorboard\", default=False, action=\"store_true\")\n",
    "        parser.add_argument(\"--exp_name\", default=\"EA_exp\", type=str, help=\"Experiment name\")\n",
    "        parser.add_argument(\"--dump_path\", default=\"dump/\", type=str, help=\"Experiment dump path\")\n",
    "        parser.add_argument(\"--exp_id\", default=\"001\", type=str, help=\"Experiment ID\")\n",
    "        parser.add_argument(\"--random_seed\", default=42, type=int)\n",
    "        parser.add_argument(\"--data_path\", default=\"mmkg\", type=str, help=\"Experiment path\")\n",
    "\n",
    "        # --------- EA -----------\n",
    "        parser.add_argument(\"--data_choice\", default=\"DBP15K\", type=str, choices=[\"DBP15K\", \"DWY\", \"FBYG15K\", \"FBDB15K\"], help=\"Experiment path\")\n",
    "        parser.add_argument(\"--data_rate\", type=float, default=0.3, help=\"training set rate\")\n",
    "        # parser.add_argument(\"--data_rate\", type=float, default=0.3, choices=[0.2, 0.3, 0.5, 0.8], help=\"training set rate\")\n",
    "\n",
    "        # TODO: add some dynamic variable\n",
    "        parser.add_argument(\"--model_name\", default=\"EVA\", type=str, choices=[\"EVA\", \"MCLEA\", \"MSNEA\", \"MEAformer\"], help=\"model name\")\n",
    "        parser.add_argument(\"--model_name_save\", default=\"\", type=str, help=\"model name for model load\")\n",
    "\n",
    "        # 训练阶段\n",
    "        parser.add_argument('--workers', type=int, default=8)\n",
    "        parser.add_argument('--accumulation_steps', type=int, default=1)\n",
    "        parser.add_argument(\"--scheduler\", default=\"fixed\", type=str, choices=[\"linear\", \"cos\", \"fixed\"])\n",
    "        parser.add_argument(\"--optim\", default=\"adam\", type=str, choices=[\"adamw\", \"adam\"])\n",
    "        parser.add_argument('--lr', type=float, default=5e-4)\n",
    "        parser.add_argument('--weight_decay', type=float, default=0.0001)\n",
    "        parser.add_argument(\"--adam_epsilon\", default=1e-8, type=float)\n",
    "        parser.add_argument('--eval_epoch', default=2, type=int, help='evaluate each n epoch')\n",
    "\n",
    "        # 可选\n",
    "        parser.add_argument('--margin', default=1, type=float, help='The fixed margin in loss function. ')\n",
    "        parser.add_argument('--emb_dim', default=1000, type=int, help='The embedding dimension in KGE model.')\n",
    "        parser.add_argument('--adv_temp', default=1.0, type=float, help='The temperature of sampling in self-adversarial negative sampling.')\n",
    "        parser.add_argument(\"--contrastive_loss\", default=0, type=int, choices=[0, 1])\n",
    "        parser.add_argument('--clip', type=float, default=1., help='gradient clipping')\n",
    "\n",
    "        # --------- EVA -----------\n",
    "        parser.add_argument(\"--data_split\", default=\"fr_en\", type=str, help=\"Experiment split\", choices=[\"dbp_wd_15k_V2\", \"dbp_wd_15k_V1\", \"zh_en\", \"ja_en\", \"fr_en\", \"norm\"])\n",
    "        parser.add_argument(\"--hidden_units\", type=str, default=\"300,300,300\", help=\"hidden units in each hidden layer(including in_dim and out_dim), splitted with comma\")\n",
    "        parser.add_argument(\"--dropout\", type=float, default=0.0, help=\"dropout rate for layers\")\n",
    "        parser.add_argument(\"--attn_dropout\", type=float, default=0.0, help=\"dropout rate for gat layers\")\n",
    "        parser.add_argument(\"--distance\", type=int, default=2, help=\"L1 distance or L2 distance. ('1', '2')\", choices=[1, 2])\n",
    "        parser.add_argument(\"--csls\", action=\"store_true\", default=True, help=\"use CSLS for inference\")\n",
    "        parser.add_argument(\"--csls_k\", type=int, default=3, help=\"top k for csls\")\n",
    "        parser.add_argument(\"--il\", action=\"store_true\", default=False, help=\"Iterative learning?\")\n",
    "        parser.add_argument(\"--semi_learn_step\", type=int, default=10, help=\"If IL, what's the update step?\")\n",
    "        parser.add_argument(\"--il_start\", type=int, default=500, help=\"If Il, when to start?\")\n",
    "        parser.add_argument(\"--unsup\", action=\"store_true\", default=False)\n",
    "        parser.add_argument(\"--unsup_k\", type=int, default=1000, help=\"|visual seed|\")\n",
    "\n",
    "        # --------- MCLEA -----------\n",
    "        parser.add_argument(\"--unsup_mode\", type=str, default=\"img\", help=\"unsup mode\", choices=[\"img\", \"name\", \"char\"])\n",
    "        parser.add_argument(\"--tau\", type=float, default=0.1, help=\"the temperature factor of contrastive loss\")\n",
    "        parser.add_argument(\"--tau2\", type=float, default=4., help=\"the temperature factor of alignment loss\")\n",
    "        parser.add_argument(\"--alpha\", type=float, default=0.2, help=\"the margin of InfoMaxNCE loss\")\n",
    "        parser.add_argument(\"--with_weight\", type=int, default=1, help=\"Whether to weight the fusion of different \")\n",
    "        parser.add_argument(\"--structure_encoder\", type=str, default=\"gat\", help=\"the encoder of structure view\", choices=[\"gat\", \"gcn\"])\n",
    "        parser.add_argument(\"--ab_weight\", type=float, default=0.5, help=\"the weight of NTXent Loss\")\n",
    "\n",
    "        parser.add_argument(\"--projection\", action=\"store_true\", default=False, help=\"add projection for model\")\n",
    "        parser.add_argument(\"--heads\", type=str, default=\"2,2\", help=\"heads in each gat layer, splitted with comma\")\n",
    "        parser.add_argument(\"--instance_normalization\", action=\"store_true\", default=False, help=\"enable instance normalization\")\n",
    "        # 为了避免冲突给这些dim置-1默认\n",
    "        parser.add_argument(\"--attr_dim\", type=int, default=300, help=\"the hidden size of attr and rel features\")\n",
    "        parser.add_argument(\"--img_dim\", type=int, default=300, help=\"the hidden size of img feature\")\n",
    "        parser.add_argument(\"--name_dim\", type=int, default=300, help=\"the hidden size of name feature\")\n",
    "        parser.add_argument(\"--char_dim\", type=int, default=300, help=\"the hidden size of char feature\")\n",
    "\n",
    "        parser.add_argument(\"--w_gcn\", action=\"store_false\", default=True, help=\"with gcn features\")\n",
    "        parser.add_argument(\"--w_rel\", action=\"store_false\", default=True, help=\"with rel features\")\n",
    "        parser.add_argument(\"--w_attr\", action=\"store_false\", default=True, help=\"with attr features\")\n",
    "        parser.add_argument(\"--w_name\", action=\"store_false\", default=True, help=\"with name features\")\n",
    "        parser.add_argument(\"--w_char\", action=\"store_false\", default=True, help=\"with char features\")\n",
    "        parser.add_argument(\"--w_img\", action=\"store_false\", default=True, help=\"with img features\")\n",
    "        parser.add_argument(\"--use_surface\", type=int, default=0, help=\"whether to use the surface\")\n",
    "\n",
    "        parser.add_argument(\"--inner_view_num\", type=int, default=6, help=\"the number of inner view\")\n",
    "        parser.add_argument(\"--word_embedding\", type=str, default=\"glove\", help=\"the type of word embedding, [glove|fasttext]\", choices=[\"glove\", \"bert\"])\n",
    "        # projection head\n",
    "        parser.add_argument(\"--use_project_head\", action=\"store_true\", default=False, help=\"use projection head\")\n",
    "        parser.add_argument(\"--zoom\", type=float, default=0.1, help=\"narrow the range of losses\")\n",
    "        parser.add_argument(\"--reduction\", type=str, default=\"mean\", help=\"[sum|mean]\", choices=[\"sum\", \"mean\"])\n",
    "\n",
    "        # --------- MEAformer -----------\n",
    "        parser.add_argument(\"--hidden_size\", type=int, default=300, help=\"the hidden size of MEAformer\")\n",
    "        parser.add_argument(\"--intermediate_size\", type=int, default=400, help=\"the hidden size of MEAformer\")\n",
    "        parser.add_argument(\"--num_attention_heads\", type=int, default=1, help=\"the number of attention_heads of MEAformer\")\n",
    "        parser.add_argument(\"--num_hidden_layers\", type=int, default=1, help=\"the number of hidden_layers of MEAformer\")\n",
    "        parser.add_argument(\"--position_embedding_type\", default=\"absolute\", type=str)\n",
    "        parser.add_argument(\"--use_intermediate\", type=int, default=1, help=\"whether to use_intermediate\")\n",
    "        parser.add_argument(\"--replay\", type=int, default=0, help=\"whether to use replay strategy\")\n",
    "        parser.add_argument(\"--neg_cross_kg\", type=int, default=0, help=\"whether to force the negative samples in the opposite KG\")\n",
    "        parser.add_argument(\"--ratio\", type=str, default=\"1.0\", help=\"which visual adapt\", choices=[\"0.05\", \"0.1\", \"0.15\", \"0.2\", \"0.3\", \"0.4\",\n",
    "                                                                                                    \"0.45\", \"0.5\", \"0.55\", \"0.6\", \"0.7\", \"0.75\", \"0.8\", \"0.9\", \"1.0\"])\n",
    "\n",
    "        # --------- MSNEA -----------\n",
    "        parser.add_argument(\"--dim\", type=int, default=100, help=\"the hidden size of MSNEA\")\n",
    "        parser.add_argument(\"--neg_triple_num\", type=int, default=1, help=\"neg triple num\")\n",
    "        # 是否使用bert\n",
    "        parser.add_argument(\"--use_bert\", type=int, default=0)\n",
    "        # 是否使用属性值\n",
    "        parser.add_argument(\"--use_attr_value\", type=int, default=0)\n",
    "        # parser.add_argument(\"--learning_rate\", type=int, default=0.001)\n",
    "        # parser.add_argument(\"--optimizer\", type=str, default=\"Adam\")\n",
    "        # parser.add_argument(\"--max_epoch\", type=int, default=200)\n",
    "\n",
    "        # parser.add_argument(\"--save_path\", type=str, default=\"save_pkl\", help=\"save path\")\n",
    "\n",
    "\n",
    "\n",
    "        # --------- GEEA -----------\n",
    "        parser.add_argument(\"--num_layers\", type=int, default=3, help=\"number of layers for each sub-VAE\")\n",
    "        parser.add_argument(\"--es\", action=\"store_true\", default=False, help=\"process the datasets for entity synthesis\")\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        # ------------ 并行训练 ------------\n",
    "        # 是否并行\n",
    "        parser.add_argument('--rank', type=int, default=0, help='rank to dist')\n",
    "        parser.add_argument('--dist', type=int, default=0, help='whether to dist')\n",
    "        # 不要改该参数，系统会自动分配\n",
    "        parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')\n",
    "        # 开启的进程数(注意不是线程),不用设置该参数，会根据nproc_per_node自动设置\n",
    "        parser.add_argument('--world-size', default=3, type=int,\n",
    "                            help='number of distributed processes')\n",
    "        parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')\n",
    "        parser.add_argument(\"--local_rank\", default=-1, type=int)\n",
    "\n",
    "        self.cfg = parser.parse_args()\n",
    "\n",
    "    def update_train_configs(self):\n",
    "        # add some constraint for parameters\n",
    "        # e.g. cannot save and test at the same time\n",
    "        assert not (self.cfg.save_model and self.cfg.only_test)\n",
    "\n",
    "        # TODO: update some dynamic variable\n",
    "        self.cfg.data_root = self.data_root\n",
    "\n",
    "        if self.cfg.use_surface:\n",
    "            self.cfg.w_name = True\n",
    "            self.cfg.w_char = True\n",
    "        else:\n",
    "            self.cfg.w_name = False\n",
    "            self.cfg.w_char = False\n",
    "\n",
    "        if self.cfg.data_choice in [\"FBYG15K\", \"FBDB15K\"]:\n",
    "            # 不使用intermediate\n",
    "            self.cfg.use_intermediate = 0\n",
    "            self.cfg.data_split = \"norm\"\n",
    "            self.cfg.inner_view_num = 4\n",
    "            # assert self.cfg.data_rate in [0.2, 0.5, 0.8]\n",
    "            self.cfg.w_name = False\n",
    "            self.cfg.w_char = False\n",
    "            # 不能使用文本信息\n",
    "            self.cfg.use_surface = 0\n",
    "            data_split_name = f\"{self.cfg.data_rate}_\"\n",
    "        else:\n",
    "            # DBP数据集，可能存在surface\n",
    "            data_split_name = f\"{self.cfg.data_split}_\"\n",
    "            if self.cfg.w_name and self.cfg.w_char:\n",
    "                data_split_name = f\"{data_split_name}with_surface_\"\n",
    "\n",
    "        self.cfg.exp_id = f\"{self.cfg.model_name}_{self.cfg.data_choice}_{data_split_name}{self.cfg.exp_id}\"\n",
    "        self.cfg.data_path = osp.join(self.data_root, self.cfg.data_path)\n",
    "        self.cfg.dump_path = osp.join(self.cfg.data_path, self.cfg.dump_path)\n",
    "        if self.cfg.only_test == 1:\n",
    "            self.save_model = 0\n",
    "            # 测试不需要并行\n",
    "            self.dist = 0\n",
    "\n",
    "        # 信息更多了容易拟合\n",
    "        # 跨语言的数据容易出现过拟合\n",
    "        if self.cfg.model_name not in [\"MEAformer\", \"MSNEA\", \"EVA\", \"MCLEA\"]:\n",
    "            if (self.cfg.data_choice == \"DBP15K\" and (self.cfg.w_name or self.cfg.w_char)):\n",
    "                # 六种模态\n",
    "\n",
    "                self.cfg.epoch = min(800, self.cfg.epoch)\n",
    "                self.cfg.il_start = min(500, self.cfg.il_start)\n",
    "                self.cfg.eval_epoch = min(50, self.cfg.eval_epoch)\n",
    "\n",
    "            if self.cfg.attr_dim >= 300:\n",
    "                self.cfg.epoch = min(1000, self.cfg.epoch)\n",
    "                self.cfg.il_start = min(500, self.cfg.il_start)\n",
    "                self.cfg.eval_epoch = min(50, self.cfg.eval_epoch)\n",
    "\n",
    "        # --------- MSNEA -----------\n",
    "        self.cfg.dim = self.cfg.attr_dim\n",
    "\n",
    "        # --------- MEAformer -----------\n",
    "        self.cfg.max_position_embeddings = self.cfg.inner_view_num + 1\n",
    "        assert self.cfg.hidden_size == self.cfg.attr_dim\n",
    "\n",
    "        if self.cfg.enable_sota:\n",
    "            if self.cfg.il:\n",
    "                self.cfg.eval_epoch = max(2, self.cfg.eval_epoch)\n",
    "                self.cfg.weight_decay = max(0.0005, self.cfg.weight_decay)\n",
    "                if self.cfg.data_rate > 0.5:\n",
    "                    self.cfg.weight_decay = max(0.001, self.cfg.weight_decay)\n",
    "                if self.cfg.data_choice == \"DBP15K\":\n",
    "                    if not self.cfg.use_surface:\n",
    "                        self.cfg.weight_decay = max(0.001, self.cfg.weight_decay)\n",
    "            else:\n",
    "                if self.cfg.data_choice == \"DBP15K\" or \"FBYG\" in self.cfg.data_choice:\n",
    "                    self.cfg.epoch = 250\n",
    "                else:\n",
    "                    self.cfg.epoch = 500\n",
    "        return self.cfg\n",
    "    \n",
    "cfg = cfg()\n",
    "cfg.get_args()\n",
    "\n",
    "cfg.cfg.model_name = 'MCLEA'\n",
    "\n",
    "\n",
    "cfg.cfg.data_rate = 0.8\n",
    "cfg.cfg.data_choice = 'DBP15K'\n",
    "cfg.cfg.data_split = 'zh_en'\n",
    "\n",
    "# cfg.cfg.data_rate = 0.8\n",
    "# cfg.cfg.data_choice = 'FBYG15K'\n",
    "# cfg.cfg.data_split = 'norm'\n",
    "\n",
    "cfg.cfg.num_layer=3\n",
    "\n",
    "# cfg.cfg.il = True\n",
    "# cfg.cfg.il_start = 20\n",
    "cfg.cfg.epoch = 200\n",
    "cfg.cfg.lr = 1e-3\n",
    "\n",
    "cfgs = cfg.update_train_configs()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "902e3fb4",
   "metadata": {
    "code_folding": [
     13,
     107
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from numpy import cov\n",
    "from numpy import trace\n",
    "from numpy import iscomplexobj\n",
    "from numpy.random import random\n",
    "from scipy.linalg import sqrtm\n",
    "\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from src.utils import pairwise_distances, csls_sim\n",
    "\n",
    "class MutualVAE(nn.Module):\n",
    "\n",
    "    def __init__(self, in_dim, hidden_dims, latent_dim=None, **kwargs):\n",
    "        super(MutualVAE, self).__init__()\n",
    "\n",
    "        if latent_dim:\n",
    "            self.latent_dim = latent_dim\n",
    "        else:\n",
    "            self.latent_dim = hidden_dims[-1]\n",
    "\n",
    "        modules = []\n",
    "\n",
    "        # encoder\n",
    "        for h_dim in hidden_dims:\n",
    "            modules.append(\n",
    "                nn.Sequential(\n",
    "                    nn.Linear(in_dim, h_dim),\n",
    "                    nn.LeakyReLU()\n",
    "                )\n",
    "            )\n",
    "            in_dim = h_dim\n",
    "\n",
    "        self.encoder = nn.Sequential(*modules)\n",
    "\n",
    "        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)\n",
    "        self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)\n",
    "\n",
    "        # decoder\n",
    "\n",
    "        modules = []\n",
    "\n",
    "        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1])\n",
    "\n",
    "        hidden_dims.reverse()\n",
    "\n",
    "        for i in range(len(hidden_dims) - 1):\n",
    "            modules.append(\n",
    "                nn.Sequential(\n",
    "                    nn.Linear(hidden_dims[i], hidden_dims[i+1]),\n",
    "                    nn.LeakyReLU()\n",
    "                )\n",
    "            )\n",
    "\n",
    "        self.decoder = nn.Sequential(*modules)\n",
    "\n",
    "    def encode(self, x):\n",
    "\n",
    "        x = self.encoder(x)\n",
    "\n",
    "        mu_x = self.fc_mu(x)\n",
    "        log_var_x = self.fc_var(x)\n",
    "\n",
    "        return (mu_x, log_var_x)\n",
    "\n",
    "    def decode(self, z, reparameterize=False):\n",
    "        if reparameterize:\n",
    "            z = self.reparameterize(*z)\n",
    "\n",
    "        z = self.decoder_input(z)\n",
    "        x = self.decoder(z)\n",
    "\n",
    "        return x\n",
    "\n",
    "    def reparameterize(self, mu, logvar):\n",
    "        std = torch.exp(.5 * logvar)\n",
    "        eps = torch.rand_like(std)\n",
    "        return eps * std + mu\n",
    "\n",
    "    def forward(self, embs, train_links, left_ents, right_ents):\n",
    "\n",
    "        # train_links (x,y) used for flows : x->y and y->x, supervised learning\n",
    "        # left_ents, right_ents used for flows : x->x and y->y, self-supervised learning\n",
    "\n",
    "        # flows: x->y and y->x\n",
    "        x = embs[train_links[:, 0]]\n",
    "        y = embs[train_links[:, 1]]\n",
    "\n",
    "        z_xy, z_yx = self.encode(x), self.encode(y)\n",
    "        y_xy, x_yx = self.decode(z_xy, reparameterize=True), self.decode(\n",
    "            z_yx, reparameterize=True)\n",
    "\n",
    "        # flows : x->x and y->y\n",
    "        sampled_x, sampled_y = embs[left_ents], embs[right_ents]\n",
    "        z_xx, z_yy = self.encode(sampled_x), self.encode(sampled_y)\n",
    "        x_xx, y_yy = self.decode(z_xx, reparameterize=True), self.decode(\n",
    "            z_yy, reparameterize=True)\n",
    "\n",
    "        flows = {'xx': (sampled_x, z_xx, x_xx),\n",
    "                 'yy': (sampled_y, z_yy, y_yy),\n",
    "                 'xy': (x, z_xy, y_xy),\n",
    "                 'yx': (y, z_yx, x_yx)}\n",
    "\n",
    "        return flows\n",
    "\n",
    "class NeighborDecoder(nn.Module):\n",
    "    def __init__(self, sub_dim, ent_embs) -> None:\n",
    "        super().__init__()\n",
    "\n",
    "        self.ent_embs = None\n",
    "#         self.activation = nn.Sigmoid()\n",
    "        self.subdecoder = nn.Sequential(nn.Linear(sub_dim, sub_dim),\n",
    "                                               nn.Tanh(),\n",
    "                                               nn.Dropout(0.5),\n",
    "                                               nn.BatchNorm1d(sub_dim),\n",
    "                                               nn.Linear(sub_dim, sub_dim),\n",
    "                                               nn.Tanh(),\n",
    "                                               nn.Dropout(0.5),\n",
    "                                               nn.BatchNorm1d(sub_dim),\n",
    "                                            )\n",
    "        self.register_parameter('bias', nn.Parameter(torch.zeros(ent_embs.shape[0])))\n",
    "\n",
    "    def forward(self, x):\n",
    "        output = self.subdecoder(x)\n",
    "        output = x @ self.ent_embs.T + self.bias\n",
    "        return F.tanh(output)\n",
    "\n",
    "\n",
    "\n",
    "class GEEA(nn.Module):\n",
    "\n",
    "    def __init__(self, args, kgs, concrete_features, sub_dims, joint_dim, ent_embs, fusion_layer):\n",
    "        super().__init__()\n",
    "        self.args = args\n",
    "        self.kgs = kgs\n",
    "        self.latent_dim=sub_dims[0]\n",
    "\n",
    "        self.subgenerators = []\n",
    "        self.subdecoders = []\n",
    "\n",
    "        self.num_none_concrete_feature = 0\n",
    "        for i, sub_dim, concrete_feature in zip(range(len(sub_dims)), sub_dims, concrete_features):\n",
    "            if concrete_feature is not None:\n",
    "                subgenerator = MutualVAE(in_dim=sub_dim,\n",
    "                                        hidden_dims=[sub_dim,]*args.num_layers,\n",
    "                                        latent_dim=sub_dim)\n",
    "                self.subgenerators.append(subgenerator)\n",
    "                \n",
    "                if i==-1:\n",
    "                    subdecoder = NeighborDecoder(sub_dim, ent_embs)\n",
    "                else:\n",
    "                    subdecoder = nn.Sequential(nn.Linear(sub_dim, 1000),\n",
    "                                               nn.Tanh(),\n",
    "                                               nn.Dropout(0.5),\n",
    "                                               nn.BatchNorm1d(1000),\n",
    "                                               nn.Linear(1000, concrete_feature.shape[-1]),\n",
    "#                                                nn.Softmax()\n",
    "                                            )\n",
    "                self.subdecoders.append(subdecoder)\n",
    "            else:\n",
    "                self.num_none_concrete_feature += 1\n",
    "                \n",
    "#             break\n",
    "        \n",
    "        \n",
    "\n",
    "\n",
    "        self.subgenerators = nn.ModuleList(self.subgenerators)\n",
    "#         self.subgenerators = nn.ModuleList([self.subgenerators[0], ]*len(sub_dims))\n",
    "        self.subdecoders = nn.ModuleList(self.subdecoders)\n",
    "\n",
    "        # for distribtuion matching\n",
    "        self.sample_prop = 1./7\n",
    "        # print(self.kgs.keys())\n",
    "        self.number_samples = int(\n",
    "            len(self.kgs['left_ents']) * self.sample_prop)\n",
    "\n",
    "        # for prior and post reconstruction\n",
    "        self.prior_reconstruction_loss_func = nn.BCELoss()#reduction='sum'\n",
    "        self.post_reconstruction_loss_func = nn.MSELoss()\n",
    "        self.concrete_features = concrete_features\n",
    "        \n",
    "        self.fusion_layer = fusion_layer\n",
    "\n",
    "        # xx, yy, xy, yx\n",
    "        self.flow_weights = [1, 1, 1, 1]\n",
    "\n",
    "    def distribution_match_loss(self, outputs):\n",
    "\n",
    "        def kld_loss(mu, logvar, kld_weight=self.sample_prop):\n",
    "            return kld_weight * torch.mean(-.5 * torch.sum(1 + logvar - mu**2 - logvar.exp(), dim=1), dim=0)\n",
    "\n",
    "        # output = (x, z=(mu, var), reconstrctued_x)\n",
    "        xx_distribution_match_loss = [\n",
    "            kld_loss(*output['xx'][1]) for output in outputs]\n",
    "        yy_distribution_match_loss = [\n",
    "            kld_loss(*output['yy'][1]) for output in outputs]\n",
    "\n",
    "        return sum(xx_distribution_match_loss) + sum(yy_distribution_match_loss) \n",
    "\n",
    "    \n",
    "    def sampled_bce_loss(self, predicted, label, neg_ratio=5):\n",
    "        pos_mask = torch.where(label>0)\n",
    "\n",
    "\n",
    "        neg = torch.randint(high=label.shape[-1], size=(len(pos_mask[0])*neg_ratio,))\n",
    "        neg_mask = [pos_mask[0].repeat(neg_ratio), neg]\n",
    "\n",
    "        predicted_pos = predicted[pos_mask]\n",
    "        label_pos = torch.ones_like(predicted_pos)\n",
    "        predicted_neg = predicted[neg_mask]\n",
    "        label_neg = torch.zeros_like(predicted_neg)\n",
    "\n",
    "        loss = self.prior_reconstruction_loss_func(predicted_pos, label_pos) + self.prior_reconstruction_loss_func(predicted_neg, label_neg) / neg_ratio\n",
    "        return loss\n",
    "    \n",
    "    def sampled_crossentropy_loss(self, predicted, label, neg_ratio=1):\n",
    "        pos_mask, labels = torch.where(label>0)\n",
    "        sampled = torch.randperm(len(pos_mask))[:3500]\n",
    "        pos_mask, labels = pos_mask[sampled], labels[sampled]\n",
    "        \n",
    "        predicted_pos = predicted[pos_mask]\n",
    "        \n",
    "        loss = F.cross_entropy(predicted_pos.cuda(), labels.cuda()) \n",
    "        return loss\n",
    "    \n",
    "    \n",
    "    def prior_reconstruction_loss(self, outputs, train_links, left_ents, right_ents):\n",
    "\n",
    "        prior_reconstruction_loss = []\n",
    "        \n",
    "        for output, subdecoder, concrete_feature in  zip(outputs, self.subdecoders, self.concrete_features):\n",
    "            \n",
    "            \n",
    "            \n",
    "            reconstructed_xx = subdecoder(output['xx'][-1])\n",
    "            reconstructed_yy = subdecoder(output['yy'][-1])\n",
    "            reconstructed_xy = subdecoder(output['xy'][-1])\n",
    "            reconstructed_yx = subdecoder(output['yx'][-1])\n",
    "            \n",
    "#             reconstructed_xx = subdecoder(output['xx'][0])\n",
    "#             reconstructed_yy = subdecoder(output['yy'][0])\n",
    "#             reconstructed_xy = subdecoder(output['xy'][0])\n",
    "#             reconstructed_yx = subdecoder(output['yx'][0])\n",
    "            \n",
    "            concrete_xx = concrete_feature[left_ents].cuda()#.argmax(dim=-1).cuda()#*0.9+1./len(left_ents)\n",
    "            concrete_yy = concrete_feature[right_ents].cuda()#.argmax(dim=-1).cuda()#*0.9+1./len(right_ents)\n",
    "            concrete_xy = concrete_feature[train_links[:, 1]].cuda()#.argmax(dim=-1).cuda()#*0.9+1./len(train_links)\n",
    "            concrete_yx = concrete_feature[train_links[:, 0]].cuda()#.argmax(dim=-1).cuda()#*0.9+1./len(train_links)\n",
    "\n",
    " \n",
    "            loss_xx = self.sampled_crossentropy_loss(\n",
    "                reconstructed_xx, concrete_xx)\n",
    "            loss_yy = self.sampled_crossentropy_loss(\n",
    "                reconstructed_yy, concrete_yy)\n",
    "            loss_xy = self.sampled_crossentropy_loss(\n",
    "                reconstructed_xy, concrete_xy)\n",
    "            loss_yx = self.sampled_crossentropy_loss(\n",
    "                reconstructed_yx, concrete_yx)\n",
    "\n",
    "            loss_list = [loss_xx, loss_yy, loss_xy, loss_yx]\n",
    "\n",
    "            prior_reconstruction_loss += [sum(loss*flow_weight for loss, flow_weight in zip(loss_list, self. flow_weights)), ]\n",
    "        \n",
    "        print('Neighbor-IMG-REL-ATTR:%.3f-%.3f-%.3f-%.3f' % (prior_reconstruction_loss[0].item(), prior_reconstruction_loss[1].item(), prior_reconstruction_loss[2].item(), prior_reconstruction_loss[3].item()))\n",
    "        \n",
    "        return sum(prior_reconstruction_loss)\n",
    "\n",
    "    def re_fusion(self, sub_embs):\n",
    "        # self.fusion_layer.requires_grad_(False)\n",
    "        sub_embs = sub_embs+[None,]*self.num_none_concrete_feature\n",
    "        # self.fusion_layer.requires_grad_(True)\n",
    "        return self.fusion_layer(*sub_embs)\n",
    "    \n",
    "    def reconstruction_loss(self, outputs):\n",
    "        loss = 0.\n",
    "        for output in outputs:\n",
    "            for flow in output.keys():\n",
    "                input_, z, output_ = output[flow]\n",
    "                loss += self.post_reconstruction_loss_func(input_.detach(), output_)\n",
    "        return loss\n",
    "            \n",
    "        \n",
    "\n",
    "    def post_reconstruction_loss(self, outputs, joint_emb, train_links, left_ents, right_ents):\n",
    "\n",
    "        xx, yy, xy, yx = [], [], [], []\n",
    "\n",
    "        for output, subdecoder in zip(outputs, self.subdecoders):\n",
    "            xx.append(output['xx'][-1])\n",
    "            yy.append(output['yy'][-1])\n",
    "            xy.append(output['xy'][-1])\n",
    "            yx.append(output['yx'][-1])\n",
    "\n",
    "        # reconstructed\n",
    "        reconstructed_xx = self.re_fusion(xx)\n",
    "        reconstructed_yy = self.re_fusion(yy)\n",
    "        reconstructed_xy = self.re_fusion(xy)\n",
    "        reconstructed_yx = self.re_fusion(yx)\n",
    "\n",
    "        # the targets\n",
    "        joint_emb = joint_emb.detach()\n",
    "        joint_xx = joint_emb[left_ents]\n",
    "        joint_yy = joint_emb[right_ents]\n",
    "        joint_xy = joint_emb[train_links[:, 1]]\n",
    "        joint_yx = joint_emb[train_links[:, 0]]\n",
    "\n",
    "        # loss\n",
    "        loss_xx = self.post_reconstruction_loss_func(\n",
    "            reconstructed_xx, joint_xx) \n",
    "        loss_yy = self.post_reconstruction_loss_func(\n",
    "            reconstructed_yy, joint_yy) \n",
    "        loss_xy = self.post_reconstruction_loss_func(\n",
    "            reconstructed_xy, joint_xy) \n",
    "        loss_yx = self.post_reconstruction_loss_func(\n",
    "            reconstructed_yx, joint_yx) \n",
    "\n",
    "        return loss_xx + loss_yy + loss_xy + loss_yx\n",
    "    \n",
    "\n",
    "    def encode(self, xs, sub_embs):\n",
    "        sub_embs = [embs for embs in sub_embs if embs is not None]\n",
    "        \n",
    "        x_zs = [subgenerator.encode(embs[xs])\n",
    "                   for embs, subgenerator in zip(sub_embs, self.subgenerators)]\n",
    "        \n",
    "        return x_zs\n",
    "\n",
    "    def decode(self, zs, reparameterize=False):\n",
    "\n",
    "        reconstructed_x = [subgenerator.decode(z, reparameterize=reparameterize)\n",
    "                   for subgenerator, z in zip(self.subgenerators, zs)]\n",
    "\n",
    "        return reconstructed_x\n",
    "    \n",
    "    def sample(self, num):\n",
    "        z = torch.randn(num, self.latent_dim).cuda()\n",
    "        \n",
    "        samples = self.decode(z)\n",
    "        \n",
    "        return samples\n",
    "    \n",
    "    def id2feature(self):\n",
    "        pass\n",
    "\n",
    "\n",
    "    def sample_from_x_to_y(self, xs, sub_embs):\n",
    "#         x = np.random.choice(\n",
    "#             self.kgs['left_ents'], num, replace=False)\n",
    "        \n",
    "        zs = self.encode(xs, sub_embs)\n",
    "\n",
    "        samples = self.decode(zs, reparameterize=True)\n",
    "\n",
    "        return samples\n",
    "    \n",
    "    def recover_img_to_id(self, img_samples, all_imgs):\n",
    "        distances = pairwise_distances(img_samples.cpu(), all_imgs.cpu())\n",
    "#         distances[:, self.kgs['left_ents']] = 1e10\n",
    "#         distances = 1 - csls_sim(1 - distances, self.args.csls_k)\n",
    "        return torch.argmin(distances, dim=-1)\n",
    "\n",
    "    def recover_to_feature(self, samples, ajd_loc=0, img_loc=1, max_counts=10):\n",
    "        features = []\n",
    "        for i, sub_samples, subdecoder in enumerate(zip(samples,self.subdecoders)):\n",
    "            sub_samples = subdecoder(sub_samples)\n",
    "            \n",
    "            if i == ajd_loc:\n",
    "#                 sub_samples[:, self.kgs['left_ents']] = -1e10\n",
    "                features.append(torch.argsort(sub_samples, dim=-1, descending=True)[:, :max_counts].cpu().numpy())\n",
    "            elif i == img_loc:\n",
    "                features.append(self.recover_img_to_id(sub_samples, self.concrete_features[img_loc]).cpu().numpy())\n",
    "            else:\n",
    "                features.append(torch.argsort(sub_samples, dim=-1, descending=True)[:, :max_counts].cpu().numpy())\n",
    "                    \n",
    "                \n",
    "        return features\n",
    "\n",
    "\n",
    "    def forward(self, train_links, sub_embs, joint_emb):\n",
    "        sub_embs = [embs for embs in sub_embs if embs is not None]\n",
    "        self.subdecoders[0].ent_embs=sub_embs[0]\n",
    "\n",
    "        # for self-supervised learning\n",
    "        left_ents = np.random.choice(\n",
    "            self.kgs['left_ents'], self.number_samples, replace=False)\n",
    "        right_ents = np.random.choice(\n",
    "            self.kgs['right_ents'], self.number_samples, replace=False)\n",
    "\n",
    "        outputs = [subgenerator(embs, train_links, left_ents, right_ents)\n",
    "                   for embs, subgenerator in zip(sub_embs, self.subgenerators)]\n",
    "\n",
    "        distribution_match_loss = self.distribution_match_loss(outputs)\n",
    "        prior_reconstruction_loss = self.prior_reconstruction_loss(\n",
    "            outputs, train_links, left_ents, right_ents)\n",
    "        post_reconstruction_loss = self.post_reconstruction_loss(outputs, joint_emb, train_links, left_ents, right_ents)\n",
    "        \n",
    "        reconstruction_loss = self.reconstruction_loss(outputs)\n",
    "\n",
    "        print('DistMatch Loss: %.3f; PriorRec Loss: %.3f; PostRec Loss: %.3f; Rec Loss: %.3f' % (\n",
    "            distribution_match_loss.item(), prior_reconstruction_loss.item(), post_reconstruction_loss.item(), reconstruction_loss.item()))\n",
    "        \n",
    "#         return  distribution_match_loss*0.5 + reconstruction_loss*self.latent_dim + prior_reconstruction_loss \n",
    "#         return distribution_match_loss*0.5 + prior_reconstruction_loss*self.latent_dim  \n",
    "        return distribution_match_loss*0.5 + prior_reconstruction_loss*self.latent_dim + reconstruction_loss*self.latent_dim + post_reconstruction_loss\n",
    "\n",
    "    \n",
    "    \n",
    "# calculate frechet inception distance\n",
    "def calculate_fid(act1, act2):\n",
    "#     print(act1, act2)\n",
    "     # calculate mean and covariance statistics\n",
    "    mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)\n",
    "    mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)\n",
    "    # calculate sum squared difference between means\n",
    "    ssdiff = np.sum((mu1 - mu2)**2.0)\n",
    "    # calculate sqrt of product between cov\n",
    "    covmean = sqrtm(sigma1.dot(sigma2))\n",
    "    # check and correct imaginary numbers from sqrt\n",
    "    if iscomplexobj(covmean):\n",
    "        covmean = covmean.real\n",
    "    # calculate score\n",
    "    fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)\n",
    "    return fid\n",
    "\n",
    "def eval_genarative_model(es_ill, g, model):\n",
    "    gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb = runner.model.joint_emb_generat(\n",
    "            only_joint=False)\n",
    "    embs = [gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb]\n",
    "    \n",
    "    left, right = es_ill[:, 0], es_ill[:, 1]\n",
    "    \n",
    "    left_z =  g.encode(left, [gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb])\n",
    "    left_emb_y =  g.decode(left_z, reparameterize=True)\n",
    "    right_emb_y = [emb[right] for emb in embs if emb is not None]\n",
    "    with torch.no_grad():\n",
    "        error = [F.mse_loss(ly, ry) for ly, ry in zip(left_emb_y, right_emb_y)]\n",
    "        error = sum(error) / len(error)\n",
    "        \n",
    "    left_concrete_y = [subdecoder(ley) for ley, subdecoder in zip(left_emb_y, g.subdecoders)]\n",
    "    right_concrete_y = [concrete_feature[right] for concrete_feature in g.concrete_features if concrete_feature is not None]\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        rc_error = [F.mse_loss(ly.cpu(), ry.cpu()) for ly, ry in zip(left_emb_y, right_emb_y)]\n",
    "        rc_error = sum(rc_error) / len(rc_error)\n",
    "        \n",
    "        prc_error = [F.mse_loss(ly.cpu(), ry.cpu()) for ly, ry in zip(left_concrete_y, right_concrete_y)]\n",
    "        prc_error = sum(prc_error) / len(prc_error)\n",
    "    \n",
    "    sample_z = torch.randn(len(left), runner.model.geea.latent_dim).repeat(len(left_concrete_y), 1, 1).cuda()\n",
    "    sample_emb_y = g.decode(sample_z, reparameterize=False)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        fid = [calculate_fid(sy.cpu().numpy(), ry.cpu().numpy()) for sy, ry in zip(sample_emb_y, right_emb_y)]\n",
    "        fid = sum(fid) / len(fid)\n",
    "        \n",
    "    return rc_error, prc_error, fid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "472058e1",
   "metadata": {
    "code_folding": [
     203
    ]
   },
   "outputs": [],
   "source": [
    "# model\n",
    "\n",
    "import types\n",
    "import torch\n",
    "import transformers\n",
    "import torch.nn.functional as F\n",
    "import torch.nn as nn\n",
    "from torch.nn import CrossEntropyLoss\n",
    "import numpy as np\n",
    "import pdb\n",
    "import math\n",
    "from model.Tool_model import AutomaticWeightedLoss\n",
    "from model.MCLEA_tools import MultiModalEncoder\n",
    "from model.MCLEA_loss import CustomMultiLossLayer, ial_loss, icl_loss\n",
    "from src.utils import pairwise_distances\n",
    "import os.path as osp\n",
    "import json\n",
    "# from geea import GEEA\n",
    "\n",
    "\n",
    "class MCLEA(nn.Module):\n",
    "    def __init__(self, kgs, args):\n",
    "        super().__init__()\n",
    "        self.kgs = kgs\n",
    "        self.args = args\n",
    "        self.img_features = F.normalize(\n",
    "            torch.FloatTensor(kgs[\"images_list\"])).cuda()\n",
    "        self.input_idx = kgs[\"input_idx\"].cuda()\n",
    "        self.adj = kgs[\"adj\"].cuda()\n",
    "        self.rel_features = torch.Tensor(kgs[\"rel_features\"]).cuda()\n",
    "        self.att_features = torch.Tensor(kgs[\"att_features\"]).cuda()\n",
    "        self.name_features = None\n",
    "        self.char_features = None\n",
    "        if kgs[\"name_features\"] is not None:\n",
    "            self.name_features = kgs[\"name_features\"].cuda()\n",
    "            self.char_features = kgs[\"char_features\"].cuda()\n",
    "\n",
    "        img_dim = self._get_img_dim(kgs)\n",
    "\n",
    "        char_dim = kgs[\"char_features\"].shape[1] if self.char_features is not None else 100\n",
    "\n",
    "        # 主要修改的部分在这\n",
    "        # 和EVA没区别，相当于重构了EVA\n",
    "        self.multimodal_encoder = MultiModalEncoder(args=self.args,\n",
    "                                                    ent_num=kgs[\"ent_num\"],\n",
    "                                                    img_feature_dim=img_dim,\n",
    "                                                    char_feature_dim=char_dim,\n",
    "                                                    use_project_head=self.args.use_project_head,\n",
    "                                                    attr_input_dim=kgs[\"att_features\"].shape[1])\n",
    "\n",
    "        self.multi_loss_layer = CustomMultiLossLayer(loss_num=6)  # 6\n",
    "        self.align_multi_loss_layer = CustomMultiLossLayer(loss_num=6)  # 6\n",
    "\n",
    "        self.criterion_cl = icl_loss(\n",
    "            tau=self.args.tau, ab_weight=self.args.ab_weight, n_view=2)\n",
    "        self.criterion_align = ial_loss(tau=self.args.tau2,\n",
    "                                        ab_weight=self.args.ab_weight,\n",
    "                                        zoom=self.args.zoom,\n",
    "                                        reduction=self.args.reduction)\n",
    "\n",
    "        self.concrete_features = [self.adj.cpu().to_dense().bool().float(), self.img_features, self.rel_features.bool().float(), self.att_features.bool().float(),\n",
    "                                   self.name_features, self.char_features]\n",
    "\n",
    "        self.geea = GEEA(args, kgs, self.concrete_features,\n",
    "                         sub_dims=[self.multimodal_encoder.n_units[-1], self.args.img_dim, self.args.attr_dim, self.args.attr_dim,\n",
    "                          self.args.name_dim, self.args.char_dim], joint_dim=self.multimodal_encoder.n_units[-1], ent_embs=self.multimodal_encoder.entity_emb.weight,\n",
    "                         fusion_layer=self.multimodal_encoder)\n",
    "        \n",
    "        # self.concrete_features = [self.img_features, self.rel_features, self.att_features,\n",
    "        #                            self.name_features, self.char_features]\n",
    "\n",
    "        # self.geea = GEEA(args, kgs, self.concrete_features,\n",
    "        #                  sub_dims=[self.args.img_dim, self.args.attr_dim, self.args.attr_dim,\n",
    "        #                   self.args.name_dim, self.args.char_dim], joint_dim=self.multimodal_encoder.n_units[-1],\n",
    "        #                  fusion_layer=self.multimodal_encoder)\n",
    "\n",
    "    def forward(self, batch):\n",
    "        gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb = self.joint_emb_generat(\n",
    "            only_joint=False)\n",
    "\n",
    "        \n",
    "        \n",
    "        \n",
    "        # 和EVA 差不多\n",
    "        # ---- 融合模态内对比学习  ----\n",
    "        #  ICL loss for joint embedding\n",
    "        loss_joi = self.criterion_cl(joint_emb, batch)\n",
    "        # ---- 独立模态内对比学习  ----\n",
    "        # ICL loss for uni-modal embedding\n",
    "        in_loss = self.inner_view_loss(\n",
    "            gph_emb, rel_emb, att_emb, img_emb, name_emb, char_emb, batch)\n",
    "\n",
    "        # 模态融合\n",
    "        # IAL loss for uni-modal embedding\n",
    "        align_loss = self.kl_alignment_loss(\n",
    "            joint_emb, gph_emb, rel_emb, att_emb, img_emb, name_emb, char_emb, batch)\n",
    "\n",
    "        \n",
    "        loss_all = 0.\n",
    "        loss_all += loss_joi + in_loss + align_loss\n",
    "\n",
    "        geea_loss = self.geea(batch, [gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb], joint_emb)\n",
    "        loss_all += geea_loss\n",
    "\n",
    "        # loss_all.backward(retain_graph=True)\n",
    "        weight_raw = self.multimodal_encoder.fusion.weight.reshape(-1).tolist()\n",
    "        # pdb.set_trace()\n",
    "        loss_dic = {\"joint_Intra_modal\": loss_joi.item(\n",
    "        ), \"Intra_modal\": in_loss.item(), \"Inter_modal\": align_loss.item()}\n",
    "        output = {\"loss_dic\": loss_dic, \"emb\": joint_emb, \"weight\": weight_raw}\n",
    "        return loss_all, output\n",
    "\n",
    "    def inner_view_loss(self, gph_emb, rel_emb, att_emb, img_emb, name_emb, char_emb, train_ill):\n",
    "\n",
    "        loss_GCN = self.criterion_cl(\n",
    "            gph_emb, train_ill) if gph_emb is not None else 0\n",
    "        loss_rel = self.criterion_cl(\n",
    "            rel_emb, train_ill) if rel_emb is not None else 0\n",
    "        loss_att = self.criterion_cl(\n",
    "            att_emb, train_ill) if att_emb is not None else 0\n",
    "        loss_img = self.criterion_cl(\n",
    "            img_emb, train_ill) if img_emb is not None else 0\n",
    "        loss_name = self.criterion_cl(\n",
    "            name_emb, train_ill) if name_emb is not None else 0\n",
    "        loss_char = self.criterion_cl(\n",
    "            char_emb, train_ill) if char_emb is not None else 0\n",
    "\n",
    "        total_loss = self.multi_loss_layer(\n",
    "            [loss_GCN, loss_rel, loss_att, loss_img, loss_name, loss_char])\n",
    "        return total_loss\n",
    "\n",
    "    def kl_alignment_loss(self, joint_emb, gph_emb, rel_emb, att_emb, img_emb, name_emb, char_emb, train_ill):\n",
    "        zoom = self.args.zoom\n",
    "        loss_GCN = self.criterion_align(\n",
    "            gph_emb, joint_emb, train_ill) if gph_emb is not None else 0\n",
    "        loss_rel = self.criterion_align(\n",
    "            rel_emb, joint_emb, train_ill) if rel_emb is not None else 0\n",
    "        loss_att = self.criterion_align(\n",
    "            att_emb, joint_emb, train_ill) if att_emb is not None else 0\n",
    "        loss_img = self.criterion_align(\n",
    "            img_emb, joint_emb, train_ill) if img_emb is not None else 0\n",
    "        loss_name = self.criterion_align(\n",
    "            name_emb, joint_emb, train_ill) if name_emb is not None else 0\n",
    "        loss_char = self.criterion_align(\n",
    "            char_emb, joint_emb, train_ill) if char_emb is not None else 0\n",
    "\n",
    "        total_loss = self.align_multi_loss_layer(\n",
    "            [loss_GCN, loss_rel, loss_att, loss_img, loss_name, loss_char]) * zoom\n",
    "        return total_loss\n",
    "\n",
    "    # --------- necessary ---------------\n",
    "    def joint_emb_generat(self, only_joint=True):\n",
    "        gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb = self.multimodal_encoder._emb_generate(self.input_idx,\n",
    "                                                                                                       self.adj,\n",
    "                                                                                                       self.img_features,\n",
    "                                                                                                       self.rel_features,\n",
    "                                                                                                       self.att_features,\n",
    "                                                                                                       self.name_features,\n",
    "                                                                                                       self.char_features)\n",
    "\n",
    "        joint_emb = self.multimodal_encoder(\n",
    "            gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb)\n",
    "\n",
    "        if only_joint:\n",
    "            return joint_emb\n",
    "        else:\n",
    "            return gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb\n",
    "\n",
    "    # --------- share ---------------\n",
    "\n",
    "    def _get_img_dim(self, kgs):\n",
    "        if isinstance(kgs[\"images_list\"], list):\n",
    "            img_dim = kgs[\"images_list\"][0].shape[1]\n",
    "        elif isinstance(kgs[\"images_list\"], np.ndarray) or torch.is_tensor(kgs[\"images_list\"]):\n",
    "            img_dim = kgs[\"images_list\"].shape[1]\n",
    "        return img_dim\n",
    "\n",
    "    # 通用的生成link/刷新训练集方法\n",
    "    def Iter_new_links(self, epoch, left_non_train, final_emb, right_non_train, new_links=[]):\n",
    "        if len(left_non_train) == 0 or len(right_non_train) == 0:\n",
    "            return new_links\n",
    "        distance_list = []\n",
    "        for i in np.arange(0, len(left_non_train), 1000):\n",
    "            d = pairwise_distances(\n",
    "                final_emb[left_non_train[i:i + 1000]], final_emb[right_non_train])\n",
    "            distance_list.append(d)\n",
    "        distance = torch.cat(distance_list, dim=0)\n",
    "        preds_l = torch.argmin(distance, dim=1).cpu().numpy().tolist()\n",
    "        preds_r = torch.argmin(distance.t(), dim=1).cpu().numpy().tolist()\n",
    "        del distance_list, distance, final_emb\n",
    "        if (epoch + 1) % (self.args.semi_learn_step * 10) == self.args.semi_learn_step:\n",
    "            # 在未匹配的里面只要有匹配（互为最近）的就加进去\n",
    "            new_links = [(left_non_train[i], right_non_train[p])\n",
    "                         for i, p in enumerate(preds_l) if preds_r[p] == i]\n",
    "        else:\n",
    "            new_links = [(left_non_train[i], right_non_train[p]) for i, p in enumerate(preds_l) if (\n",
    "                preds_r[p] == i) and ((left_non_train[i], right_non_train[p]) in new_links)]\n",
    "\n",
    "        # if self.args.rank == 0:\n",
    "        #     print(\"[epoch %d] #links in candidate set: %d\" % (epoch, len(new_links)))\n",
    "        return new_links\n",
    "\n",
    "    def data_refresh(self, logger, train_ill, test_ill_, left_non_train, right_non_train, new_links=[]):\n",
    "        if len(new_links) != 0 and (len(left_non_train) != 0 and len(right_non_train) != 0):\n",
    "            new_links_select = new_links\n",
    "            # if len(new_links) >= 5000: new_links = random.sample(new_links, 5000)\n",
    "            train_ill = np.vstack((train_ill, np.array(new_links_select)))\n",
    "            num_true = len([nl for nl in new_links_select if nl in test_ill_])\n",
    "            # remove from left/right_non_train\n",
    "            for nl in new_links_select:\n",
    "                left_non_train.remove(nl[0])\n",
    "                right_non_train.remove(nl[1])\n",
    "\n",
    "            if self.args.rank == 0:\n",
    "                logger.info(f\"#new_links_select:{len(new_links_select)}\")\n",
    "                logger.info(f\"train_ill.shape:{train_ill.shape}\")\n",
    "                logger.info(f\"#true_links: {num_true}\")\n",
    "                logger.info(\n",
    "                    f\"true link ratio: {(100 * num_true / len(new_links_select)):.1f}%\")\n",
    "                logger.info(\n",
    "                    f\"#entity not in train set: {len(left_non_train)} (left) {len(right_non_train)} (right)\")\n",
    "\n",
    "            new_links = []\n",
    "        else:\n",
    "            logger.info(\"len(new_links) is 0\")\n",
    "\n",
    "        return left_non_train, right_non_train, train_ill, new_links\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4f435aa9",
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "# runner\n",
    "import os\n",
    "import os.path as osp\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "from torch.utils.data import DataLoader, RandomSampler\n",
    "from torch.cuda.amp import GradScaler, autocast\n",
    "from datetime import datetime\n",
    "from easydict import EasyDict as edict\n",
    "from tqdm import tqdm\n",
    "import pdb\n",
    "import pprint\n",
    "import json\n",
    "import pickle\n",
    "from collections import defaultdict\n",
    "\n",
    "from torchlight import initialize_exp, set_seed, get_dump_path\n",
    "from src.data import load_data, Collator_base, EADataset\n",
    "from src.utils import set_optim, Loss_log, pairwise_distances, csls_sim\n",
    "# from model import MEAformer\n",
    "\n",
    "from src.distributed_utils import init_distributed_mode, dist_pdb, is_main_process, reduce_value, cleanup\n",
    "import torch.distributed as dist\n",
    "from torch.nn.parallel import DistributedDataParallel\n",
    "import torch.nn.functional as F\n",
    "import scipy\n",
    "import gc\n",
    "import copy\n",
    "\n",
    "\n",
    "class Runner:\n",
    "    def __init__(self, args, writer=None, logger=None, rank=0):\n",
    "        self.datapath = edict()\n",
    "        self.datapath.log_dir = get_dump_path(args)\n",
    "        self.datapath.model_dir = os.path.join(self.datapath.log_dir, 'model')\n",
    "        self.rank = rank\n",
    "        self.args = args\n",
    "        self.writer = writer\n",
    "        self.logger = logger\n",
    "        self.scaler = GradScaler()\n",
    "        self.model_list = []\n",
    "        set_seed(args.random_seed)\n",
    "        self.data_init()\n",
    "        self.model_choise()\n",
    "        set_seed(args.random_seed)\n",
    "        \n",
    "        self.csv_results = []\n",
    "\n",
    "        if self.args.only_test:\n",
    "            self.dataloader_init(test_set=self.test_set)\n",
    "        else:\n",
    "            self.dataloader_init(train_set=self.train_set, eval_set=self.eval_set, test_set=self.test_set)\n",
    "            if self.args.dist:\n",
    "                self.model_sync()\n",
    "            else:\n",
    "                self.model_list = [self.model]\n",
    "            if self.args.il:\n",
    "                assert self.args.il_start < self.args.epoch\n",
    "                train_epoch_1_stage = self.args.il_start\n",
    "            else:\n",
    "                train_epoch_1_stage = self.args.epoch\n",
    "            self.optim_init(self.args, total_epoch=train_epoch_1_stage)\n",
    "\n",
    "    def model_sync(self):\n",
    "        folder = osp.join(self.args.data_path, \"tmp\")\n",
    "        if not os.path.exists(folder):\n",
    "            os.makedirs(folder)\n",
    "        checkpoint_path = osp.join(folder, \"initial_weights.pt\")\n",
    "        if self.rank == 0:\n",
    "            torch.save(self.model.state_dict(), checkpoint_path)\n",
    "        dist.barrier()\n",
    "        self.model = self._model_sync(self.model, checkpoint_path)\n",
    "\n",
    "    def _model_sync(self, model, checkpoint_path):\n",
    "        model.load_state_dict(torch.load(checkpoint_path, map_location=self.args.device))\n",
    "        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(self.args.device)\n",
    "        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.args.gpu], find_unused_parameters=True)\n",
    "        self.model_list.append(model)\n",
    "        model = model.module\n",
    "        return model\n",
    "\n",
    "    def model_choise(self):\n",
    "        assert self.args.model_name in [\"EVA\", \"MCLEA\", \"MSNEA\", \"MEAformer\"]\n",
    "        if self.args.model_name == \"EVA\":\n",
    "            self.model = EVA(self.KGs, self.args)\n",
    "        elif self.args.model_name == \"MCLEA\":\n",
    "            self.model = MCLEA(self.KGs, self.args)\n",
    "        elif self.args.model_name == \"MSNEA\":\n",
    "            self.model = MSNEA(self.KGs, self.args)\n",
    "        elif self.args.model_name == \"MEAformer\":\n",
    "            self.model = MEAformer(self.KGs, self.args)\n",
    "\n",
    "        self.model = self._load_model(self.model, model_name=self.args.model_name_save)\n",
    "\n",
    "        total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)\n",
    "        self.logger.info(f\"total params num: {total_params}\")\n",
    "\n",
    "    def optim_init(self, opt, total_step=None, total_epoch=None, accumulation_step=None):\n",
    "        step_per_epoch = len(self.train_dataloader)\n",
    "        if total_epoch is not None:\n",
    "            opt.total_steps = int(step_per_epoch * total_epoch)\n",
    "        else:\n",
    "            opt.total_steps = int(step_per_epoch * opt.epoch) if total_step is None else int(total_step)\n",
    "        opt.warmup_steps = int(opt.total_steps * 0.15)\n",
    "\n",
    "        if self.rank == 0 and total_step is None:\n",
    "            self.logger.info(f\"warmup_steps: {opt.warmup_steps}\")\n",
    "            self.logger.info(f\"total_steps: {opt.total_steps}\")\n",
    "            self.logger.info(f\"weight_decay: {opt.weight_decay}\")\n",
    "        freeze_part = []\n",
    "\n",
    "        self.optimizer, self.scheduler = set_optim(opt, self.model_list, freeze_part, accumulation_step)\n",
    "\n",
    "    def data_init(self):\n",
    "        self.KGs, self.non_train, self.train_set, self.eval_set, self.test_set, self.test_ill_ = load_data(self.logger, self.args)\n",
    "        self.train_ill = self.train_set.data\n",
    "        self.eval_left = torch.LongTensor(self.eval_set[:, 0].squeeze()).cuda()\n",
    "        self.eval_right = torch.LongTensor(self.eval_set[:, 1].squeeze()).cuda()\n",
    "        if self.test_set is not None:\n",
    "            self.test_left = torch.LongTensor(self.test_ill[:, 0].squeeze()).cuda()\n",
    "            self.test_right = torch.LongTensor(self.test_ill[:, 1].squeeze()).cuda()\n",
    "\n",
    "        self.eval_sampler = None\n",
    "        if self.args.dist and not self.args.only_test:\n",
    "            self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.train_set)\n",
    "            self.eval_sampler = torch.utils.data.distributed.DistributedSampler(self.eval_set)\n",
    "            if self.test_set is not None:\n",
    "                self.test_sampler = torch.utils.data.distributed.DistributedSampler(self.test_set)\n",
    "\n",
    "    def dataloader_init(self, train_set=None, eval_set=None, test_set=None):\n",
    "        bs = self.args.batch_size\n",
    "        collator = Collator_base(self.args)\n",
    "        if self.args.dist and not self.args.only_test:\n",
    "            self.args.workers = min([os.cpu_count(), self.args.batch_size, self.args.workers])\n",
    "            if train_set is not None:\n",
    "                self.train_dataloader = self._dataloader_dist(train_set, self.train_sampler, bs, collator)\n",
    "            if test_set is not None:\n",
    "                self.test_dataloader = self._dataloader_dist(test_set, self.test_sampler, bs, collator)\n",
    "            if eval_set is not None:\n",
    "                self.eval_dataloader = self._dataloader_dist(eval_set, self.eval_sampler, bs, collator)\n",
    "        else:\n",
    "            self.args.workers = min([os.cpu_count(), self.args.batch_size, self.args.workers])\n",
    "            if train_set is not None:\n",
    "                self.train_dataloader = self._dataloader(train_set, bs, collator)\n",
    "            if test_set is not None:\n",
    "                self.test_dataloader = self._dataloader(test_set, bs, collator)\n",
    "            if eval_set is not None:\n",
    "                self.eval_dataloader = self._dataloader(eval_set, bs, collator)\n",
    "\n",
    "    def _dataloader_dist(self, train_set, train_sampler, batch_size, collator):\n",
    "        train_dataloader = DataLoader(\n",
    "            train_set,\n",
    "            sampler=train_sampler,\n",
    "            pin_memory=True,\n",
    "            num_workers=self.args.workers,\n",
    "            persistent_workers=True,  # True\n",
    "            drop_last=True,\n",
    "            batch_size=batch_size,\n",
    "            collate_fn=collator\n",
    "        )\n",
    "        return train_dataloader\n",
    "\n",
    "    def _dataloader(self, train_set, batch_size, collator):\n",
    "        train_dataloader = DataLoader(\n",
    "            train_set,\n",
    "            num_workers=self.args.workers,\n",
    "            persistent_workers=True,  # True\n",
    "            shuffle=(self.args.only_test == 0),\n",
    "            # drop_last=(self.args.only_test == 0),\n",
    "            drop_last=False,\n",
    "            batch_size=batch_size,\n",
    "            collate_fn=collator\n",
    "        )\n",
    "        return train_dataloader\n",
    "\n",
    "    def run(self):\n",
    "        self.loss_log = Loss_log()\n",
    "        self.curr_loss = 0.\n",
    "        self.lr = self.args.lr\n",
    "        self.curr_loss_dic = defaultdict(float)\n",
    "        self.weight = [1, 1, 1, 1, 1, 1]\n",
    "        self.loss_weight = [1, 1]\n",
    "        self.loss_item = 99999.\n",
    "        self.step = 1\n",
    "        self.epoch = 0\n",
    "        self.new_links = []\n",
    "        self.best_model_wts = None\n",
    "\n",
    "        self.best_mrr = 0\n",
    "\n",
    "        self.early_stop_init = 1000\n",
    "        self.early_stop_count = self.early_stop_init\n",
    "        self.stage = 0\n",
    "\n",
    "        with tqdm(total=self.args.epoch) as _tqdm:\n",
    "            for i in range(self.args.epoch):\n",
    "                # _tqdm.set_description(f'Train | epoch {i} Loss {self.loss_log.get_loss():.5f} Acc {self.loss_log.get_acc()*100:.3f}%')\n",
    "                if self.args.dist and not self.args.only_test:\n",
    "                    self.train_sampler.set_epoch(i)\n",
    "                # -------------------------------\n",
    "                self.epoch = i\n",
    "                if self.args.il and (self.epoch == self.args.il_start and self.stage == 0) or (self.early_stop_count <= 0 and self.epoch <= self.args.il_start):\n",
    "                    if self.early_stop_count <= 0:\n",
    "                        logger.info(f\"Early stop in epoch {self.epoch}... Begin iteration....\")\n",
    "                    self.stage = 1\n",
    "                    self.early_stop_init = 2000\n",
    "                    self.early_stop_count = self.early_stop_init\n",
    "\n",
    "                    self.eval_epoch = 1\n",
    "\n",
    "                    self.step = 1\n",
    "                    self.args.lr = self.args.lr / 5\n",
    "                    self.optim_init(self.args, total_epoch=(self.args.epoch - self.args.il_start) * 3)\n",
    "                    if self.best_model_wts is not None:\n",
    "                        self.logger.info(\"load from the best model before IL... \")\n",
    "                        self.model.load_state_dict(self.best_model_wts)\n",
    "                    name = self._save_name_define()\n",
    "                    self.test(save_name=f\"{name}_test_ep{self.args.epoch}_no_iter\")\n",
    "                    if self.rank == 0:\n",
    "                        if not self.args.only_test and self.args.save_model:\n",
    "                            self._save_model(self.model, input_name=f\"{name}_non_iter\")\n",
    "\n",
    "                if self.stage == 1 and (self.epoch + 1) % self.args.semi_learn_step == 0 and self.args.il:\n",
    "                    self.il_for_ea()\n",
    "\n",
    "                if self.stage == 1 and (self.epoch + 1) % (self.args.semi_learn_step * 10) == 0 and len(self.new_links) != 0 and self.args.il:\n",
    "                    self.il_for_data_ref()\n",
    "\n",
    "                self.train(_tqdm)\n",
    "                self.loss_log.update(self.curr_loss)\n",
    "                self.loss_item = self.loss_log.get_loss()\n",
    "                _tqdm.set_description(f'Train | Ep [{self.epoch}/{self.args.epoch}] Step [{self.step}/{self.args.total_steps}] LR [{self.lr:.5f}] Loss {self.loss_log.get_loss():.5f} ')\n",
    "                self.update_loss_log()\n",
    "                if (i + 1) % self.args.eval_epoch == 0:\n",
    "                    self.eval()\n",
    "                    if self.args.es:\n",
    "                        self.eval_es()\n",
    "                _tqdm.update(1)\n",
    "                if self.stage == 1 and self.early_stop_count <= 0:\n",
    "                    logger.info(f\"Early stop in epoch {self.epoch}\")\n",
    "                    break\n",
    "\n",
    "#         name = self._save_name_define()\n",
    "#         if self.best_model_wts is not None:\n",
    "#             self.logger.info(\"load from the best model before final testing ... \")\n",
    "#             self.model.load_state_dict(self.best_model_wts)\n",
    "#         self.test(save_name=f\"{name}_test_ep{self.args.epoch}\")\n",
    "\n",
    "#         if self.rank == 0:\n",
    "#             self.logger.info(f\"min loss {self.loss_log.get_min_loss()}\")\n",
    "#             if not self.args.only_test and self.args.save_model:\n",
    "#                 self._save_model(self.model, input_name=name)\n",
    "    \n",
    "    def eval_es(self):\n",
    "        es_ill = self.model.geea.kgs['es_ill']\n",
    "        g = self.model.geea\n",
    "        rc_error, prc_error, fid = eval_genarative_model(es_ill, g, self.model)\n",
    "        \n",
    "        print('ES RESULTS: PRE:%.6f;\\t RE:%.6f;\\t FID:%.6f' % (prc_error, rc_error, fid))\n",
    "    \n",
    "    def il_for_ea(self):\n",
    "        with torch.no_grad():\n",
    "            if self.args.model_name in [\"MEAformer\"]:\n",
    "                final_emb, weight_norm = self.model.joint_emb_generat()\n",
    "            else:\n",
    "                final_emb = self.model.joint_emb_generat()\n",
    "            final_emb = F.normalize(final_emb)\n",
    "            self.new_links = self.model.Iter_new_links(self.epoch, self.non_train[\"left\"], final_emb, self.non_train[\"right\"], new_links=self.new_links)\n",
    "            if (self.epoch + 1) % (self.args.semi_learn_step * 5) == 0:\n",
    "                self.logger.info(f\"[epoch {self.epoch}] #links in candidate set: {len(self.new_links)}\")\n",
    "\n",
    "    def il_for_data_ref(self):\n",
    "        self.non_train[\"left\"], self.non_train[\"right\"], self.train_ill, self.new_links = self.model.data_refresh(\n",
    "            self.logger, self.train_ill, self.test_ill_, self.non_train[\"left\"], self.non_train[\"right\"], new_links=self.new_links)\n",
    "        set_seed(self.args.random_seed)\n",
    "        self.train_set = EADataset(self.train_ill)\n",
    "        self.dataloader_init(train_set=self.train_set)\n",
    "        # one time train\n",
    "\n",
    "    def _save_name_define(self):\n",
    "        prefix = \"\"\n",
    "        if self.args.dist:\n",
    "            prefix = f\"dist_{prefix}\"\n",
    "        if self.args.il:\n",
    "            prefix = f\"il{self.args.epoch-self.args.il_start}_b{self.args.il_start}_{prefix}\"\n",
    "        name = f'{self.args.exp_id}_{prefix}'\n",
    "        return name\n",
    "\n",
    "    def train(self, _tqdm):\n",
    "        self.model.train()\n",
    "        curr_loss = 0.\n",
    "        self.loss_log.acc_init()\n",
    "        accumulation_steps = self.args.accumulation_steps\n",
    "        # torch.cuda.empty_cache()\n",
    "        for batch in self.train_dataloader:\n",
    "            loss, output = self.model(batch)\n",
    "            loss = loss / accumulation_steps\n",
    "            self.scaler.scale(loss).backward()\n",
    "            if self.args.dist:\n",
    "                loss = reduce_value(loss, average=True)\n",
    "            self.step += 1\n",
    "            if not self.args.dist or is_main_process():\n",
    "                curr_loss += loss.item()\n",
    "                self.output_statistic(loss, output)\n",
    "\n",
    "            if self.step % accumulation_steps == 0:\n",
    "                self.scaler.unscale_(self.optimizer)\n",
    "                for model in self.model_list:\n",
    "                    torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.clip)\n",
    "                scale = self.scaler.get_scale()\n",
    "                self.scaler.step(self.optimizer)\n",
    "                self.scaler.update()\n",
    "                skip_lr_sched = (scale > self.scaler.get_scale())\n",
    "                if not skip_lr_sched:\n",
    "                    self.scheduler.step()\n",
    "\n",
    "                if not self.args.dist or is_main_process():\n",
    "                    self.lr = self.scheduler.get_last_lr()[-1]\n",
    "                    self.writer.add_scalars(\"lr\", {\"lr\": self.lr}, self.step)\n",
    "                for model in self.model_list:\n",
    "                    model.zero_grad(set_to_none=True)\n",
    "\n",
    "            if self.args.dist:\n",
    "                torch.cuda.synchronize(self.args.device)\n",
    "\n",
    "        return curr_loss\n",
    "\n",
    "    def output_statistic(self, loss, output):\n",
    "        self.curr_loss += loss.item()\n",
    "        if output is None:\n",
    "            return\n",
    "        for key in output['loss_dic'].keys():\n",
    "            self.curr_loss_dic[key] += output['loss_dic'][key]\n",
    "        if 'weight' in output and output['weight'] is not None:\n",
    "            self.weight = output['weight']\n",
    "        if 'loss_weight' in output and output['loss_weight'] is not None:\n",
    "            self.loss_weight = output['loss_weight']\n",
    "\n",
    "    def update_loss_log(self):\n",
    "        vis_dict = {\"train_loss\": self.curr_loss}\n",
    "        vis_dict.update(self.curr_loss_dic)\n",
    "        self.writer.add_scalars(\"loss\", vis_dict, self.step)\n",
    "\n",
    "        if self.weight is not None:\n",
    "            weight_dic = {}\n",
    "            weight_dic[\"img\"] = self.weight[0]\n",
    "            weight_dic[\"attr\"] = self.weight[1]\n",
    "            weight_dic[\"rel\"] = self.weight[2]\n",
    "            weight_dic[\"graph\"] = self.weight[3]\n",
    "            if self.args.w_name or self.args.w_char:\n",
    "                weight_dic[\"name\"] = self.weight[4]\n",
    "                weight_dic[\"char\"] = self.weight[5]\n",
    "            self.writer.add_scalars(\"modal_weight\", weight_dic, self.step)\n",
    "\n",
    "        if self.loss_weight is not None and self.loss_weight != [1, 1]:\n",
    "            weight_dic = {}\n",
    "            weight_dic[\"mask\"] = 1 / (self.loss_weight[0]**2)\n",
    "            weight_dic[\"kpi\"] = 1 / (self.loss_weight[1]**2)\n",
    "            self.writer.add_scalars(\"loss_weight\", weight_dic, self.step)\n",
    "\n",
    "        self.curr_loss = 0.\n",
    "        for key in self.curr_loss_dic:\n",
    "            self.curr_loss_dic[key] = 0.\n",
    "\n",
    "    def eval(self, last_epoch=False, save_name=\"\"):\n",
    "        test_left = self.eval_left\n",
    "        test_right = self.eval_right\n",
    "        self.model.eval()\n",
    "        self._test(test_left, test_right, last_epoch=last_epoch, save_name=save_name)\n",
    "\n",
    "    # one time test\n",
    "    def test(self, save_name=\"\", last_epoch=True):\n",
    "        if self.test_set is None:\n",
    "            test_left = self.eval_left\n",
    "            test_right = self.eval_right\n",
    "        else:\n",
    "            test_left = self.test_left\n",
    "            test_right = self.test_right\n",
    "        self.model.eval()\n",
    "        self.logger.info(\" --------------------- Test result --------------------- \")\n",
    "        self._test(test_left, test_right, last_epoch=last_epoch, save_name=save_name)\n",
    "\n",
    "    def _test(self, test_left, test_right, last_epoch=False, save_name=\"\", loss=None):\n",
    "        with torch.no_grad():\n",
    "            w_normalized = None\n",
    "            if self.args.model_name in [\"MEAformer\"]:\n",
    "                final_emb, weight_norm = self.model.joint_emb_generat()\n",
    "            else:\n",
    "                final_emb = self.model.joint_emb_generat()\n",
    "                weight_norm = None\n",
    "            final_emb = F.normalize(final_emb)\n",
    "\n",
    "        # pdb.set_trace()\n",
    "        top_k = [1, 10, 50]\n",
    "        acc_l2r = np.zeros((len(top_k)), dtype=np.float32)\n",
    "        acc_r2l = np.zeros((len(top_k)), dtype=np.float32)\n",
    "        test_total, test_loss, mean_l2r, mean_r2l, mrr_l2r, mrr_r2l = 0, 0., 0., 0., 0., 0.\n",
    "        if self.args.distance == 2:\n",
    "            distance = pairwise_distances(final_emb[test_left], final_emb[test_right])\n",
    "        elif self.args.distance == 1:\n",
    "            distance = torch.FloatTensor(scipy.spatial.distance.cdist(\n",
    "                final_emb[test_left].cpu().data.numpy(),\n",
    "                final_emb[test_right].cpu().data.numpy(), metric=\"cityblock\"))\n",
    "        if self.args.csls is True:\n",
    "            distance = 1 - csls_sim(1 - distance, self.args.csls_k)\n",
    "\n",
    "        if last_epoch:\n",
    "            to_write = []\n",
    "            test_left_np = test_left.cpu().numpy()\n",
    "            test_right_np = test_right.cpu().numpy()\n",
    "            to_write.append([\"idx\", \"rank\", \"query_id\", \"gt_id\", \"ret1\", \"ret2\", \"ret3\", \"v1\", \"v2\", \"v3\"])\n",
    "        for idx in range(test_left.shape[0]):\n",
    "            values, indices = torch.sort(distance[idx, :], descending=False)\n",
    "            rank = (indices == idx).nonzero(as_tuple=False).squeeze().item()\n",
    "            mean_l2r += (rank + 1)\n",
    "            mrr_l2r += 1.0 / (rank + 1)\n",
    "            for i in range(len(top_k)):\n",
    "                if rank < top_k[i]:\n",
    "                    acc_l2r[i] += 1\n",
    "            if last_epoch:\n",
    "                indices = indices.cpu().numpy()\n",
    "                to_write.append([idx, rank, test_left_np[idx], test_right_np[idx], test_right_np[indices[0]], test_right_np[indices[1]],\n",
    "                                 test_right_np[indices[2]], round(values[0].item(), 4), round(values[1].item(), 4), round(values[2].item(), 4)])\n",
    "        if last_epoch:\n",
    "            import csv\n",
    "            if save_name == \"\":\n",
    "                save_name = self.args.model_name\n",
    "            save_pred_path = osp.join(self.args.data_path, self.args.model_name, f\"{save_name}_pred\")\n",
    "            os.makedirs(save_pred_path, exist_ok=True)\n",
    "            with open(osp.join(save_pred_path, f\"{self.args.model_name}_{self.args.data_choice}_{self.args.data_split}_{self.args.data_rate}_ep{self.args.il_start}_pred.txt\"), \"w\") as f:\n",
    "                wr = csv.writer(f, dialect='excel')\n",
    "                wr.writerows(to_write)\n",
    "            if w_normalized is not None:\n",
    "                with open(osp.join(save_pred_path, f\"{self.args.model_name}_{self.args.data_choice}_{self.args.data_split}_{self.args.data_rate}_ep{self.args.il_start}_wight.json\"), \"w\") as fp:\n",
    "                    json.dump(w_normalized.cpu().tolist(), fp)\n",
    "            if weight_norm is not None:\n",
    "                wight_dic = {\"all\": weight_norm.cpu(), \"left\": weight_norm[test_left].cpu(), \"right\": weight_norm[test_right].cpu()}\n",
    "                with open(osp.join(save_pred_path, f\"{self.args.model_name}_{self.args.data_choice}_{self.args.data_split}_{self.args.data_rate}_ep{self.args.il_start}_wight_dic.pkl\"), \"wb\") as fp:\n",
    "                    pickle.dump(wight_dic, fp)\n",
    "\n",
    "        for idx in range(test_right.shape[0]):\n",
    "            _, indices = torch.sort(distance[:, idx], descending=False)\n",
    "            rank = (indices == idx).nonzero(as_tuple=False).squeeze().item()\n",
    "            mean_r2l += (rank + 1)\n",
    "            mrr_r2l += 1.0 / (rank + 1)\n",
    "            for i in range(len(top_k)):\n",
    "                if rank < top_k[i]:\n",
    "                    acc_r2l[i] += 1\n",
    "        mean_l2r /= test_left.size(0)\n",
    "        mean_r2l /= test_right.size(0)\n",
    "        mrr_l2r /= test_left.size(0)\n",
    "        mrr_r2l /= test_right.size(0)\n",
    "        for i in range(len(top_k)):\n",
    "            acc_l2r[i] = round(acc_l2r[i] / test_left.size(0), 4)\n",
    "            acc_r2l[i] = round(acc_r2l[i] / test_right.size(0), 4)\n",
    "        gc.collect()\n",
    "        if not self.args.only_test:\n",
    "            Loss_out = f\", Loss = {self.loss_item:.4f}\"\n",
    "        else:\n",
    "            Loss_out = \"\"\n",
    "            self.epoch = \"Test\"\n",
    "            self.early_stop_count = 1\n",
    "\n",
    "        if self.rank == 0:\n",
    "            self.logger.info(f\"Ep {self.epoch} | l2r: acc of top {top_k} = {acc_l2r}, mr = {mean_l2r:.3f}, mrr = {mrr_l2r:.3f}{Loss_out}\")\n",
    "            self.logger.info(f\"Ep {self.epoch} | r2l: acc of top {top_k} = {acc_r2l}, mr = {mean_r2l:.3f}, mrr = {mrr_r2l:.3f}{Loss_out}\")\n",
    "            self.early_stop_count -= 1\n",
    "            \n",
    "            self.save_csv_results(mrr_l2r)\n",
    "            \n",
    "        if not self.args.only_test and mrr_l2r > max(self.loss_log.acc) and not last_epoch:\n",
    "            self.logger.info(f\"Best model update in Ep {self.epoch}: MRR from [{max(self.loss_log.acc)}] --> [{mrr_l2r}] ... \")\n",
    "            self.loss_log.update_acc(mrr_l2r)\n",
    "            self.early_stop_count = self.early_stop_init\n",
    "            self.best_model_wts = copy.deepcopy(self.model.state_dict())\n",
    "            \n",
    "    def save_csv_results(self, mrr):\n",
    "        self.csv_results.append(mrr)\n",
    "        pd.Series(self.csv_results).to_csv('csvs/saved.csv')\n",
    "        \n",
    "    def _load_model(self, model, model_name=None):\n",
    "        if model_name is None:\n",
    "            model_name = self.args.model_name_save\n",
    "        save_path = osp.join(self.args.data_path, self.args.model_name, 'save')\n",
    "        save_path = osp.join(save_path, f'{model_name}.pkl')\n",
    "        if len(model_name) > 0 and not os.path.exists(save_path):\n",
    "            print(f\"not exists {model_name} !! \")\n",
    "            pdb.set_trace()\n",
    "        if (len(model_name) == 0 or not os.path.exists(save_path)) and self.rank == 0:\n",
    "            if len(model_name) > 0:\n",
    "                self.logger.info(f\"{model_name}.pkl not exist!!\")\n",
    "            else:\n",
    "                self.logger.info(\"Random init...\")\n",
    "            model.cuda()\n",
    "            return model\n",
    "        if 'Dist' in self.args.model_name:\n",
    "            model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(save_path, map_location=self.args.device).items()})\n",
    "        else:\n",
    "            model.load_state_dict(torch.load(save_path, map_location=self.args.device))\n",
    "\n",
    "        model.cuda()\n",
    "        if self.rank == 0:\n",
    "            self.logger.info(f\"loading model [{model_name}.pkl] done!\")\n",
    "\n",
    "        return model\n",
    "\n",
    "    def _save_model(self, model, input_name=\"\"):\n",
    "\n",
    "        model_name = self.args.model_name\n",
    "\n",
    "        save_path = osp.join(self.args.data_path, model_name, 'save')\n",
    "        os.makedirs(save_path, exist_ok=True)\n",
    "\n",
    "        if input_name == \"\":\n",
    "            input_name = self._save_name_define()\n",
    "        save_path = osp.join(save_path, f'{input_name}.pkl')\n",
    "\n",
    "        if model is None:\n",
    "            return\n",
    "        if self.args.save_model:\n",
    "            torch.save(model.state_dict(), save_path)\n",
    "\n",
    "            self.logger.info(f\"saving [{save_path}] done!\")\n",
    "\n",
    "        return save_path\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1559a13a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO - 05/23/23 09:22:33 - 0:00:00 - ============ Initialized logger ============\n",
      "INFO - 05/23/23 09:22:33 - 0:00:00 - The experiment will be stored in D:\\Projects\\diffusion_model\\data\\mmkg\\dump/0523-EA_exp\\MCLEA_DBP15K_zh_en_001\n",
      "                                     \n",
      "INFO - 05/23/23 09:22:33 - 0:00:00 - Running command: python \n",
      "\n"
     ]
    }
   ],
   "source": [
    "set_seed(cfgs.random_seed)\n",
    "# -----  Init ----------\n",
    "if cfgs.dist and not cfgs.only_test:\n",
    "    init_distributed_mode(args=cfgs)\n",
    "else:\n",
    "    torch.multiprocessing.set_sharing_strategy('file_system')\n",
    "rank = cfgs.rank\n",
    "# pprint.pprint(cfgs)\n",
    "\n",
    "writer, logger = None, None\n",
    "if rank == 0:\n",
    "    logger = initialize_exp(cfgs)\n",
    "    logger_path = get_dump_path(cfgs)\n",
    "    cfgs.time_stamp = \"{0:%Y-%m-%dT%H-%M-%S/}\".format(datetime.now())\n",
    "    comment = f'bath_size={cfgs.batch_size} exp_id={cfgs.exp_id}'\n",
    "    if not cfgs.no_tensorboard and not cfgs.only_test:\n",
    "        writer = SummaryWriter(log_dir=os.path.join(logger_path, 'tensorboard', cfgs.time_stamp), comment=comment)\n",
    "\n",
    "cfgs.device = torch.device(cfgs.device)\n",
    "\n",
    "# print(\"print c to continue...\")\n",
    "# -----  Begin ----------\n",
    "torch.cuda.set_device(cfgs.gpu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25c59f04",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading raw data...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO - 05/23/23 09:22:34 - 0:00:01 - 77.09% entities have images\n",
      "INFO - 05/23/23 09:22:34 - 0:00:01 - image feature shape:(38960, 2048)\n",
      "INFO - 05/23/23 09:22:34 - 0:00:01 - #left entity : 19388, #right entity: 19572\n",
      "INFO - 05/23/23 09:22:34 - 0:00:01 - #left entity not in train set: 7388, #right entity not in train set: 7572\n",
      "INFO - 05/23/23 09:22:34 - 0:00:01 - relation feature shape:(38960, 1000)\n",
      "INFO - 05/23/23 09:22:34 - 0:00:02 - attribute feature shape:(38960, 1000)\n",
      "INFO - 05/23/23 09:22:34 - 0:00:02 - -----dataset summary-----\n",
      "INFO - 05/23/23 09:22:34 - 0:00:02 - dataset:\t\t D:\\Projects\\diffusion_model\\data\\mmkg\\DBP15K\\zh_en\n",
      "INFO - 05/23/23 09:22:34 - 0:00:02 - triple num:\t 165556\n",
      "INFO - 05/23/23 09:22:34 - 0:00:02 - entity num:\t 38960\n",
      "INFO - 05/23/23 09:22:34 - 0:00:02 - relation num:\t 3024\n",
      "INFO - 05/23/23 09:22:34 - 0:00:02 - train ill num:\t 12000 \t test ill num:\t 3000\n",
      "INFO - 05/23/23 09:22:34 - 0:00:02 - -------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "getting a sparse tensor r_adj...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO - 05/23/23 09:22:37 - 0:00:04 - Random init...\n",
      "INFO - 05/23/23 09:22:37 - 0:00:04 - total params num: 60180126\n",
      "INFO - 05/23/23 09:22:37 - 0:00:04 - warmup_steps: 120\n",
      "INFO - 05/23/23 09:22:37 - 0:00:04 - total_steps: 800\n",
      "INFO - 05/23/23 09:22:37 - 0:00:04 - weight_decay: 0.0001\n",
      "  0%|                                                                                                                                                                                                                                                               | 0/200 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.943-31.147-28.284-28.292\n",
      "DistMatch Loss: 0.346; PriorRec Loss: 130.665; PostRec Loss: 0.000; Rec Loss: 0.229\n",
      "Neighbor-IMG-REL-ATTR:42.880-31.156-28.285-28.292\n",
      "DistMatch Loss: 0.354; PriorRec Loss: 130.613; PostRec Loss: 0.000; Rec Loss: 0.228\n",
      "Neighbor-IMG-REL-ATTR:42.850-31.134-28.257-28.285\n",
      "DistMatch Loss: 0.367; PriorRec Loss: 130.526; PostRec Loss: 0.000; Rec Loss: 0.266\n",
      "Neighbor-IMG-REL-ATTR:42.840-31.160-28.224-28.250\n",
      "DistMatch Loss: 0.379; PriorRec Loss: 130.475; PostRec Loss: 0.000; Rec Loss: 0.251\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [0/200] Step [5/800] LR [0.00100] Loss 157121.79297 :   0%|▉                                                                                                                                                                                     | 1/200 [00:11<37:41, 11.36s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.796-31.141-28.195-28.203\n",
      "DistMatch Loss: 0.397; PriorRec Loss: 130.335; PostRec Loss: 0.000; Rec Loss: 0.308\n",
      "Neighbor-IMG-REL-ATTR:42.775-31.174-28.201-28.168\n",
      "DistMatch Loss: 0.423; PriorRec Loss: 130.317; PostRec Loss: 0.000; Rec Loss: 0.273\n",
      "Neighbor-IMG-REL-ATTR:42.817-31.129-28.040-28.086\n",
      "DistMatch Loss: 0.502; PriorRec Loss: 130.072; PostRec Loss: 0.000; Rec Loss: 0.315\n",
      "Neighbor-IMG-REL-ATTR:42.789-31.096-27.781-27.835\n",
      "DistMatch Loss: 0.703; PriorRec Loss: 129.501; PostRec Loss: 0.000; Rec Loss: 0.300\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [1/200] Step [9/800] LR [0.00100] Loss 156536.32812 :   0%|▉                                                                                                                                                                                     | 1/200 [00:15<37:41, 11.36s/it]INFO - 05/23/23 09:22:54 - 0:00:21 - Ep 1 | l2r: acc of top [1, 10, 50] = [0.816  0.9807 0.997 ], mr = 1.985, mrr = 0.877, Loss = 156536.3281\n",
      "INFO - 05/23/23 09:22:54 - 0:00:21 - Ep 1 | r2l: acc of top [1, 10, 50] = [0.8057 0.979  0.998 ], mr = 2.018, mrr = 0.872, Loss = 156536.3281\n",
      "INFO - 05/23/23 09:22:54 - 0:00:21 - Best model update in Ep 1: MRR from [0.0] --> [0.8770374597840452] ... \n",
      "Train | Ep [1/200] Step [9/800] LR [0.00100] Loss 156536.32812 :   1%|█▊                                                                                                                                                                                    | 2/200 [00:16<25:51,  7.84s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.732-31.127-27.287-27.583\n",
      "DistMatch Loss: 1.029; PriorRec Loss: 128.729; PostRec Loss: 0.000; Rec Loss: 0.296\n",
      "Neighbor-IMG-REL-ATTR:42.762-31.118-26.493-27.199\n",
      "DistMatch Loss: 1.721; PriorRec Loss: 127.572; PostRec Loss: 0.000; Rec Loss: 0.310\n",
      "Neighbor-IMG-REL-ATTR:42.753-31.162-25.683-26.804\n",
      "DistMatch Loss: 2.497; PriorRec Loss: 126.402; PostRec Loss: 0.000; Rec Loss: 0.398\n",
      "Neighbor-IMG-REL-ATTR:42.732-31.143-25.015-26.314\n",
      "DistMatch Loss: 3.502; PriorRec Loss: 125.204; PostRec Loss: 0.000; Rec Loss: 0.358\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [2/200] Step [13/800] LR [0.00100] Loss 152884.48438 :   2%|██▋                                                                                                                                                                                | 3/200 [01:27<1:59:43, 36.46s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.725-31.143-24.310-25.665\n",
      "DistMatch Loss: 4.846; PriorRec Loss: 123.842; PostRec Loss: 0.000; Rec Loss: 0.399\n",
      "Neighbor-IMG-REL-ATTR:42.708-31.125-23.679-25.004\n",
      "DistMatch Loss: 6.870; PriorRec Loss: 122.516; PostRec Loss: 0.000; Rec Loss: 0.368\n",
      "Neighbor-IMG-REL-ATTR:42.714-31.119-23.393-24.398\n",
      "DistMatch Loss: 14.663; PriorRec Loss: 121.624; PostRec Loss: 0.000; Rec Loss: 0.492\n",
      "Neighbor-IMG-REL-ATTR:42.710-31.106-23.055-24.001\n",
      "DistMatch Loss: 12.138; PriorRec Loss: 120.873; PostRec Loss: 0.000; Rec Loss: 0.423\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [3/200] Step [17/800] LR [0.00100] Loss 147277.16406 :   2%|██▋                                                                                                                                                                                | 3/200 [01:32<1:59:43, 36.46s/it]INFO - 05/23/23 09:24:10 - 0:01:37 - Ep 3 | l2r: acc of top [1, 10, 50] = [0.8203 0.985  0.997 ], mr = 1.867, mrr = 0.883, Loss = 147277.1641\n",
      "INFO - 05/23/23 09:24:10 - 0:01:37 - Ep 3 | r2l: acc of top [1, 10, 50] = [0.813  0.9857 0.9967], mr = 1.865, mrr = 0.879, Loss = 147277.1641\n",
      "INFO - 05/23/23 09:24:10 - 0:01:37 - Best model update in Ep 3: MRR from [0.8770374597840452] --> [0.8825224899670588] ... \n",
      "Train | Ep [3/200] Step [17/800] LR [0.00100] Loss 147277.16406 :   2%|███▌                                                                                                                                                                               | 4/200 [01:32<1:19:25, 24.31s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.711-31.140-22.528-23.486\n",
      "DistMatch Loss: 13.074; PriorRec Loss: 119.865; PostRec Loss: 0.000; Rec Loss: 0.505\n",
      "Neighbor-IMG-REL-ATTR:42.717-31.125-22.392-23.140\n",
      "DistMatch Loss: 14.357; PriorRec Loss: 119.375; PostRec Loss: 0.000; Rec Loss: 0.528\n",
      "Neighbor-IMG-REL-ATTR:42.706-31.132-21.971-22.786\n",
      "DistMatch Loss: 14.907; PriorRec Loss: 118.595; PostRec Loss: 0.000; Rec Loss: 0.476\n",
      "Neighbor-IMG-REL-ATTR:42.691-31.095-21.810-22.490\n",
      "DistMatch Loss: 15.364; PriorRec Loss: 118.086; PostRec Loss: 0.000; Rec Loss: 0.480\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [4/200] Step [21/800] LR [0.00100] Loss 143495.13281 :   2%|████▌                                                                                                                                                                                | 5/200 [01:38<57:06, 17.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.692-31.128-21.554-22.322\n",
      "DistMatch Loss: 15.646; PriorRec Loss: 117.696; PostRec Loss: 0.000; Rec Loss: 0.537\n",
      "Neighbor-IMG-REL-ATTR:42.681-31.092-21.435-22.039\n",
      "DistMatch Loss: 17.523; PriorRec Loss: 117.247; PostRec Loss: 0.000; Rec Loss: 0.515\n",
      "Neighbor-IMG-REL-ATTR:42.676-31.100-21.053-21.875\n",
      "DistMatch Loss: 17.839; PriorRec Loss: 116.704; PostRec Loss: 0.000; Rec Loss: 0.475\n",
      "Neighbor-IMG-REL-ATTR:42.677-31.124-21.182-21.628\n",
      "DistMatch Loss: 16.913; PriorRec Loss: 116.610; PostRec Loss: 0.000; Rec Loss: 0.483\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [5/200] Step [25/800] LR [0.00100] Loss 141204.47266 :   2%|████▌                                                                                                                                                                                | 5/200 [01:45<57:06, 17.57s/it]INFO - 05/23/23 09:24:23 - 0:01:51 - Ep 5 | l2r: acc of top [1, 10, 50] = [0.843  0.986  0.9957], mr = 1.926, mrr = 0.899, Loss = 141204.4727\n",
      "INFO - 05/23/23 09:24:23 - 0:01:51 - Ep 5 | r2l: acc of top [1, 10, 50] = [0.842  0.9857 0.9957], mr = 1.951, mrr = 0.899, Loss = 141204.4727\n",
      "INFO - 05/23/23 09:24:23 - 0:01:51 - Best model update in Ep 5: MRR from [0.8825224899670588] --> [0.8993888104830363] ... \n",
      "Train | Ep [5/200] Step [25/800] LR [0.00100] Loss 141204.47266 :   3%|█████▍                                                                                                                                                                               | 6/200 [01:46<46:18, 14.32s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.667-31.069-20.815-21.519\n",
      "DistMatch Loss: 15.983; PriorRec Loss: 116.070; PostRec Loss: 0.000; Rec Loss: 0.437\n",
      "Neighbor-IMG-REL-ATTR:42.662-31.106-20.695-21.559\n",
      "DistMatch Loss: 16.901; PriorRec Loss: 116.022; PostRec Loss: 0.000; Rec Loss: 0.416\n",
      "Neighbor-IMG-REL-ATTR:42.646-31.083-20.560-21.337\n",
      "DistMatch Loss: 17.962; PriorRec Loss: 115.625; PostRec Loss: 0.000; Rec Loss: 0.483\n",
      "Neighbor-IMG-REL-ATTR:42.661-31.082-20.736-21.406\n",
      "DistMatch Loss: 18.096; PriorRec Loss: 115.885; PostRec Loss: 0.000; Rec Loss: 0.416\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [6/200] Step [29/800] LR [0.00100] Loss 139729.14844 :   4%|██████▎                                                                                                                                                                              | 7/200 [01:54<39:00, 12.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.644-31.083-20.431-21.000\n",
      "DistMatch Loss: 18.499; PriorRec Loss: 115.158; PostRec Loss: 0.000; Rec Loss: 0.413\n",
      "Neighbor-IMG-REL-ATTR:42.674-31.078-20.226-20.959\n",
      "DistMatch Loss: 19.364; PriorRec Loss: 114.936; PostRec Loss: 0.000; Rec Loss: 0.441\n",
      "Neighbor-IMG-REL-ATTR:42.633-31.042-20.150-20.785\n",
      "DistMatch Loss: 19.191; PriorRec Loss: 114.610; PostRec Loss: 0.000; Rec Loss: 0.412\n",
      "Neighbor-IMG-REL-ATTR:42.639-31.107-20.022-20.889\n",
      "DistMatch Loss: 19.803; PriorRec Loss: 114.658; PostRec Loss: 0.000; Rec Loss: 0.375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [7/200] Step [33/800] LR [0.00100] Loss 138425.87109 :   4%|██████▎                                                                                                                                                                              | 7/200 [01:59<39:00, 12.12s/it]INFO - 05/23/23 09:24:38 - 0:02:06 - Ep 7 | l2r: acc of top [1, 10, 50] = [0.8553 0.984  0.997 ], mr = 1.834, mrr = 0.906, Loss = 138425.8711\n",
      "INFO - 05/23/23 09:24:38 - 0:02:06 - Ep 7 | r2l: acc of top [1, 10, 50] = [0.8587 0.985  0.9973], mr = 1.843, mrr = 0.909, Loss = 138425.8711\n",
      "INFO - 05/23/23 09:24:38 - 0:02:06 - Best model update in Ep 7: MRR from [0.8993888104830363] --> [0.9064826748008512] ... \n",
      "Train | Ep [7/200] Step [33/800] LR [0.00100] Loss 138425.87109 :   4%|███████▏                                                                                                                                                                             | 8/200 [02:01<33:41, 10.53s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.652-31.068-19.924-20.691\n",
      "DistMatch Loss: 28.860; PriorRec Loss: 114.335; PostRec Loss: 0.000; Rec Loss: 0.376\n",
      "Neighbor-IMG-REL-ATTR:42.613-31.087-19.784-20.588\n",
      "DistMatch Loss: 22.517; PriorRec Loss: 114.073; PostRec Loss: 0.000; Rec Loss: 0.397\n",
      "Neighbor-IMG-REL-ATTR:42.613-31.063-19.840-20.490\n",
      "DistMatch Loss: 143.898; PriorRec Loss: 114.006; PostRec Loss: 0.000; Rec Loss: 0.498\n",
      "Neighbor-IMG-REL-ATTR:42.598-31.128-19.753-20.517\n",
      "DistMatch Loss: 21.700; PriorRec Loss: 113.996; PostRec Loss: 0.000; Rec Loss: 0.419\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [8/200] Step [37/800] LR [0.00100] Loss 137624.18750 :   4%|████████▏                                                                                                                                                                            | 9/200 [02:07<29:01,  9.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.589-31.053-19.576-20.436\n",
      "DistMatch Loss: 21.343; PriorRec Loss: 113.653; PostRec Loss: 0.000; Rec Loss: 0.371\n",
      "Neighbor-IMG-REL-ATTR:42.605-31.080-19.606-20.381\n",
      "DistMatch Loss: 25.367; PriorRec Loss: 113.672; PostRec Loss: 0.000; Rec Loss: 0.431\n",
      "Neighbor-IMG-REL-ATTR:42.621-31.028-19.359-20.145\n",
      "DistMatch Loss: 22.569; PriorRec Loss: 113.153; PostRec Loss: 0.000; Rec Loss: 0.426\n",
      "Neighbor-IMG-REL-ATTR:42.588-31.072-19.142-20.006\n",
      "DistMatch Loss: 26.200; PriorRec Loss: 112.808; PostRec Loss: 0.000; Rec Loss: 0.425\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [9/200] Step [41/800] LR [0.00100] Loss 136615.28906 :   4%|████████▏                                                                                                                                                                            | 9/200 [02:13<29:01,  9.12s/it]INFO - 05/23/23 09:24:52 - 0:02:19 - Ep 9 | l2r: acc of top [1, 10, 50] = [0.8597 0.9843 0.9977], mr = 1.769, mrr = 0.911, Loss = 136615.2891\n",
      "INFO - 05/23/23 09:24:52 - 0:02:19 - Ep 9 | r2l: acc of top [1, 10, 50] = [0.8583 0.986  0.9973], mr = 1.870, mrr = 0.910, Loss = 136615.2891\n",
      "INFO - 05/23/23 09:24:52 - 0:02:19 - Best model update in Ep 9: MRR from [0.9064826748008512] --> [0.9110493600023121] ... \n",
      "Train | Ep [9/200] Step [41/800] LR [0.00100] Loss 136615.28906 :   5%|█████████                                                                                                                                                                           | 10/200 [02:14<27:10,  8.58s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.587-31.089-19.060-20.004\n",
      "DistMatch Loss: 23.573; PriorRec Loss: 112.740; PostRec Loss: 0.000; Rec Loss: 0.422\n",
      "Neighbor-IMG-REL-ATTR:42.548-31.057-18.893-20.076\n",
      "DistMatch Loss: 24.028; PriorRec Loss: 112.574; PostRec Loss: 0.000; Rec Loss: 0.396\n",
      "Neighbor-IMG-REL-ATTR:42.544-31.047-18.883-19.876\n",
      "DistMatch Loss: 27.827; PriorRec Loss: 112.349; PostRec Loss: 0.000; Rec Loss: 0.437\n",
      "Neighbor-IMG-REL-ATTR:42.580-31.081-18.914-20.008\n",
      "DistMatch Loss: 25.828; PriorRec Loss: 112.583; PostRec Loss: 0.000; Rec Loss: 0.431\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [10/200] Step [45/800] LR [0.00100] Loss 135714.08594 :   6%|█████████▊                                                                                                                                                                         | 11/200 [02:20<24:18,  7.72s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.561-31.067-18.813-19.847\n",
      "DistMatch Loss: 27.591; PriorRec Loss: 112.289; PostRec Loss: 0.000; Rec Loss: 0.444\n",
      "Neighbor-IMG-REL-ATTR:42.544-31.077-18.668-19.694\n",
      "DistMatch Loss: 26.029; PriorRec Loss: 111.984; PostRec Loss: 0.000; Rec Loss: 0.390\n",
      "Neighbor-IMG-REL-ATTR:42.554-31.034-18.360-19.548\n",
      "DistMatch Loss: 25.950; PriorRec Loss: 111.496; PostRec Loss: 0.000; Rec Loss: 0.389\n",
      "Neighbor-IMG-REL-ATTR:42.540-31.052-18.433-19.590\n",
      "DistMatch Loss: 28.120; PriorRec Loss: 111.615; PostRec Loss: 0.000; Rec Loss: 0.451\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [11/200] Step [49/800] LR [0.00100] Loss 134853.80469 :   6%|█████████▊                                                                                                                                                                         | 11/200 [02:26<24:18,  7.72s/it]INFO - 05/23/23 09:25:05 - 0:02:32 - Ep 11 | l2r: acc of top [1, 10, 50] = [0.857  0.9847 0.9973], mr = 1.875, mrr = 0.909, Loss = 134853.8047\n",
      "INFO - 05/23/23 09:25:05 - 0:02:32 - Ep 11 | r2l: acc of top [1, 10, 50] = [0.861 0.985 0.997], mr = 1.923, mrr = 0.911, Loss = 134853.8047\n",
      "Train | Ep [11/200] Step [49/800] LR [0.00100] Loss 134853.80469 :   6%|██████████▋                                                                                                                                                                        | 12/200 [02:28<24:14,  7.74s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.505-31.008-18.253-19.505\n",
      "DistMatch Loss: 29.922; PriorRec Loss: 111.271; PostRec Loss: 0.000; Rec Loss: 0.474\n",
      "Neighbor-IMG-REL-ATTR:42.504-31.052-18.109-19.369\n",
      "DistMatch Loss: 28.299; PriorRec Loss: 111.035; PostRec Loss: 0.000; Rec Loss: 0.390\n",
      "Neighbor-IMG-REL-ATTR:42.515-31.029-18.009-19.270\n",
      "DistMatch Loss: 29.353; PriorRec Loss: 110.822; PostRec Loss: 0.000; Rec Loss: 0.424\n",
      "Neighbor-IMG-REL-ATTR:42.479-31.029-18.068-19.478\n",
      "DistMatch Loss: 31.668; PriorRec Loss: 111.054; PostRec Loss: 0.000; Rec Loss: 0.482\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [12/200] Step [53/800] LR [0.00100] Loss 133926.77734 :   6%|███████████▋                                                                                                                                                                       | 13/200 [02:33<21:56,  7.04s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.479-31.023-18.053-19.199\n",
      "DistMatch Loss: 31.383; PriorRec Loss: 110.755; PostRec Loss: 0.000; Rec Loss: 0.425\n",
      "Neighbor-IMG-REL-ATTR:42.471-31.034-17.746-19.144\n",
      "DistMatch Loss: 29.194; PriorRec Loss: 110.395; PostRec Loss: 0.000; Rec Loss: 0.419\n",
      "Neighbor-IMG-REL-ATTR:42.488-30.998-17.844-19.205\n",
      "DistMatch Loss: 31.172; PriorRec Loss: 110.534; PostRec Loss: 0.000; Rec Loss: 0.416\n",
      "Neighbor-IMG-REL-ATTR:42.479-31.034-17.872-19.287\n",
      "DistMatch Loss: 29.593; PriorRec Loss: 110.671; PostRec Loss: 0.000; Rec Loss: 0.419\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [13/200] Step [57/800] LR [0.00100] Loss 133351.80859 :   6%|███████████▋                                                                                                                                                                       | 13/200 [02:39<21:56,  7.04s/it]INFO - 05/23/23 09:25:18 - 0:02:45 - Ep 13 | l2r: acc of top [1, 10, 50] = [0.869  0.9837 0.9963], mr = 1.882, mrr = 0.916, Loss = 133351.8086\n",
      "INFO - 05/23/23 09:25:18 - 0:02:45 - Ep 13 | r2l: acc of top [1, 10, 50] = [0.8697 0.9843 0.9967], mr = 1.961, mrr = 0.916, Loss = 133351.8086\n",
      "INFO - 05/23/23 09:25:18 - 0:02:45 - Best model update in Ep 13: MRR from [0.9110493600023121] --> [0.9159309086728497] ... \n",
      "Train | Ep [13/200] Step [57/800] LR [0.00100] Loss 133351.80859 :   7%|████████████▌                                                                                                                                                                      | 14/200 [02:40<21:55,  7.07s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.464-31.003-17.674-19.126\n",
      "DistMatch Loss: 31.095; PriorRec Loss: 110.267; PostRec Loss: 0.000; Rec Loss: 0.474\n",
      "Neighbor-IMG-REL-ATTR:42.449-30.997-17.701-19.020\n",
      "DistMatch Loss: 31.018; PriorRec Loss: 110.166; PostRec Loss: 0.000; Rec Loss: 0.400\n",
      "Neighbor-IMG-REL-ATTR:42.461-31.014-17.600-18.994\n",
      "DistMatch Loss: 31.469; PriorRec Loss: 110.069; PostRec Loss: 0.000; Rec Loss: 0.440\n",
      "Neighbor-IMG-REL-ATTR:42.444-30.991-17.486-19.091\n",
      "DistMatch Loss: 33.091; PriorRec Loss: 110.012; PostRec Loss: 0.000; Rec Loss: 0.467\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [14/200] Step [61/800] LR [0.00100] Loss 132832.27344 :   8%|█████████████▍                                                                                                                                                                     | 15/200 [02:46<20:49,  6.75s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.439-31.024-17.341-18.721\n",
      "DistMatch Loss: 32.096; PriorRec Loss: 109.524; PostRec Loss: 0.000; Rec Loss: 0.429\n",
      "Neighbor-IMG-REL-ATTR:42.374-31.006-17.313-18.962\n",
      "DistMatch Loss: 30.288; PriorRec Loss: 109.656; PostRec Loss: 0.000; Rec Loss: 0.459\n",
      "Neighbor-IMG-REL-ATTR:42.385-30.993-17.212-18.728\n",
      "DistMatch Loss: 31.537; PriorRec Loss: 109.318; PostRec Loss: 0.000; Rec Loss: 0.413\n",
      "Neighbor-IMG-REL-ATTR:42.360-30.983-17.147-18.890\n",
      "DistMatch Loss: 34.860; PriorRec Loss: 109.380; PostRec Loss: 0.000; Rec Loss: 0.434\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [15/200] Step [65/800] LR [0.00100] Loss 132028.06641 :   8%|█████████████▍                                                                                                                                                                     | 15/200 [02:52<20:49,  6.75s/it]INFO - 05/23/23 09:25:31 - 0:02:58 - Ep 15 | l2r: acc of top [1, 10, 50] = [0.8713 0.9857 0.9967], mr = 1.858, mrr = 0.918, Loss = 132028.0664\n",
      "INFO - 05/23/23 09:25:31 - 0:02:58 - Ep 15 | r2l: acc of top [1, 10, 50] = [0.873  0.9853 0.997 ], mr = 1.815, mrr = 0.919, Loss = 132028.0664\n",
      "INFO - 05/23/23 09:25:31 - 0:02:58 - Best model update in Ep 15: MRR from [0.9159309086728497] --> [0.9179064339099082] ... \n",
      "Train | Ep [15/200] Step [65/800] LR [0.00100] Loss 132028.06641 :   8%|██████████████▎                                                                                                                                                                    | 16/200 [02:54<21:11,  6.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.363-30.984-17.072-18.706\n",
      "DistMatch Loss: 32.465; PriorRec Loss: 109.124; PostRec Loss: 0.000; Rec Loss: 0.433\n",
      "Neighbor-IMG-REL-ATTR:42.331-30.983-16.883-18.582\n",
      "DistMatch Loss: 31.531; PriorRec Loss: 108.779; PostRec Loss: 0.000; Rec Loss: 0.392\n",
      "Neighbor-IMG-REL-ATTR:42.308-30.994-16.990-18.668\n",
      "DistMatch Loss: 31.133; PriorRec Loss: 108.960; PostRec Loss: 0.000; Rec Loss: 0.404\n",
      "Neighbor-IMG-REL-ATTR:42.329-30.993-16.972-18.576\n",
      "DistMatch Loss: 30.943; PriorRec Loss: 108.869; PostRec Loss: 0.000; Rec Loss: 0.393\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [16/200] Step [69/800] LR [0.00100] Loss 131348.45312 :   8%|███████████████▏                                                                                                                                                                   | 17/200 [03:00<20:24,  6.69s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.322-30.985-16.829-18.563\n",
      "DistMatch Loss: 33.924; PriorRec Loss: 108.699; PostRec Loss: 0.000; Rec Loss: 0.457\n",
      "Neighbor-IMG-REL-ATTR:42.334-30.958-16.737-18.533\n",
      "DistMatch Loss: 33.582; PriorRec Loss: 108.562; PostRec Loss: 0.000; Rec Loss: 0.435\n",
      "Neighbor-IMG-REL-ATTR:42.296-31.008-16.724-18.589\n",
      "DistMatch Loss: 35.396; PriorRec Loss: 108.618; PostRec Loss: 0.000; Rec Loss: 0.456\n",
      "Neighbor-IMG-REL-ATTR:42.266-30.982-16.606-18.391\n",
      "DistMatch Loss: 34.391; PriorRec Loss: 108.245; PostRec Loss: 0.000; Rec Loss: 0.436\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [17/200] Step [73/800] LR [0.00100] Loss 130919.28711 :   8%|███████████████▏                                                                                                                                                                   | 17/200 [03:06<20:24,  6.69s/it]INFO - 05/23/23 09:25:45 - 0:03:12 - Ep 17 | l2r: acc of top [1, 10, 50] = [0.867  0.986  0.9967], mr = 1.957, mrr = 0.916, Loss = 130919.2871\n",
      "INFO - 05/23/23 09:25:45 - 0:03:12 - Ep 17 | r2l: acc of top [1, 10, 50] = [0.8757 0.9863 0.998 ], mr = 1.810, mrr = 0.920, Loss = 130919.2871\n",
      "Train | Ep [17/200] Step [73/800] LR [0.00100] Loss 130919.28711 :   9%|████████████████                                                                                                                                                                   | 18/200 [03:07<20:58,  6.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.256-30.968-16.748-18.576\n",
      "DistMatch Loss: 33.122; PriorRec Loss: 108.547; PostRec Loss: 0.000; Rec Loss: 0.446\n",
      "Neighbor-IMG-REL-ATTR:42.202-30.996-16.548-18.414\n",
      "DistMatch Loss: 32.806; PriorRec Loss: 108.160; PostRec Loss: 0.000; Rec Loss: 0.421\n",
      "Neighbor-IMG-REL-ATTR:42.200-30.985-16.467-18.380\n",
      "DistMatch Loss: 34.957; PriorRec Loss: 108.032; PostRec Loss: 0.000; Rec Loss: 0.449\n",
      "Neighbor-IMG-REL-ATTR:42.211-30.964-16.504-18.332\n",
      "DistMatch Loss: 35.742; PriorRec Loss: 108.011; PostRec Loss: 0.000; Rec Loss: 0.455\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [18/200] Step [77/800] LR [0.00100] Loss 130502.69141 :  10%|█████████████████                                                                                                                                                                  | 19/200 [03:13<19:35,  6.50s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.178-30.962-16.431-18.251\n",
      "DistMatch Loss: 35.539; PriorRec Loss: 107.822; PostRec Loss: 0.000; Rec Loss: 0.479\n",
      "Neighbor-IMG-REL-ATTR:42.168-30.977-16.433-18.308\n",
      "DistMatch Loss: 35.962; PriorRec Loss: 107.886; PostRec Loss: 0.000; Rec Loss: 0.406\n",
      "Neighbor-IMG-REL-ATTR:42.167-30.948-16.415-18.346\n",
      "DistMatch Loss: 35.489; PriorRec Loss: 107.876; PostRec Loss: 0.000; Rec Loss: 0.456\n",
      "Neighbor-IMG-REL-ATTR:42.148-30.988-16.359-18.236\n",
      "DistMatch Loss: 35.383; PriorRec Loss: 107.732; PostRec Loss: 0.000; Rec Loss: 0.478\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [19/200] Step [81/800] LR [0.00100] Loss 130087.39453 :  10%|█████████████████                                                                                                                                                                  | 19/200 [03:19<19:35,  6.50s/it]INFO - 05/23/23 09:25:57 - 0:03:24 - Ep 19 | l2r: acc of top [1, 10, 50] = [0.869  0.9863 0.9973], mr = 1.847, mrr = 0.917, Loss = 130087.3945\n",
      "INFO - 05/23/23 09:25:57 - 0:03:24 - Ep 19 | r2l: acc of top [1, 10, 50] = [0.8707 0.987  0.997 ], mr = 1.723, mrr = 0.918, Loss = 130087.3945\n",
      "Train | Ep [19/200] Step [81/800] LR [0.00100] Loss 130087.39453 :  10%|█████████████████▉                                                                                                                                                                 | 20/200 [03:20<19:59,  6.66s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.136-30.958-16.416-18.316\n",
      "DistMatch Loss: 36.330; PriorRec Loss: 107.825; PostRec Loss: 0.000; Rec Loss: 0.482\n",
      "Neighbor-IMG-REL-ATTR:42.115-30.971-16.257-18.186\n",
      "DistMatch Loss: 35.852; PriorRec Loss: 107.529; PostRec Loss: 0.000; Rec Loss: 0.427\n",
      "Neighbor-IMG-REL-ATTR:42.072-30.945-16.196-18.129\n",
      "DistMatch Loss: 36.283; PriorRec Loss: 107.342; PostRec Loss: 0.000; Rec Loss: 0.410\n",
      "Neighbor-IMG-REL-ATTR:42.061-30.978-16.119-18.069\n",
      "DistMatch Loss: 36.731; PriorRec Loss: 107.226; PostRec Loss: 0.000; Rec Loss: 0.446\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [20/200] Step [85/800] LR [0.00100] Loss 129653.88867 :  10%|██████████████████▊                                                                                                                                                                | 21/200 [03:25<18:49,  6.31s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.087-30.963-16.223-17.990\n",
      "DistMatch Loss: 35.925; PriorRec Loss: 107.263; PostRec Loss: 0.000; Rec Loss: 0.463\n",
      "Neighbor-IMG-REL-ATTR:42.013-30.947-16.153-18.213\n",
      "DistMatch Loss: 36.752; PriorRec Loss: 107.327; PostRec Loss: 0.000; Rec Loss: 0.416\n",
      "Neighbor-IMG-REL-ATTR:42.019-30.964-16.042-18.006\n",
      "DistMatch Loss: 37.009; PriorRec Loss: 107.031; PostRec Loss: 0.000; Rec Loss: 0.440\n",
      "Neighbor-IMG-REL-ATTR:41.968-30.948-15.988-18.166\n",
      "DistMatch Loss: 37.358; PriorRec Loss: 107.070; PostRec Loss: 0.000; Rec Loss: 0.428\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [21/200] Step [89/800] LR [0.00100] Loss 129281.07227 :  10%|██████████████████▊                                                                                                                                                                | 21/200 [03:31<18:49,  6.31s/it]INFO - 05/23/23 09:26:10 - 0:03:37 - Ep 21 | l2r: acc of top [1, 10, 50] = [0.8733 0.9867 0.997 ], mr = 2.104, mrr = 0.919, Loss = 129281.0723\n",
      "INFO - 05/23/23 09:26:10 - 0:03:37 - Ep 21 | r2l: acc of top [1, 10, 50] = [0.8717 0.986  0.9977], mr = 1.842, mrr = 0.918, Loss = 129281.0723\n",
      "INFO - 05/23/23 09:26:10 - 0:03:37 - Best model update in Ep 21: MRR from [0.9179064339099082] --> [0.9188213930926427] ... \n",
      "Train | Ep [21/200] Step [89/800] LR [0.00100] Loss 129281.07227 :  11%|███████████████████▋                                                                                                                                                               | 22/200 [03:32<19:31,  6.58s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:42.004-30.947-16.057-17.957\n",
      "DistMatch Loss: 37.802; PriorRec Loss: 106.966; PostRec Loss: 0.000; Rec Loss: 0.508\n",
      "Neighbor-IMG-REL-ATTR:41.938-30.967-15.890-18.035\n",
      "DistMatch Loss: 36.788; PriorRec Loss: 106.830; PostRec Loss: 0.000; Rec Loss: 0.425\n",
      "Neighbor-IMG-REL-ATTR:41.905-30.947-15.874-17.903\n",
      "DistMatch Loss: 37.260; PriorRec Loss: 106.629; PostRec Loss: 0.000; Rec Loss: 0.435\n",
      "Neighbor-IMG-REL-ATTR:41.905-30.922-15.822-17.842\n",
      "DistMatch Loss: 35.441; PriorRec Loss: 106.491; PostRec Loss: 0.000; Rec Loss: 0.391\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Train | Ep [22/200] Step [93/800] LR [0.00100] Loss 128752.46680 :  12%|████████████████████▌                                                                                                                                                              | 23/200 [03:38<18:54,  6.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Neighbor-IMG-REL-ATTR:41.896-30.955-15.925-17.847\n",
      "DistMatch Loss: 37.623; PriorRec Loss: 106.623; PostRec Loss: 0.000; Rec Loss: 0.421\n",
      "Neighbor-IMG-REL-ATTR:41.873-30.924-15.758-17.821\n",
      "DistMatch Loss: 37.849; PriorRec Loss: 106.376; PostRec Loss: 0.000; Rec Loss: 0.478\n"
     ]
    }
   ],
   "source": [
    "# import os\n",
    "# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n",
    "\n",
    "runner = Runner(cfgs, writer, logger, rank)\n",
    "if cfgs.only_test:\n",
    "    runner.test(last_epoch=False)\n",
    "else:\n",
    "    runner.run()\n",
    "\n",
    "# -----  End ----------\n",
    "if not cfgs.no_tensorboard and not cfgs.only_test and rank == 0:\n",
    "    writer.close()\n",
    "    logger.info(\"done!\")\n",
    "\n",
    "if cfgs.dist and not cfgs.only_test:\n",
    "    dist.barrier()\n",
    "    dist.destroy_process_group()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dde79750",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def eval_nongenarative_model(es_ill, g, model):\n",
    "#     gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb = runner.model.joint_emb_generat(\n",
    "#             only_joint=False)\n",
    "#     embs = [gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb]\n",
    "    \n",
    "#     left, right = es_ill[:, 0], es_ill[:, 1]\n",
    "    \n",
    "#     left_z =  g.encode(left, [gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb])\n",
    "#     left_emb_y =  [emb[left] for emb in embs if emb is not None]\n",
    "#     right_emb_y = [emb[right] for emb in embs if emb is not None]\n",
    "#     with torch.no_grad():\n",
    "#         error = [F.mse_loss(ly, ry) for ly, ry in zip(left_emb_y, right_emb_y)]\n",
    "#         error = sum(error) / len(error)\n",
    "        \n",
    "#     left_concrete_y = [subdecoder(ley) for ley, subdecoder in zip(left_emb_y, g.subdecoders)]\n",
    "#     right_concrete_y = [concrete_feature[right] for concrete_feature in g.concrete_features if concrete_feature is not None]\n",
    "    \n",
    "#     with torch.no_grad():\n",
    "#         rc_error = [F.mse_loss(ly.cpu(), ry.cpu()) for ly, ry in zip(left_emb_y, right_emb_y)]\n",
    "#         rc_error = sum(rc_error) / len(rc_error)\n",
    "        \n",
    "#         prc_error = [F.mse_loss(ly.cpu(), ry.cpu()) for ly, ry in zip(left_concrete_y, right_concrete_y)]\n",
    "#         prc_error = sum(prc_error) / len(prc_error)\n",
    "        \n",
    "#     return rc_error, prc_error\n",
    "# es_ill = runner.model.geea.kgs['es_ill']\n",
    "# g = runner.model.geea\n",
    "# rc_error, prc_error = eval_nongenarative_model(es_ill, g, runner.model)\n",
    "# rc_error, prc_error\n",
    "\n",
    "\n",
    "\n",
    "# def decode(self, zs, reparameterize=False):\n",
    "\n",
    "#     reconstructed_x = [subgenerator.decode(z, reparameterize=reparameterize)\n",
    "#                for subgenerator, z in zip(self.subgenerators, zs)]\n",
    "\n",
    "#     return reconstructed_x\n",
    "# g.decode = decode\n",
    "\n",
    "# gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb = runner.model.joint_emb_generat(\n",
    "#             only_joint=False)\n",
    "\n",
    "# ents = [23873, 5899, 860]\n",
    "# samples = runner.model.geea.sample_from_x_to_y(ents, [gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb])\n",
    "# recovered_samples = runner.model.geea.recover_to_feature(samples,max_counts=30)\n",
    "# recovered_samples[2]\n",
    "\n",
    "# import numpy\n",
    "# from numpy import cov\n",
    "# from numpy import trace\n",
    "# from numpy import iscomplexobj\n",
    "# from numpy.random import random\n",
    "# from scipy.linalg import sqrtm\n",
    " \n",
    "# # calculate frechet inception distance\n",
    "# def calculate_fid(act1, act2):\n",
    "#     print(act1, act2)\n",
    "#      # calculate mean and covariance statistics\n",
    "#     mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)\n",
    "#     mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)\n",
    "#     # calculate sum squared difference between means\n",
    "#     ssdiff = numpy.sum((mu1 - mu2)**2.0)\n",
    "#     # calculate sqrt of product between cov\n",
    "#     covmean = sqrtm(sigma1.dot(sigma2))\n",
    "#     # check and correct imaginary numbers from sqrt\n",
    "#     if iscomplexobj(covmean):\n",
    "#         covmean = covmean.real\n",
    "#     # calculate score\n",
    "#     fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)\n",
    "#     return fid\n",
    "\n",
    "# def rc(es_ill, g, model):\n",
    "#     gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb = runner.model.joint_emb_generat(\n",
    "#             only_joint=False)\n",
    "#     embs = [gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb]\n",
    "    \n",
    "#     left, right = es_ill[:, 0], es_ill[:, 1]\n",
    "    \n",
    "#     left_z =  g.encode(left, [gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb])\n",
    "#     left_emb_y =  g.decode(left_z, reparameterize=True)\n",
    "#     right_emb_y = [emb[right] for emb in embs if emb is not None]\n",
    "#     with torch.no_grad():\n",
    "#         error = [F.mse_loss(ly, ry) for ly, ry in zip(left_emb_y, right_emb_y)]\n",
    "#         error = sum(error) / len(error)\n",
    "        \n",
    "#     left_concrete_y = [subdecoder(ley) for ley, subdecoder in zip(left_emb_y, g.subdecoders)]\n",
    "#     right_concrete_y = [concrete_feature[right] for concrete_feature in g.concrete_features if concrete_feature is not None]\n",
    "    \n",
    "#     with torch.no_grad():\n",
    "#         rc_error = [F.mse_loss(ly.cpu(), ry.cpu()) for ly, ry in zip(left_emb_y, right_emb_y)]\n",
    "#         rc_error = sum(rc_error) / len(rc_error)\n",
    "        \n",
    "#         prc_error = [F.mse_loss(ly.cpu(), ry.cpu()) for ly, ry in zip(left_concrete_y, right_concrete_y)]\n",
    "#         prc_error = sum(prc_error) / len(prc_error)\n",
    "    \n",
    "#     sample_z = torch.randn(len(left), runner.model.geea.latent_dim).repeat(len(left_concrete_y), 1, 1).cuda()\n",
    "#     sample_emb_y = g.decode(sample_z, reparameterize=False)\n",
    "    \n",
    "#     with torch.no_grad():\n",
    "#         fid = [calculate_fid(sy.cpu().numpy(), ry.cpu().numpy()) for sy, ry in zip(sample_emb_y, right_emb_y)]\n",
    "#         fid = sum(fid) / len(fid)\n",
    "        \n",
    "#     return rc_error, prc_error, fid\n",
    "\n",
    "# es_ill = runner.model.geea.kgs['es_ill']\n",
    "# g = runner.model.geea\n",
    "# rc(es_ill, g, runner.model)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# z = g.encode(es_ill[:, 0], [gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb])\n",
    "\n",
    "# y_xy = g.decode(g,z, reparameterize=True)\n",
    "\n",
    "# F.mse_loss(y_xy[3],att_emb[es_ill[:, 1]])\n",
    "\n",
    "# F.mse_loss(y_xy[3],att_emb[es_ill[:, 1]])\n",
    "\n",
    "# runner.eval_set.data\n",
    "\n",
    "# es_ill = np.random.randint(test_ill)\n",
    "\n",
    "# samples[0]\n",
    "\n",
    "# mask = torch.where(runner.model.geea.concrete_features[0]>0)\n",
    "\n",
    "# z = torch.randn(10, runner.model.geea.latent_dim).cuda()\n",
    "\n",
    "# z.repeat(4,1,1)\n",
    "\n",
    "# def decode(self, zs, reparameterize=False):\n",
    "#     decoded = [subgenerator.decode(z, reparameterize=reparameterize)\n",
    "#                for subgenerator, z in zip(self.subgenerators, zs)]\n",
    "#     reconstructed_x = [subdecoder(d)\n",
    "#                for d, subdecoder in zip(decoded, self.subdecoders)]\n",
    "\n",
    "#     return reconstructed_x, decoded\n",
    "\n",
    "# z = torch.randn(4,10, runner.model.geea.latent_dim).cuda()\n",
    "        \n",
    "# samples, decoded = decode(runner.model.geea, z)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54ea1a3b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb, joint_emb = runner.model.joint_emb_generat(\n",
    "            only_joint=False)\n",
    "\n",
    "embs = [gph_emb, img_emb, rel_emb, att_emb, name_emb, char_emb]\n",
    "\n",
    "g = runner.model.geea\n",
    "\n",
    "import pandas as pd\n",
    "def id2feature(self):\n",
    "    self.id2ent = pd.Series(self.kgs['ent2id'].keys(), index=self.kgs['ent2id'].values()) \n",
    "    self.id2rel = pd.Series(self.kgs['rel2id'].keys(), index=self.kgs['rel2id'].values()) \n",
    "    self.id2attr = pd.Series(self.kgs['attr2id'].keys(), index=self.kgs['attr2id'].values()) \n",
    "    self.left2right = pd.Series(list(self.kgs['test_ill'][:,1])+list(self.kgs['train_ill'][:,1]), index=list(self.kgs['test_ill'][:,0])+list(self.kgs['train_ill'][:,0])) \n",
    "id2feature(g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afe40665",
   "metadata": {},
   "outputs": [],
   "source": [
    "# test_ill = g.kgs['test_ill'].data\n",
    "\n",
    "# encoded_x, encoded_y = g.encode(test_ill[:,0], embs), g.encode(test_ill[:,1], embs)\n",
    "\n",
    "# encoded_x\n",
    "\n",
    "# decoded_x= g.decode(encoded_x, reparameterize=True)\n",
    "# decoded_y= g.decode(encoded_y, reparameterize=True)\n",
    "\n",
    "# decoded_x = [x.cpu() for x in decoded_x]\n",
    "# decoded_y = [y.cpu() for y in decoded_y]\n",
    "\n",
    "\n",
    "# dis = pairwise_distances(gph_emb[test_ill[:,0]].cpu(), gph_emb[test_ill[:,1]].cpu())\n",
    "# dis = pairwise_distances(img_emb[test_ill[:,0]].cpu(), img_emb[test_ill[:,1]].cpu())\n",
    "# mask = dis.argmin(dim=-1).cpu() == torch.arange(len(dis)).cpu()\n",
    "# mask.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c26fedaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "dis = pairwise_distances(gph_emb.cpu(), gph_emb.cpu())\n",
    "dis[:, g.kgs['left_ents']] = 1e4\n",
    "argsort_dis = torch.argsort(dis, dim=-1)\n",
    "dis_top30 = argsort_dis[:, :30]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71897cb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded = g.encode(g.kgs['input_idx'], embs)\n",
    "\n",
    "decoded = g.decode(encoded,reparameterize=True)\n",
    "\n",
    "decoded_attr = g.subdecoders[3].cpu()(decoded[3].cpu())\n",
    "\n",
    "argsort_attr = torch.argsort(decoded_attr, dim=-1, descending=True)\n",
    "\n",
    "attr_top30 = argsort_attr[:, :30]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7272e55e",
   "metadata": {},
   "outputs": [],
   "source": [
    "attr_top30"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3547908",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pandas as pd\n",
    "# def id2feature(self):\n",
    "#     self.id2ent = pd.Series(self.kgs['ent2id'].keys(), index=self.kgs['ent2id'].values()) \n",
    "#     self.id2rel = pd.Series(self.kgs['rel2id'].keys(), index=self.kgs['rel2id'].values()) \n",
    "#     self.id2attr = pd.Series(self.kgs['attr2id'].keys(), index=self.kgs['attr2id'].values()) \n",
    "#     self.left2right = pd.Series(list(self.kgs['test_ill'][:,1])+list(self.kgs['train_ill'][:,1]), index=list(self.kgs['test_ill'][:,0])+list(self.kgs['train_ill'][:,0])) \n",
    "# g = runner.model.geea\n",
    "# id2feature(g)\n",
    "\n",
    "# print(g.id2ent.loc[test_ill[mask][:,1]])\n",
    "\n",
    "# g.id2ent[g.id2ent.values=='<http://dbpedia.org/resource/The_Beatles>']\n",
    "\n",
    "# g.id2ent[g.id2ent.values=='/m/06mmr']\n",
    "\n",
    "# g.id2ent[g.id2ent.values=='<http://dbpedia.org/resource/Star_Wars>']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec740f94",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "g.id2ent[g.id2ent.values=='http://fr.dbpedia.org/resource/Nintendo_3DS']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91553ce6",
   "metadata": {},
   "outputs": [],
   "source": [
    "g.id2ent[g.id2ent.values=='<http://yago-knowledge.org/resource/The_Matrix>']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9f196ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "12094  in g.left2right"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "309f683a",
   "metadata": {},
   "outputs": [],
   "source": [
    "g.kgs['att_features'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4198b3d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3be839da",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def verbose_feature(self, ents, samples):\n",
    "    adj_mat, attr_mat = self.concrete_features[0], self.concrete_features[3]\n",
    "    right_ents = self.left2right.loc[ents].values\n",
    "    \n",
    "    neighbors, attrs = samples[0][ents], samples[1][ents]\n",
    "    for ent, neighbor, attr, right_ent in zip(ents, neighbors, attrs, right_ents):\n",
    "        n = self.id2ent.loc[neighbor]\n",
    "        print(attr)\n",
    "        a = self.id2attr.loc[attr]\n",
    "        \n",
    "        rn = np.where(adj_mat[right_ent].cpu()>0)[0]\n",
    "        rn = self.id2ent.loc[rn]\n",
    "        ra = np.where(attr_mat[right_ent].cpu()>0)[0]\n",
    "        ra = self.id2attr.loc[ra]\n",
    "        print('#'*20 + '\\n this entity is:')\n",
    "        print(self.id2ent.loc[ent])\n",
    "        print('its counterpart is:')\n",
    "        print(self.id2ent.loc[right_ent])\n",
    "        \n",
    "        print('Preditc: it has the neighors:')\n",
    "        print(n)\n",
    "        print('In fact: it has the neighors:')\n",
    "        print(rn)\n",
    "        \n",
    "        print('Overlapped neighors:')\n",
    "        print(n[n.isin(rn)])\n",
    "        \n",
    "        print('Preditct: it has the attribute:')\n",
    "        print(a) \n",
    "        print('In fact: it has the attribute:')\n",
    "        print(ra)\n",
    "        print('Overlapped attribute:')\n",
    "        print(a[a.isin(ra)])\n",
    "        \n",
    "        \n",
    "pd.set_option('display.max_columns', None)  # or 1000\n",
    "pd.set_option('display.max_rows', None)  # or 1000\n",
    "pd.set_option('display.max_colwidth', None)  # or 199\n",
    "    \n",
    "verbose_feature(g, [7615    ,], [dis_top30, attr_top30])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc917eea",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b8d4e35",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c173478",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "def verbose_features(self, ents, samples):\n",
    "    adj_mat, attr_mat = self.concrete_features[0], self.concrete_features[3]\n",
    "    right_ents = self.left2right.loc[ents].values\n",
    "    \n",
    "    neighbors, imgs, rels, attrs = samples\n",
    "    for ent, neighbor, img, rel, attr, right_ent in zip(ents, neighbors, imgs, rels, attrs, right_ents):\n",
    "        n = self.id2ent.loc[neighbor]\n",
    "        i = self.id2ent.loc[img]\n",
    "#         r = self.id2rel.loc[rel]\n",
    "        a = self.id2attr.loc[attr]\n",
    "        \n",
    "        rn = np.where(adj_mat[right_ent].cpu()>0)[0]\n",
    "        rn = self.id2ent.loc[rn]\n",
    "        ra = np.where(attr_mat[right_ent].cpu()>0)[0]\n",
    "        ra = self.id2attr.loc[ra]\n",
    "        print('#'*20 + '\\n this entity is:')\n",
    "        print(self.id2ent.loc[ent])\n",
    "        print('its counterpart is:')\n",
    "        print(self.id2ent.loc[right_ent])\n",
    "        \n",
    "        print('Preditc: it has the neighors:')\n",
    "        print(n)\n",
    "        print('In fact: it has the neighors:')\n",
    "        print(rn)\n",
    "        \n",
    "        \n",
    "        print('its image refers to:')\n",
    "        print(i)\n",
    "        \n",
    "        \n",
    "#         print('it has the relations:')\n",
    "#         print(r)\n",
    "        \n",
    "        print('Preditct: it has the attribute:')\n",
    "        print(a) \n",
    "        print('In fact: it has the attribute:')\n",
    "        print(ra)\n",
    "        \n",
    "        \n",
    "\n",
    "verbose_features(g, ents, recovered_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e0b78cf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0058f396",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from PIL import Image\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6878a1ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.image as mpimg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d874c683",
   "metadata": {},
   "outputs": [],
   "source": [
    "a1 =runner.KGs['att_features'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a05cca90",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.where(runner.KGs['att_features'][:5]>0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87954664",
   "metadata": {},
   "outputs": [],
   "source": [
    "runner.KGs['att_features'][]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
