{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models import Model\n",
    "import torch\n",
    "from dataloader import load_data\n",
    "from utils import get_training_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "([tensor([[ 5.8442e-01, -1.0356e-01,  8.1319e-04,  ...,  3.6470e-13,\n",
      "         -8.1418e-01, -4.6541e-01],\n",
      "        [ 1.2335e+00, -9.5575e-01,  7.9504e-04,  ...,  3.6470e-13,\n",
      "         -2.5618e-01, -1.9892e+00],\n",
      "        [ 8.6501e-01, -5.7860e-01,  8.0410e-04,  ...,  3.6470e-13,\n",
      "         -3.1976e-01, -1.0682e+00],\n",
      "        [ 5.6083e-01, -3.5569e-01,  8.1139e-04,  ...,  3.6470e-13,\n",
      "         -5.0912e-02, -1.1522e+00],\n",
      "        [ 5.9504e-01, -1.2344e+00,  7.9975e-04,  ...,  3.6470e-13,\n",
      "         -3.4080e-01, -6.1032e-01]], device='cuda:7', grad_fn=<AddmmBackward0>), tensor([[-0.4326, -0.5708,  0.2880,  ...,  0.1598, -0.1299, -0.6277],\n",
      "        [-0.3381, -0.4483, -0.1010,  ..., -0.1019, -0.3313, -0.3271],\n",
      "        [-0.1578, -0.3170, -0.3024,  ..., -0.3029, -0.3202, -0.2118],\n",
      "        [-0.0832, -0.3379,  0.0926,  ...,  0.0070, -0.0386, -0.3235],\n",
      "        [-0.0428, -0.4614,  0.0952,  ...,  0.0340, -0.0761, -0.3674]],\n",
      "       device='cuda:7', grad_fn=<AddmmBackward0>)], tensor([[ 8.6250,  3.5353,  2.5977,  2.3122,  1.7462,  3.2218,  2.3302, -0.9583,\n",
      "          1.8395,  0.6811,  2.6614,  2.2533,  0.6455, -0.2155, -1.2876, -0.7696,\n",
      "          1.5092,  0.6498,  1.7766,  2.5994, -0.1170,  5.0221,  0.0792,  1.7163,\n",
      "         -0.2546,  2.4614,  1.6401, -1.9606, -0.9698, -1.3876, -1.8324, -1.5089,\n",
      "         -1.1285, -1.9537, -1.9629, -1.8634, -1.9777, -1.9411, -1.9872, -1.9298,\n",
      "         -1.7938, -1.9857, -2.0122, -2.0151, -2.0146, -2.0151, -2.0149],\n",
      "        [ 3.4078,  8.7837,  4.3939,  3.7900,  3.5715,  1.9104, -0.1307, -0.9546,\n",
      "         -0.4370,  5.3219,  0.5227,  1.6680, -0.9306,  0.9265, -1.8959, -1.9098,\n",
      "          0.6061,  2.4731,  0.0157,  3.9762,  0.2678,  0.4718, -0.5785,  1.6585,\n",
      "          1.1599,  0.9607,  0.8731, -1.7668,  0.8049,  0.1789, -1.6464, -1.5529,\n",
      "         -1.2138, -1.7634, -1.7779, -1.7133, -1.7877, -1.7575, -1.7928, -1.7395,\n",
      "         -1.6037, -1.7908, -1.8146, -1.8169, -1.8164, -1.8169, -1.8167],\n",
      "        [ 1.4560,  5.9002,  8.8393,  1.9918,  2.0915,  0.6639,  1.5218,  0.6239,\n",
      "          1.3836,  2.1215,  0.8923,  2.1940, -0.0375, -0.6560, -0.9030, -0.3696,\n",
      "          0.4463,  1.8303,  0.1459,  2.0479, -0.8840, -0.0096, -0.7529,  1.8901,\n",
      "          0.8748, -0.0989, -0.5645, -1.7227,  2.4612,  3.9332, -1.6202, -1.4203,\n",
      "         -1.1587, -1.7189, -1.7267, -1.6559, -1.7423, -1.7179, -1.7458, -1.7022,\n",
      "         -1.5674, -1.7443, -1.7651, -1.7673, -1.7668, -1.7674, -1.7671],\n",
      "        [ 6.5086,  3.7320,  3.0863,  5.9444,  1.7101,  2.0350,  2.5874, -0.8093,\n",
      "          1.2827,  2.6385,  1.0594,  3.1118,  0.2759,  0.9857, -1.5033, -0.6232,\n",
      "          0.2904, -0.1523,  1.1692,  1.8500, -1.2071,  1.6546, -1.6078,  2.9196,\n",
      "          0.6128,  0.9179,  1.4693, -1.7239, -0.4869, -0.9233, -1.6104, -1.4562,\n",
      "         -1.1702, -1.7201, -1.7349, -1.6652, -1.7468, -1.7209, -1.7530, -1.7023,\n",
      "         -1.5788, -1.7513, -1.7740, -1.7764, -1.7760, -1.7764, -1.7762],\n",
      "        [ 5.7744,  4.7144,  3.6592,  5.9032,  1.3523,  2.2229,  1.4047, -1.2503,\n",
      "          1.1803,  2.2009,  0.7076,  2.9293, -0.0316,  0.1391, -2.0155, -1.4973,\n",
      "          1.2216,  0.3286,  0.9488,  2.8041, -1.0363,  2.3572, -1.9520,  2.7534,\n",
      "          0.3699,  1.3186,  1.2791, -1.6531, -0.1117, -0.6066, -1.5438, -1.4310,\n",
      "         -1.0977, -1.6514, -1.6628, -1.6060, -1.6733, -1.6444, -1.6796, -1.6272,\n",
      "         -1.5156, -1.6778, -1.7002, -1.7024, -1.7021, -1.7024, -1.7023]],\n",
      "       device='cuda:7', grad_fn=<AddmmBackward0>))\n"
     ]
    }
   ],
   "source": [
    "def inference(\n",
    "    state_dict_path, \n",
    "    data_name, \n",
    "    data_cache_path, \n",
    "    dw_emb_path,\n",
    "    student_name, \n",
    "    config_path,\n",
    "    device='cuda:7',\n",
    "    n_data=5\n",
    "):\n",
    "    # get feature\n",
    "    g, labels, idx_train, idx_val, idx_test = load_data(\n",
    "        data_name,\n",
    "        data_cache_path,\n",
    "        split_idx=0,\n",
    "        seed=0,\n",
    "        labelrate_train=20,\n",
    "        labelrate_val=30,\n",
    "    )\n",
    "    feats = g.ndata[\"feat\"].to(device)\n",
    "    # dw \n",
    "    loaded_dw_emb = torch.load(dw_emb_path).to(device)\n",
    "    position_feature = loaded_dw_emb\n",
    "    len_position_feature = position_feature.shape[-1]\n",
    "    feats = torch.cat([feats, position_feature], dim=1)\n",
    "    feats = feats[:n_data]\n",
    "\n",
    "    # get model\n",
    "    conf = {}\n",
    "    conf = get_training_config(\n",
    "        config_path,\n",
    "        student_name,\n",
    "        data_name\n",
    "    )  # Note: student config\n",
    "    conf['feat_dim'] = g.ndata[\"feat\"].shape[1]\n",
    "    conf['label_dim'] = labels.int().max().item() + 1\n",
    "    conf['device'] = device\n",
    "    model = Model(conf, None, len_position_feature)\n",
    "    state_dict = torch.load(state_dict_path)\n",
    "    model.load_state_dict(state_dict)\n",
    "    model.eval()\n",
    "\n",
    "    # inference\n",
    "    batch_mlp_emb, logits = model(None, feats)\n",
    "    return batch_mlp_emb, logits\n",
    "\n",
    "res = inference(\n",
    "    state_dict_path='./outputs/transductive/ogbn-products/SAGE_MLP/seed_0/model.pth',\n",
    "    data_name='ogbn-products',\n",
    "    data_cache_path='./data',\n",
    "    dw_emb_path='./outputs/transductive/ogbn-products/SAGE_MLP/dw_emb.pt',\n",
    "    student_name='MLP',\n",
    "    config_path='./tran.conf.yaml',\n",
    "    n_data=5\n",
    ")\n",
    "print(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
