{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# One hot encoded physician matrix\n",
    "Z = ...\n",
    "pnuemonia_patient_account_id = Z.index\n",
    "Z.head()\n",
    "\n",
    "# One hot encoded Comorbidities category\n",
    "W = ...\n",
    "W.head()\n",
    "\n",
    "# Treatment indicator matrix\n",
    "T = ...\n",
    "T.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Baseline Construction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set baseline outcome based on coexisting conditions\n",
    "comorb_type = ['AIDS', 'ALCOHOL', 'ANEMDEF', 'AUTOIMMUNE', 'BLDLOSS', 'CANCER_LEUK',\n",
    "       'CANCER_LYMPH', 'CANCER_METS', 'CANCER_NSITU', 'CANCER_SOLID',\n",
    "       'CBVD_POA', 'CBVD_SQLA', 'COAG', 'DEMENTIA', 'DEPRESS', 'DIAB_CX',\n",
    "       'DIAB_UNCX', 'DRUG_ABUSE', 'HF', 'HTN_CX', 'HTN_UNCX', 'LIVER_MLD',\n",
    "       'LIVER_SEV', 'LUNG_CHRONIC', 'NEURO_MOVT', 'NEURO_OTH', 'NEURO_SEIZ',\n",
    "       'OBESE', 'PARALYSIS', 'PERIVASC', 'PSYCHOSES', 'PULMCIRC', 'RENLFL_MOD',\n",
    "       'RENLFL_SEV', 'THYROID_HYPO', 'THYROID_OTH', 'ULCER_PEPTIC', 'VALVE',\n",
    "       'WGHTLOSS']\n",
    "W[comorb_type].sum(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1. Map Comorbidities to \"Recovery Penalties\" (in hours)\n",
    "# Clinical intuition: Heart Failure and Severe Liver disease slow recovery more than Obesity.\n",
    "\n",
    "comorb_weights_dict = {\n",
    "    # --- Category A: Respiratory & Hemodynamic (Direct Interference) ---\n",
    "    'LUNG_CHRONIC': 22,   # COPD/Asthma makes gas exchange much harder\n",
    "    'HF': 20,             # Heart Failure leads to pulmonary edema (fluid in lungs)\n",
    "    'PULMCIRC': 18,       # Pulmonary hypertension/embolism is a direct lung complication\n",
    "    'VALVE': 12,          # Valvular issues complicate fluid management\n",
    "    \n",
    "    # --- Category B: Immune & Metabolic Exhaustion ---\n",
    "    'CANCER_METS': 25,    # Massive systemic load\n",
    "    'AIDS': 20,           # Profound immunosuppression\n",
    "    'CANCER_LEUK': 15,    # Impaired white blood cell function\n",
    "    'WGHTLOSS': 12,       # Malnutrition is a major predictor of pneumonia mortality\n",
    "    'ALCOHOL': 10,        # Increases risk of aspiration and blunts immune response\n",
    "    \n",
    "    # --- Category C: Fluid & Filtration (Kidney/Liver) ---\n",
    "    'RENLFL_SEV': 18,     # Difficulty clearing meds and managing hydration\n",
    "    'LIVER_SEV': 20,      # Impaired protein synthesis for healing\n",
    "    'RENLFL_MOD': 10,\n",
    "    'LIVER_MLD': 7,\n",
    "    \n",
    "    # --- Category D: Neurological (Aspiration Risk) ---\n",
    "    'DEMENTIA': 15,       # High risk of aspirating (swallowing saliva/food into lungs)\n",
    "    'PARALYSIS': 15,      # Impaired cough reflex and mobility\n",
    "    'CBVD_POA': 12,       # Stroke history often means impaired airway protection\n",
    "    'NEURO_SEIZ': 8,\n",
    "    \n",
    "    # --- Category E: Standard Chronic (Systemic Stress) ---\n",
    "    'DIAB_CX': 12,        # High blood sugar slows healing and feeds bacteria\n",
    "    'COAG': 10,           # Clotting issues (DIC is a risk in sepsis)\n",
    "    'HTN_CX': 6,          # General vascular strain\n",
    "    'DIAB_UNCX': 5,\n",
    "    'ANEMDEF': 5,         # Low oxygen carrying capacity\n",
    "    \n",
    "    # --- Category F: Low Impact in Pneumonia Context ---\n",
    "    'OBESE': 3,           # Can complicate ventilation but not a primary delay\n",
    "    'THYROID_HYPO': 2,\n",
    "    'DEPRESS': 2,\n",
    "    'PSYCHOSES': 2,\n",
    "    'ULCER_PEPTIC': 3,    # Mostly a risk for GI bleed, not lung recovery\n",
    "}\n",
    "\n",
    "# Fill missing weights with a default of 5 for other chronic conditions\n",
    "weights = np.array([comorb_weights_dict.get(c, 5) for c in comorb_type])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate individual patient sickness score\n",
    "# (Assuming your comorb matrix is binary 0/1)\n",
    "patient_sickness = W[comorb_type] @ weights"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Treatment Simplification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assume df is your indicator matrix (1s and 0s)\n",
    "# Step 1: Melt the wide matrix to long format\n",
    "T.index.name = 'Phys___Patient_Account'\n",
    "df_long = T.reset_index().melt(id_vars='Phys___Patient_Account', \\\n",
    "                                                            var_name='raw_string', value_name='indicator')\n",
    "\n",
    "# Step 2: Keep only the rows where the medication was actually given\n",
    "df_long = df_long[df_long['indicator'] == 1]\n",
    "df_long.head()\n",
    "\n",
    "# filter out icdpx for this run\n",
    "df_long = df_long[~df_long[\"raw_string\"].str.contains(\"icdpx__\")]\n",
    "\n",
    "# Step 3: Use Regex to extract Molecule and Dosage\n",
    "# Pattern explanation: \n",
    "# ([a-z\\-]+) grabs the drug name\n",
    "# (\\d+) grabs the number\n",
    "# (mg|mcg|ml) grabs the unit\n",
    "import re\n",
    "def extract_info(s):\n",
    "    s_lower = s.lower()\n",
    "    \n",
    "    # Try to extract Name, Dose, and Form\n",
    "    name_match = re.search(r'activitycode__([a-z\\-\\(\\)\\s]+)', s_lower)\n",
    "    dose_match = re.search(r'([\\d,]+\\.?\\d*)\\s*(mg|mcg|ml|g|unit|units|\\%|intl)', s_lower)\n",
    "    form_match = re.search(r'\\b(tabs?|caps?|soln|liq|iv|po|inj|susp|tp|opht|spray)\\b', s_lower)\n",
    "    \n",
    "    # If we found a name AND a dose, we treat it as a continuous variable\n",
    "    if name_match and dose_match:\n",
    "        name = name_match.group(1)\n",
    "        raw_dose = dose_match.group(1).replace(',', '')\n",
    "        try:\n",
    "            dose = float(raw_dose)\n",
    "        except ValueError:\n",
    "            dose = 1.0\n",
    "        form = form_match.group(1) if form_match else \"unspecified\"\n",
    "        return pd.Series([f\"{name}_{form}\", dose], index=['col_name', 'val'])\n",
    "    \n",
    "    # Fallback: Use the original string and treat value as 1 (binary indicator)\n",
    "    words = s_lower.replace(\"activitycode__\", \"\").split()\n",
    "    if len(words) >= 3:\n",
    "        s_lower = f\"{words[0]} {words[1]} {words[2]}\"\n",
    "    return pd.Series([s_lower, 1], index=['col_name', 'val'])\n",
    "\n",
    "df_long[['molecule', 'dose']] = df_long['raw_string'].apply(extract_info)\n",
    "\n",
    "final_matrix = df_long.pivot_table(\n",
    "    index='Phys___Patient_Account', \n",
    "    columns='molecule', \n",
    "    values='dose', \n",
    "    aggfunc='max' # Use 'max' to keep the dose if it exists, or the 1 if binary\n",
    ").fillna(0)\n",
    "\n",
    "final_matrix.shape, T.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "T2 = final_matrix\n",
    "T2 = (T2 > 0)*1\n",
    "T2.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Physician Construction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1. Identify all physician columns\n",
    "phys_cols = [col for col in Z.columns if col.startswith('physicianid__')]\n",
    "\n",
    "# 2. Get counts for each physician\n",
    "phys_counts = Z[phys_cols].sum(axis=0).sort_values(ascending=False)\n",
    "\n",
    "# 3. Define your threshold (e.g., physicians with < 20 patients)\n",
    "threshold = 10\n",
    "rare_physicians = phys_counts[phys_counts < threshold].index.tolist()\n",
    "frequent_physicians = phys_counts[phys_counts >= threshold].index.tolist()\n",
    "\n",
    "# 4. Create the new Physician_ID column\n",
    "# Default to 'no_pref' if it's a rare physician or if they had no ID at all\n",
    "def assign_phys_group(row):\n",
    "    # Find which phys column is 1\n",
    "    active_phys = row[phys_cols][row[phys_cols] == 1].index.tolist()\n",
    "    \n",
    "    if not active_phys or active_phys[0] in rare_physicians:\n",
    "        return 'no_pref'\n",
    "    return active_phys[0].replace('phys_', '')\n",
    "\n",
    "Z['Physician_Group'] = Z.apply(assign_phys_group, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "Z2 = pd.get_dummies(Z[\"Physician_Group\"].loc[pnuemonia_patient_account_id], drop_first=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simple Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA, SparsePCA\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "# simple check\n",
    "\n",
    "# 1. Prepare your Treatment Matrix (2,000 dims)\n",
    "# Ensure columns are your cleaned dosages + simplified procedures\n",
    "X_treatment = final_matrix.copy() \n",
    "\n",
    "# Standardize the data (PCA is sensitive to scale!)\n",
    "scaler = StandardScaler()\n",
    "X_scaled = scaler.fit_transform(X_treatment)\n",
    "\n",
    "# 2. Run PCA\n",
    "# We look at the top 5 components to capture the main \"clinical styles\"\n",
    "pca = PCA(n_components=5)\n",
    "pca_features = pca.fit_transform(X_scaled)\n",
    "\n",
    "# Create a DataFrame with the PCA scores for each patient\n",
    "pca_cols = [f'PC{i+1}' for i in range(5)]\n",
    "df_pca = pd.DataFrame(pca_features, columns=pca_cols, index=X_treatment.index)\n",
    "\n",
    "# 3. Link back to Physician ID\n",
    "# Assuming you have a 'physician_map' dataframe with [Patient_ID, Physician_ID]\n",
    "# Join the PCA scores with the Physician IDs\n",
    "df_style = df_pca.join(Z[['Physician_Group']])\n",
    "\n",
    "# 4. Calculate Physician Propensity (The \"Style Score\")\n",
    "physician_style_scores = df_style.groupby('Physician_Group')[pca_cols].mean()\n",
    "\n",
    "\n",
    "# Extract the variance ratios\n",
    "explained_variance = pca.explained_variance_ratio_\n",
    "cumulative_variance = np.cumsum(explained_variance)\n",
    "print(\"PCA Variance Explained:\")\n",
    "print(\"-\" * 30)\n",
    "for i, ratio in enumerate(explained_variance):\n",
    "    print(f\"Component {i+1}: {ratio:.2%} variance\")\n",
    "\n",
    "print(\"-\" * 30)\n",
    "print(f\"Total variance explained by 5 components: {cumulative_variance[-1]:.2%}\")\n",
    "\n",
    "# 5. Interpret the \"Loadings\"\n",
    "# This tells you WHAT PC1 actually represents (e.g., is it high-dose Abacavir?)\n",
    "loadings = pd.DataFrame(\n",
    "    pca.components_.T, \n",
    "    columns=pca_cols, \n",
    "    index=X_treatment.columns\n",
    ")\n",
    "\n",
    "# Show the top treatments driving PC1\n",
    "for i in range(1, 6, 1):\n",
    "    print(f\"Top features for PC{i} (The dominant clinical style):\")\n",
    "    print(loadings[f'PC{i}'].sort_values(ascending=False).head(10))\n",
    "\n",
    "\n",
    "print(\"\\nPhysician Style Scores (Example):\")\n",
    "print(physician_style_scores.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(1, 6, 1):\n",
    "    print(f\"Top features for PC{i} (The dominant clinical style):\")\n",
    "    print(loadings[f'PC{i}'].sort_values(ascending=False).head(10).index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Construct Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Common and variable treatment\n",
    "temp = (final_matrix > 0)\n",
    "temp = temp.groupby(Z[\"Physician_Group\"].loc[temp.index]).mean().var().sort_values(ascending=False)\n",
    "\n",
    "usage_rates = pd.DataFrame(final_matrix).mean()\n",
    "high_usage_mask = usage_rates > 0.05 # Only look at drugs used in >5% of visits\n",
    "\n",
    "temp = temp[high_usage_mask]\n",
    "\n",
    "temp.head(20).index, temp[~(temp.index.str.startswith(\"hc\") | temp.index.str.startswith(\"activitycode\"))].head(20).index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import gaussian_kde\n",
    "\n",
    "# 1. Extract the data\n",
    "# We drop NaNs to ensure the KDE fit doesn't fail\n",
    "los_data = encdf.loc[pnuemonia_patient_account_id, \"Enc___Length_of_Stay_in_Hours\"].dropna()\n",
    "\n",
    "# 2. Fit the Kernel Density Estimator\n",
    "# This automatically determines the bandwidth using Scott's Rule by default\n",
    "kde = gaussian_kde(los_data)\n",
    "\n",
    "# 3. Visualize the Fit\n",
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "# Plot the histogram (density=True is required to match the KDE scale)\n",
    "plt.hist(los_data, bins=50, density=True, alpha=0.5, color='gray', label='Actual Data')\n",
    "\n",
    "# Plot the KDE curve\n",
    "x_range = np.linspace(los_data.min(), los_data.max(), 1000)\n",
    "plt.plot(x_range, kde(x_range), color='red', lw=2, label='KDE Fit')\n",
    "\n",
    "plt.title(\"KDE Fit for Length of Stay\", fontsize=14)\n",
    "plt.xlabel(\"Hours\", fontsize=12)\n",
    "plt.ylabel(\"Density\", fontsize=12)\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def theta(T, noisemodel, noise=True, baseline = patient_sickness):\n",
    "    # --- STEP 1: DEFINE TREATMENT EFFECTS ---\n",
    "    treatment_effects = {\n",
    "            # --- Antibiotics & Primary Respiratory (The Core Signal) ---\n",
    "            'levofloxacin _po': -10.0,\n",
    "            'azithromycin _po': -8.0,\n",
    "            'ipratropium-albuterol _unspecified': -4.5,\n",
    "            'albuterol sulfate _unspecified': -4.0,\n",
    "            'hc stat nebulizer': -3.5,\n",
    "            'hc metered dose': -2.0,\n",
    "            'hc inhalation tx': -3.0,\n",
    "\n",
    "            # --- Steroids & Anti-inflammatories (Managing the Immune Response) ---\n",
    "            'prednisone _po': -6.0,\n",
    "            'dexamethasone sodium phosphate _inj': -5.5,\n",
    "            'dexamethasone _po': -5.0,\n",
    "            'ibuprofen _po': -1.0,  # Lowered weight; mainly for fever/pain\n",
    "            'aspirin _po': -1.0,\n",
    "\n",
    "            # --- Anticoagulants & Fluids (Supportive Care) ---\n",
    "            'heparin_iv': -5.0,\n",
    "            'heparin_inj': -4.0,\n",
    "            'dextrose _iv': -2.0,\n",
    "            \n",
    "            # --- Symptomatic Control ---\n",
    "            'acetaminophen _po': -1.0,\n",
    "            'guaifenesin _po': -0.5,  # Expectorant to clear mucus\n",
    "\n",
    "            # --- Diagnostic & Process ---\n",
    "            'hc blood culture': -1.5,      \n",
    "            'hc pc02-02 (cg4)': -0.8,     # Arterial blood gas monitoring\n",
    "            'hc pclact-lactate (cg4)': -0.5, # Sepsis monitoring\n",
    "            'hc chest 2v': -0.5,          # Standard X-ray\n",
    "            'hc chest 1': -0.4,\n",
    "            'hc therapypro/dx ivp': -2.0,  \n",
    "            'hc stat:thrpst trtmnt': -1.0,\n",
    "            'activitycode__hc venipuncture': -1.0,\n",
    "            'hc cbc w/auto': -1.0,\n",
    "            'hc metabolic pnl': -1.0\n",
    "            }\n",
    "    \n",
    "\n",
    "    # --- STEP 2: GENERATE Y ---\n",
    "    treatment_total = (T[list(treatment_effects.keys())] * pd.Series(treatment_effects)).sum(axis=1)\n",
    "\n",
    "    # Outcome = Baseline (48 hrs) + Sickness - Treatment Help + Noise\n",
    "    if noise:\n",
    "        return treatment_effects, np.maximum(48.0 + baseline.loc[T.index] + treatment_total + noisemodel.resample(size=1).flatten()/5, 0)\n",
    "    else:\n",
    "        return treatment_effects, np.maximum(48.0 + baseline.loc[T.index] + treatment_total, 0)\n",
    "\n",
    "treatment_effects, Y2 = theta(T2, kde)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y2.describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LIRR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.decomposition import PCA\n",
    "from abc import ABC, abstractmethod\n",
    "\n",
    "# Abstract base class that defines the common interface\n",
    "class BaseModel(ABC):\n",
    "    @abstractmethod\n",
    "    def fit(self, Z, X, y, r):\n",
    "        \"\"\"\n",
    "        Fit the model to the training data.\n",
    "        \n",
    "        Parameters:\n",
    "        -----------\n",
    "        Z : array-like of shape (n_samples, k)\n",
    "            Instruments.\n",
    "        X : array-like of shape (n_samples, m)\n",
    "            Features.\n",
    "        y : array-like of shape (n_samples, 1)\n",
    "            Outcome.\n",
    "        r : scaler\n",
    "            latent dimension\n",
    "            \n",
    "        Returns:\n",
    "        --------\n",
    "        self : object\n",
    "            Returns self.\n",
    "        \"\"\"\n",
    "        pass\n",
    "\n",
    "\n",
    "from sklearn.metrics import r2_score\n",
    "# IV regression model\n",
    "class IVRegression(BaseModel):\n",
    "    '''\n",
    "    Run IV (Z^TX)^-1Z^TY\n",
    "    Run TSLS\n",
    "    \n",
    "    Z: instrumental var\n",
    "    X: endo var\n",
    "    Y: outcome\n",
    "    '''\n",
    "    def __init__(self, intercept=False):\n",
    "        self.theta = None\n",
    "        self.intercept = intercept\n",
    "        self.first = None\n",
    "        self.second = None\n",
    "    def fit(self, Z, X, Y):\n",
    "        self.first = LinearRegression().fit(Z, X)\n",
    "        r2 = r2_score(X, self.first.predict(Z)) \n",
    "        #print(f\"first stage r2 {r2}\")\n",
    "        self.second = LinearRegression().fit(self.first.predict(Z), Y)\n",
    "        self.theta = self.second.coef_.T\n",
    "        return self\n",
    "\n",
    "\n",
    "\n",
    "# Proposed method via SVD\n",
    "class LIRR(BaseModel):\n",
    "    '''\n",
    "    Solve C from X ~ CZ using least squares\n",
    "    Perform SVD and take B = U, A = Sigma V.T\n",
    "    Solve theta using IV regression on Y ~ B.TX\n",
    "\n",
    "    Z: instrumental var\n",
    "    X: features\n",
    "    Y: outcome\n",
    "    r: latent dimension\n",
    "\n",
    "    Assume all vars are mean-centered and standardized.\n",
    "    '''\n",
    "    \n",
    "    def __init__(self, r, intercept=False):\n",
    "        self.r = r\n",
    "        self.A = None\n",
    "        self.B = None\n",
    "        self.ivmodel = None\n",
    "        self.intercept = intercept\n",
    "        \n",
    "    def fit(self, Z, X, y):\n",
    "        # First least square to get C\n",
    "        regr = LinearRegression(fit_intercept=False)\n",
    "        regr.fit(Z, X)\n",
    "        C = regr.coef_\n",
    "        \n",
    "        # then svd to get A and B\n",
    "        U, S, Vh = np.linalg.svd(C, full_matrices=False)\n",
    "        B_hat = U[:,:self.r]\n",
    "        A_hat = np.diag(S[:self.r]) @ Vh[:self.r,]\n",
    "        self.A = A_hat\n",
    "        self.B = B_hat\n",
    "\n",
    "        # fit iv regression on the estimated latent\n",
    "        D = X @ B_hat\n",
    "        \n",
    "        self.ivmodel = IVRegression(self.intercept)\n",
    "        self.ivmodel.fit(Z, D, y)\n",
    "        \n",
    "        return self\n",
    "\n",
    "    def gettheta(self):\n",
    "        return self.ivmodel.theta\n",
    "        \n",
    "    def encode(self, X):\n",
    "        return X @ self.B\n",
    "\n",
    "    def decode(self, D):\n",
    "        return D @ (self.B).T\n",
    "\n",
    "\n",
    "# Naive method PCA followed by IVregression\n",
    "class PCAMethod(BaseModel):\n",
    "    '''\n",
    "    PCA followed by IVregression\n",
    "\n",
    "    Z: instrumental var\n",
    "    X: features\n",
    "    Y: outcome\n",
    "    r: latent dimension\n",
    "\n",
    "    Assume all vars are mean-centered and standardized.\n",
    "    '''\n",
    "    \n",
    "    def __init__(self, r, intercept=False):\n",
    "        self.r = r\n",
    "        self.PCA = None\n",
    "        self.intercept = intercept\n",
    "        \n",
    "    def fit(self, Z, X, y):\n",
    "        # First PCA, assume X is standarized\n",
    "        pca = PCA(n_components=self.r)\n",
    "        X_reduced = pca.fit_transform(X)\n",
    "        self.PCA = pca\n",
    "        self.B = pca.components_.T\n",
    "        \n",
    "        # then ivregression\n",
    "        self.ivmodel = IVRegression(self.intercept)\n",
    "        self.ivmodel.fit(Z, X_reduced, y)\n",
    "        \n",
    "        return self\n",
    "\n",
    "    def gettheta(self):\n",
    "        return self.ivmodel.theta\n",
    "        \n",
    "    def encode(self, X):\n",
    "        return self.PCA.transform(X)\n",
    "\n",
    "    def decode(self, D):\n",
    "        return self.PCA.inverse_transform(D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example\n",
    "data = {\"X\": T2.loc[pnuemonia_patient_account_id], \\\n",
    "        \"Y\": pd.DataFrame(index=pnuemonia_patient_account_id, data=Y2.loc[pnuemonia_patient_account_id].values), \\\n",
    "        \"Z\": Z2.loc[pnuemonia_patient_account_id]}\n",
    "data_trans = {}\n",
    "XStandardScaler = StandardScaler().fit(data[\"X\"])\n",
    "data_trans['X'] = XStandardScaler.transform(data[\"X\"])\n",
    "YStandardScaler = StandardScaler().fit(data[\"Y\"])\n",
    "data_trans['Y'] = YStandardScaler.transform(data[\"Y\"])\n",
    "data_trans['Z'] = StandardScaler().fit_transform(data[\"Z\"])\n",
    "\n",
    "r = 5\n",
    "mymodel = LIRR(r, intercept=False)\n",
    "mymodel.fit(data_trans['Z'], data_trans['X'], data_trans['Y'])\n",
    "data_trans['D'] = mymodel.encode(data_trans['X'])\n",
    "\n",
    "alpha = 20\n",
    "data_trans['Xprime'] = XStandardScaler.inverse_transform( data_trans['X'] \\\n",
    "                                - ((alpha*mymodel.gettheta().T/np.linalg.norm(mymodel.gettheta())) @ (mymodel.B).T))\n",
    "data_trans['Xprime'] = pd.DataFrame(data=data_trans['Xprime'], index=pnuemonia_patient_account_id, columns=T2.columns)\n",
    "data_trans['Xprime'] = data_trans['Xprime'] > 0.2\n",
    "_, data_trans[\"Yprime\"] = theta(data_trans['Xprime'], kde, noise=False)\n",
    "_, data_trans[\"Y_denoise\"] = theta(data[\"X\"], kde, noise=False)\n",
    "\n",
    "improvement = (data_trans[\"Yprime\"].loc[pnuemonia_patient_account_id].values - data_trans[\"Y_denoise\"].loc[pnuemonia_patient_account_id].values).mean()\n",
    "print(f\"Average predicted improvement in recovery: {improvement:.4f} hours\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Synthetic Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_smart_sickness_nudge(W_syn, treatment_columns):\n",
    "    # Initialize a nudge matrix of zeros (n_patients, n_treatments)\n",
    "    nudge = np.zeros((len(W_syn), len(treatment_columns)))\n",
    "    \n",
    "    # Create a mapping of Comorbidity -> Treatment Category\n",
    "    resp_treatments = ['hc stat nebulizer', 'hc inhalation tx', 'albuterol sulfate _unspecified', 'ipratropium-albuterol _unspecified']\n",
    "    antibiotics = ['levofloxacin _po', 'azithromycin _po', 'ceftriaxone _iv']\n",
    "    steroids = ['prednisone _po', 'dexamethasone _po']\n",
    "\n",
    "    # Logic: Lung/Heart issues drive Respiratory meds\n",
    "    resp_drivers = ['LUNG_CHRONIC', 'HF', 'PULMCIRC', 'VALVE']\n",
    "    for col in resp_treatments:\n",
    "        if col in treatment_columns:\n",
    "            idx = treatment_columns.get_loc(col)\n",
    "            # Patients with these issues are 50% more likely to get resp meds\n",
    "            nudge[:, idx] += W_syn[resp_drivers].max(axis=1)*0.99\n",
    "            \n",
    "    # Logic: Immune issues drive stronger Antibiotic usage\n",
    "    immune_drivers = ['AIDS', 'CANCER_METS', 'CANCER_LEUK', 'WGHTLOSS']\n",
    "    for col in antibiotics:\n",
    "        if col in treatment_columns:\n",
    "            idx = treatment_columns.get_loc(col)\n",
    "            nudge[:, idx] += W_syn[immune_drivers].max(axis=1)*0.99\n",
    "            \n",
    "    # Logic: Inflammatory/Asthma issues drive Steroids\n",
    "    inflam_drivers = ['LUNG_CHRONIC', 'AUTOIMMUNE']\n",
    "    for col in steroids:\n",
    "        if col in treatment_columns:\n",
    "            idx = treatment_columns.get_loc(col)\n",
    "            nudge[:, idx] += W_syn[inflam_drivers].max(axis=1)*0.99\n",
    "            \n",
    "    return nudge\n",
    "\n",
    "def generate_synthetic_data_clean_iv(Z_real, X_real, W_real, comorb_type, n_target=50000):\n",
    "    indx = np.arange(n_target)\n",
    "    \n",
    "    # --- STEP 1: Generate Comorbidities (Independent of Physician) ---\n",
    "    # Sample from the global distribution of W (not grouped by doctor)\n",
    "    global_comorb_probs = W_real.mean() + 0.01\n",
    "    W_syn_raw = (np.random.rand(n_target, W_real.shape[1]) < global_comorb_probs.values).astype(int)\n",
    "    W_syn = pd.DataFrame(W_syn_raw, columns=W_real.columns, index=indx)\n",
    "\n",
    "    # --- STEP 2: Generate Physician Assignments (The Instrument) ---\n",
    "    phys_ids = Z_real.idxmax(axis=1).unique()\n",
    "    doc_assignments = np.random.choice(phys_ids, size=n_target)\n",
    "    doc_assignments_ser = pd.Series(doc_assignments, index=indx)\n",
    "    Z_syn = pd.get_dummies(doc_assignments_ser, prefix='Physician')\n",
    "\n",
    "    # --- STEP 3: Generate Treatments (Related to W and Z) ---\n",
    "    # Physician Style Component\n",
    "    physician_probs = X_real.groupby(Z_real.idxmax(axis=1)).mean()\n",
    "    relevant_drugs = list(treatment_effects.keys())\n",
    "    physician_probs = physician_probs.clip(0.01, 0.99)\n",
    "    base_probs = physician_probs.loc[doc_assignments].values\n",
    "    \n",
    "    # Sickness Bias Component (Confounding)\n",
    "    # If a patient has HF or LIVER_SEV, increase prob of getting certain drugs\n",
    "    # This creates the correlation between W and X\n",
    "    sickness_nudge = get_smart_sickness_nudge(W_syn, X_real.columns)\n",
    "    \n",
    "    X_probs = np.clip(base_probs + sickness_nudge, 0.01, 0.99)\n",
    "    X_syn_raw = (np.random.rand(n_target, X_real.shape[1]) < X_probs).astype(int)\n",
    "    X_syn = pd.DataFrame(X_syn_raw, columns=X_real.columns, index=indx)\n",
    "\n",
    "    # --- STEP 4: Calculate Sickness Penalty (Related to W and Y) ---\n",
    "    weights_vec = np.array([comorb_weights_dict.get(c, 5) for c in comorb_type])\n",
    "    patient_sickness = pd.Series(W_syn[comorb_type].values @ weights_vec, \n",
    "                                    index=indx)\n",
    "\n",
    "    return X_syn, Z_syn, W_syn, patient_sickness, doc_assignments_ser\n",
    "\n",
    "# Run the clean version\n",
    "X_syn, Z_syn, W_syn, sickness_score_syn, doc_ids_syn = generate_synthetic_data_clean_iv(\n",
    "    data[\"Z\"], data[\"X\"], W.loc[data[\"Z\"].index], comorb_type\n",
    ")\n",
    "_, Y_syn_df = theta(X_syn, kde, noise=True, baseline = sickness_score_syn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LassoCV\n",
    "# First pick LASSO weight\n",
    "\n",
    "X_syn, Z_syn, W_syn, sickness_score_syn, doc_ids_syn = generate_synthetic_data_clean_iv(\n",
    "        data[\"Z\"], data[\"X\"], W.loc[data[\"Z\"].index], comorb_type\n",
    ")\n",
    "_, Y_syn_df = theta(X_syn, kde, noise=True, baseline = sickness_score_syn)\n",
    "\n",
    "# 1. Scale\n",
    "scaler = StandardScaler()\n",
    "X_scaled = scaler.fit_transform(X_syn)\n",
    "\n",
    "# 2. Fit with Cross-Validation\n",
    "# cv=5 or 10 is standard\n",
    "model = LassoCV(cv=5, random_state=42).fit(X_scaled, Y_syn_df)\n",
    "\n",
    "# 3. See which features survived\n",
    "relevant_features = sum(model.coef_ != 0)\n",
    "print(f\"Optimal Lambda: {model.alpha_}\")\n",
    "print(f\"Features retained: {relevant_features}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GroupKFold\n",
    "from sklearn.linear_model import LinearRegression, Lasso\n",
    "\n",
    "def get_improvement(model, X_train_scaled, Y_train, Z_train, X_test_raw, X_scaler, \n",
    "                    alpha, threshold, theta_func, baseline_score):\n",
    "    \"\"\"Fits a model and returns the improvement score on the test set.\"\"\"\n",
    "    \n",
    "    # 1. Fit Model\n",
    "    # Some models need Z (Latent IV), some don't (OLS/Lasso)\n",
    "    if hasattr(model, 'fit') and 'Z' in model.fit.__code__.co_varnames:\n",
    "        model.fit(Z_train, X_train_scaled, Y_train)\n",
    "    else:\n",
    "        model.fit(X_train_scaled, Y_train.ravel())\n",
    "\n",
    "    # 2. Get Improvement Direction\n",
    "    if hasattr(model, 'gettheta'):\n",
    "        theta_vec = model.gettheta()\n",
    "        direction_latent = theta_vec.T / (np.linalg.norm(theta_vec) + 1e-9)\n",
    "        direction_X = direction_latent @ model.B.T\n",
    "    else:\n",
    "        coeffs = model.coef_\n",
    "        direction_X = coeffs / (np.linalg.norm(coeffs) + 1e-9)\n",
    "\n",
    "    # 3. Apply Perturbation\n",
    "    X_test_scaled = X_scaler.transform(X_test_raw)\n",
    "    X_prime_scaled = X_test_scaled - (alpha * direction_X)\n",
    "    X_prime_raw = X_scaler.inverse_transform(X_prime_scaled)\n",
    "    \n",
    "    # Reconstruct DataFrame and apply threshold\n",
    "    X_prime_df = pd.DataFrame(X_prime_raw, index=X_test_raw.index, columns=X_test_raw.columns)\n",
    "    X_prime_df = (X_prime_df > threshold)\n",
    "\n",
    "    # 4. Score\n",
    "    _, Y_prime = theta_func(X_prime_df, kde, noise=False, baseline=baseline_score)\n",
    "    _, Y_denoise = theta_func(X_test_raw, kde, noise=False, baseline=baseline_score)\n",
    "    \n",
    "    return Y_denoise - Y_prime\n",
    "\n",
    "def kold(doc_ids_syn, T2, Z2, Y2, sickness_score_syn, pnuemonia_patient_account_id, alpha=10, threshold=0.2):\n",
    "    gkf = GroupKFold(n_splits=5)\n",
    "    groups = doc_ids_syn[pnuemonia_patient_account_id]\n",
    "\n",
    "    # Model definitions\n",
    "    models_to_run = {\n",
    "        'LIRR': LIRR(r=5, intercept=False),\n",
    "        'PCA': PCAMethod(r=5, intercept=False),\n",
    "        'OLS': LinearRegression(fit_intercept=False),\n",
    "        'LASSO': Lasso(alpha=20, fit_intercept=False)\n",
    "    }\n",
    "    \n",
    "    # Results container: {model_name: [list_of_fold_improvements]}\n",
    "    results = {name: [] for name in models_to_run.keys()}\n",
    "\n",
    "    for train_idx, test_idx in gkf.split(pnuemonia_patient_account_id, groups=groups):\n",
    "        train_ids = pnuemonia_patient_account_id[train_idx]\n",
    "        test_ids = pnuemonia_patient_account_id[test_idx]\n",
    "        \n",
    "        # Shared training data setup\n",
    "        X_scaler = StandardScaler(with_mean=False).fit(T2.loc[train_ids])\n",
    "        X_train_scaled = X_scaler.transform(T2.loc[train_ids])\n",
    "        Y_train = Y2.loc[train_ids].values.reshape(-1, 1)\n",
    "        Z_train = Z2.loc[train_ids].values\n",
    "        X_test_raw = T2.loc[test_ids]\n",
    "\n",
    "        for name, model in models_to_run.items():\n",
    "            improvement = get_improvement(\n",
    "                model, X_train_scaled, Y_train, Z_train, X_test_raw, \n",
    "                X_scaler, alpha, threshold, theta, sickness_score_syn\n",
    "            )\n",
    "            results[name].append(improvement)\n",
    "\n",
    "    # Aggregation\n",
    "    final_stats = {}\n",
    "    for name, improvements in results.items():\n",
    "        all_imp = np.concatenate(improvements)\n",
    "        \n",
    "        # Store both Mean and Standard Deviation\n",
    "        m_score = np.mean(all_imp)\n",
    "        s_score = np.std(all_imp)\n",
    "        \n",
    "        final_stats[name] = {\n",
    "            'mean': m_score,\n",
    "            'std': s_score\n",
    "        }\n",
    "        \n",
    "        print(f\"{name: <6} | Mean: {m_score: .4f} hrs | Std: {s_score: .4f}\")\n",
    "\n",
    "    # Return the specific means requested and the full stats dictionary\n",
    "    return final_stats, results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_syn, Z_syn, W_syn, sickness_score_syn, doc_ids_syn = generate_synthetic_data_clean_iv(\n",
    "        data[\"Z\"], data[\"X\"], W.loc[data[\"Z\"].index], comorb_type\n",
    "    )\n",
    "_, Y_syn_df = theta(X_syn, kde, noise=True, baseline = sickness_score_syn)\n",
    "\n",
    "full_stats, improvements = kold(\n",
    "    doc_ids_syn, \n",
    "    X_syn, \n",
    "    Z_syn, \n",
    "    Y_syn_df, \n",
    "    sickness_score_syn, \n",
    "    X_syn.index, \n",
    "    alpha=10, \n",
    "    threshold=0.2\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "def plot_combined_kde(improvements_dict):\n",
    "    plt.figure(figsize=(8, 6))\n",
    "    \n",
    "    # Define a color palette for consistency\n",
    "    palette = sns.color_palette(\"husl\", len(improvements_dict))\n",
    "    \n",
    "    for i, (model_key, data_list) in enumerate(improvements_dict.items()):\n",
    "        # Flatten the cross-validation folds\n",
    "        all_data = np.concatenate(data_list)\n",
    "        \n",
    "        # Calculate mean for the legend label\n",
    "        mu = np.mean(all_data)\n",
    "        \n",
    "        # Plot the KDE\n",
    "        sns.kdeplot(all_data, fill=True, label=f'{model_key} (μ: {mu:.2f} hrs)', \n",
    "                    linewidth=2, alpha=0.5, bw_adjust=3)\n",
    "\n",
    "    # Add a vertical reference line at 0\n",
    "    plt.axvline(0, color='black', linestyle='--', alpha=0.6, label='No Improvement')\n",
    "    \n",
    "    # Styling\n",
    "    plt.title('Comparison of Distribution of Saved Recovery Time', fontsize=16)\n",
    "    plt.xlabel('Hours Saved', fontsize=13)\n",
    "    plt.ylabel('Density', fontsize=13)\n",
    "    plt.xlim(-5, 60) # Keeping your original range\n",
    "    plt.grid(axis='y', alpha=0.3)\n",
    "    plt.legend(title=\"Models\", loc='upper right')\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig('combined_model_kde.png', dpi=300)\n",
    "    plt.show()\n",
    "\n",
    "# Call it once instead of in a loop\n",
    "plot_combined_kde(improvements)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ipca_test",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
