{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "08397eea",
   "metadata": {},
   "source": [
    "## 1. Protein Node Description"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b445d959",
   "metadata": {},
   "source": [
    "### 1.1 Protein Entity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "488ed4f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "bmgc_gene_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene.csv', dtype=str)\n",
    "bmgc_gene_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene_Description.csv', dtype=str).drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_gene_llm_df = pd.merge(bmgc_gene_df, bmgc_gene_desc_df, how='left', on='BioMedGraphica_Conn_ID')\n",
    "# keep columns ['HGNC_Symbol', 'Gene_Name', ''Gene_Type', 'Chromosome', 'Gene_Start', 'Gene_End', 'Ensembl_Gene_ID', 'Ensembl_Gene_ID_Version', 'Ensembl', 'NCBI_Gene_ID', 'NCBI Gene']\n",
    "bmgc_gene_llm_df = bmgc_gene_llm_df[['HGNC_Symbol', 'Gene_Name', 'Gene_Type', 'Chromosome', 'Gene_Start', 'Gene_End', 'Ensembl_Gene_ID', 'Ensembl_Gene_ID_Version', 'Ensembl', 'NCBI_Gene_ID', 'NCBI Gene']]\n",
    "# drop the NaN content in the 'HGNC_Symbol' column\n",
    "bmgc_gene_llm_df = bmgc_gene_llm_df.dropna(subset=['HGNC_Symbol'])\n",
    "# drop the duplicates\n",
    "bmgc_gene_llm_df = bmgc_gene_llm_df.drop_duplicates(subset=['HGNC_Symbol', 'Gene_Name', 'Gene_Type', 'Chromosome', 'Gene_Start', 'Gene_End', 'Ensembl_Gene_ID', 'Ensembl_Gene_ID_Version', 'Ensembl']).reset_index(drop=True)\n",
    "display(bmgc_gene_llm_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5522e143",
   "metadata": {},
   "outputs": [],
   "source": [
    "# check the value counts of HGNC_Symbol column\n",
    "hgnc_symbol_counts = bmgc_gene_llm_df['HGNC_Symbol'].value_counts()\n",
    "# print(hgnc_symbol_counts)\n",
    "# Filter out the HGNC_Symbols that have more than 1 occurrence\n",
    "hgnc_symbol_counts = hgnc_symbol_counts[hgnc_symbol_counts == 1]\n",
    "# Convert the series to a list\n",
    "hgnc_symbol_unique_list = hgnc_symbol_counts.index.tolist()\n",
    "# print(len(hgnc_symbol_unique_list))\n",
    "# Select rows which are in the [hgnc_symbol_unique_list]\n",
    "bmgc_gene_llm_df = bmgc_gene_llm_df[bmgc_gene_llm_df['HGNC_Symbol'].isin(hgnc_symbol_unique_list)].reset_index(drop=True)\n",
    "# rename the column 'NCBI Gene' to 'NCBI_Gene'\n",
    "bmgc_gene_llm_df = bmgc_gene_llm_df.rename(columns={'NCBI Gene': 'NCBI_Gene'})\n",
    "display(bmgc_gene_llm_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccef0d76",
   "metadata": {},
   "source": [
    "### 1.2 Protein Relation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a04478a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "bmgc_relation_df = pd.read_csv('./data/BioMedGraphica-Conn/Relation/BioMedGraphica_Conn_Relation.csv', dtype=str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d9eb0c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select the rows where the 'Type' is 'Protein-Protein'\n",
    "bmgc_relation_ppi_df = bmgc_relation_df[bmgc_relation_df['Type'] == 'Protein-Protein']\n",
    "# Keep the columns ['BMGC_From_ID', 'BMGC_To_ID', 'Source', 'Type']\n",
    "bmgc_relation_ppi_df = bmgc_relation_ppi_df[['BMGC_From_ID', 'BMGC_To_ID', 'Source', 'Type']].reset_index(drop=True)\n",
    "display(bmgc_relation_ppi_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55970c00",
   "metadata": {},
   "source": [
    "### 1.3 Map the gene with protein"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab58c37a",
   "metadata": {},
   "outputs": [],
   "source": [
    "bmgc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/BioMedGraphica_Conn_Entity.csv', dtype=str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6709c41",
   "metadata": {},
   "outputs": [],
   "source": [
    "# translation chain converging to the same node\n",
    "# fetch the promoter, gene, transcript and protein entity alone\n",
    "promoter_entity_df = bmgc_df[bmgc_df['Type'] == 'Promoter'].copy()\n",
    "gene_entity_df = bmgc_df[bmgc_df['Type'] == 'Gene'].copy()\n",
    "transcript_entity_df = bmgc_df[bmgc_df['Type'] == 'Transcript'].copy()\n",
    "protein_entity_df = bmgc_df[bmgc_df['Type'] == 'Protein'].copy()\n",
    "\n",
    "# recheck the null values in bmgc_relation_df\n",
    "print(\"Null values in bmgc_relation_df:\")\n",
    "print(bmgc_relation_df.isnull().sum())\n",
    "\n",
    "# fetch the Promoter-Gene, Gene-Transcript, Transcript-Protein relation alone\n",
    "promoter_gene_relation_df = bmgc_relation_df[bmgc_relation_df['Type'] == 'Promoter-Gene'].copy()\n",
    "gene_transcript_relation_df = bmgc_relation_df[bmgc_relation_df['Type'] == 'Gene-Transcript'].copy()\n",
    "transcript_protein_relation_df = bmgc_relation_df[bmgc_relation_df['Type'] == 'Transcript-Protein'].copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "217084aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "gene_transcript_entity_df = pd.merge(gene_entity_df, gene_transcript_relation_df[['BMGC_From_ID', 'BMGC_To_ID']], left_on='BioMedGraphica_Conn_ID', right_on='BMGC_From_ID', how='outer')\n",
    "gene_transcript_protein_entity_df = pd.merge(gene_transcript_entity_df, transcript_protein_relation_df[['BMGC_From_ID', 'BMGC_To_ID']], left_on='BMGC_To_ID', right_on='BMGC_From_ID', how='outer')\n",
    "# drop NaN values in BMGC_From_ID_x\tBMGC_To_ID_x BMGC_From_ID_y\tBMGC_To_ID_y\n",
    "gene_transcript_protein_entity_df = gene_transcript_protein_entity_df.dropna(subset=['BMGC_From_ID_x', 'BMGC_To_ID_x', 'BMGC_From_ID_y', 'BMGC_To_ID_y']).reset_index(drop=True)\n",
    "# keep the columns ['BioMedGraphica_Conn_ID', 'BMGC_To_ID_x', 'BMGC_To_ID_y'] and rename the columns to ['BMGC_GN_ID', 'BMGC_TS_ID', 'BMGC_PT_ID']\n",
    "gene_transcript_protein_entity_df = gene_transcript_protein_entity_df[['BioMedGraphica_Conn_ID', 'BMGC_To_ID_y']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_GN_ID', 'BMGC_To_ID_y': 'BMGC_PT_ID'}).sort_values(by='BMGC_GN_ID').reset_index(drop=True)\n",
    "# drop duplicates rows in gene_transcript_protein_entity_df\n",
    "gene_transcript_protein_entity_df = gene_transcript_protein_entity_df.drop_duplicates().reset_index(drop=True)\n",
    "# Merge HGNC_Symbol with [gene_transcript_protein_entity_df]\n",
    "bmgc_gene_df = bmgc_gene_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']]\n",
    "gene_transcript_protein_hgnc_entity_df = pd.merge(gene_transcript_protein_entity_df, bmgc_gene_df, left_on='BMGC_GN_ID', right_on='BioMedGraphica_Conn_ID', how='left')\n",
    "# Fill the NaN values with 'Unknown'\n",
    "gene_transcript_protein_hgnc_entity_df['HGNC_Symbol'] = gene_transcript_protein_hgnc_entity_df['HGNC_Symbol'].fillna('Unknown')\n",
    "# Only keep the columns ['BMGC_PT_ID', 'HGNC_Symbol']\n",
    "gene_transcript_protein_hgnc_entity_df = gene_transcript_protein_hgnc_entity_df[['BMGC_PT_ID', 'HGNC_Symbol']]\n",
    "# Build up the dictionary for mapping the BMGC_PT_ID to HGNC_Symbol from [gene_transcript_protein_hgnc_entity_df] columns ['BMGC_PT_ID', 'HGNC_Symbol']\n",
    "protein_hgnc_map_dict = dict(zip(gene_transcript_protein_hgnc_entity_df['BMGC_PT_ID'], gene_transcript_protein_hgnc_entity_df['HGNC_Symbol']))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7399c5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Map the BMGC_PT_ID to HGNC_Symbol in bmgc_relation_ppi_df['BMGC_From_ID', 'BMGC_To_ID'] columns\n",
    "bmgc_relation_ppi_map_df = bmgc_relation_ppi_df.copy()\n",
    "bmgc_relation_ppi_map_df['BMGC_From_ID'] = bmgc_relation_ppi_map_df['BMGC_From_ID'].map(protein_hgnc_map_dict)\n",
    "bmgc_relation_ppi_map_df['BMGC_To_ID'] = bmgc_relation_ppi_map_df['BMGC_To_ID'].map(protein_hgnc_map_dict)\n",
    "# Drop the rows where the BMGC_From_ID or BMGC_To_ID is 'Unknown'\n",
    "bmgc_relation_ppi_map_df = bmgc_relation_ppi_map_df[(bmgc_relation_ppi_map_df['BMGC_From_ID'] != 'Unknown') & (bmgc_relation_ppi_map_df['BMGC_To_ID'] != 'Unknown')].reset_index(drop=True)\n",
    "# Check if there is any Unknown string in either BMGC_From_ID or BMGC_To_ID\n",
    "unknown_from_id = bmgc_relation_ppi_map_df[bmgc_relation_ppi_map_df['BMGC_From_ID'] == 'Unknown']\n",
    "unknown_to_id = bmgc_relation_ppi_map_df[bmgc_relation_ppi_map_df['BMGC_To_ID'] == 'Unknown']\n",
    "# Print the number of Unknown entries in BMGC_From_ID and BMGC_To_ID\n",
    "print(f\"Number of Unknown entries in BMGC_From_ID: {len(unknown_from_id)}\")\n",
    "print(f\"Number of Unknown entries in BMGC_To_ID: {len(unknown_to_id)}\")\n",
    "display(bmgc_relation_ppi_map_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "871206a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make a reversed copy of bmgc_relation_ppi_map_df and concatenate it with the original dataframe\n",
    "bmgc_relation_ppi_map_df_reversed = bmgc_relation_ppi_map_df.copy()\n",
    "bmgc_relation_ppi_map_df_reversed = bmgc_relation_ppi_map_df_reversed.rename(columns={'BMGC_From_ID': 'BMGC_To_ID', 'BMGC_To_ID': 'BMGC_From_ID'})\n",
    "bmgc_relation_ppi_map_concat_df = pd.concat([bmgc_relation_ppi_map_df, bmgc_relation_ppi_map_df_reversed], ignore_index=True)\n",
    "display(bmgc_relation_ppi_map_concat_df)\n",
    "# Drop duplicates in the concatenated dataframe with BMGC_From_ID and BMGC_To_ID\n",
    "bmgc_relation_ppi_map_concat_df = bmgc_relation_ppi_map_concat_df.drop_duplicates(subset=['BMGC_From_ID', 'BMGC_To_ID']).reset_index(drop=True)\n",
    "display(bmgc_relation_ppi_map_concat_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e182e2c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Group by 'BMGC_From_ID', and merge the 'BMGC_To_ID' values into a list (droping the Source and Type columns)\n",
    "bmgc_relation_ppi_map_concat_group_df = bmgc_relation_ppi_map_concat_df.drop(columns=['Source', 'Type']).copy()\n",
    "bmgc_relation_ppi_map_concat_group_df = bmgc_relation_ppi_map_concat_group_df.groupby('BMGC_From_ID')['BMGC_To_ID'].apply(list).reset_index()\n",
    "# Check the element in the BMGC_To_ID column, if it contains the ';' string, then replace the ';' with ', '\n",
    "bmgc_relation_ppi_map_concat_group_df['BMGC_To_ID'] = bmgc_relation_ppi_map_concat_group_df['BMGC_To_ID'].apply(lambda x: [str(i).replace(';', ', ') for i in x])\n",
    "display(bmgc_relation_ppi_map_concat_group_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f0d0575",
   "metadata": {},
   "outputs": [],
   "source": [
    "bmgc_gene_llm_relation_df = pd.merge(bmgc_gene_llm_df, bmgc_relation_ppi_map_concat_group_df, left_on='HGNC_Symbol', right_on='BMGC_From_ID', how='left').drop(columns=['BMGC_From_ID'])\n",
    "display(bmgc_gene_llm_relation_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f028ec6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "{\n",
    "\"text\": \"{HGNC_Symbol} short for {Gene_Name} is a {Gene_Type} gene located on Chromosome {Chromosome} from {Gene_Start} to {Gene_End}. The Ensembl Gene ID is {Ensembl_Gene_ID}, also {Ensembl_Gene_ID_Version} and the NCBI Gene ID is {NCBI_Gene_ID}. In details, {Gene_Name} has the NCBI Gene description with {NCBI_Gene}. Also, it has the Ensembl description with {Ensembl}. Aside from that, {HGNC_Symbol} is related to the following genes: {BMGC_To_ID}.\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03b1d286",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "\n",
    "# Function to create text description while handling NaN values and gene relationships\n",
    "def create_text_description(row):\n",
    "    # Start with HGNC_Symbol which we know exists\n",
    "    text = f\"{row['HGNC_Symbol']}\"\n",
    "    \n",
    "    # Add Gene_Name if available\n",
    "    if pd.notna(row['Gene_Name']):\n",
    "        text += f\" short for {row['Gene_Name']}\"\n",
    "    \n",
    "    # Add Gene_Type if available\n",
    "    if pd.notna(row['Gene_Type']):\n",
    "        text += f\" is a {row['Gene_Type']} gene\"\n",
    "    else:\n",
    "        text += \" is a gene\"\n",
    "    \n",
    "    # Add Chromosome, Gene_Start, Gene_End if available\n",
    "    if pd.notna(row['Chromosome']):\n",
    "        text += f\" located on Chromosome {row['Chromosome']}\"\n",
    "        \n",
    "        if pd.notna(row['Gene_Start']) and pd.notna(row['Gene_End']):\n",
    "            text += f\" from {row['Gene_Start']} to {row['Gene_End']}\"\n",
    "    \n",
    "    # Add Ensembl Gene IDs if available\n",
    "    if pd.notna(row['Ensembl_Gene_ID']):\n",
    "        text += f\". The Ensembl Gene ID is {row['Ensembl_Gene_ID']}\"\n",
    "        if pd.notna(row['Ensembl_Gene_ID_Version']):\n",
    "            text += f\", also {row['Ensembl_Gene_ID_Version']}\"\n",
    "    \n",
    "    # Add NCBI Gene ID if available\n",
    "    if pd.notna(row['NCBI_Gene_ID']):\n",
    "        text += f\" and the NCBI Gene ID is {row['NCBI_Gene_ID']}\"\n",
    "    \n",
    "    # Add detailed descriptions if available\n",
    "    if pd.notna(row['Gene_Name']):\n",
    "        if pd.notna(row['NCBI_Gene']):\n",
    "            text += f\". In details, {row['Gene_Name']} has the NCBI Gene description with {row['NCBI_Gene']}\"\n",
    "        \n",
    "        if pd.notna(row['Ensembl']):\n",
    "            text += f\". Also, it has the Ensembl description with {row['Ensembl']}\"\n",
    "    \n",
    "    # Add related genes if available - Fixed to properly check if BMGC_To_ID exists and is not None/NaN\n",
    "    if 'BMGC_To_ID' in row and row['BMGC_To_ID'] is not None and isinstance(row['BMGC_To_ID'], list) and len(row['BMGC_To_ID']) > 0:\n",
    "        # Format the list of related genes as a string\n",
    "        related_genes = ', '.join(row['BMGC_To_ID'])\n",
    "        text += f\". Aside from that, {row['HGNC_Symbol']} is related to the following genes: {related_genes}\"\n",
    "    \n",
    "    # Add period at the end if needed\n",
    "    if not text.endswith('.'):\n",
    "        text += \".\"\n",
    "        \n",
    "    return text\n",
    "\n",
    "# Create the output directory if it doesn't exist\n",
    "output_dir = './data/TargetPretrain'\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "# Path for the JSONL file\n",
    "output_path = os.path.join(output_dir, 'gene_relation_description.jsonl')\n",
    "\n",
    "# Write each row as a separate JSON line in the file\n",
    "with open(output_path, 'w', encoding='utf-8') as f:\n",
    "    for _, row in bmgc_gene_llm_relation_df.iterrows():\n",
    "        text_description = create_text_description(row)\n",
    "        json_line = {\"text\": text_description}\n",
    "        f.write(json.dumps(json_line, ensure_ascii=False) + '\\n')\n",
    "\n",
    "print(f\"JSONL file created with {len(bmgc_gene_llm_relation_df)} gene descriptions at {os.path.abspath(output_path)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "969ecafc",
   "metadata": {},
   "source": [
    "## 2. Disease Node Description"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed8a6ed6",
   "metadata": {},
   "source": [
    "### 2.1 Disease Entity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5e1fa45",
   "metadata": {},
   "outputs": [],
   "source": [
    "bmgc_disease_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Disease/BioMedGraphica_Conn_Disease.csv', dtype=str)\n",
    "bmgc_disease_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Disease/BioMedGraphica_Conn_Disease_Description.csv', dtype=str).drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_disease_display_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Disease/BioMedGraphica_Conn_Disease_Display_Name.csv', dtype=str).drop(columns=['BioMedGraphica_ID'])\n",
    "# display(bmgc_disease_display_df)\n",
    "bmgc_disease_llm_df = pd.merge(bmgc_disease_df, bmgc_disease_desc_df, how='left', on='BioMedGraphica_Conn_ID')\n",
    "bmgc_disease_llm_df = pd.merge(bmgc_disease_llm_df, bmgc_disease_display_df, how='left', on='BioMedGraphica_Conn_ID')\n",
    "# # count each columns unique values\n",
    "# print(bmgc_disease_llm_df.nunique())\n",
    "# keep columns ['MONDO_Name', 'MONDO_ID', 'UMLS_Name', 'UMLS_ID', 'DO_Name', 'DO_ID', 'SNOMEDCT_Name', 'SNOMEDCT_ID', 'MeSH_Name', 'MESH_ID', 'ICD11_Title', 'ICD11_ID', 'ICD10_ID', 'OMIM_ID', 'MONDO', 'MESH', 'NCI', 'SNOMEDCT_US', 'ORPHANET', 'HPO']\n",
    "bmgc_disease_llm_df = bmgc_disease_llm_df[['BioMedGraphica_Conn_ID', 'BMG_Disease_Name', 'MONDO_Name', 'MONDO_ID', 'UMLS_Name', 'UMLS_ID', 'DO_Name', 'DO_ID', 'SNOMEDCT_Name', 'SNOMEDCT_ID', 'MeSH_Name', 'MeSH_ID', 'ICD11_Title', 'ICD11_ID', 'ICD10_ID', 'OMIM_ID', 'MONDO', 'MeSH', 'NCI', 'SNOMEDCT_US', 'ORPHANET', 'HPO']]\n",
    "display(bmgc_disease_llm_df.head(2))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb3e4ce9",
   "metadata": {},
   "source": [
    "### 2.2 Disease Protein Relation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43fbc8ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filter out rows in type ['Protein-Disease', 'Disease-Protein']\n",
    "bmgc_relation_disease_df = bmgc_relation_df[bmgc_relation_df['Type'].isin(['Protein-Disease', 'Disease-Protein'])].copy()\n",
    "# Keep the columns ['BMGC_From_ID', 'BMGC_To_ID', 'Source', 'Type']\n",
    "bmgc_relation_disease_df = bmgc_relation_disease_df[['BMGC_From_ID', 'BMGC_To_ID', 'Source', 'Type']].reset_index(drop=True)\n",
    "# Map the BMGC_From_ID to hgnc_symbol in bmgc_relation_disease_df['BMGC_From_ID'] column\n",
    "bmgc_relation_disease_df['BMGC_From_ID'] = bmgc_relation_disease_df['BMGC_From_ID'].map(protein_hgnc_map_dict)\n",
    "# Drop the rows where the BMGC_From_ID is 'Unknown'\n",
    "bmgc_relation_disease_df = bmgc_relation_disease_df[bmgc_relation_disease_df['BMGC_From_ID'] != 'Unknown'].reset_index(drop=True)\n",
    "display(bmgc_relation_disease_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81c62036",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Group by 'BMGC_To_ID' by droping columns ['Source', 'Type']\n",
    "bmgc_relation_disease_group_df = bmgc_relation_disease_df.drop(columns=['Source', 'Type']).copy()\n",
    "bmgc_relation_disease_group_df = bmgc_relation_disease_group_df.groupby('BMGC_To_ID')['BMGC_From_ID'].apply(list).reset_index()\n",
    "display(bmgc_relation_disease_group_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de771684",
   "metadata": {},
   "outputs": [],
   "source": [
    "bmgc_disease_llm_relation_df = pd.merge(bmgc_disease_llm_df, bmgc_relation_disease_group_df, left_on='BioMedGraphica_Conn_ID', right_on='BMGC_To_ID', how='left').drop(columns=['BMGC_To_ID'])\n",
    "display(bmgc_disease_llm_relation_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1630d1bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For column \"BMG_Disease_Name\", make the string initial letter capitalized\n",
    "# Handle NaN values and ensure we're working with strings\n",
    "bmgc_disease_llm_relation_df['BMG_Disease_Name'] = bmgc_disease_llm_relation_df['BMG_Disease_Name'].apply(\n",
    "    lambda x: ', '.join([i.strip().capitalize() for i in str(x).split(',')]) if pd.notna(x) and ', ' in str(x) \n",
    "    else (str(x).capitalize() if pd.notna(x) else x)\n",
    ")\n",
    "# For column \"BMG_Disease_Name\":\n",
    "# 1. Make the string initial letter capitalized\n",
    "# 2. Remove commas while preserving text and original capitalization pattern\n",
    "# 3. Handle NaN values properly\n",
    "bmgc_disease_llm_relation_df['BMG_Disease_Name'] = bmgc_disease_llm_relation_df['BMG_Disease_Name'].apply(\n",
    "    lambda x: ' '.join([i.strip() if idx > 0 else i.strip().capitalize() \n",
    "                       for idx, i in enumerate(str(x).split(','))])\n",
    "    if pd.notna(x) else x\n",
    ")\n",
    "# Replace the ' | ' with ' or ' in string columns that might contain it\n",
    "# Handle NaN values to avoid AttributeError\n",
    "for col in ['BMG_Disease_Name', 'MONDO_Name', 'UMLS_Name', 'DO_Name', 'SNOMEDCT_Name', 'MeSH_Name', 'ICD11_Title']:\n",
    "    bmgc_disease_llm_relation_df[col] = bmgc_disease_llm_relation_df[col].apply(\n",
    "        lambda x: x.replace(' | ', ' or ') if pd.notna(x) and isinstance(x, str) else x\n",
    "    )\n",
    "\n",
    "display(bmgc_disease_llm_relation_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97926f47",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert the Dataframe to text description in this format:\n",
    "{\n",
    "    \"text\": \"{BMG_Disease_Name} is a disease, which has been recorded in MONDO with MONDO Name {MONDO_Name} and {MONDO_ID}. It is also recorded in UMLS with UMLS ID {UMLS_ID} and name with {UMLS_Name}. In addition, it is recorded in Disease Ontology (DO) with {DO_ID} and name with {DO_Name}. The disease is also recorded in SNOMEDCT with SNOMEDCT ID with {SNOMEDCT_ID} and name with {SNOMEDCT_Name}. It is also recorded in MeSH with MeSH ID {MeSH_ID} and {MeSH_Name}. The disease is also recorded in ICD11 with {ICD11_ID}, named as {ICD11_Title} and ICD10 with {ICD10_ID}. The disease is also recorded in OMIM with OMIM ID {OMIM_ID}. In details, the disease {BMG_Disease_Name} has the MONDO description with: {MONDO} {BMG_Disease_Name} also has MeSH description with: {MeSH} NCI description with: {NCI} SNOMEDCT_US description with: {SNOMEDCT_US} ORPHANET description with: {ORPHANET} and HPO description with: {HPO} Aside from that, {BMG_Disease_Name} is related to the following genes: {BMGC_From_ID}.\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3c6ffd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# Function to create text description for diseases following the specified template\n",
    "def create_disease_description(row):\n",
    "    # Start with disease name - use alternative names if BMG_Disease_Name is NaN\n",
    "    if pd.notna(row['BMG_Disease_Name']):\n",
    "        disease_name = row['BMG_Disease_Name']\n",
    "    else:\n",
    "        # Try to get name from other sources in order of preference\n",
    "        name_sources = ['MONDO_Name', 'UMLS_Name', 'DO_Name', 'SNOMEDCT_Name', 'MeSH_Name']\n",
    "        disease_name = None\n",
    "        for source in name_sources:\n",
    "            if pd.notna(row[source]):\n",
    "                disease_name = row[source]\n",
    "                break\n",
    "        \n",
    "        # If still no name found, return None to skip this entry\n",
    "        if disease_name is None:\n",
    "            return None\n",
    "    \n",
    "    # Start building the text following the template\n",
    "    text = f\"{disease_name} is a disease\"\n",
    "    \n",
    "    # Add MONDO information\n",
    "    if pd.notna(row['MONDO_ID']) or pd.notna(row['MONDO_Name']):\n",
    "        text += \", which has been recorded in MONDO\"\n",
    "        if pd.notna(row['MONDO_Name']):\n",
    "            text += f\" with MONDO Name {row['MONDO_Name']}\"\n",
    "        if pd.notna(row['MONDO_ID']):\n",
    "            text += f\" and {row['MONDO_ID']}\"\n",
    "    \n",
    "    # Add UMLS information\n",
    "    if pd.notna(row['UMLS_ID']) or pd.notna(row['UMLS_Name']):\n",
    "        text += \". It is also recorded in UMLS\"\n",
    "        if pd.notna(row['UMLS_ID']):\n",
    "            text += f\" with UMLS ID {row['UMLS_ID']}\"\n",
    "        if pd.notna(row['UMLS_Name']):\n",
    "            text += f\" and name with {row['UMLS_Name']}\"\n",
    "    \n",
    "    # Add Disease Ontology information\n",
    "    if pd.notna(row['DO_ID']) or pd.notna(row['DO_Name']):\n",
    "        text += \". In addition, it is recorded in Disease Ontology (DO)\"\n",
    "        if pd.notna(row['DO_ID']):\n",
    "            text += f\" with {row['DO_ID']}\"\n",
    "        if pd.notna(row['DO_Name']):\n",
    "            text += f\" and name with {row['DO_Name']}\"\n",
    "    \n",
    "    # Add SNOMEDCT information\n",
    "    if pd.notna(row['SNOMEDCT_ID']) or pd.notna(row['SNOMEDCT_Name']):\n",
    "        text += \". The disease is also recorded in SNOMEDCT\"\n",
    "        if pd.notna(row['SNOMEDCT_ID']):\n",
    "            text += f\" with SNOMEDCT ID with {row['SNOMEDCT_ID']}\"\n",
    "        if pd.notna(row['SNOMEDCT_Name']):\n",
    "            text += f\" and name with {row['SNOMEDCT_Name']}\"\n",
    "    \n",
    "    # Add MeSH information\n",
    "    if pd.notna(row['MeSH_ID']) or pd.notna(row['MeSH_Name']):\n",
    "        text += \". It is also recorded in MeSH\"\n",
    "        if pd.notna(row['MeSH_ID']):\n",
    "            text += f\" with MeSH ID {row['MeSH_ID']}\"\n",
    "        if pd.notna(row['MeSH_Name']):\n",
    "            text += f\" and {row['MeSH_Name']}\"\n",
    "    \n",
    "    # Add ICD information\n",
    "    if pd.notna(row['ICD11_ID']) or pd.notna(row['ICD11_Title']) or pd.notna(row['ICD10_ID']):\n",
    "        text += \". The disease is also recorded\"\n",
    "        if pd.notna(row['ICD11_ID']) or pd.notna(row['ICD11_Title']):\n",
    "            text += \" in ICD11\"\n",
    "            if pd.notna(row['ICD11_ID']):\n",
    "                text += f\" with {row['ICD11_ID']}\"\n",
    "            if pd.notna(row['ICD11_Title']):\n",
    "                text += f\", named as {row['ICD11_Title']}\"\n",
    "        \n",
    "        if pd.notna(row['ICD10_ID']):\n",
    "            if pd.notna(row['ICD11_ID']) or pd.notna(row['ICD11_Title']):\n",
    "                text += \" and\"\n",
    "            text += f\" ICD10 with {row['ICD10_ID']}\"\n",
    "    \n",
    "    # Add OMIM information\n",
    "    if pd.notna(row['OMIM_ID']):\n",
    "        text += f\". The disease is also recorded in OMIM with OMIM ID {row['OMIM_ID']}\"\n",
    "    \n",
    "    # Add detailed descriptions\n",
    "    descriptions = []\n",
    "    \n",
    "    # MONDO description\n",
    "    if pd.notna(row['MONDO']):\n",
    "        descriptions.append(f\"the disease {disease_name} has the MONDO description with: {row['MONDO']}\")\n",
    "    \n",
    "    # MeSH description\n",
    "    if pd.notna(row['MeSH']):\n",
    "        descriptions.append(f\"{disease_name} also has MeSH description with: {row['MeSH']}\")\n",
    "    \n",
    "    # NCI description\n",
    "    if pd.notna(row['NCI']):\n",
    "        descriptions.append(f\"NCI description with: {row['NCI']}\")\n",
    "    \n",
    "    # SNOMEDCT_US description\n",
    "    if pd.notna(row['SNOMEDCT_US']):\n",
    "        descriptions.append(f\"SNOMEDCT_US description with: {row['SNOMEDCT_US']}\")\n",
    "    \n",
    "    # ORPHANET description\n",
    "    if pd.notna(row['ORPHANET']):\n",
    "        descriptions.append(f\"ORPHANET description with: {row['ORPHANET']}\")\n",
    "    \n",
    "    # HPO description\n",
    "    if pd.notna(row['HPO']):\n",
    "        descriptions.append(f\"HPO description with: {row['HPO']}\")\n",
    "    \n",
    "    # Add descriptions if any exist\n",
    "    if descriptions:\n",
    "        text += \". In details, \" + \" \".join(descriptions)\n",
    "    \n",
    "    # Add related genes if available\n",
    "    if 'BMGC_From_ID' in row and isinstance(row['BMGC_From_ID'], list) and len(row['BMGC_From_ID']) > 0:\n",
    "        # Convert all elements to strings, filtering out any NaN values\n",
    "        valid_genes = [str(gene) for gene in row['BMGC_From_ID'] \n",
    "                      if pd.notna(gene) and not (isinstance(gene, float) and np.isnan(gene))]\n",
    "        \n",
    "        if valid_genes:  # Only proceed if there are valid genes\n",
    "            related_genes = ', '.join(valid_genes)\n",
    "            text += f\". Aside from that, {disease_name} is related to the following genes: {related_genes}\"\n",
    "    \n",
    "    # Add period at the end if needed\n",
    "    if not text.endswith('.'):\n",
    "        text += \".\"\n",
    "    \n",
    "    return text\n",
    "\n",
    "# Path for the JSONL file\n",
    "disease_output_path = os.path.join(output_dir, 'disease_relation_description.jsonl')\n",
    "\n",
    "# Write each row as a separate JSON line\n",
    "skipped_count = 0\n",
    "written_count = 0\n",
    "\n",
    "with open(disease_output_path, 'w', encoding='utf-8') as f:\n",
    "    for _, row in bmgc_disease_llm_relation_df.iterrows():\n",
    "        description = create_disease_description(row)\n",
    "        if description is not None:\n",
    "            json_line = {\"text\": description}\n",
    "            f.write(json.dumps(json_line, ensure_ascii=False) + '\\n')\n",
    "            written_count += 1\n",
    "        else:\n",
    "            skipped_count += 1\n",
    "\n",
    "print(f\"JSONL file created with {written_count} disease descriptions at {os.path.abspath(disease_output_path)}\")\n",
    "print(f\"Skipped {skipped_count} entries due to missing disease names\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ff39e28",
   "metadata": {},
   "source": [
    "## 3. Mixed pretrainining data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdc16273",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import random\n",
    "import os\n",
    "\n",
    "# Lists to hold the data\n",
    "gene_data = []\n",
    "disease_data = []\n",
    "\n",
    "# Load the gene descriptions\n",
    "gene_file_path = './data/TargetPretrain/gene_relation_description.jsonl'\n",
    "with open(gene_file_path, 'r', encoding='utf-8') as f:\n",
    "    for line in f:\n",
    "        gene_data.append(json.loads(line))\n",
    "\n",
    "# Load the disease descriptions\n",
    "disease_file_path = './data/TargetPretrain/disease_relation_description.jsonl'\n",
    "with open(disease_file_path, 'r', encoding='utf-8') as f:\n",
    "    for line in f:\n",
    "        disease_data.append(json.loads(line))\n",
    "\n",
    "# Combine the data\n",
    "combined_data = gene_data + disease_data\n",
    "\n",
    "# Shuffle the combined data\n",
    "random.seed(42)  # For reproducibility\n",
    "random.shuffle(combined_data)\n",
    "\n",
    "# Save the mixed data as JSONL\n",
    "output_path = './data/TargetPretrain/mixed_description.jsonl'\n",
    "with open(output_path, 'w', encoding='utf-8') as f:\n",
    "    for item in combined_data:\n",
    "        f.write(json.dumps(item, ensure_ascii=False) + '\\n')\n",
    "\n",
    "print(f\"Mixed data file created with {len(combined_data)} total descriptions\")\n",
    "print(f\"- {len(gene_data)} gene descriptions\")\n",
    "print(f\"- {len(disease_data)} disease descriptions\")\n",
    "print(f\"Output saved to: {os.path.abspath(output_path)}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mkg",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
