{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch_geometric.data import Data, DataLoader\n",
    "from cnf import BipartiteData\n",
    "from loss import LossCompute, AccuracyCompute, LossMetric\n",
    "import time\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from utils import load_checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "# add a folder raw under dataset manually before execute this chunk\n",
    "from data import SATDataset\n",
    "ds = SATDataset('../dataset', 'RND3SAT/uf50-218', False)\n",
    "last_trn, last_val = int(len(ds)), int(len(ds))\n",
    "train_ds = ds[: last_trn]\n",
    "valid_ds = ds[last_trn: last_val]\n",
    "test_ds = ds[last_val:]"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_data = train_ds[1]\n",
    "edge_index_pos = test_data.edge_index_pos\n",
    "edge_index_neg = test_data.edge_index_neg\n",
    "variable_count = max(max(edge_index_pos[1]), max(edge_index_neg[1])) + 1\n",
    "clause_count = len(edge_index_pos[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Debug for loss.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def test_loss(iter_num, par_sm, par_sg, var_num, plot=False):\n",
    "    loss_func = LossCompute(par_sm, par_sg, metric=LossMetric.linear_loss, debug=True)\n",
    "    sat_rate = np.zeros(iter_num)\n",
    "    loss_v = np.zeros(iter_num)\n",
    "    for i in range(iter_num):\n",
    "        x_s = LossCompute.push_to_side(torch.rand(var_num, 1), par_sg)\n",
    "        loss, sm = loss_func(x_s, edge_index_pos, edge_index_neg, clause_count=218, is_train=False)\n",
    "        satisfied_percentage = sum(sm > 0.5).numpy() / clause_count\n",
    "        loss_v[i] = loss\n",
    "        sat_rate[i] = satisfied_percentage\n",
    "    if plot:\n",
    "        plt.plot(sat_rate, loss_v, \"ro\")\n",
    "        plt.xlabel(\"Satisfied Clauses / Numer of Clauses\")\n",
    "        plt.ylabel(\"Loss\")\n",
    "    return sat_rate, loss_v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "error",
     "ename": "TypeError",
     "evalue": "__call__() missing 1 required positional argument: 'gr_idx_cls'",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-318e82c1bda7>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mstart\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0msat_rate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_v\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtest_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m30\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m50\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvariable_count\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mplot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0mspan\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Time for loss1 and loss2 to compute loss of 5000 FG respectively takes {span}s\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-4-9558f9e9132c>\u001b[0m in \u001b[0;36mtest_loss\u001b[0;34m(iter_num, par_sm, par_sg, var_num, plot)\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miter_num\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m         \u001b[0mx_s\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLossCompute\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpush_to_side\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrand\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvar_num\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpar_sg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_s\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0medge_index_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0medge_index_neg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclause_count\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m218\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_train\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m         \u001b[0msatisfied_percentage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msm\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0.5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mclause_count\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m         \u001b[0mloss_v\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: __call__() missing 1 required positional argument: 'gr_idx_cls'"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "sat_rate, loss_v = test_loss(5000, 30, 50, variable_count, plot=True)\n",
    "span = time.time() - start\n",
    "print(f\"Time for loss1 and loss2 to compute loss of 5000 FG respectively takes {span}s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Debug for models.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import models\n",
    "from args import make_args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Args():\n",
    "    def __init__(self):\n",
    "        self.dataset = 'RND3SAT/uf50-218'\n",
    "        self.dataset_root = '../dataset'\n",
    "        self.loss = 'l2'\n",
    "        self.use_gpu = True\n",
    "        self.cuda = '0'\n",
    "        self.graph_valid_ratio = 0.1\n",
    "        self.graph_test_ratio = 0.1\n",
    "        self.feature_transform = False\n",
    "        self.drop_rate = 0.5\n",
    "        self.speedip = False\n",
    "        self.load_model = False\n",
    "        self.batch_size = 1\n",
    "        self.num_layers = 2\n",
    "        self.num_encoder_layers = 4\n",
    "        self.num_decoder_layers = 2\n",
    "        self.num_meta_paths = 4\n",
    "        self.encoder_channels = '1,16,32,32,32'\n",
    "        self.decoder_channels = '32,16,16'\n",
    "        self.self_att_heads = 8\n",
    "        self.cross_att_heads = 8\n",
    "        self.lr = 1e-6\n",
    "        self.sm_par = 1e-3\n",
    "        self.sig_par = 1e-3\n",
    "        self.warmup_steps = 200\n",
    "        self.opt_train_factor = 4\n",
    "        self.epoch_num = 201\n",
    "        self.epoch_log = 50\n",
    "        self.epoch_save = 50\n",
    "        self.save_root = 'saved_model'\n",
    "        self.root = '../dataset'\n",
    "        self.activation = 'relu'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## sample a test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "args = Args()\n",
    "device = torch.device('cuda:0') if args.use_gpu and torch.cuda.is_available() else torch.device('cpu')\n",
    "# device = 'cuda' if torch.cuda.is_available() else 'cpu\n",
    "# download and save the dataset\n",
    "dataset = SATDataset(args.root, args.dataset, use_negative=False)\n",
    "dataset, perm = dataset.shuffle(return_perm=True)\n",
    "num_clauses = dataset.num_clauses[perm][900:]\n",
    "test_loader = DataLoader(dataset[900:],\n",
    "                             batch_size=1,\n",
    "                             shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize a model and to make it run through"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "new accuracy:  None\n0.0\ntensor(0.9260, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9555, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9545, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9608, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9465, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9622, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9674, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9365, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9337, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9608, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9474, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9385, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9751, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9760, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9791, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9460, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9575, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9365, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9454, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9339, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9431, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9738, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9401, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9432, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9589, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9449, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9471, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9310, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9551, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9502, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9579, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9556, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9373, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9347, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9271, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9654, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9512, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9649, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9580, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9437, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9431, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9120, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9743, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9619, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9576, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9604, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9622, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9448, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9697, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9545, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9672, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9496, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9476, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9500, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9579, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9254, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9357, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9355, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9398, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9327, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9559, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9532, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9467, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9346, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9516, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9441, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9466, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9573, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9412, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9631, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9470, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9325, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9722, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9580, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9761, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9452, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9563, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9572, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9736, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9492, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9479, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9628, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9567, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9227, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9486, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9419, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9313, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9351, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9546, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9522, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9422, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9352, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9318, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9698, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9683, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9549, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9357, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9628, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9500, grad_fn=<DivBackward0>)\nnew accuracy:  None\n0.0\ntensor(0.9669, grad_fn=<DivBackward0>)\n"
    }
   ],
   "source": [
    "model = models.make_model(args)\n",
    "path = '../saved_model/check_point_500.pickle'\n",
    "checkpoint = torch.load(path)\n",
    "model.load_state_dict(checkpoint['model_state_dict'])\n",
    "loss_func = LossCompute(30, 50, metric=LossMetric.linear_loss, debug=True)\n",
    "accuracy = AccuracyCompute()\n",
    "for i, batch in enumerate(test_loader):\n",
    "    xv = model(batch, args)\n",
    "    num_cls = num_clauses[i * 1 : (i + 1) * 1]\n",
    "    gr_idx_cls = torch.cat([torch.tensor([i] * num_cls[i]) for i in range(num_cls.size(0))])\n",
    "    loss, sm = loss_func(xv, batch.edge_index_pos, batch.edge_index_neg, clause_count=218, gr_idx_cls=gr_idx_cls, is_train=False)\n",
    "    acc = accuracy(xv, batch.edge_index_pos, batch.edge_index_neg).item()\n",
    "    print(\"new accuracy: \", flip_search(xv, sm, batch.edge_index_pos, batch.edge_index_neg, acc))\n",
    "    print(acc)\n",
    "    print(torch.sum(sm) / 218)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "tensor([0.2695, 0.7240, 0.8077, 0.8801])\naccuracy:  tensor(1.)\ntensor([0.7167, 0.7261, 0.8555])\nnew accuracy:  None\ntensor([0.2695, 0.7240, 0.8077, 0.8801])\n"
    }
   ],
   "source": [
    "from loss import flip_search\n",
    "x = torch.rand(4)\n",
    "edge_index_pos = torch.tensor([\n",
    "    [0, 0, 1, 2, 2],\n",
    "    [0, 1, 1, 2, 3],\n",
    "])\n",
    "edge_index_neg = torch.tensor([\n",
    "    [0, 1, 1, 2],\n",
    "    [2, 0, 2, 1],\n",
    "])\n",
    "accuracy_compute = AccuracyCompute()\n",
    "accuracy = accuracy_compute(x, edge_index_pos, edge_index_neg)\n",
    "sm = LossCompute.get_sm(x, edge_index_pos, edge_index_neg, 10, 10)\n",
    "print(x)\n",
    "print(\"accuracy: \", accuracy)\n",
    "print(sm)\n",
    "print(\"new accuracy: \", flip_search(x, sm, edge_index_pos, edge_index_neg, accuracy))\n",
    "print(x)"
   ]
  },
  {
   "source": [
    "vx = model(vx, vc, graph)   # graph is concatenated batch of 32 graphs  \n",
    "vx.shape = (sum(lit_num_of_each_graph), 1)  \n",
    "vc = sm = LossCompute(vx, graph)  \n",
    "batch_index.shape = (sum(cls_num_of_each_graph), 1)  "
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# After finish model.py, debug train.py"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python38364bitbaseconda16d50af7b8b1404a8c193832594d3dff",
   "language": "python",
   "display_name": "Python 3.8.3 64-bit ('base': conda)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "3.8.3-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}