{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05193a47",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "\n",
    "def load_triples_and_image_map(dataset):\n",
    "    image_map = json.load(open(\"\".format(dataset), \"r\"))\n",
    "    triples = json.load(open(\"\".format(dataset), \"r\"))\n",
    "    return triples, image_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "421b1772",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "from pickletools import read_stringnl_noescape\n",
    "import random\n",
    "datasets = ['mkgy', 'vsem', 'fb15k237']\n",
    "triple_full_map = {}\n",
    "image_full_map = {}\n",
    "sample_instances = []\n",
    "for dataset in datasets:\n",
    "    triples, image_map = load_triples_and_image_map(dataset)\n",
    "    triple_full_map[dataset] = triples\n",
    "    image_full_map[dataset] = image_map\n",
    "    extra_triple_map = defaultdict(list)\n",
    "    num = len(triples)\n",
    "    count = 0\n",
    "    for i in range(num):\n",
    "        for triple in triples[i]:\n",
    "            head = triple['head']\n",
    "            tail = triple['tail']\n",
    "            rel = triple['relation']\n",
    "            extra_triple_map[(head, tail)].append(rel)\n",
    "    for i in range(num):\n",
    "        num_triple = len(triples[i])\n",
    "        num_cand = min(8, num_triple)\n",
    "        select_triple = random.sample(triples[i], num_cand)\n",
    "        entity_set = set()\n",
    "        relation_set = set()\n",
    "        triple_set = set()\n",
    "        for triple in select_triple:\n",
    "            head = triple['head']\n",
    "            tail = triple['tail']\n",
    "            rel = triple['relation']\n",
    "            entity_set.add(head)\n",
    "            entity_set.add(tail)\n",
    "            relation_set.add(rel)\n",
    "            triple_set.add((head, tail, rel))\n",
    "        entity_set = list(entity_set)\n",
    "        num_entity = len(entity_set)\n",
    "        for j in range(num_entity):\n",
    "            ent1 = entity_set[j]\n",
    "            for k in range(j + 1, num_entity):\n",
    "                ent2 = entity_set[k]\n",
    "                if (ent1, ent2) in extra_triple_map:\n",
    "                    rels = extra_triple_map[(ent1, ent2)]\n",
    "                    for rel in rels:\n",
    "                        relation_set.add(rel)\n",
    "                        triple_set.add((ent1, ent2, rel))\n",
    "                if (ent2, ent1) in extra_triple_map:\n",
    "                    rels = extra_triple_map[(ent2, ent1)]\n",
    "                    for rel in rels:\n",
    "                        relation_set.add(rel)\n",
    "                        triple_set.add((ent2, ent1, rel))\n",
    "\n",
    "        if len(select_triple) != len(triple_set):\n",
    "            # print(len(select_triple), len(triple_set), len(entity_set), len(relation_set))\n",
    "            count += 1\n",
    "        pass\n",
    "        \n",
    "        sample_instances.append({\n",
    "            \"entity\": list(entity_set),\n",
    "            \"relation\": list(relation_set),\n",
    "            \"triple\": list(triple_set),\n",
    "            \"source\": dataset,\n",
    "            \"task_type\": random.choice([1, 2, 3, 4, 5, 6, 7, 8])\n",
    "        })\n",
    "        print(\"Dataset: {}, count: {}\".format(dataset, count))\n",
    "print(len(sample_instances))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88b71976",
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(sample_instances, open(\"dataset/subgraph/full_instances.json\", \"w\"), ensure_ascii=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2f41a2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(sample_instances[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2945a82",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vvan",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
