{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import random\n",
    "import warnings\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "cmap = 'Blues'\n",
    "warnings.filterwarnings('ignore')\n",
    "sys.path.insert(0, '../GRAPH_Framework-main')\n",
    "from tasks.experiment import ModelTest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "gene = ['ESR1', 'ESR2', 'NCOA1', 'NCOA3', 'FOS', 'JUN', 'SP1', 'CCND1', 'MYC', 'PGR', 'WNT1', 'WNT4', \n",
    "        'TNFSF11', 'ERBB2', 'FGF1', 'FGF2', 'FGF3', 'FGF4', 'FGF17', 'FGF6', 'FGF7', 'FGF8', 'FGF9', \n",
    "        'FGF10', 'FGF16', 'FGF5', 'FGF18', 'FGF20', 'FGF22', 'FGF19', 'FGF21', 'FGF23', 'FGFR1', 'IGF1', \n",
    "        'IGF1R', 'EGF', 'EGFR', 'KIT', 'SHC1', 'SHC2', 'SHC3', 'SHC4', 'GRB2', 'SOS1', 'SOS2', 'HRAS', \n",
    "        'KRAS', 'NRAS', 'ARAF', 'BRAF', 'RAF1', 'MAP2K1', 'MAP2K2', 'MAPK1', 'MAPK3', 'PIK3CA', 'PIK3CD', \n",
    "        'PIK3CB', 'PIK3R1', 'PIK3R2', 'PIK3R3', 'PTEN', 'AKT1', 'AKT2', 'AKT3', 'MTOR', 'RPS6KB1', 'RPS6KB2', \n",
    "        'JAG1', 'JAG2', 'DLL3', 'DLL1', 'DLL4', 'NOTCH1', 'NOTCH2', 'NOTCH3', 'NOTCH4', 'HES1', 'HES5', 'HEYL', \n",
    "        'HEY1', 'HEY2', 'FLT4', 'CDKN1A', 'NFKB2', 'WNT2', 'WNT2B', 'WNT3', 'WNT3A', 'WNT5A', 'WNT5B', 'WNT6', \n",
    "        'WNT7A', 'WNT7B', 'WNT8A', 'WNT8B', 'WNT9A', 'WNT9B', 'WNT10B', 'WNT10A', 'WNT11', 'WNT16', 'FZD1', \n",
    "        'FZD7', 'FZD2', 'FZD3', 'FZD4', 'FZD5', 'FZD8', 'FZD6', 'FZD10', 'FZD9', 'LRP5', 'LRP6', 'DVL3', \n",
    "        'DVL2', 'DVL1', 'FRAT1', 'FRAT2', 'GSK3B', 'AXIN1', 'AXIN2', 'APC', 'APC2', 'CTNNB1', 'CSNK1A1L', \n",
    "        'CSNK1A1', 'TCF7', 'TCF7L1', 'TCF7L2', 'LEF1', 'TP53', 'GADD45A', 'GADD45B', 'GADD45G', 'BAX', \n",
    "        'BAK1', 'DDB2', 'POLK', 'CDK4', 'CDK6', 'RB1', 'E2F1', 'E2F2', 'E2F3', 'BRCA1', 'BRCA2']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_path = '../Source_Data/TCGA/LUAD/gdc_download/MANIFEST.txt'\n",
    "dict_df = pd.read_csv(file_path, sep='\\t')\n",
    "\n",
    "file_path = '../Source_Data/TCGA/LUAD/gdc_sample_sheet.tsv'\n",
    "samp_df = pd.read_csv(file_path, sep='\\t')\n",
    "samp_df = samp_df[samp_df['Sample Type'] == 'Primary Tumor']\n",
    "\n",
    "file_path = '../Source_Data/TCGA/LUAD/clinical.cart/clinical.tsv'\n",
    "demo_coln = ['case_submitter_id', 'age_at_index', 'ethnicity', 'gender', 'race']\n",
    "demo_df = pd.read_csv(file_path, sep='\\t')[demo_coln]\n",
    "demo_df = demo_df.drop_duplicates()\n",
    "\n",
    "info = dict_df.merge(samp_df, how='inner', left_on='id', right_on='File ID')\n",
    "info = info.merge(demo_df, how='inner', left_on='Case ID', right_on='case_submitter_id')\n",
    "\n",
    "LUAD = pd.DataFrame(columns=demo_coln + gene).rename(columns={'case_submitter_id': 'id'})\n",
    "LUAD = LUAD.set_index('id')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_coln = ['id', 'age_at_index', 'ethnicity', 'gender', 'race', 'filename']\n",
    "for index, row in info[temp_coln].iterrows():\n",
    "    id = row['id']\n",
    "    temp_info = [row[col] for col in temp_coln[1:-1]]\n",
    "    file_path = '../Source_Data/TCGA/LUAD/gdc_download/' + row['filename']\n",
    "    temp_sample_df = pd.read_csv(file_path, sep='\\t', skiprows=1)[['gene_name', 'tpm_unstranded']]\n",
    "    selected_rows = temp_sample_df[temp_sample_df['gene_name'].isin(gene)]\n",
    "    temp_gene = selected_rows['tpm_unstranded'].to_list()\n",
    "    LUAD.loc[id] = temp_info + temp_gene\n",
    "\n",
    "# LUAD.to_csv('../Source_Data/TCGA/LUAD_GENE.csv')\n",
    "# LUAD = pd.read_csv('../Source_Data/TCGA/LUAD_GENE.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "male = LUAD.loc[LUAD['gender'] == 'male']\n",
    "female = LUAD.loc[LUAD['gender'] == 'female']\n",
    "\n",
    "luad = np.array(LUAD[gene].dropna())\n",
    "luad1 = np.array(male[gene].dropna())\n",
    "luad2 = np.array(female[gene].dropna())\n",
    "\n",
    "scaler = StandardScaler()\n",
    "D = scaler.fit_transform(luad)\n",
    "D1 = scaler.fit_transform(luad1)\n",
    "D2 = scaler.fit_transform(luad2)\n",
    "D_lt = [D1, D2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler = StandardScaler()\n",
    "D = scaler.fit_transform(D)\n",
    "for i in range(2):\n",
    "    D_lt[i] = scaler.fit_transform(D_lt[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parameters = {'max_iter':2000, 'step_size':1e-4, 'lam':1e-1, 'lamm':1e-1, 'rhom':1, 'tol':1e-3}\n",
    "\n",
    "precision_test = ModelTest(model_type='Precision',showfig=False)\n",
    "precision_test.group_graph(D,D_lt,parameters)\n",
    "\n",
    "precision_test.runtime(1,D,D_lt,parameters)\n",
    "precision_test.summary()\n",
    "precision_test.plot()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Graph_Learning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
