{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aP4dGQjT5A4G"
      },
      "outputs": [],
      "source": [
        "#non-private Optuna Best weight-decay =  0.00027258944899542076, Best learning rate: 0.00047755386467452173\n",
        "import pandas as pd\n",
        "from sklearn.model_selection import train_test_split\n",
        "from sklearn.metrics import accuracy_score, classification_report\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import opacus\n",
        "from opacus.accountants.utils import get_noise_multiplier\n",
        "from opacus.utils.batch_memory_manager import BatchMemoryManager\n",
        "from torch.utils.data import DataLoader, TensorDataset\n",
        "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, classification_report, confusion_matrix\n",
        "#import matplotlib.pyplot as plt\n",
        "import fire\n",
        "import numpy as np\n",
        "from tqdm.auto import tqdm\n",
        "from transformers import get_linear_schedule_with_warmup\n",
        "import optuna\n",
        "import torch.nn.functional as F\n",
        "import wandb\n",
        "from sklearn.preprocessing import MinMaxScaler\n",
        "import random\n",
        "\n",
        "\n",
        "# ================================\n",
        "# User Configurable Variables\n",
        "# ================================\n",
        "\n",
        "SEED = 42\n",
        "# Paths\n",
        "eIcu1_PATH = \"/eICUGOSSIS/physionet.org/files/gossis-1-eicu/1.0.0/gossis-1-eicu-only-model-ready.csv.gz\"\n",
        "eIcu2_PATH = \"/eICUGOSSIS/physionet.org/files/gossis-1-eicu/1.0.0/gossis-1-eicu-only.csv.gz\"\n",
        "\n",
        "\n",
        "def seed_everything(seed=42):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed(seed)\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "    # Ensure deterministic behavior in cuDNN\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "seed_everything(SEED)\n",
        "\n",
        "\n",
        "class EarlyStopping:\n",
        "    def __init__(self, patience=3, threshold=0.0):\n",
        "        self.patience = patience\n",
        "        self.threshold = threshold\n",
        "        self.counter = 0\n",
        "        self.best_score = None\n",
        "        self.early_stop = False\n",
        "\n",
        "    def __call__(self, val_score):\n",
        "        if self.best_score is None:\n",
        "            self.best_score = val_score\n",
        "        elif val_score < self.best_score + self.threshold:\n",
        "            self.counter += 1\n",
        "            if self.counter >= self.patience:\n",
        "                self.early_stop = True\n",
        "        else:\n",
        "            self.best_score = val_score\n",
        "            self.counter = 0\n",
        "        return self.early_stop\n",
        "\n",
        "def compute_metrics(logits, labels):\n",
        "    logits = torch.tensor(logits).float()\n",
        "    probabilities = torch.softmax(torch.tensor(logits), dim=-1).numpy()\n",
        "    predictions = np.argmax(logits, axis=-1)\n",
        "    accuracy = accuracy_score(labels, predictions)\n",
        "    precision = precision_score(labels, predictions, average='binary')\n",
        "    recall = recall_score(labels, predictions, average='binary')\n",
        "    f1 = f1_score(labels, predictions, average='binary')\n",
        "    roc_auc = roc_auc_score(labels, probabilities[:, 1])\n",
        "    return {\"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1\": f1, \"roc_auc\": roc_auc}\n",
        "\n",
        "\n",
        "\n",
        "class ICUModel(nn.Module):\n",
        "    def __init__(self, input_size):\n",
        "        super(ICUModel, self).__init__()\n",
        "        self.fc1 = nn.Linear(input_size, 128)\n",
        "        self.bn1 = nn.BatchNorm1d(128)\n",
        "        self.fc2 = nn.Linear(128, 64)\n",
        "        self.bn2 = nn.BatchNorm1d(64)\n",
        "        self.fc3 = nn.Linear(64, 32)\n",
        "        self.fc4 = nn.Linear(32, 2)\n",
        "        self.relu = nn.ReLU()\n",
        "        self.dropout = nn.Dropout(0.3)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.relu(self.bn1(self.fc1(x)))\n",
        "        x = self.relu(self.bn2(self.fc2(x)))\n",
        "        x = self.relu(self.fc3(x))\n",
        "        x = self.fc4(x)\n",
        "        return x\n",
        "\n",
        "\n",
        "\n",
        "def test_model(model, dataloader):\n",
        "    model.eval()\n",
        "    total_loss = 0.0\n",
        "    all_preds = []\n",
        "    all_labels = []\n",
        "    all_probs = []\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    class_weights = torch.tensor([0.5, 1.0]).to(device)\n",
        "    criterion = nn.CrossEntropyLoss(weight=class_weights,reduction=\"sum\")\n",
        "\n",
        "    with torch.no_grad():\n",
        "        for X_batch, y_batch in dataloader:\n",
        "            X_batch = X_batch.to(device)\n",
        "            y_batch = y_batch.to(device)\n",
        "            outputs = model(X_batch)\n",
        "            loss = criterion(outputs, y_batch)\n",
        "            total_loss += loss.item()\n",
        "\n",
        "            probs = torch.softmax(outputs, dim=1)[:, 1]\n",
        "            preds = outputs.argmax(dim=1)\n",
        "\n",
        "            all_preds.extend(preds.cpu().numpy())\n",
        "            all_labels.extend(y_batch.cpu().numpy())\n",
        "            all_probs.extend(probs.cpu().numpy())\n",
        "\n",
        "    avg_loss = total_loss / len(dataloader)\n",
        "    accuracy = accuracy_score(all_labels, all_preds)\n",
        "    precision = precision_score(all_labels, all_preds, average='binary')\n",
        "    recall = recall_score(all_labels, all_preds, average='binary')\n",
        "    f1 = f1_score(all_labels, all_preds, average='binary')\n",
        "    roc_auc = roc_auc_score(all_labels, all_probs)\n",
        "\n",
        "    print(\"Test Results:\")\n",
        "    print(f\"Loss: {avg_loss:.4f} | Accuracy: {accuracy:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {f1:.4f} | AUROC: {roc_auc:.4f}\")\n",
        "    return {\n",
        "        \"loss\": avg_loss,\n",
        "        \"accuracy\": accuracy,\n",
        "        \"precision\": precision,\n",
        "        \"recall\": recall,\n",
        "        \"f1\": f1,\n",
        "        \"roc_auc\": roc_auc\n",
        "    }\n",
        "\n",
        "# Main Function for Document Classification\n",
        "def doc_classification(lr=0.00001, epochs=30, batch_size=32): #0.0005, 0.001\n",
        "\n",
        "    wandb.init(\n",
        "        project=\"icu_mortality_prediction\",\n",
        "        name=\"icu_run_with_batchnorm_dropout_normalization\",\n",
        "        config={\"learning_rate\": lr, \"epochs\": epochs, \"batch_size\": batch_size}\n",
        "    )\n",
        "\n",
        "    # Load dataset\n",
        "    file_path = eIcu1_PATH\n",
        "    df = pd.read_csv(file_path, compression='gzip')\n",
        "\n",
        "    df = df[~((df[\"icu_death\"] == 1) & (df[\"hospital_death\"] == 0))]\n",
        "\n",
        "    df=df.dropna()\n",
        "\n",
        "    df = pd.get_dummies(df, columns=[\"dx_class\", \"dx_sub\", \"dcs_group\", \"icu_admit_source\", \"group\"], drop_first=True)\n",
        "\n",
        "    #X_train = df[df['partition']=='training'].drop(columns=[\"patientunitstayid\", \"encounter_id\", \"partition\", \"icu_death\"])\n",
        "    #y_train = df[df['partition']=='training'][[\"hospital_death\"]]\n",
        "    #X_temp =  df[df['partition']=='testing'].drop(columns=[\"patientunitstayid\", \"encounter_id\", \"partition\", \"icu_death\"])\n",
        "    #y_temp =  df[df['partition']=='testing'][[\"hospital_death\"]]\n",
        "\n",
        "    #X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=2/3, random_state=42)  # 2/3 of 30% = 20%\n",
        "\n",
        "    # Add gender and thnicity\n",
        "    df2 = pd.read_csv(eIcu2_PATH)\n",
        "    df = pd.merge(df,df2[['patientunitstayid','encounter_id','ethnicity','gender']], on =['patientunitstayid','encounter_id'] , how='inner')\n",
        "\n",
        "    df=df.dropna()\n",
        "\n",
        "    # Drop unnecessary columns\n",
        "    df = df.drop(columns=[\"patientunitstayid\", \"encounter_id\", \"partition\", \"icu_death\"])  # ID-like columns\n",
        "\n",
        "    # Undersampling it\n",
        "    # Count samples in each class\n",
        "    count_0 = len(df[df['hospital_death'] == 0])\n",
        "    count_1 = len(df[df['hospital_death'] == 1])\n",
        "\n",
        "    # Under-sample class 0 to match class 1\n",
        "    df_class_0 = df[df['hospital_death'] == 0].sample(n=count_1, random_state=42)\n",
        "    df_class_1 = df[df['hospital_death'] == 1]\n",
        "\n",
        "    # Combine and shuffle\n",
        "    df = pd.concat([df_class_0, df_class_1]).sample(frac=1, random_state=42).reset_index(drop=True)\n",
        "\n",
        "    median_age = df['age'].median() #median is 65\n",
        "    df['age_median_gp'] = df['age'].apply(lambda x: 0 if x < median_age else 1)\n",
        "    df = df.drop(columns=['age'])\n",
        "\n",
        "    df['gender']=df['gender'].map({'M': 1, 'F': 0})\n",
        "    df['ethnicity'] = df['ethnicity'].apply(lambda x: 1 if x == 'Caucasian' else 0)\n",
        "\n",
        "    # Select numeric columns (int and float types)\n",
        "    numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns\n",
        "\n",
        "    # Initialize the MinMaxScaler\n",
        "    scaler = MinMaxScaler()\n",
        "\n",
        "    # Apply Min-Max scaling to the numeric columns\n",
        "    df[numeric_cols] = scaler.fit_transform(df[numeric_cols])\n",
        "\n",
        "    # high_corr_columns_list = high_corr_columns_list + [\"gender\", \"ethnicity\",\"age_median_gp\"]\n",
        "    X = df.drop(columns=[\"hospital_death\"])  # All features except target\n",
        "    #X = X [high_corr_columns_list]\n",
        "    y = df[\"hospital_death\"]  # Target variable (0 = Survived, 1 = Died)\n",
        "\n",
        "    # First, split data into 70% train and 30% temporary (test + val)\n",
        "    X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, stratify=y, random_state=42)\n",
        "    # Split the 30% temp data into 10% validation and 20% test\n",
        "    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=2/3, stratify=y_temp, random_state=42)  # 2/3 of 30% = 20%\n",
        "\n",
        "    print(f\"Train: {len(X_train)} samples\")\n",
        "    print(f\"Validation: {len(X_val)} samples\")\n",
        "    print(f\"Test: {len(X_test)} samples\")\n",
        "\n",
        "    X_train = X_train.astype('float32')\n",
        "    X_val = X_val.astype('float32')\n",
        "    X_test = X_test.astype('float32')\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    X_train_torch = torch.tensor(X_train.values, dtype=torch.float32).to(device)\n",
        "    y_train_torch = torch.tensor(y_train.values, dtype=torch.long).squeeze().to(device)\n",
        "    X_test_torch = torch.tensor(X_test.values, dtype=torch.float32).to(device)\n",
        "    y_test_torch = torch.tensor(y_test.values, dtype=torch.long).squeeze().to(device)\n",
        "    X_val_torch = torch.tensor(X_val.values, dtype=torch.float32).to(device)\n",
        "    y_val_torch = torch.tensor(y_val.values, dtype=torch.long).squeeze().to(device)\n",
        "\n",
        "    # Create DataLoader\n",
        "    train_dataset = TensorDataset(X_train_torch, y_train_torch)\n",
        "    train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)\n",
        "\n",
        "    # Create Test DataLoader (Batching for Memory Efficiency)\n",
        "    test_dataset = TensorDataset(X_test_torch, y_test_torch)\n",
        "    test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle=False)  # Use batch_size based on memory capacity\n",
        "    val_dataset = TensorDataset(X_val_torch, y_val_torch)\n",
        "    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)\n",
        "\n",
        "    # Create val DataLoader (Batching for Memory Efficiency)\n",
        "    white_mask = X_test[\"ethnicity\"] == 1\n",
        "    Age_Below_med_mask = X_test[\"age_median_gp\"] == 0\n",
        "    Age_Above_med_mask = X_test[\"age_median_gp\"] == 1\n",
        "\n",
        "    #white_mask = X_val[\"gender\"] == \"M\"\n",
        "    X_val_white, y_val_white = X_test[white_mask], y_test[white_mask]\n",
        "    X_val_nonwhite, y_val_nonwhite = X_test[~white_mask], y_test[~white_mask]\n",
        "    X_val_age_below_med, y_val_age_below_med = X_test[Age_Below_med_mask], y_test[Age_Below_med_mask]\n",
        "    X_val_age_above_med, y_val_age_above_med = X_test[Age_Above_med_mask], y_test[Age_Above_med_mask]\n",
        "\n",
        "    # Convert to tensors\n",
        "    X_val_white = X_val_white.astype('float32')\n",
        "    y_val_white = y_val_white.astype('float32')\n",
        "    X_val_nonwhite = X_val_nonwhite.astype('float32')\n",
        "    y_val_nonwhite = y_val_nonwhite.astype('float32')\n",
        "\n",
        "    X_val_age_below_med = X_val_age_below_med.astype('float32')\n",
        "    y_val_age_below_med = y_val_age_below_med.astype('float32')\n",
        "    X_val_age_above_med = X_val_age_above_med.astype('float32')\n",
        "    y_val_age_above_med = y_val_age_above_med.astype('float32')\n",
        "\n",
        "    X_val_white_torch = torch.tensor(X_val_white.values, dtype=torch.float32).to(device)\n",
        "    y_val_white_torch = torch.tensor(y_val_white.values, dtype=torch.long).squeeze().to(device)\n",
        "\n",
        "    X_val_nonwhite_torch = torch.tensor(X_val_nonwhite.values, dtype=torch.float32).to(device)\n",
        "    y_val_nonwhite_torch = torch.tensor(y_val_nonwhite.values, dtype=torch.long).squeeze().to(device)\n",
        "\n",
        "    X_val_age_below_med_torch = torch.tensor(X_val_age_below_med.values, dtype=torch.float32).to(device)\n",
        "    y_val_age_below_med_torch = torch.tensor(y_val_age_below_med.values, dtype=torch.long).squeeze().to(device)\n",
        "    X_val_age_above_med_torch = torch.tensor(X_val_age_above_med.values, dtype=torch.float32).to(device)\n",
        "    y_val_age_above_med_torch = torch.tensor(y_val_age_above_med.values, dtype=torch.long).squeeze().to(device)\n",
        "\n",
        "    # Create DataLoaders\n",
        "    white_dataset = TensorDataset(X_val_white_torch, y_val_white_torch)\n",
        "    nonwhite_dataset = TensorDataset(X_val_nonwhite_torch, y_val_nonwhite_torch)\n",
        "    age_above_med_dataset = TensorDataset(X_val_age_above_med_torch, y_val_age_above_med_torch)\n",
        "    age_below_med_dataset = TensorDataset(X_val_age_below_med_torch, y_val_age_below_med_torch)\n",
        "\n",
        "\n",
        "    White_dataloader = DataLoader(white_dataset, batch_size=batch_size, shuffle=False)\n",
        "    nonWhite_dataloader = DataLoader(nonwhite_dataset, batch_size=batch_size, shuffle=False)\n",
        "\n",
        "    age_above_med_dataloader = DataLoader(age_above_med_dataset, batch_size=batch_size, shuffle=False)\n",
        "    age_below_med_dataloader = DataLoader(age_below_med_dataset, batch_size=batch_size, shuffle=False)\n",
        "\n",
        "\n",
        "    # Create val DataLoader (Batching for Memory Efficiency)\n",
        "    male_mask = X_test[\"gender\"] == 1 #Male=1, female=0\n",
        "    X_val_male, y_val_male = X_test[male_mask], y_test[male_mask]\n",
        "    X_val_female, y_val_female = X_test[~male_mask], y_test[~male_mask]\n",
        "\n",
        "    # Convert to tensors\n",
        "    X_val_male = X_val_male.astype('float32')\n",
        "    y_val_male = y_val_male.astype('float32')\n",
        "    X_val_female = X_val_female.astype('float32')\n",
        "    y_val_female = y_val_female.astype('float32')\n",
        "\n",
        "\n",
        "    X_val_male_torch = torch.tensor(X_val_male.values, dtype=torch.float32).to(device)\n",
        "    y_val_male_torch = torch.tensor(y_val_male.values, dtype=torch.long).squeeze().to(device)\n",
        "\n",
        "    X_val_female_torch = torch.tensor(X_val_female.values, dtype=torch.float32).to(device)\n",
        "    y_val_female_torch = torch.tensor(y_val_female.values, dtype=torch.long).squeeze().to(device)\n",
        "\n",
        "    # Create DataLoaders\n",
        "    male_dataset = TensorDataset(X_val_male_torch, y_val_male_torch)\n",
        "    female_dataset = TensorDataset(X_val_female_torch, y_val_female_torch)\n",
        "\n",
        "    male_dataloader = DataLoader(male_dataset, batch_size=batch_size, shuffle=False)\n",
        "    female_dataloader = DataLoader(female_dataset, batch_size=batch_size, shuffle=False)\n",
        "\n",
        "    # Compute class weights\n",
        "    class_weights = torch.tensor([0.5, 1.0]).to(device)\n",
        "\n",
        "    # Initialize Model and Move to GPU\n",
        "    model = ICUModel(input_size=X_train.shape[1]).to(device)\n",
        "    criterion = nn.CrossEntropyLoss(weight=class_weights,reduction=\"sum\")\n",
        "    #criterion = nn.CrossEntropyLoss(reduction=\"sum\")\n",
        "    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=2.9173130489231118e-05) #2.9173130489231118e-05\n",
        "\n",
        "    num_training_steps = len(train_loader) * epochs\n",
        "    num_warmup_steps = int(0.1 * num_training_steps)\n",
        "    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)\n",
        "\n",
        "    early_stopper = EarlyStopping(patience=5, threshold=0.001)  # You can tweak this\n",
        "    best_val_f1 = 0\n",
        "    best_model_state = None\n",
        "\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        total_loss = 0\n",
        "        all_preds = []\n",
        "        all_labels = []\n",
        "        all_probs = []\n",
        "        for batch in tqdm(train_loader, desc=f\"Epoch {epoch+1} Train\"):\n",
        "            X_batch, y_batch = batch\n",
        "            optimizer.zero_grad()\n",
        "            logits = model(X_batch)\n",
        "            loss = criterion(logits, y_batch)\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "            scheduler.step()\n",
        "\n",
        "            total_loss += loss.item()\n",
        "            preds = logits.argmax(dim=1).detach().cpu().numpy()\n",
        "            labels = y_batch.detach().cpu().numpy()\n",
        "            probs = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy()\n",
        "\n",
        "            all_preds.extend(preds)\n",
        "            all_labels.extend(labels)\n",
        "            all_probs.extend(probs)\n",
        "\n",
        "        train_f1 = f1_score(all_labels, all_preds)\n",
        "        train_acc = accuracy_score(all_labels, all_preds)\n",
        "        train_auroc = roc_auc_score(all_labels, all_probs)\n",
        "        train_precision = precision_score(all_labels, all_preds)\n",
        "        train_recall = recall_score(all_labels, all_preds)\n",
        "        #train_metrics = compute_metrics(np.stack(all_preds), np.stack(all_labels))\n",
        "        avg_train_loss = total_loss / len(train_loader)\n",
        "\n",
        "        wandb.log({\n",
        "            \"train/loss\": avg_train_loss,\n",
        "            \"train/accuracy\": train_acc,\n",
        "            \"train/precision\": train_precision,\n",
        "            \"train/recall\": train_recall,\n",
        "            \"train/f1\": train_f1,\n",
        "            \"train/roc_auc\": train_auroc,\n",
        "            \"train_learning_rate\": scheduler.get_last_lr()[0],\n",
        "            \"train_epoch\": epoch + 1\n",
        "        })\n",
        "\n",
        "        # === VALIDATION ===\n",
        "        model.eval()\n",
        "        val_preds = []\n",
        "        val_labels = []\n",
        "        val_probs = []\n",
        "        val_loss = 0\n",
        "        with torch.no_grad():\n",
        "            for X_batch, y_batch in val_loader:\n",
        "                outputs = model(X_batch)\n",
        "                loss = criterion(outputs, y_batch)\n",
        "                probs = torch.softmax(outputs, dim=1)[:, 1]\n",
        "                val_preds.extend(outputs.argmax(dim=1).cpu().numpy())\n",
        "                val_labels.extend(y_batch.cpu().numpy())\n",
        "                val_probs.extend(probs.detach().cpu().numpy())\n",
        "                val_loss += loss.item()\n",
        "\n",
        "        val_f1 = f1_score(val_labels, val_preds)\n",
        "        val_acc = accuracy_score(val_labels, val_preds)\n",
        "        val_auroc = roc_auc_score(val_labels, val_probs)\n",
        "        val_precision = precision_score(val_labels, val_preds)\n",
        "        val_recall = recall_score(val_labels, val_preds)\n",
        "\n",
        "        #val_metrics = compute_metrics(np.stack(val_preds), np.stack(val_labels))\n",
        "        avg_val_loss = val_loss / len(val_loader)\n",
        "\n",
        "        #print(f\"Epoch {epoch+1}: Train F1 {val_f1:.4f}, Val F1 {val_metrics['f1']:.4f}\")\n",
        "\n",
        "        wandb.log({\n",
        "            \"val/loss\": avg_val_loss,\n",
        "            \"val/accuracy\": val_acc,\n",
        "            \"val/precision\": val_precision,\n",
        "            \"val/recall\": val_recall,\n",
        "            \"val/f1\": val_f1,\n",
        "            \"val/roc_auc\": val_auroc,\n",
        "            \"epoch\": epoch + 1\n",
        "        })\n",
        "\n",
        "        if val_f1 > best_val_f1:\n",
        "            best_val_f1 = val_f1\n",
        "            best_model_state = model.state_dict()  # Save best model weights\n",
        "\n",
        "        # if early_stopper(val_f1):\n",
        "        #     print(f\"Early stopping at epoch {epoch+1}\")\n",
        "        #     break\n",
        "\n",
        "    # Load best model after training\n",
        "    #if best_model_state is not None:\n",
        "    #model.load_state_dict(best_model_state)\n",
        "\n",
        "    print(\"\\n=== Test on Full Test Set ===\")\n",
        "    test_model(model, test_loader)\n",
        "    print(\"\\n=== Test on Female Subset ===\")\n",
        "    test_model(model, female_dataloader)\n",
        "    print(\"\\n=== Test on Male Subset ===\")\n",
        "    test_model(model, male_dataloader)\n",
        "    print(\"\\n=== Test on age below median Subset ===\")\n",
        "    test_model(model, age_below_med_dataloader)\n",
        "    print(\"\\n=== Test on age above Subset ===\")\n",
        "    test_model(model, age_above_med_dataloader)\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    fire.Fire(doc_classification)\n"
      ]
    }
  ]
}