{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1., 1., 1., 0., 1., 1.])\n",
      "1\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "from datagenerator import *\n",
    "from module import LogicModule2\n",
    "from module import LogicModule1\n",
    "from model import LogicModel\n",
    "\n",
    "# D = ILP2()\n",
    "# D = ILP5()\n",
    "# D = ILP6()\n",
    "# D = ILP16()\n",
    "D = ILP20()\n",
    "\n",
    "n_predicate = D.n_predicate\n",
    "target = n_predicate[0]-1\n",
    "b,B = D.get_data(dim=10)\n",
    "p = b[target].clone()\n",
    "neg = 1-b[target]\n",
    "b[target] = 0\n",
    "model = LogicModel(n_predicate[0],n_predicate[1])\n",
    "print(p)\n",
    "print(target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████| 300/300 [00:02<00:00, 146.29it/s, POS: 2.92 NEG: 0.97]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Proved:   [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]\n",
      "Unproved: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
      "Done.\n",
      "tensor(0.7138)\n",
      "tensor(0.9967)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model.add_module2()\n",
    "model.add_module2()\n",
    "model.add_module1()\n",
    "model.add_module1()\n",
    "model.add_module1()\n",
    "model.add_module1()\n",
    "model.add_module1(target)\n",
    "\n",
    "opt = torch.optim.Adam(model.parameters(),lr=3e-1)\n",
    "with tqdm(range(300),ncols=80) as _t:\n",
    "    for _ in _t:\n",
    "        pred = model.forward(b, B)[0][model.submodules[-1].info['dim']]\n",
    "        loss = -((p*pred).sum()/p.sum()+1e-5).log() + ((neg*pred).sum()/neg.sum()+1e-5).log()\n",
    "        _t.set_postfix_str('POS: {:.2f} NEG: {:.2f}'.format((p*pred).sum(),(neg*pred).sum()))\n",
    "        opt.zero_grad()\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "\n",
    "model.fix_parameters()\n",
    "model.inference_(b, B)\n",
    "print('Proved:   '+str(b[target].tolist()))\n",
    "p -= b[target]\n",
    "p = p.clamp(min=0,max=1)\n",
    "print('Unproved: '+str(p.tolist()))\n",
    "if p.sum() == 0:\n",
    "    print('Done.')\n",
    "\n",
    "print(model.submodules[-1].theta.softmax(dim=-1).max())\n",
    "m = model.submodules[-1].theta.argmax()\n",
    "if m == 0:\n",
    "    print(model.submodules[-1].w1.softmax(dim=-1).max())\n",
    "elif m == 1:\n",
    "    print(model.submodules[-1].w2.softmax(dim=-1).max())\n",
    "    print(model.submodules[-1].w3.softmax(dim=-1).max())\n",
    "elif m == 2:\n",
    "    print(model.submodules[-1].w4.softmax(dim=-1).max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1.],\n",
      "        [0., 0., 0., 1.],\n",
      "        [0., 0., 0., 0.]])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "from datagenerator import *\n",
    "from module import LogicModule2\n",
    "from module import LogicModule1\n",
    "from model import LogicModel\n",
    "\n",
    "\n",
    "# D = ILP1()\n",
    "# D = ILP4()\n",
    "# D = ILP7()\n",
    "# D = ILP8()\n",
    "# D = ILP15()\n",
    "# D = ILP9()\n",
    "D = ILP19()\n",
    "\n",
    "n_predicate = D.n_predicate\n",
    "target = n_predicate[1]-1\n",
    "b,B = D.get_data(dim=10)\n",
    "p = B[target].clone()\n",
    "neg = 1-B[target]\n",
    "B[target] = 0\n",
    "model = LogicModel(n_predicate[0],n_predicate[1])\n",
    "print(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████| 500/500 [00:02<00:00, 244.79it/s, POS: 2.10 NEG: 0.00]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1.],\n",
      "        [0., 0., 0., 1.],\n",
      "        [0., 0., 0., 0.]])\n",
      "tensor([[0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.]])\n",
      "Unproved: 0.0\n",
      "Done.\n",
      "tensor(0.9997)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model.add_module2()\n",
    "model.add_module2(target)\n",
    "\n",
    "opt = torch.optim.Adam(model.parameters(),lr=1e-1)\n",
    "with tqdm(range(500),ncols=80) as _t:\n",
    "    for _ in _t:\n",
    "        pred = model.forward(b, B)[1][model.submodules[-1].info['dim']]\n",
    "        loss = -((p*pred).sum()/p.sum()+1e-5).log() + ((neg*pred).sum()/neg.sum()+1e-5).log()\n",
    "        _t.set_postfix_str('POS: {:.2f} NEG: {:.2f}'.format((p*pred).sum(),(neg*pred).sum()))\n",
    "        opt.zero_grad()\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "\n",
    "model.fix_parameters()\n",
    "model.inference_(b, B)\n",
    "print(B[target])\n",
    "p -= B[target]\n",
    "p = p.clamp(min=0,max=1)\n",
    "print(p)\n",
    "print('Unproved: '+str(p.sum().tolist()))\n",
    "if p.sum() == 0:\n",
    "    print('Done.')\n",
    "\n",
    "print(model.submodules[-1].theta.softmax(dim=-1).max())\n",
    "m = model.submodules[-1].theta.argmax()\n",
    "if m == 0:\n",
    "    print(model.submodules[-1].w1.softmax(dim=-1).max())\n",
    "elif m == 2:\n",
    "    print(model.submodules[-1].w5.softmax(dim=-1).max())\n",
    "    print(model.submodules[-1].w6.softmax(dim=-1).max())\n",
    "    print(model.submodules[-1].w7.softmax(dim=-1).max())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "b081a66ee97bd2b6a16f43955f1d810b7ea816d6eaeb65e157ef9e038445f0c6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
