{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7637ef6c-2441-4133-91bd-7a96978585c1",
   "metadata": {},
   "source": [
    "# Benchmarking of WHI RCT and OS with selection bias"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c05c57be",
   "metadata": {},
   "source": [
    "we will include all the patients who were not selected, and they will be S = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "133b8401",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "import numpy as np \n",
    "import os \n",
    "import sys \n",
    "from tqdm import tqdm\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.model_selection import train_test_split\n",
    "from scipy.stats import zscore"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9810e77e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# read tables\n",
    "dir_path = '/data'\n",
    "out   = pd.read_csv(os.path.join(dir_path, 'whi/data/main_study/csv/outc_adj_bio.csv'))\n",
    "ct_fu = pd.read_csv(os.path.join(dir_path, 'whi/data/main_study/csv/adh_ht_pub.csv'))[['ID', 'ADHRATE', 'ENDDY', 'STARTDY', 'LOST', 'STOPHRT']] \n",
    "std_trt = pd.read_csv(os.path.join(dir_path, 'whi/data/main_study/csv/dem_ctos_bio.csv'))[['ID', 'HRTARM', 'OSFLAG']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "45371f73",
   "metadata": {},
   "outputs": [],
   "source": [
    "# List of outcomes     \n",
    "glbl_list = ['CHD', 'BREAST', 'STROKE', 'PE', 'ENDMTRL', 'COLORECTAL', 'BKHIP', 'DEATH']    \n",
    "other_list = ['PTCA', 'DVT']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1977db63",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get end of follow-up for CT patients \n",
    "# BTW, do we have to consider START-DAY? what about LOST for censoring?\n",
    "\n",
    "# keep only those with ADHRATE not missing, and group by ID to get max ENDDY\n",
    "# keep columns 'ID', 'ENDDY', and 'LOST'\n",
    "# rename ENDDY to END_DY\n",
    "# ct_end = ct_fu[ct_fu['ADHRATE'].notna()].groupby('ID')['ENDDY'].max().reset_index() \n",
    "# ct_end = ct_end.rename(columns={'ENDDY': 'END_DY'})\n",
    "# ct_end\n",
    "\n",
    "# ct_end = ct_fu[ct_fu['ADHRATE'].notna()][['ADHRATE','ID', 'ENDDY', 'LOST']].rename(columns={'ENDDY': 'END_DY'})\n",
    "# ct_end = ct_fu[['ADHRATE','ID', 'ENDDY', 'LOST']].rename(columns={'ENDDY': 'END_DY'})\n",
    "# ct_end = ct_end.query('ADHRATE != 0.')[['ID','END_DY','LOST']]\n",
    "\n",
    "## OG\n",
    "# ct_end = ct_fu[ct_fu['ADHRATE'].notna()].groupby('ID')['ENDDY'].max().reset_index() \n",
    "# ct_end = ct_end.rename(columns={'ENDDY': 'END_DY'})\n",
    "# ct_end\n",
    "def get_lost_day(group):\n",
    "    lost_rows = group[group['LOST'] == 'Yes']\n",
    "    return lost_rows['ENDDY'].iloc[0] if not lost_rows.empty else None\n",
    "\n",
    "ct_end = (ct_fu[ct_fu['ADHRATE'].notna()]\n",
    "          .groupby('ID')\n",
    "          .agg({\n",
    "              'ENDDY': 'max',\n",
    "              'LOST': lambda x: 1 if 'Yes' in x.values else 0,\n",
    "          })\n",
    "          .reset_index())\n",
    "\n",
    "# ADD LOST_DY column\n",
    "lost_days = (ct_fu[ct_fu['ADHRATE'].notna()]\n",
    "             .groupby('ID')[['LOST','ENDDY']]\n",
    "             .apply(get_lost_day)\n",
    "             .rename('LOST_DY'))\n",
    "\n",
    "ct_end = ct_end.merge(lost_days.to_frame(), on='ID', how='left')\n",
    "ct_end = ct_end.rename(columns={'ENDDY': 'END_DY'})\n",
    "ct_end\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "784a8de1-4277-42b4-8530-9ab00e1c96a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "ct_end[ct_end['LOST'] == 1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce9819a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "ct_df = std_trt.drop_duplicates('ID')\n",
    "ct_df = ct_df[ct_df['HRTARM'].isin(['E+P intervention', 'E+P control'])]\n",
    "ct_df = ct_df.merge(ct_end, on='ID', how='left')\n",
    "ct_df = ct_df.merge(out, on='ID', how='left')\n",
    "\n",
    "# code variables HRTARM and OS \n",
    "ct_df['OS'] = 0 \n",
    "ct_df['HRTARM'] = ct_df['HRTARM'].map({'E+P intervention': 1, 'E+P control': 0})\n",
    "\n",
    "# print out first 10 rows\n",
    "print(ct_df.shape)\n",
    "print(ct_df[ct_df['HRTARM'] == 1].shape)\n",
    "print(ct_df[ct_df['HRTARM'] == 0].shape)\n",
    "ct_df.head(n=10)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cd107864",
   "metadata": {},
   "outputs": [],
   "source": [
    "diff_selection_for_CT = False\n",
    "\n",
    "# process outcomes \n",
    "for i in glbl_list + other_list: \n",
    "    ct_df[i+'_E']  = ((ct_df[i] == 1) & (ct_df[i+'DY'] <= ct_df['END_DY'])).astype(int)\n",
    "    ct_df[i+'_DY'] = np.where(ct_df[i+'_E'] == 1, ct_df[i+'DY'], ct_df['END_DY'])\n",
    "    ct_df[i+'_EDY'] = np.where(ct_df[i+'_E'] == 1, ct_df[i+'DY'], np.nan) \n",
    "\n",
    "# Global index\n",
    "ct_df['GLBL_E'] = (ct_df[[j+'_E' for j in glbl_list]].sum(axis=1) > 0).astype(int)\n",
    "ct_df['GLBL_DY'] = np.where(ct_df['GLBL_E'] == 1,\n",
    "                            ct_df[[j+'_EDY' for j in glbl_list]].min(axis=1),\n",
    "                            ct_df[[j+'_DY' for j in glbl_list]].min(axis=1))\n",
    "\n",
    "# Selection variable \n",
    "ct_df['S'] = 1\n",
    "\n",
    "# Add different selection variables for each outcome (this is because S = 0 for censored patients)\n",
    "for i in glbl_list + other_list: \n",
    "    ct_df['S_'+i] = ct_df['S']\n",
    "    if diff_selection_for_CT:\n",
    "        ct_df['S_'+i] = np.where(ct_df[i+'DY'] > ct_df['END_DY'], 0, ct_df['S_'+i])\n",
    "        ct_df['S_'+i] = np.where(((ct_df['LOST'] == 1) & (ct_df[i+'DY'] > ct_df['LOST_DY'])), 0, ct_df['S_'+i])\n",
    "ct_df['S_GLBL'] = ct_df['S']\n",
    "\n",
    "# Select needed columns\n",
    "ct_df = ct_df[['ID', 'OS', 'HRTARM'] + \n",
    "                ['S_'+j for j in glbl_list + other_list + ['GLBL']] +\n",
    "                [j+'_E' for j in glbl_list + other_list + ['GLBL']] + \n",
    "                [j+'_DY' for j in glbl_list + other_list + ['GLBL']]]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a906d59b-699b-4d99-b959-e715feac45b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "ct_df[ct_df['S_STROKE'] == 0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cb9b1af",
   "metadata": {},
   "outputs": [],
   "source": [
    "ct_df.query('HRTARM == 0 & STROKE_E == 1')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5650268-d445-4ed2-a473-13725b8e986e",
   "metadata": {},
   "source": [
    "## OS Specification (SELECTION FLAG + CENSORED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "769eecc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "dir_path = '/data'\n",
    "hyst    = pd.read_csv(os.path.join(dir_path, 'whi/data/main_study/csv/f2_ctos_bio.csv'))[['ID','HYST']]\n",
    "pre_hrt  = pd.read_csv(os.path.join(dir_path, 'whi/data/main_study/csv/f43_ctos_bio.csv'))[['ID', 'TOTESTAT','TOTPSTAT']]\n",
    "post_hrt = pd.read_csv(os.path.join(dir_path, 'whi/data/main_study/csv/f48_av1_os_pub.csv'))[['ID','ELSTYR','PLSTYR','HRTCMBP']]\n",
    "unc_hf   = pd.read_csv(os.path.join(dir_path, 'whi/data/main_study/csv/unc_hf_bio.csv'))[['ID','CHDYRHX','CHDEVERHX','HYPERTNHX','MIHX','PVDHX','DIABHX','STROKEHX']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58ee9f60",
   "metadata": {},
   "outputs": [],
   "source": [
    "# construct os_df \n",
    "selection_flag = 'biased'\n",
    "'''\n",
    "drop_all_excluded: this drops all patients who had hysterectomy OR are on unopposed estrogen; thus, selection, S = 0 and S = 1, is based on censoring only\n",
    "drop_some_excluded: this keeps patients who had hyseterectomy OR are on unopposed estrogen but were past users of combined HRT, assigns them to be S = 0;\n",
    "censored patients are additionally S = 0\n",
    "drop_no_excluded: keeps all patients who had hysterectomy OR are on unopposed estrogen, and assigns them S = 0; censored patients are additionally S = 0\n",
    "'''\n",
    "additional_selection_processing = 'drop_all_excluded' # 'drop_some_excluded', 'drop_no_excluded', 'drop_all_excluded'\n",
    "'''\n",
    "if censored_patients_sel0 = True, then censored patients are additionally S = 0\n",
    "'''\n",
    "censored_patients_sel0 = True \n",
    "\n",
    "os_df = std_trt.drop_duplicates('ID')\n",
    "os_df = os_df[os_df['OSFLAG'] == 'Yes']\n",
    "os_df = os_df.merge(hyst, on='ID', how='left')\n",
    "os_df = os_df.merge(pre_hrt, on='ID', how='left')\n",
    "print(os_df['TOTESTAT'].value_counts())\n",
    "print(os_df['HYST'].value_counts())\n",
    "os_df = os_df.merge(post_hrt, on='ID', how='left')\n",
    "if additional_selection_processing == 'drop_some_excluded': \n",
    "    os_df = os_df.merge(unc_hf, on='ID', how='left')\n",
    "    condition_dict = {\n",
    "        'CHDEVERHX': ('!=', 1.),\n",
    "        'HYPERTNHX': ('!=', 1.),\n",
    "        'MIHX': ('!=', 1.),\n",
    "        'PVDHX': ('!=', 1.),\n",
    "        'DIABHX': ('!=', 1.),\n",
    "        'STROKEHX': ('!=', 1.)\n",
    "    }\n",
    "    \n",
    "    # condition = (os_df['HYST'] == 'Yes') & (\n",
    "    #     pd.concat([\n",
    "    #         os_df[var].apply(lambda x: eval(f\"x {op} {repr(val)}\"))\n",
    "    #         for var, (op, val) in condition_dict.items()\n",
    "    #     ], axis=1).all(axis=1) | (os_df['TOTPSTAT'] == 'Never used')\n",
    "    # ) \n",
    "    # os_df = os_df[~condition]\n",
    "    # condition2 = (os_df['TOTESTAT'] == 'Current user') & (\n",
    "    #     pd.concat([\n",
    "    #         os_df[var].apply(lambda x: eval(f\"x {op} {repr(val)}\"))\n",
    "    #         for var, (op, val) in condition_dict.items()\n",
    "    #     ], axis=1).all(axis=1) | (os_df['TOTPSTAT'] == 'Never used')\n",
    "    # )\n",
    "    # os_df = os_df[~condition2]\n",
    "    condition = (os_df['HYST'] == 'Yes') & ((os_df['TOTPSTAT'] == 'Never used') | (os_df['TOTPSTAT'] == 'Current user'))\n",
    "    os_df = os_df[~condition]\n",
    "    condition2 = (os_df['TOTESTAT'] == 'Current user') & ((os_df['TOTPSTAT'] == 'Never used') | (os_df['TOTPSTAT'] == 'Current user'))\n",
    "    os_df = os_df[~condition2]\n",
    "    os_df['S'] = os_df.apply(\n",
    "        lambda row: 0 if (row['HYST'] == 'Yes' or row['TOTESTAT'] == 'Current user') else 1,\n",
    "        axis=1\n",
    "    )\n",
    "elif additional_selection_processing == 'drop_no_excluded': \n",
    "    # Selected patients\n",
    "    os_df['S'] = os_df.apply(\n",
    "        lambda row: 1 if (row['HYST'] == 'No' and row['TOTESTAT'] in ['Never used', 'Past user']) else 0,\n",
    "        axis=1\n",
    "    )\n",
    "elif additional_selection_processing == 'drop_all_excluded':\n",
    "    os_df = os_df[os_df['HYST'] == 'No']\n",
    "    os_df = os_df[os_df['TOTESTAT'].isin(['Never used', 'Past user'])]\n",
    "    os_df['S'] = 1\n",
    "\n",
    "os_df = os_df.merge(out, on='ID', how='left')\n",
    "\n",
    "# 35551 (control) + 17503 (intervention) = 53054\n",
    "\n",
    "print(os_df[os_df['TOTPSTAT'].isin(['Current user'])].shape)\n",
    "print(os_df[os_df['TOTPSTAT'].isin(['Never used', 'Past user'])].shape)\n",
    "\n",
    "if selection_flag == 'biased': \n",
    "    os_df = os_df[os_df['TOTPSTAT'].isin(['Never used', 'Past user','Current user'])]\n",
    "    os_df['HRTARM'] = os_df['TOTPSTAT'].map({'Current user': 1, 'Never used': 0, 'Past user': 0})\n",
    "elif selection_flag == 'unbiased' or selection_flag == 'manually_biased': \n",
    "    os_df = os_df[os_df['TOTPSTAT'].isin(['Never used', 'Past user','Current user'])]\n",
    "    conditions = [\n",
    "        (((os_df['ELSTYR'] == 'Yes') & (os_df['PLSTYR'] == 'Yes')) | (os_df['HRTCMBP'] == 'Yes')),\n",
    "        ((os_df['ELSTYR'] == 'No') & (os_df['PLSTYR'] == 'No')),\n",
    "        (((os_df['ELSTYR'] == 'Yes') & (os_df['PLSTYR'] == 'No')) | ((os_df['ELSTYR'] == 'No') & (os_df['PLSTYR'] == 'Yes')))\n",
    "    ]\n",
    "    choices = [1, 0, -1]\n",
    "    os_df['HRTGRP'] = np.select(conditions, choices, default=-2)\n",
    "    os_df = os_df[os_df['HRTGRP'] != -2]\n",
    "    os_df['HRTARM'] = (os_df['HRTGRP'] == 1).astype(int)\n",
    "    os_df['S'] = os_df.apply(lambda row: 0 if row['TOTPSTAT'] == 'Current user' else row['S'], axis=1)\n",
    "os_df['OS'] = 1\n",
    "\n",
    "# os_end_day = None\n",
    "os_end_day = 6*365\n",
    "os_df['END_DY'] = os_end_day if os_end_day is not None else os_df['ENDFOLLOWDY']\n",
    "# os_df['END_DY'] = os_df.apply(lambda x: x['DEATHDY'] if x['DEATHDY'] < os_end_day else os_end_day, axis=1)\n",
    "\n",
    "\n",
    "# Process outcomes (same as CT)\n",
    "for i in glbl_list + other_list:\n",
    "    os_df[i+'_E'] = ((os_df[i] == 1) & (os_df[i+'DY'] <= os_df['END_DY'])).astype(int)\n",
    "    os_df[i+'_DY'] = np.where(os_df[i+'_E'] == 1, os_df[i+'DY'], os_df['END_DY'])\n",
    "    os_df[i+'_EDY'] = np.where(os_df[i+'_E'] == 1, os_df[i+'_DY'], np.nan)\n",
    "\n",
    "# Global index\n",
    "os_df['GLBL_E'] = (os_df[[j+'_E' for j in glbl_list]].sum(axis=1) > 0).astype(int)\n",
    "os_df['GLBL_DY'] = np.where(os_df['GLBL_E'] == 1,\n",
    "                            os_df[[j+'_EDY' for j in glbl_list]].min(axis=1),\n",
    "                            os_df[[j+'_DY' for j in glbl_list]].min(axis=1))\n",
    "\n",
    "# Selection variable adjustment\n",
    "for i in glbl_list + other_list:\n",
    "    os_df['S_'+i] = os_df['S']\n",
    "    if censored_patients_sel0: \n",
    "        os_df['S_'+i] = np.where(os_df[i+'DY'] > os_df['END_DY'], 0, os_df['S_'+i])\n",
    "os_df['S_GLBL'] = os_df['S']\n",
    "\n",
    "# Select needed columns\n",
    "os_df = os_df[['ID', 'OS', 'HRTARM'] + \n",
    "                ['S_'+j for j in glbl_list + other_list + ['GLBL']] + \n",
    "                [j+'_E' for j in glbl_list + other_list + ['GLBL']] + \n",
    "                [j+'_DY' for j in glbl_list + other_list + ['GLBL']]]\n",
    "\n",
    "os_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "583fc12d-95d0-4eeb-b491-0878dcf5f07e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(os_df[os_df['S_CHD'] == 0].shape)\n",
    "print(os_df[os_df['S_CHD'] == 1].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f08b4a50-e9fa-4458-ab3e-0b68646f1564",
   "metadata": {},
   "source": [
    "# Setup of dataframes with target variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f1eae3cd-659d-4456-a2a7-9e8698dffa75",
   "metadata": {},
   "outputs": [],
   "source": [
    "ctos_df = pd.concat([ct_df, os_df], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "5aff6186-cbc9-47e5-a9b3-4ac60f109772",
   "metadata": {},
   "outputs": [],
   "source": [
    "if selection_flag == 'manually_biased': \n",
    "    # removing age and menopausal status\n",
    "    categorical_features = {\n",
    "        'dem_ctos_bio.csv': {'ETHNIC': True, 'EDUC': True}, \n",
    "        'f80_ctos_bio.csv': {'BMI': False}, \n",
    "        'f34_ctos_bio.csv': {'SMOKING': True}, \n",
    "        'f151_ctos_bio.csv': {'PHYSFUN': False}    \n",
    "    }\n",
    "    \n",
    "    new_feature_dict = { \n",
    "        'dem_ctos_bio.csv': ['ETHNIC_White', \\\n",
    "                             'EDUC_Some post-graduate or professional', \\\n",
    "                             'EDUC_Some college or Associate Degree'],\n",
    "        'f80_ctos_bio.csv': ['BMI'],\n",
    "        'f34_ctos_bio.csv': ['SMOKING_Past Smoker', 'SMOKING_Current Smoker'],\n",
    "        'f151_ctos_bio.csv': ['PHYSFUN']\n",
    "    }\n",
    "else:\n",
    "    categorical_features = {\n",
    "        'dem_ctos_bio.csv': {'AGE': False, 'ETHNIC': True, 'EDUC': True}, \n",
    "        'f80_ctos_bio.csv': {'BMI': False}, \n",
    "        'f34_ctos_bio.csv': {'SMOKING': True}, \n",
    "        'f31_ctos_bio.csv': {'MENO': False}, \n",
    "        'f151_ctos_bio.csv': {'PHYSFUN': False}    \n",
    "    }\n",
    "    \n",
    "    new_feature_dict = { \n",
    "        'dem_ctos_bio.csv': ['AGE', 'ETHNIC_White', \\\n",
    "                             'EDUC_Some post-graduate or professional', \\\n",
    "                             'EDUC_Some college or Associate Degree'],\n",
    "        'f80_ctos_bio.csv': ['BMI'],\n",
    "        'f34_ctos_bio.csv': ['SMOKING_Past Smoker', 'SMOKING_Current Smoker'],\n",
    "        'f31_ctos_bio.csv': ['MENO'],\n",
    "        'f151_ctos_bio.csv': ['PHYSFUN']\n",
    "    }\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42e14324-7076-49f8-a4f2-dec9c65cd71e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas.api.types as ptypes\n",
    "\n",
    "ctos_temp = ctos_df.copy()\n",
    "# Dictionary to specify which features are categorical\n",
    "\n",
    "# dfs = []  # Store all dataframes to concatenate later\n",
    "new_dir_path = os.path.join(dir_path, 'whi/data/main_study/csv')\n",
    "\n",
    "for filename, f_dict in categorical_features.items():\n",
    "    # Read the data\n",
    "    df = pd.read_csv(os.path.join(new_dir_path, filename))\n",
    "    if filename == 'f80_ctos_bio.csv': \n",
    "        df = df.query('F80VTYP == \"Screening\"')\n",
    "    elif filename == 'f151_ctos_bio.csv': \n",
    "        idx = df.groupby('ID')['F151DAYS'].idxmin().reset_index(drop=True)\n",
    "        df = df.loc[idx, :].reset_index(drop=True)[['ID','PHYSFUN']]\n",
    "    # Select needed columns\n",
    "    features = list(f_dict.keys())\n",
    "    df = df[['ID'] + features]\n",
    "    \n",
    "    # Separate ID column\n",
    "    id_col = df['ID']\n",
    "    print(f\"Processed {filename}\")\n",
    "    print(df.shape)\n",
    "\n",
    "    orig_cols = ctos_temp.columns.tolist()\n",
    "    ctos_temp = ctos_temp.merge(df, on='ID', how='left')\n",
    "\n",
    "    # Handle continuous and categorical features separately\n",
    "    cont_features = [f for f in features if not f_dict[f]]\n",
    "    cat_features = [f for f in features if f_dict[f]]\n",
    "    \n",
    "    # Handle continuous features\n",
    "    if cont_features:\n",
    "        cont_imputer = SimpleImputer(missing_values=np.nan, strategy='mean')\n",
    "        ctos_temp[cont_features] = cont_imputer.fit_transform(ctos_temp[cont_features])\n",
    "    \n",
    "    # Handle categorical features\n",
    "    if cat_features:\n",
    "        cat_imputer = SimpleImputer(missing_values=np.nan, strategy='most_frequent')\n",
    "        ctos_temp[cat_features] = cat_imputer.fit_transform(ctos_temp[cat_features])\n",
    "        \n",
    "        # One-hot encode categorical features\n",
    "        ctos_temp = pd.get_dummies(ctos_temp, columns=cat_features, prefix=cat_features)\n",
    "\n",
    "    if filename == 'dem_ctos_bio.csv': \n",
    "        ctos_temp = ctos_temp.rename(columns={'ETHNIC_White (not of Hispanic origin)': 'ETHNIC_White'})\n",
    "\n",
    "    ctos_temp = ctos_temp[orig_cols + new_feature_dict[filename]]\n",
    "\n",
    "ctos_temp = ctos_temp.astype({col: int for col in ctos_temp.select_dtypes(include='bool').columns})\n",
    "display(ctos_temp)    \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "029a6cd8-919a-4d62-92e1-3a4e7f8edac1",
   "metadata": {},
   "source": [
    "## Quick Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6e8f402-eaf2-4d28-adf2-4be03df21187",
   "metadata": {},
   "outputs": [],
   "source": [
    "# hazard ratios for stroke, breast cancer, and CHD in clinical trial vs observational study \n",
    "\n",
    "## CT \n",
    "ct_df = ctos_temp.query('OS == 0 & S_GLBL == 1')\n",
    "ct_df_sub = ct_df[['ID','HRTARM', 'STROKE_E', 'BREAST_E', 'CHD_E','STROKE_DY', 'BREAST_DY', 'CHD_DY']]\n",
    "ct_df_chd = ct_df[['HRTARM', 'CHD_E', 'CHD_DY']]\n",
    "ct_df_chd = ct_df_chd[ct_df_chd['CHD_DY'].notna()]\n",
    "\n",
    "ct_df_stroke = ct_df[['HRTARM', 'STROKE_E', 'STROKE_DY']]\n",
    "ct_df_stroke = ct_df_stroke[ct_df_stroke['STROKE_DY'].notna()]\n",
    "\n",
    "ct_df_breast = ct_df[['HRTARM', 'BREAST_E', 'BREAST_DY']]\n",
    "ct_df_breast = ct_df_breast[ct_df_breast['BREAST_DY'].notna()]\n",
    "\n",
    "from lifelines import CoxPHFitter\n",
    "\n",
    "def get_hr(df, Y, E, event_name, HR_cov='HRTARM', study_type='Clinical Trial'): \n",
    "    cph = CoxPHFitter()\n",
    "    cph.fit(df, duration_col=Y, event_col=E)\n",
    "    cph.print_summary()\n",
    "    cHR = cph.hazard_ratios_[HR_cov]\n",
    "    cis = cph.confidence_intervals_\n",
    "    lower = np.exp(cis['95% lower-bound'][HR_cov])\n",
    "    upper = np.exp(cis['95% upper-bound'][HR_cov])\n",
    "    print(f'Hazard ratio for {event_name} in {study_type}: {np.round(cHR, 2)} (95% CI: {np.round(lower, 2)}, {np.round(upper, 2)})')\n",
    "\n",
    "get_hr(ct_df_chd, 'CHD_DY', 'CHD_E', 'CHD')\n",
    "get_hr(ct_df_stroke, 'STROKE_DY', 'STROKE_E', 'Stroke')\n",
    "get_hr(ct_df_breast, 'BREAST_DY', 'BREAST_E', 'Breast Cancer')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6443f28-3de7-4e5f-b423-38aa45116433",
   "metadata": {},
   "outputs": [],
   "source": [
    "# OS \n",
    "os_df = ctos_temp.query('OS == 1 & S_CHD == 1')\n",
    "features = ['AGE','ETHNIC_White', 'EDUC_Some post-graduate or professional', \\\n",
    "            'EDUC_Some college or Associate Degree', 'BMI', 'SMOKING_Past Smoker', \\\n",
    "            'SMOKING_Current Smoker', 'MENO', 'PHYSFUN']\n",
    "treatment = ['HRTARM']\n",
    "events = ['CHD_E', 'CHD_DY']\n",
    "event_name = 'CHD'\n",
    "# events = ['STROKE_E', 'STROKE_DY']\n",
    "# event_name = 'STROKE'\n",
    "\n",
    "os_df_sub = os_df[features + treatment + events]\n",
    "os_df_sub = os_df_sub[os_df_sub[events[1]].notna()]\n",
    "\n",
    "get_hr(os_df_sub, events[1], events[0], event_name, HR_cov='HRTARM', study_type='Observational Study')\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c3b01f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# OS \n",
    "os_df = ctos_temp.query('OS == 1 & S_STROKE == 1')\n",
    "features = ['AGE','ETHNIC_White', 'EDUC_Some post-graduate or professional', \\\n",
    "            'EDUC_Some college or Associate Degree', 'BMI', 'SMOKING_Past Smoker', \\\n",
    "            'SMOKING_Current Smoker', 'MENO', 'PHYSFUN']\n",
    "treatment = ['HRTARM']\n",
    "events = ['STROKE_E', 'STROKE_DY']\n",
    "event_name = 'STROKE'\n",
    "\n",
    "os_df_sub = os_df[features + treatment + events]\n",
    "os_df_sub = os_df_sub[os_df_sub[events[1]].notna()]\n",
    "\n",
    "get_hr(os_df_sub, events[1], events[0], event_name, HR_cov='HRTARM', study_type='Observational Study')\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2550e75-28e5-46a9-8a23-ed1ba7f1123f",
   "metadata": {},
   "source": [
    "## Adding target variables (CHD/STROKE + LR/RF)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "626aa794-f554-432f-919f-5069c6622646",
   "metadata": {},
   "outputs": [],
   "source": [
    "# A, Y, R, S, X\n",
    "ctos_temp.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "10b8d739-0108-4cc3-a9bd-987cc4881fb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split \n",
    "\n",
    "df_ctos = ctos_temp.copy()\n",
    "if selection_flag == 'manually_biased': \n",
    "    predictors = ['ETHNIC_White', 'EDUC_Some post-graduate or professional', \n",
    "          'EDUC_Some college or Associate Degree', 'BMI', 'SMOKING_Past Smoker', \n",
    "          'SMOKING_Current Smoker', 'PHYSFUN'] \n",
    "else: \n",
    "    predictors = ['AGE', 'ETHNIC_White', 'EDUC_Some post-graduate or professional', \n",
    "          'EDUC_Some college or Associate Degree', 'BMI', 'SMOKING_Past Smoker', \n",
    "          'SMOKING_Current Smoker', 'MENO', 'PHYSFUN'] \n",
    "outcome_name = 'CHD' # STROKE, BREAST\n",
    "\n",
    "outcome= outcome_name + '_E'\n",
    "trt    = 'HRTARM'\n",
    "select = f'S_{outcome_name}'\n",
    "\n",
    "drop_columns = [x for x in df_ctos.columns if x not in predictors + [outcome, trt, 'ID', 'S']]\n",
    "df_ctos.rename(columns={trt: 'A', outcome: 'Y'}, inplace=True) \n",
    "df_ctos['S']  = df_ctos[select]\n",
    "df_ctos['Y0'] = df_ctos['Y']\n",
    "df_ctos['Y1'] = df_ctos['Y']\n",
    "df_ctos['R'] = 1 - df_ctos['OS']\n",
    "df_ctos.drop(columns=drop_columns, inplace=True)\n",
    "\n",
    "seeds = [42]\n",
    "seeds += [x for x in range(19)]\n",
    "\n",
    "df_rct_train_list = []\n",
    "df_obs_train_list = []\n",
    "df_rct_val_list = [] \n",
    "df_obs_val_list = []\n",
    "\n",
    "for seed in seeds:  \n",
    "    # split into train and val\n",
    "    df_ctos_train, df_ctos_val = train_test_split(df_ctos, test_size=0.25, random_state=seed)\n",
    "\n",
    "    # split into RCT and OBS\n",
    "    df_rct_train = df_ctos_train.query('R == 1') \n",
    "    df_obs_train = df_ctos_train.query('R == 0')\n",
    "    df_rct_val   = df_ctos_val.query('R == 1')\n",
    "    df_obs_val   = df_ctos_val.query('R == 0')\n",
    "\n",
    "    # add into lists\n",
    "    df_rct_train_list.append(df_rct_train)\n",
    "    df_obs_train_list.append(df_obs_train)\n",
    "    df_rct_val_list.append(df_rct_val) \n",
    "    df_obs_val_list.append(df_obs_val)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f82ce336-069d-4b0e-a35e-8e88a6ab1c3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_rct_train_list[9]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a01ad45-d447-4c71-85eb-2a42b4862350",
   "metadata": {},
   "source": [
    "## Training models (multiple trials, $n=20$)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "bc013fd2-7dc9-4489-af27-395eaf856321",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../synthetic/')\n",
    "from utils_models_v2 import *\n",
    "from collections import defaultdict\n",
    "from utils_v2 import pearsonr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0375981f-cd42-4331-9f82-7878a14ee6a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_trials = len(seeds)\n",
    "bias_res = list() \n",
    "cov_res = defaultdict(lambda: defaultdict(list))\n",
    "from tqdm import tqdm\n",
    "model_type = 'RF'\n",
    "for i in tqdm(range(num_trials)): \n",
    "    df_rct_train = df_rct_train_list[i]\n",
    "    df_obs_train = df_obs_train_list[i]\n",
    "    df_rct_val   = df_rct_val_list[i]\n",
    "    df_obs_val   = df_obs_val_list[i]\n",
    "\n",
    "    rct_models = fit_models(df_rct_train, predictors, is_rct=True, model=model_type)\n",
    "    make_preds(df_rct_val, predictors, rct_models)\n",
    "    \n",
    "    obs_models = fit_models(df_obs_train, predictors, is_rct=False, model=model_type)\n",
    "    make_preds(df_obs_val, predictors, obs_models)\n",
    "\n",
    "    pr_model = fit_model(pd.concat([df_rct_train, df_obs_train]), predictors, \"R\")\n",
    "    df_val = merge_df_val(df_rct_val, df_obs_val, predictors, pr_model, rct_models, obs_models)\n",
    "    bias_res.append(df_val['b1(X)'].mean())\n",
    "    for key in ['SE_Y0', 'SE_Y1', 'SE_A', 'SE_S']:\n",
    "        cov_res['Pearson'][key].append(pearsonr(df_val, 'abs(b1(X))', key, df_val.shape[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3877061e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure()\n",
    "plt.hist(df_obs_val[\"hat_P(A=1)\"], bins=50)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ab07381-3ef7-4ff1-a083-a407f5279385",
   "metadata": {},
   "outputs": [],
   "source": [
    "cov_res_final = defaultdict(list)\n",
    "keys = ['SE_Y0', 'SE_Y1', 'SE_A', 'SE_S']\n",
    "alpha = 0.01\n",
    "for key in keys: \n",
    "    l = cov_res['Pearson'][key]\n",
    "    # res = [x[0] for x in l if x[1] < alpha]\n",
    "    # mean, standard deviation, sample size \n",
    "    mean = np.mean(l); std = np.std(l); n = len(l)\n",
    "    lower = mean - 1.96 * (std / np.sqrt(n))\n",
    "    upper = mean + 1.96 * (std / np.sqrt(n))\n",
    "    cov_res_final[key].append(mean)\n",
    "    cov_res_final[key].append(lower)\n",
    "    cov_res_final[key].append(upper)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c447551d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_obs_val.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07b51dcc-2a2e-4c7e-8208-1702db1ff495",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame.from_dict(cov_res_final, orient='index', columns=['mean', 'lower', 'upper'])\n",
    "filename_save = f'./results/run_{selection_flag}_includecensored_{censored_patients_sel0}_{outcome_name}_{model_type}.csv'\n",
    "print(f'Saving {filename_save}....')\n",
    "df.to_csv(\n",
    "    filename_save,          # File name\n",
    "    sep=',',               # Delimiter (comma)\n",
    "    index=True,            # Include index\n",
    "    header=True,           # Include headers\n",
    "    float_format='%.6f'    # Floating-point format\n",
    ")\n",
    "df\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4bb42a31-0b8c-4472-a4e9-3635998fa90a",
   "metadata": {},
   "source": [
    "## Training models (1 run, $n=1$)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "573d089b-7632-415f-a9c7-56c58fb30bf5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../synthetic/')\n",
    "from utils_data_v2 import fit_models, make_preds, merge_df_val, fit_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de2c5601-c05a-433a-b227-479ce1dd7716",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ctos_train, df_ctos_val = train_test_split(df_ctos, test_size=0.25, random_state=seed)\n",
    "\n",
    "# split into RCT and OBS\n",
    "df_rct_train = df_ctos_train.query('R == 1') \n",
    "df_obs_train = df_ctos_train.query('R == 0')\n",
    "df_rct_val   = df_ctos_val.query('R == 1')\n",
    "df_obs_val   = df_ctos_val.query('R == 0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3612575b-8183-423d-8973-d8aaffc16c2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "rct_models = fit_models(df_rct_train, predictors, is_rct=True)\n",
    "make_preds(df_rct_val, predictors, rct_models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fa60840-0bf0-457a-901d-600238ca1371",
   "metadata": {},
   "outputs": [],
   "source": [
    "obs_models = fit_models(df_obs_train, predictors, is_rct=False)\n",
    "make_preds(df_obs_val, predictors, obs_models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c45ef2c7-95d6-4675-bfd8-3dce098f7503",
   "metadata": {},
   "outputs": [],
   "source": [
    "pr_model = fit_model(pd.concat([df_rct_train, df_obs_train]), predictors, \"R\")\n",
    "df_val = merge_df_val(df_rct_val, df_obs_val, predictors, pr_model, rct_models, obs_models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75c756b7-c77f-4088-81c7-dc635cfc1a56",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e43de867-37cc-4612-a55f-899ab66028b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils_v2 import pearsonr\n",
    "from collections import defaultdict\n",
    "bias_res = list() \n",
    "bias_res.append(df_val['w1(X)'].mean())\n",
    "cov_res = defaultdict(lambda: defaultdict(list))\n",
    "print(df_val.shape[0])\n",
    "for key in ['SE_Y0', 'SE_Y1', 'SE_A', 'SE_S']:\n",
    "    cov_res['Pearson'][key].append(pearsonr(df_val, 'abs(w1(X))', key, df_val.shape[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c439146a-657b-412d-85a8-3f2c03b0255d",
   "metadata": {},
   "outputs": [],
   "source": [
    "cov_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7934e807-b59a-4bf8-911d-f4e03cdaf493",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "cp = sns.color_palette(\"tab10\")\n",
    "fig, axs = plt.subplots(2, 2, figsize=(10,8)) \n",
    "\n",
    "for idx, key in enumerate(cov_res[\"Pearson\"]):\n",
    "    i = idx // 2\n",
    "    j = idx % 2\n",
    "\n",
    "    axs[i, j].set_title(f\"Cov($w1(X)$, {key})\", fontsize=16)\n",
    "    axs[i, j].axhline(y=-np.log10(0.05), color='dimgray', linestyle='--', label='p = 0.05')\n",
    "    \n",
    "    arr = np.array(cov_res[\"Pearson\"][key])\n",
    "    hat_cov = arr[:, 0]\n",
    "    log_p_val = np.clip(-np.log10(arr[:, 1]), a_min=None, a_max=5)\n",
    "    axs[i, j].scatter(hat_cov, log_p_val, color=cp[idx], s=16, alpha=1)\n",
    "\n",
    "axs[0,0].set_ylabel('-log10(p-value)', fontsize=16)\n",
    "axs[1,0].set_ylabel('-log10(p-value)', fontsize=16)\n",
    "axs[1,0].set_xlabel(\"Pearson's R\", fontsize=16)\n",
    "axs[1,1].set_xlabel(\"Pearson's R\", fontsize=16)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "924c6fee-eb79-423e-96a3-7753c263cf02",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = 'BAC'\n",
    "prefix = 'ZD'\n",
    "prefix == f[:len(prefix)]\n",
    "li = ['A', 'C', 'B']\n",
    "li.sort()\n",
    "print(li)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9ef3376-cecf-4d8e-ba04-17f25cf40afe",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cenfal",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
