{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "source": [
    "## Load SII Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "dfalc_data/testing/features.csv not found.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 118\u001b[0m\n\u001b[1;32m    113\u001b[0m     partOf_of_pair_of_data \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray([idx_whole_for_data[i] \u001b[38;5;241m==\u001b[39m j \u001b[38;5;28;01mfor\u001b[39;00m p \u001b[38;5;129;01min\u001b[39;00m pics \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m pics[p] \u001b[38;5;28;01mfor\u001b[39;00m j \u001b[38;5;129;01min\u001b[39;00m pics[p]])\n\u001b[1;32m    115\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m data, pairs_of_data, types_of_data, partOf_of_pair_of_data, pairs_of_bb_idxs, pics\n\u001b[0;32m--> 118\u001b[0m data, pairs_of_data, types_of_data, partOf_of_pairs_of_data, pairs_of_bb_idxs, pics \u001b[38;5;241m=\u001b[39m get_data(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest\u001b[39m\u001b[38;5;124m\"\u001b[39m,max_rows\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1000000000\u001b[39m)\n\u001b[1;32m    119\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mindividual size: \u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mconcept size: \u001b[39m\u001b[38;5;124m\"\u001b[39m, data\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m4\u001b[39m)\n",
      "Cell \u001b[0;32mIn[1], line 72\u001b[0m, in \u001b[0;36mget_data\u001b[0;34m(train_or_test_swritch, max_rows)\u001b[0m\n\u001b[1;32m     70\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m train_or_test_swritch \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m     71\u001b[0m     data_dir \u001b[38;5;241m=\u001b[39m data_testing_dir\n\u001b[0;32m---> 72\u001b[0m data \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mgenfromtxt(data_dir\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfeatures.csv\u001b[39m\u001b[38;5;124m\"\u001b[39m,delimiter\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m,\u001b[39m\u001b[38;5;124m\"\u001b[39m,max_rows\u001b[38;5;241m=\u001b[39mmax_rows)\n\u001b[1;32m     73\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m add_noise_flag:\n\u001b[1;32m     74\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mexists(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdfalc_data/training/masked\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;28mstr\u001b[39m(mask_rate)\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_type.pkl\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n",
      "File \u001b[0;32m/data2/kristal/anaconda/lib/python3.11/site-packages/numpy/lib/npyio.py:1980\u001b[0m, in \u001b[0;36mgenfromtxt\u001b[0;34m(fname, dtype, comments, delimiter, skip_header, skip_footer, converters, missing_values, filling_values, usecols, names, excludelist, deletechars, replace_space, autostrip, case_sensitive, defaultfmt, unpack, usemask, loose, invalid_raise, max_rows, encoding, ndmin, like)\u001b[0m\n\u001b[1;32m   1978\u001b[0m     fname \u001b[38;5;241m=\u001b[39m os_fspath(fname)\n\u001b[1;32m   1979\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(fname, \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m-> 1980\u001b[0m     fid \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mlib\u001b[38;5;241m.\u001b[39m_datasource\u001b[38;5;241m.\u001b[39mopen(fname, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mrt\u001b[39m\u001b[38;5;124m'\u001b[39m, encoding\u001b[38;5;241m=\u001b[39mencoding)\n\u001b[1;32m   1981\u001b[0m     fid_ctx \u001b[38;5;241m=\u001b[39m contextlib\u001b[38;5;241m.\u001b[39mclosing(fid)\n\u001b[1;32m   1982\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "File \u001b[0;32m/data2/kristal/anaconda/lib/python3.11/site-packages/numpy/lib/_datasource.py:193\u001b[0m, in \u001b[0;36mopen\u001b[0;34m(path, mode, destpath, encoding, newline)\u001b[0m\n\u001b[1;32m    156\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    157\u001b[0m \u001b[38;5;124;03mOpen `path` with `mode` and return the file object.\u001b[39;00m\n\u001b[1;32m    158\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    189\u001b[0m \n\u001b[1;32m    190\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    192\u001b[0m ds \u001b[38;5;241m=\u001b[39m DataSource(destpath)\n\u001b[0;32m--> 193\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ds\u001b[38;5;241m.\u001b[39mopen(path, mode, encoding\u001b[38;5;241m=\u001b[39mencoding, newline\u001b[38;5;241m=\u001b[39mnewline)\n",
      "File \u001b[0;32m/data2/kristal/anaconda/lib/python3.11/site-packages/numpy/lib/_datasource.py:533\u001b[0m, in \u001b[0;36mDataSource.open\u001b[0;34m(self, path, mode, encoding, newline)\u001b[0m\n\u001b[1;32m    530\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m _file_openers[ext](found, mode\u001b[38;5;241m=\u001b[39mmode,\n\u001b[1;32m    531\u001b[0m                               encoding\u001b[38;5;241m=\u001b[39mencoding, newline\u001b[38;5;241m=\u001b[39mnewline)\n\u001b[1;32m    532\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 533\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m not found.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: dfalc_data/testing/features.csv not found."
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import csv\n",
    "import os\n",
    "import pickle\n",
    "\n",
    "# Generate data in different types\n",
    "data_training_dir = \"dfalc_data/training/\"\n",
    "data_testing_dir = \"dfalc_data/testing/\"\n",
    "zero_distance_threshold = 6\n",
    "number_of_features = 65\n",
    "name = \"animal\"\n",
    "types = {}\n",
    "add_noise_flag = True\n",
    "mask_rate =0.1\n",
    "with open(\"classes.csv\", \"r\") as f:\n",
    "    cnt = 0\n",
    "    for line in f.readlines():\n",
    "        types[line.strip()] = cnt\n",
    "        cnt += 1\n",
    "        \n",
    "        \n",
    "if name == \"vehicle\":\n",
    "    # uncomment this line for training the vehicle object types\n",
    "    selected_types_name = np.array(['aeroplane','artifact_wing','body','engine','stern','wheel','bicycle','chain_wheel','handlebar','headlight','saddle','bus','bodywork','door','license_plate','mirror','window','car','motorbike','train','coach','locomotive','boat'])\n",
    "\n",
    "if name == \"indoor\":\n",
    "    # uncomment this line for training the indoor object types\n",
    "    selected_types_name = np.array(['bottle','body','cap','pottedplant','plant','pot','tvmonitor','screen']) #'chair','sofa','diningtable'\n",
    "\n",
    "if name == \"animal\":\n",
    "    # uncomment this line for training the animal object types\n",
    "    selected_types_name = np.array(['person','arm','ear','ebrow','foot','hair','hand','mouth','nose','eye','head','leg','neck','torso','cat','tail','bird','animal_wing','beak','sheep','horn','muzzle','cow','dog','horse','hoof'])\n",
    "\n",
    "\n",
    "\n",
    "selected_types = [types[n] for n in selected_types_name]\n",
    "\n",
    "# uncomment this line for training all the object types\n",
    "# selected_types = list(range(len(types)))\n",
    "\n",
    "def add_noise(data, mask_rate = 0.2):\n",
    "    nosied_data = np.array(data, copy=True)\n",
    "    coords = []\n",
    "    for i in range(data.shape[0]):\n",
    "        for j in range(data.shape[1]):\n",
    "            coords.append([i,j])\n",
    "\n",
    "    size = len(coords)\n",
    "    print(\"cEmb masked size: \", int(mask_rate*size))\n",
    "    for i in np.random.choice(size, int(mask_rate*size), replace=False):\n",
    "        x,y = coords[i]\n",
    "        nosied_data[x,y] = np.random.uniform(0,1)\n",
    "    return nosied_data\n",
    "\n",
    "def containment_ratios_between_two_bbxes(bb1, bb2):\n",
    "    bb1_area = (bb1[-2] - bb1[-4]) * (bb1[-1] - bb1[-3])\n",
    "    bb2_area = (bb2[-2] - bb2[-4]) * (bb2[-1] - bb2[-3])\n",
    "    w_intersec = max(0,min([bb1[-2], bb2[-2]]) - max([bb1[-4], bb2[-4]]))\n",
    "    h_intersec = max(0,min([bb1[-1], bb2[-1]]) - max([bb1[-3], bb2[-3]]))\n",
    "    bb_area_intersection = w_intersec * h_intersec\n",
    "    return [float(bb_area_intersection)/bb1_area, float(bb_area_intersection)/bb2_area]\n",
    "\n",
    "def get_data(train_or_test_swritch,max_rows=10000000):\n",
    "    assert train_or_test_swritch == \"train\" or train_or_test_swritch == \"test\"\n",
    "\n",
    "    # Fetching the data from the file system\n",
    "\n",
    "    if train_or_test_swritch == \"train\":\n",
    "        data_dir = data_training_dir\n",
    "    if train_or_test_swritch == \"test\":\n",
    "        data_dir = data_testing_dir\n",
    "    data = np.genfromtxt(data_dir+\"features.csv\",delimiter=\",\",max_rows=max_rows)\n",
    "    if add_noise_flag:\n",
    "        if os.path.exists(\"dfalc_data/training/masked\"+str(mask_rate)+\"_type.pkl\"):\n",
    "            data = pickle.load(open(\"dfalc_data/training/masked\"+str(mask_rate)+\"_type.pkl\",\"rb\"))\n",
    "        else:\n",
    "            data = add_noise(data,mask_rate)\n",
    "            pickle.dump(data,open(\"dfalc_data/training/masked\"+str(mask_rate)+\"_type.pkl\",\"wb\"),2)\n",
    "    types_of_data = np.genfromtxt(data_dir + \"types.csv\", dtype=\"i\", max_rows=max_rows)\n",
    "    idx_whole_for_data = np.genfromtxt(data_dir+ \"partOf.csv\",dtype=\"i\",max_rows=max_rows)\n",
    "    idx_of_cleaned_data = np.where(np.logical_and(\n",
    "        np.all(data[:, -2:] - data[:, -4:-2] >= zero_distance_threshold, axis=1),\n",
    "        np.in1d(types_of_data,selected_types)))[0]\n",
    "    print(\"deleting\", len(data) - len(idx_of_cleaned_data), \"small bb out of\", data.shape[0], \"bb\")\n",
    "    data = data[idx_of_cleaned_data]\n",
    "    data[:, -4:] /= 500\n",
    "\n",
    "    # Cleaning data by removing small bounding boxes and recomputing indexes of partof data\n",
    "\n",
    "    types_of_data = types_of_data[idx_of_cleaned_data]\n",
    "    idx_whole_for_data = idx_whole_for_data[idx_of_cleaned_data]\n",
    "    for i in range(len(idx_whole_for_data)):\n",
    "        if idx_whole_for_data[i] != -1 and idx_whole_for_data[i] in idx_of_cleaned_data:\n",
    "            idx_whole_for_data[i] = np.where(idx_whole_for_data[i] == idx_of_cleaned_data)[0]\n",
    "        else:\n",
    "            idx_whole_for_data[i] = -1\n",
    "\n",
    "    # Grouping bbs that belong to the same picture\n",
    "\n",
    "    pics = {} #记录了每张图片对应的bbox，即data的id\n",
    "    for i in range(len(data)):\n",
    "        if data[i][0] in pics:\n",
    "            pics[data[i][0]].append(i)\n",
    "        else:\n",
    "            pics[data[i][0]] = [i]\n",
    "\n",
    "    pairs_of_data = np.array(\n",
    "        [np.concatenate((data[i][1:], data[j][1:], containment_ratios_between_two_bbxes(data[i], data[j]))) for p in\n",
    "         pics for i in pics[p] for j in pics[p]])\n",
    "\n",
    "    pairs_of_bb_idxs = np.array([(i,j) for p in pics for i in pics[p] for j in pics[p]]) #枚举同一张图片里不同objects(bbox) pair\n",
    "\n",
    "    partOf_of_pair_of_data = np.array([idx_whole_for_data[i] == j for p in pics for i in pics[p] for j in pics[p]])\n",
    "\n",
    "    return data, pairs_of_data, types_of_data, partOf_of_pair_of_data, pairs_of_bb_idxs, pics\n",
    "\n",
    "\n",
    "data, pairs_of_data, types_of_data, partOf_of_pairs_of_data, pairs_of_bb_idxs, pics = get_data(\"test\",max_rows=1000000000)\n",
    "print(\"individual size: \", data.shape[0], \"concept size: \", data.shape[1]-4)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'selected_types' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[5], line 16\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctional\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mF\u001b[39;00m\n\u001b[1;32m     14\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mos\u001b[39;00m\n\u001b[0;32m---> 16\u001b[0m conceptSize \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(selected_types)\n\u001b[1;32m     17\u001b[0m roleSize \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m     18\u001b[0m individualSize \u001b[38;5;241m=\u001b[39m data\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\n",
      "\u001b[0;31mNameError\u001b[0m: name 'selected_types' is not defined"
     ]
    }
   ],
   "source": [
    "from numpy import require\n",
    "from Dataset import OntologyDataset\n",
    "from model import DFALC, DFALC2\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "from tqdm import tqdm\n",
    "import re\n",
    "from torch.utils.data.sampler import RandomSampler\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from torch.autograd import Variable\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import os\n",
    "\n",
    "conceptSize = len(selected_types)\n",
    "roleSize = 1\n",
    "individualSize = data.shape[0]\n",
    "cEmb_candid = torch.Tensor(data[:,1:-4])\n",
    "rEmb_candid = torch.zeros(1, individualSize, individualSize)\n",
    "partOf_of_pairs_idx = np.where(partOf_of_pairs_of_data)[0]\n",
    "for idx,(i,j) in enumerate(pairs_of_bb_idxs):\n",
    "    i_partof_j_p, j_partof_i_p = pairs_of_data[idx][-2], pairs_of_data[idx][-1]\n",
    "    if i_partof_j_p == 1: j_partof_i_p = 0\n",
    "    if j_partof_i_p == 1: i_partof_j_p = 0\n",
    "    rEmb_candid[0,i,j] = i_partof_j_p\n",
    "    rEmb_candid[0,j,i] = j_partof_i_p\n",
    "\n",
    "info_path = \"dfalc_data\"\n",
    "file_name = \"PascalPartOntology_\"+name+\".owl\"\n",
    "with open(os.path.join(info_path,file_name+\"_roles.txt\"),\"w\") as f:\n",
    "    f.write(\"http://www.w3.org/2002/07/partOf\")\n",
    "with open(os.path.join(info_path,file_name+\"_individuals.txt\"),\"w\") as f:\n",
    "    individuals = []\n",
    "    for p in pics:\n",
    "        for i in pics[p]:\n",
    "            f.write(str(i) + \"\\n\")\n",
    "\n",
    "params = {\n",
    "        \"conceptPath\": os.path.join(info_path,file_name+\"_concepts.txt\"),\n",
    "        \"rolePath\": os.path.join(info_path,file_name+\"_roles.txt\"),\n",
    "        \"individualPath\": os.path.join(info_path,file_name+\"_individuals.txt\"),\n",
    "        \"normalizationPath\": os.path.join(info_path,file_name+\"_normalization.txt\"),\n",
    "        \"batchSize\": 3,\n",
    "        \"epochSize\":10,\n",
    "        \"earlystopping\":10,\n",
    "        \"dist\": \"minkowski\",\n",
    "        \"norm\":1,\n",
    "        \"norm_rate\":0.5,\n",
    "        \"norm_rate2\":0\n",
    "    }\n",
    "to_train = False\n",
    "\n",
    "save_path = \"dfalc_data\"\n",
    "if to_train: save_path = os.path.join(save_path,\"training\")\n",
    "else: save_path = os.path.join(save_path,\"testing\")\n",
    "save_path += \"/PascalPartOntology_\"\n",
    "dataset = OntologyDataset(params,save_path)\n",
    "\n",
    "cEmb_init = torch.zeros(dataset.conceptSize-2, individualSize)\n",
    "rEmb_init = torch.zeros(1, individualSize, individualSize)\n",
    "# cEmb_init.fill_(0.5)\n",
    "# rEmb_init.fill_(0.5)\n",
    "\n",
    "true_rEmb = torch.zeros(1, individualSize, individualSize)\n",
    "for idx, (i,j) in enumerate(pairs_of_bb_idxs[partOf_of_pairs_of_data]):\n",
    "#     ci = dataset.concept2id['http://www.w3.org/2002/07/'+id2types[types_of_data[i]]]\n",
    "#     cj = dataset.concept2id['http://www.w3.org/2002/07/'+id2types[types_of_data[j]]]\n",
    "    true_rEmb[0,i,j] = 1 #\n",
    "    rEmb_init[0,i,j] = rEmb_candid[0,i,j]\n",
    "# rEmb_init = rEmb_candid.clone()\n",
    "\n",
    "\n",
    "\n",
    "def get_definers_initial_semantics(cid, use_partOf=False):\n",
    "    print(cid)\n",
    "    if use_partOf:\n",
    "        emb = torch.zeros(1,individualSize)\n",
    "        print(\"mode: 5\")\n",
    "        dataset.mode = 5\n",
    "        for left,right,neg in dataset:\n",
    "            print(left, right,neg)\n",
    "            if right.item() == cid:\n",
    "                emb = torch.max(torch.minimum(true_rEmb[0],cEmb_init[left[1]].unsqueeze(1).expand(true_rEmb[0].shape)),1).values + 0.1\n",
    "                break\n",
    "        dataset.mode = 0\n",
    "        print(\"mode: 0\")\n",
    "        \n",
    "        for left,right,neg in dataset:\n",
    "            print(left, right,neg)\n",
    "            if right.item() == cid:\n",
    "                emb = torch.maximum(cEmb_init[left.item()], emb) + 0.1\n",
    "       \n",
    "        return emb\n",
    "    else:\n",
    "        dataset.mode = 2\n",
    "        for left,right,neg in dataset:\n",
    "            if left == cid:\n",
    "                if id2concept[right[0].item()][:7] == \"definer\":\n",
    "                    return torch.maximum(get_definers_initial_semantics(right[0],use_partOf=False),cEmb_init[right[1]]) - 0.1\n",
    "                return torch.maximum(cEmb_init[right[0]],cEmb_init[right[1]]) - 0.1\n",
    "                \n",
    "    \n",
    "    return torch.zeros(1,individualSize)\n",
    "\n",
    "cid2typeid = {} \n",
    "interest_cids = []\n",
    "cnt = 0\n",
    "id2concept = {c:idx for idx, c in dataset.concept2id.items()}\n",
    "\n",
    "                \n",
    "for c,idx in dataset.concept2id.items():\n",
    "    cname = c.replace(\"http://www.w3.org/2002/07/\",\"\")\n",
    "    if cname in selected_types_name:\n",
    "        cEmb_init[idx] = cEmb_candid[:,types[cname]]\n",
    "        cid2typeid[cnt] = types[cname]\n",
    "        interest_cids.append(True)\n",
    "        cnt += 1\n",
    "    else:\n",
    "        interest_cids.append(False)\n",
    "interest_cids = interest_cids[:-2]\n",
    "for c,idx in dataset.concept2id.items():\n",
    "    cname = c.replace(\"http://www.w3.org/2002/07/\",\"\")\n",
    "    if cname in selected_types_name: continue\n",
    "    if (idx == dataset.conceptSize -3) or (idx == dataset.conceptSize-4):continue\n",
    "    cEmb_init[idx] = get_definers_initial_semantics(idx,use_partOf=True)\n",
    "    print(torch.max(cEmb_init[idx]))\n",
    "        \n",
    "cEmb_init[-1] = torch.ones(1,individualSize)\n",
    "cEmb_init[-2] = torch.zeros(1,individualSize)\n",
    "\n",
    "id2types = {idx:name for name,idx in types.items()}\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id2concept"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Baseline Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'cEmb_init' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[6], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetrics\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m confusion_matrix\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetrics\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m precision_recall_fscore_support\n\u001b[0;32m----> 3\u001b[0m p_baseline \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39margmax(cEmb_init[interest_cids],\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mnumpy()\n\u001b[1;32m      4\u001b[0m h_baseline \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray([cid2typeid[i] \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m p_baseline])\n\u001b[1;32m      6\u001b[0m y_true \u001b[38;5;241m=\u001b[39m []\n",
      "\u001b[0;31mNameError\u001b[0m: name 'cEmb_init' is not defined"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import confusion_matrix\n",
    "from sklearn.metrics import precision_recall_fscore_support\n",
    "p_baseline = torch.argmax(cEmb_init[interest_cids],0).detach().numpy()\n",
    "h_baseline = np.array([cid2typeid[i] for i in p_baseline])\n",
    "\n",
    "y_true = []\n",
    "y_baseline = []\n",
    "# partof_true = np.ravel(true_rEmb.numpy()).T\n",
    "# partof_baseline = (np.ravel(rEmb_init.numpy()).T > .5).astype(int)\n",
    "for idx,t in enumerate(types_of_data):\n",
    "    if t in selected_types:\n",
    "        y_baseline.append(selected_types_name[np.where(selected_types==h_baseline[idx])[0][0]])\n",
    "        y_true.append(selected_types_name[np.where(selected_types==t)[0][0]])\n",
    "\n",
    "\n",
    "\n",
    "# confusion_matrix(y_true, y_pred, labels=selected_types_name)\n",
    "p,r,f,s = precision_recall_fscore_support(y_true, y_baseline)\n",
    "print(precision_recall_fscore_support(y_true, y_baseline,average='macro'))\n",
    "print({selected_types_name[i]:s[i] for i in range(len(selected_types))})\n",
    "# print(\"PartOf: \", precision_recall_fscore_support(partof_true, partof_baseline))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Type Classification - Revising based on DF-ALC with rule-based loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim.lr_scheduler import LambdaLR\n",
    "from torch import nn\n",
    "from sklearn.metrics import precision_recall_fscore_support\n",
    "device = torch.device(\"cuda:0\")\n",
    "\n",
    "import pickle\n",
    "params = {\n",
    "        \"conceptPath\": os.path.join(info_path,file_name+\"_concepts.txt\"),\n",
    "        \"rolePath\": os.path.join(info_path,file_name+\"_roles.txt\"),\n",
    "        \"individualPath\": os.path.join(info_path,file_name+\"_individuals.txt\"),\n",
    "        \"normalizationPath\": os.path.join(info_path,file_name+\"_normalization.txt\"),\n",
    "        \"batchSize\": 5,\n",
    "        \"epochSize\":10,\n",
    "        \"earlystopping\":10,\n",
    "        \"dist\": \"minkowski\",\n",
    "        \"norm\":1,\n",
    "        \"norm_rate\":0.3,\n",
    "        \"norm_rate2\":0.3,\n",
    "        \"alpha\": 0.8\n",
    "    }\n",
    "\n",
    "\n",
    "\n",
    "nEpoch = 15000\n",
    "batchSz = 4\n",
    "best_loss = 100\n",
    "last_best_epoch = 0\n",
    "patience = 1\n",
    "best_f1 = 0\n",
    "best_cEmb = None\n",
    "\n",
    "model = DFALC(params, conceptSize, roleSize, cEmb_init, rEmb_init, device).to(device)\n",
    "model = nn.DataParallel(model, device_ids=[0,1])\n",
    "# model= nn.DataParallel(model,device_ids=[0, 1, 2,3])\n",
    "optimizer = optim.Adam(model.parameters(), 2e-4)\n",
    "# scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.65 ** epoch)\n",
    "\n",
    "for epoch in range(1, nEpoch+1):\n",
    "    loss_final, data_cnt = 0, 0\n",
    "#     losses = []\n",
    "    for mode in range(7):\n",
    "        dataset.mode = mode\n",
    "        if len(dataset) == 0: continue\n",
    "#         print(mode, len(dataset))\n",
    "#         if mode ==3 or mode == 4: continue\n",
    "        loader = DataLoader(dataset, sampler = RandomSampler(dataset), batch_size=batchSz)\n",
    "        data_cnt += len(dataset)\n",
    "        for i, batch in enumerate(loader):\n",
    "            ptriplets = [b.to(device) for b in batch]\n",
    "            loss = model(ptriplets, mode, device)#, is_input.contiguous()\n",
    "#             losses.append(loss)\n",
    "#             if loss.item()<1e-6: break\n",
    "            optimizer.zero_grad()\n",
    "            torch.mean(loss).backward(retain_graph=True)\n",
    "            optimizer.step()\n",
    "#         if mode %2 == 0:\n",
    "#             loss = losses[0]\n",
    "# #             print(losses)\n",
    "#             optimizer.zero_grad()\n",
    "#             if mode > 1:\n",
    "#                 loss = torch.mean(torch.sum(torch.concat(losses,0).to(device),1))\n",
    "#             else:\n",
    "#                 for i in range(1,len(losses)):\n",
    "#                     loss += losses[i]\n",
    "#             loss.backward(retain_graph=True)\n",
    "#             optimizer.step()\n",
    "#             losses = []\n",
    "#             data_cnt = 0\n",
    "#     scheduler.step(losses)\n",
    "#             losses.append(loss)\n",
    "    \n",
    "\n",
    "        # err = computeErr(preds_selected.data.detach(), boardSz, unperm)/batchSz\n",
    "#     loss = losses[0]\n",
    "#     for i in range(1,len(losses)):\n",
    "#         loss += losses[i]\n",
    "    \n",
    "    \n",
    "    \n",
    "    p = torch.argmax(model.module.cEmb[torch.BoolTensor(interest_cids)],0).detach().cpu().numpy()\n",
    "    h = np.array([cid2typeid[i] for i in p])\n",
    "    y_pred = []\n",
    "    for idx,t in enumerate(types_of_data):\n",
    "        if t in selected_types:\n",
    "            y_pred.append(selected_types_name[np.where(selected_types==h[idx])[0][0]])\n",
    "    \n",
    "    precision, recall, fbeta_score, support = precision_recall_fscore_support(y_true, y_pred,average='macro')\n",
    "    print(\"Epoch {} loss {:6f} precision {} recall {} f1 {}\".format(epoch,torch.mean(loss).item(),precision, recall, fbeta_score))\n",
    "    if fbeta_score > best_f1:\n",
    "        best_f1 = fbeta_score\n",
    "        best_cEmb = model.module.cEmb.detach().cpu()\n",
    "        # err_final += err\n",
    "#     if loss_final < best_loss:\n",
    "#         best_loss = loss_final\n",
    "#         last_best_epoch = epoch\n",
    "#     else:\n",
    "#         patience -= 1\n",
    "#     if patience == 0:\n",
    "#         break\n",
    "    \n",
    "    # loss_final, err_final = loss_final/len(loader), err_final/len(loader)\n",
    "    # scheduler.step(err_final/len(loader))\n",
    "#     if not to_train:\n",
    "#         print('TESTING SET RESULTS: Average loss: {:.4f}'.format(loss_final))\n",
    "#     if loss_final < 0.05: break\n",
    "\n",
    "#print('memory: {:.2f} MB, cached: {:.2f} MB'.format(torch.cuda.memory_allocated()/2.**20, torch.cuda.memory_cached()/2.**20))\n",
    "\n",
    "# pickle.dump(best_cEmb,open(\"/data/data_wuxuan/SII/dfalc_data/testing/cEmb_\"+name+\".pkl\",\"wb\"))\n",
    "# pickle.dump(model.rEmb,open(\"/data/data_wuxuan/SII/dfalc_data/testing/rEmb_\"+name+\".pkl\",\"wb\"))\n",
    "# pickle.dump(cEmb_candid,open(\"/data/data_wuxuan/SII/dfalc_data/testing/true_cEmb_\"+name+\".pkl\",\"wb\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add Noise to PartOf Relation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpy import require\n",
    "from Dataset import OntologyDataset\n",
    "from model import DFALC, DFALC2\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "from tqdm import tqdm\n",
    "import re\n",
    "from torch.utils.data.sampler import RandomSampler\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from torch.autograd import Variable\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import os\n",
    "\n",
    "conceptSize = len(selected_types)\n",
    "roleSize = 1\n",
    "individualSize = data.shape[0]\n",
    "cEmb_candid = torch.Tensor(data[:,1:-4])\n",
    "rEmb_candid = torch.zeros(1, individualSize, individualSize)\n",
    "partOf_of_pairs_idx = np.where(partOf_of_pairs_of_data)[0]\n",
    "for idx,(i,j) in enumerate(pairs_of_bb_idxs):\n",
    "    i_partof_j_p, j_partof_i_p = pairs_of_data[idx][-2], pairs_of_data[idx][-1]\n",
    "    if i_partof_j_p == 1: j_partof_i_p = 0\n",
    "    if j_partof_i_p == 1: i_partof_j_p = 0\n",
    "    rEmb_candid[0,i,j] = i_partof_j_p\n",
    "    rEmb_candid[0,j,i] = j_partof_i_p\n",
    "\n",
    "info_path = \"dfalc_data\"\n",
    "file_name = \"PascalPartOntology_\"+name+\"_e.owl\"\n",
    "with open(os.path.join(info_path,file_name+\"_roles.txt\"),\"w\") as f:\n",
    "    f.write(\"http://www.w3.org/2002/07/partOf\")\n",
    "with open(os.path.join(info_path,file_name+\"_individuals.txt\"),\"w\") as f:\n",
    "    individuals = []\n",
    "    for p in pics:\n",
    "        for i in pics[p]:\n",
    "            f.write(str(i) + \"\\n\")\n",
    "\n",
    "params = {\n",
    "        \"conceptPath\": os.path.join(info_path,file_name+\"_concepts.txt\"),\n",
    "        \"rolePath\": os.path.join(info_path,file_name+\"_roles.txt\"),\n",
    "        \"individualPath\": os.path.join(info_path,file_name+\"_individuals.txt\"),\n",
    "        \"normalizationPath\": os.path.join(info_path,file_name+\"_normalization.txt\"),\n",
    "        \"batchSize\": 3,\n",
    "        \"epochSize\":10,\n",
    "        \"earlystopping\":10,\n",
    "        \"dist\": \"minkowski\",\n",
    "        \"norm\":1,\n",
    "        \"norm_rate\":0.5,\n",
    "        \"norm_rate2\":0\n",
    "    }\n",
    "to_train = False\n",
    "\n",
    "save_path = \"dfalc_data\"\n",
    "if to_train: save_path = os.path.join(save_path,\"training\")\n",
    "else: save_path = os.path.join(save_path,\"testing\")\n",
    "save_path += \"/PascalPartOntology_\"\n",
    "dataset = OntologyDataset(params,save_path)\n",
    "\n",
    "cEmb_init = torch.zeros(dataset.conceptSize-2, individualSize)\n",
    "rEmb_init = torch.zeros(1, individualSize, individualSize)\n",
    "# cEmb_init.fill_(0.5)\n",
    "# rEmb_init.fill_(0.5)\n",
    "\n",
    "true_rEmb = torch.zeros(1, individualSize, individualSize)\n",
    "for idx, (i,j) in enumerate(pairs_of_bb_idxs[partOf_of_pairs_of_data]):\n",
    "#     ci = dataset.concept2id['http://www.w3.org/2002/07/'+id2types[types_of_data[i]]]\n",
    "#     cj = dataset.concept2id['http://www.w3.org/2002/07/'+id2types[types_of_data[j]]]\n",
    "    true_rEmb[0,i,j] = 1 #\n",
    "    rEmb_init[0,i,j] = rEmb_candid[0,i,j]\n",
    "# rEmb_init = rEmb_candid.clone()\n",
    "\n",
    "\n",
    "\n",
    "def get_definers_initial_semantics(cid, use_partOf=False):\n",
    "    print(cid)\n",
    "    if use_partOf:\n",
    "        emb = torch.zeros(1,individualSize)\n",
    "        print(\"mode: 5\")\n",
    "        dataset.mode = 5\n",
    "        for left,right,neg in dataset:\n",
    "            print(left, right,neg)\n",
    "            if right.item() == cid:\n",
    "                emb = torch.max(torch.minimum(true_rEmb[0],cEmb_init[left[1]].unsqueeze(1).expand(true_rEmb[0].shape)),1).values + 0.1\n",
    "                break\n",
    "        dataset.mode = 0\n",
    "        print(\"mode: 0\")\n",
    "        \n",
    "        for left,right,neg in dataset:\n",
    "            print(left, right,neg)\n",
    "            if right.item() == cid:\n",
    "                emb = torch.maximum(cEmb_init[left.item()], emb) + 0.1\n",
    "       \n",
    "        return emb\n",
    "    else:\n",
    "        dataset.mode = 2\n",
    "        for left,right,neg in dataset:\n",
    "            if left == cid:\n",
    "                if id2concept[right[0].item()][:7] == \"definer\":\n",
    "                    return torch.maximum(get_definers_initial_semantics(right[0],use_partOf=False),cEmb_init[right[1]]) - 0.1\n",
    "                return torch.maximum(cEmb_init[right[0]],cEmb_init[right[1]]) - 0.1\n",
    "                \n",
    "    \n",
    "    return torch.zeros(1,individualSize)\n",
    "\n",
    "cid2typeid = {} \n",
    "interest_cids = []\n",
    "cnt = 0\n",
    "id2concept = {c:idx for idx, c in dataset.concept2id.items()}\n",
    "\n",
    "                \n",
    "for c,idx in dataset.concept2id.items():\n",
    "    cname = c.replace(\"http://www.w3.org/2002/07/\",\"\")\n",
    "    if cname in selected_types_name:\n",
    "        cEmb_init[idx] = cEmb_candid[:,types[cname]]\n",
    "        cid2typeid[cnt] = types[cname]\n",
    "        interest_cids.append(True)\n",
    "        cnt += 1\n",
    "    else:\n",
    "        interest_cids.append(False)\n",
    "interest_cids = interest_cids[:-2]\n",
    "for c,idx in dataset.concept2id.items():\n",
    "    cname = c.replace(\"http://www.w3.org/2002/07/\",\"\")\n",
    "    if cname in selected_types_name: continue\n",
    "    if (idx == dataset.conceptSize -3) or (idx == dataset.conceptSize-4):continue\n",
    "    cEmb_init[idx] = get_definers_initial_semantics(idx,use_partOf=True)\n",
    "    print(torch.max(cEmb_init[idx]))\n",
    "        \n",
    "cEmb_init[-1] = torch.ones(1,individualSize)\n",
    "cEmb_init[-2] = torch.zeros(1,individualSize)\n",
    "\n",
    "id2types = {idx:name for name,idx in types.items()}\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add Noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sklearn\n",
    "def add_noise(cEmb, rEmb, mask_rate = 0.2, mask_rEmb = False, mask_cEmb = False):\n",
    "    masked_cEmb, masked_rEmb = np.array(cEmb, copy=True), np.array(rEmb, copy=True)\n",
    "    if mask_cEmb:\n",
    "        coords = []\n",
    "        for i in range(conceptSize):\n",
    "            for j in range(individualSize):\n",
    "                coords.append([i,j])\n",
    "\n",
    "        size = len(coords)\n",
    "        print(\"cEmb masked size: \", int(mask_rate*size))\n",
    "        for i in np.random.choice(size, int(mask_rate*size), replace=False):\n",
    "            x,y = coords[i]\n",
    "            masked_cEmb[x,y] = np.random.uniform(0,1)\n",
    "\n",
    "    if mask_rEmb:\n",
    "        coords = set()\n",
    "\n",
    "        size = individualSize*individualSize\n",
    "        print(\"rEmb masked size: \", int(mask_rate*size))\n",
    "        for i in np.random.choice(individualSize*individualSize, int(mask_rate*size), replace=False):\n",
    "            x,y,z = np.random.randint(0,high=roleSize), np.random.randint(0,high=individualSize), np.random.randint(0,high=individualSize)\n",
    "            while (x,y,z) in coords:\n",
    "                x,y,z = np.random.randint(0,high=roleSize), np.random.randint(0,high=individualSize), np.random.randint(0,high=individualSize)\n",
    "            coords.add((x,y,z))\n",
    "    #         if masked_rEmb[x,y,z] == 0:\n",
    "            v = np.random.uniform(0,1)\n",
    "\n",
    "    #         while (v == -1) or (v == 1):\n",
    "    #             v = np.random.uniform(1-self.alpha,self.alpha)\n",
    "            masked_rEmb[x,y,z] = v\n",
    "    return torch.Tensor(masked_cEmb), torch.Tensor(masked_rEmb)\n",
    "noised_cEmb,noised_rEmb = add_noise(cEmb_init.numpy(),rEmb_init.numpy(),0.1,False,True)\n",
    "\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from sklearn.metrics import precision_recall_fscore_support\n",
    "p_baseline = torch.argmax(noised_cEmb[interest_cids],0).detach().numpy()\n",
    "h_baseline = np.array([cid2typeid[i] for i in p_baseline])\n",
    "\n",
    "y_true = []\n",
    "y_baseline = []\n",
    "partof_true = np.ravel(true_rEmb.numpy()).T\n",
    "partof_baseline = (np.ravel(noised_rEmb.numpy()).T > .5).astype(int)\n",
    "for idx,t in enumerate(types_of_data):\n",
    "    if t in selected_types:\n",
    "        y_baseline.append(selected_types_name[np.where(selected_types==h_baseline[idx])[0][0]])\n",
    "        y_true.append(selected_types_name[np.where(selected_types==t)[0][0]])\n",
    "\n",
    "\n",
    "\n",
    "# confusion_matrix(y_true, y_pred, labels=selected_types_name)\n",
    "p,r,f,s = precision_recall_fscore_support(y_true, y_baseline)\n",
    "print(precision_recall_fscore_support(y_true, y_baseline,average='macro'))\n",
    "print({selected_types_name[i]:s[i] for i in range(len(selected_types))})\n",
    "print(\"PartOf: \", precision_recall_fscore_support(partof_true, partof_baseline),sklearn.metrics.accuracy_score(partof_true,partof_baseline))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part-of Relation Classification - Revising based on DF-ALC with rule-based loss\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "b5d4ea6110d76bf407abdf3fc85b4f9a1bbb4f7f6454d667a509d28831b3322d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
