{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.cluster import AgglomerativeClustering\n",
    "from openai import OpenAI\n",
    "from typing import List\n",
    "from sklearn.metrics import silhouette_score\n",
    "from tqdm import tqdm\n",
    "import argparse\n",
    "import json\n",
    "import os\n",
    "import random\n",
    "import logging\n",
    "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
    "import concurrent.futures\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_args():\n",
    "    parser = argparse.ArgumentParser(description='Cluster entities using hierarchical clustering and refine the clusters using LLM.')\n",
    "    parser.add_argument('--output_dir', type=str, default=\"/data/pj20/lamake_data\")\n",
    "    parser.add_argument('--data_dir', type=str, default=\"/home/pj20/server-03/lamake/data\")\n",
    "    parser.add_argument('--dataset', type=str, default=\"FB15K-237\", help='Path to the dataset file containing the list of entities to cluster.')\n",
    "    parser.add_argument('--dimensions', type=int, default=1024, help='Dimensionality of the embeddings. Default: 1024.')\n",
    "    parser.add_argument('--num_threads', type=int, default=10, help='Number of threads to use for multi-threaded processes. Default: 10.')\n",
    "    parser.add_argument('--max_entities', type=int, default=100, help='Maximum number of entities to include in an LLM request. Default: 100.')\n",
    "    \n",
    "    args = parser.parse_args(args=[])\n",
    "    args.log_dir = f\"{args.output_dir}/{args.dataset}/logs\"\n",
    "    \n",
    "    return args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = construct_args()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cluster import  read_entities, create_entity_info_emb_dict, generate_embeddings, build_hierarchy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "clustering_file = \"/data/pj20/lamake_data/FB15K-237/clustering/clustering_0.52.pkl\"\n",
    "with open(clustering_file, \"rb\") as f:\n",
    "    clustering = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entities = read_entities('/home/pj20/server-03/lamake/data/FB15K-237/entities.dict')\n",
    "entity_info, entity_embeddings = create_entity_info_emb_dict(args, entities)\n",
    "entities_text, original_descriptions = [], []\n",
    "for entity in entities:\n",
    "    entities_text.append(entity_info[entity][\"text_label\"])\n",
    "    original_descriptions.append(entity_info[entity][\"original_description\"])\n",
    "    \n",
    "print(\"Start Generating Embeddings...\")\n",
    "embeddings, entity_info, entity_embeddings = generate_embeddings(args, entity_info=entity_info, entity_embeddings=entity_embeddings, dim=args.dimensions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [],
   "source": [
    "clusters = {}\n",
    "for i in range(clustering.n_clusters_):\n",
    "    cluster_indices = np.where(clustering.labels_ == i)[0]\n",
    "    cluster_entities = [entities_text[idx] for idx in cluster_indices]\n",
    "    clusters[f\"Cluster_{i+1}\"] = cluster_entities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "clusters_ = {int(i): entities for i, entities in enumerate(clusters.values())}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(entities)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clustering.children_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_hier = build_hierarchy(clustering.children_, len(entities), entity_labels=entities_text, clustering=clustering)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {},
   "outputs": [],
   "source": [
    "clusters_ = {int(i): entities for i, entities in enumerate(clusters.values())}\n",
    "entity2clusterid = {}\n",
    "\n",
    "for i, cluster in enumerate(clusters_.values()):\n",
    "    for entity in cluster:\n",
    "        entity2clusterid[entity] = i\n",
    "        \n",
    "clusterid2count = defaultdict(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity2clusterid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./initial_hier.json', 'w') as f:\n",
    "    json.dump(initial_hier, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity2clusterid = {}\n",
    "\n",
    "for i, cluster in enumerate(clusters_.values()):\n",
    "    for entity in cluster:\n",
    "        entity2clusterid[entity] = i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "clusterid2count = defaultdict(int)\n",
    "\n",
    "\n",
    "def label_(d, leaf_keys=None, leaf_values=None):\n",
    "    if leaf_keys is None:\n",
    "        leaf_keys = []\n",
    "    if leaf_values is None:\n",
    "        leaf_values = []\n",
    "    for key, value in d.items():\n",
    "        if isinstance(value, dict):  # If the value is another dictionary, recurse into it\n",
    "            label_(value, leaf_keys, leaf_values)\n",
    "        else:  # If the value is not a dictionary, then it's a leaf node\n",
    "            cluster_id = entity2clusterid[value]\n",
    "            d[key] = [cluster_id, clusterid2count[cluster_id]]\n",
    "            clusterid2count[entity2clusterid[value]] += 1\n",
    "    return d\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "metadata": {},
   "outputs": [],
   "source": [
    "hierarchy = label_(initial_hier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./initial_hier_numeric.json', 'w') as f:\n",
    "    json.dump(hierarchy, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [],
   "source": [
    "def refine_1(d, clusters_, leaf_keys=None, leaf_values=None):\n",
    "    if leaf_keys is None:\n",
    "        leaf_keys = []\n",
    "    if leaf_values is None:\n",
    "        leaf_values = []\n",
    "    \n",
    "    keys_to_delete = []  # List to hold keys of items to be deleted\n",
    "    items_to_update = {}  # Dictionary to hold items to be updated\n",
    "\n",
    "    for key, value in list(d.items()):  # Convert dict_items to a list to safely iterate\n",
    "        if isinstance(value, dict):  # If the value is another dictionary, recurse into it\n",
    "            refine_1(value, clusters_, leaf_keys, leaf_values)\n",
    "        else:\n",
    "            if value[1] > 0:\n",
    "                keys_to_delete.append(key)\n",
    "            else:\n",
    "                items_to_update[key] = clusters_[value[0]]\n",
    "\n",
    "    # Now, delete keys marked for deletion\n",
    "    for key in keys_to_delete:\n",
    "        del d[key]\n",
    "\n",
    "    # Update the dictionary with new values\n",
    "    for key, new_value in items_to_update.items():\n",
    "        d[key] = new_value\n",
    "\n",
    "    return d\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {},
   "outputs": [],
   "source": [
    "hierarchy = refine_1(hierarchy, clusters_)\n",
    "with open('./refined_hier.json', 'w') as f:\n",
    "    json.dump(hierarchy, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "metadata": {},
   "outputs": [],
   "source": [
    "def refine_2(d):\n",
    "    # Recursive function to process and refine each dictionary\n",
    "    def process_dict(sub_dict):\n",
    "        for key in list(sub_dict.keys()):  # Iterate over a copy of the keys\n",
    "            value = sub_dict[key]\n",
    "            if isinstance(value, dict):\n",
    "                if value:  # Check if the dictionary is not empty\n",
    "                    result = process_dict(value)\n",
    "                    # If the result is a single entry with a list, replace the current dict\n",
    "                    if len(result) == 1 and isinstance(list(result.values())[0], list):\n",
    "                        sub_dict[key] = list(result.values())[0]\n",
    "                    else:\n",
    "                        sub_dict[key] = result\n",
    "                else:\n",
    "                    del sub_dict[key]  # Remove empty dictionaries\n",
    "        return sub_dict\n",
    "\n",
    "    # Copy the original dictionary to avoid modification issues\n",
    "    refined_dict = process_dict(d.copy())\n",
    "    return refined_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "metadata": {},
   "outputs": [],
   "source": [
    "hierarchy = refine_2(hierarchy)\n",
    "with open('./refined_hier.json', 'w') as f:\n",
    "    json.dump(hierarchy, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {},
   "outputs": [],
   "source": [
    "def refine_3(d):\n",
    "    # Recursive function to process and refine each dictionary\n",
    "    def process_dict(sub_dict):\n",
    "        new_dict = {}  # To accumulate refined results\n",
    "        for key, value in list(sub_dict.items()):\n",
    "            if isinstance(value, dict):\n",
    "                processed = process_dict(value)  # Recursively process\n",
    "                if processed:  # Only add non-empty results\n",
    "                    new_dict[key] = processed\n",
    "            else:  # Keep non-dict items as they are\n",
    "                new_dict[key] = value\n",
    "        return new_dict\n",
    "\n",
    "    # Start the processing with the original dictionary\n",
    "    refined_dict = process_dict(d)\n",
    "    return refined_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "metadata": {},
   "outputs": [],
   "source": [
    "hierarchy = refine_3(hierarchy)\n",
    "with open('./refined_hier.json', 'w') as f:\n",
    "    json.dump(hierarchy, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import refine_4\n",
    "import json\n",
    "\n",
    "with open('/data/pj20/lamake_data/FB15K-237/seed_clusters.json', 'r') as f:\n",
    "    seed_clusters = json.load(f)\n",
    "    \n",
    "hierarchy = refine_4(seed_clusters)\n",
    "\n",
    "with open('./refined_hier.json', 'w') as f:\n",
    "    json.dump(hierarchy, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from openai import OpenAI\n",
    "\n",
    "with open('./openai_api.key', 'r') as f:\n",
    "    api_key = f.read().strip()\n",
    "client = OpenAI(api_key=api_key)\n",
    "\n",
    "def gpt_chat_return_response(model, prompt, seed=44):\n",
    "    response = client.chat.completions.create(\n",
    "        model=model,\n",
    "        messages=[\n",
    "            {\"role\": \"user\", \"content\": prompt}\n",
    "        ],\n",
    "        max_tokens=200,\n",
    "        temperature=0,\n",
    "        seed=seed,\n",
    "        logprobs=True\n",
    "    )\n",
    "    return response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_lca(root, node1, node2):\n",
    "    if root is None:\n",
    "        return None\n",
    "    \n",
    "    if isinstance(root, list):\n",
    "        if node1 in root or node2 in root:\n",
    "            return root\n",
    "    \n",
    "    if root == node1 or root == node2:\n",
    "        return root\n",
    "\n",
    "    lca_list = []\n",
    "    if isinstance(root, dict):\n",
    "        for child in root.values():\n",
    "            lca = find_lca(child, node1, node2)\n",
    "            if lca is not None:\n",
    "                lca_list.append(lca)\n",
    "            if len(lca_list) > 1:\n",
    "                return root\n",
    "\n",
    "    return lca_list[0] if lca_list else None\n",
    "\n",
    "def find_distance_from_root_to_node(root, node, distance=0):\n",
    "    if root is None:\n",
    "        return -1\n",
    "\n",
    "    if isinstance(root, list):\n",
    "        if node in root:\n",
    "            return distance\n",
    "\n",
    "    if root == node:\n",
    "        return distance\n",
    "\n",
    "    if isinstance(root, dict):\n",
    "        for child in root.values():\n",
    "            dist = find_distance_from_root_to_node(child, node, distance + 1)\n",
    "            if dist != -1:\n",
    "                return dist\n",
    "\n",
    "    return -1\n",
    "\n",
    "def distance_between_nodes(root, node1, node2):\n",
    "    lca = find_lca(root, node1, node2)\n",
    "    if lca is None:\n",
    "        return -1\n",
    "\n",
    "    distance1 = find_distance_from_root_to_node(lca, node1, 0)\n",
    "    distance2 = find_distance_from_root_to_node(lca, node2, 0)\n",
    "    \n",
    "    return distance1 + distance2 if distance1 != -1 and distance2 != -1 else -1\n",
    "\n",
    "# 树结构示例\n",
    "tree = {\n",
    "    \"Cluster_29078\": {\n",
    "        \"Cluster_29074\": {\n",
    "            \"Cluster_29030\": {\n",
    "                \"Cluster_27322\": {\n",
    "                    \"Cluster_1143\": [\"American Kennel Club\"],\n",
    "                    \"Cluster_25030\": {\n",
    "                        \"Cluster_22602\": [\"dog\", \"bulldog\"],\n",
    "                        \"Cluster_23900\": {\n",
    "                            \"Cluster_20210\": [\"Golden Retriever\", \"Labrador Retriever\", \"German Shepherd dog\"],\n",
    "                            \"Cluster_20718\": [\"Chihuahua\", \"Yorkshire Terrier\"]\n",
    "                        }\n",
    "                    }\n",
    "                }\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "}\n",
    "\n",
    "# Recalculate the distance with the revised functions\n",
    "distance = distance_between_nodes(tree, \"Golden Retriever\", \"American Kennel Club\")\n",
    "print(f\"Distance between Golden Retriever and American Kennel Club: {distance}\")\n",
    "\n",
    "distance = distance_between_nodes(tree, \"Golden Retriever\", \"bulldog\")\n",
    "print(f\"Distance between Golden Retriever and bulldog: {distance}\")\n",
    "\n",
    "distance = distance_between_nodes(tree, \"Golden Retriever\", \"Labrador Retriever\")\n",
    "print(f\"Distance between Golden Retriever and Labrador Retriever: {distance}\")\n",
    "\n",
    "distance = distance_between_nodes(tree, \"Golden Retriever\", \"Chihuahua\")\n",
    "print(f\"Distance between Golden Retriever and Chihuahua: {distance}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_tree_depth(root):\n",
    "    if root is None:\n",
    "        return 0\n",
    "    \n",
    "    if isinstance(root, list) or isinstance(root, str):\n",
    "        return 1  # Leaf nodes contribute a depth of 1\n",
    "    \n",
    "    if isinstance(root, dict):\n",
    "        max_depth = 0\n",
    "        for child in root.values():\n",
    "            child_depth = compute_tree_depth(child)\n",
    "            if child_depth > max_depth:\n",
    "                max_depth = child_depth\n",
    "        return 1 + max_depth  # Add 1 for the depth from the current node to its children\n",
    "\n",
    "    return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/data/pj20/lamake_data/FB15K-237/seed_hierarchy.json', 'r') as f:\n",
    "    tree = json.load(f)\n",
    "\n",
    "compute_tree_depth(tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(entity_embeddings), len(entity_info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "\n",
    "# Load entity embeddings\n",
    "with open('/data/pj20/lamake_data/FB15K-237/entity_embeddings.json', 'r') as file:\n",
    "    entity_embeddings = json.load(file)\n",
    "    \n",
    "# load entity info\n",
    "with open('/data/pj20/lamake_data/FB15K-237/entity_info.json', 'r') as file:\n",
    "    entity_info = json.load(file)\n",
    "    \n",
    "label2entity = {entity_info[entity]['text_label']: entity for entity in entity_info.keys()}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_cluster_embedding(cluster, embeddings_dict, cluster_id):\n",
    "    if isinstance(cluster, list):\n",
    "        # Base case: cluster is a list of entities\n",
    "        embeddings = [entity_embeddings.get(label2entity.get(entity)) for entity in cluster]\n",
    "        cluster_embedding = np.mean(embeddings, axis=0)\n",
    "        print(f\"Computed embedding for {cluster_id} with entities: {cluster}\")\n",
    "        print(f\"Cluster embedding: {cluster_embedding}\")\n",
    "    elif isinstance(cluster, dict):\n",
    "        # Recursive case: cluster has sub-clusters\n",
    "        sub_embeddings = []\n",
    "        for sub_cluster_id, sub_cluster in cluster.items():\n",
    "            sub_embedding = compute_cluster_embedding(sub_cluster, embeddings_dict, sub_cluster_id)\n",
    "            sub_embeddings.append(sub_embedding)\n",
    "        cluster_embedding = np.mean(sub_embeddings, axis=0)\n",
    "        print(f\"Computed embedding for {cluster_id} with sub-clusters: {list(cluster.keys())}\")\n",
    "        print(f\"Parent cluster embeddings: {cluster_embedding}\")\n",
    "\n",
    "    embeddings_dict[cluster_id] = cluster_embedding\n",
    "    return cluster_embedding\n",
    "\n",
    "\n",
    "# Initialize dictionary to store embeddings\n",
    "cluster_embeddings = {}\n",
    "# Trigger the recursive computation\n",
    "for cluster_id, cluster_data in tree.items():\n",
    "    compute_cluster_embedding(cluster_data, cluster_embeddings, cluster_id)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "\n",
    "entity2clusterid = {}\n",
    "leaf_keys, leaf_values = find_leaves(tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(leaf_values)):\n",
    "    for entity_label in leaf_values[i]:\n",
    "        entity2clusterid[label2entity[entity_label]] = leaf_keys[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity2clusterid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/data/pj20/lamake_data/FB15K-237/entity_info.json', 'r') as f:\n",
    "    entity_info = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity_info['/m/08k05y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "map_child_to_parent(tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_lca_key(root, key1, key2):\n",
    "    if root is None:\n",
    "        return None\n",
    "\n",
    "    # If the current root (or dict) contains the key directly, we check its keys\n",
    "    if key1 in root or key2 in root:\n",
    "        return root  # Found one of the keys at the current level, return this root\n",
    "\n",
    "    lca_list = []\n",
    "    if isinstance(root, dict):\n",
    "        for key, child in root.items():\n",
    "            if key == key1 or key == key2:\n",
    "                lca_list.append(key)\n",
    "            lca = find_lca_key(child, key1, key2)\n",
    "            if lca is not None:\n",
    "                lca_list.append(lca)\n",
    "            if len(lca_list) > 1:\n",
    "                return root  # Both keys found in different subtrees\n",
    "\n",
    "    return lca_list[0] if lca_list else None\n",
    "\n",
    "def find_distance_from_root_to_key(root, key, distance=0):\n",
    "    if root is None:\n",
    "        return -1\n",
    "\n",
    "    # Check if the key is the current root's direct key\n",
    "    if key in root:\n",
    "        return distance\n",
    "\n",
    "    if isinstance(root, dict):\n",
    "        for child_key, child in root.items():\n",
    "            if child_key == key:\n",
    "                return distance + 1\n",
    "            dist = find_distance_from_root_to_key(child, key, distance + 1)\n",
    "            if dist != -1:\n",
    "                return dist\n",
    "\n",
    "    return -1\n",
    "\n",
    "def distance_between_keys(root, key1, key2):\n",
    "    lca = find_lca_key(root, key1, key2)\n",
    "    if lca is None:\n",
    "        return -1\n",
    "\n",
    "    distance1 = find_distance_from_root_to_key(lca, key1, 0)\n",
    "    distance2 = find_distance_from_root_to_key(lca, key2, 0)\n",
    "    \n",
    "    return distance1 + distance2 if distance1 != -1 and distance2 != -1 else -1\n",
    "\n",
    "\n",
    "tree = {\n",
    "    \"Cluster_29078\": {\n",
    "        \"Cluster_29074\": {\n",
    "            \"Cluster_29030\": {\n",
    "                \"Cluster_27322\": {\n",
    "                    \"Cluster_1143\": [\"American Kennel Club\"],\n",
    "                    \"Cluster_25030\": {\n",
    "                        \"Cluster_22602\": [\"dog\", \"bulldog\"],\n",
    "                        \"Cluster_23900\": {\n",
    "                            \"Cluster_20210\": [\"Golden Retriever\", \"Labrador Retriever\", \"German Shepherd dog\"],\n",
    "                            \"Cluster_20718\": [\"Chihuahua\", \"Yorkshire Terrier\"]\n",
    "                        }\n",
    "                    }\n",
    "                }\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "}\n",
    "\n",
    "# Example usage with keys from the tree:\n",
    "key1 = \"Cluster_29078\"\n",
    "key2 = \"Cluster_1143\"\n",
    "\n",
    "# Calculate the distance between two keys\n",
    "distance = distance_between_keys(tree, key1, key2)\n",
    "print(f\"Distance between {key1} and {key2}: {distance}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_keys = get_all_keys(tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_distance(parent_map, key, root):\n",
    "    distance = 0\n",
    "    while key != root:\n",
    "        key = parent_map[key]\n",
    "        distance += 1\n",
    "    return distance\n",
    "\n",
    "\n",
    "def find_nearest_keys_lca_based(tree, input_key, n, parent_map):\n",
    "    all_keys = set(parent_map.keys())\n",
    "    distances = []\n",
    "    \n",
    "    for key in all_keys:\n",
    "        if key != input_key:\n",
    "            dist = distance_between_keys(parent_map, input_key, key)\n",
    "            if dist != -1:  # Only consider valid distances\n",
    "                distances.append((key, dist))\n",
    "    \n",
    "    # Sort the list of distances based on distance, and return the first n keys\n",
    "    distances.sort(key=lambda x: x[1])\n",
    "    return [key for key, dist in distances[:n]]\n",
    "\n",
    "def find_lca(parent_map, key1, key2):\n",
    "    ancestors = set()\n",
    "    # Climb up from key1 to the root, collecting all ancestors\n",
    "    while key1 in parent_map:\n",
    "        ancestors.add(key1)\n",
    "        key1 = parent_map.get(key1, None)  # Safely get parent or None if not exists\n",
    "        if key1 is None:\n",
    "            break\n",
    "    # Climb up from key2 until we find the first common ancestor\n",
    "    while key2 not in ancestors:\n",
    "        key2 = parent_map.get(key2, None)  # Safely get parent or None if not exists\n",
    "        if key2 is None:\n",
    "            return None  # If reached the top without finding an ancestor, return None\n",
    "    return key2\n",
    "\n",
    "def distance_between_keys(parent_map, key1, key2):\n",
    "    # Find root two levels above current key\n",
    "    root1 = parent_map.get(key1)\n",
    "    if root1:\n",
    "        root1 = parent_map.get(root1)\n",
    "    \n",
    "    root2 = parent_map.get(key2)\n",
    "    if root2:\n",
    "        root2 = parent_map.get(root2)\n",
    "\n",
    "    # Find LCA considering two levels up as the root\n",
    "    if root1 and root2:\n",
    "        lca = find_lca(parent_map, key1, key2)\n",
    "        if lca:\n",
    "            distance1 = find_distance(parent_map, key1, lca)\n",
    "            distance2 = find_distance(parent_map, key2, lca)\n",
    "            return distance1 + distance2\n",
    "    return -1  # Return -1 if no valid LCA is found\n",
    "\n",
    "# Building the parent map\n",
    "parent_map = map_child_to_parent(tree)\n",
    "\n",
    "# Example usage:\n",
    "input_key = \"Cluster_23662\"\n",
    "nearest_keys = find_nearest_keys_lca_based(tree, input_key, 5, parent_map)\n",
    "print(f\"The nearest 3 keys to {input_key} are: {nearest_keys}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "def node2parentpath(d, source_cluster):\n",
    "    parent_path = []\n",
    "    parent_distances = []\n",
    "    child_parent = map_child_to_parent(d)\n",
    "    current_parent = child_parent[source_cluster]\n",
    "    while current_parent in child_parent.keys():\n",
    "        parent_path.append(current_parent)\n",
    "        parent_distances.append(distance_between_keys(d, current_parent, source_cluster))\n",
    "        current_parent = child_parent[current_parent]\n",
    "    \n",
    "    return parent_path, parent_distances\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "node2parentpath(tree, 'Cluster_1143')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/data/pj20/lamake_data/FB15K-237/seed_hierarchy.json', 'r') as f:\n",
    "    tree = json.load(f)\n",
    "\n",
    "tree = {\n",
    "    \"Cluster_top\": tree\n",
    "}\n",
    "\n",
    "with open('/data/pj20/lamake_data/FB15K-237/seed_hierarchy.json', 'w') as f:\n",
    "    json.dump(tree, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rename_clusters_to_ints(original_dict, start_index=0, key_map=None):\n",
    "    \"\"\"\n",
    "    Recursively renames keys of the nested dictionary to integers, incrementing from a given start index.\n",
    "    Also tracks the mapping from original keys to new keys.\n",
    "    \"\"\"\n",
    "    if key_map is None:\n",
    "        key_map = {}\n",
    "\n",
    "    new_dict = {}\n",
    "    index = start_index\n",
    "\n",
    "    for key, value in original_dict.items():\n",
    "        key_map[key] = index\n",
    "        if isinstance(value, dict):\n",
    "            new_dict[index], index, key_map = rename_clusters_to_ints(value, index + 1, key_map)\n",
    "        else:\n",
    "            new_dict[index] = value\n",
    "            index += 1\n",
    "\n",
    "    key_map_inv = {v: k for k, v in key_map.items()}\n",
    "    return new_dict, index, key_map, key_map_inv\n",
    "\n",
    "with open('/data/pj20/lamake_data/FB15K-237/seed_hierarchy.json', 'r') as f:\n",
    "    tree = json.load(f)\n",
    "    \n",
    "tree, _, key_map, key_map_inv = rename_clusters_to_ints(tree)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "def rename_unique_keys(d, prefix=\"Cluster\"):\n",
    "    counter = itertools.count(1)  # Initialize the counter outside the function\n",
    "\n",
    "    def rename_recursively(d):\n",
    "        \"\"\"Recursively renames all keys using a globally unique counter.\"\"\"\n",
    "        new_dict = {}\n",
    "        for key, value in d.items():\n",
    "            new_key = f\"{prefix}_{next(counter)}\"  # Generate a globally unique key\n",
    "            if isinstance(value, dict):\n",
    "                # Recursively rename keys in sub-dictionaries\n",
    "                new_dict[new_key] = rename_recursively(value)\n",
    "            else:\n",
    "                # Apply new keys to values that are lists\n",
    "                new_dict[new_key] = value\n",
    "        return new_dict\n",
    "\n",
    "    return rename_recursively(d)  # Start the recursive renaming\n",
    "\n",
    "# Original nested dictionary\n",
    "nested_dict = {\n",
    "    \"Cluster_llm_root\": {\n",
    "        \"Cluster_5241\": [\n",
    "            \"Vidyasagar\"\n",
    "        ],\n",
    "        \"Cluster_29077\": {\n",
    "            \"Cluster_29033\": {\n",
    "                \"Cluster_25423\": {\n",
    "                    \"Cluster_6928\": [\n",
    "                        \"old age\"\n",
    "                    ],\n",
    "                    \"Cluster_21691\": [\n",
    "                        \"adolescence\",\n",
    "                        \"young adult\",\n",
    "                        \"coming of age\"\n",
    "                    ]\n",
    "                },\n",
    "                \"Cluster_28201\": {\n",
    "                    \"Cluster_16051\": [\n",
    "                        \"20th century\",\n",
    "                        \"19th century\"\n",
    "                    ],\n",
    "                    \"Cluster_26871\": {\n",
    "                        \"Cluster_20840\": [\n",
    "                            \"modern architecture\",\n",
    "                            \"modernism\"\n",
    "                        ],\n",
    "                        \"Cluster_24018\": {\n",
    "                            \"Cluster_1631\": [\n",
    "                                \"Surrealism\"\n",
    "                            ],\n",
    "                            \"Cluster_2924\": [\n",
    "                                \"New Romanticism\"\n",
    "                            ]\n",
    "                        }\n",
    "                    }\n",
    "                }\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "}\n",
    "\n",
    "# Renaming all keys in the dictionary with unique names\n",
    "renamed_dict = rename_unique_keys(nested_dict, \"Cluster\")\n",
    "print(renamed_dict)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from trialmind.llm import call_llm # for chat models\n",
    "from trialmind.llm import function_call_llm # for function  call\n",
    "\n",
    "outputs = call_llm(\n",
    "        prompt_template=\"tell me a joke about {item}\",\n",
    "        inputs = {\"item\": \"dog\"},\n",
    "        llm=\"gpt-4\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "entity_info_file = \"/shared/pj20/lamake_data/WN18RR/entity_info.json\"\n",
    "entity_embedding_file = \"/shared/pj20/lamake_data/WN18RR/entity_init_embeddings.json\"\n",
    "\n",
    "with open(entity_info_file, \"r\") as f:\n",
    "    entity_info = json.load(f)\n",
    "    \n",
    "with open(entity_embedding_file, \"r\") as f:\n",
    "    entity_embeddings = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity_info['06611376'], entity_info['13555775']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(entity_info.keys()), list(entity_embeddings.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "entities_dict = \"\"\n",
    "entity_ids = list(entity_info.keys())\n",
    "for i in range(len(entity_ids)):\n",
    "    entities_dict += f\"{i}\\t{entity_ids[i]}\\n\"\n",
    "    \n",
    "with open(\"../../data/WN18RR/entities.dict\", \"w\") as f:\n",
    "    f.write(entities_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#compute cosine similarity between two entities\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "\n",
    "def compute_similarity(entity1, entity2, entity_info, entity_embeddings):\n",
    "    entity1_emb = entity_embeddings[entity1]\n",
    "    entity2_emb = entity_embeddings[entity2]\n",
    "    return cosine_similarity([entity1_emb], [entity2_emb])[0][0]\n",
    "\n",
    "compute_similarity('14252320', '02722458', entity_info, entity_embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "from collections import defaultdict, deque\n",
    "\n",
    "# Function to read the entity index to entity ID mapping\n",
    "def read_entity_mapping(filename):\n",
    "    entity_to_index = {}\n",
    "    with open(filename, 'r') as file:\n",
    "        reader = csv.reader(file, delimiter='\\t')\n",
    "        for row in reader:\n",
    "            index, entity_id = int(row[0]), row[1]\n",
    "            entity_to_index[entity_id] = index\n",
    "    return entity_to_index\n",
    "\n",
    "# Function to read the knowledge graph triples\n",
    "def read_knowledge_graph(filename):\n",
    "    graph = defaultdict(list)\n",
    "    with open(filename, 'r') as file:\n",
    "        reader = csv.reader(file, delimiter='\\t')\n",
    "        for row in reader:\n",
    "            e1, relation, e2 = row[0], row[1], row[2]\n",
    "            graph[e1].append(e2)\n",
    "            graph[e2].append(e1)  # Assuming the graph is undirected; if directed, remove this line\n",
    "    return graph\n",
    "\n",
    "# Function to perform BFS and find k-hop neighbors\n",
    "def find_k_hop_neighbors(graph, entity_to_index, k):\n",
    "    k_hop_neighbors = {}\n",
    "    for entity_id in graph:\n",
    "        visited = set()\n",
    "        queue = deque([(entity_id, 0)])\n",
    "        neighbors = set()\n",
    "\n",
    "        while queue:\n",
    "            current, depth = queue.popleft()\n",
    "            if depth > k:\n",
    "                break\n",
    "            if current in visited:\n",
    "                continue\n",
    "            visited.add(current)\n",
    "\n",
    "            if depth == k:\n",
    "                neighbors.add(current)\n",
    "            else:\n",
    "                for neighbor in graph[current]:\n",
    "                    if neighbor not in visited:\n",
    "                        queue.append((neighbor, depth + 1))\n",
    "\n",
    "        k_hop_neighbors[entity_id] = [entity_to_index[neighbor] for neighbor in neighbors if neighbor in entity_to_index]\n",
    "    \n",
    "    return k_hop_neighbors\n",
    "\n",
    "# Main function\n",
    "def main(entity_mapping_file, kg_file, k):\n",
    "    entity_to_index = read_entity_mapping(entity_mapping_file)\n",
    "    graph = read_knowledge_graph(kg_file)\n",
    "    k_hop_neighbors = find_k_hop_neighbors(graph, entity_to_index, k)\n",
    "\n",
    "    return k_hop_neighbors\n",
    "\n",
    "# Example usage\n",
    "entity_mapping_file = '/home/pj20/server-03/lamake/data/WN18RR/entities.dict'\n",
    "kg_file = '/home/pj20/server-03/lamake/data/WN18RR/train.txt'\n",
    "k = 2  # Replace with desired k-hop value\n",
    "\n",
    "k_hop_neighbors = main(entity_mapping_file, kg_file, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k_hop_neighbors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "entity_info_file = \"/shared/pj20/lamake_data/WN18RR/entity_info_seed_hier.json\"\n",
    "\n",
    "with open(entity_info_file, \"r\") as f:\n",
    "    entity_info = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(entity_info['00260881']['k_hop_neighbors'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def cosine_similarity(a, b):\n",
    "    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from openai import OpenAI\n",
    "\n",
    "with open('./openai_api.key', 'r') as f:\n",
    "    api_key = f.read().strip()\n",
    "    \n",
    "\n",
    "client = OpenAI(api_key=api_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation = \"_member_of_domain_usage\".replace(\"_\", \" \")\n",
    "relation_emb = client.embeddings.create(\n",
    "                input=relation,\n",
    "                model=\"text-embedding-3-large\",\n",
    "                dimensions=1024,\n",
    "            ).data[0].embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_emb = np.array(relation_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "ent_emb_file = \"/shared/pj20/lamake_data/WN18RR/entity_init_embeddings.npy\"\n",
    "\n",
    "entity_embs = np.load(ent_emb_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "e1_emb = entity_embs[34590][:1024]\n",
    "e2_emb = entity_embs[31909][:1024]\n",
    "\n",
    "cosine_similarity(e1_emb+relation_emb, e2_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def cosine_similarity(a, b):\n",
    "    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))\n",
    "\n",
    "# Load entity embeddings\n",
    "entity_embs = np.load(\"/shared/pj20/lamake_data/WN18RR/entity_init_embeddings.npy\")\n",
    "relation_embs= np.load(\"/shared/pj20/lamake_data/WN18RR/checkpoints/pRotatE_seed_batch_512_hidden_512_dist_cosine/relation_embedding.npy\")\n",
    "\n",
    "# Entity embeddings\n",
    "e1_emb = entity_embs[18015]\n",
    "e2_emb = entity_embs[19725]\n",
    "\n",
    "# Compute relation embedding\n",
    "relation = \"_derivationally_related_form\".replace(\"_\", \" \")\n",
    "relation_emb = client.embeddings.create(\n",
    "                input=relation,\n",
    "                model=\"text-embedding-3-large\",\n",
    "                dimensions=1024,\n",
    "            ).data[0].embedding\n",
    "\n",
    "relation_emb = np.concatenate([relation_emb, relation_emb])\n",
    "# relation_emb = relation_embs[0]\n",
    "\n",
    "# Compute target embedding\n",
    "target_emb = e1_emb + relation_emb\n",
    "\n",
    "# Compute cosine similarities with all entity embeddings\n",
    "similarities = np.array([cosine_similarity(target_emb, emb) for emb in entity_embs])\n",
    "\n",
    "# Sort entities by similarity\n",
    "sorted_indices = np.argsort(similarities)[::-1]\n",
    "\n",
    "# Find rank of e2_emb\n",
    "e2_index = 19725\n",
    "e2_rank = np.where(sorted_indices == e2_index)[0][0] + 1  # +1 for 1-based rank\n",
    "\n",
    "print(f\"Rank of e2_emb: {e2_rank}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kgc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
