{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Methylation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_methy_df = pd.read_csv('./data/raw_data/CCLE_RRBS_TSS1kb_20181022.txt', sep='\\t')\n",
    "display(raw_methy_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# column locus_id should be separate the content after the  '_'\n",
    "raw_methy_df['locus_id'] = raw_methy_df['locus_id'].apply(lambda x: x.split('_')[0])\n",
    "# drop na in the column locus_id\n",
    "raw_methy_df = raw_methy_df.dropna(subset=['CpG_sites_hg19'])\n",
    "# remove the column CpG_sites_hg19 and avg_coverage\n",
    "raw_methy_df = raw_methy_df.drop(columns=['CpG_sites_hg19', 'avg_coverage'])\n",
    "display(raw_methy_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Identify columns to convert (all except 'locus_id')\n",
    "cols_to_convert = raw_methy_df.columns.difference(['locus_id'])\n",
    "# Convert these columns to numeric (using pd.to_numeric for safety, which can handle errors)\n",
    "raw_methy_df[cols_to_convert] = raw_methy_df[cols_to_convert].apply(pd.to_numeric, errors='coerce')\n",
    "raw_methy_df = raw_methy_df.fillna(0.0)\n",
    "display(raw_methy_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now perform the groupby mean operation\n",
    "methy_df = raw_methy_df.groupby('locus_id', as_index=False).mean()\n",
    "display(methy_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bmg_promoter_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Promoter/BioMedGraphica_Conn_Promoter.csv')\n",
    "# keep BioMedGraphica_ID and HGNC_Symbol\n",
    "bmg_promoter_df = bmg_promoter_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']]\n",
    "display(bmg_promoter_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge the biomedgraphica_id with the raw_methy_df\n",
    "merged_methy_df = pd.merge(bmg_promoter_df, methy_df, left_on='HGNC_Symbol', right_on='locus_id', how='inner')\n",
    "merged_methy_df.drop(columns=['HGNC_Symbol', 'locus_id'], inplace=True)\n",
    "display(merged_methy_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_merged_methy_df = pd.merge(bmg_promoter_df, methy_df, left_on='HGNC_Symbol', right_on='locus_id', how='left')\n",
    "# fill the NaN with 0.0\n",
    "final_merged_methy_df = final_merged_methy_df.fillna(0.0)\n",
    "final_merged_methy_df.drop(columns=['HGNC_Symbol', 'locus_id'], inplace=True)\n",
    "display(final_merged_methy_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Gene"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_gene_df = pd.read_csv('./data/raw_data/OmicsCNGene.csv')\n",
    "display(raw_gene_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First, set the \"Unnamed: 0\" column as the index:\n",
    "raw_gene_t_df = raw_gene_df.set_index(\"Unnamed: 0\")\n",
    "# Then transpose the DataFrame:\n",
    "raw_gene_t_df = raw_gene_t_df.transpose()\n",
    "# move the index to a column and rename it as the gene_name\n",
    "raw_gene_t_df.reset_index(inplace=True)\n",
    "raw_gene_t_df.rename(columns={'index': 'gene_name'}, inplace=True)\n",
    "# column gene_name should be kept with the content by removing \"()\" and remove and \" \" content\n",
    "raw_gene_t_df['gene_name'] = raw_gene_t_df['gene_name'].apply(lambda x: x.split('(')[0].strip())\n",
    "# Optionally, if you want to view the result:\n",
    "display(raw_gene_t_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Identify columns to convert (all except 'Unnamed: 0' and 'gene_name')\n",
    "cols_to_convert = raw_gene_t_df.columns.difference(['Unnamed: 0', 'gene_name'])\n",
    "# Convert these columns to numeric (using pd.to_numeric for safety, which can handle errors)\n",
    "raw_gene_t_df[cols_to_convert] = raw_gene_t_df[cols_to_convert].apply(pd.to_numeric, errors='coerce')\n",
    "raw_gene_t_df = raw_gene_t_df.fillna(0.0)\n",
    "display(raw_gene_t_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# groupby the gene_name and calculate the mean\n",
    "gene_df = raw_gene_t_df.groupby('gene_name', as_index=False).mean()\n",
    "display(gene_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bmg_gene_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene.csv')\n",
    "# keep BioMedGraphica_ID and HGNC_Symbol\n",
    "bmg_gene_df = bmg_gene_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']]\n",
    "display(bmg_gene_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge the biomedgraphica_id with the raw_gene_t_df\n",
    "merged_gene_df = pd.merge(bmg_gene_df, gene_df, left_on='HGNC_Symbol', right_on='gene_name', how='inner')\n",
    "merged_gene_df.drop(columns=['HGNC_Symbol', 'gene_name'], inplace=True)\n",
    "display(merged_gene_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_merged_gene_df = pd.merge(bmg_gene_df, gene_df, left_on='HGNC_Symbol', right_on='gene_name', how='left')\n",
    "# fill the NaN with -1.0\n",
    "final_merged_gene_df = final_merged_gene_df.fillna(0.0)\n",
    "final_merged_gene_df.drop(columns=['HGNC_Symbol', 'gene_name'], inplace=True)\n",
    "display(final_merged_gene_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Transcript"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_transcript_df = pd.read_csv('./data/raw_data/OmicsExpressionProteinCodingGenesTPMLogp1BatchCorrected.csv')\n",
    "display(raw_transcript_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First, set the \"Unnamed: 0\" column as the index:\n",
    "raw_transcript_df = raw_transcript_df.set_index(\"Unnamed: 0\")\n",
    "# Then transpose the DataFrame:\n",
    "raw_transcript_df = raw_transcript_df.transpose()\n",
    "# move the index to a column and rename it as the gene_name\n",
    "raw_transcript_df.reset_index(inplace=True)\n",
    "raw_transcript_df.rename(columns={'index': 'gene_name'}, inplace=True)\n",
    "# column gene_name should be kept with the content by removing \"()\" and remove and\n",
    "raw_transcript_df['gene_name'] = raw_transcript_df['gene_name'].apply(lambda x: x.split('(')[0].strip())\n",
    "# Optionally, if you want to view the result:\n",
    "display(raw_transcript_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Identify columns to convert (all except 'Unnamed: 0' and 'gene_name')\n",
    "cols_to_convert = raw_transcript_df.columns.difference(['Unnamed: 0', 'gene_name'])\n",
    "# Convert these columns to numeric (using pd.to_numeric for safety, which can handle errors)\n",
    "raw_transcript_df[cols_to_convert] = raw_transcript_df[cols_to_convert].apply(pd.to_numeric, errors='coerce')\n",
    "raw_transcript_df = raw_transcript_df.fillna(0.0)\n",
    "display(raw_transcript_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bmg_transcript_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Transcript/BioMedGraphica_Conn_Transcript.csv')\n",
    "# keep BioMedGraphica_ID and HGNC_Symbol\n",
    "bmg_transcript_df = bmg_transcript_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']]\n",
    "display(bmg_transcript_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge the biomedgraphica_id with the raw_transcript_df\n",
    "merge_transcript_df = pd.merge(bmg_transcript_df, raw_transcript_df, left_on='HGNC_Symbol', right_on='gene_name', how='inner')\n",
    "merge_transcript_df.drop(columns=['HGNC_Symbol', 'gene_name'], inplace=True)\n",
    "display(merge_transcript_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge the biomedgraphica_id with the raw_transcript_df\n",
    "final_merged_transcript_df = pd.merge(bmg_transcript_df, raw_transcript_df, left_on='HGNC_Symbol', right_on='gene_name', how='left')\n",
    "# fill the NaN with -1.0\n",
    "final_merged_transcript_df = final_merged_transcript_df.fillna(0.0)\n",
    "final_merged_transcript_df.drop(columns=['HGNC_Symbol', 'gene_name'], inplace=True)\n",
    "display(final_merged_transcript_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Protein"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_protein_df = pd.read_csv('./data/raw_data/protein_quant_current_normalized.csv')\n",
    "raw_protein_df = raw_protein_df.drop(columns=['Protein_Id', 'Gene_Symbol', 'Description', 'Group_ID', 'Uniprot'])\n",
    "# Also drop columns names contain Peptides\n",
    "raw_protein_df = raw_protein_df[raw_protein_df.columns.drop(list(raw_protein_df.filter(regex='Peptides')))]\n",
    "display(raw_protein_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Identify columns to convert (all except 'Unnamed: 0' and 'protein_name')\n",
    "cols_to_convert = raw_protein_df.columns.difference(['Uniprot_Acc'])\n",
    "# Convert these columns to numeric (using pd.to_numeric for safety, which can handle errors)\n",
    "raw_protein_df[cols_to_convert] = raw_protein_df[cols_to_convert].apply(pd.to_numeric, errors='coerce')\n",
    "raw_protein_df = raw_protein_df.fillna(0.0)\n",
    "display(raw_protein_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# groupby the protein_name and calculate the mean\n",
    "protein_df = raw_protein_df.groupby('Uniprot_Acc', as_index=False).mean()\n",
    "display(protein_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge the biomedgraphica_id with the raw_protein_df\n",
    "bmg_protein_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein.csv')\n",
    "bmg_protein_df = bmg_protein_df[['BioMedGraphica_Conn_ID', 'Uniprot_ID']]\n",
    "display(bmg_protein_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge the biomedgraphica_id with the raw_protein_df\n",
    "merged_protein_df = pd.merge(bmg_protein_df, protein_df, left_on='Uniprot_ID', right_on='Uniprot_Acc', how='inner')\n",
    "merged_protein_df.drop(columns=['Uniprot_ID', 'Uniprot_Acc'], inplace=True)\n",
    "display(merged_protein_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge the biomedgraphica_id with the raw_protein_df\n",
    "final_merged_protein_df = pd.merge(bmg_protein_df, protein_df, left_on='Uniprot_ID', right_on='Uniprot_Acc', how='left')\n",
    "# fill the NaN with 0.0\n",
    "final_merged_protein_df = final_merged_protein_df.fillna(0.0)\n",
    "final_merged_protein_df.drop(columns=['Uniprot_ID', 'Uniprot_Acc'], inplace=True)\n",
    "display(final_merged_protein_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Drug"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_drug_df = pd.read_csv('./data/raw_data/sanger-dose-response.csv')\n",
    "display(raw_drug_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# keep columns ['ARXSPAN_ID', 'DRUG_NAME', 'IC50_PUBLISHED']\n",
    "raw_drug_df = raw_drug_df[['ARXSPAN_ID', 'DRUG_NAME', 'IC50_PUBLISHED', 'AUC_PUBLISHED']]\n",
    "# check if there is nan in all of the dataframe\n",
    "print(raw_drug_df.isnull().sum())\n",
    "# drop the nan in this dataframe\n",
    "raw_drug_df = raw_drug_df.dropna().reset_index(drop=True)\n",
    "# check if there is nan in all of the dataframe again\n",
    "print(raw_drug_df.isnull().sum())\n",
    "display(raw_drug_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fetch the drug_name as a independent dataframe\n",
    "drug_name_df = raw_drug_df[['DRUG_NAME']]\n",
    "# remove the duplicate in the drug_name_df\n",
    "drug_name_df = drug_name_df.drop_duplicates().reset_index(drop=True)\n",
    "display(drug_name_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bmg_drug_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Drug/BioMedGraphica_Conn_Drug.csv')\n",
    "display(bmg_drug_df)\n",
    "bmg_drug_name_df = bmg_drug_df[['BioMedGraphica_Conn_ID', 'PubChem_Name', 'IUPAC_Name', 'UNII_Name', 'DrugBank_Name', 'PubChem_Synonym']]\n",
    "# keep BioMedGraphica_Conn_ID and Drug_Name\n",
    "display(bmg_drug_name_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Combine 'PubChem_Name', 'IUPAC_Name', 'UNII_Name', 'DrugBank_Name', 'PubChem_Synonym' into one column called 'Drug_Name'\n",
    "bmg_drug_name_df['Drug_Name'] = bmg_drug_name_df[['PubChem_Name', 'IUPAC_Name', 'UNII_Name', 'DrugBank_Name', 'PubChem_Synonym']].apply(lambda x: ';'.join(x.dropna().astype(str)), axis=1)\n",
    "bmg_drug_name_df = bmg_drug_name_df[['BioMedGraphica_Conn_ID', 'Drug_Name']]\n",
    "display(bmg_drug_name_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# since in the column drug name, there are lots of names split by \";\", if this space contain the name in the drug_name_df, then we should match this 2 rows and merge the drug_name_df with the bmg_drug_df\n",
    "# Fix the drug name processing\n",
    "# First, check if there are any NaN values in Drug_Name column\n",
    "print(f\"Number of NaN values in Drug_Name column: {bmg_drug_name_df['Drug_Name'].isna().sum()}\")\n",
    "\n",
    "# Fill NaN values with empty string to avoid errors\n",
    "bmg_drug_name_df['Drug_Name'] = bmg_drug_name_df['Drug_Name'].fillna('')\n",
    "\n",
    "# Convert to string to ensure split() works on all entries\n",
    "bmg_drug_name_df['drug_name_list'] = bmg_drug_name_df['Drug_Name'].astype(str).apply(lambda x: [name.strip() for name in x.split(';')] if x else [])\n",
    "\n",
    "# Similarly for drug_name_df\n",
    "drug_name_df['drug_name_list'] = drug_name_df['DRUG_NAME'].astype(str).apply(lambda x: [name.strip() for name in x.split(';')] if x else [])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(bmg_drug_name_df)\n",
    "display(drug_name_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a flattened mapping of drug names to their BMG IDs\n",
    "drug_name_to_bmg = {}\n",
    "for idx, row in bmg_drug_name_df.iterrows():\n",
    "    bmg_id = row['BioMedGraphica_Conn_ID']\n",
    "    for drug_name in row['drug_name_list']:\n",
    "        if drug_name:  # Avoid empty strings\n",
    "            drug_name_to_bmg[drug_name.upper()] = bmg_id\n",
    "\n",
    "# Create a new mapping from ARXSPAN_ID to BMG_ID based on drug name matches\n",
    "arxspan_to_bmg = {}\n",
    "for idx, row in drug_name_df.iterrows():\n",
    "    arxspan_id = row['DRUG_NAME']\n",
    "    for drug_name in row['drug_name_list']:\n",
    "        if drug_name.upper() in drug_name_to_bmg:\n",
    "            arxspan_to_bmg[arxspan_id] = drug_name_to_bmg[drug_name.upper()]\n",
    "            break\n",
    "\n",
    "# Create a mapping dataframe\n",
    "mapping_df = pd.DataFrame(list(arxspan_to_bmg.items()), columns=['DRUG_NAME', 'BioMedGraphica_Conn_ID'])\n",
    "\n",
    "# Display how many drug names were successfully matched\n",
    "print(f\"Successfully matched {len(mapping_df)} out of {len(drug_name_df)} drugs\")\n",
    "\n",
    "# Display the first few rows of the mapping\n",
    "print(\"\\nSample of drug name mappings:\")\n",
    "display(mapping_df)\n",
    "\n",
    "# Now you can use this mapping to merge with your drug_response_df\n",
    "# Example:\n",
    "# merged_df = pd.merge(drug_response_df, mapping_df, on='DRUG_NAME', how='left')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_merged_drug_df = mapping_df.copy()\n",
    "# sort values by the column BioMedGraphica_ID\n",
    "final_merged_drug_df = final_merged_drug_df.sort_values(by='BioMedGraphica_Conn_ID').reset_index(drop=True)\n",
    "display(final_merged_drug_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# filter the final drug score dataframe by filtering the 'DRUG_NAME' in the final_merged_drug_df DRUG_NAME\n",
    "final_drug_df = raw_drug_df[raw_drug_df['DRUG_NAME'].isin(final_merged_drug_df['DRUG_NAME'])].reset_index(drop=True)\n",
    "display(final_drug_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. CRISPR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load entity type ['Promoter', 'Gene', 'Transcript', 'Protein'] for bmgc_entity_df\n",
    "bmgc_entity_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/BioMedGraphica_Conn_Entity.csv')\n",
    "# check if there is any null values in the bmgc_entity_df\n",
    "print(bmgc_entity_df.isnull().sum())\n",
    "# filter out the rows with Type in the list ['promoter', 'gene', 'transcript', 'protein']\n",
    "bmgc_omics_df = bmgc_entity_df[bmgc_entity_df['Type'].isin(['Promoter', 'Gene', 'Transcript', 'Protein'])].reset_index(drop=True)\n",
    "# check if there is any null values in the bmgc_omics_df\n",
    "print(bmgc_omics_df.isnull().sum())\n",
    "display(bmgc_omics_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the relation\n",
    "bmgc_relation_df = pd.read_csv('./data/BioMedGraphica-Conn/Relation/BioMedGraphica_Conn_Relation.csv')\n",
    "# check if there is any null values in the bmgc_relation_df\n",
    "print(bmgc_relation_df.isnull().sum())\n",
    "\n",
    "# filter our the rows with Relation type in the list ['Promoter-Gene', 'Gene-Transcript', 'Transcript-Protein', 'Protein-Protein']\n",
    "bmgc_omics_relation_df = bmgc_relation_df[bmgc_relation_df['Type'].isin(['Promoter-Gene', 'Gene-Transcript', 'Transcript-Protein', 'Protein-Protein'])].reset_index(drop=True)\n",
    "# check if there is any null values in the bmgc_omics_relation_df\n",
    "print(bmgc_omics_relation_df.isnull().sum())\n",
    "display(bmgc_omics_relation_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# translation chain converging to the same node\n",
    "# fetch the promoter, gene, transcript and protein entity alone\n",
    "promoter_entity_df = bmgc_omics_df[bmgc_omics_df['Type'] == 'Promoter'].copy()\n",
    "gene_entity_df = bmgc_omics_df[bmgc_omics_df['Type'] == 'Gene'].copy()\n",
    "transcript_entity_df = bmgc_omics_df[bmgc_omics_df['Type'] == 'Transcript'].copy()\n",
    "protein_entity_df = bmgc_omics_df[bmgc_omics_df['Type'] == 'Protein'].copy()\n",
    "\n",
    "display(bmgc_omics_relation_df)\n",
    "# recheck the null values in bmgc_omics_relation_df\n",
    "print(\"Null values in bmgc_omics_relation_df:\")\n",
    "print(bmgc_omics_relation_df.isnull().sum())\n",
    "\n",
    "# fetch the Promoter-Gene, Gene-Transcript, Transcript-Protein relation alone\n",
    "promoter_gene_relation_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Promoter-Gene'].copy()\n",
    "gene_transcript_relation_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Gene-Transcript'].copy()\n",
    "transcript_protein_relation_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Transcript-Protein'].copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gene_transcript_entity_df = pd.merge(gene_entity_df, gene_transcript_relation_df[['BMGC_From_ID', 'BMGC_To_ID']], left_on='BioMedGraphica_Conn_ID', right_on='BMGC_From_ID', how='outer')\n",
    "gene_transcript_protein_entity_df = pd.merge(gene_transcript_entity_df, transcript_protein_relation_df[['BMGC_From_ID', 'BMGC_To_ID']], left_on='BMGC_To_ID', right_on='BMGC_From_ID', how='outer')\n",
    "# drop NaN values in BMGC_From_ID_x\tBMGC_To_ID_x BMGC_From_ID_y\tBMGC_To_ID_y\n",
    "gene_transcript_protein_entity_df = gene_transcript_protein_entity_df.dropna(subset=['BMGC_From_ID_x', 'BMGC_To_ID_x', 'BMGC_From_ID_y', 'BMGC_To_ID_y']).reset_index(drop=True)\n",
    "# keep the columns ['BioMedGraphica_Conn_ID', 'BMGC_To_ID_x', 'BMGC_To_ID_y'] and rename the columns to ['BMGC_GN_ID', 'BMGC_TS_ID', 'BMGC_PT_ID']\n",
    "gene_transcript_protein_entity_df = gene_transcript_protein_entity_df[['BioMedGraphica_Conn_ID', 'BMGC_To_ID_y']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_GN_ID', 'BMGC_To_ID_y': 'BMGC_PT_ID'}).sort_values(by='BMGC_GN_ID').reset_index(drop=True)\n",
    "# drop duplicates rows in gene_transcript_protein_entity_df\n",
    "gene_transcript_protein_entity_df = gene_transcript_protein_entity_df.drop_duplicates().reset_index(drop=True)\n",
    "display(gene_transcript_protein_entity_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# just copy gene_transcript_protein_entity_df as promoter_gene_transcript_protein_entity_df\n",
    "bmgc_promoter_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Promoter/BioMedGraphica_Conn_Promoter.csv')\n",
    "bmgc_gene_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene.csv')\n",
    "promoter_gene_transcript_protein_entity_df = gene_transcript_protein_entity_df.copy()\n",
    "promoter_gene_df = pd.concat([bmgc_promoter_df[['BioMedGraphica_Conn_ID']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_PM_ID'}), bmgc_gene_df[['BioMedGraphica_Conn_ID']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_GN_ID'})], axis=1)\n",
    "promoter_protein_entity_df = pd.merge(promoter_gene_transcript_protein_entity_df, promoter_gene_df, left_on='BMGC_GN_ID', right_on='BMGC_GN_ID', how='left').drop(columns=['BMGC_GN_ID'])\n",
    "promoter_protein_entity_df = promoter_protein_entity_df[['BMGC_PM_ID', 'BMGC_PT_ID']]\n",
    "display(promoter_protein_entity_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transcript_protein_entity_df = pd.merge(transcript_entity_df, transcript_protein_relation_df[['BMGC_From_ID', 'BMGC_To_ID']], left_on='BioMedGraphica_Conn_ID', right_on='BMGC_From_ID', how='outer')\n",
    "# drop NaN rows in the BMGC_From_ID and\tBMGC_To_ID\n",
    "transcript_protein_entity_df = transcript_protein_entity_df.dropna(subset=['BMGC_From_ID', 'BMGC_To_ID']).reset_index(drop=True)\n",
    "# keep the columns ['BioMedGraphica_Conn_ID', 'BMGC_To_ID'] and rename the columns to ['BMGC_TS_ID', 'BMGC_PT_ID']\n",
    "transcript_protein_entity_df = transcript_protein_entity_df[['BioMedGraphica_Conn_ID', 'BMGC_To_ID']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_TS_ID', 'BMGC_To_ID': 'BMGC_PT_ID'}).sort_values(by='BMGC_TS_ID').reset_index(drop=True)\n",
    "# drop duplicates rows in transcript_protein_entity_df\n",
    "transcript_protein_entity_df = transcript_protein_entity_df.drop_duplicates().reset_index(drop=True)\n",
    "display(transcript_protein_entity_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# keep the columns ['BioMedGraphica_Conn_ID'] and rename the columns to ['BMGC_PT_ID']\n",
    "only_protein_entity_df = protein_entity_df[['BioMedGraphica_Conn_ID']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_PT_ID'}).sort_values(by='BMGC_PT_ID').reset_index(drop=True)\n",
    "display(only_protein_entity_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.1 CRISPR top 100 gene entities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "raw_crispr_df = pd.read_csv('./data/raw_data/CRISPRGeneEffect.csv')\n",
    "display(raw_crispr_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the column names of the raw_crispr_df aside from the first column and convert this to a list\n",
    "raw_crispr_df_columns = raw_crispr_df.columns[1:].tolist()\n",
    "print(raw_crispr_df_columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Coalign the samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "methy_samples = final_merged_methy_df.columns[1:]\n",
    "gene_samples = final_merged_gene_df.columns[1:]\n",
    "transcript_samples = final_merged_transcript_df.columns[1:]\n",
    "protein_samples = final_merged_protein_df.columns[1:]\n",
    "drug_samples = list(set(list(final_drug_df['ARXSPAN_ID'])))\n",
    "\n",
    "# print all samples\n",
    "print(methy_samples)\n",
    "print(gene_samples)\n",
    "print(transcript_samples)\n",
    "print(protein_samples)\n",
    "print(drug_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cell line names in the protein samples, remove the content after the second \"_\"\n",
    "# result = ['_'.join(s.split('_')[:2]) for s in strings]\n",
    "protein_split_samples = ['_'.join(s.split('_')[:2]) for s in protein_samples]\n",
    "print(protein_split_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cell_line_anno_df = pd.read_csv('./data/raw_data/Cell_lines_annotations_20181226.txt', sep='\\t')\n",
    "# drop NaN in the depMapID\n",
    "cell_line_anno_df = cell_line_anno_df.dropna(subset=['depMapID']).reset_index(drop=True)\n",
    "# drop NaN in the Pathology\n",
    "cell_line_anno_df = cell_line_anno_df.dropna(subset=['Pathology']).reset_index(drop=True)\n",
    "# drop rows if the column 'PATHOLOGIST_ANNOTATION' contains 'benign'\n",
    "# First handle NaN values in PATHOLOGIST_ANNOTATION\n",
    "cell_line_anno_df = cell_line_anno_df[~(\n",
    "    cell_line_anno_df['PATHOLOGIST_ANNOTATION'].fillna('').str.contains('benign', case=False)\n",
    ")]\n",
    "display(cell_line_anno_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(list(set(protein_split_samples))))\n",
    "# give me the duplicate elements in the protein_split_samples\n",
    "import collections\n",
    "rep_list = [item for item, count in collections.Counter(protein_split_samples).items() if count > 1]\n",
    "# for column names in final_merged_protein_df, if the column name contains the element in the rep_list, collect the column name in rep_col_list\n",
    "rep_col_list = [col for col in final_merged_protein_df.columns if any(rep in col for rep in rep_list)]\n",
    "print(rep_col_list)\n",
    "# display the rep_col_list in the final_merged_protein_df\n",
    "display(final_merged_protein_df[rep_col_list])\n",
    "\n",
    "# removed all of the rep_col_list in the final_merged_protein_df\n",
    "final_merged_protein_df = final_merged_protein_df.drop(columns=rep_col_list)\n",
    "display(final_merged_protein_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the overlapped cell lines in methy_samples with cell_line_anno_df by merge on cell line id\n",
    "# first format the methy_samples as dataframe\n",
    "methy_samples_df = pd.DataFrame(methy_samples, columns=['CCLE_Name'])\n",
    "# merge the methy_samples_df with the cell_line_anno_df\n",
    "merged_methy_samples_df = pd.merge(methy_samples_df, cell_line_anno_df, left_on='CCLE_Name', right_on='CCLE_ID', how='inner')\n",
    "display(merged_methy_samples_df)\n",
    "\n",
    "# get the map dictionary from the ccle id to depmap id\n",
    "methy_map_dict = dict(zip(merged_methy_samples_df['CCLE_Name'], merged_methy_samples_df['depMapID']))\n",
    "print(methy_map_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the cleaned cell line names by removed the content after the second \"_\" in final_merged_protein_df.columns[1:]\n",
    "original_protein_samples = final_merged_protein_df.columns[1:]\n",
    "cleaned_protein_samples = ['_'.join(s.split('_')[:2]) for s in original_protein_samples]\n",
    "# get the cleaned_protein_samples_dict_df from dictionary\n",
    "cleaned_protein_samples_dict_df = pd.DataFrame({'CCLE_Name': original_protein_samples, 'cleaned_CCLE_Name': cleaned_protein_samples})\n",
    "# merge the cleaned_protein_samples_df with the cell_line_anno_df\n",
    "merged_cleaned_protein_samples_df = pd.merge(cleaned_protein_samples_dict_df, cell_line_anno_df, left_on='cleaned_CCLE_Name', right_on='CCLE_ID', how='inner')\n",
    "display(merged_cleaned_protein_samples_df)\n",
    "# get the depmap protein id\n",
    "depmap_protein_samples = merged_cleaned_protein_samples_df['depMapID'].to_list()\n",
    "\n",
    "# Create the map dictionary for the protein samples\n",
    "protein_map_dict = dict(zip(original_protein_samples, depmap_protein_samples))\n",
    "print(protein_map_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the methy_mapped_samples and protein_mapped_samples from merged_methy_samples_df and merged_cleaned_protein_samples_df\n",
    "methy_mapped_samples = list(merged_methy_samples_df['depMapID'])\n",
    "protein_mapped_samples = list(merged_cleaned_protein_samples_df['depMapID'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.1 Aligning over omics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 6.1.1 Aligning with intersection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the overlapped samples only over omics samples\n",
    "omics_overlapped_samples = sorted(list(set(methy_mapped_samples) & set(gene_samples) & set(transcript_samples) & set(protein_mapped_samples)))\n",
    "print(\"length of omics overlapped samples: \", len(omics_overlapped_samples))\n",
    "# check the overlapped samples with annotation samples\n",
    "annotation_samples = list(cell_line_anno_df['depMapID'])\n",
    "overlapped_omics_annotation_samples = sorted(list(set(omics_overlapped_samples) & set(annotation_samples)))\n",
    "print(\"length of overlapped omics and annotation samples: \", len(overlapped_omics_annotation_samples))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 6.1.2 Aligning with union on omics and intersection with annotation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the union samples only over omics samples\n",
    "omics_union_samples = sorted(list(set(methy_mapped_samples) | set(gene_samples) | set(transcript_samples) | set(protein_mapped_samples)))\n",
    "print(\"length of omics union samples: \", len(omics_union_samples))\n",
    "# check the intersection of union samples with annotation samples\n",
    "overlapped_omics_union_annotation_samples = sorted(list(set(omics_union_samples) & set(annotation_samples)))\n",
    "print(\"length of overlapped omics and annotation samples: \", len(overlapped_omics_union_annotation_samples))\n",
    "print(overlapped_omics_union_annotation_samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 6.1.3 overlapped_omics_union_annotation_samples (cancerous / non-cancerous) Ratio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "non_cancerous_samples_df = pd.read_csv('./data/raw_data/cell-lines-in-Non-Cancerous.csv')\n",
    "display(non_cancerous_samples_df)\n",
    "# get the depmap id from the non_cancerous_samples_df\n",
    "non_cancerous_samples = list(non_cancerous_samples_df['Depmap Id'])\n",
    "# get the overlapped non_cancerous samples with the overlapped_omics_union_annotation_samples\n",
    "overlapped_non_cancerous_samples = sorted(list(set(overlapped_omics_union_annotation_samples) & set(non_cancerous_samples)))\n",
    "print(\"length of overlapped omics and annotation samples: \", len(overlapped_omics_union_annotation_samples))\n",
    "print(\"length of overlapped non cancerous samples: \", len(overlapped_non_cancerous_samples))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.2 Overlapping over DTI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the overlapped samples among all samples (methy_mapped_samples, gene_samples, transcript_samples, protein_mapped_samples, drug_samples)\n",
    "dti_overlapped_samples = sorted(list(set(overlapped_omics_union_annotation_samples) & set(drug_samples)))\n",
    "print(\"len(overlapped_samples):\", len(dti_overlapped_samples))\n",
    "# need to confirm zero intersection between the overlapped samples and non_cancerous_samples\n",
    "dti_overlapped_non_cancerous_samples = sorted(list(set(dti_overlapped_samples) & set(non_cancerous_samples)))\n",
    "print(\"len(overlapped_non_cancerous_samples):\", len(dti_overlapped_non_cancerous_samples))\n",
    "# if there are overlapped samples between the dti_overlapped_samples and non_cancerous_samples, then remove these samples from the dti_overlapped_samples\n",
    "if len(dti_overlapped_non_cancerous_samples) > 0:\n",
    "    dti_overlapped_samples = sorted(list(set(dti_overlapped_samples) - set(dti_overlapped_non_cancerous_samples)))\n",
    "# convert the overlapped_samples to the dataframe with annotation\n",
    "dti_overlapped_samples_df = pd.merge(pd.DataFrame(dti_overlapped_samples, columns=['depMapID']), cell_line_anno_df, on='depMapID', how='inner')\n",
    "display(dti_overlapped_samples_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mkdir for processsing folder\n",
    "import os\n",
    "if not os.path.exists('./data/process_data'):\n",
    "    os.makedirs('./data/process_data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dti_overlapped_samples_df.to_csv('./data/process_data/dti_overlapped_samples.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(dti_overlapped_samples_df.columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Annotate the samples with disease"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.1 Add annotation from Cellosaurus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pandas as pd\n",
    "# import requests\n",
    "\n",
    "# url = 'https://ftp.expasy.org/databases/cellosaurus/cellosaurus.obo'\n",
    "# response = requests.get(url)\n",
    "\n",
    "# with open('./data/raw_data/cellosaurus.obo', 'wb') as f:\n",
    "#     f.write(response.content)\n",
    "\n",
    "# # Path to the downloaded OBO file\n",
    "# obo_path = './data/raw_data/cellosaurus.obo'\n",
    "\n",
    "# entries = []\n",
    "# current_entry = {}\n",
    "\n",
    "# # Read and parse\n",
    "# with open(obo_path, 'r', encoding='utf-8') as file:\n",
    "#     for line in file:\n",
    "#         line = line.strip()\n",
    "\n",
    "#         if line == \"[Term]\":\n",
    "#             if current_entry.get(\"xref\") and any(\"NCBI_TaxID:9606\" in x for x in current_entry[\"xref\"]):\n",
    "#                 # Separate NCIt and ORDO xrefs\n",
    "#                 ncit_refs = [x for x in current_entry[\"xref\"] if x.startswith(\"NCIt:\")]\n",
    "#                 ordo_refs = [x for x in current_entry[\"xref\"] if x.startswith(\"ORDO:\")]\n",
    "#                 current_entry[\"xref_NCIt\"] = \"; \".join(ncit_refs)\n",
    "#                 current_entry[\"xref_ORDO\"] = \"; \".join(ordo_refs)\n",
    "#                 entries.append({\n",
    "#                     \"id\": current_entry.get(\"id\", \"\"),\n",
    "#                     \"name\": current_entry.get(\"name\", \"\"),\n",
    "#                     \"synonym\": \"; \".join(current_entry.get(\"synonym\", [])),\n",
    "#                     \"xref_NCIt\": current_entry.get(\"xref_NCIt\", \"\"),\n",
    "#                     \"xref_ORDO\": current_entry.get(\"xref_ORDO\", \"\")\n",
    "#                 })\n",
    "#             current_entry = {\"synonym\": [], \"xref\": []}\n",
    "\n",
    "#         elif line.startswith(\"id:\"):\n",
    "#             current_entry[\"id\"] = line.split(\"id:\")[1].strip()\n",
    "\n",
    "#         elif line.startswith(\"name:\"):\n",
    "#             current_entry[\"name\"] = line.split(\"name:\")[1].strip()\n",
    "\n",
    "#         elif line.startswith(\"synonym:\"):\n",
    "#             synonym = line.split(\"synonym:\")[1].split(\"RELATED\")[0].strip().strip('\"')\n",
    "#             current_entry[\"synonym\"].append(synonym)\n",
    "\n",
    "#         elif line.startswith(\"xref:\"):\n",
    "#             current_entry[\"xref\"].append(line.split(\"xref:\")[1].strip())\n",
    "\n",
    "# # Handle last entry\n",
    "# if current_entry.get(\"xref\") and any(\"NCBI_TaxID:9606\" in x for x in current_entry[\"xref\"]):\n",
    "#     ncit_refs = [x for x in current_entry[\"xref\"] if x.startswith(\"NCIt:\")]\n",
    "#     ordo_refs = [x for x in current_entry[\"xref\"] if x.startswith(\"ORDO:\")]\n",
    "#     entries.append({\n",
    "#         \"id\": current_entry.get(\"id\", \"\"),\n",
    "#         \"name\": current_entry.get(\"name\", \"\"),\n",
    "#         \"synonym\": \"; \".join(current_entry.get(\"synonym\", [])),\n",
    "#         \"xref_NCIt\": \"; \".join(ncit_refs),\n",
    "#         \"xref_ORDO\": \"; \".join(ordo_refs)\n",
    "#     })\n",
    "\n",
    "# # Create DataFrame\n",
    "# cellosaurus_parsed_df = pd.DataFrame(entries)\n",
    "# cellosaurus_parsed_df = cellosaurus_parsed_df[[\"id\", \"name\", \"synonym\", \"xref_NCIt\", \"xref_ORDO\"]]\n",
    "# cellosaurus_parsed_df.to_csv('./data/raw_data/cellosaurus_parsed.csv', index=False)\n",
    "# # Show preview\n",
    "# display(cellosaurus_parsed_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Step 1: Keep the desired columns\n",
    "# dti_overlapped_samples_desc_df = dti_overlapped_samples_df[['depMapID', 'Name', 'Pathology', 'Histology', 'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code']]\n",
    "# display(dti_overlapped_samples_desc_df)\n",
    "\n",
    "# # Step 2: Try direct merge on Name == name using left join\n",
    "# matched_df = pd.merge(\n",
    "#     dti_overlapped_samples_desc_df,\n",
    "#     cellosaurus_parsed_df,\n",
    "#     left_on='Name',\n",
    "#     right_on='name',\n",
    "#     how='left'\n",
    "# )\n",
    "\n",
    "# # Step 3: Find rows where no match occurred (name is NaN)\n",
    "# unmatched_df = matched_df[matched_df['name'].isna()].copy()\n",
    "\n",
    "# # Step 4: Drop the Cellosaurus columns to prep for synonym match\n",
    "# cols_to_drop = [col for col in matched_df.columns if col in cellosaurus_parsed_df.columns]\n",
    "# unmatched_df = unmatched_df.drop(columns=cols_to_drop)\n",
    "\n",
    "# # Step 5: Expand cellosaurus synonyms\n",
    "# synonym_expanded_df = cellosaurus_parsed_df.copy()\n",
    "# synonym_expanded_df = synonym_expanded_df.dropna(subset=['synonym'])\n",
    "# synonym_expanded_df = synonym_expanded_df.assign(\n",
    "#     synonym=synonym_expanded_df['synonym'].str.split(';')\n",
    "# ).explode('synonym')\n",
    "# synonym_expanded_df['synonym'] = synonym_expanded_df['synonym'].str.strip()\n",
    "\n",
    "# # Step 6: Left join unmatched rows with synonym-expanded Cellosaurus\n",
    "# synonym_matched_df = pd.merge(\n",
    "#     unmatched_df,\n",
    "#     synonym_expanded_df,\n",
    "#     left_on='Name',\n",
    "#     right_on='synonym',\n",
    "#     how='left'\n",
    "# )\n",
    "\n",
    "# # Step 7: Combine direct-matched (non-NaN name) and synonym-matched rows\n",
    "# all_match_df = pd.concat(\n",
    "#     [matched_df[matched_df['name'].notna()], synonym_matched_df],\n",
    "#     ignore_index=True\n",
    "# )\n",
    "\n",
    "# # Optional: Drop duplicates based on depMapID if needed\n",
    "# all_match_df = all_match_df.drop_duplicates(subset=['depMapID'])\n",
    "# display(all_match_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Function to store primary and extra xref matches with better column handling\n",
    "# def process_xref_columns_with_extras(df):\n",
    "#     # Create a copy to avoid modifying the original\n",
    "#     result_df = df.copy()\n",
    "    \n",
    "#     # Always create all the columns, even if they might end up empty\n",
    "#     result_df['xref_NCIt_ID'] = ''\n",
    "#     result_df['xref_NCIt_name'] = ''\n",
    "#     result_df['xref_NCIt_extra_ID'] = ''\n",
    "#     result_df['xref_NCIt_extra_name'] = ''\n",
    "#     result_df['xref_ORDO_ID'] = ''\n",
    "#     result_df['xref_ORDO_name'] = ''\n",
    "#     result_df['xref_ORDO_extra_ID'] = ''\n",
    "#     result_df['xref_ORDO_extra_name'] = ''\n",
    "    \n",
    "#     # Process xref_NCIt if it exists\n",
    "#     if 'xref_NCIt' in result_df.columns:\n",
    "#         # Only process rows that have xref_NCIt values\n",
    "#         mask = result_df['xref_NCIt'].notna() & (result_df['xref_NCIt'] != '')\n",
    "        \n",
    "#         for idx in result_df[mask].index:\n",
    "#             xref_value = result_df.loc[idx, 'xref_NCIt']\n",
    "            \n",
    "#             if '; ' in xref_value:  # Multiple entries\n",
    "#                 entries = xref_value.split('; ')\n",
    "                \n",
    "#                 # Process first entry for primary columns\n",
    "#                 if ' ! ' in entries[0]:\n",
    "#                     id_val, name_val = entries[0].split(' ! ', 1)\n",
    "#                     result_df.loc[idx, 'xref_NCIt_ID'] = id_val\n",
    "#                     result_df.loc[idx, 'xref_NCIt_name'] = name_val\n",
    "                \n",
    "#                 # Process additional entries for extra columns\n",
    "#                 extra_ids = []\n",
    "#                 extra_names = []\n",
    "#                 for entry in entries[1:]:\n",
    "#                     if ' ! ' in entry:\n",
    "#                         id_val, name_val = entry.split(' ! ', 1)\n",
    "#                         extra_ids.append(id_val)\n",
    "#                         extra_names.append(name_val)\n",
    "                \n",
    "#                 if extra_ids:\n",
    "#                     result_df.loc[idx, 'xref_NCIt_extra_ID'] = '; '.join(extra_ids)\n",
    "#                     result_df.loc[idx, 'xref_NCIt_extra_name'] = '; '.join(extra_names)\n",
    "                    \n",
    "#             elif ' ! ' in xref_value:  # Single entry\n",
    "#                 id_val, name_val = xref_value.split(' ! ', 1)\n",
    "#                 result_df.loc[idx, 'xref_NCIt_ID'] = id_val\n",
    "#                 result_df.loc[idx, 'xref_NCIt_name'] = name_val\n",
    "        \n",
    "#         # Drop the original column\n",
    "#         result_df = result_df.drop(columns=['xref_NCIt'])\n",
    "    \n",
    "#     # Process xref_ORDO if it exists\n",
    "#     if 'xref_ORDO' in result_df.columns:\n",
    "#         # Only process rows that have xref_ORDO values\n",
    "#         mask = result_df['xref_ORDO'].notna() & (result_df['xref_ORDO'] != '')\n",
    "        \n",
    "#         for idx in result_df[mask].index:\n",
    "#             xref_value = result_df.loc[idx, 'xref_ORDO']\n",
    "            \n",
    "#             if '; ' in xref_value:  # Multiple entries\n",
    "#                 entries = xref_value.split('; ')\n",
    "                \n",
    "#                 # Process first entry for primary columns\n",
    "#                 if ' ! ' in entries[0]:\n",
    "#                     id_val, name_val = entries[0].split(' ! ', 1)\n",
    "#                     result_df.loc[idx, 'xref_ORDO_ID'] = id_val\n",
    "#                     result_df.loc[idx, 'xref_ORDO_name'] = name_val\n",
    "                \n",
    "#                 # Process additional entries for extra columns\n",
    "#                 extra_ids = []\n",
    "#                 extra_names = []\n",
    "#                 for entry in entries[1:]:\n",
    "#                     if ' ! ' in entry:\n",
    "#                         id_val, name_val = entry.split(' ! ', 1)\n",
    "#                         extra_ids.append(id_val)\n",
    "#                         extra_names.append(name_val)\n",
    "                \n",
    "#                 if extra_ids:\n",
    "#                     result_df.loc[idx, 'xref_ORDO_extra_ID'] = '; '.join(extra_ids)\n",
    "#                     result_df.loc[idx, 'xref_ORDO_extra_name'] = '; '.join(extra_names)\n",
    "                    \n",
    "#             elif ' ! ' in xref_value:  # Single entry\n",
    "#                 id_val, name_val = xref_value.split(' ! ', 1)\n",
    "#                 result_df.loc[idx, 'xref_ORDO_ID'] = id_val\n",
    "#                 result_df.loc[idx, 'xref_ORDO_name'] = name_val\n",
    "        \n",
    "#         # Drop the original column\n",
    "#         result_df = result_df.drop(columns=['xref_ORDO'])\n",
    "    \n",
    "#     return result_df\n",
    "\n",
    "# # Apply the function\n",
    "# all_match_df = process_xref_columns_with_extras(all_match_df)\n",
    "\n",
    "# # Update the column list to include the extra columns\n",
    "# dti_overlapped_samples_desc_co_df = all_match_df[['depMapID', 'Name', 'Pathology', 'Histology', \n",
    "#                                              'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code', 'id', \n",
    "#                                              'xref_NCIt_ID', 'xref_NCIt_name', 'xref_NCIt_extra_ID', 'xref_NCIt_extra_name',\n",
    "#                                              'xref_ORDO_ID', 'xref_ORDO_name', 'xref_ORDO_extra_ID', 'xref_ORDO_extra_name']]\n",
    "# display(dti_overlapped_samples_desc_co_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.2 Cell line disease match"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 7.2.1 Cell line disease hard match"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# bmg_disease_name_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Disease/BioMedGraphica_Conn_Disease_GUI_Name.csv')\n",
    "# display(bmg_disease_name_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Step 1: Expand BMG_Disease_Name (still lowercased and exploded)\n",
    "# bmg_expanded = bmg_disease_name_df.copy()\n",
    "\n",
    "# # First, strip any surrounding quotes and fill NaN values with empty strings\n",
    "# bmg_expanded['Disease_Name_List'] = bmg_expanded['Disease_Name_List'].fillna('')\n",
    "# bmg_expanded['Disease_Name_List'] = bmg_expanded['Disease_Name_List'].str.replace('^\"|\"$', '', regex=True)\n",
    "\n",
    "# # Split on pipe with spaces\n",
    "# bmg_expanded['Disease_Name_List'] = bmg_expanded['Disease_Name_List'].str.split(' \\| ')\n",
    "\n",
    "# # Apply strip only to non-empty lists\n",
    "# bmg_expanded['Disease_Name_List'] = bmg_expanded['Disease_Name_List'].apply(\n",
    "#     lambda x: [name.strip() for name in x] if isinstance(x, list) else []\n",
    "# )\n",
    "\n",
    "# # Explode the list into separate rows\n",
    "# bmg_expanded = bmg_expanded.explode('Disease_Name_List')\n",
    "\n",
    "# # Remove rows with empty disease names after exploding\n",
    "# bmg_expanded = bmg_expanded[bmg_expanded['Disease_Name_List'].str.len() > 0]\n",
    "\n",
    "# # Create lowercase version for easier matching\n",
    "# bmg_expanded['Disease_Name_List_lower'] = bmg_expanded['Disease_Name_List'].str.lower()\n",
    "\n",
    "# display(bmg_expanded)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Create a lookup list of tuples: (Conn_ID, Original_Name, Lower_Name)\n",
    "# bmg_records = bmg_expanded[['BioMedGraphica_Conn_ID', 'Disease_Name_List', 'Disease_Name_List_lower']].to_records(index=False)\n",
    "\n",
    "# # Updated matching function: look for exact full-name match\n",
    "# def match_bmg_disease_exact(xref_name, bmg_records):\n",
    "#     if pd.isna(xref_name):\n",
    "#         return (None, None)\n",
    "#     xref_name_lower = str(xref_name).strip().lower()\n",
    "    \n",
    "#     for conn_id, original_name, bmg_name_lower in bmg_records:\n",
    "#         if xref_name_lower == bmg_name_lower:\n",
    "#             return (conn_id, original_name)\n",
    "    \n",
    "#     return (None, None)\n",
    "\n",
    "# # NCIt\n",
    "# ncit_matches = dti_overlapped_samples_desc_co_df['xref_NCIt_name'].apply(lambda x: match_bmg_disease_exact(x, bmg_records))\n",
    "# dti_overlapped_samples_desc_co_df['NCIt_BMGC_ID'] = ncit_matches.apply(lambda x: x[0])\n",
    "# dti_overlapped_samples_desc_co_df['NCIt_BMGC_name'] = ncit_matches.apply(lambda x: x[1])\n",
    "\n",
    "# # ORDO\n",
    "# ordo_matches = dti_overlapped_samples_desc_co_df['xref_ORDO_name'].apply(lambda x: match_bmg_disease_exact(x, bmg_records))\n",
    "# dti_overlapped_samples_desc_co_df['ORDO_BMGC_ID'] = ordo_matches.apply(lambda x: x[0])\n",
    "# dti_overlapped_samples_desc_co_df['ORDO_BMGC_name'] = ordo_matches.apply(lambda x: x[1])\n",
    "\n",
    "# display(dti_overlapped_samples_desc_co_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # keep the columns ['depMapID', 'Name', 'Pathology', 'Histology', 'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code', 'id', 'xref_NCIt_ID', 'xref_NCIt_name', 'NCIt_BMGC_ID', 'NCIt_BMGC_name', 'xref_ORDO_ID', 'xref_ORDO_name', 'ORDO_BMGC_ID', 'ORDO_BMGC_name']\n",
    "# dti_overlapped_samples_desc_co_df = dti_overlapped_samples_desc_co_df[['depMapID', 'Name', 'Pathology', 'Histology', 'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code', 'id', 'xref_NCIt_ID', 'xref_NCIt_name', 'NCIt_BMGC_ID', 'NCIt_BMGC_name', 'xref_ORDO_ID', 'xref_ORDO_name', 'ORDO_BMGC_ID', 'ORDO_BMGC_name']]\n",
    "# display(dti_overlapped_samples_desc_co_df)\n",
    "# dti_overlapped_samples_desc_co_df.to_csv('./data/process_data/dti_overlapped_samples_desc_co.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Merge NCIt_BMGC_ID/ORDO_BMGC_ID and NCIt_BMGC_name/ORDO_BMGC_name into BMGC_Matched columns\n",
    "# Strategy:\n",
    "# 1. Prioritize NCIt data when both are available\n",
    "# 2. Use NCIt data if only NCIt is available\n",
    "# 3. Use ORDO data if only ORDO is available\n",
    "# 4. Leave as null if both are null\n",
    "\n",
    "# dti_matched_df = dti_overlapped_samples_desc_co_df.copy()\n",
    "\n",
    "# # Create new merged ID column\n",
    "# dti_matched_df['BMGC_Matched_ID'] = (\n",
    "#     dti_matched_df['NCIt_BMGC_ID'].fillna(\n",
    "#         dti_matched_df['ORDO_BMGC_ID']\n",
    "#     )\n",
    "# )\n",
    "\n",
    "# # Create new merged name column\n",
    "# dti_matched_df['BMGC_Matched_name'] = (\n",
    "#     dti_matched_df['NCIt_BMGC_name'].fillna(\n",
    "#         dti_matched_df['ORDO_BMGC_name']\n",
    "#     )\n",
    "# )\n",
    "\n",
    "# # Display the dataframe with new columns\n",
    "# display(dti_matched_df)\n",
    "\n",
    "# # Check number of null values in new columns\n",
    "# print(f\"Null values in BMGC_Matched_ID: {dti_matched_df['BMGC_Matched_ID'].isna().sum()}\")\n",
    "# print(f\"Null values in BMGC_Matched_name: {dti_matched_df['BMGC_Matched_name'].isna().sum()}\")\n",
    "\n",
    "# # only keep the columns where the BMGC_Matched_ID is not null\n",
    "# dti_matched_df = dti_matched_df[dti_matched_df['BMGC_Matched_ID'].notna()].reset_index(drop=True)\n",
    "# display(dti_matched_df)\n",
    "\n",
    "# # Save to CSV\n",
    "# dti_matched_df.to_csv('./data/process_data/dti_matched.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 7.2.2 Cell line disease soft match"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # check the unmatched samples in the dti_overlapped_samples_desc_co_df by checking the NaN in the NCIt_BMGC_ID and ORDO_BMGC_ID\n",
    "# unmatched_samples_df = dti_overlapped_samples_desc_co_df[dti_overlapped_samples_desc_co_df['NCIt_BMGC_ID'].isna() & dti_overlapped_samples_desc_co_df['ORDO_BMGC_ID'].isna()]\n",
    "# # keep the columns ['depMapID', 'Name', 'Pathology', 'Histology', 'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code', 'id', 'xref_NCIt_name', , 'xref_ORDO_name']\n",
    "# unmatched_samples_df = unmatched_samples_df[['depMapID', 'Name', 'Pathology', 'Histology', 'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code', 'id', 'xref_NCIt_name', 'xref_ORDO_name']]\n",
    "# display(unmatched_samples_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# embed the disease name for bmg_disease_df by bert-based model\n",
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from typing import List, Tuple, Dict\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "class SentenceDataset(Dataset):\n",
    "    def __init__(self, sentences: List[str]):\n",
    "        self.sentences = sentences\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.sentences)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.sentences[idx]\n",
    "\n",
    "\n",
    "class TextEncoder():\n",
    "    def __init__(self, model_path: str = \"dmis-lab/biobert-v1.1\", device: str = \"cuda\"):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            model_path (str, optional): Path to the deberta model. Defaults to 'dmis-lab/biobert-v1.1'.\n",
    "            device (str, optional): Device to run the model on ('cpu' or 'cuda'). Defaults to 'cpu'.\n",
    "        \"\"\"\n",
    "        self.model_path = model_path\n",
    "        self.device = device\n",
    "        self.model = None\n",
    "        self.tokenizer = None\n",
    "\n",
    "    def load_model(self):\n",
    "        \"\"\"\n",
    "        Load the deberta model and tokenizer from the specified model path.\n",
    "        \"\"\"\n",
    "        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)\n",
    "        self.model = AutoModel.from_pretrained(self.model_path).to(self.device)\n",
    "\n",
    "    def generate_embeddings(self, sentences: List[str], batch_size: int = 32, seq_emb_dim: int = 64) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        Generate a single-dimensional embedding for each sentence.\n",
    "\n",
    "        Args:\n",
    "            sentences (List[str]): List of sentences to embed.\n",
    "            batch_size (int, optional): Batch size for DataLoader. Defaults to 32.\n",
    "\n",
    "        Returns:\n",
    "            List[float]: List of single-dimensional embeddings.\n",
    "        \"\"\"\n",
    "        dataset = SentenceDataset(sentences)\n",
    "        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "        embedding_batches = []\n",
    "        for batch in tqdm(dataloader, desc=\"Embedding sentences\", unit=\"batch\"):\n",
    "            inputs = self.tokenizer(batch, return_tensors=\"pt\", padding=True, truncation=True, max_length=512).to(self.device)\n",
    "            with torch.no_grad():\n",
    "                outputs = self.model(**inputs)\n",
    "            # Handle single batch case properly\n",
    "            mean_embeddings = torch.mean(outputs.last_hidden_state, dim=1)  # [batch_size, hidden_dim]\n",
    "            \n",
    "            # For adaptive pooling, we need to reshape for 1D adaptive pooling\n",
    "            # [batch_size, 1, hidden_dim] -> [batch_size, 1, seq_emb_dim] -> [batch_size, seq_emb_dim]\n",
    "            batch_size = mean_embeddings.size(0)\n",
    "            reshaped = mean_embeddings.view(batch_size, 1, -1)\n",
    "            projected = torch.nn.functional.adaptive_avg_pool1d(reshaped, output_size=seq_emb_dim)\n",
    "            projected = projected.squeeze(1)  # Only squeeze dimension 1, keep batch dimension\n",
    "            embedding_batches.append(projected)\n",
    "        return torch.cat(embedding_batches, dim=0)\n",
    "\n",
    "    def save_embeddings(self, embeddings, output_npy_path):\n",
    "        \"\"\"\n",
    "        Save embeddings to a .npy file.\n",
    "        \n",
    "        Args:\n",
    "            embeddings (torch.Tensor): The embeddings to save.\n",
    "            output_npy_path (str): Path to save the embeddings file.\n",
    "        \"\"\"\n",
    "        # Move embeddings to CPU before converting to numpy\n",
    "        embeddings_cpu = embeddings.cpu().numpy()\n",
    "        np.save(output_npy_path, embeddings_cpu)\n",
    "        print(f\"Embeddings saved at {output_npy_path} with shape {embeddings_cpu.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # convert the bmg_disease_name_df to the list of disease names of each rows\n",
    "# disease_names = list(bmg_disease_name_df['Disease_Name_List'])\n",
    "# # to make sure each disease name is a string\n",
    "# disease_names = [str(name) for name in disease_names]\n",
    "# print(len(disease_names))\n",
    "# print(disease_names[:5])\n",
    "# print(disease_names[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#  # Use language model to embed the name and description\n",
    "# name_sentence_list = disease_names\n",
    "# text_encoder = TextEncoder()\n",
    "# text_encoder.load_model()\n",
    "# name_embeddings = text_encoder.generate_embeddings(name_sentence_list, batch_size=32, seq_emb_dim=768)\n",
    "# print(name_embeddings.shape)\n",
    "# # mkdir folder BMG_emb\n",
    "# if not os.path.exists('./data/BMG_emb'):\n",
    "#     os.makedirs('./data/BMG_emb')\n",
    "# text_encoder.save_embeddings(name_embeddings, './data/BMG_emb/disease_name_embeddings.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # convert each row in the unmatched_samples_df to a sentence like \"Name: NIH:OVCAR-3, Pathology: OVARY, Histology: CARCINOMA, type: CANCER, PATHOLOGIST_ANNOTATION: OVARY, tcga_code: OV\"\n",
    "# depmap_desc_sentence_list = []\n",
    "# for idx, row in unmatched_samples_df.iterrows():\n",
    "#     depmap_desc_sentence_list.append(f\" {row['type']}, {row['PATHOLOGIST_ANNOTATION']}\")\n",
    "\n",
    "# # and create the ncit_desc_sentence_list and ordo_desc_sentence_list\n",
    "# ncit_desc_sentence_list = list(unmatched_samples_df['xref_NCIt_name'])\n",
    "# ordo_desc_sentence_list = list(unmatched_samples_df['xref_ORDO_name'])\n",
    "\n",
    "# # make sure each element in the depmap_desc_sentence_list, ncit_desc_sentence_list and ordo_desc_sentence_list is a string\n",
    "# depmap_desc_sentence_list = [str(name) for name in depmap_desc_sentence_list]\n",
    "# ncit_desc_sentence_list = [str(name) for name in ncit_desc_sentence_list]\n",
    "# ordo_desc_sentence_list = [str(name) for name in ordo_desc_sentence_list]\n",
    "\n",
    "# print(depmap_desc_sentence_list)\n",
    "# print(ncit_desc_sentence_list)\n",
    "# print(ordo_desc_sentence_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Define the sentence lists to process and their corresponding prefixes\n",
    "# sentence_list_mapping = {\n",
    "#     'depmap': depmap_desc_sentence_list,\n",
    "#     'ncit': ncit_desc_sentence_list,\n",
    "#     'ordo': ordo_desc_sentence_list\n",
    "# }\n",
    "\n",
    "# # Define top_k for this matching task\n",
    "# top_k = 2\n",
    "\n",
    "# # Process each sentence list and add the matches to the DataFrame\n",
    "# for prefix, sentence_list in sentence_list_mapping.items():\n",
    "#     # Lists to store the matches for each description\n",
    "#     all_matched_lists = [[] for _ in range(top_k)]\n",
    "#     all_matched_name_lists = [[] for _ in range(top_k)]\n",
    "#     all_similarity_lists = [[] for _ in range(top_k)]\n",
    "    \n",
    "#     for desc in sentence_list:\n",
    "#         desc_embeddings = text_encoder.generate_embeddings([desc], batch_size=1, seq_emb_dim=768)\n",
    "#         # Calculate cosine similarity\n",
    "#         cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)\n",
    "#         similarity = cos(name_embeddings, desc_embeddings)\n",
    "#         # Find the index of top k most similar diseases\n",
    "#         top_k_idx = torch.argsort(similarity, descending=True)[:top_k]\n",
    "        \n",
    "#         # Process each of the top k matches\n",
    "#         for rank, idx in enumerate(top_k_idx):\n",
    "#             idx_int = idx.item()\n",
    "#             disease_id = bmg_disease_name_df.iloc[idx_int]['BioMedGraphica_Conn_ID']\n",
    "#             disease_name = disease_names[idx_int]\n",
    "#             sim_score = similarity[idx_int].item()\n",
    "            \n",
    "#             # Store in the corresponding lists\n",
    "#             all_matched_lists[rank].append(disease_id)\n",
    "#             all_matched_name_lists[rank].append(disease_name)\n",
    "#             all_similarity_lists[rank].append(sim_score)\n",
    "    \n",
    "#     # Add the matched data to the DataFrame with the appropriate prefix\n",
    "#     for i in range(top_k):\n",
    "#         rank = i + 1\n",
    "#         unmatched_samples_df[f'{prefix}_match_{rank}_disease'] = all_matched_lists[i]\n",
    "#         unmatched_samples_df[f'{prefix}_match_{rank}_disease_name'] = all_matched_name_lists[i]\n",
    "#         unmatched_samples_df[f'{prefix}_match_{rank}_similarity'] = all_similarity_lists[i]\n",
    "    \n",
    "#     print(f\"Finished processing {prefix} descriptions\")\n",
    "\n",
    "# # Display the updated DataFrame\n",
    "# display(unmatched_samples_df)\n",
    "\n",
    "# # Save the updated DataFrame to CSV\n",
    "# unmatched_samples_df.to_csv('./data/process_data/unmatched_samples_softmatch.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.3 Manual filter and match"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# unmatched_samples_df = pd.read_csv('./data/process_data/unmatched_samples_softmatch.csv')\n",
    "# # add the 2 empty columnns 'BMGC_manual_ID' and 'BMGC_manual_name' to the unmatched_samples_manual_df\n",
    "# unmatched_samples_manual_df = unmatched_samples_df.copy()\n",
    "# unmatched_samples_manual_df['BMGC_manual_ID'] = None\n",
    "# unmatched_samples_manual_df['BMGC_manual_name'] = None\n",
    "# # change the order of the columns to make the 'BMGC_manual_ID' and 'BMGC_manual_name' after id\n",
    "# unmatched_samples_manual_df = unmatched_samples_manual_df[['depMapID', 'Name', 'Pathology', 'Histology', 'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code', 'id', 'BMGC_manual_ID', 'BMGC_manual_name', 'xref_NCIt_name', 'xref_ORDO_name', 'depmap_match_1_disease', 'depmap_match_1_disease_name', 'depmap_match_2_disease', 'depmap_match_2_disease_name', 'ncit_match_1_disease', 'ncit_match_1_disease_name', 'ncit_match_2_disease', 'ncit_match_2_disease_name', 'ordo_match_1_disease', 'ordo_match_1_disease_name', 'ordo_match_2_disease', 'ordo_match_2_disease_name']]\n",
    "# display(unmatched_samples_manual_df)\n",
    "# unmatched_samples_manual_df.to_csv('./data/process_data/unmatched_samples_manual.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.4 Combine matched and manual_unmatached"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Matched samples, rename BMGC_Matched_ID and BMGC_Matched_name to BMGC_ID and BMGC_name\n",
    "dti_matched_df = pd.read_csv('./data/process_data/dti_matched.csv').rename(columns={'BMGC_Matched_ID': 'BMGC_Disease_ID', 'BMGC_Matched_name': 'BMGC_Disease_name'})\n",
    "# keep the columns ['depMapID', 'Name', 'BMGC_Disease_ID', 'BMGC_Disease_name', 'Pathology', 'Histology', 'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code', 'xref_NCIt_ID', 'xref_NCIt_name', 'xref_ORDO_ID', 'xref_ORDO_name']\n",
    "dti_matched_df = dti_matched_df[['depMapID', 'Name', 'BMGC_Disease_ID', 'BMGC_Disease_name', 'Pathology', 'Histology', 'type', 'PATHOLOGIST_ANNOTATION', 'tcga_code']]\n",
    "display(dti_matched_df)\n",
    "# Unmatched samples, rename BMGC_manual_ID and BMGC_manual_name to BMGC_Disease_ID and BMGC_Disease_name\n",
    "dti_unmatched_manual_df = pd.read_csv('./data/process_data/unmatched_samples_manual.csv').rename(columns={'BMGC_manual_ID': 'BMGC_Disease_ID', 'BMGC_manual_name': 'BMGC_Disease_name'})\n",
    "display(dti_unmatched_manual_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dti_unmatched_manual_valid_df = pd.read_csv('./data/manual_match_data/unmatched_samples_manual_valid.csv').rename(columns={'BMGC_manual_ID': 'BMGC_Disease_ID', 'BMGC_manual_name': 'BMGC_Disease_name'})\n",
    "display(dti_unmatched_manual_valid_df)\n",
    "# concatenate the dti_matched_df and dti_unmatched_manual_df and sort by depMapID\n",
    "dti_combined_df = pd.concat([dti_matched_df, dti_unmatched_manual_valid_df], ignore_index=True)\n",
    "# sort the dti_combined_df by depMapID\n",
    "dti_combined_df = dti_combined_df.sort_values(by='depMapID').reset_index(drop=True)\n",
    "# check if there is any empty value in the BMGC_ID and BMGC_name columns\n",
    "print(dti_combined_df[['BMGC_Disease_ID', 'BMGC_Disease_name']].isnull().sum())\n",
    "# convert all content in 'BMGC_Disease_name' to lower case\n",
    "dti_combined_df['BMGC_Disease_name'] = dti_combined_df['BMGC_Disease_name'].str.lower()\n",
    "display(dti_combined_df)\n",
    "dti_combined_df.to_csv('./data/process_data/dti_combined_samples.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. CRISPR/RNAi Biomarkers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs('./data/pretrain_plain_data', exist_ok=True)\n",
    "os.makedirs('./data/pretrain_status_data', exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the name and desc files\n",
    "bmgc_promoter_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Promoter/BioMedGraphica_Conn_Promoter_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_gene_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_transcript_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Transcript/BioMedGraphica_Conn_Transcript_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_protein_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "\n",
    "bmgc_omics_name_id_tmp_df = pd.concat([bmgc_promoter_name_id_df, bmgc_gene_name_id_df, bmgc_transcript_name_id_df, bmgc_protein_name_id_df], axis=0).reset_index(drop=True)\n",
    "bmgc_omics_name_id_df = pd.merge(bmgc_omics_df[['BioMedGraphica_Conn_ID']], bmgc_omics_name_id_tmp_df, on='BioMedGraphica_Conn_ID', how='left')\n",
    "# check the null values in the bmgc_omics_name_id_df\n",
    "print(\"Null values in bmgc_omics_name_id_df:\")\n",
    "print(bmgc_omics_name_id_df.isnull().sum())\n",
    "# fill the NaN values in the Names_and_IDs column with empty string\n",
    "bmgc_omics_name_id_df['Names_and_IDs'] = bmgc_omics_name_id_df['Names_and_IDs'].fillna(' ')\n",
    "# recheck the null values in the bmgc_omics_name_id_dff\n",
    "print(\"Null values in bmgc_omics_name_id_df:\")\n",
    "print(bmgc_omics_name_id_df.isnull().sum())\n",
    "display(bmgc_omics_name_id_df)\n",
    "bmgc_omics_name_id_df.to_csv('./data/pretrain_plain_data/bmgc_omics_name.csv', index=False)\n",
    "bmgc_omics_name_id_df.to_csv('./data/pretrain_status_data/bmgc_omics_name.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bmgc_promoter_desc_df = bmgc_promoter_name_id_df.drop(columns=['Names_and_IDs'], axis=1).copy()\n",
    "bmgc_promoter_desc_df['Description'] = np.nan # add the Description column to bmgc_promoter_name_df with NaN values\n",
    "bmgc_gene_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'], axis=1)\n",
    "bmgc_transcript_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Transcript/BioMedGraphica_Conn_Transcript_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'], axis=1)\n",
    "bmgc_protein_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'], axis=1)\n",
    "\n",
    "# concat and drop BioMedGraphica_ID column\n",
    "bmgc_omics_desc_tmp_df = pd.concat([bmgc_promoter_desc_df, bmgc_gene_desc_df, bmgc_transcript_desc_df, bmgc_protein_desc_df], axis=0).reset_index(drop=True)\n",
    "# check the null values in the bmgc_omics_desc_tmp_df\n",
    "print(bmgc_omics_desc_tmp_df.isnull().sum())\n",
    "bmgc_omics_desc_df = pd.merge(bmgc_omics_df[['BioMedGraphica_Conn_ID']], bmgc_omics_desc_tmp_df, on='BioMedGraphica_Conn_ID', how='left')\n",
    "# check the null values in the bmgc_omics_desc_df\n",
    "print(bmgc_omics_desc_df.isnull().sum())\n",
    "# fill NaN values with empty string in Description column\n",
    "bmgc_omics_desc_df['Description'] = bmgc_omics_desc_df['Description'].fillna(' ')\n",
    "# recheck the null values in the bmgc_omics_desc_df\n",
    "print(bmgc_omics_desc_df.isnull().sum())\n",
    "display(bmgc_omics_desc_df)\n",
    "\n",
    "bmgc_omics_desc_df.to_csv('./data/pretrain_plain_data/bmgc_omics_desc.csv', index=False)\n",
    "bmgc_omics_desc_df.to_csv('./data/pretrain_status_data/bmgc_omics_desc.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 8.0 Entity markers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# translation chain converging to the same node\n",
    "# fetch the promoter, gene, transcript and protein entity alone\n",
    "promoter_entity_df = bmgc_omics_df[bmgc_omics_df['Type'] == 'Promoter'].copy()\n",
    "gene_entity_df = bmgc_omics_df[bmgc_omics_df['Type'] == 'Gene'].copy()\n",
    "transcript_entity_df = bmgc_omics_df[bmgc_omics_df['Type'] == 'Transcript'].copy()\n",
    "protein_entity_df = bmgc_omics_df[bmgc_omics_df['Type'] == 'Protein'].copy()\n",
    "\n",
    "display(bmgc_omics_relation_df)\n",
    "# recheck the null values in bmgc_omics_relation_df\n",
    "print(\"Null values in bmgc_omics_relation_df:\")\n",
    "print(bmgc_omics_relation_df.isnull().sum())\n",
    "\n",
    "# fetch the Promoter-Gene, Gene-Transcript, Transcript-Protein relation alone\n",
    "promoter_gene_relation_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Promoter-Gene'].copy()\n",
    "gene_transcript_relation_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Gene-Transcript'].copy()\n",
    "transcript_protein_relation_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Transcript-Protein'].copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gene_transcript_entity_df = pd.merge(gene_entity_df, gene_transcript_relation_df[['BMGC_From_ID', 'BMGC_To_ID']], left_on='BioMedGraphica_Conn_ID', right_on='BMGC_From_ID', how='outer')\n",
    "gene_transcript_protein_entity_df = pd.merge(gene_transcript_entity_df, transcript_protein_relation_df[['BMGC_From_ID', 'BMGC_To_ID']], left_on='BMGC_To_ID', right_on='BMGC_From_ID', how='outer')\n",
    "# drop NaN values in BMGC_From_ID_x\tBMGC_To_ID_x BMGC_From_ID_y\tBMGC_To_ID_y\n",
    "gene_transcript_protein_entity_df = gene_transcript_protein_entity_df.dropna(subset=['BMGC_From_ID_x', 'BMGC_To_ID_x', 'BMGC_From_ID_y', 'BMGC_To_ID_y']).reset_index(drop=True)\n",
    "# keep the columns ['BioMedGraphica_Conn_ID', 'BMGC_To_ID_x', 'BMGC_To_ID_y'] and rename the columns to ['BMGC_GN_ID', 'BMGC_TS_ID', 'BMGC_PT_ID']\n",
    "gene_transcript_protein_entity_df = gene_transcript_protein_entity_df[['BioMedGraphica_Conn_ID', 'BMGC_To_ID_y']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_GN_ID', 'BMGC_To_ID_y': 'BMGC_PT_ID'}).sort_values(by='BMGC_GN_ID').reset_index(drop=True)\n",
    "# drop duplicates rows in gene_transcript_protein_entity_df\n",
    "gene_transcript_protein_entity_df = gene_transcript_protein_entity_df.drop_duplicates().reset_index(drop=True)\n",
    "display(gene_transcript_protein_entity_df)\n",
    "# just copy gene_transcript_protein_entity_df as promoter_gene_transcript_protein_entity_df\n",
    "bmgc_promoter_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Promoter/BioMedGraphica_Conn_Promoter.csv')\n",
    "bmgc_gene_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene.csv')\n",
    "promoter_gene_transcript_protein_entity_df = gene_transcript_protein_entity_df.copy()\n",
    "promoter_gene_df = pd.concat([bmgc_promoter_df[['BioMedGraphica_Conn_ID']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_PM_ID'}), bmgc_gene_df[['BioMedGraphica_Conn_ID']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_GN_ID'})], axis=1)\n",
    "promoter_protein_entity_df = pd.merge(promoter_gene_transcript_protein_entity_df, promoter_gene_df, left_on='BMGC_GN_ID', right_on='BMGC_GN_ID', how='left').drop(columns=['BMGC_GN_ID'])\n",
    "promoter_protein_entity_df = promoter_protein_entity_df[['BMGC_PM_ID', 'BMGC_PT_ID']]\n",
    "display(promoter_protein_entity_df)\n",
    "transcript_protein_entity_df = pd.merge(transcript_entity_df, transcript_protein_relation_df[['BMGC_From_ID', 'BMGC_To_ID']], left_on='BioMedGraphica_Conn_ID', right_on='BMGC_From_ID', how='outer')\n",
    "# drop NaN rows in the BMGC_From_ID and\tBMGC_To_ID\n",
    "transcript_protein_entity_df = transcript_protein_entity_df.dropna(subset=['BMGC_From_ID', 'BMGC_To_ID']).reset_index(drop=True)\n",
    "# keep the columns ['BioMedGraphica_Conn_ID', 'BMGC_To_ID'] and rename the columns to ['BMGC_TS_ID', 'BMGC_PT_ID']\n",
    "transcript_protein_entity_df = transcript_protein_entity_df[['BioMedGraphica_Conn_ID', 'BMGC_To_ID']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_TS_ID', 'BMGC_To_ID': 'BMGC_PT_ID'}).sort_values(by='BMGC_TS_ID').reset_index(drop=True)\n",
    "# drop duplicates rows in transcript_protein_entity_df\n",
    "transcript_protein_entity_df = transcript_protein_entity_df.drop_duplicates().reset_index(drop=True)\n",
    "display(transcript_protein_entity_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# keep the columns ['BioMedGraphica_Conn_ID'] and rename the columns to ['BMGC_PT_ID']\n",
    "only_protein_entity_df = protein_entity_df[['BioMedGraphica_Conn_ID']].rename(columns={'BioMedGraphica_Conn_ID': 'BMGC_PT_ID'}).sort_values(by='BMGC_PT_ID').reset_index(drop=True)\n",
    "display(only_protein_entity_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 8.1 CRISPR top 100 gene entities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "raw_crispr_df = pd.read_csv('./data/raw_data/CRISPRGeneEffect.csv')\n",
    "display(raw_crispr_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the column names of the raw_crispr_df aside from the first column and convert this to a list\n",
    "raw_crispr_df_columns = raw_crispr_df.columns[1:].tolist()\n",
    "print(raw_crispr_df_columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First, set the \"Unnamed: 0\" column as the index:\n",
    "raw_crispr_t_df = raw_crispr_df.set_index(\"Unnamed: 0\")\n",
    "# Then transpose the DataFrame:\n",
    "raw_crispr_t_df = raw_crispr_t_df.transpose()\n",
    "# move the index to a column and rename it as the gene_name\n",
    "raw_crispr_t_df.reset_index(inplace=True)\n",
    "raw_crispr_t_df.rename(columns={'index': 'gene_name'}, inplace=True)\n",
    "# column gene_name should be kept with the content by removing \"()\" and remove and \" \" content\n",
    "raw_crispr_t_df['gene_name'] = raw_crispr_t_df['gene_name'].apply(lambda x: x.split('(')[0].strip())\n",
    "# Optionally, if you want to view the result:\n",
    "display(raw_crispr_t_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Identify columns to convert (all except 'gene_name')\n",
    "cols_to_convert = raw_crispr_t_df.columns.difference(['gene_name'])\n",
    "# Convert these columns to numeric (using pd.to_numeric for safety, which can handle errors)\n",
    "raw_crispr_t_df[cols_to_convert] = raw_crispr_t_df[cols_to_convert].apply(pd.to_numeric, errors='coerce')\n",
    "raw_crispr_t_df = raw_crispr_t_df.fillna(0.0)\n",
    "display(raw_crispr_t_df)\n",
    "# groupby the gene_name and calculate the mean\n",
    "crispr_df = raw_crispr_t_df.groupby('gene_name', as_index=False).mean()\n",
    "display(crispr_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bmg_gene_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene.csv')\n",
    "# keep BioMedGraphica_ID and HGNC_Symbol\n",
    "bmg_gene_df = bmg_gene_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']]\n",
    "display(bmg_gene_df)\n",
    "# merge the biomedgraphica_id with the crispr_df\n",
    "merged_crispr_df = pd.merge(bmg_gene_df, crispr_df, left_on='HGNC_Symbol', right_on='gene_name', how='inner')\n",
    "merged_crispr_df.drop(columns=['HGNC_Symbol', 'gene_name'], inplace=True)\n",
    "display(merged_crispr_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# map the gene entity with gene_transcript_protein_entity_df\n",
    "protein_crispr_df = pd.merge(gene_transcript_protein_entity_df, merged_crispr_df, left_on='BMGC_GN_ID', right_on='BioMedGraphica_Conn_ID', how='left').drop(columns=['BioMedGraphica_Conn_ID'])\n",
    "# and drop nan in any rows from protein_crispr_df\n",
    "protein_crispr_df = protein_crispr_df.dropna().reset_index(drop=True)\n",
    "display(protein_crispr_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "overlapped_omics_union_annotation_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Keep BioMedGraphica_Conn_ID and columns that are in overlapped_omics_union_annotation_samples\n",
    "# For CRISPR data\n",
    "columns_to_keep = ['BMGC_GN_ID', 'BMGC_PT_ID'] + [col for col in protein_crispr_df.columns if col in overlapped_omics_union_annotation_samples]\n",
    "processed_crispr_df = protein_crispr_df[columns_to_keep]\n",
    "display(processed_crispr_df)\n",
    "processed_crispr_df.to_csv('./data/pretrain_plain_data/processed_crispr.csv', index=False)\n",
    "processed_crispr_df.to_csv('./data/pretrain_status_data/processed_crispr.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 8.2 RNAi top 100 transcript entities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_rna_combined_df = pd.read_csv('./data/raw_data/D2_combined_gene_dep_scores.csv')\n",
    "# fill NaN values with 0.0\n",
    "raw_rna_combined_df = raw_rna_combined_df.fillna(0.0)\n",
    "display(raw_rna_combined_df)\n",
    "# Move the Unnamed: 0 to a column and rename it as the gene_name\n",
    "raw_rna_combined_df.rename(columns={'Unnamed: 0': 'gene_name'}, inplace=True)\n",
    "# Clean the gene_name column by removing parentheses and extra spaces\n",
    "raw_rna_combined_df['gene_name'] = raw_rna_combined_df['gene_name'].apply(lambda x: x.split('(')[0].strip())\n",
    "# Display the result\n",
    "display(raw_rna_combined_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the overlapped cell lines in methy_samples with cell_line_anno_df by merge on cell line id\n",
    "rna_combined_samples = raw_rna_combined_df.columns[1:].tolist()\n",
    "# first format the methy_samples as dataframe\n",
    "rna_combined_samples_df = pd.DataFrame(rna_combined_samples, columns=['CCLE_Name'])\n",
    "# merge the methy_samples_df with the cell_line_anno_df\n",
    "rna_combined_samples_anno_df = pd.merge(rna_combined_samples_df, cell_line_anno_df, left_on='CCLE_Name', right_on='CCLE_ID', how='inner')\n",
    "display(rna_combined_samples_anno_df)\n",
    "\n",
    "# get the map dictionary from the ccle id to depmap id\n",
    "rna_combined_samples_map_dict = dict(zip(rna_combined_samples_anno_df['CCLE_Name'], rna_combined_samples_anno_df['depMapID']))\n",
    "print(rna_combined_samples_map_dict)\n",
    "maped_rna_combined_df = raw_rna_combined_df.rename(columns=rna_combined_samples_map_dict)\n",
    "columns_to_keep = ['gene_name'] + sorted([col for col in maped_rna_combined_df.columns if col in overlapped_omics_union_annotation_samples])\n",
    "filtered_rna_df = maped_rna_combined_df[columns_to_keep]\n",
    "display(filtered_rna_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bmg_transcript_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Transcript/BioMedGraphica_Conn_Transcript.csv')\n",
    "# keep BioMedGraphica_ID and HGNC_Symbol\n",
    "bmg_transcript_df = bmg_transcript_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']]\n",
    "display(bmg_transcript_df)\n",
    "# merge the biomedgraphica_id with the raw_transcript_df\n",
    "merge_rna_df = pd.merge(bmg_transcript_df, filtered_rna_df, left_on='HGNC_Symbol', right_on='gene_name', how='inner')\n",
    "merge_rna_df.drop(columns=['HGNC_Symbol', 'gene_name'], inplace=True)\n",
    "display(merge_rna_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# map the transcript entity with transcript_protein_entity_df\n",
    "protein_rna_df = pd.merge(transcript_protein_entity_df, merge_rna_df, left_on='BMGC_TS_ID', right_on='BioMedGraphica_Conn_ID', how='left').drop(columns=['BioMedGraphica_Conn_ID'])\n",
    "# and drop nan in any rows from protein_rna_df\n",
    "processed_rna_df = protein_rna_df.dropna().reset_index(drop=True)\n",
    "display(processed_rna_df)\n",
    "processed_rna_df.to_csv('./data/pretrain_plain_data/processed_rna.csv', index=False)\n",
    "processed_rna_df.to_csv('./data/pretrain_status_data/processed_rna.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "processed_crispr_samples = processed_crispr_df.columns[2:].tolist()\n",
    "processed_rna_samples = processed_rna_df.columns[2:].tolist()\n",
    "\n",
    "# get the union and intersection samples of the two lists\n",
    "overlapped_crispr_rna_samples = sorted(list(set(processed_crispr_samples) & set(processed_rna_samples)))\n",
    "print(\"The overlapped samples between protein_crispr_samples and protein_rna_samples are:\")\n",
    "print(len(overlapped_crispr_rna_samples))\n",
    "print(overlapped_crispr_rna_samples)\n",
    "\n",
    "# get the union set of samples of the two lists\n",
    "union_crispr_rna_samples = sorted(list(set(processed_crispr_samples) | set(processed_rna_samples)))\n",
    "print(\"The union samples between protein_crispr_samples and protein_rna_samples are:\")\n",
    "print(len(union_crispr_rna_samples))\n",
    "print(union_crispr_rna_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the intersection based on the overlapped over omics not union over omics by overlapped_omics_annotation_samples\n",
    "overlapped_omics_crispr_rna_samples = sorted(list(set(overlapped_crispr_rna_samples) & set(overlapped_omics_annotation_samples)))\n",
    "print(\"The overlapped samples between protein_crispr_samples and protein_rna_samples are:\")\n",
    "print(len(overlapped_omics_crispr_rna_samples))\n",
    "print(overlapped_omics_crispr_rna_samples)\n",
    "\n",
    "# get the intersection based on the overlapped over omics not union over omics by overlapped_omics_annotation_samples for union_crispr_rna_samples\n",
    "overlapped_omics_union_crispr_rna_samples = sorted(list(set(union_crispr_rna_samples) & set(overlapped_omics_annotation_samples)))\n",
    "print(\"The overlapped samples between protein_crispr_samples and protein_rna_samples are:\")\n",
    "print(len(overlapped_omics_union_crispr_rna_samples))\n",
    "print(overlapped_omics_union_crispr_rna_samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Sample Splits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 9.1 Pretrain, Drug and Target Sample Splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dti_overlapped_samples = list(dti_overlapped_samples_df['depMapID'])\n",
    "print(dti_overlapped_samples)\n",
    "print(\"length of dti_overlapped_samples: \", len(set(dti_overlapped_samples)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 9.1.1 Pretrain Samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# Your existing code for rest_samples is good\n",
    "rest_samples = sorted(list(set(overlapped_omics_union_annotation_samples) - set(dti_overlapped_samples)))\n",
    "print(\"len(rest_samples):\", len(rest_samples))\n",
    "print(rest_samples[:5])\n",
    "\n",
    "# check the non-cancerous samples in the rest_samples\n",
    "non_cancerous_rest_samples = sorted(list(set(non_cancerous_samples_df['Depmap Id']) & set(rest_samples)))\n",
    "print(\"len(non_cancerous_rest_samples):\", len(non_cancerous_rest_samples))\n",
    "print(non_cancerous_rest_samples[:5])\n",
    "\n",
    "cancerous_rest_samples = sorted(list(set(rest_samples) - set(non_cancerous_rest_samples)))\n",
    "print(\"len(cancerous_rest_samples):\", len(cancerous_rest_samples))\n",
    "print(cancerous_rest_samples[:5])\n",
    "\n",
    "# convert the rest_samples to a dataframe with marking the sample cancerous or non-cancerous status\n",
    "rest_samples_df = pd.DataFrame({\n",
    "    'depMapID': rest_samples,\n",
    "    'cancerous_status': ['cancerous' if sample in cancerous_rest_samples else 'non-cancerous' for sample in rest_samples]\n",
    "})\n",
    "display(rest_samples_df)\n",
    "\n",
    "# Split the rest_samples into pretrain_plain_samples and pretrain_status_samples\n",
    "# Sample 20% of the rest_samples for pretrain_plain_samples while maintaining the cancerous/non-cancerous ratio\n",
    "\n",
    "def stratified_split_pretrain_samples(rest_samples_df, pretrain_plain_ratio=0.2, random_state=42):\n",
    "    \"\"\"\n",
    "    Split rest samples into pretrain_plain and pretrain_status while maintaining \n",
    "    the cancerous/non-cancerous ratio in both splits.\n",
    "    \n",
    "    Args:\n",
    "        rest_samples_df: DataFrame with depMapID and cancerous_status columns\n",
    "        pretrain_plain_ratio: Ratio of samples to allocate to pretrain_plain (default 0.2)\n",
    "        random_state: Random seed for reproducibility\n",
    "    \n",
    "    Returns:\n",
    "        tuple: (pretrain_plain_samples, pretrain_status_samples)\n",
    "    \"\"\"\n",
    "    # Use stratified split to maintain the ratio\n",
    "    pretrain_plain_df, pretrain_status_df = train_test_split(\n",
    "        rest_samples_df,\n",
    "        test_size=1-pretrain_plain_ratio,  # 0.8 for pretrain_status\n",
    "        stratify=rest_samples_df['cancerous_status'],\n",
    "        random_state=random_state\n",
    "    )\n",
    "    \n",
    "    # Extract sample lists\n",
    "    pretrain_plain_samples = sorted(pretrain_plain_df['depMapID'].tolist())\n",
    "    pretrain_status_samples = sorted(pretrain_status_df['depMapID'].tolist())\n",
    "    \n",
    "    return pretrain_plain_samples, pretrain_status_samples, pretrain_plain_df, pretrain_status_df\n",
    "\n",
    "# Perform the split\n",
    "pretrain_plain_samples, pretrain_status_samples, pretrain_plain_df, pretrain_status_df = stratified_split_pretrain_samples(\n",
    "    rest_samples_df, \n",
    "    pretrain_plain_ratio=0.2, \n",
    "    random_state=42\n",
    ")\n",
    "\n",
    "# Display results\n",
    "print(f\"\\n📊 Pretrain Sample Split Summary:\")\n",
    "print(f\"Total rest samples: {len(rest_samples_df)}\")\n",
    "print(f\"Pretrain plain samples: {len(pretrain_plain_samples)} ({len(pretrain_plain_samples)/len(rest_samples_df)*100:.1f}%)\")\n",
    "print(f\"Pretrain status samples: {len(pretrain_status_samples)} ({len(pretrain_status_samples)/len(rest_samples_df)*100:.1f}%)\")\n",
    "\n",
    "print(f\"\\n📋 Pretrain Plain Sample Distribution:\")\n",
    "print(pretrain_plain_df['cancerous_status'].value_counts())\n",
    "print(pretrain_plain_df['cancerous_status'].value_counts(normalize=True))\n",
    "\n",
    "print(f\"\\n📋 Pretrain Status Sample Distribution:\")\n",
    "print(pretrain_status_df['cancerous_status'].value_counts())\n",
    "print(pretrain_status_df['cancerous_status'].value_counts(normalize=True))\n",
    "\n",
    "# Save the sample lists and detailed DataFrames\n",
    "pretrain_plain_df.to_csv('./data/pretrain_plain_data/pretrain_plain_samples.csv', index=False)\n",
    "pretrain_status_df.to_csv('./data/pretrain_status_data/pretrain_status_samples.csv', index=False)\n",
    "\n",
    "# Also save just the sample IDs for convenience\n",
    "pd.DataFrame(pretrain_plain_samples, columns=['depMapID']).to_csv('./data/pretrain_plain_data/pretrain_plain_sample_ids.csv', index=False)\n",
    "pd.DataFrame(pretrain_status_samples, columns=['depMapID']).to_csv('./data/pretrain_status_data/pretrain_status_sample_ids.csv', index=False)\n",
    "\n",
    "print(f\"\\n✅ Saved pretrain sample splits to:\")\n",
    "print(f\"- ./data/pretrain_plain_data/pretrain_plain_samples.csv ({len(pretrain_plain_samples)} samples)\")\n",
    "print(f\"- ./data/pretrain_status_data/pretrain_status_samples.csv ({len(pretrain_status_samples)} samples)\")\n",
    "\n",
    "# Verify the splits don't overlap\n",
    "overlap_check = set(pretrain_plain_samples) & set(pretrain_status_samples)\n",
    "print(f\"\\n🔍 Overlap check: {len(overlap_check)} overlapping samples (should be 0)\")\n",
    "\n",
    "# Show sample cancerous/non-cancerous ratios\n",
    "pretrain_plain_cancerous = pretrain_plain_df['cancerous_status'].value_counts()\n",
    "pretrain_status_cancerous = pretrain_status_df['cancerous_status'].value_counts()\n",
    "\n",
    "print(f\"\\n📈 Final Cancerous/Non-cancerous Distribution:\")\n",
    "print(f\"Pretrain Plain - Cancerous: {pretrain_plain_cancerous.get('cancerous', 0)}, Non-cancerous: {pretrain_plain_cancerous.get('non-cancerous', 0)}\")\n",
    "print(f\"Pretrain Status - Cancerous: {pretrain_status_cancerous.get('cancerous', 0)}, Non-cancerous: {pretrain_status_cancerous.get('non-cancerous', 0)}\")\n",
    "\n",
    "# Split the pretrain_status_samples into training and test sets according to the cancerous status\n",
    "def stratified_split_status_samples(pretrain_status_df, test_size=0.2, random_state=42):\n",
    "    \"\"\"\n",
    "    Split pretrain status samples into training and test sets while maintaining \n",
    "    the cancerous/non-cancerous ratio.\n",
    "    \n",
    "    Args:\n",
    "        pretrain_status_df: DataFrame with depMapID and cancerous_status columns\n",
    "        test_size: Proportion of the dataset to include in the test split (default 0.2)\n",
    "        random_state: Random seed for reproducibility\n",
    "    \n",
    "    Returns:\n",
    "        tuple: (train_samples, test_samples, train_df, test_df)\n",
    "    \"\"\"\n",
    "    # Use stratified split to maintain the ratio\n",
    "    train_df, test_df = train_test_split(\n",
    "        pretrain_status_df,\n",
    "        test_size=test_size,\n",
    "        stratify=pretrain_status_df['cancerous_status'],\n",
    "        random_state=random_state\n",
    "    )\n",
    "    \n",
    "    # Extract sample lists\n",
    "    train_samples = sorted(train_df['depMapID'].tolist())\n",
    "    test_samples = sorted(test_df['depMapID'].tolist())\n",
    "    \n",
    "    return train_samples, test_samples, train_df, test_df\n",
    "\n",
    "# Perform the split\n",
    "pretrain_status_train_samples, pretrain_status_test_samples, pretrain_status_train_df, pretrain_status_test_df = stratified_split_status_samples(\n",
    "    pretrain_status_df, \n",
    "    test_size=0.2, \n",
    "    random_state=42\n",
    ")\n",
    "\n",
    "# Display results for training and test sets for cancerous and non-cancerous counts\n",
    "print(f\"\\n📊 Pretrain Status Sample Split Summary:\"\n",
    "      f\"\\nTotal pretrain status samples: {len(pretrain_status_df)}\")\n",
    "print(f\"Training samples: {len(pretrain_status_train_samples)} ({len(pretrain_status_train_samples)/len(pretrain_status_df)*100:.1f}%)\")\n",
    "print(f\"Test samples: {len(pretrain_status_test_samples)} ({len(pretrain_status_test_samples)/len(pretrain_status_df)*100:.1f}%)\")\n",
    "print(f\"\\n📋 Training Sample Distribution:\"\n",
    "      f\"\\n{pretrain_status_train_df['cancerous_status'].value_counts()}\")\n",
    "print(pretrain_status_train_df['cancerous_status'].value_counts(normalize=True))\n",
    "print(f\"\\n📋 Test Sample Distribution:\"\n",
    "      f\"\\n{pretrain_status_test_df['cancerous_status'].value_counts()}\")\n",
    "\n",
    "# Save the training and test sample lists and detailed DataFrames\n",
    "pretrain_status_train_df.to_csv('./data/pretrain_status_data/pretrain_status_train_samples.csv', index=False)\n",
    "pretrain_status_test_df.to_csv('./data/pretrain_status_data/pretrain_status_test_samples.csv', index=False)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 9.1.2 Target-CRISPR Samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check the overlapped samples between dti_overlapped_samples and overlapped_crispr_rna_samples\n",
    "overlapped_dti_crispr_rna_samples = sorted(list(set(dti_overlapped_samples) & set(overlapped_crispr_rna_samples)))\n",
    "print(\"The overlapped samples between dti_overlapped_samples and overlapped_crispr_rna_samples are:\")\n",
    "print(len(overlapped_dti_crispr_rna_samples))\n",
    "print(overlapped_dti_crispr_rna_samples)\n",
    "\n",
    "# Add the sample information to overlapped_dti_crispr_rna_samples with tcga_code, type, and PATHOLOGIST_ANNOTATION\n",
    "overlapped_dti_crispr_rna_samples_df = dti_overlapped_samples_df[dti_overlapped_samples_df['depMapID'].isin(overlapped_dti_crispr_rna_samples)].reset_index(drop=True)\n",
    "# fill NaN values in the columns with 'Unknown'\n",
    "overlapped_dti_crispr_rna_samples_df['tcga_code'] = overlapped_dti_crispr_rna_samples_df['tcga_code'].fillna('Unknown')\n",
    "display(overlapped_dti_crispr_rna_samples_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "def stratified_tcga_split_target(df, test_ratio=0.2, random_state=42):\n",
    "    \"\"\"\n",
    "    Split samples by TCGA code with special handling for small groups:\n",
    "    - For TCGA codes with >=5 samples: split according to test_ratio (0.8/0.2)\n",
    "    - For TCGA codes with 2-4 samples: keep 1 sample for test\n",
    "    - For TCGA codes with 1 sample: keep that sample for train\n",
    "    \"\"\"\n",
    "    np.random.seed(random_state)\n",
    "    \n",
    "    train_samples = []\n",
    "    test_samples = []\n",
    "    \n",
    "    # Group by TCGA code\n",
    "    tcga_groups = df.groupby('tcga_code')\n",
    "    \n",
    "    print(\"TCGA Code Distribution and Split Strategy for Target Samples:\")\n",
    "    print(\"-\" * 70)\n",
    "    \n",
    "    for tcga_code, group in tcga_groups:\n",
    "        n_samples = len(group)\n",
    "        samples = group['depMapID'].tolist()\n",
    "        \n",
    "        # Shuffle samples within each TCGA group\n",
    "        np.random.shuffle(samples)\n",
    "        \n",
    "        if n_samples == 1:\n",
    "            # Single sample: keep for train\n",
    "            train_samples.extend(samples)\n",
    "            test_samples_for_group = []\n",
    "            strategy = \"1 sample -> train only\"\n",
    "            \n",
    "        elif n_samples < 5:\n",
    "            # 2-4 samples: keep 1 for test, rest for train\n",
    "            test_samples_for_group = samples[:1]\n",
    "            train_samples_for_group = samples[1:]\n",
    "            train_samples.extend(train_samples_for_group)\n",
    "            test_samples.extend(test_samples_for_group)\n",
    "            strategy = f\"{n_samples} samples -> 1 test, {n_samples-1} train\"\n",
    "            \n",
    "        else:\n",
    "            # 5+ samples: use the specified ratio\n",
    "            n_test = max(1, int(n_samples * test_ratio))\n",
    "            test_samples_for_group = samples[:n_test]\n",
    "            train_samples_for_group = samples[n_test:]\n",
    "            train_samples.extend(train_samples_for_group)\n",
    "            test_samples.extend(test_samples_for_group)\n",
    "            strategy = f\"{n_samples} samples -> {n_test} test, {len(train_samples_for_group)} train\"\n",
    "        \n",
    "        print(f\"TCGA {tcga_code}: {strategy}\")\n",
    "    \n",
    "    print(\"-\" * 70)\n",
    "    print(f\"Total train samples: {len(train_samples)}\")\n",
    "    print(f\"Total test samples: {len(test_samples)}\")\n",
    "    print(f\"Train ratio: {len(train_samples)/(len(train_samples)+len(test_samples)):.3f}\")\n",
    "    print(f\"Test ratio: {len(test_samples)/(len(train_samples)+len(test_samples)):.3f}\")\n",
    "    \n",
    "    return train_samples, test_samples\n",
    "\n",
    "# Prepare the overlapped samples with TCGA annotation\n",
    "overlapped_dti_crispr_rna_samples_df_annotated = overlapped_dti_crispr_rna_samples_df.copy()\n",
    "# Fill NaN TCGA codes with 'Unknown' if any exist\n",
    "overlapped_dti_crispr_rna_samples_df_annotated['tcga_code'] = overlapped_dti_crispr_rna_samples_df_annotated['tcga_code'].fillna('Unknown')\n",
    "\n",
    "# Perform stratified split based on TCGA codes\n",
    "target_crispr_train_samples, target_crispr_test_samples = stratified_tcga_split_target(\n",
    "    overlapped_dti_crispr_rna_samples_df_annotated, \n",
    "    test_ratio=0.2, \n",
    "    random_state=42\n",
    ")\n",
    "\n",
    "# Count the samples\n",
    "print(f\"\\n📊 Target Sample Count Summary:\")\n",
    "print(f\"Total samples in dataset: {len(overlapped_dti_crispr_rna_samples_df_annotated)}\")\n",
    "print(f\"Training samples: {len(target_crispr_train_samples)}\")\n",
    "print(f\"Test samples: {len(target_crispr_test_samples)}\")\n",
    "print(f\"Training percentage: {len(target_crispr_train_samples)/len(overlapped_dti_crispr_rna_samples_df_annotated)*100:.1f}%\")\n",
    "print(f\"Test percentage: {len(target_crispr_test_samples)/len(overlapped_dti_crispr_rna_samples_df_annotated)*100:.1f}%\")\n",
    "\n",
    "# Create directories\n",
    "if not os.path.exists('./data/TargetQA'):\n",
    "    os.makedirs('./data/TargetQA')\n",
    "if not os.path.exists('./data/TargetScreen'):\n",
    "    os.makedirs('./data/TargetScreen')\n",
    "\n",
    "# Sort the samples before saving\n",
    "train_samples_sorted = sorted(target_crispr_train_samples)\n",
    "test_samples_sorted = sorted(target_crispr_test_samples)\n",
    "\n",
    "# Save the sorted train and test samples to CSV files\n",
    "pd.DataFrame(train_samples_sorted, columns=['depMapID']).to_csv('./data/TargetQA/train_samples.csv', index=False)\n",
    "pd.DataFrame(test_samples_sorted, columns=['depMapID']).to_csv('./data/TargetQA/test_samples.csv', index=False)\n",
    "pd.DataFrame(train_samples_sorted, columns=['depMapID']).to_csv('./data/TargetScreen/train_samples.csv', index=False)\n",
    "pd.DataFrame(test_samples_sorted, columns=['depMapID']).to_csv('./data/TargetScreen/test_samples.csv', index=False)\n",
    "\n",
    "# Create detailed split information DataFrames\n",
    "train_target_samples_df = overlapped_dti_crispr_rna_samples_df_annotated[\n",
    "    overlapped_dti_crispr_rna_samples_df_annotated['depMapID'].isin(target_crispr_train_samples)\n",
    "].reset_index(drop=True)\n",
    "\n",
    "test_target_samples_df = overlapped_dti_crispr_rna_samples_df_annotated[\n",
    "    overlapped_dti_crispr_rna_samples_df_annotated['depMapID'].isin(target_crispr_test_samples)\n",
    "].reset_index(drop=True)\n",
    "\n",
    "# Display the TCGA distributions in train and test sets\n",
    "print(\"\\n📋 Train set TCGA distribution:\")\n",
    "print(train_target_samples_df['tcga_code'].value_counts().sort_index())\n",
    "print(\"\\n📋 Test set TCGA distribution:\")\n",
    "print(test_target_samples_df['tcga_code'].value_counts().sort_index())\n",
    "\n",
    "# Save detailed sample information\n",
    "train_target_samples_df.to_csv('./data/TargetQA/train_samples_detailed.csv', index=False)\n",
    "test_target_samples_df.to_csv('./data/TargetQA/test_samples_detailed.csv', index=False)\n",
    "train_target_samples_df.to_csv('./data/TargetScreen/train_samples_detailed.csv', index=False)\n",
    "test_target_samples_df.to_csv('./data/TargetScreen/test_samples_detailed.csv', index=False)\n",
    "\n",
    "print(f\"\\n✅ Saved stratified train and test splits to:\")\n",
    "print(f\"- ./data/TargetQA/train_samples.csv ({len(train_samples_sorted)} samples)\")\n",
    "print(f\"- ./data/TargetQA/test_samples.csv ({len(test_samples_sorted)} samples)\")\n",
    "print(f\"- ./data/TargetScreen/train_samples.csv ({len(train_samples_sorted)} samples)\")\n",
    "print(f\"- ./data/TargetScreen/test_samples.csv ({len(test_samples_sorted)} samples)\")\n",
    "print(f\"- Detailed sample information with TCGA annotations also saved\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 9.1.3 Drug Samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Keep the left samples dti_overlapped_samples - overlapped_dti_crispr_rna_samples\n",
    "remaining_dti_samples = sorted(list(set(dti_overlapped_samples) - set(overlapped_dti_crispr_rna_samples)))\n",
    "print(\"The remaining samples in dti_overlapped_samples after removing overlapped_dti_crispr_rna_samples are:\")\n",
    "print(len(remaining_dti_samples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "# Get the remaining DTI samples\n",
    "remaining_dti_samples = sorted(list(set(dti_overlapped_samples) - set(overlapped_dti_crispr_rna_samples)))\n",
    "print(\"The remaining samples in dti_overlapped_samples after removing overlapped_dti_crispr_rna_samples are:\")\n",
    "print(len(remaining_dti_samples))\n",
    "\n",
    "# Create DataFrame with remaining samples and their TCGA annotations\n",
    "remaining_dti_samples_df = dti_overlapped_samples_df[\n",
    "    dti_overlapped_samples_df['depMapID'].isin(remaining_dti_samples)\n",
    "].reset_index(drop=True)\n",
    "\n",
    "# Fill NaN TCGA codes with 'Unknown' if any exist\n",
    "remaining_dti_samples_df['tcga_code'] = remaining_dti_samples_df['tcga_code'].fillna('Unknown')\n",
    "\n",
    "def stratified_tcga_split_dti(df, test_ratio=0.2, random_state=42):\n",
    "    \"\"\"\n",
    "    Split samples by TCGA code with special handling for small groups:\n",
    "    - For TCGA codes with >=5 samples: split according to test_ratio (0.8/0.2)\n",
    "    - For TCGA codes with 2-4 samples: keep 1 sample for test\n",
    "    - For TCGA codes with 1 sample: keep that sample for train\n",
    "    \"\"\"\n",
    "    np.random.seed(random_state)\n",
    "    \n",
    "    train_samples = []\n",
    "    test_samples = []\n",
    "    \n",
    "    # Group by TCGA code\n",
    "    tcga_groups = df.groupby('tcga_code')\n",
    "    \n",
    "    print(\"TCGA Code Distribution and Split Strategy for DTI Drug Samples:\")\n",
    "    print(\"-\" * 70)\n",
    "    \n",
    "    for tcga_code, group in tcga_groups:\n",
    "        n_samples = len(group)\n",
    "        samples = group['depMapID'].tolist()\n",
    "        \n",
    "        # Shuffle samples within each TCGA group\n",
    "        np.random.shuffle(samples)\n",
    "        \n",
    "        if n_samples == 1:\n",
    "            # Single sample: keep for train\n",
    "            train_samples.extend(samples)\n",
    "            test_samples_for_group = []\n",
    "            strategy = \"1 sample -> train only\"\n",
    "            \n",
    "        elif n_samples < 5:\n",
    "            # 2-4 samples: keep 1 for test, rest for train\n",
    "            test_samples_for_group = samples[:1]\n",
    "            train_samples_for_group = samples[1:]\n",
    "            train_samples.extend(train_samples_for_group)\n",
    "            test_samples.extend(test_samples_for_group)\n",
    "            strategy = f\"{n_samples} samples -> 1 test, {n_samples-1} train\"\n",
    "            \n",
    "        else:\n",
    "            # 5+ samples: use the specified ratio\n",
    "            n_test = max(1, int(n_samples * test_ratio))\n",
    "            test_samples_for_group = samples[:n_test]\n",
    "            train_samples_for_group = samples[n_test:]\n",
    "            train_samples.extend(train_samples_for_group)\n",
    "            test_samples.extend(test_samples_for_group)\n",
    "            strategy = f\"{n_samples} samples -> {n_test} test, {len(train_samples_for_group)} train\"\n",
    "        \n",
    "        print(f\"TCGA {tcga_code}: {strategy}\")\n",
    "    \n",
    "    print(\"-\" * 70)\n",
    "    print(f\"Total train samples: {len(train_samples)}\")\n",
    "    print(f\"Total test samples: {len(test_samples)}\")\n",
    "    print(f\"Train ratio: {len(train_samples)/(len(train_samples)+len(test_samples)):.3f}\")\n",
    "    print(f\"Test ratio: {len(test_samples)/(len(train_samples)+len(test_samples)):.3f}\")\n",
    "    \n",
    "    return train_samples, test_samples\n",
    "\n",
    "# Perform stratified split based on TCGA codes for remaining DTI samples\n",
    "dti_train_samples, dti_test_samples = stratified_tcga_split_dti(\n",
    "    remaining_dti_samples_df, \n",
    "    test_ratio=0.2, \n",
    "    random_state=42\n",
    ")\n",
    "\n",
    "# Count the samples\n",
    "print(f\"\\n📊 DTI Drug Sample Count Summary:\")\n",
    "print(f\"Total remaining DTI samples: {len(remaining_dti_samples_df)}\")\n",
    "print(f\"Training samples: {len(dti_train_samples)}\")\n",
    "print(f\"Test samples: {len(dti_test_samples)}\")\n",
    "print(f\"Training percentage: {len(dti_train_samples)/len(remaining_dti_samples_df)*100:.1f}%\")\n",
    "print(f\"Test percentage: {len(dti_test_samples)/len(remaining_dti_samples_df)*100:.1f}%\")\n",
    "\n",
    "# Create directories\n",
    "if not os.path.exists('./data/DrugQA'):\n",
    "    os.makedirs('./data/DrugQA')\n",
    "if not os.path.exists('./data/DrugScreen'):\n",
    "    os.makedirs('./data/DrugScreen')\n",
    "\n",
    "# Sort the samples before saving\n",
    "dti_train_samples_sorted = sorted(dti_train_samples)\n",
    "dti_test_samples_sorted = sorted(dti_test_samples)\n",
    "\n",
    "remaining_dti_samples_df.to_csv('./data/DrugQA/remaining_dti_samples.csv', index=False)\n",
    "remaining_dti_samples_df.to_csv('./data/DrugScreen/remaining_dti_samples.csv', index=False)\n",
    "# Save the sorted train and test samples to CSV files\n",
    "pd.DataFrame(dti_train_samples_sorted, columns=['depMapID']).to_csv('./data/DrugQA/train_samples.csv', index=False)\n",
    "pd.DataFrame(dti_test_samples_sorted, columns=['depMapID']).to_csv('./data/DrugQA/test_samples.csv', index=False)\n",
    "pd.DataFrame(dti_train_samples_sorted, columns=['depMapID']).to_csv('./data/DrugScreen/train_samples.csv', index=False)\n",
    "pd.DataFrame(dti_test_samples_sorted, columns=['depMapID']).to_csv('./data/DrugScreen/test_samples.csv', index=False)\n",
    "\n",
    "# Create detailed split information DataFrames\n",
    "train_dti_samples_df = remaining_dti_samples_df[\n",
    "    remaining_dti_samples_df['depMapID'].isin(dti_train_samples)\n",
    "].reset_index(drop=True)\n",
    "\n",
    "test_dti_samples_df = remaining_dti_samples_df[\n",
    "    remaining_dti_samples_df['depMapID'].isin(dti_test_samples)\n",
    "].reset_index(drop=True)\n",
    "\n",
    "# Display the TCGA distributions in train and test sets\n",
    "print(\"\\n📋 DTI Drug Train set TCGA distribution:\")\n",
    "print(train_dti_samples_df['tcga_code'].value_counts().sort_index())\n",
    "print(\"\\n📋 DTI Drug Test set TCGA distribution:\")\n",
    "print(test_dti_samples_df['tcga_code'].value_counts().sort_index())\n",
    "\n",
    "# Save detailed sample information\n",
    "train_dti_samples_df.to_csv('./data/DrugQA/train_samples_detailed.csv', index=False)\n",
    "test_dti_samples_df.to_csv('./data/DrugQA/test_samples_detailed.csv', index=False)\n",
    "train_dti_samples_df.to_csv('./data/DrugScreen/train_samples_detailed.csv', index=False)\n",
    "test_dti_samples_df.to_csv('./data/DrugScreen/test_samples_detailed.csv', index=False)\n",
    "\n",
    "print(f\"\\n✅ Saved stratified DTI drug train and test splits to:\")\n",
    "print(f\"- ./data/DrugQA/train_samples.csv ({len(dti_train_samples_sorted)} samples)\")\n",
    "print(f\"- ./data/DrugQA/test_samples.csv ({len(dti_test_samples_sorted)} samples)\")\n",
    "print(f\"- ./data/DrugScreen/train_samples.csv ({len(dti_train_samples_sorted)} samples)\")\n",
    "print(f\"- ./data/DrugScreen/test_samples.csv ({len(dti_test_samples_sorted)} samples)\")\n",
    "print(f\"- Detailed sample information with TCGA annotations also saved\")\n",
    "\n",
    "# Summary of all splits\n",
    "print(f\"\\n📈 Complete Sample Split Summary:\")\n",
    "print(f\"Original DTI overlapped samples: {len(dti_overlapped_samples)}\")\n",
    "print(f\"Target identification samples (CRISPR/RNA): {len(overlapped_dti_crispr_rna_samples)}\")\n",
    "print(f\"Drug screening samples (remaining): {len(remaining_dti_samples)}\")\n",
    "print(f\"  - DTI Drug Train: {len(dti_train_samples)}\")\n",
    "print(f\"  - DTI Drug Test: {len(dti_test_samples)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 9.2 Data Integration"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 9.2.0 Data Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "maped_methy_df = final_merged_methy_df.rename(columns=methy_map_dict)\n",
    "display(maped_methy_df)\n",
    "mapped_protein_df = final_merged_protein_df.rename(columns=protein_map_dict)\n",
    "display(mapped_protein_df)\n",
    "print(\"Sum of all values in mapped_protein_df:\", mapped_protein_df.iloc[:, 1:].sum().sum())\n",
    "# check if all values in mapped_protein_df are zero\n",
    "print(\"Are all values in mapped_protein_df zero?\", (mapped_protein_df.iloc[:, 1:] == 0).all().all())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 9.2.0 All tasks shared the same edge_index and nodes_index system"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create an index data file for nodes and save it to the DTI_data folder\n",
    "nodes = bmgc_entity_df['BioMedGraphica_Conn_ID'].tolist()\n",
    "# create nodes index ranging from 0 to len(nodes)-1\n",
    "nodes_index = np.arange(len(nodes))\n",
    "nodes_index_data = pd.DataFrame({'Node': nodes, 'Index': nodes_index})\n",
    "nodes_index_data = pd.merge(nodes_index_data, bmgc_entity_df[['BioMedGraphica_Conn_ID', 'Type']], left_on='Node', right_on='BioMedGraphica_Conn_ID', how='left').drop(columns=['BioMedGraphica_Conn_ID'])\n",
    "display(nodes_index_data)\n",
    "nodes_index_data.to_csv('./data/pretrain_plain_data/nodes_index.csv', index=False)\n",
    "nodes_index_data.to_csv('./data/pretrain_status_data/nodes_index.csv', index=False)\n",
    "nodes_index_data.to_csv('./data/TargetQA/nodes_index.csv', index=False)\n",
    "nodes_index_data.to_csv('./data/TargetScreen/nodes_index.csv', index=False)\n",
    "nodes_index_data.to_csv('./data/DrugQA/nodes_index.csv', index=False)\n",
    "nodes_index_data.to_csv('./data/DrugScreen/nodes_index.csv', index=False)\n",
    "node_index_dict = dict(zip(nodes_index_data['Node'], nodes_index_data['Index']))\n",
    "# Print the first 10 items of node_index_dict\n",
    "print(\"\\nFirst 10 items in node_index_dict:\")\n",
    "print(list(node_index_dict.items())[:10])\n",
    "\n",
    "# Convert the bmgc_relation_df to numpy array and save it to the DTI_data folder as the edge_index.npy\n",
    "# keep the columns ['BMGC_From_ID', 'BMGC_To_ID']\n",
    "edge_index_df = bmgc_relation_df[['BMGC_From_ID', 'BMGC_To_ID']].copy()\n",
    "# Map the BMGC_From_ID and BMGC_To_ID by the node_index_dict\n",
    "edge_index_df['BMGC_From_ID'] = edge_index_df['BMGC_From_ID'].map(node_index_dict)\n",
    "edge_index_df['BMGC_To_ID'] = edge_index_df['BMGC_To_ID'].map(node_index_dict)\n",
    "# check the null values in the edge_index_df\n",
    "print(\"Null values in edge_index_df:\")\n",
    "print(edge_index_df.isnull().sum())\n",
    "display(edge_index_df)\n",
    "# convert the edge_index_df to numpy array and save it to the DTI_data folder\n",
    "edge_index_array = edge_index_df.to_numpy().T\n",
    "print('The shape of edge_index_array is:', edge_index_array.shape)\n",
    "print(edge_index_array)\n",
    "# save the numpy array to the DTI_data folder\n",
    "np.save('./data/pretrain_plain_data/edge_index.npy', edge_index_array)\n",
    "np.save('./data/pretrain_status_data/edge_index.npy', edge_index_array)\n",
    "np.save('./data/TargetQA/edge_index.npy', edge_index_array)\n",
    "np.save('./data/TargetScreen/edge_index.npy', edge_index_array)\n",
    "np.save('./data/DrugQA/edge_index.npy', edge_index_array)\n",
    "np.save('./data/DrugScreen/edge_index.npy', edge_index_array)\n",
    "# generate the internal_edge_index by selecting bmgc_relation_df in Type ['Promoter-Gene', 'Gene-Transcript', 'Transcript-Protein']\n",
    "internal_edge_index_df = bmgc_relation_df[bmgc_relation_df['Type'].isin(['Promoter-Gene', 'Gene-Transcript', 'Transcript-Protein'])].copy()\n",
    "# keep the columns ['BMGC_From_ID', 'BMGC_To_ID']\n",
    "internal_edge_index_df = internal_edge_index_df[['BMGC_From_ID', 'BMGC_To_ID']].copy()\n",
    "# Map the BMGC_From_ID and BMGC_To_ID by the node_index_dict\n",
    "internal_edge_index_df['BMGC_From_ID'] = internal_edge_index_df['BMGC_From_ID'].map(node_index_dict)\n",
    "internal_edge_index_df['BMGC_To_ID'] = internal_edge_index_df['BMGC_To_ID'].map(node_index_dict)\n",
    "# check the null values in the internal_edge_index_df\n",
    "print(\"Null values in internal_edge_index_df:\")\n",
    "print(internal_edge_index_df.isnull().sum())\n",
    "# convert the internal_edge_index_df to numpy array and save it to the DTI_data folder\n",
    "internal_edge_index_array = internal_edge_index_df.to_numpy().T\n",
    "print('The shape of internal_edge_index_array is:', internal_edge_index_array.shape)\n",
    "print(internal_edge_index_array)\n",
    "# save the numpy array to the DTI_data folder\n",
    "np.save('./data/pretrain_plain_data/internal_edge_index.npy', internal_edge_index_array)\n",
    "np.save('./data/pretrain_status_data/internal_edge_index.npy', internal_edge_index_array)\n",
    "np.save('./data/TargetQA/internal_edge_index.npy', internal_edge_index_array)\n",
    "np.save('./data/TargetScreen/internal_edge_index.npy', internal_edge_index_array)\n",
    "np.save('./data/DrugQA/internal_edge_index.npy', internal_edge_index_array)\n",
    "np.save('./data/DrugScreen/internal_edge_index.npy', internal_edge_index_array)\n",
    "# generate the ppi_edge_index by selecting bmgc_relation_df in Type ['Protein-Protein']\n",
    "ppi_edge_index_df = bmgc_relation_df[bmgc_relation_df['Type'].isin(['Protein-Protein'])].copy()\n",
    "# keep the columns ['BMGC_From_ID', 'BMGC_To_ID']\n",
    "ppi_edge_index_df = ppi_edge_index_df[['BMGC_From_ID', 'BMGC_To_ID']].copy()\n",
    "# Map the BMGC_From_ID and BMGC_To_ID by the node_index_dict\n",
    "ppi_edge_index_df['BMGC_From_ID'] = ppi_edge_index_df['BMGC_From_ID'].map(node_index_dict)\n",
    "ppi_edge_index_df['BMGC_To_ID'] = ppi_edge_index_df['BMGC_To_ID'].map(node_index_dict)\n",
    "# check the null values in the ppi_edge_index_df\n",
    "print(\"Null values in ppi_edge_index_df:\")\n",
    "print(ppi_edge_index_df.isnull().sum())\n",
    "# convert the ppi_edge_index_df to numpy array and save it to the DTI_data folder\n",
    "ppi_edge_index_array = ppi_edge_index_df.to_numpy().T\n",
    "print('The shape of ppi_edge_index_array is:', ppi_edge_index_array.shape)\n",
    "print(ppi_edge_index_array)\n",
    "# save the numpy array to the DTI_data folder\n",
    "np.save('./data/pretrain_plain_data/ppi_edge_index.npy', ppi_edge_index_array)\n",
    "np.save('./data/pretrain_status_data/ppi_edge_index.npy', ppi_edge_index_array)\n",
    "np.save('./data/TargetQA/ppi_edge_index.npy', ppi_edge_index_array)\n",
    "np.save('./data/TargetScreen/ppi_edge_index.npy', ppi_edge_index_array)\n",
    "np.save('./data/DrugQA/ppi_edge_index.npy', ppi_edge_index_array)\n",
    "np.save('./data/DrugScreen/ppi_edge_index.npy', ppi_edge_index_array)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 9.2.0 All tasks shared the same textual description"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the name and desc files\n",
    "bmgc_promoter_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Promoter/BioMedGraphica_Conn_Promoter_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_gene_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_transcript_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Transcript/BioMedGraphica_Conn_Transcript_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_protein_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_pathway_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Pathway/BioMedGraphica_Conn_Pathway_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_metabolite_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Metabolite/BioMedGraphica_Conn_Metabolite_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_microbiota_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Microbiota/BioMedGraphica_Conn_Microbiota_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_exposure_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Exposure/BioMedGraphica_Conn_Exposure_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_phenotype_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Phenotype/BioMedGraphica_Conn_Phenotype_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_disease_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Disease/BioMedGraphica_Conn_Disease_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_drug_name_id_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Drug/BioMedGraphica_Conn_Drug_LLM_Name_ID_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "# get the number of rows in each dataframe and sum the number of rows for those dataframes\n",
    "print(\"bmgc_promoter_name_id_df:\", len(bmgc_promoter_name_id_df))\n",
    "print(\"bmgc_gene_name_id_df:\", len(bmgc_gene_name_id_df))\n",
    "print(\"bmgc_transcript_name_id_df:\", len(bmgc_transcript_name_id_df))\n",
    "print(\"bmgc_protein_name_id_df:\", len(bmgc_protein_name_id_df))\n",
    "print(\"bmgc_pathway_name_id_df:\", len(bmgc_pathway_name_id_df))\n",
    "print(\"bmgc_metabolite_name_id_df:\", len(bmgc_metabolite_name_id_df))\n",
    "print(\"bmgc_microbiota_name_id_df:\", len(bmgc_microbiota_name_id_df))\n",
    "print(\"bmgc_exposure_name_id_df:\", len(bmgc_exposure_name_id_df))\n",
    "print(\"bmgc_phenotype_name_id_df:\", len(bmgc_phenotype_name_id_df))\n",
    "print(\"bmgc_disease_name_id_df:\", len(bmgc_disease_name_id_df))\n",
    "print(\"bmgc_drug_name_id_df:\", len(bmgc_drug_name_id_df))\n",
    "# sum the number of rows for those dataframes\n",
    "print(\"Total number of rows in all dataframes:\", len(bmgc_promoter_name_id_df) + len(bmgc_gene_name_id_df) + len(bmgc_transcript_name_id_df) + len(bmgc_protein_name_id_df) + len(bmgc_pathway_name_id_df) + len(bmgc_metabolite_name_id_df) + len(bmgc_microbiota_name_id_df) + len(bmgc_exposure_name_id_df) + len(bmgc_phenotype_name_id_df) + len(bmgc_disease_name_id_df) + len(bmgc_drug_name_id_df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# left join those name_df into bmgc_entity_df\n",
    "bmgc_entity_cp_df = bmgc_entity_df[['BioMedGraphica_Conn_ID']].copy()\n",
    "display(bmgc_entity_cp_df)\n",
    "# concatenate all the name_id_df into one dataframe\n",
    "bmgc_name_id_tmp_df = pd.concat([bmgc_promoter_name_id_df, bmgc_gene_name_id_df, bmgc_transcript_name_id_df, bmgc_protein_name_id_df, bmgc_pathway_name_id_df, bmgc_metabolite_name_id_df, bmgc_microbiota_name_id_df, bmgc_exposure_name_id_df, bmgc_phenotype_name_id_df, bmgc_disease_name_id_df, bmgc_drug_name_id_df], ignore_index=True)\n",
    "display(bmgc_name_id_tmp_df)\n",
    "# left join the bmgc_name_id_tmp_df into bmgc_entity_cp_df on BioMedGraphica_Conn_ID\n",
    "bmgc_entity_cp_df = pd.merge(bmgc_entity_cp_df, bmgc_name_id_tmp_df, on='BioMedGraphica_Conn_ID', how='left')\n",
    "display(bmgc_entity_cp_df)\n",
    "bmgc_entity_cp_df.to_csv('./data/pretrain_plain_data/bmgc_name.csv', index=False)\n",
    "bmgc_entity_cp_df.to_csv('./data/pretrain_status_data/bmgc_name.csv', index=False)\n",
    "bmgc_entity_cp_df.to_csv('./data/TargetScreen/bmgc_name.csv', index=False)\n",
    "bmgc_entity_cp_df.to_csv('./data/TargetQA/bmgc_name.csv', index=False)\n",
    "bmgc_entity_cp_df.to_csv('./data/DrugScreen/bmgc_name.csv', index=False)\n",
    "bmgc_entity_cp_df.to_csv('./data/DrugQA/bmgc_name.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bmgc_promoter_desc_df = bmgc_promoter_name_id_df.drop(columns=['Names_and_IDs'], axis=1).copy()\n",
    "bmgc_promoter_desc_df['Description'] = np.nan # add the Description column to bmgc_promoter_name_df with NaN values\n",
    "bmgc_gene_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Gene/BioMedGraphica_Conn_Gene_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_transcript_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Transcript/BioMedGraphica_Conn_Transcript_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_protein_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_pathway_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Pathway/BioMedGraphica_Conn_Pathway_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_metabolite_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Metabolite/BioMedGraphica_Conn_Metabolite_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_microbiota_desc_df = bmgc_microbiota_name_id_df.drop(columns=['Names_and_IDs'], axis=1).copy()\n",
    "bmgc_microbiota_desc_df['Description'] = np.nan # add the Description column to bmgc_microbiota_name_df with NaN values\n",
    "bmgc_exposure_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Exposure/BioMedGraphica_Conn_Exposure_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_phenotype_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Phenotype/BioMedGraphica_Conn_Phenotype_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_disease_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Disease/BioMedGraphica_Conn_Disease_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])\n",
    "bmgc_drug_desc_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Drug/BioMedGraphica_Conn_Drug_Description_Combined.csv').drop(columns=['BioMedGraphica_ID'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bmgc_desc_tmp_df = pd.concat([bmgc_promoter_desc_df, bmgc_gene_desc_df, bmgc_transcript_desc_df, bmgc_protein_desc_df, bmgc_pathway_desc_df, bmgc_metabolite_desc_df, bmgc_microbiota_desc_df, bmgc_exposure_desc_df, bmgc_phenotype_desc_df, bmgc_disease_desc_df, bmgc_drug_desc_df], ignore_index=True)\n",
    "display(bmgc_desc_tmp_df)\n",
    "# left join the bmgc_desc_tmp_df into bmgc_entity_cp_df on BioMedGraphica_Conn_ID\n",
    "bmgc_entity_cp_df = bmgc_entity_df[['BioMedGraphica_Conn_ID']].copy()\n",
    "bmgc_desc_df = pd.merge(bmgc_entity_cp_df, bmgc_desc_tmp_df, on='BioMedGraphica_Conn_ID', how='left')\n",
    "# check the null values in the bmgc_desc_df\n",
    "print(\"Null values in bmgc_desc_df:\")\n",
    "print(bmgc_desc_df.isnull().sum())\n",
    "# fill the NaN values in the Description column with empty string\n",
    "bmgc_desc_df['Description'] = bmgc_desc_df['Description'].fillna(' ')\n",
    "# recheck the null values in the bmgc_desc_df\n",
    "print(\"Null values in bmgc_desc_df:\")\n",
    "print(bmgc_desc_df.isnull().sum())\n",
    "display(bmgc_desc_df)\n",
    "bmgc_desc_df.to_csv('./data/pretrain_plain_data/bmgc_desc.csv', index=False)\n",
    "bmgc_desc_df.to_csv('./data/pretrain_status_data/bmgc_desc.csv', index=False)\n",
    "bmgc_desc_df.to_csv('./data/TargetScreen/bmgc_desc.csv', index=False)\n",
    "bmgc_desc_df.to_csv('./data/TargetQA/bmgc_desc.csv', index=False)\n",
    "bmgc_desc_df.to_csv('./data/DrugScreen/bmgc_desc.csv', index=False)\n",
    "bmgc_desc_df.to_csv('./data/DrugQA/bmgc_desc.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 9.2.1 Pretrain data integration (plain + status)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"The overlapped samples dataframe for methylation between pretrain_plain_samples and pretrain_status_samples are:\")\n",
    "# insert the samples that is in the overlapped_samples but not in the maped_methy_df\n",
    "for sample in pretrain_plain_samples:\n",
    "    if sample not in maped_methy_df.columns:\n",
    "        maped_methy_df[sample] = 0.0\n",
    "# filter out the samples that is not in the pretraining_samples\n",
    "pretrain_plain_methy_df = maped_methy_df[['BioMedGraphica_Conn_ID'] + pretrain_plain_samples]\n",
    "display(pretrain_plain_methy_df)\n",
    "# insert the samples that is in the overlapped_samples but not in the maped_methy_df\n",
    "for sample in pretrain_status_samples:\n",
    "    if sample not in maped_methy_df.columns:\n",
    "        maped_methy_df[sample] = 0.0\n",
    "# filter out the samples that is not in the pretraining_samples\n",
    "pretrain_status_methy_df = maped_methy_df[['BioMedGraphica_Conn_ID'] + pretrain_status_samples]\n",
    "display(pretrain_status_methy_df)\n",
    "\n",
    "print(\"The overlapped samples dataframe for protein between pretrain_plain_samples and pretrain_status_samples are:\")\n",
    "# insert the samples that is in the pretraining_samples but not in the mapped_protein_df\n",
    "for sample in pretrain_plain_samples:\n",
    "    if sample not in mapped_protein_df.columns:\n",
    "        mapped_protein_df[sample] = 0.0\n",
    "# filter out the samples that is not in the pretraining_samples\n",
    "pretrain_plain_protein_df = mapped_protein_df[['BioMedGraphica_Conn_ID'] + pretrain_plain_samples]\n",
    "display(pretrain_plain_protein_df)\n",
    "# insert the samples that is in the pretraining_samples but not in the mapped_protein_df\n",
    "for sample in pretrain_status_samples:\n",
    "    if sample not in mapped_protein_df.columns:\n",
    "        mapped_protein_df[sample] = 0.0\n",
    "# filter out the samples that is not in the pretrain_status_samples\n",
    "pretrain_status_protein_df = mapped_protein_df[['BioMedGraphica_Conn_ID'] + pretrain_status_samples]\n",
    "display(pretrain_status_protein_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fill in the missing samples in the final_merged_gene_df for pretrain_plain_samples\n",
    "for sample in pretrain_plain_samples:\n",
    "    if sample not in final_merged_gene_df.columns:\n",
    "        final_merged_gene_df[sample] = 0.0\n",
    "# fill in the missing samples in the final_merged_transcript_df for pretrain_plain_samples\n",
    "for sample in pretrain_plain_samples:\n",
    "    if sample not in final_merged_transcript_df.columns:\n",
    "        final_merged_transcript_df[sample] = 0.0\n",
    "# get the final gene, transcript, drug dataframe by filtering the pretrain_plain_samples\n",
    "pretrain_plain_gene_df = final_merged_gene_df[['BioMedGraphica_Conn_ID'] + pretrain_plain_samples]\n",
    "display(pretrain_plain_gene_df)\n",
    "pretrain_plain_transcript_df = final_merged_transcript_df[['BioMedGraphica_Conn_ID'] + pretrain_plain_samples]\n",
    "display(pretrain_plain_transcript_df)\n",
    "\n",
    "# fill in the missing samples in the final_merged_gene_df for pretrain_status_samples\n",
    "for sample in pretrain_status_samples:\n",
    "    if sample not in final_merged_gene_df.columns:\n",
    "        final_merged_gene_df[sample] = 0.0\n",
    "# fill in the missing samples in the final_merged_transcript_df for pretrain_status_samples\n",
    "for sample in pretrain_status_samples:\n",
    "    if sample not in final_merged_transcript_df.columns:\n",
    "        final_merged_transcript_df[sample] = 0.0\n",
    "# get the final gene, transcript, drug dataframe by filtering the pretrain_status_samples\n",
    "pretrain_status_gene_df = final_merged_gene_df[['BioMedGraphica_Conn_ID'] + pretrain_status_samples]\n",
    "display(pretrain_status_gene_df)\n",
    "pretrain_status_transcript_df = final_merged_transcript_df[['BioMedGraphica_Conn_ID'] + pretrain_status_samples]\n",
    "display(pretrain_status_transcript_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretrain_plain_omics_df = pd.concat([pretrain_plain_methy_df, pretrain_plain_gene_df, pretrain_plain_transcript_df, pretrain_plain_protein_df], axis=0).reset_index(drop=True)\n",
    "display(pretrain_plain_omics_df)\n",
    "pretrain_plain_feat_df = pd.merge(bmgc_entity_df, pretrain_plain_omics_df, left_on='BioMedGraphica_Conn_ID', right_on='BioMedGraphica_Conn_ID', how='left').drop(columns=['BioMedGraphica_ID', 'Type'])\n",
    "pretrain_plain_feat_df = pretrain_plain_feat_df.fillna(0.0)\n",
    "display(pretrain_plain_feat_df)\n",
    "pretrain_status_omics_df = pd.concat([pretrain_status_methy_df, pretrain_status_gene_df, pretrain_status_transcript_df, pretrain_status_protein_df], axis=0).reset_index(drop=True)\n",
    "display(pretrain_status_omics_df)\n",
    "pretrain_status_feat_df = pd.merge(bmgc_entity_df, pretrain_status_omics_df, left_on='BioMedGraphica_Conn_ID', right_on='BioMedGraphica_Conn_ID', how='left').drop(columns=['BioMedGraphica_ID', 'Type'])\n",
    "pretrain_status_feat_df = pretrain_status_feat_df.fillna(0.0)\n",
    "display(pretrain_status_feat_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert the pretrain_plain_df and pretrain_status_df to numpy arrays and transpose them\n",
    "pretrain_plain_array = pretrain_plain_feat_df.drop(columns=['BioMedGraphica_Conn_ID']).values.T\n",
    "pretrain_status_array = pretrain_status_feat_df.drop(columns=['BioMedGraphica_Conn_ID']).values.T\n",
    "print(\"Shape of pretrain_plain_array:\", pretrain_plain_array.shape)\n",
    "print(\"Shape of pretrain_status_array:\", pretrain_status_array.shape)\n",
    "# Save the numpy arrays to .npy files\n",
    "np.save('./data/pretrain_plain_data/pretrain_plain_feature.npy', pretrain_plain_array)\n",
    "np.save('./data/pretrain_status_data/pretrain_status_feature.npy', pretrain_status_array)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 9.2.1.1 Pretrain Status"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert pretrain_status_df to numpy array label with  cancerous as 1 and non-cancerous as 0\n",
    "pretrain_status_label = np.array([1 if status == 'cancerous' else 0 for status in pretrain_status_df['cancerous_status']])\n",
    "# Save the pretrain_status_label to a .npy file\n",
    "np.save('./data/pretrain_status_data/pretrain_status_label.npy', pretrain_status_label)\n",
    "print(\"Shape of pretrain_status_label:\", pretrain_status_label.shape)\n",
    "\n",
    "# assign the label by pretrain_status_training_samples and pretrain_status_test_samples\n",
    "pretrain_status_train_label = np.array([1 if status == 'cancerous' else 0 for status in pretrain_status_train_df['cancerous_status']])\n",
    "pretrain_status_test_label = np.array([1 if status == 'cancerous' else 0 for status in pretrain_status_test_df['cancerous_status']])\n",
    "# assign the feat by pretrain_status_training_samples and pretrain_status_test_samples\n",
    "pretrain_status_train_feat = pretrain_status_feat_df[['BioMedGraphica_Conn_ID'] + pretrain_status_train_samples].drop(columns=['BioMedGraphica_Conn_ID']).values.T\n",
    "pretrain_status_test_feat = pretrain_status_feat_df[['BioMedGraphica_Conn_ID'] + pretrain_status_test_samples].drop(columns=['BioMedGraphica_Conn_ID']).values.T\n",
    "\n",
    "# Save the training and test features and labels to .npy files\n",
    "np.save('./data/pretrain_status_data/pretrain_status_train_feature.npy', pretrain_status_train_feat)\n",
    "np.save('./data/pretrain_status_data/pretrain_status_test_feature.npy', pretrain_status_test_feat)\n",
    "np.save('./data/pretrain_status_data/pretrain_status_train_label.npy', pretrain_status_train_label)\n",
    "np.save('./data/pretrain_status_data/pretrain_status_test_label.npy', pretrain_status_test_label)\n",
    "print(\"Shape of pretrain_status_train_feat:\", pretrain_status_train_feat.shape)\n",
    "print(\"Shape of pretrain_status_test_feat:\", pretrain_status_test_feat.shape)\n",
    "print(\"Shape of pretrain_status_train_label:\", pretrain_status_train_label.shape)\n",
    "print(\"Shape of pretrain_status_test_label:\", pretrain_status_test_label.shape)\n",
    "# Print the cancerous/non-cancerous distribution in training and test sets\n",
    "print(f\"\\n📈 Pretrain Status Training Set Distribution (Numpy Files):\\n\"\n",
    "      f\"Cancerous: {np.sum(pretrain_status_train_label == 1)}, Non-cancerous: {np.sum(pretrain_status_train_label == 0)}\")\n",
    "print(f\"📈 Pretrain Status Test Set Distribution (Numpy Files):\\n\"\n",
    "      f\"Cancerous: {np.sum(pretrain_status_test_label == 1)}, Non-cancerous: {np.sum(pretrain_status_test_label == 0)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 9.2.2 Target data integration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"The overlapped samples dataframe for methylation between overlapped_dti_crispr_rna_samples:\")\n",
    "# insert the samples that is in the overlapped_samples but not in the maped_methy_df\n",
    "for sample in overlapped_dti_crispr_rna_samples:\n",
    "    if sample not in maped_methy_df.columns:\n",
    "        maped_methy_df[sample] = 0.0\n",
    "# filter out the samples that is not in the overlapped_dti_crispr_rna_samples\n",
    "target_crispr_methy_df = maped_methy_df[['BioMedGraphica_Conn_ID'] + overlapped_dti_crispr_rna_samples]\n",
    "display(target_crispr_methy_df)\n",
    "\n",
    "print(\"The overlapped samples dataframe for protein between overlapped_dti_crispr_rna_samples are:\")\n",
    "# insert the samples that is in the pretraining_samples but not in the mapped_protein_df\n",
    "for sample in overlapped_dti_crispr_rna_samples:\n",
    "    if sample not in mapped_protein_df.columns:\n",
    "        mapped_protein_df[sample] = 0.0\n",
    "# filter out the samples that is not in the overlapped_dti_crispr_rna_samples\n",
    "target_crispr_protein_df = mapped_protein_df[['BioMedGraphica_Conn_ID'] + overlapped_dti_crispr_rna_samples]\n",
    "display(target_crispr_protein_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fill in the missing samples in the final_merged_gene_df for overlapped_dti_crispr_rna_samples\n",
    "for sample in overlapped_dti_crispr_rna_samples:\n",
    "    if sample not in final_merged_gene_df.columns:\n",
    "        final_merged_gene_df[sample] = 0.0\n",
    "# fill in the missing samples in the final_merged_transcript_df for overlapped_dti_crispr_rna_samples\n",
    "for sample in overlapped_dti_crispr_rna_samples:\n",
    "    if sample not in final_merged_transcript_df.columns:\n",
    "        final_merged_transcript_df[sample] = 0.0\n",
    "# get the final gene, transcript, drug dataframe by filtering the overlapped_dti_crispr_rna_samples\n",
    "target_crispr_gene_df = final_merged_gene_df[['BioMedGraphica_Conn_ID'] + overlapped_dti_crispr_rna_samples]\n",
    "display(target_crispr_gene_df)\n",
    "target_crispr_transcript_df = final_merged_transcript_df[['BioMedGraphica_Conn_ID'] + overlapped_dti_crispr_rna_samples]\n",
    "display(target_crispr_transcript_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_crispr_omics_df = pd.concat([target_crispr_methy_df, target_crispr_gene_df, target_crispr_transcript_df, target_crispr_protein_df], axis=0).reset_index(drop=True)\n",
    "display(target_crispr_omics_df)\n",
    "target_crispr_df = pd.merge(bmgc_entity_df, target_crispr_omics_df, left_on='BioMedGraphica_Conn_ID', right_on='BioMedGraphica_Conn_ID', how='left').drop(columns=['BioMedGraphica_ID', 'Type'])\n",
    "target_crispr_df = target_crispr_df.fillna(0.0)\n",
    "display(target_crispr_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 9.2.2.1 Target Screen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select the features from target_crispr_df by target_crispr_train_samples and target_crispr_test_samples\n",
    "target_crispr_train_feat = target_crispr_df[['BioMedGraphica_Conn_ID'] + target_crispr_train_samples].drop(columns=['BioMedGraphica_Conn_ID']).values.T\n",
    "target_crispr_test_feat = target_crispr_df[['BioMedGraphica_Conn_ID'] + target_crispr_test_samples].drop(columns=['BioMedGraphica_Conn_ID']).values.T\n",
    "# Save the training and test features to .npy files\n",
    "np.save('./data/TargetScreen/target_crispr_train_feature.npy', target_crispr_train_feat)\n",
    "np.save('./data/TargetScreen/target_crispr_test_feature.npy', target_crispr_test_feat)\n",
    "# Print the shapes of the training and test features\n",
    "print(\"Shape of target_crispr_train_feat:\", target_crispr_train_feat.shape)\n",
    "print(\"Shape of target_crispr_test_feat:\", target_crispr_test_feat.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select the features from target_crispr_df\n",
    "target_crispr_feat = target_crispr_df.drop(columns=['BioMedGraphica_Conn_ID']).values.T\n",
    "# Save the features to .npy files\n",
    "np.save('./data/TargetScreen/target_crispr_feature.npy', target_crispr_feat)\n",
    "# Print the shapes of the features\n",
    "print(\"Shape of target_crispr_feat:\", target_crispr_feat.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build up internal relation for KO drop\n",
    "bmgc_promoter_gene_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Promoter-Gene'].drop(columns=['Type'])\n",
    "bmgc_gene_transcript_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Gene-Transcript'].drop(columns=['Type'])\n",
    "bmgc_transcript_protein_df = bmgc_omics_relation_df[bmgc_omics_relation_df['Type'] == 'Transcript-Protein'].drop(columns=['Type'])\n",
    "# Rename the columns\n",
    "bmgc_promoter_gene_df.rename(columns={'BMGC_From_ID':'promoterID','BMGC_To_ID':'geneID'}, inplace=True)\n",
    "bmgc_gene_transcript_df.rename(columns={'BMGC_From_ID':'geneID','BMGC_To_ID':'transcriptID'}, inplace=True)\n",
    "bmgc_transcript_protein_df.rename(columns={'BMGC_From_ID':'transcriptID','BMGC_To_ID':'proteinID'}, inplace=True)\n",
    "# Merge gene_transcript and transcript_protein dataframes\n",
    "bmgc_gene_transcript_protein_df = bmgc_transcript_protein_df.merge(bmgc_gene_transcript_df, on='transcriptID',how='outer')\n",
    "display(bmgc_gene_transcript_protein_df)\n",
    "# Merge promoter\n",
    "bmgc_promoter_gene_transcript_protein_df = bmgc_gene_transcript_protein_df.merge(bmgc_promoter_gene_df, on='geneID',how='outer')\n",
    "display(bmgc_promoter_gene_transcript_protein_df)\n",
    "internal_relation_df = bmgc_promoter_gene_transcript_protein_df.dropna(subset=['geneID'])\n",
    "internal_relation_df = internal_relation_df[['promoterID', 'geneID', 'transcriptID', 'proteinID']].copy()\n",
    "display(internal_relation_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_crispr_df = pd.read_csv('./data/raw_data/CRISPRGeneEffect.csv')\n",
    "# get the column names of the raw_crispr_df aside from the first column and convert this to a list\n",
    "first_column_name = raw_crispr_df.columns[0]\n",
    "raw_crispr_df.rename(columns={first_column_name: 'Sample'}, inplace=True)\n",
    "display(raw_crispr_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_target_samples_id_df = train_target_samples_df[['depMapID']].copy()\n",
    "train_target_samples_id_df.rename(columns={'depMapID': 'Sample'}, inplace=True)\n",
    "test_target_samples_id_df = test_target_samples_df[['depMapID']].copy()\n",
    "test_target_samples_id_df.rename(columns={'depMapID': 'Sample'}, inplace=True)\n",
    "# Get the samples in the raw_crispr_df\n",
    "train_crispr_score_df = raw_crispr_df.merge(train_target_samples_id_df, on=\"Sample\", how=\"inner\")\n",
    "display(train_crispr_score_df)\n",
    "test_crispr_score_df = raw_crispr_df.merge(test_target_samples_id_df, on=\"Sample\", how=\"inner\")\n",
    "display(test_crispr_score_df)\n",
    "\n",
    "# Set Index column to str to keep digit without .00\n",
    "train_crispr_score_t_df = train_crispr_score_df.set_index(\"Sample\").T.reset_index()\n",
    "train_crispr_score_t_df.rename(columns={\"index\": \"HGNC_Symbol\"}, inplace=True)\n",
    "display(train_crispr_score_t_df)\n",
    "test_crispr_score_t_df = test_crispr_score_df.set_index(\"Sample\").T.reset_index()\n",
    "test_crispr_score_t_df.rename(columns={\"index\": \"HGNC_Symbol\"}, inplace=True)\n",
    "display(test_crispr_score_t_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "def extract_gene_name(gene):\n",
    "    return re.sub(r\"\\s*\\(.*?\\)\", \"\", str(gene))\n",
    "\n",
    "train_crispr_score_t_df[\"HGNC_Symbol\"] = train_crispr_score_t_df[\"HGNC_Symbol\"].apply(extract_gene_name)\n",
    "unique_Gene_map = bmgc_gene_df[['HGNC_Symbol']].drop_duplicates()\n",
    "unique_omics_Gene = train_crispr_score_t_df[['HGNC_Symbol']].drop_duplicates()\n",
    "# match\n",
    "intersection = set(unique_Gene_map['HGNC_Symbol'])&set(unique_omics_Gene['HGNC_Symbol'])\n",
    "total_Gene = len(unique_Gene_map)\n",
    "match_rate = len(intersection) / total_Gene if total_Gene > 0 else 0\n",
    "print(f\"match rate:{match_rate:.2%}\")\n",
    "train_crispr_score_bmgc_df = bmgc_gene_df.merge(train_crispr_score_t_df, on=\"HGNC_Symbol\", how=\"inner\")\n",
    "train_crispr_score_bmgc_df = train_crispr_score_bmgc_df.drop(columns=['BioMedGraphica_ID','NCBI_Gene_ID','Ensembl_Gene_ID_Version','Gene_Start', 'Gene_End','Chromosome','Gene_Type','Ensembl_Gene_ID','HGNC_ID','Gene_Name','RefSeq_ID','OMIM_ID','HGNC_Symbol'])\n",
    "display(train_crispr_score_bmgc_df)\n",
    "\n",
    "test_crispr_score_t_df[\"HGNC_Symbol\"] = test_crispr_score_t_df[\"HGNC_Symbol\"].apply(extract_gene_name)\n",
    "unique_Gene_map = bmgc_gene_df[['HGNC_Symbol']].drop_duplicates()\n",
    "unique_omics_Gene = test_crispr_score_t_df[['HGNC_Symbol']].drop_duplicates()\n",
    "# match\n",
    "intersection = set(unique_Gene_map['HGNC_Symbol'])&set(unique_omics_Gene['HGNC_Symbol'])\n",
    "total_Gene = len(unique_Gene_map)\n",
    "match_rate = len(intersection) / total_Gene if total_Gene > 0 else 0\n",
    "print(f\"match rate:{match_rate:.2%}\")\n",
    "test_crispr_score_bmgc_df = bmgc_gene_df.merge(test_crispr_score_t_df, on=\"HGNC_Symbol\", how=\"inner\")\n",
    "test_crispr_score_bmgc_df = test_crispr_score_bmgc_df.drop(columns=['BioMedGraphica_ID','NCBI_Gene_ID','Ensembl_Gene_ID_Version','Gene_Start', 'Gene_End','Chromosome','Gene_Type','Ensembl_Gene_ID','HGNC_ID','Gene_Name','RefSeq_ID','OMIM_ID','HGNC_Symbol'])\n",
    "display(test_crispr_score_bmgc_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_crispr_score_label = train_crispr_score_bmgc_df.copy()\n",
    "train_crispr_score_label_melted = train_crispr_score_label.melt(id_vars=[\"BioMedGraphica_Conn_ID\"], var_name=\"ACH_ID\", value_name=\"Value\")\n",
    "train_crispr_score_label_melted = train_crispr_score_label_melted.dropna(subset=['Value'])\n",
    "display(train_crispr_score_label_melted)\n",
    "test_crispr_score_label = test_crispr_score_bmgc_df.copy()\n",
    "test_crispr_score_label_melted = test_crispr_score_label.melt(id_vars=[\"BioMedGraphica_Conn_ID\"], var_name=\"ACH_ID\", value_name=\"Value\")\n",
    "test_crispr_score_label_melted = test_crispr_score_label_melted.dropna(subset=['Value'])\n",
    "display(test_crispr_score_label_melted)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Merge into internal relation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Based on dti_sample_index, convert this to a dict\n",
    "target_crispr_sample_index_dict = dict(zip(overlapped_dti_crispr_rna_samples_df['depMapID'], overlapped_dti_crispr_rna_samples_df.index))\n",
    "# Map columns ['ACH_ID', 'BioMedGraphica_Conn_ID'] by the crispr_sample_index_dict and node_index_dict\n",
    "train_crispr_score_label_melted['ACH_ID'] = train_crispr_score_label_melted['ACH_ID'].map(target_crispr_sample_index_dict)\n",
    "train_crispr_score_label_melted['BioMedGraphica_Conn_ID'] = train_crispr_score_label_melted['BioMedGraphica_Conn_ID'].map(node_index_dict)\n",
    "test_crispr_score_label_melted['ACH_ID'] = test_crispr_score_label_melted['ACH_ID'].map(target_crispr_sample_index_dict)\n",
    "test_crispr_score_label_melted['BioMedGraphica_Conn_ID'] = test_crispr_score_label_melted['BioMedGraphica_Conn_ID'].map(node_index_dict)\n",
    "\n",
    "# Map internal_relation_df columns ['promoterID', 'geneID', 'transcriptID', 'proteinID'] by the node_index_dict\n",
    "internal_relation_map_df = internal_relation_df.copy()\n",
    "internal_relation_map_df['promoterID'] = internal_relation_map_df['promoterID'].map(node_index_dict)\n",
    "internal_relation_map_df['geneID'] = internal_relation_map_df['geneID'].map(node_index_dict)\n",
    "internal_relation_map_df['transcriptID'] = internal_relation_map_df['transcriptID'].map(node_index_dict)\n",
    "internal_relation_map_df['proteinID'] = internal_relation_map_df['proteinID'].map(node_index_dict)\n",
    "\n",
    "ko_internal_relation_map_df = internal_relation_map_df.copy()\n",
    "display(ko_internal_relation_map_df)\n",
    "ko_internal_relation_index_df = ko_internal_relation_map_df.groupby([\"promoterID\", \"geneID\"], as_index=False).agg({\n",
    "    \"transcriptID\": lambda x: list(set(x.dropna())), \n",
    "    \"proteinID\": lambda x: list(set(x.dropna()))  \n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a function to remove .0 suffix from numbers in lists\n",
    "def remove_decimal_suffix(list_with_decimals):\n",
    "    if isinstance(list_with_decimals, list):\n",
    "        return [int(x) if not pd.isna(x) else x for x in list_with_decimals]\n",
    "    return list_with_decimals\n",
    "\n",
    "# Apply the function to both columns\n",
    "ko_internal_relation_index_df['transcriptID'] = ko_internal_relation_index_df['transcriptID'].apply(remove_decimal_suffix)\n",
    "ko_internal_relation_index_df['proteinID'] = ko_internal_relation_index_df['proteinID'].apply(remove_decimal_suffix)\n",
    "\n",
    "# Display the result to verify\n",
    "display(ko_internal_relation_index_df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_crispr_score_label_melted_ko = train_crispr_score_label_melted.merge(ko_internal_relation_index_df, left_on='BioMedGraphica_Conn_ID',right_on='geneID' ,how='left')\n",
    "train_crispr_score_label_melted_ko = train_crispr_score_label_melted_ko.drop(columns=['BioMedGraphica_Conn_ID'])\n",
    "display(train_crispr_score_label_melted_ko)\n",
    "\n",
    "# Create a new column 'merged_ids' by combining the columns into lists\n",
    "train_crispr_score_label_melted_ko['merged_ids'] = train_crispr_score_label_melted_ko.apply(\n",
    "    lambda row: [row['promoterID'], row['geneID']] + \n",
    "                (row['transcriptID'] if isinstance(row['transcriptID'], list) else []) +\n",
    "                (row['proteinID'] if isinstance(row['proteinID'], list) else []),\n",
    "    axis=1\n",
    ")\n",
    "\n",
    "# Display the result\n",
    "train_crispr_score_label_melted_ko = train_crispr_score_label_melted_ko.drop(columns=['promoterID', 'geneID', 'transcriptID', 'proteinID'])\n",
    "display(train_crispr_score_label_melted_ko)\n",
    "\n",
    "# reorder the columns to ['ACH_ID', 'merged_ids', 'Value']\n",
    "train_crispr_score_label_melted_ko = train_crispr_score_label_melted_ko[['ACH_ID', 'merged_ids', 'Value']].copy()\n",
    "display(train_crispr_score_label_melted_ko)\n",
    "# Convert train_crispr_score_label_melted_ko to numpy array and save it to the CRISPR-Graph folder\n",
    "train_crispr_score_label_melted_ko_array = train_crispr_score_label_melted_ko.to_numpy()\n",
    "print('The shape of train_crispr_score_label_melted_ko_array is:', train_crispr_score_label_melted_ko_array.shape)\n",
    "print(train_crispr_score_label_melted_ko_array[:10])  # Print first 10 rows as a sample\n",
    "# Save the numpy array to the CRISPR-Graph folder\n",
    "np.save('./data/TargetScreen/train_crispr_score_label_melted_ko.npy', train_crispr_score_label_melted_ko_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_crispr_score_label_melted_ko = test_crispr_score_label_melted.merge(ko_internal_relation_index_df, left_on='BioMedGraphica_Conn_ID',right_on='geneID' ,how='left')\n",
    "test_crispr_score_label_melted_ko = test_crispr_score_label_melted_ko.drop(columns=['BioMedGraphica_Conn_ID'])\n",
    "display(test_crispr_score_label_melted_ko)\n",
    "\n",
    "# Create a new column 'merged_ids' by combining the columns into lists\n",
    "test_crispr_score_label_melted_ko['merged_ids'] = test_crispr_score_label_melted_ko.apply(\n",
    "    lambda row: [row['promoterID'], row['geneID']] + \n",
    "                (row['transcriptID'] if isinstance(row['transcriptID'], list) else []) +\n",
    "                (row['proteinID'] if isinstance(row['proteinID'], list) else []),\n",
    "    axis=1\n",
    ")\n",
    "\n",
    "# Display the result\n",
    "test_crispr_score_label_melted_ko = test_crispr_score_label_melted_ko.drop(columns=['promoterID', 'geneID', 'transcriptID', 'proteinID'])\n",
    "display(test_crispr_score_label_melted_ko)\n",
    "\n",
    "# reorder the columns to ['ACH_ID', 'merged_ids', 'Value']\n",
    "test_crispr_score_label_melted_ko = test_crispr_score_label_melted_ko[['ACH_ID', 'merged_ids', 'Value']].copy()\n",
    "display(test_crispr_score_label_melted_ko)\n",
    "# Convert test_crispr_score_label_melted_ko to numpy array and save it to the CRISPR-Graph folder\n",
    "test_crispr_score_label_melted_ko_array = test_crispr_score_label_melted_ko.to_numpy()\n",
    "print('The shape of test_crispr_score_label_melted_ko_array is:', test_crispr_score_label_melted_ko_array.shape)\n",
    "print(test_crispr_score_label_melted_ko_array[:10])  # Print first 10 rows as a sample\n",
    "# Save the numpy array to the CRISPR-Graph folder\n",
    "np.save('./data/TargetScreen/test_crispr_score_label_melted_ko.npy', test_crispr_score_label_melted_ko_array)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 9.2.2.2 TargetQA"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Omic Feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select the features from target_crispr_df\n",
    "target_crispr_feat = target_crispr_df.drop(columns=['BioMedGraphica_Conn_ID']).values.T\n",
    "# Save the features to .npy files\n",
    "np.save('./data/TargetQA/target_crispr_feature.npy', target_crispr_feat)\n",
    "# Print the shapes of the features\n",
    "print(\"Shape of target_crispr_feat:\", target_crispr_feat.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Omic Information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# select the columns in the overlapped_dti_crispr_rna_samples for gene_df\n",
    "dti_crispr_rna_gene_df = gene_df[['gene_name'] + sorted(overlapped_dti_crispr_rna_samples)].copy()\n",
    "display(dti_crispr_rna_gene_df)\n",
    "bmgc_protein_llmnameid_combined_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein_LLM_Name_ID_Combined.csv')\n",
    "display(bmgc_protein_llmnameid_combined_df)\n",
    "\n",
    "def extract_gn_info(dti_crispr_rna_gene_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k=10):\n",
    "    # Check if sample_ach_name exists in dti_crispr_rna_gene_df columns\n",
    "    if sample_ach_name not in dti_crispr_rna_gene_df.columns:\n",
    "        return \"non-existed\", \"non-existed\", \"non-existed\"\n",
    "    # Extract the top k highest values for the given sample name\n",
    "    top_k_genes = dti_crispr_rna_gene_df.nlargest(k, sample_ach_name)[['gene_name', sample_ach_name]]\n",
    "    # Sort the top k genes by their values in descending order\n",
    "    top_k_genes = top_k_genes.sort_values(by=sample_ach_name, ascending=False).reset_index(drop=True)\n",
    "    top_k_gene_hgnc_name_list = top_k_genes['gene_name'].tolist()\n",
    "    # Merge with the bmgc_gene_df to get the BioMedGraphica_Conn_ID\n",
    "    bmgc_gene_df = bmgc_gene_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']].copy()\n",
    "    top_k_bmgc_gene_df = pd.merge(bmgc_gene_df, top_k_genes, left_on='HGNC_Symbol', right_on='gene_name', how='inner')\n",
    "    # Get the corresponding proteins\n",
    "    top_k_bmgc_gene_protein_df = pd.merge(gene_transcript_protein_entity_df, top_k_bmgc_gene_df, left_on='BMGC_GN_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'gene_name'])\n",
    "    top_k_bmgc_gene_protein_info_df = pd.merge(top_k_bmgc_gene_protein_df, bmgc_protein_llmnameid_combined_df, left_on='BMGC_PT_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'BioMedGraphica_ID', sample_ach_name])\n",
    "    top_k_gene_protein_bmgc_id_list = top_k_bmgc_gene_protein_info_df['BMGC_PT_ID'].tolist()\n",
    "    top_k_gene_protein_bmgc_llmnameid_list = top_k_bmgc_gene_protein_info_df['Names_and_IDs'].replace(r' \\| ', ' or ', regex=True).tolist()\n",
    "    return top_k_gene_hgnc_name_list, top_k_gene_protein_bmgc_id_list, top_k_gene_protein_bmgc_llmnameid_list\n",
    "\n",
    "# Example usage\n",
    "sample_ach_name = 'ACH-000001'\n",
    "k=10\n",
    "top_k_gene_hgnc_name_list, top_k_gene_protein_bmgc_id_list, top_k_gene_protein_bmgc_llmnameid_list = extract_gn_info(dti_crispr_rna_gene_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)\n",
    "print(f\"Top {k} Gene HGNC Names:\", top_k_gene_hgnc_name_list)\n",
    "print(f\"Top {k} Gene Protein BMGC IDs:\", top_k_gene_protein_bmgc_id_list)\n",
    "print(f\"Top {k} Gene Protein BMGC LLM Name IDs:\", top_k_gene_protein_bmgc_llmnameid_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# select the columns in the overlapped_dti_crispr_rna_samples for raw_transcript_df\n",
    "transcript_overlapped_dti_crispr_rna_samples = sorted(list(set(raw_transcript_df.columns[1:]) & set(overlapped_dti_crispr_rna_samples)))\n",
    "dti_crispr_rna_transcript_df = raw_transcript_df[['gene_name'] + transcript_overlapped_dti_crispr_rna_samples].copy()\n",
    "display(dti_crispr_rna_transcript_df)\n",
    "bmgc_transcript_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Transcript/BioMedGraphica_Conn_Transcript.csv')\n",
    "\n",
    "def extract_ts_info(dti_crispr_rna_transcript_df, bmgc_transcript_df, transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k=10):\n",
    "    # Check if sample_ach_name exists in dti_crispr_rna_transcript_df columns\n",
    "    if sample_ach_name not in dti_crispr_rna_transcript_df.columns:\n",
    "        return \"non-existed\", \"non-existed\", \"non-existed\"\n",
    "    # Extract the top k highest values for the given sample name\n",
    "    top_k_transcripts = dti_crispr_rna_transcript_df.nlargest(k, sample_ach_name)[['gene_name', sample_ach_name]]\n",
    "    # Sort the top k transcripts by their values in descending order\n",
    "    top_k_transcripts = top_k_transcripts.sort_values(by=sample_ach_name, ascending=False).reset_index(drop=True)\n",
    "    top_k_transcript_hgnc_name_list = top_k_transcripts['gene_name'].tolist()\n",
    "    # Merge with the bmgc_transcript_df to get the BioMedGraphica_Conn_ID\n",
    "    bmgc_transcript_df = bmgc_transcript_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']].copy()\n",
    "    top_k_bmgc_transcript_df = pd.merge(bmgc_transcript_df, top_k_transcripts, left_on='HGNC_Symbol', right_on='gene_name', how='inner')\n",
    "    # Get the corresponding proteins\n",
    "    top_k_bmgc_transcript_protein_df = pd.merge(transcript_protein_entity_df, top_k_bmgc_transcript_df, left_on='BMGC_TS_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'gene_name'])\n",
    "    top_k_bmgc_transcript_protein_info_df = pd.merge(top_k_bmgc_transcript_protein_df, bmgc_protein_llmnameid_combined_df, left_on='BMGC_PT_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'BioMedGraphica_ID', sample_ach_name])\n",
    "    top_k_transcript_protein_bmgc_id_list = top_k_bmgc_transcript_protein_info_df['BMGC_PT_ID'].tolist()\n",
    "    top_k_transcript_protein_bmgc_llmnameid_list = top_k_bmgc_transcript_protein_info_df['Names_and_IDs'].replace(r' \\| ', ' or ', regex=True).tolist()\n",
    "    return top_k_transcript_hgnc_name_list, top_k_transcript_protein_bmgc_id_list, top_k_transcript_protein_bmgc_llmnameid_list\n",
    "\n",
    "# Example usage\n",
    "sample_ach_name = 'ACH-000001'\n",
    "k=10\n",
    "top_k_transcript_hgnc_name_list, top_k_transcript_protein_bmgc_id_list, top_k_transcript_protein_bmgc_llmnameid_list = extract_ts_info(dti_crispr_rna_transcript_df, bmgc_transcript_df, transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)\n",
    "print(f\"Top {k} Transcript HGNC Names:\", top_k_transcript_hgnc_name_list)\n",
    "print(f\"Top {k} Transcript Protein BMGC IDs:\", top_k_transcript_protein_bmgc_id_list)\n",
    "print(f\"Top {k} Transcript Protein BMGC LLM Name IDs:\", top_k_transcript_protein_bmgc_llmnameid_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bmg_protein_all_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein.csv')\n",
    "bmg_protein_all_df = bmg_protein_all_df[['BioMedGraphica_Conn_ID', 'Uniprot_ID', 'HGNC_Symbol']].copy()\n",
    "display(bmg_protein_all_df)\n",
    "\n",
    "# Rename columns in raw_protein_df using the provided mapping\n",
    "raw_protein_map_df = raw_protein_df.rename(columns=protein_map_dict)\n",
    "# Merge with bmg_protein_all_df to get HGNC symbols\n",
    "symbol_protein_map_df = pd.merge(raw_protein_map_df, bmg_protein_all_df, left_on='Uniprot_Acc', right_on='Uniprot_ID', how='inner')\n",
    "# Reorder columns: keep Uniprot IDs and protein expression values\n",
    "symbol_protein_map_df = symbol_protein_map_df[['Uniprot_ID', 'Uniprot_Acc', 'HGNC_Symbol'] + sorted(set(symbol_protein_map_df.columns) - {'Uniprot_ID', 'Uniprot_Acc', 'HGNC_Symbol'})]\n",
    "# Identify overlapping samples between protein data and the provided sample list\n",
    "protein_overlapped_samples = sorted(set(symbol_protein_map_df.columns) & set(overlapped_dti_crispr_rna_samples))\n",
    "# Select only HGNC symbol and overlapping sample columns\n",
    "dti_crispr_rna_protein_df = symbol_protein_map_df[['HGNC_Symbol'] + protein_overlapped_samples].copy()\n",
    "# Split multiple HGNC symbols by \";\" and expand into multiple rows\n",
    "dti_crispr_rna_protein_df = dti_crispr_rna_protein_df.assign(HGNC_Symbol=dti_crispr_rna_protein_df['HGNC_Symbol'].str.split(';')).explode('HGNC_Symbol')\n",
    "# Remove leading/trailing whitespace in gene symbols\n",
    "dti_crispr_rna_protein_df['HGNC_Symbol'] = dti_crispr_rna_protein_df['HGNC_Symbol'].str.strip()\n",
    "# Drop rows with empty or missing gene symbols\n",
    "dti_crispr_rna_protein_df = dti_crispr_rna_protein_df[dti_crispr_rna_protein_df['HGNC_Symbol'].notna() & (dti_crispr_rna_protein_df['HGNC_Symbol'] != '')].reset_index(drop=True)\n",
    "# Display final dataframe\n",
    "display(dti_crispr_rna_protein_df)\n",
    "\n",
    "bmgc_protein_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein.csv')\n",
    "\n",
    "def extract_pt_info(dti_crispr_rna_protein_df, bmgc_protein_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k=10):\n",
    "    # Check if sample_ach_name exists in dti_crispr_rna_protein_df columns\n",
    "    if sample_ach_name not in dti_crispr_rna_protein_df.columns:\n",
    "        return \"non-existed\", \"non-existed\", \"non-existed\"\n",
    "    # Extract the top k highest values for the given sample name\n",
    "    top_k_proteins = dti_crispr_rna_protein_df.nlargest(k, sample_ach_name)[['HGNC_Symbol', sample_ach_name]]\n",
    "    # Sort the top k proteins by their values in descending order\n",
    "    top_k_proteins = top_k_proteins.sort_values(by=sample_ach_name, ascending=False).reset_index(drop=True)\n",
    "    top_k_protein_hgnc_name_list = top_k_proteins['HGNC_Symbol'].tolist()\n",
    "    # Merge with the bmgc_protein_df to get the BioMedGraphica_Conn_ID\n",
    "    bmgc_protein_df = bmgc_protein_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']].copy()\n",
    "    top_k_bmgc_protein_df = pd.merge(bmgc_protein_df, top_k_proteins, left_on='HGNC_Symbol', right_on='HGNC_Symbol', how='inner')\n",
    "    # Get the corresponding protein information\n",
    "    top_k_bmgc_protein_info_df = pd.merge(top_k_bmgc_protein_df, bmgc_protein_llmnameid_combined_df, left_on='BioMedGraphica_Conn_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_ID', sample_ach_name])\n",
    "    top_k_protein_bmgc_id_list = top_k_bmgc_protein_info_df['BioMedGraphica_Conn_ID'].tolist()\n",
    "    # Replace both \"|\" and \";\" with \" or \"\n",
    "    top_k_protein_bmgc_llmnameid_list = top_k_bmgc_protein_info_df['Names_and_IDs'].replace([r' \\| ', r';'], ' or ', regex=True).tolist()\n",
    "    return top_k_protein_hgnc_name_list, top_k_protein_bmgc_id_list, top_k_protein_bmgc_llmnameid_list\n",
    "\n",
    "# Example usage\n",
    "sample_ach_name = 'ACH-000001'\n",
    "k = 10\n",
    "top_k_protein_hgnc_name_list, top_k_protein_bmgc_id_list, top_k_protein_bmgc_llmnameid_list = extract_pt_info(dti_crispr_rna_protein_df, bmgc_protein_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)\n",
    "print(f\"Top {k} Protein HGNC Names:\", top_k_protein_hgnc_name_list)\n",
    "print(f\"Top {k} Protein BMGC IDs:\", top_k_protein_bmgc_id_list)\n",
    "print(f\"Top {k} Protein BMGC LLM Name IDs:\", top_k_protein_bmgc_llmnameid_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Related Proteins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bmgc_pt_id_to_hgnc(bmgc_id_list, bmgc_protein_df):\n",
    "    \"\"\"\n",
    "    Convert a list of BioMedGraphica IDs to their corresponding HGNC symbols.\n",
    "    \n",
    "    Args:\n",
    "        bmgc_id_list (list): List of BioMedGraphica IDs\n",
    "        bmgc_protein_df (pd.DataFrame): DataFrame with BioMedGraphica IDs and HGNC symbols\n",
    "        \n",
    "    Returns:\n",
    "        tuple: (\n",
    "            dict: Dictionary mapping each BioMedGraphica ID to its list of HGNC symbols,\n",
    "            list: Combined list of all HGNC symbols\n",
    "        )\n",
    "    \"\"\"\n",
    "    # Ensure bmgc_id_list is actually a list\n",
    "    if not isinstance(bmgc_id_list, list):\n",
    "        bmgc_id_list = [bmgc_id_list]\n",
    "    results = {}\n",
    "    all_hgnc_symbols = []\n",
    "    for bmgc_id in bmgc_id_list:\n",
    "        # Filter the DataFrame for the given BioMedGraphica ID\n",
    "        filtered_df = bmgc_protein_df[bmgc_protein_df['BioMedGraphica_Conn_ID'] == bmgc_id]\n",
    "        # Skip if no match found\n",
    "        if filtered_df.empty:\n",
    "            results[bmgc_id] = []\n",
    "            continue\n",
    "        # Get the HGNC symbols\n",
    "        hgnc_value = filtered_df['HGNC_Symbol'].values[0]\n",
    "        # Skip if HGNC symbol is NaN\n",
    "        if pd.isna(hgnc_value):\n",
    "            results[bmgc_id] = []\n",
    "            continue\n",
    "        # Process valid HGNC symbols\n",
    "        hgnc_list = list(set(hgnc_value.split(';')))\n",
    "        hgnc_list = [hgnc.strip() for hgnc in hgnc_list if hgnc.strip() != '']\n",
    "        results[bmgc_id] = hgnc_list\n",
    "        all_hgnc_symbols.extend(hgnc_list)\n",
    "    # Remove duplicates from the combined list\n",
    "    all_hgnc_symbols = list(set(all_hgnc_symbols))\n",
    "    return results, all_hgnc_symbols\n",
    "\n",
    "def hgnc_to_bmgc_pt_id(hgnc_list, bmgc_protein_df):\n",
    "    \"\"\"\n",
    "    Convert a list of HGNC symbols to their corresponding BioMedGraphica IDs.\n",
    "    \n",
    "    Args:\n",
    "        hgnc_list (list): List of HGNC symbols\n",
    "        bmgc_protein_df (pd.DataFrame): DataFrame with BioMedGraphica IDs and HGNC symbols\n",
    "        \n",
    "    Returns:\n",
    "        tuple: (\n",
    "            dict: Dictionary mapping each HGNC symbol to its list of BioMedGraphica IDs,\n",
    "            list: Combined list of all BioMedGraphica IDs\n",
    "        )\n",
    "    \"\"\"\n",
    "    # Ensure hgnc_list is actually a list\n",
    "    if not isinstance(hgnc_list, list):\n",
    "        hgnc_list = [hgnc_list]\n",
    "    results = {}\n",
    "    all_bmgc_ids = []\n",
    "    for hgnc in hgnc_list:\n",
    "        # Filter the DataFrame for the given HGNC symbol\n",
    "        filtered_df = bmgc_protein_df[bmgc_protein_df['HGNC_Symbol'] == hgnc]\n",
    "        # Skip if no match found\n",
    "        if filtered_df.empty:\n",
    "            results[hgnc] = []\n",
    "            continue\n",
    "        # Get the BioMedGraphica IDs\n",
    "        bmgc_value = filtered_df['BioMedGraphica_Conn_ID'].values[0]\n",
    "        # Skip if BioMedGraphica ID is NaN\n",
    "        if pd.isna(bmgc_value):\n",
    "            results[hgnc] = []\n",
    "            continue\n",
    "        # Process valid BioMedGraphica IDs\n",
    "        bmgc_list = list(set(bmgc_value.split(';')))\n",
    "        bmgc_list = [bmgc.strip() for bmgc in bmgc_list if bmgc.strip() != '']\n",
    "        results[hgnc] = bmgc_list\n",
    "        all_bmgc_ids.extend(bmgc_list)\n",
    "    # Remove duplicates from the combined list\n",
    "    all_bmgc_ids = list(set(all_bmgc_ids))\n",
    "    return results, all_bmgc_ids\n",
    "\n",
    "# Example usage\n",
    "bmgc_ids = ['BMGC_PT000001', 'BMGC_PT013541']\n",
    "hgnc_dict, all_hgnc_symbols = bmgc_pt_id_to_hgnc(bmgc_ids, bmgc_protein_df)\n",
    "print(hgnc_dict)\n",
    "print(all_hgnc_symbols)\n",
    "# Example usage\n",
    "hgnc_list = ['BRCA1', 'TP53']\n",
    "bmgc_dict, all_bmgc_ids = hgnc_to_bmgc_pt_id(hgnc_list, bmgc_protein_df)\n",
    "print(bmgc_dict)\n",
    "print(all_bmgc_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_disease_protein(selected_sample_disease_bmgc_id, edge_index, \n",
    "                    node_index_df, nodeid_index_dict, index_nodeid_dict):\n",
    "    # Extract the index based on the selected disease BMGC ID\n",
    "    sample_disease_bmgc_id_index = nodeid_index_dict[selected_sample_disease_bmgc_id]\n",
    "    # Find incoming edges (source nodes that point to the disease)\n",
    "    incoming_mask = edge_index[1, :] == sample_disease_bmgc_id_index\n",
    "    incoming_source_nodes = edge_index[0, incoming_mask]\n",
    "    # Find outgoing edges (target nodes that the disease points to)\n",
    "    outgoing_mask = edge_index[0, :] == sample_disease_bmgc_id_index\n",
    "    outgoing_target_nodes = edge_index[1, outgoing_mask]\n",
    "    # Combine all neighbor nodes (both incoming and outgoing)\n",
    "    disease_related_nodes = np.concatenate([incoming_source_nodes, outgoing_target_nodes])\n",
    "    unique_disease_related_nodes = np.unique(disease_related_nodes)\n",
    "    # Get protein node index\n",
    "    protein_node_index_df = node_index_df[node_index_df['Type'] == 'Protein']\n",
    "    protein_node_index_list = protein_node_index_df['Index'].tolist()\n",
    "    # Filter to get only protein nodes directly connected to the disease\n",
    "    disease_protein_index = sorted(\n",
    "        list(set(unique_disease_related_nodes) & set(protein_node_index_list))\n",
    "    )\n",
    "    # Map protein index to BMGC id\n",
    "    disease_protein_bmgc_id = [index_nodeid_dict[i] for i in disease_protein_index]\n",
    "    return disease_protein_index, disease_protein_bmgc_id\n",
    "\n",
    "def extract_ppi_nodes(disease_protein_index, edge_index, node_index_df, index_nodeid_dict):\n",
    "    # Get protein node index\n",
    "    protein_node_index_df = node_index_df[node_index_df['Type'] == 'Protein']\n",
    "    protein_node_index_list = protein_node_index_df['Index'].tolist()\n",
    "    # Get all nodes related to the identified protein neighbors (second hop)\n",
    "    protein_related_nodes = []\n",
    "    # Iterate through each protein neighbor node index\n",
    "    for protein_node_idx in disease_protein_index:\n",
    "        # Find incoming edges (nodes that point to this protein)\n",
    "        protein_incoming_mask = edge_index[1, :] == protein_node_idx\n",
    "        protein_incoming_sources = edge_index[0, protein_incoming_mask]\n",
    "        # Find outgoing edges (nodes that this protein points to)\n",
    "        protein_outgoing_mask = edge_index[0, :] == protein_node_idx\n",
    "        protein_outgoing_targets = edge_index[1, protein_outgoing_mask]\n",
    "        # Add these connected nodes to our list\n",
    "        protein_related_nodes.extend(protein_incoming_sources)\n",
    "        protein_related_nodes.extend(protein_outgoing_targets)\n",
    "    # Convert to numpy array and get unique nodes\n",
    "    protein_related_nodes = np.array(protein_related_nodes)\n",
    "    unique_protein_related_nodes = np.unique(protein_related_nodes)\n",
    "    # Remove any protein nodes themselves from this list to avoid duplication\n",
    "    unique_protein_related_nodes = np.setdiff1d(\n",
    "        unique_protein_related_nodes, disease_protein_index\n",
    "    )\n",
    "    # Filter to only keep protein nodes among the second-hop neighbors\n",
    "    ppi_nodes_index = sorted(\n",
    "        list(set(unique_protein_related_nodes) & set(protein_node_index_list))\n",
    "    )\n",
    "    # Map PPI node index to BMGC id\n",
    "    ppi_nodes_bmgc_id = [index_nodeid_dict[i] for i in ppi_nodes_index]\n",
    "    return ppi_nodes_index, ppi_nodes_bmgc_id\n",
    "\n",
    "def extract_kg_related_proteins(selected_sample_disease_bmgc_id, edge_index,  # './data/DTI_data/edge_index.npy'\n",
    "                               node_index_df, nodeid_index_dict, index_nodeid_dict):\n",
    "    \"\"\"\n",
    "    Extract disease-related proteins and their interactions from the knowledge graph.\n",
    "    \n",
    "    Args:\n",
    "        selected_sample_disease_bmgc_id (str): BMGC ID of the selected disease\n",
    "        edge_index (str): Edge index file\n",
    "        node_index_df (pd.DataFrame): DataFrame with node type information\n",
    "        nodeid_index_dict (dict): Mapping from node ID to index\n",
    "        index_nodeid_dict (dict): Mapping from index to node ID\n",
    "    \n",
    "    Returns:\n",
    "        tuple: (disease_protein_index, disease_protein_bmgc_id, ppi_nodes_index, ppi_nodes_bmgc_id)\n",
    "    \"\"\"\n",
    "    \n",
    "    # Extract disease-protein connections\n",
    "    disease_protein_index, disease_protein_bmgc_id = extract_disease_protein(\n",
    "        selected_sample_disease_bmgc_id=selected_sample_disease_bmgc_id,\n",
    "        edge_index=edge_index,\n",
    "        node_index_df=node_index_df,\n",
    "        nodeid_index_dict=nodeid_index_dict,\n",
    "        index_nodeid_dict=index_nodeid_dict\n",
    "    )\n",
    "\n",
    "    # Extract protein-protein interactions (can replace this with LLM to generate PPI, may need NER and mapping to BMGC id)\n",
    "    ppi_nodes_index, ppi_nodes_bmgc_id = extract_ppi_nodes(\n",
    "        disease_protein_index=disease_protein_index,\n",
    "        edge_index=edge_index,\n",
    "        node_index_df=node_index_df,\n",
    "        index_nodeid_dict=index_nodeid_dict\n",
    "    )\n",
    "\n",
    "    # Convert the BMGC id into HGNC symbol\n",
    "\n",
    "    return disease_protein_index, disease_protein_bmgc_id, ppi_nodes_index, ppi_nodes_bmgc_id\n",
    "\n",
    "# Example usage\n",
    "selected_sample_disease_bmgc_id = 'BMGC_DS07934'\n",
    "edge_index = np.load('./data/TargetQA/edge_index.npy')\n",
    "index_node_dict = dict(zip(nodes_index_data['Index'], nodes_index_data['Node']))\n",
    "\n",
    "disease_protein_index, disease_protein_bmgc_id, ppi_nodes_index, ppi_nodes_bmgc_id = extract_kg_related_proteins(\n",
    "    selected_sample_disease_bmgc_id=selected_sample_disease_bmgc_id,\n",
    "    edge_index=edge_index,\n",
    "    node_index_df=nodes_index_data,\n",
    "    nodeid_index_dict=node_index_dict,\n",
    "    index_nodeid_dict=index_node_dict\n",
    ")\n",
    "print(\"Disease Protein Index:\", disease_protein_index)\n",
    "print(\"Disease Protein BMGC ID:\", disease_protein_bmgc_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### CRISPR Answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "answer_crispr_df = crispr_df[['gene_name'] + sorted(overlapped_dti_crispr_rna_samples)].copy()\n",
    "display(answer_crispr_df)\n",
    "\n",
    "def extract_answer(answer_crispr_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, top_bm=100):\n",
    "    # Extract the lowest top_bm values for the given sample name\n",
    "    top_bm_genes = answer_crispr_df.nsmallest(top_bm, sample_ach_name)[['gene_name', sample_ach_name]]\n",
    "    # Sort the top bm genes by their values in ascending order\n",
    "    top_bm_genes = top_bm_genes.sort_values(by=sample_ach_name, ascending=True).reset_index(drop=True)\n",
    "    top_bm_gene_hgnc_name_list = top_bm_genes['gene_name'].tolist()\n",
    "    # Merge with the bmgc_gene_df to get the BioMedGraphica_Conn_ID\n",
    "    bmgc_gene_df = bmgc_gene_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']].copy()\n",
    "    top_bm_bmgc_gene_df = pd.merge(bmgc_gene_df, top_bm_genes, left_on='HGNC_Symbol', right_on='gene_name', how='inner')\n",
    "    # Get the corresponding proteins\n",
    "    top_bm_bmgc_gene_protein_df = pd.merge(gene_transcript_protein_entity_df, top_bm_bmgc_gene_df, left_on='BMGC_GN_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'gene_name'])\n",
    "    top_bm_bmgc_gene_protein_info_df = pd.merge(top_bm_bmgc_gene_protein_df, bmgc_protein_llmnameid_combined_df, left_on='BMGC_PT_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'BioMedGraphica_ID', sample_ach_name])\n",
    "    top_bm_gene_protein_bmgc_id_list = top_bm_bmgc_gene_protein_info_df['BMGC_PT_ID'].tolist()\n",
    "    top_bm_gene_protein_bmgc_llmnameid_list = top_bm_bmgc_gene_protein_info_df['Names_and_IDs'].replace(r' \\| ', ' or ', regex=True).tolist()\n",
    "    return top_bm_gene_hgnc_name_list, top_bm_gene_protein_bmgc_id_list, top_bm_gene_protein_bmgc_llmnameid_list\n",
    "\n",
    "# Example usage\n",
    "sample_ach_name = 'ACH-000001'\n",
    "top_bm = 10\n",
    "bmgc_protein_llmnameid_combined_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein_LLM_Name_ID_Combined.csv')\n",
    "top_bm_gene_hgnc_name_list, top_bm_gene_protein_bmgc_id_list, top_bm_gene_protein_bmgc_llmnameid_list = extract_answer(answer_crispr_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, top_bm)\n",
    "print(f\"Top {top_bm} Gene HGNC Names:\", top_bm_gene_hgnc_name_list)\n",
    "print(f\"Top {top_bm} Gene Protein BMGC IDs:\", top_bm_gene_protein_bmgc_id_list)\n",
    "print(f\"Top {top_bm} Gene Protein BMGC LLM Name IDs:\", top_bm_gene_protein_bmgc_llmnameid_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Knowledge Graph Information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_protein_relationships(hgnc_symbols, bmgc_protein_df, bmgc_relation_df):\n",
    "    \"\"\"\n",
    "    Find relationships between a list of proteins based on HGNC symbols.\n",
    "    \n",
    "    Parameters:\n",
    "    -----------\n",
    "    hgnc_symbols : list\n",
    "        List of HGNC symbols to find relationships between\n",
    "    bmgc_protein_df : pandas DataFrame\n",
    "        DataFrame containing BioMedGraphica_Conn_ID and HGNC_Symbol columns\n",
    "    bmgc_relation_df : pandas DataFrame\n",
    "        DataFrame containing BMGC_From_ID and BMGC_To_ID columns\n",
    "        \n",
    "    Returns:\n",
    "    --------\n",
    "    tuple (pandas DataFrame, list)\n",
    "        - DataFrame with source_symbol, target_symbol and their relationship\n",
    "        - List of text descriptions of relationships in \"A -> B\" format\n",
    "    \"\"\"\n",
    "    # Filter the protein DataFrame to only include the proteins we care about\n",
    "    filtered_proteins = bmgc_protein_df[bmgc_protein_df['HGNC_Symbol'].isin(hgnc_symbols)]\n",
    "    # Create a mapping from HGNC symbol to BMGC ID\n",
    "    hgnc_to_bmgc = dict(zip(filtered_proteins['HGNC_Symbol'], filtered_proteins['BioMedGraphica_Conn_ID']))\n",
    "    bmgc_to_hgnc = dict(zip(filtered_proteins['BioMedGraphica_Conn_ID'], filtered_proteins['HGNC_Symbol']))\n",
    "    # Get all BMGC IDs of our proteins\n",
    "    bmgc_ids = list(hgnc_to_bmgc.values())\n",
    "    # Filter the relationship DataFrame to only include relationships between our proteins\n",
    "    protein_relations = bmgc_relation_df[\n",
    "        bmgc_relation_df['BMGC_From_ID'].isin(bmgc_ids) & \n",
    "        bmgc_relation_df['BMGC_To_ID'].isin(bmgc_ids)\n",
    "    ]\n",
    "    # Map the BMGC IDs back to HGNC symbols\n",
    "    result_data = []\n",
    "    text_descriptions = []\n",
    "    for _, row in protein_relations.iterrows():\n",
    "        source_bmgc = row['BMGC_From_ID']\n",
    "        target_bmgc = row['BMGC_To_ID']\n",
    "        if source_bmgc in bmgc_to_hgnc and target_bmgc in bmgc_to_hgnc:\n",
    "            source_symbol = bmgc_to_hgnc[source_bmgc]\n",
    "            target_symbol = bmgc_to_hgnc[target_bmgc]\n",
    "            # Create text description\n",
    "            text_description = f\"{source_symbol} -> {target_symbol}\"\n",
    "            text_descriptions.append(text_description)\n",
    "            # If relation_type column exists, include it in the description and data\n",
    "            relation_info = {\n",
    "                'source_symbol': source_symbol,\n",
    "                'target_symbol': target_symbol\n",
    "            }\n",
    "            # Add relation type if it exists in the DataFrame\n",
    "            if 'relation_type' in bmgc_relation_df.columns:\n",
    "                relation_type = row['relation_type']\n",
    "                relation_info['relation_type'] = relation_type\n",
    "                text_descriptions[-1] = f\"{source_symbol} -{relation_type}-> {target_symbol}\"\n",
    "            result_data.append(relation_info)\n",
    "    # Create a DataFrame from the results\n",
    "    result_df = pd.DataFrame(result_data)\n",
    "    return result_df, text_descriptions\n",
    "\n",
    "# Define the HGNC symbols\n",
    "hgnc_symbols = ['SNRPD3', 'RAN', 'RPS8', 'UBL5', 'SMU1', 'RRM1', 'PSMA6', 'PSMB3', 'WEE1', \n",
    "                'PHB1', 'BANF1', 'KIF11', 'SNRPD1', 'PSMA3', 'PSMD11', 'PRPF19', 'SNRPF', \n",
    "                'RPS29', 'CDC27', 'SRSF3', 'TUBGCP2', 'ECD', 'RPS20', 'PCNA', 'PSMA7', 'CDC7', \n",
    "                'RPL17', 'GINS1', 'PHB2', 'SRSF2', 'MAD2L1', 'MED14']\n",
    "\n",
    "# Call the function to find relationships\n",
    "relationships_df, relationship_texts = find_protein_relationships(hgnc_symbols, bmgc_protein_df, bmgc_relation_df)\n",
    "\n",
    "# You can now use both the DataFrame and text descriptions\n",
    "# Example usage:\n",
    "print(f\"Found {len(relationship_texts)} relationships between the proteins\")\n",
    "print(\"Example relationships:\")\n",
    "print(relationship_texts[:5])  # Print first 5 relationships"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Formulate QA JSON"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def qa_sample_info(sample_ach_name, selected_sample_disease_bmgc_id, k=100, top_bm=100):\n",
    "    # LLM Info\n",
    "    print(f\"Sample ACH Name: {sample_ach_name}\")\n",
    "    print(f\"Extracting top {k} gene information for {sample_ach_name}...\")\n",
    "    top_k_gene_hgnc_name_list, top_k_gene_protein_bmgc_id_list, top_k_gene_protein_bmgc_llmnameid_list = extract_gn_info(dti_crispr_rna_gene_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)\n",
    "    print(f\"Extracting top {k} transcript information for {sample_ach_name}...\")\n",
    "    top_k_transcript_hgnc_name_list, top_k_transcript_protein_bmgc_id_list, top_k_transcript_protein_bmgc_llmnameid_list = extract_ts_info(dti_crispr_rna_transcript_df, bmgc_transcript_df, transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)\n",
    "    print(f\"Extracting top {k} protein information for {sample_ach_name}...\")\n",
    "    top_k_protein_hgnc_name_list, top_k_protein_bmgc_id_list, top_k_protein_bmgc_llmnameid_list = extract_pt_info(dti_crispr_rna_protein_df, bmgc_protein_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)\n",
    "    # KG Info\n",
    "    edge_index = np.load('./data/TargetQA/edge_index.npy')\n",
    "    print(f\"Extracting disease-related proteins index and bmgc id for {selected_sample_disease_bmgc_id} ({sample_ach_name}) ...\")\n",
    "    disease_protein_index, disease_protein_bmgc_id, ppi_nodes_index, ppi_nodes_bmgc_id = extract_kg_related_proteins(selected_sample_disease_bmgc_id, edge_index, nodes_index_data, node_index_dict, index_node_dict)\n",
    "    print(f\"Knowledge Graph Info: Found {len(disease_protein_index)} disease-related proteins directly connected to {selected_sample_disease_bmgc_id} and {len(ppi_nodes_index)} proteins in their PPI network\")\n",
    "    print(f\"Mapping disease-related proteins to HGNC symbols...\")\n",
    "    disease_protein_hgnc_dict, disease_protein_hgnc_list = bmgc_pt_id_to_hgnc(disease_protein_bmgc_id, bmgc_protein_df)\n",
    "    print(f\"Mapping PPI-related proteins to HGNC symbols...\")\n",
    "    ppi_hgnc_dict, ppi_hgnc_list = bmgc_pt_id_to_hgnc(ppi_nodes_bmgc_id, bmgc_protein_df)\n",
    "    # LLM Used KG Info\n",
    "    print(f\"Extracting protein relationships from BMGC...\")\n",
    "    # Convert the any non-existed string in top_k_gene_hgnc_name_list + top_k_transcript_hgnc_name_list + top_k_protein_hgnc_name_list + disease_protein_hgnc_list into empty list []\n",
    "    if top_k_gene_hgnc_name_list == \"non-existed\": top_k_gene_hgnc_name_list = []\n",
    "    if top_k_transcript_hgnc_name_list == \"non-existed\": top_k_transcript_hgnc_name_list = []\n",
    "    if top_k_protein_hgnc_name_list == \"non-existed\": top_k_protein_hgnc_name_list = []\n",
    "    if disease_protein_hgnc_list == \"non-existed\": disease_protein_hgnc_list = []\n",
    "    # Combine all the HGNC symbols into a single list for relationship extraction\n",
    "    omics_disease_protein_hgnc_list = list(set(top_k_gene_hgnc_name_list + top_k_transcript_hgnc_name_list + top_k_protein_hgnc_name_list + disease_protein_hgnc_list))\n",
    "    relationships_df, relationship_texts = find_protein_relationships(omics_disease_protein_hgnc_list, bmgc_protein_df, bmgc_relation_df)\n",
    "    print(f\"Knowledge Graph Info: Found {len(omics_disease_protein_hgnc_list)} unique proteins and {len(relationship_texts)} relationships between them\")\n",
    "    # Answer Info\n",
    "    print(f\"Extracting top {top_bm} CRISPR gene information for {sample_ach_name}...\")\n",
    "    top_bm_gene_hgnc_name_list, top_bm_gene_protein_bmgc_id_list, top_bm_gene_protein_bmgc_llmnameid_list = extract_answer(answer_crispr_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, top_bm)\n",
    "    return top_k_gene_hgnc_name_list, top_k_gene_protein_bmgc_id_list, top_k_gene_protein_bmgc_llmnameid_list, \\\n",
    "              top_k_transcript_hgnc_name_list, top_k_transcript_protein_bmgc_id_list, top_k_transcript_protein_bmgc_llmnameid_list, \\\n",
    "                top_k_protein_hgnc_name_list, top_k_protein_bmgc_id_list, top_k_protein_bmgc_llmnameid_list, \\\n",
    "                disease_protein_index, disease_protein_bmgc_id, ppi_nodes_index, ppi_nodes_bmgc_id, \\\n",
    "                    disease_protein_hgnc_dict, disease_protein_hgnc_list, ppi_hgnc_dict, ppi_hgnc_list, relationship_texts, \\\n",
    "                        top_bm_gene_hgnc_name_list, top_bm_gene_protein_bmgc_id_list, top_bm_gene_protein_bmgc_llmnameid_list\n",
    "\n",
    "# Example usage\n",
    "sample_ach_name = 'ACH-000001'\n",
    "selected_sample_disease_bmgc_id = 'BMGC_DS07934'\n",
    "k = 100\n",
    "top_bm = 100\n",
    "return_tuples = qa_sample_info(sample_ach_name, selected_sample_disease_bmgc_id, k=k, top_bm=top_bm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Parameters\n",
    "k = 10\n",
    "top_bm = 100\n",
    "save_every_n = 10\n",
    "\n",
    "# Output folder and filename\n",
    "output_dir = \"./data/TargetQA\"\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "output_filename = f\"target_qa_k{k}_bm{top_bm}.json\"\n",
    "output_path = os.path.join(output_dir, output_filename)\n",
    "\n",
    "# Load existing JSON file if it exists\n",
    "if os.path.exists(output_path):\n",
    "    with open(output_path, \"r\") as f:\n",
    "        multi_sample_qa_json = json.load(f)\n",
    "    print(f\"Loaded existing JSON with {len(multi_sample_qa_json)} processed samples\")\n",
    "else:\n",
    "    multi_sample_qa_json = {}\n",
    "    print(\"No existing JSON found, starting fresh\")\n",
    "\n",
    "# Load sample info data\n",
    "dti_sample_info_index = pd.read_csv('./data/process_data/dti_combined_samples.csv')\n",
    "target_sample_info_index = pd.merge(dti_sample_info_index, overlapped_dti_crispr_rna_samples_df['depMapID'], how='inner', on='depMapID')\n",
    "target_sample_info_index = target_sample_info_index[target_sample_info_index['depMapID'].isin(overlapped_dti_crispr_rna_samples)].reset_index(drop=True)\n",
    "target_sample_info_index['BMGC_Disease_name'] = target_sample_info_index['BMGC_Disease_name'].replace(r' \\| ', ' or ', regex=True)\n",
    "# Insert a new column \"Index\" in the first position\n",
    "target_sample_info_index.insert(0, 'Index', range(1, len(target_sample_info_index) + 1))\n",
    "display(target_sample_info_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the list of samples that have already been processed\n",
    "processed_samples = set(multi_sample_qa_json.keys())\n",
    "print(f\"Found {len(processed_samples)} already processed samples\")\n",
    "\n",
    "count = len(processed_samples)\n",
    "total_to_process = len(target_sample_info_index)\n",
    "remaining = total_to_process - count\n",
    "print(f\"Total samples to process: {total_to_process}, already processed: {count}, remaining: {remaining}\")\n",
    "\n",
    "# Iterate through the sample info dataframe, skipping already processed samples\n",
    "for idx, row_tuple in tqdm(enumerate(target_sample_info_index.iterrows()), total=len(target_sample_info_index)):\n",
    "    _, row = row_tuple  # Unpack the tuple - index and row data\n",
    "\n",
    "    sample_ach_name = row[\"depMapID\"]\n",
    "    \n",
    "    # Skip if already processed\n",
    "    if sample_ach_name in processed_samples:\n",
    "        continue\n",
    "        \n",
    "    count += 1\n",
    "    target_sample_index = row[\"Index\"]\n",
    "    cell_line_name = row[\"Name\"]\n",
    "    disease = row[\"BMGC_Disease_name\"]\n",
    "    disease_bmgc_id = row[\"BMGC_Disease_ID\"]\n",
    "\n",
    "    print(f\"Processing sample {count}/{total_to_process}: {sample_ach_name} ({cell_line_name})\")\n",
    "    print(f\"Sample Index: {target_sample_index}\")\n",
    "    print(f\"Sample Disease: {disease}\")\n",
    "    print(f\"Sample Disease BMGC ID: {disease_bmgc_id}\")\n",
    "\n",
    "    try:\n",
    "        (top_k_gene_hgnc, top_k_gene_bmgc, top_k_gene_llm,\n",
    "        top_k_ts_hgnc, top_k_ts_bmgc, top_k_ts_llm,\n",
    "        top_k_pt_hgnc, top_k_pt_bmgc, top_k_pt_llm,\n",
    "        dis_pt_idx, dis_pt_bmgc, ppi_idx, ppi_bmgc,\n",
    "        dis_pt_hgnc_dict, dis_pt_hgnc, ppi_hgnc_dict, ppi_hgnc, relationship_texts,\n",
    "        ans_hgnc, ans_bmgc, ans_llm) = qa_sample_info(sample_ach_name, disease_bmgc_id, k=k, top_bm=top_bm)\n",
    "\n",
    "        multi_sample_qa_json[sample_ach_name] = {\n",
    "            \"cell_line_name\": cell_line_name,\n",
    "            \"sample_index\": target_sample_index,\n",
    "            \"disease\": disease,\n",
    "            \"disease_bmgc_id\": disease_bmgc_id,\n",
    "            \"input\": {\n",
    "                \"top_k_gene\": {\n",
    "                    \"hgnc_symbols\": top_k_gene_hgnc,\n",
    "                    \"protein_bmgc_ids\": top_k_gene_bmgc,\n",
    "                    \"protein_llmname_ids\": top_k_gene_llm\n",
    "                },\n",
    "                \"top_k_transcript\": {\n",
    "                    \"hgnc_symbols\": top_k_ts_hgnc,\n",
    "                    \"protein_bmgc_ids\": top_k_ts_bmgc,\n",
    "                    \"protein_llmname_ids\": top_k_ts_llm\n",
    "                },\n",
    "                \"top_k_protein\": {\n",
    "                    \"hgnc_symbols\": top_k_pt_hgnc,\n",
    "                    \"protein_bmgc_ids\": top_k_pt_bmgc,\n",
    "                    \"protein_llmname_ids\": top_k_pt_llm\n",
    "                },\n",
    "                \"knowledge_graph\": {\n",
    "                    \"disease_protein\": {\n",
    "                        \"bmgc_ids\": dis_pt_bmgc,\n",
    "                        \"hgnc_symbols\": dis_pt_hgnc,\n",
    "                        \"indices\": dis_pt_idx\n",
    "                    },\n",
    "                    \"ppi_neighbors\": {\n",
    "                        \"bmgc_ids\": ppi_bmgc,\n",
    "                        \"hgnc_symbols\": ppi_hgnc,\n",
    "                        \"indices\": ppi_idx\n",
    "                    },\n",
    "                    \"protein_relationships\": relationship_texts,\n",
    "                }\n",
    "            },\n",
    "            \"ground_truth_answer\": {\n",
    "                \"top_bm_gene\": {\n",
    "                    \"hgnc_symbols\": ans_hgnc,\n",
    "                    \"protein_bmgc_ids\": ans_bmgc,\n",
    "                    \"protein_llmname_ids\": ans_llm\n",
    "                }\n",
    "            }\n",
    "        }\n",
    "\n",
    "    except Exception as e:\n",
    "        print(f\"⚠️ Error processing {sample_ach_name}: {e}\")\n",
    "        continue\n",
    "\n",
    "    # Periodic save every N samples\n",
    "    if count % save_every_n == 0:\n",
    "        with open(output_path, \"w\") as f:\n",
    "            json.dump(multi_sample_qa_json, f, indent=2, default=lambda o: int(o) if hasattr(o, 'item') else o)\n",
    "        print(f\"💾 Auto-saved JSON at {count}/{total_to_process} samples to: {output_path}\")\n",
    "        print(f\"Last processed sample: {sample_ach_name}\")\n",
    "        processed = len(multi_sample_qa_json)\n",
    "        remaining = total_to_process - processed\n",
    "        print(f\"Progress: {processed}/{total_to_process} ({processed/total_to_process*100:.1f}%), Remaining: {remaining}\")\n",
    "\n",
    "# Final save after loop\n",
    "with open(output_path, \"w\") as f:\n",
    "    json.dump(multi_sample_qa_json, f, indent=2, default=lambda o: int(o) if hasattr(o, 'item') else o)\n",
    "\n",
    "print(f\"✅ Final JSON saved to: {output_path}\")\n",
    "print(f\"Total samples processed: {len(multi_sample_qa_json)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Split QA as Train and Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "\n",
    "def read_json(file_path):\n",
    "    with open(file_path, 'r') as file:\n",
    "        data = json.load(file)\n",
    "    return data\n",
    "\n",
    "def separate_data(data):\n",
    "    # Get the training and testing list\n",
    "    target_tr_df = pd.read_csv('./data/TargetQA/train_samples.csv')\n",
    "    target_te_df = pd.read_csv('./data/TargetQA/test_samples.csv')\n",
    "    target_tr_list = target_tr_df['depMapID'].tolist()\n",
    "    target_te_list = target_te_df['depMapID'].tolist()\n",
    "    # Separate the data into two lists based on tr/te list with (for sample_id, sample_info in data.items())\n",
    "    tr_data = {}\n",
    "    te_data = {}\n",
    "    for sample_id, sample_info in data.items():\n",
    "        if sample_id in target_tr_list:\n",
    "            tr_data[sample_id] = sample_info\n",
    "        elif sample_id in target_te_list:\n",
    "            te_data[sample_id] = sample_info\n",
    "    # Check if the separated data is correct\n",
    "    tr_data_count = len(tr_data)\n",
    "    te_data_count = len(te_data)\n",
    "    print(f\"Training data count: {tr_data_count}\")\n",
    "    print(f\"Testing data count: {te_data_count}\")\n",
    "    # Check if the training and testing data are mutually exclusive\n",
    "    if set(tr_data.keys()).intersection(te_data.keys()):\n",
    "        print(\"Error: Training and testing data are not mutually exclusive.\")\n",
    "    else:\n",
    "        print(\"Training and testing data are mutually exclusive.\")\n",
    "    return tr_data, te_data\n",
    "\n",
    "def save_to_json(data, file_path):\n",
    "    with open(file_path, 'w') as file:\n",
    "        json.dump(data, file, indent=2)\n",
    "    print(f\"Data saved to {file_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read the JSON file\n",
    "data = read_json('./data/TargetQA/target_qa_k10_bm100.json')\n",
    "\n",
    "# Separate the data into training and testing sets\n",
    "tr_data, te_data = separate_data(data)\n",
    "# Save the separated data to new JSON files\n",
    "save_to_json(tr_data, './data/TargetQA/target_qa_k10_bm100_tr.json')\n",
    "save_to_json(te_data, './data/TargetQA/target_qa_k10_bm100_te.json')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Formulate QA text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 9.2.3 Drug data integration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# insert the samples that is in the remaining_dti_samples but not in the maped_methy_df\n",
    "for sample in remaining_dti_samples:\n",
    "    if sample not in maped_methy_df.columns:\n",
    "        maped_methy_df[sample] = 0.0\n",
    "# filter out the samples that is not in the remaining_dti_samples\n",
    "dti_methy_df = maped_methy_df[['BioMedGraphica_Conn_ID'] + remaining_dti_samples]\n",
    "display(dti_methy_df)\n",
    "\n",
    "# insert the samples that is in the remaining_dti_samples but not in the mapped_protein_df\n",
    "for sample in remaining_dti_samples:\n",
    "    if sample not in mapped_protein_df.columns:\n",
    "        mapped_protein_df[sample] = 0.0\n",
    "# filter out the samples that is not in the remaining_dti_samples\n",
    "dti_protein_df = mapped_protein_df[['BioMedGraphica_Conn_ID'] + remaining_dti_samples]\n",
    "display(dti_protein_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fill in the missing samples in the final_merged_gene_df\n",
    "for sample in remaining_dti_samples:\n",
    "    if sample not in final_merged_gene_df.columns:\n",
    "        final_merged_gene_df[sample] = 0.0\n",
    "# fill in the missing samples in the final_merged_transcript_df\n",
    "for sample in remaining_dti_samples:\n",
    "    if sample not in final_merged_transcript_df.columns:\n",
    "        final_merged_transcript_df[sample] = 0.0\n",
    "# get the final gene, transcript, drug dataframe by filtering the remaining_dti_samples\n",
    "dti_gene_df = final_merged_gene_df[['BioMedGraphica_Conn_ID'] + remaining_dti_samples]\n",
    "display(dti_gene_df)\n",
    "dti_transcript_df = final_merged_transcript_df[['BioMedGraphica_Conn_ID'] + remaining_dti_samples]\n",
    "display(dti_transcript_df)\n",
    "dti_drug_overlap_df = final_drug_df[final_drug_df['ARXSPAN_ID'].isin(remaining_dti_samples)].reset_index(drop=True)\n",
    "dti_drug_overlap_df = pd.merge(dti_drug_overlap_df, dti_combined_df, left_on='ARXSPAN_ID', right_on='depMapID', how='left')\n",
    "dti_drug_overlap_df = pd.merge(dti_drug_overlap_df, final_merged_drug_df, on='DRUG_NAME', how='left').rename(columns={'BMGC_Disease_name': 'BMGC_Disease_Name', 'DRUG_NAME': 'BMGC_Drug_Name', 'BioMedGraphica_Conn_ID': 'BMGC_Drug_ID', 'Name': 'Cell_Line_Name', 'AUC_PUBLISHED': 'AUC'})\n",
    "# only keep columns ['depMapID', 'Cell_Line_Name', 'BMGC_Drug_ID', 'BMGC_Drug_Name', 'BMGC_Disease_ID', 'BMGC_Disease_Name', 'AUC']\n",
    "dti_drug_overlap_df = dti_drug_overlap_df[['depMapID', 'Cell_Line_Name', 'BMGC_Drug_ID', 'BMGC_Drug_Name', 'BMGC_Disease_ID', 'BMGC_Disease_Name', 'AUC']]\n",
    "# check if there is null values in the dti_drug_overlap_df\n",
    "print(dti_drug_overlap_df.isnull().sum())\n",
    "display(dti_drug_overlap_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dti_omics_df = pd.concat([dti_methy_df, dti_gene_df, dti_transcript_df, dti_protein_df], axis=0).reset_index(drop=True)\n",
    "display(dti_omics_df)\n",
    "dti_feat_df = pd.merge(bmgc_entity_df, dti_omics_df, left_on='BioMedGraphica_Conn_ID', right_on='BioMedGraphica_Conn_ID', how='left').drop(columns=['BioMedGraphica_ID', 'Type'])\n",
    "dti_feat_df = dti_feat_df.fillna(0.0)\n",
    "display(dti_feat_df)\n",
    "\n",
    "# convert dti_feat_df to numpy array and transpose it\n",
    "dti_array = dti_feat_df.drop(columns=['BioMedGraphica_Conn_ID']).values.T\n",
    "print(\"Shape of dti_array:\", dti_array.shape)\n",
    "# Save the numpy array to .npy file\n",
    "np.save('./data/DrugQA/dti_feature.npy', dti_array)\n",
    "np.save('./data/DrugScreen/dti_feature.npy', dti_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# filter out the drug screen by training and test with condition that meet both cell line test samples and drug test samples for test\n",
    "final_training_drug_screen_df = dti_drug_overlap_df[dti_drug_overlap_df['depMapID'].isin(dti_train_samples)].reset_index(drop=True)\n",
    "display(final_training_drug_screen_df)\n",
    "# and calculate the final training drug ids and cell line ids\n",
    "training_drug_ids = list(set(final_training_drug_screen_df['BMGC_Drug_ID']))\n",
    "training_cell_line_ids = list(set(final_training_drug_screen_df['depMapID']))\n",
    "print(\"len(training_drug_ids):\", len(training_drug_ids))\n",
    "print(\"len(training_cell_line_ids):\", len(training_cell_line_ids))\n",
    "\n",
    "# the rest of the drug screen is the test drug screen\n",
    "final_test_drug_screen_df = dti_drug_overlap_df[dti_drug_overlap_df['depMapID'].isin(dti_test_samples)].reset_index(drop=True)\n",
    "display(final_test_drug_screen_df)\n",
    "# and calculate the final test drug ids and cell line ids\n",
    "test_drug_ids = list(set(final_test_drug_screen_df['BMGC_Drug_ID']))\n",
    "test_cell_line_ids = list(set(final_test_drug_screen_df['depMapID']))\n",
    "print(\"len(test_drug_ids):\", len(test_drug_ids))\n",
    "print(\"len(test_cell_line_ids):\", len(test_cell_line_ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"The overlapped samples dataframe for methylation between remaining_dti_samples:\")\n",
    "# insert the samples that is in the overlapped_samples but not in the maped_methy_df\n",
    "for sample in remaining_dti_samples:\n",
    "    if sample not in maped_methy_df.columns:\n",
    "        maped_methy_df[sample] = 0.0\n",
    "# filter out the samples that is not in the remaining_dti_samples\n",
    "dti_methy_df = maped_methy_df[['BioMedGraphica_Conn_ID'] + remaining_dti_samples]\n",
    "display(dti_methy_df)\n",
    "\n",
    "print(\"The overlapped samples dataframe for protein between remaining_dti_samples are:\")\n",
    "# insert the samples that is in the pretraining_samples but not in the mapped_protein_df\n",
    "for sample in remaining_dti_samples:\n",
    "    if sample not in mapped_protein_df.columns:\n",
    "        mapped_protein_df[sample] = 0.0\n",
    "# filter out the samples that is not in the remaining_dti_samples\n",
    "dti_protein_df = mapped_protein_df[['BioMedGraphica_Conn_ID'] + remaining_dti_samples]\n",
    "display(dti_protein_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fill in the missing samples in the final_merged_gene_df for remaining_dti_samples\n",
    "for sample in remaining_dti_samples:\n",
    "    if sample not in final_merged_gene_df.columns:\n",
    "        final_merged_gene_df[sample] = 0.0\n",
    "# fill in the missing samples in the final_merged_transcript_df for remaining_dti_samples\n",
    "for sample in remaining_dti_samples:\n",
    "    if sample not in final_merged_transcript_df.columns:\n",
    "        final_merged_transcript_df[sample] = 0.0\n",
    "# get the final gene, transcript, drug dataframe by filtering the remaining_dti_samples\n",
    "dti_gene_df = final_merged_gene_df[['BioMedGraphica_Conn_ID'] + remaining_dti_samples]\n",
    "display(dti_gene_df)\n",
    "dti_transcript_df = final_merged_transcript_df[['BioMedGraphica_Conn_ID'] + remaining_dti_samples]\n",
    "display(dti_transcript_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dti_omics_df = pd.concat([dti_methy_df, dti_gene_df, dti_transcript_df, dti_protein_df], axis=0).reset_index(drop=True)\n",
    "display(dti_omics_df)\n",
    "dti_df = pd.merge(bmgc_entity_df, dti_omics_df, left_on='BioMedGraphica_Conn_ID', right_on='BioMedGraphica_Conn_ID', how='left').drop(columns=['BioMedGraphica_ID', 'Type'])\n",
    "dti_df = dti_df.fillna(0.0)\n",
    "display(dti_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select the features from dti_df\n",
    "dti_feat = dti_df.drop(columns=['BioMedGraphica_Conn_ID']).values.T\n",
    "# Save the features to .npy files\n",
    "np.save('./data/DrugQA/dti_feature.npy', dti_feat)\n",
    "np.save('./data/DrugScreen/dti_feature.npy', dti_feat)\n",
    "# Print the shapes of the features\n",
    "print(\"Shape of dti_feat:\", dti_feat.shape)\n",
    "\n",
    "# Create a dictionary mapping each depMapID to its corresponding row index (default DataFrame index)\n",
    "dti_sample_index_dict = dict(zip(remaining_dti_samples_df['depMapID'], remaining_dti_samples_df.index))\n",
    "print(dti_sample_index_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 9.2.3.1 Drug Screen Integration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For final_training_drug_screen_df, only keep the columns ['depMapID', 'BMGC_Drug_ID', 'BMGC_Disease_ID', 'AUC']\n",
    "final_training_drug_screen_dfc = final_training_drug_screen_df[['depMapID', 'BMGC_Drug_ID', 'BMGC_Disease_ID', 'AUC']].copy()\n",
    "# Map depMapID to dti_sample_index, BMGC_Drug_ID to node_index, BMGC_Disease_ID to node_index\n",
    "final_training_drug_screen_dfc['dti_sample_index'] = final_training_drug_screen_dfc['depMapID'].map(dti_sample_index_dict)\n",
    "final_training_drug_screen_dfc['BMGC_Drug_ID'] = final_training_drug_screen_dfc['BMGC_Drug_ID'].map(node_index_dict)\n",
    "final_training_drug_screen_dfc['BMGC_Disease_ID'] = final_training_drug_screen_dfc['BMGC_Disease_ID'].map(node_index_dict)\n",
    "final_training_drug_screen_dfc = final_training_drug_screen_dfc.drop(columns=['depMapID'])\n",
    "# reorder the columns to ['dti_sample_index', 'BMGC_Drug_ID', 'BMGC_Disease_ID', 'AUC']\n",
    "final_training_drug_screen_dfc = final_training_drug_screen_dfc[['dti_sample_index', 'BMGC_Drug_ID', 'BMGC_Disease_ID', 'AUC']]\n",
    "display(final_training_drug_screen_dfc)\n",
    "\n",
    "# For final_test_drug_screen_df, only keep the columns ['depMapID', 'BMGC_Drug_ID', 'BMGC_Disease_ID', 'AUC']\n",
    "final_test_drug_screen_dfc = final_test_drug_screen_df[['depMapID', 'BMGC_Drug_ID', 'BMGC_Disease_ID', 'AUC']].copy()\n",
    "# Map depMapID to dti_sample_index, BMGC_Drug_ID to node_index, BMGC_Disease_ID to node_index\n",
    "final_test_drug_screen_dfc['dti_sample_index'] = final_test_drug_screen_dfc['depMapID'].map(dti_sample_index_dict)\n",
    "final_test_drug_screen_dfc['BMGC_Drug_ID'] = final_test_drug_screen_dfc['BMGC_Drug_ID'].map(node_index_dict)\n",
    "final_test_drug_screen_dfc['BMGC_Disease_ID'] = final_test_drug_screen_dfc['BMGC_Disease_ID'].map(node_index_dict)\n",
    "final_test_drug_screen_dfc = final_test_drug_screen_dfc.drop(columns=['depMapID'])\n",
    "# reorder the columns to ['dti_sample_index', 'BMGC_Drug_ID', 'BMGC_Disease_ID', 'AUC']\n",
    "final_test_drug_screen_dfc = final_test_drug_screen_dfc[['dti_sample_index', 'BMGC_Drug_ID', 'BMGC_Disease_ID', 'AUC']]\n",
    "display(final_test_drug_screen_dfc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert final_training_drug_screen_dfc and final_test_drug_screen_dfc to numpy arrays and save them\n",
    "final_training_drug_screen_array = final_training_drug_screen_dfc.values\n",
    "final_test_drug_screen_array = final_test_drug_screen_dfc.values\n",
    "# Save the numpy arrays to .npy files\n",
    "np.save('./data/DrugScreen/final_training_drug_screen.npy', final_training_drug_screen_array)\n",
    "np.save('./data/DrugScreen/final_test_drug_screen.npy', final_test_drug_screen_array)\n",
    "print(\"Shape of final_training_drug_screen_array:\", final_training_drug_screen_array.shape)\n",
    "print(\"Shape of final_test_drug_screen_array:\", final_test_drug_screen_array.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 9.2.3.2 DrugQA Integration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# select the columns in the remaining_dti_samples for gene_df\n",
    "# Only include samples that actually exist in gene_df columns\n",
    "available_samples = [sample for sample in remaining_dti_samples if sample in gene_df.columns]\n",
    "dti_gene_df = gene_df[['gene_name'] + sorted(available_samples)].copy()\n",
    "print(f\"Total remaining_dti_samples: {len(remaining_dti_samples)}\")\n",
    "print(f\"Available samples in gene_df: {len(available_samples)}\")\n",
    "print(f\"Missing samples: {len(remaining_dti_samples) - len(available_samples)}\")\n",
    "display(dti_gene_df)\n",
    "bmgc_protein_llmnameid_combined_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein_LLM_Name_ID_Combined.csv')\n",
    "display(bmgc_protein_llmnameid_combined_df)\n",
    "\n",
    "def extract_dti_gn_info(dti_gene_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k=10):\n",
    "    # Check if sample_ach_name exists in dti_gene_df columns\n",
    "    if sample_ach_name not in dti_gene_df.columns:\n",
    "        return \"non-existed\", \"non-existed\", \"non-existed\"\n",
    "    # Extract the top k highest values for the given sample name\n",
    "    top_k_genes = dti_gene_df.nlargest(k, sample_ach_name)[['gene_name', sample_ach_name]]\n",
    "    # Sort the top k genes by their values in descending order\n",
    "    top_k_genes = top_k_genes.sort_values(by=sample_ach_name, ascending=False).reset_index(drop=True)\n",
    "    top_k_gene_hgnc_name_list = top_k_genes['gene_name'].tolist()\n",
    "    # Merge with the bmgc_gene_df to get the BioMedGraphica_Conn_ID\n",
    "    bmgc_gene_df = bmgc_gene_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']].copy()\n",
    "    top_k_bmgc_gene_df = pd.merge(bmgc_gene_df, top_k_genes, left_on='HGNC_Symbol', right_on='gene_name', how='inner')\n",
    "    # Get the corresponding proteins\n",
    "    top_k_bmgc_gene_protein_df = pd.merge(gene_transcript_protein_entity_df, top_k_bmgc_gene_df, left_on='BMGC_GN_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'gene_name'])\n",
    "    top_k_bmgc_gene_protein_info_df = pd.merge(top_k_bmgc_gene_protein_df, bmgc_protein_llmnameid_combined_df, left_on='BMGC_PT_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'BioMedGraphica_ID', sample_ach_name])\n",
    "    top_k_gene_protein_bmgc_id_list = top_k_bmgc_gene_protein_info_df['BMGC_PT_ID'].tolist()\n",
    "    top_k_gene_protein_bmgc_llmnameid_list = top_k_bmgc_gene_protein_info_df['Names_and_IDs'].replace(r' \\| ', ' or ', regex=True).tolist()\n",
    "    return top_k_gene_hgnc_name_list, top_k_gene_protein_bmgc_id_list, top_k_gene_protein_bmgc_llmnameid_list\n",
    "\n",
    "# Example usage\n",
    "sample_ach_name = 'ACH-000002'\n",
    "k=10\n",
    "top_k_gene_hgnc_name_list, top_k_gene_protein_bmgc_id_list, top_k_gene_protein_bmgc_llmnameid_list = extract_dti_gn_info(dti_gene_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)\n",
    "print(f\"Top {k} Gene HGNC Names:\", top_k_gene_hgnc_name_list)\n",
    "print(f\"Top {k} Gene Protein BMGC IDs:\", top_k_gene_protein_bmgc_id_list)\n",
    "print(f\"Top {k} Gene Protein BMGC LLM Name IDs:\", top_k_gene_protein_bmgc_llmnameid_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# select the columns in the remaining_dti_samples for raw_transcript_df\n",
    "# Only include samples that actually exist in raw_transcript_df columns\n",
    "available_transcript_samples = [sample for sample in remaining_dti_samples if sample in raw_transcript_df.columns]\n",
    "dti_transcript_df = raw_transcript_df[['gene_name'] + sorted(available_transcript_samples)].copy()\n",
    "print(f\"Total remaining_dti_samples: {len(remaining_dti_samples)}\")\n",
    "print(f\"Available samples in raw_transcript_df: {len(available_transcript_samples)}\")\n",
    "print(f\"Missing samples: {len(remaining_dti_samples) - len(available_transcript_samples)}\")\n",
    "display(dti_transcript_df)\n",
    "\n",
    "def extract_dti_ts_info(dti_transcript_df, bmgc_transcript_df, transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k=10):\n",
    "    # Check if sample_ach_name exists in dti_transcript_df columns\n",
    "    if sample_ach_name not in dti_transcript_df.columns:\n",
    "        return \"non-existed\", \"non-existed\", \"non-existed\"\n",
    "    # Extract the top k highest values for the given sample name\n",
    "    top_k_transcripts = dti_transcript_df.nlargest(k, sample_ach_name)[['gene_name', sample_ach_name]]\n",
    "    # Sort the top k transcripts by their values in descending order\n",
    "    top_k_transcripts = top_k_transcripts.sort_values(by=sample_ach_name, ascending=False).reset_index(drop=True)\n",
    "    top_k_transcript_hgnc_name_list = top_k_transcripts['gene_name'].tolist()\n",
    "    # Merge with the bmgc_transcript_df to get the BioMedGraphica_Conn_ID\n",
    "    bmgc_transcript_df = bmgc_transcript_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']].copy()\n",
    "    top_k_bmgc_transcript_df = pd.merge(bmgc_transcript_df, top_k_transcripts, left_on='HGNC_Symbol', right_on='gene_name', how='inner')\n",
    "    # Get the corresponding proteins\n",
    "    top_k_bmgc_transcript_protein_df = pd.merge(transcript_protein_entity_df, top_k_bmgc_transcript_df, left_on='BMGC_TS_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'gene_name'])\n",
    "    top_k_bmgc_transcript_protein_info_df = pd.merge(top_k_bmgc_transcript_protein_df, bmgc_protein_llmnameid_combined_df, left_on='BMGC_PT_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_Conn_ID', 'BioMedGraphica_ID', sample_ach_name])\n",
    "    top_k_transcript_protein_bmgc_id_list = top_k_bmgc_transcript_protein_info_df['BMGC_PT_ID'].tolist()\n",
    "    top_k_transcript_protein_bmgc_llmnameid_list = top_k_bmgc_transcript_protein_info_df['Names_and_IDs'].replace(r' \\| ', ' or ', regex=True).tolist()\n",
    "    return top_k_transcript_hgnc_name_list, top_k_transcript_protein_bmgc_id_list, top_k_transcript_protein_bmgc_llmnameid_list\n",
    "\n",
    "# Example usage\n",
    "sample_ach_name = 'ACH-000002'\n",
    "k=10\n",
    "top_k_transcript_hgnc_name_list, top_k_transcript_protein_bmgc_id_list, top_k_transcript_protein_bmgc_llmnameid_list = extract_dti_ts_info(dti_transcript_df, bmgc_transcript_df, transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)\n",
    "print(f\"Top {k} Transcript HGNC Names:\", top_k_transcript_hgnc_name_list)\n",
    "print(f\"Top {k} Transcript Protein BMGC IDs:\", top_k_transcript_protein_bmgc_id_list)\n",
    "print(f\"Top {k} Transcript Protein BMGC LLM Name IDs:\", top_k_transcript_protein_bmgc_llmnameid_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load protein mapping data with BioMedGraphica IDs, Uniprot IDs, and HGNC symbols\n",
    "bmg_protein_all_df = pd.read_csv('./data/BioMedGraphica-Conn/Entity/Protein/BioMedGraphica_Conn_Protein.csv')\n",
    "bmg_protein_all_df = bmg_protein_all_df[['BioMedGraphica_Conn_ID', 'Uniprot_ID', 'HGNC_Symbol']].copy()\n",
    "display(bmg_protein_all_df)\n",
    "\n",
    "# Map sample IDs in raw protein data to standardized format\n",
    "raw_protein_map_df = raw_protein_df.rename(columns=protein_map_dict)\n",
    "\n",
    "# Join protein expression data with HGNC symbol annotations\n",
    "symbol_protein_map_df = pd.merge(raw_protein_map_df, bmg_protein_all_df, left_on='Uniprot_Acc', right_on='Uniprot_ID', how='inner')\n",
    "\n",
    "# Reorganize columns: protein identifiers first, then sample expression values\n",
    "identifier_cols = ['Uniprot_ID', 'Uniprot_Acc', 'HGNC_Symbol']\n",
    "expression_cols = sorted(set(symbol_protein_map_df.columns) - set(identifier_cols))\n",
    "symbol_protein_map_df = symbol_protein_map_df[identifier_cols + expression_cols]\n",
    "\n",
    "# Only include samples that actually exist in symbol_protein_map_df columns\n",
    "available_protein_samples = [sample for sample in remaining_dti_samples if sample in symbol_protein_map_df.columns]\n",
    "print(f\"Total remaining_dti_samples: {len(remaining_dti_samples)}\")\n",
    "print(f\"Available samples in symbol_protein_map_df: {len(available_protein_samples)}\")\n",
    "print(f\"Missing samples: {len(remaining_dti_samples) - len(available_protein_samples)}\")\n",
    "\n",
    "# Extract protein data for DTI samples with HGNC symbols\n",
    "dti_protein_df = symbol_protein_map_df[['HGNC_Symbol'] + sorted(available_protein_samples)].copy()\n",
    "\n",
    "# Handle multiple HGNC symbols per protein (semicolon-separated)\n",
    "dti_protein_df = dti_protein_df.assign(\n",
    "    HGNC_Symbol=dti_protein_df['HGNC_Symbol'].str.split(';')\n",
    ").explode('HGNC_Symbol')\n",
    "\n",
    "# Clean up HGNC symbols: remove whitespace and filter out empty entries\n",
    "dti_protein_df['HGNC_Symbol'] = dti_protein_df['HGNC_Symbol'].str.strip()\n",
    "dti_protein_df = dti_protein_df[\n",
    "    dti_protein_df['HGNC_Symbol'].notna() & \n",
    "    (dti_protein_df['HGNC_Symbol'] != '')\n",
    "].reset_index(drop=True)\n",
    "\n",
    "display(dti_protein_df)\n",
    "\n",
    "def extract_dti_pt_info(dti_protein_df, bmgc_protein_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k=10):\n",
    "    # Check if sample_ach_name exists in dti_protein_df columns\n",
    "    if sample_ach_name not in dti_protein_df.columns:\n",
    "        return \"non-existed\", \"non-existed\", \"non-existed\"\n",
    "    # Extract the top k highest values for the given sample name\n",
    "    top_k_proteins = dti_protein_df.nlargest(k, sample_ach_name)[['HGNC_Symbol', sample_ach_name]]\n",
    "    # Sort the top k proteins by their values in descending order\n",
    "    top_k_proteins = top_k_proteins.sort_values(by=sample_ach_name, ascending=False).reset_index(drop=True)\n",
    "    top_k_protein_hgnc_name_list = top_k_proteins['HGNC_Symbol'].tolist()\n",
    "    # Merge with the bmgc_protein_df to get the BioMedGraphica_Conn_ID\n",
    "    bmgc_protein_df = bmgc_protein_df[['BioMedGraphica_Conn_ID', 'HGNC_Symbol']].copy()\n",
    "    top_k_bmgc_protein_df = pd.merge(bmgc_protein_df, top_k_proteins, left_on='HGNC_Symbol', right_on='HGNC_Symbol', how='inner')\n",
    "    # Get the corresponding protein information\n",
    "    top_k_bmgc_protein_info_df = pd.merge(top_k_bmgc_protein_df, bmgc_protein_llmnameid_combined_df, left_on='BioMedGraphica_Conn_ID', right_on='BioMedGraphica_Conn_ID', how='inner').drop(columns=['BioMedGraphica_ID', sample_ach_name])\n",
    "    top_k_protein_bmgc_id_list = top_k_bmgc_protein_info_df['BioMedGraphica_Conn_ID'].tolist()\n",
    "    # Replace both \"|\" and \";\" with \" or \"\n",
    "    top_k_protein_bmgc_llmnameid_list = top_k_bmgc_protein_info_df['Names_and_IDs'].replace([r' \\| ', r';'], ' or ', regex=True).tolist()\n",
    "    return top_k_protein_hgnc_name_list, top_k_protein_bmgc_id_list, top_k_protein_bmgc_llmnameid_list\n",
    "\n",
    "# Example usage\n",
    "sample_ach_name = 'ACH-000008'\n",
    "k = 10\n",
    "top_k_protein_hgnc_name_list, top_k_protein_bmgc_id_list, top_k_protein_bmgc_llmnameid_list = extract_dti_pt_info(dti_protein_df, bmgc_protein_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)\n",
    "print(f\"Top {k} Protein HGNC Names:\", top_k_protein_hgnc_name_list)\n",
    "print(f\"Top {k} Protein BMGC IDs:\", top_k_protein_bmgc_id_list)\n",
    "print(f\"Top {k} Protein BMGC LLM Name IDs:\", top_k_protein_bmgc_llmnameid_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_drug_protein(selected_sample_drug_bmgc_id, edge_index, \n",
    "                    node_index_df, nodeid_index_dict, index_nodeid_dict):\n",
    "    # Extract the index based on the selected drug BMGC ID\n",
    "    sample_drug_bmgc_id_index = nodeid_index_dict[selected_sample_drug_bmgc_id]\n",
    "    # Find incoming edges (source nodes that point to the drug)\n",
    "    incoming_mask = edge_index[1, :] == sample_drug_bmgc_id_index\n",
    "    incoming_source_nodes = edge_index[0, incoming_mask]\n",
    "    # Find outgoing edges (target nodes that the drug points to)\n",
    "    outgoing_mask = edge_index[0, :] == sample_drug_bmgc_id_index\n",
    "    outgoing_target_nodes = edge_index[1, outgoing_mask]\n",
    "    # Combine all neighbor nodes (both incoming and outgoing)\n",
    "    drug_related_nodes = np.concatenate([incoming_source_nodes, outgoing_target_nodes])\n",
    "    unique_drug_related_nodes = np.unique(drug_related_nodes)\n",
    "    # Get protein node index\n",
    "    protein_node_index_df = node_index_df[node_index_df['Type'] == 'Protein']\n",
    "    protein_node_index_list = protein_node_index_df['Index'].tolist()\n",
    "    # Filter to get only protein nodes directly connected to the drug\n",
    "    drug_protein_index = sorted(\n",
    "        list(set(unique_drug_related_nodes) & set(protein_node_index_list))\n",
    "    )\n",
    "    # Map protein index to BMGC id\n",
    "    drug_protein_bmgc_id = [index_nodeid_dict[i] for i in drug_protein_index]\n",
    "    return drug_protein_index, drug_protein_bmgc_id\n",
    "\n",
    "# Example usage\n",
    "selected_sample_drug_bmgc_id = 'BMGC_DG00001'\n",
    "\n",
    "drug_protein_index, drug_protein_bmgc_id = extract_drug_protein(\n",
    "    selected_sample_drug_bmgc_id=selected_sample_drug_bmgc_id,\n",
    "    edge_index=edge_index,\n",
    "    node_index_df=nodes_index_data,\n",
    "    nodeid_index_dict=node_index_dict,\n",
    "    index_nodeid_dict=index_node_dict\n",
    ")\n",
    "\n",
    "print(\"Drug Protein Index:\", drug_protein_index)\n",
    "print(\"Drug Protein BMGC ID:\", drug_protein_bmgc_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Formulate the drug related protein json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from tqdm import tqdm\n",
    "\n",
    "dti_drug_ids = sorted(list(set(dti_drug_overlap_df['BMGC_Drug_ID'])))\n",
    "\n",
    "drug_kg_protein_bmgc_dict = {}\n",
    "print_count = 0\n",
    "max_prints = 5\n",
    "\n",
    "for i, drug_bmgc_id in enumerate(tqdm(dti_drug_ids, desc=\"Processing drug-protein relationships\")):\n",
    "    drug_protein_index, drug_protein_bmgc_id = extract_drug_protein(\n",
    "        selected_sample_drug_bmgc_id=drug_bmgc_id,\n",
    "        edge_index=edge_index,\n",
    "        node_index_df=nodes_index_data,\n",
    "        nodeid_index_dict=node_index_dict,\n",
    "        index_nodeid_dict=index_node_dict\n",
    "    )\n",
    "    hgnc_dict, drug_protein_hgnc_list = bmgc_pt_id_to_hgnc(drug_protein_bmgc_id, bmgc_protein_df)\n",
    "    \n",
    "    # Only print first 5 entries\n",
    "    if print_count < max_prints:\n",
    "        print(f\"Drug BMGC ID: {drug_bmgc_id}, Protein Index: {drug_protein_index}, Protein BMGC ID: {drug_protein_bmgc_id}\")\n",
    "        print_count += 1\n",
    "    elif print_count == max_prints:\n",
    "        print(\"... (remaining entries processed silently)\")\n",
    "        print_count += 1\n",
    "    \n",
    "    # Convert NumPy integers to Python integers for JSON serialization\n",
    "    drug_kg_protein_bmgc_dict[drug_bmgc_id] = {\n",
    "        'drug_protein_index': [int(x) for x in drug_protein_index],  # Convert numpy ints to Python ints\n",
    "        'drug_protein_bmgc_id': drug_protein_bmgc_id\n",
    "    }\n",
    "    \n",
    "    # Save every 10 iterations\n",
    "    if (i + 1) % 10 == 0:\n",
    "        output_path = \"./data/DrugQA/drug_kg_protein_relationships.json\"\n",
    "        with open(output_path, \"w\") as f:\n",
    "            json.dump(drug_kg_protein_bmgc_dict, f, indent=2)\n",
    "        print(f\"💾 Auto-saved after processing {i + 1}/{len(dti_drug_ids)} drugs\")\n",
    "\n",
    "# Final save after loop completion\n",
    "output_path = \"./data/DrugQA/drug_kg_protein_relationships.json\"\n",
    "with open(output_path, \"w\") as f:\n",
    "    json.dump(drug_kg_protein_bmgc_dict, f, indent=2)\n",
    "    \n",
    "print(f\"✅ Final save completed - processed {len(drug_kg_protein_bmgc_dict)} drugs and saved to: {output_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Formulate the sample related json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def drug_qa_sample_info(sample_ach_name, selected_sample_disease_bmgc_id, k=10):\n",
    "    # LLM Info\n",
    "    print(f\"Sample ACH Name: {sample_ach_name}\")\n",
    "    print(f\"Extracting top {k} gene information for {sample_ach_name}...\")\n",
    "    top_k_gene_hgnc_name_list, top_k_gene_protein_bmgc_id_list, top_k_gene_protein_bmgc_llmnameid_list = extract_dti_gn_info(dti_gene_df, bmgc_gene_df, gene_transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)\n",
    "    print(f\"Extracting top {k} transcript information for {sample_ach_name}...\")\n",
    "    top_k_transcript_hgnc_name_list, top_k_transcript_protein_bmgc_id_list, top_k_transcript_protein_bmgc_llmnameid_list = extract_dti_ts_info(dti_transcript_df, bmgc_transcript_df, transcript_protein_entity_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)\n",
    "    print(f\"Extracting top {k} protein information for {sample_ach_name}...\")\n",
    "    top_k_protein_hgnc_name_list, top_k_protein_bmgc_id_list, top_k_protein_bmgc_llmnameid_list = extract_dti_pt_info(dti_protein_df, bmgc_protein_df, bmgc_protein_llmnameid_combined_df, sample_ach_name, k)\n",
    "    # KG Info\n",
    "    edge_index = np.load('./data/DrugQA/edge_index.npy')\n",
    "    print(f\"Extracting disease-related proteins index and bmgc id for {selected_sample_disease_bmgc_id} ({sample_ach_name}) ...\")\n",
    "    disease_protein_index, disease_protein_bmgc_id, ppi_nodes_index, ppi_nodes_bmgc_id = extract_kg_related_proteins(selected_sample_disease_bmgc_id, edge_index, nodes_index_data, node_index_dict, index_node_dict)\n",
    "    print(f\"Knowledge Graph Info: Found {len(disease_protein_index)} disease-related proteins directly connected to {selected_sample_disease_bmgc_id} and {len(ppi_nodes_index)} proteins in their PPI network\")\n",
    "    print(f\"Mapping disease-related proteins to HGNC symbols...\")\n",
    "    disease_protein_hgnc_dict, disease_protein_hgnc_list = bmgc_pt_id_to_hgnc(disease_protein_bmgc_id, bmgc_protein_df)\n",
    "    print(f\"Mapping PPI-related proteins to HGNC symbols...\")\n",
    "    ppi_hgnc_dict, ppi_hgnc_list = bmgc_pt_id_to_hgnc(ppi_nodes_bmgc_id, bmgc_protein_df)\n",
    "    # LLM Used KG Info\n",
    "    print(f\"Extracting protein relationships from BMGC...\")\n",
    "    # Convert the any non-existed string in top_k_gene_hgnc_name_list + top_k_transcript_hgnc_name_list + top_k_protein_hgnc_name_list + disease_protein_hgnc_list into empty list []\n",
    "    if top_k_gene_hgnc_name_list == \"non-existed\": top_k_gene_hgnc_name_list = []\n",
    "    if top_k_transcript_hgnc_name_list == \"non-existed\": top_k_transcript_hgnc_name_list = []\n",
    "    if top_k_protein_hgnc_name_list == \"non-existed\": top_k_protein_hgnc_name_list = []\n",
    "    if disease_protein_hgnc_list == \"non-existed\": disease_protein_hgnc_list = []\n",
    "    # Combine all the HGNC symbols into a single list for relationship extraction\n",
    "    omics_disease_protein_hgnc_list = list(set(top_k_gene_hgnc_name_list + top_k_transcript_hgnc_name_list + top_k_protein_hgnc_name_list + disease_protein_hgnc_list))\n",
    "    relationships_df, relationship_texts = find_protein_relationships(omics_disease_protein_hgnc_list, bmgc_protein_df, bmgc_relation_df)\n",
    "    print(f\"Knowledge Graph Info: Found {len(omics_disease_protein_hgnc_list)} unique proteins and {len(relationship_texts)} relationships between them\")\n",
    "    return top_k_gene_hgnc_name_list, top_k_gene_protein_bmgc_id_list, top_k_gene_protein_bmgc_llmnameid_list, \\\n",
    "              top_k_transcript_hgnc_name_list, top_k_transcript_protein_bmgc_id_list, top_k_transcript_protein_bmgc_llmnameid_list, \\\n",
    "                top_k_protein_hgnc_name_list, top_k_protein_bmgc_id_list, top_k_protein_bmgc_llmnameid_list, \\\n",
    "                disease_protein_index, disease_protein_bmgc_id, ppi_nodes_index, ppi_nodes_bmgc_id, \\\n",
    "                    disease_protein_hgnc_dict, disease_protein_hgnc_list, ppi_hgnc_dict, ppi_hgnc_list, relationship_texts\n",
    "\n",
    "# Example usage\n",
    "sample_ach_name = 'ACH-000002'\n",
    "selected_sample_disease_bmgc_id = 'BMGC_DS07934'\n",
    "k = 10\n",
    "return_tuples = drug_qa_sample_info(sample_ach_name, selected_sample_disease_bmgc_id, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Parameters\n",
    "k = 10\n",
    "save_every_n = 10\n",
    "\n",
    "# Output folder and filename\n",
    "output_dir = \"./data/DrugQA\"\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "output_filename = f\"drug_qa_k{k}_bm{top_bm}.json\"\n",
    "output_path = os.path.join(output_dir, output_filename)\n",
    "\n",
    "# Load existing JSON file if it exists\n",
    "if os.path.exists(output_path):\n",
    "    with open(output_path, \"r\") as f:\n",
    "        multi_sample_qa_json = json.load(f)\n",
    "    print(f\"Loaded existing JSON with {len(multi_sample_qa_json)} processed samples\")\n",
    "else:\n",
    "    multi_sample_qa_json = {}\n",
    "    print(\"No existing JSON found, starting fresh\")\n",
    "\n",
    "# Load sample info data\n",
    "dti_sample_info_index = pd.read_csv('./data/process_data/dti_combined_samples.csv')\n",
    "dti_sample_info_index = pd.merge(dti_sample_info_index, remaining_dti_samples_df['depMapID'], how='inner', on='depMapID').reset_index(drop=True)\n",
    "dti_sample_info_index['BMGC_Disease_name'] = dti_sample_info_index['BMGC_Disease_name'].replace(r' \\| ', ' or ', regex=True)\n",
    "# Insert a new column \"Index\" in the first position\n",
    "dti_sample_info_index.insert(0, 'Index', range(1, len(dti_sample_info_index) + 1))\n",
    "display(dti_sample_info_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the list of samples that have already been processed\n",
    "processed_samples = set(multi_sample_qa_json.keys())\n",
    "print(f\"Found {len(processed_samples)} already processed samples\")\n",
    "\n",
    "count = len(processed_samples)\n",
    "total_to_process = len(dti_sample_info_index)\n",
    "remaining = total_to_process - count\n",
    "print(f\"Total samples to process: {total_to_process}, already processed: {count}, remaining: {remaining}\")\n",
    "\n",
    "# Iterate through the sample info dataframe, skipping already processed samples\n",
    "for idx, row_tuple in tqdm(enumerate(dti_sample_info_index.iterrows()), total=len(dti_sample_info_index)):\n",
    "    _, row = row_tuple  # Unpack the tuple - index and row data\n",
    "\n",
    "    sample_ach_name = row[\"depMapID\"]\n",
    "    \n",
    "    # Skip if already processed\n",
    "    if sample_ach_name in processed_samples:\n",
    "        continue\n",
    "        \n",
    "    count += 1\n",
    "    target_sample_index = row[\"Index\"]\n",
    "    cell_line_name = row[\"Name\"]\n",
    "    disease = row[\"BMGC_Disease_name\"]\n",
    "    disease_bmgc_id = row[\"BMGC_Disease_ID\"]\n",
    "\n",
    "    print(f\"Processing sample {count}/{total_to_process}: {sample_ach_name} ({cell_line_name})\")\n",
    "    print(f\"Sample Index: {target_sample_index}\")\n",
    "    print(f\"Sample Disease: {disease}\")\n",
    "    print(f\"Sample Disease BMGC ID: {disease_bmgc_id}\")\n",
    "\n",
    "    try:\n",
    "        (top_k_gene_hgnc, top_k_gene_bmgc, top_k_gene_llm,\n",
    "        top_k_ts_hgnc, top_k_ts_bmgc, top_k_ts_llm,\n",
    "        top_k_pt_hgnc, top_k_pt_bmgc, top_k_pt_llm,\n",
    "        dis_pt_idx, dis_pt_bmgc, ppi_idx, ppi_bmgc,\n",
    "        dis_pt_hgnc_dict, dis_pt_hgnc, \n",
    "        ppi_hgnc_dict, ppi_hgnc, relationship_texts) = drug_qa_sample_info(sample_ach_name, disease_bmgc_id, k)\n",
    "\n",
    "        multi_sample_qa_json[sample_ach_name] = {\n",
    "            \"cell_line_name\": cell_line_name,\n",
    "            \"sample_index\": target_sample_index,\n",
    "            \"disease\": disease,\n",
    "            \"disease_bmgc_id\": disease_bmgc_id,\n",
    "            \"input\": {\n",
    "                \"top_k_gene\": {\n",
    "                    \"hgnc_symbols\": top_k_gene_hgnc,\n",
    "                    \"protein_bmgc_ids\": top_k_gene_bmgc,\n",
    "                    \"protein_llmname_ids\": top_k_gene_llm\n",
    "                },\n",
    "                \"top_k_transcript\": {\n",
    "                    \"hgnc_symbols\": top_k_ts_hgnc,\n",
    "                    \"protein_bmgc_ids\": top_k_ts_bmgc,\n",
    "                    \"protein_llmname_ids\": top_k_ts_llm\n",
    "                },\n",
    "                \"top_k_protein\": {\n",
    "                    \"hgnc_symbols\": top_k_pt_hgnc,\n",
    "                    \"protein_bmgc_ids\": top_k_pt_bmgc,\n",
    "                    \"protein_llmname_ids\": top_k_pt_llm\n",
    "                },\n",
    "                \"knowledge_graph\": {\n",
    "                    \"disease_protein\": {\n",
    "                        \"bmgc_ids\": dis_pt_bmgc,\n",
    "                        \"hgnc_symbols\": dis_pt_hgnc,\n",
    "                        \"indices\": dis_pt_idx\n",
    "                    },\n",
    "                    \"ppi_neighbors\": {\n",
    "                        \"bmgc_ids\": ppi_bmgc,\n",
    "                        \"hgnc_symbols\": ppi_hgnc,\n",
    "                        \"indices\": ppi_idx\n",
    "                    },\n",
    "                    \"protein_relationships\": relationship_texts,\n",
    "                }\n",
    "            }\n",
    "        }\n",
    "\n",
    "    except Exception as e:\n",
    "        print(f\"⚠️ Error processing {sample_ach_name}: {e}\")\n",
    "        continue\n",
    "\n",
    "    # Periodic save every N samples\n",
    "    if count % save_every_n == 0:\n",
    "        with open(output_path, \"w\") as f:\n",
    "            json.dump(multi_sample_qa_json, f, indent=2, default=lambda o: int(o) if hasattr(o, 'item') else o)\n",
    "        print(f\"💾 Auto-saved JSON at {count}/{total_to_process} samples to: {output_path}\")\n",
    "        print(f\"Last processed sample: {sample_ach_name}\")\n",
    "        processed = len(multi_sample_qa_json)\n",
    "        remaining = total_to_process - processed\n",
    "        print(f\"Progress: {processed}/{total_to_process} ({processed/total_to_process*100:.1f}%), Remaining: {remaining}\")\n",
    "\n",
    "# Final save after loop\n",
    "with open(output_path, \"w\") as f:\n",
    "    json.dump(multi_sample_qa_json, f, indent=2, default=lambda o: int(o) if hasattr(o, 'item') else o)\n",
    "\n",
    "print(f\"✅ Final JSON saved to: {output_path}\")\n",
    "print(f\"Total samples processed: {len(multi_sample_qa_json)}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mkg",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
