{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import pickle\n",
    "from Dataset import OntologyDataset\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "from tqdm import tqdm\n",
    "import re\n",
    "from torch.utils.data.sampler import RandomSampler\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from torch.autograd import Variable\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import os\n",
    "\n",
    "info_path = \"dfalc_data\"\n",
    "file_name = \"toyOntology.owl\"\n",
    "params = {\n",
    "        \"conceptPath\": os.path.join(info_path,file_name+\"_concepts.txt\"),\n",
    "        \"rolePath\": os.path.join(info_path,file_name+\"_roles.txt\"),\n",
    "        \"individualPath\": os.path.join(info_path,file_name+\"_individuals.txt\"),\n",
    "        \"normalizationPath\": os.path.join(info_path,file_name+\"_normalization.txt\"),\n",
    "        \"batchSize\": 3,\n",
    "        \"epochSize\":10,\n",
    "        \"earlystopping\":10,\n",
    "        \"dist\": \"minkowski\",\n",
    "        \"norm\":1,\n",
    "        \"norm_rate\":0.5,\n",
    "        \"norm_rate2\":0\n",
    "    }\n",
    "to_train = False\n",
    "\n",
    "save_path = \"dfalc_data\"\n",
    "if to_train: save_path = os.path.join(save_path,\"training\")\n",
    "else: save_path = os.path.join(save_path,\"testing\")\n",
    "save_path += \"/toyOntology_\"\n",
    "dataset = OntologyDataset(params,save_path)\n",
    "\n",
    "individualSize = 8\n",
    "\n",
    "\n",
    "print(dataset.concept2id)\n",
    "print(dataset.role2id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#exists NF3\n",
    "cEmb_init = torch.zeros(dataset.conceptSize, individualSize)\n",
    "rEmb_init = torch.zeros(4, individualSize, individualSize)\n",
    "cEmb_init[2] = torch.FloatTensor([0,0,0,0.9,0.9,0,0,0.9])\n",
    "cEmb_init[5] = torch.FloatTensor([0.9,0,0,0,0,0,0.9,0])\n",
    "rEmb_init[3,0,1], rEmb_init[3,2,3], rEmb_init[3,4,5], rEmb_init[3,6,7] = 0.9,0.9,0.9,0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#forall NF4\n",
    "cEmb_init[3] = torch.FloatTensor([0,0,0,0.9,0.9,0,0,0.9])\n",
    "cEmb_init[7] = torch.FloatTensor([0.9,0,0,0,0,0,0.9,0])\n",
    "rEmb_init[2,0,1], rEmb_init[2,2,3], rEmb_init[2,4,5], rEmb_init[2,6,7] = 0.9,0.9,0.9,0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#exists NF5\n",
    "cEmb_init[0] = torch.FloatTensor([0,0,0,0.9,0.9,0,0,0.9])\n",
    "cEmb_init[4] = torch.FloatTensor([0.9,0,0,0,0,0,0.9,0])\n",
    "rEmb_init[0,0,1], rEmb_init[0,2,3], rEmb_init[0,4,5], rEmb_init[0,6,7] = 0.9,0.9,0.9,0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#forall NF6\n",
    "cEmb_init[1] = torch.FloatTensor([0,0,0,0.9,0.9,0,0,0.9])\n",
    "cEmb_init[6] = torch.FloatTensor([0.9,0,0,0,0,0,0.9,0])\n",
    "rEmb_init[1,0,1], rEmb_init[1,2,3], rEmb_init[1,4,5], rEmb_init[1,6,7] = 0.9,0.9,0.9,0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "    \n",
    "class DFALC(nn.Module):\n",
    "    def __init__(self, params, conceptSize, roleSize, cEmb_init, rEmb_init,  device, name=\"Godel\"):\n",
    "        super().__init__()\n",
    "        self.params = params\n",
    "        self.conceptSize, self.roleSize = conceptSize, roleSize\n",
    "        self.device = device\n",
    "        self.cEmb = nn.Parameter(torch.tensor(cEmb_init))\n",
    "        self.rEmb = nn.Parameter(torch.tensor(rEmb_init))\n",
    "        self.relu = torch.nn.ReLU()\n",
    "        # self.c_mask, self.r_mask = self.get_mask()\n",
    "        self.logic_name = name\n",
    "        self.epsilon = 1e-2\n",
    "        self.p=2\n",
    "        self.alpha=0.8\n",
    "\n",
    "    def to_sparse(self, A):\n",
    "        return torch.sparse_coo_tensor(np.where(A!=0),A[np.where(A!=0)],A.shape)\n",
    "    \n",
    "    def index_sparse(self, A, idx):\n",
    "        return torch.where(A.indices[0] in idx)\n",
    "    \n",
    "    def pi_0(self, x):\n",
    "        return (1-self.epsilon)*x+self.epsilon\n",
    "    \n",
    "    def pi_1(self, x):\n",
    "        return (1-self.epsilon)*x\n",
    "    \n",
    "    \n",
    "    def neg(self, x, negf):\n",
    "        negf = negf.unsqueeze(1)\n",
    "        # print(\"negf: \",negf.shape)\n",
    "        # print(\"x: \",x.shape)\n",
    "        negf2 = negf*(-2) + 1\n",
    "        # print(\"negf2: \",negf2)\n",
    "        # print(\"negf2: \",negf2.shape)\n",
    "        \n",
    "        return negf2*x\n",
    "        \n",
    "    def t_norm(self, x, y):\n",
    "        if self.logic_name == \"Godel\":\n",
    "            return torch.minimum(x,y)\n",
    "        elif self.logic_name == \"LTN\":\n",
    "            return self.pi_0(x)*self.pi_0(y)\n",
    "        # elif self.logic_name == \"Product\":\n",
    "        #     return x*y\n",
    "        \n",
    "    def t_cnorm(self, x, y):\n",
    "        if self.logic_name == \"Godel\":\n",
    "            return torch.maximum(x,y)\n",
    "        elif self.logic_name == \"LTN\":\n",
    "            a = self.pi_1(x)\n",
    "            b = self.pi_1(y)\n",
    "            return a+b-a*b\n",
    "        # elif self.logic_name == \"Product\":\n",
    "        #     return x+y-x*y\n",
    "\n",
    "    def forall(self, r, x):\n",
    "        if self.logic_name == \"Godel\":\n",
    "            return torch.min(self.t_cnorm(1-r,x.unsqueeze(1).expand(r.shape)),2).values\n",
    "        elif self.logic_name == \"LTN\":\n",
    "            return 1-torch.pow(torch.mean(torch.pow(1-self.pi_1(self.t_cnorm(r,x.unsqueeze(1).expand(r.shape))),self.p),2),1/self.p)\n",
    "        # elif self.logic_name == \"Product\":\n",
    "        #     return torch.prod(torch.max(-b,0),2)\n",
    "    \n",
    "    def exist(self, r, x):\n",
    "        if self.logic_name == \"Godel\":\n",
    "            return torch.max(self.t_norm(r,x.unsqueeze(1).expand(r.shape)),2).values\n",
    "        elif self.logic_name == \"LTN\":\n",
    "            return torch.pow(torch.mean(torch.pow(self.pi_0(self.t_norm(r,x.unsqueeze(1).expand(r.shape))),self.p),2),1/self.p)\n",
    "    \n",
    "    def L2(self, x, dim=1):\n",
    "        return torch.sqrt(torch.sum((x)**2, dim))\n",
    "    \n",
    "    def L2_dist(self, x, y, dim=1):\n",
    "        return torch.sqrt(torch.sum((x-y)**2, dim))\n",
    "    \n",
    "    def L1(self,x,dim=1):\n",
    "        return torch.sum(torch.abs(x),dim)\n",
    "    \n",
    "    def L1_dist(self,x,y,dim=1):\n",
    "        return torch.sum(torch.abs(x-y),dim)\n",
    "    \n",
    "    def HierarchyLoss(self, lefte, righte):\n",
    "        return torch.mean(self.L1(self.relu(lefte-righte)))\n",
    "\n",
    "\n",
    "        \n",
    "    def forward(self, batch, atype, device):\n",
    "        left, right, negf = batch\n",
    "        # print(\"here negf: \", negf.shape)\n",
    "        # print('here left: ',left.shape)\n",
    "        # print(\"here right: \", right.shape)\n",
    "        \n",
    "        loss, lefte, righte, loss1 = None, None, None, None\n",
    "        \n",
    "        self.cEmb[-1,:].detach().masked_fill_(self.cEmb[-1,:].gt(0.0),1.0)\n",
    "        self.cEmb[-2,:].detach().masked_fill_(self.cEmb[-2,:].lt(1),0.0)\n",
    "        \n",
    "        \n",
    "        if atype == 0:\n",
    "            lefte = self.neg(self.cEmb[left],-negf[:,0])\n",
    "            righte = self.neg(self.cEmb[right],negf[:,1])\n",
    "            shape = lefte.shape\n",
    "            # b_c_mask = self.c_mask[left] \n",
    "            \n",
    "        elif atype == 1:\n",
    "            righte = self.neg(self.cEmb[right], negf[:,2])\n",
    "            shape = righte.shape\n",
    "            lefte = self.t_norm(self.neg(self.cEmb[left[:,0]],negf[:,0]), self.neg(self.cEmb[left[:,1]],negf[:,1]))\n",
    "            loss1 = -righte*(self.relu(lefte-righte).detach())\n",
    "            \n",
    "        elif atype == 2:\n",
    "            lefte = self.neg(self.cEmb[left], negf[:,0])\n",
    "            shape = lefte.shape\n",
    "            righte = self.t_cnorm(self.neg(self.cEmb[right[:,0]],negf[:,1]), self.neg(self.cEmb[right[:,1]],negf[:,2]))\n",
    "            loss1 = -lefte*(self.relu(lefte-righte).detach())\n",
    "\n",
    "        elif atype == 3:\n",
    "            lefte = self.neg(self.cEmb[left], negf[:,0])\n",
    "            shape = lefte.shape\n",
    "            righte = self.exist(self.rEmb[right[:,0]], self.neg(self.cEmb[right[:,1]],negf[:,1]))\n",
    "\n",
    "        elif atype == 4:\n",
    "            lefte = self.neg(self.cEmb[left], negf[:,0])\n",
    "            shape = lefte.shape\n",
    "            righte = self.forall(self.rEmb[right[:,0]],self.neg(self.cEmb[right[:,1]], negf[:,1]))\n",
    "            \n",
    "            \n",
    "        elif atype == 5:\n",
    "            righte = self.neg(self.cEmb[right], negf[:,1])\n",
    "            shape = righte.shape\n",
    "            lefte = self.exist(self.rEmb[left[:,0]],self.neg(self.cEmb[left[:,1]], negf[:,0]))\n",
    "            lefte2 = self.neg(self.cEmb[left[:,1]], negf[:,0])\n",
    "            righte2 = torch.matmul(righte, self.rEmb[left[:,0]]).squeeze(2)\n",
    "            # righte1 = torch.matmul(self.rEmb[left[:,0]],lefte2.T).squeeze(2)\n",
    "            righte1 = torch.bmm(self.rEmb[left[:,0]],lefte2.unsqueeze(2)).squeeze()\n",
    "            # print(\"hehre: \",((1-self.relu(righte-self.alpha))*self.relu(righte1-righte)).detach())\n",
    "            loss1 =  (1-righte)*(((1-self.relu(righte-self.alpha))*self.relu(righte1-self.alpha)*self.relu(self.alpha-righte)).detach()) #self.relu(torch.max(self.rEmb[left[:,0]],1).values-self.alpha)*\n",
    "\n",
    "        elif atype == 6:\n",
    "            righte = self.neg(self.cEmb[right], negf[:,1])\n",
    "            shape = righte.shape\n",
    "            lefte = self.forall(self.rEmb[left[:,0]],self.neg(self.cEmb[left[:,1]], negf[:,0]))\n",
    "            lefte2 = self.neg(self.cEmb[left[:,1]], negf[:,0])\n",
    "            # righte2 = torch.matmul(righte, self.rEmb[left[:,0]]).squeeze(2)\n",
    "            righte2 = torch.bmm(righte.unsqueeze(1), self.rEmb[left[:,0]]).squeeze()\n",
    "            loss1 = (1-lefte2)*(((1-self.relu(lefte2-self.alpha))*self.relu(righte2-self.alpha)*self.relu(self.alpha-lefte2)).detach())\n",
    "        # print(\"lefte: \", lefte)\n",
    "        # print(\"righte: \", righte)\n",
    "        # print(\"r: \", torch.max(self.rEmb[left[:,0]],1).values)\n",
    "        # print(\"lefte2: \", lefte2)\n",
    "        # print(\"righte2: \", righte2)\n",
    "        # print(\"righte1: \", righte1)\n",
    "        # print(\"loss1: \", loss1)\n",
    "        # print(\"atype: \",atype)\n",
    "        loss = self.HierarchyLoss(lefte, righte)\n",
    "        return loss#torch.mean(torch.sum(loss1,1))\n",
    "        # print(\"loss: \", self.relu(lefte-righte)+loss1)\n",
    "            \n",
    "        # return torch.mean(torch.sum(self.relu(lefte-righte),1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "human tensor([0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000],\n",
      "       grad_fn=<SelectBackward>)\n",
      "grass tensor([0.0000, 0.0000, 0.0000, 0.9000, 0.9000, 0.0000, 0.0000, 0.9000],\n",
      "       grad_fn=<SelectBackward>)\n",
      "tensor([[0.0000, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
      "       grad_fn=<SelectBackward>)\n",
      "tensor(5) tensor([3, 2]) tensor([0, 0])\n",
      "loss:  tensor(1.8000, grad_fn=<MeanBackward0>)\n",
      "tensor(5) tensor([3, 2]) tensor([0, 0])\n",
      "loss:  tensor(1.7000, grad_fn=<MeanBackward0>)\n",
      "tensor(5) tensor([3, 2]) tensor([0, 0])\n",
      "loss:  tensor(1.5256, grad_fn=<MeanBackward0>)\n",
      "tensor(5) tensor([3, 2]) tensor([0, 0])\n",
      "loss:  tensor(1.3397, grad_fn=<MeanBackward0>)\n",
      "tensor(5) tensor([3, 2]) tensor([0, 0])\n",
      "loss:  tensor(1.1488, grad_fn=<MeanBackward0>)\n",
      "tensor(5) tensor([3, 2]) tensor([0, 0])\n",
      "loss:  tensor(1.0029, grad_fn=<MeanBackward0>)\n",
      "tensor(5) tensor([3, 2]) tensor([0, 0])\n",
      "loss:  tensor(0.8441, grad_fn=<MeanBackward0>)\n",
      "tensor(5) tensor([3, 2]) tensor([0, 0])\n",
      "loss:  tensor(0.6729, grad_fn=<MeanBackward0>)\n",
      "tensor(5) tensor([3, 2]) tensor([0, 0])\n",
      "loss:  tensor(0.4938, grad_fn=<MeanBackward0>)\n",
      "tensor(5) tensor([3, 2]) tensor([0, 0])\n",
      "loss:  tensor(0.3091, grad_fn=<MeanBackward0>)\n",
      "tensor(5) tensor([3, 2]) tensor([0, 0])\n",
      "loss:  tensor(0.1364, grad_fn=<MeanBackward0>)\n",
      "tensor(5) tensor([3, 2]) tensor([0, 0])\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "human tensor([0.3500, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3500, 0.0000],\n",
      "       grad_fn=<SelectBackward>)\n",
      "grass tensor([0.3796, 0.0000, 0.0000, 0.9000, 0.9000, 0.0000, 0.0000, 0.9000],\n",
      "       grad_fn=<SelectBackward>)\n",
      "tensor([[0.3624, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.3624, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
      "       grad_fn=<SelectBackward>)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\92803\\.conda\\envs\\rl_hw\\lib\\site-packages\\ipykernel_launcher.py:12: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  if sys.path[0] == '':\n",
      "c:\\Users\\92803\\.conda\\envs\\rl_hw\\lib\\site-packages\\ipykernel_launcher.py:13: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  del sys.path[0]\n"
     ]
    }
   ],
   "source": [
    "import torch.optim as optim\n",
    "device = torch.device(\"cpu\")\n",
    "model = DFALC(params,dataset.conceptSize, dataset.roleSize, cEmb_init=cEmb_init, rEmb_init=rEmb_init, device=device)\n",
    "dataset.mode = 3\n",
    "optimizer = optim.Adam(model.parameters(),lr=0.05)\n",
    "print('human', model.cEmb[5])\n",
    "print('grass',model.cEmb[2])\n",
    "print(model.rEmb[3])\n",
    "for e in range(30):\n",
    "    for l,r,n in dataset:\n",
    "        print(l,r,n)\n",
    "        loss = model([l.unsqueeze(0),r.unsqueeze(0),n.unsqueeze(0)],dataset.mode,device)\n",
    "        print(\"loss: \", loss)\n",
    "        if loss.item() == 0: break\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    if loss.item() == 0: break\n",
    "print('human', model.cEmb[5])\n",
    "print('grass',model.cEmb[2])\n",
    "print(model.rEmb[3])\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "teacher tensor([0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000],\n",
      "       grad_fn=<SelectBackward>)\n",
      "class tensor([0.0000, 0.0000, 0.0000, 0.9000, 0.9000, 0.0000, 0.0000, 0.9000],\n",
      "       grad_fn=<SelectBackward>)\n",
      "tensor([[0.0000, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
      "       grad_fn=<SelectBackward>)\n",
      "loss:  tensor(0.8000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0.7000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0.6000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0.5000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0.4000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0.3000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0.2000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0.1000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "teacher tensor([0.0917, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000],\n",
      "       grad_fn=<SelectBackward>)\n",
      "class tensor([0.0000, 0.0000, 0.0000, 0.9000, 0.9000, 0.0000, 0.0000, 0.9000],\n",
      "       grad_fn=<SelectBackward>)\n",
      "tensor([[0.0000, 0.0917, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
      "       grad_fn=<SelectBackward>)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\92803\\.conda\\envs\\rl_hw\\lib\\site-packages\\ipykernel_launcher.py:12: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  if sys.path[0] == '':\n",
      "c:\\Users\\92803\\.conda\\envs\\rl_hw\\lib\\site-packages\\ipykernel_launcher.py:13: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  del sys.path[0]\n"
     ]
    }
   ],
   "source": [
    "import torch.optim as optim\n",
    "device = torch.device(\"cpu\")\n",
    "model = DFALC(params,dataset.conceptSize, dataset.roleSize, cEmb_init=cEmb_init, rEmb_init=rEmb_init, device=device)\n",
    "dataset.mode = 4\n",
    "optimizer = optim.Adam(model.parameters(),lr=0.05)\n",
    "print('teacher', model.cEmb[7])\n",
    "print('class',model.cEmb[3])\n",
    "print(model.rEmb[2])\n",
    "for e in range(30):\n",
    "    for l,r,n in dataset:\n",
    "        loss = model([l.unsqueeze(0),r.unsqueeze(0),n.unsqueeze(0)],dataset.mode,device)\n",
    "        print(\"loss: \", loss)\n",
    "        # if loss.item() == 0: break\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    # if loss.item() == 0: break\n",
    "print('teacher', model.cEmb[7])\n",
    "print('class',model.cEmb[3])\n",
    "print(model.rEmb[2])\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[0.0000, 0.0000, 0.0000, 0.9000, 0.9000, 0.0000, 0.0000, 0.9000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.9000, 0.9000, 0.0000, 0.0000, 0.9000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.9000, 0.9000, 0.0000, 0.0000, 0.9000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.9000, 0.9000, 0.0000, 0.0000, 0.9000],\n",
       "        [0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000],\n",
       "        [0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000],\n",
       "        [0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000],\n",
       "        [0.0917, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
       "       requires_grad=True)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.cEmb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TV tensor([0.0000, 0.0000, 0.0000, 0.9000, 0.9000, 0.0000, 0.0000, 0.9000],\n",
      "       grad_fn=<SelectBackward>)\n",
      "screen tensor([0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000],\n",
      "       grad_fn=<SelectBackward>)\n",
      "tensor([[0.0000, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
      "       grad_fn=<SelectBackward>)\n",
      "loss:  tensor(0.9000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0.8000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0.7000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0.6000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0.5000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0.4000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0.3000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0.2000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0.1000, grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "loss:  tensor(0., grad_fn=<MeanBackward0>)\n",
      "TV tensor([0.0000, 0.0000, 0.0000, 0.0400, 0.9000, 0.0000, 0.0000, 0.9000],\n",
      "       grad_fn=<SelectBackward>)\n",
      "screen tensor([0.9000, 0.0000, 0.8600, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000],\n",
      "       grad_fn=<SelectBackward>)\n",
      "tensor([[0.0000, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
      "       grad_fn=<SelectBackward>)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\92803\\.conda\\envs\\rl_hw\\lib\\site-packages\\ipykernel_launcher.py:12: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  if sys.path[0] == '':\n",
      "c:\\Users\\92803\\.conda\\envs\\rl_hw\\lib\\site-packages\\ipykernel_launcher.py:13: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  del sys.path[0]\n"
     ]
    }
   ],
   "source": [
    "import torch.optim as optim\n",
    "device = torch.device(\"cpu\")\n",
    "model = DFALC(params,dataset.conceptSize, dataset.roleSize, cEmb_init=cEmb_init, rEmb_init=rEmb_init, device=device)\n",
    "dataset.mode = 5\n",
    "optimizer = optim.Adam(model.parameters(),lr=0.05)\n",
    "print('TV', model.cEmb[0])\n",
    "print('screen',model.cEmb[4])\n",
    "print(model.rEmb[0])\n",
    "for e in range(30):\n",
    "    for l,r,n in dataset:\n",
    "        loss = model([l.unsqueeze(0),r.unsqueeze(0),n.unsqueeze(0)],dataset.mode,device)\n",
    "        print(\"loss: \", loss)\n",
    "        # if loss.item() == 0: break\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    # if loss.item() == 0: break\n",
    "print('TV', model.cEmb[0])\n",
    "print('screen',model.cEmb[4])\n",
    "print(model.rEmb[0])\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "parent tensor([0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000],\n",
      "       grad_fn=<SelectBackward>)\n",
      "child tensor([0.0000, 0.0000, 0.0000, 0.9000, 0.9000, 0.0000, 0.0000, 0.9000],\n",
      "       grad_fn=<SelectBackward>)\n",
      "tensor([[0.0000, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.9000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
      "       grad_fn=<SelectBackward>)\n",
      "parent tensor([0.2396, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000],\n",
      "       grad_fn=<SelectBackward>)\n",
      "child tensor([0.4030, 1.0024, 0.4030, 1.1866, 0.9000, 1.0024, 1.0024, 1.1866],\n",
      "       grad_fn=<SelectBackward>)\n",
      "tensor([[0.0000, 1.3030, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.7148, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 1.3030, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.2866, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9000, 0.0000, 0.0000],\n",
      "        [0.7148, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.7148, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.2866, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
      "       grad_fn=<SelectBackward>)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\92803\\.conda\\envs\\rl_hw\\lib\\site-packages\\ipykernel_launcher.py:12: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  if sys.path[0] == '':\n",
      "c:\\Users\\92803\\.conda\\envs\\rl_hw\\lib\\site-packages\\ipykernel_launcher.py:13: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  del sys.path[0]\n"
     ]
    }
   ],
   "source": [
    "import torch.optim as optim\n",
    "device = torch.device(\"cpu\")\n",
    "model = DFALC(params,dataset.conceptSize, dataset.roleSize, cEmb_init=cEmb_init, rEmb_init=rEmb_init, device=device)\n",
    "dataset.mode = 6\n",
    "optimizer = optim.Adam(model.parameters(),lr=0.05)\n",
    "print('parent', model.cEmb[6])\n",
    "print('child',model.cEmb[1])\n",
    "print(model.rEmb[1])\n",
    "for e in range(30):\n",
    "    for l,r,n in dataset:\n",
    "        loss = model([l.unsqueeze(0),r.unsqueeze(0),n.unsqueeze(0)],dataset.mode,device)\n",
    "        # print(\"loss: \", loss)\n",
    "        # if loss.item() == 0: break\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    # if loss.item() == 0: break\n",
    "print('parent', model.cEmb[6])\n",
    "print('child',model.cEmb[1])\n",
    "print(model.rEmb[1])\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.0 ('rl_hw')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "b5d4ea6110d76bf407abdf3fc85b4f9a1bbb4f7f6454d667a509d28831b3322d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
