{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a5b907d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import json\n",
    "\n",
    "from collections import defaultdict\n",
    "\n",
    "instance_data = json.load(open(\"\", \"r\"))\n",
    "processed_instance = defaultdict(list)\n",
    "entity_map = json.load(open(\"\", \"r\"))\n",
    "entity_name = list(entity_map.values())\n",
    "image_map = json.load(open(\"\", \"r\"))\n",
    "name_to_image = {}\n",
    "for id in entity_map:\n",
    "    image = image_map[id]\n",
    "    name_to_image[entity_map[id]] = image\n",
    "\n",
    "# print(name_to_image)\n",
    "for instance in instance_data:\n",
    "    if len(instance[\"entity\"]) == 0:\n",
    "        continue\n",
    "    task_type = instance[\"task_type\"]\n",
    "    if task_type == 1:\n",
    "        # Process task type 1\n",
    "        # Entity Count\n",
    "        processed_instance[task_type].append(instance)\n",
    "    elif task_type == 2:\n",
    "        # Process task type 2\n",
    "        # Relation Count\n",
    "        processed_instance[task_type].append(instance)\n",
    "    elif task_type == 3:\n",
    "        # Process task type 3\n",
    "        # Image Count\n",
    "        num_entity = len(instance[\"entity\"])\n",
    "        if num_entity <= 1:\n",
    "            continue\n",
    "        image_mask = [1] * num_entity\n",
    "        k = random.randint(1, num_entity)\n",
    "        indices = random.sample(range(num_entity), k)\n",
    "        for index in indices:\n",
    "            image_mask[index] = 0\n",
    "        instance[\"image_mask\"] = image_mask\n",
    "        processed_instance[task_type].append(instance)\n",
    "    elif task_type == 4:\n",
    "        # Process task type 4\n",
    "        # Path Count\n",
    "        processed_instance[task_type].append(instance)\n",
    "    elif task_type == 5:\n",
    "        # Process task type 5\n",
    "        # Description\n",
    "        processed_instance[task_type].append(instance)\n",
    "    elif task_type == 6:\n",
    "        # Process task type 6\n",
    "        # Error Detection\n",
    "        old_entity = random.choice(instance[\"entity\"])\n",
    "        new_entity = random.choice(entity_name)\n",
    "        if new_entity not in name_to_image:\n",
    "            print(old_entity, new_entity, new_entity in name_to_image)\n",
    "        while new_entity not in name_to_image:\n",
    "            new_entity = random.choice(name_to_image)\n",
    "        new_entity_set = []\n",
    "        for ent in instance[\"entity\"]:\n",
    "            if ent == old_entity:\n",
    "                new_entity_set.append(new_entity)\n",
    "            else:\n",
    "                new_entity_set.append(ent)\n",
    "        new_relation_set = instance[\"relation\"]\n",
    "        new_triple_set = []\n",
    "        for triple in instance[\"triple\"]:\n",
    "            new_triple = []\n",
    "            for ent in triple:\n",
    "                if ent == old_entity:\n",
    "                    new_triple.append(new_entity)\n",
    "                else:\n",
    "                    new_triple.append(ent)\n",
    "            new_triple_set.append(new_triple)\n",
    "        new_source = instance[\"source\"]\n",
    "        new_instance = {\n",
    "            \"entity\": new_entity_set,\n",
    "            \"relation\": new_relation_set,\n",
    "            \"triple\": new_triple_set,\n",
    "            \"source\": new_source,\n",
    "            \"task_type\": task_type,\n",
    "            \"old_entity\": old_entity,\n",
    "            \"new_entity\": new_entity\n",
    "        }\n",
    "        processed_instance[task_type].append(new_instance)\n",
    "        # print(old_entity, new_entity)\n",
    "    elif task_type == 7:\n",
    "        # Process task type 7\n",
    "        # Entity Prediction\n",
    "        answer_entity = random.choice(instance[\"entity\"])\n",
    "        new_entity_set = []\n",
    "        for ent in instance[\"entity\"]:\n",
    "            if ent == answer_entity:\n",
    "                new_entity_set.append(\"[MASK]\")\n",
    "            else:\n",
    "                new_entity_set.append(ent)\n",
    "        new_relation_set = instance[\"relation\"]\n",
    "        new_triple_set = []\n",
    "        for triple in instance[\"triple\"]:\n",
    "            new_triple = []\n",
    "            for ent in triple:\n",
    "                if ent == answer_entity:\n",
    "                    new_triple.append(\"[MASK]\")\n",
    "                else:\n",
    "                    new_triple.append(ent)\n",
    "            new_triple_set.append(new_triple)\n",
    "        new_source = instance[\"source\"]\n",
    "        new_instance = {\n",
    "            \"entity\": new_entity_set,\n",
    "            \"relation\": new_relation_set,\n",
    "            \"triple\": new_triple_set,\n",
    "            \"source\": new_source,\n",
    "            \"task_type\": task_type,\n",
    "            \"answer\": answer_entity\n",
    "        }\n",
    "        # print(new_instance, instance)\n",
    "        processed_instance[task_type].append(new_instance)\n",
    "    elif task_type == 8:\n",
    "        # Process task type 8\n",
    "        # Relation Prediction\n",
    "        answer_triple = random.choice(instance[\"triple\"])\n",
    "        answer_relation = answer_triple[-1]\n",
    "        new_entity_set = instance[\"entity\"]\n",
    "        new_relation_set = instance[\"relation\"]\n",
    "        new_triple_set = []\n",
    "        for triple in instance[\"triple\"]:\n",
    "            if triple == answer_triple:\n",
    "                new_triple = [triple[0], triple[1], \"[MASK]\"]\n",
    "                new_triple_set.append(new_triple)\n",
    "            else:\n",
    "                new_triple_set.append(triple)\n",
    "        new_source = instance[\"source\"]\n",
    "        new_instance = {\n",
    "            \"entity\": new_entity_set,\n",
    "            \"relation\": new_relation_set,\n",
    "            \"triple\": new_triple_set,\n",
    "            \"source\": new_source,\n",
    "            \"task_type\": task_type,\n",
    "            \"answer\": answer_relation\n",
    "        }\n",
    "        processed_instance[task_type].append(new_instance)\n",
    "for i in [1, 2, 3, 4, 5, 6, 7, 8]:\n",
    "    print(i, len(processed_instance[i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d7c2da9",
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(processed_instance, open(\"\", \"w\"), ensure_ascii=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3cc57ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(name_to_image, open(\"\", \"w\"), ensure_ascii=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vvan",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
