{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "waWAvTxhagqP"
      },
      "outputs": [],
      "source": [
        "from opacus.optimizers import AdaClipDPOptimizer, DPOptimizer\n",
        "from opacus.optimizers.optimizer import (\n",
        "    _check_processed_flag,\n",
        "    _generate_noise,\n",
        "    _mark_as_processed\n",
        ")\n",
        "\n",
        "class FairOptimizer(AdaClipDPOptimizer):\n",
        "    def __init__(self, *args, **kwargs):\n",
        "        super().__init__(*args,**kwargs)\n",
        "\n",
        "\n",
        "    def clip_and_accumulate(self):\n",
        "        per_param_norms = [\n",
        "            g.view(len(g), -1).norm(2, dim=-1) for g in self.grad_samples\n",
        "        ]\n",
        "        per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)\n",
        "        per_sample_clip_factor = torch.tanh(self.max_grad_norm / (per_sample_norms + 1e-6))  #the only change to make it fair with tanh\n",
        "\n",
        "        # the two lines below are the only changes\n",
        "        # relative to the parent DPOptimizer class.\n",
        "        self.sample_size += len(per_sample_clip_factor)\n",
        "        self.unclipped_num += (\n",
        "            len(per_sample_clip_factor) - (per_sample_clip_factor < 1).sum()\n",
        "        )\n",
        "\n",
        "        for p in self.params:\n",
        "            _check_processed_flag(p.grad_sample)\n",
        "            grad_sample = self._get_flat_grad_sample(p)\n",
        "            grad = torch.einsum(\"i,i...\", per_sample_clip_factor, grad_sample)\n",
        "\n",
        "            if p.summed_grad is not None:\n",
        "                p.summed_grad += grad\n",
        "            else:\n",
        "                p.summed_grad = grad\n",
        "\n",
        "            _mark_as_processed(p.grad_sample)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from opacus import PrivacyEngine\n",
        "import torch.optim as optim\n",
        "from typing import List, Union\n",
        "from opacus.optimizers import DPOptimizer\n",
        "\n",
        "class fairPrivacyEngine(PrivacyEngine):\n",
        "    def __init__(self, *args, **kwargs):\n",
        "        super().__init__(*args, **kwargs)\n",
        "\n",
        "    def _prepare_optimizer(\n",
        "        self,\n",
        "        *,\n",
        "        optimizer: optim.Optimizer,\n",
        "        noise_multiplier: float,\n",
        "        max_grad_norm: Union[float, List[float]],\n",
        "        expected_batch_size: int,\n",
        "        loss_reduction: str = \"mean\",\n",
        "        distributed: bool = False,\n",
        "        clipping: str = \"flat\",\n",
        "        noise_generator=None,\n",
        "        grad_sample_mode=\"hooks\",\n",
        "        **kwargs,\n",
        "    ) -> DPOptimizer:\n",
        "        if isinstance(optimizer, DPOptimizer):\n",
        "            optimizer = optimizer.original_optimizer\n",
        "\n",
        "        generator = None\n",
        "        if self.secure_mode:\n",
        "            generator = self.secure_rng\n",
        "        elif noise_generator is not None:\n",
        "            generator = noise_generator\n",
        "\n",
        "\n",
        "        new_fair = FairOptimizer(\n",
        "            optimizer=optimizer,\n",
        "            noise_multiplier=noise_multiplier,\n",
        "            max_grad_norm=max_grad_norm,\n",
        "            expected_batch_size=expected_batch_size,\n",
        "            loss_reduction=loss_reduction,\n",
        "            generator=generator,\n",
        "            secure_mode=self.secure_mode,\n",
        "            **kwargs,\n",
        "        )\n",
        "\n",
        "        return new_fair\n",
        ""
      ],
      "metadata": {
        "id": "rcrWMLsjahx9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#private\n",
        "import pandas as pd\n",
        "from sklearn.model_selection import train_test_split\n",
        "from sklearn.metrics import accuracy_score, classification_report\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import opacus\n",
        "from opacus.accountants.utils import get_noise_multiplier\n",
        "from opacus.utils.batch_memory_manager import BatchMemoryManager\n",
        "from torch.utils.data import DataLoader, TensorDataset\n",
        "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, classification_report, confusion_matrix\n",
        "#import matplotlib.pyplot as plt\n",
        "import fire\n",
        "import numpy as np\n",
        "from tqdm.auto import tqdm\n",
        "from transformers import get_linear_schedule_with_warmup\n",
        "import random\n",
        "import os\n",
        "import wandb\n",
        "from sklearn.preprocessing import MinMaxScaler\n",
        "\n",
        "# ================================\n",
        "# User Configurable Variables\n",
        "# ================================\n",
        "\n",
        "SEED = 42\n",
        "# Paths\n",
        "eIcu1_PATH = \"/eICUGOSSIS/physionet.org/files/gossis-1-eicu/1.0.0/gossis-1-eicu-only-model-ready.csv.gz\"   #Select the path that the file is saved\n",
        "eIcu2_PATH = \"/eICUGOSSIS/physionet.org/files/gossis-1-eicu/1.0.0/gossis-1-eicu-only.csv.gz\" #select the path that the file is saved\n",
        "CHECKPOINT_PATH =\"/Checkpoints/120CheckpointseICUNewFairAdaptivelr=0.001_60epoch\" #Select the path for saving chekcpoints\n",
        "\n",
        "\n",
        "def seed_everything(seed=42):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed(seed)\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "    # Ensure deterministic behavior in cuDNN\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "seed_everything(SEED)\n",
        "\n",
        "class EarlyStopping:\n",
        "    def __init__(self, patience=3, threshold=0.0):\n",
        "        self.patience = patience\n",
        "        self.threshold = threshold\n",
        "        self.counter = 0\n",
        "        self.best_score = None\n",
        "        self.early_stop = False\n",
        "\n",
        "    def __call__(self, val_score):\n",
        "        if self.best_score is None:\n",
        "            self.best_score = val_score\n",
        "        elif val_score < self.best_score + self.threshold:\n",
        "            self.counter += 1\n",
        "            if self.counter >= self.patience:\n",
        "                self.early_stop = True\n",
        "        else:\n",
        "            self.best_score = val_score\n",
        "            self.counter = 0\n",
        "        return self.early_stop\n",
        "\n",
        "def compute_metrics(logits, labels):\n",
        "    logits = torch.tensor(logits).float()\n",
        "    probabilities = torch.softmax(torch.tensor(logits), dim=-1).numpy()\n",
        "    predictions = np.argmax(logits, axis=-1)\n",
        "    accuracy = accuracy_score(labels, predictions)\n",
        "    precision = precision_score(labels, predictions, average='binary')\n",
        "    recall = recall_score(labels, predictions, average='binary')\n",
        "    f1 = f1_score(labels, predictions, average='binary')\n",
        "    roc_auc = roc_auc_score(labels, probabilities[:, 1])\n",
        "    return {\"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1\": f1, \"roc_auc\": roc_auc}\n",
        "\n",
        "\n",
        "class ICUModel(nn.Module):\n",
        "    def __init__(self, input_size):\n",
        "        super(ICUModel, self).__init__()\n",
        "        self.fc1 = nn.Linear(input_size, 128)\n",
        "        self.norm1 = nn.GroupNorm(32, 128)  # Replaces BatchNorm1d\n",
        "        self.fc2 = nn.Linear(128, 64)\n",
        "        self.norm2 = nn.GroupNorm(32, 64)\n",
        "        self.fc3 = nn.Linear(64, 32)\n",
        "        self.fc4 = nn.Linear(32, 2)\n",
        "        self.relu = nn.ReLU()\n",
        "        self.dropout = nn.Dropout(0.3)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.relu(self.norm1(self.fc1(x)))\n",
        "        x = self.relu(self.norm2(self.fc2(x)))\n",
        "        x = self.relu(self.fc3(x))\n",
        "        x = self.fc4(x)\n",
        "        return x\n",
        "\n",
        "\n",
        "\n",
        "micro_batch_size = 32\n",
        "\n",
        "def main(lr=0.001, epochs=60, epsilon=8, max_grad_norm=0.1, batch_size = 64 ,micro_batch_size=32):\n",
        "\n",
        "    wandb.init(\n",
        "        project=\"icu_mortality_prediction\",\n",
        "        name=\"icu_FairAdaptive_privacy_lr=0.001\",\n",
        "        config={\"learning_rate\": lr, \"epochs\": epochs, \"batch_size\": batch_size}\n",
        "    )\n",
        "\n",
        "    checkpoint_dir = CHECKPOINT_PATH\n",
        "    os.makedirs(checkpoint_dir, exist_ok=True)\n",
        "\n",
        "    # Check if GPU is available\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    print(f\"Using device: {device}\")\n",
        "\n",
        "\n",
        "     # Load dataset\n",
        "    file_path = eIcu1_PATH\n",
        "    df = pd.read_csv(file_path, compression='gzip')\n",
        "\n",
        "    df = df[~((df[\"icu_death\"] == 1) & (df[\"hospital_death\"] == 0))]\n",
        "\n",
        "    df=df.dropna()\n",
        "\n",
        "    df = pd.get_dummies(df, columns=[\"dx_class\", \"dx_sub\", \"dcs_group\", \"icu_admit_source\", \"group\"], drop_first=True)\n",
        "\n",
        "    # Add gender and thnicity\n",
        "    df2 = pd.read_csv(eIcu2_PATH)\n",
        "    df = pd.merge(df,df2[['patientunitstayid','encounter_id','ethnicity','gender']], on =['patientunitstayid','encounter_id'] , how='inner')\n",
        "\n",
        "    df=df.dropna()\n",
        "\n",
        "    # Drop unnecessary columns\n",
        "    df = df.drop(columns=[\"patientunitstayid\", \"encounter_id\", \"partition\", \"icu_death\"])  # ID-like columns\n",
        "\n",
        "    # Undersampling it\n",
        "    # Count samples in each class\n",
        "    count_0 = len(df[df['hospital_death'] == 0])\n",
        "    count_1 = len(df[df['hospital_death'] == 1])\n",
        "\n",
        "    # Under-sample class 0 to match class 1\n",
        "    df_class_0 = df[df['hospital_death'] == 0].sample(n=count_1, random_state=42)\n",
        "    df_class_1 = df[df['hospital_death'] == 1]\n",
        "\n",
        "    # Combine and shuffle\n",
        "    df = pd.concat([df_class_0, df_class_1]).sample(frac=1, random_state=42).reset_index(drop=True)\n",
        "\n",
        "    median_age = df['age'].median() #median is 65\n",
        "    df['age_median_gp'] = df['age'].apply(lambda x: 0 if x < median_age else 1)\n",
        "    df = df.drop(columns=['age'])\n",
        "\n",
        "    df['gender']=df['gender'].map({'M': 1, 'F': 0})\n",
        "    df['ethnicity'] = df['ethnicity'].apply(lambda x: 1 if x == 'Caucasian' else 0)\n",
        "\n",
        "    # Select numeric columns (int and float types)\n",
        "    numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns\n",
        "\n",
        "    # Initialize the MinMaxScaler\n",
        "    scaler = MinMaxScaler()\n",
        "\n",
        "    # Apply Min-Max scaling to the numeric columns\n",
        "    df[numeric_cols] = scaler.fit_transform(df[numeric_cols])\n",
        "\n",
        "    # high_corr_columns_list = high_corr_columns_list + [\"gender\", \"ethnicity\",\"age_median_gp\"]\n",
        "    X = df.drop(columns=[\"hospital_death\"])  # All features except target\n",
        "    #X = X [high_corr_columns_list]\n",
        "    y = df[\"hospital_death\"]  # Target variable (0 = Survived, 1 = Died)\n",
        "\n",
        "    # First, split data into 70% train and 30% temporary (test + val)\n",
        "    X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, stratify=y, random_state=42)\n",
        "    # Split the 30% temp data into 10% validation and 20% test\n",
        "    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=2/3, stratify=y_temp, random_state=42)  # 2/3 of 30% = 20%\n",
        "    #X_train, X_test, X_val = X_train.drop(columns=[\"gender\", \"ethnicity\",\"age_median_gp\"]), X_test.drop(columns=[\"gender\", \"ethnicity\",\"age_median_gp\"]), X_val.drop(columns=[\"gender\", \"ethnicity\",\"age_median_gp\"])\n",
        "\n",
        "    print(f\"Train: {len(X_train)} samples\")\n",
        "    print(f\"Validation: {len(X_val)} samples\")\n",
        "    print(f\"Test: {len(X_test)} samples\")\n",
        "\n",
        "    X_train = X_train.astype('float32')\n",
        "    X_val = X_val.astype('float32')\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    X_train_torch = torch.tensor(X_train.values, dtype=torch.float32).to(device)\n",
        "    y_train_torch = torch.tensor(y_train.values, dtype=torch.long).squeeze().to(device)\n",
        "    #y_test_torch = torch.tensor(y_test.values, dtype=torch.long).squeeze().to(device)\n",
        "    X_val_torch = torch.tensor(X_val.values, dtype=torch.float32).to(device)\n",
        "    y_val_torch = torch.tensor(y_val.values, dtype=torch.long).squeeze().to(device)\n",
        "\n",
        "    # Create DataLoader\n",
        "    train_dataset = TensorDataset(X_train_torch, y_train_torch)\n",
        "    train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)\n",
        "\n",
        "    # Create Test DataLoader (Batching for Memory Efficiency)\n",
        "    #test_dataset = TensorDataset(X_test_torch, y_test_torch)\n",
        "    #test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle=False)  # Use batch_size based on memory capacity\n",
        "    val_dataset = TensorDataset(X_val_torch, y_val_torch)\n",
        "    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)\n",
        "\n",
        "\n",
        "    # Create val DataLoader (Batching for Memory Efficiency)\n",
        "    white_mask = X_test[\"ethnicity\"] == 1\n",
        "    Age_Below_med_mask = X_test[\"age_median_gp\"] == 0\n",
        "    Age_Above_med_mask = X_test[\"age_median_gp\"] == 1\n",
        "\n",
        "\n",
        "    #white_mask = X_val[\"gender\"] == \"M\"\n",
        "    X_val_white, y_val_white = X_test[white_mask], y_test[white_mask]\n",
        "    X_val_nonwhite, y_val_nonwhite = X_test[~white_mask], y_test[~white_mask]\n",
        "    X_val_age_below_med, y_val_age_below_med = X_test[Age_Below_med_mask], y_test[Age_Below_med_mask]\n",
        "    X_val_age_above_med, y_val_age_above_med = X_test[Age_Above_med_mask], y_test[Age_Above_med_mask]\n",
        "\n",
        "\n",
        "\n",
        "    # Convert to tensors\n",
        "    X_val_white = X_val_white.astype('float32')\n",
        "    y_val_white = y_val_white.astype('float32')\n",
        "    X_val_nonwhite = X_val_nonwhite.astype('float32')\n",
        "    y_val_nonwhite = y_val_nonwhite.astype('float32')\n",
        "\n",
        "\n",
        "    X_val_age_below_med = X_val_age_below_med.astype('float32')\n",
        "    y_val_age_below_med = y_val_age_below_med.astype('float32')\n",
        "    X_val_age_above_med = X_val_age_above_med.astype('float32')\n",
        "    y_val_age_above_med = y_val_age_above_med.astype('float32')\n",
        "\n",
        "    X_val_white_torch = torch.tensor(X_val_white.values, dtype=torch.float32).to(device)\n",
        "    y_val_white_torch = torch.tensor(y_val_white.values, dtype=torch.long).squeeze().to(device)\n",
        "\n",
        "    X_val_nonwhite_torch = torch.tensor(X_val_nonwhite.values, dtype=torch.float32).to(device)\n",
        "    y_val_nonwhite_torch = torch.tensor(y_val_nonwhite.values, dtype=torch.long).squeeze().to(device)\n",
        "\n",
        "    X_val_age_below_med_torch = torch.tensor(X_val_age_below_med.values, dtype=torch.float32).to(device)\n",
        "    y_val_age_below_med_torch = torch.tensor(y_val_age_below_med.values, dtype=torch.long).squeeze().to(device)\n",
        "    X_val_age_above_med_torch = torch.tensor(X_val_age_above_med.values, dtype=torch.float32).to(device)\n",
        "    y_val_age_above_med_torch = torch.tensor(y_val_age_above_med.values, dtype=torch.long).squeeze().to(device)\n",
        "\n",
        "    # Create DataLoaders\n",
        "    white_dataset = TensorDataset(X_val_white_torch, y_val_white_torch)\n",
        "    nonwhite_dataset = TensorDataset(X_val_nonwhite_torch, y_val_nonwhite_torch)\n",
        "    age_above_med_dataset = TensorDataset(X_val_age_above_med_torch, y_val_age_above_med_torch)\n",
        "    age_below_med_dataset = TensorDataset(X_val_age_below_med_torch, y_val_age_below_med_torch)\n",
        "\n",
        "\n",
        "    White_dataloader = DataLoader(white_dataset, batch_size=batch_size, shuffle=False)\n",
        "    nonWhite_dataloader = DataLoader(nonwhite_dataset, batch_size=batch_size, shuffle=False)\n",
        "\n",
        "    age_above_med_dataloader = DataLoader(age_above_med_dataset, batch_size=batch_size, shuffle=False)\n",
        "    age_below_med_dataloader = DataLoader(age_below_med_dataset, batch_size=batch_size, shuffle=False)\n",
        "\n",
        "\n",
        "    # Create val DataLoader (Batching for Memory Efficiency)\n",
        "    male_mask = X_test[\"gender\"] == 1 #Male=1, female=0\n",
        "    X_val_male, y_val_male = X_test[male_mask], y_test[male_mask]\n",
        "    X_val_female, y_val_female = X_test[~male_mask], y_test[~male_mask]\n",
        "\n",
        "    # Convert to tensors\n",
        "    X_val_male = X_val_male.astype('float32')\n",
        "    y_val_male = y_val_male.astype('float32')\n",
        "    X_val_female = X_val_female.astype('float32')\n",
        "    y_val_female = y_val_female.astype('float32')\n",
        "\n",
        "\n",
        "    X_val_male_torch = torch.tensor(X_val_male.values, dtype=torch.float32).to(device)\n",
        "    y_val_male_torch = torch.tensor(y_val_male.values, dtype=torch.long).squeeze().to(device)\n",
        "\n",
        "    X_val_female_torch = torch.tensor(X_val_female.values, dtype=torch.float32).to(device)\n",
        "    y_val_female_torch = torch.tensor(y_val_female.values, dtype=torch.long).squeeze().to(device)\n",
        "\n",
        "    # Create DataLoaders\n",
        "    male_dataset = TensorDataset(X_val_male_torch, y_val_male_torch)\n",
        "    female_dataset = TensorDataset(X_val_female_torch, y_val_female_torch)\n",
        "\n",
        "    male_dataloader = DataLoader(male_dataset, batch_size=batch_size, shuffle=False)\n",
        "    female_dataloader = DataLoader(female_dataset, batch_size=batch_size, shuffle=False)\n",
        "\n",
        "    # Compute class weights\n",
        "    class_weights = torch.tensor([0.5, 1.0]).to(device)\n",
        "\n",
        "    # Initialize Model and Move to GPU\n",
        "    model = ICUModel(input_size=X_train.shape[1]).to(device)\n",
        "    criterion = nn.CrossEntropyLoss(weight=class_weights,reduction=\"sum\")\n",
        "    #criterion = nn.CrossEntropyLoss(reduction=\"sum\")\n",
        "    base_optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=2.9173130489231118e-05) #2.9173130489231118e-05\n",
        "\n",
        "    # Opacus setup\n",
        "    sample_rate = batch_size / len(X_train)\n",
        "\n",
        "    sigma = get_noise_multiplier(\n",
        "        target_epsilon=epsilon,\n",
        "        target_delta=1e-5,\n",
        "        sample_rate=sample_rate,\n",
        "        epochs=epochs\n",
        "    )\n",
        "    print(f\"Using Noise Multiplier (sigma): {sigma:.4f}\")\n",
        "\n",
        "    m = batch_size  # users per round\n",
        "    sigma_b = m / 20        # from the paper\n",
        "\n",
        "    privacy_engine = fairPrivacyEngine()\n",
        "\n",
        "    model, optimizer, train_loader = privacy_engine.make_private(\n",
        "        module=model,\n",
        "        optimizer=base_optimizer,\n",
        "        clipping = 'adaptive',\n",
        "        data_loader=train_loader,\n",
        "        noise_multiplier=sigma,\n",
        "        max_grad_norm=max_grad_norm,\n",
        "        poisson_sampling=False, # poisson_sampling=True, If I was using drop_last = False in dataloader\n",
        "        loss_reduction=\"sum\", #sum\n",
        "        target_unclipped_quantile=0.5, # gamma = median\n",
        "        clipbound_learning_rate=0.2, # ηC\n",
        "        max_clipbound=2.0,  # upper limit on clipping norm\n",
        "        min_clipbound=0.1,  # lower limit\n",
        "        unclipped_num_std=sigma_b\n",
        "    )\n",
        "\n",
        "    # Compute warmup steps dynamically\n",
        "\n",
        "    num_training_steps = len(train_loader) * epochs\n",
        "    num_warmup_steps = int(0.1 * num_training_steps)\n",
        "    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)\n",
        "\n",
        "    early_stopper = EarlyStopping(patience=5, threshold=0.0)  # You can tweak this\n",
        "    best_val_f1 = 0\n",
        "    best_model_state = None\n",
        "\n",
        "    micro_batch_size = 32\n",
        "    grad_norms_before_clipping = []\n",
        "    grad_norms_after_clipping = []\n",
        "    grad_norms_after_clipping_noisy = []\n",
        "    all_noises_grad_sample = []\n",
        "    all_noises_grad = []\n",
        "\n",
        "    # Train Model\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        total_loss = 0\n",
        "        print_loss=0\n",
        "        total_micro_batches = 0\n",
        "        micro_batch_size = 32\n",
        "        all_train_predictions = []\n",
        "        all_train_labels = []\n",
        "        grad_list = []\n",
        "        all_logical_predictions = []\n",
        "        all_logical_labels = []\n",
        "        layer_wise_grad_samples = {}\n",
        "        cosine_similarities = []\n",
        "        angles = []\n",
        "        progress_bar = tqdm(train_loader, desc=f\"Epoch {epoch + 1}\")\n",
        "\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        with BatchMemoryManager(\n",
        "            data_loader=train_loader,\n",
        "            max_physical_batch_size=micro_batch_size,\n",
        "            optimizer=optimizer,\n",
        "        ) as memory_safe_data_loader:\n",
        "\n",
        "            for batch_idx, batch in enumerate(memory_safe_data_loader):\n",
        "\n",
        "                total_micro_batches += 1\n",
        "\n",
        "                X_batch, y_batch = batch\n",
        "                X_batch, y_batch = X_batch.to(device), y_batch.to(device)\n",
        "\n",
        "                # Forward pass\n",
        "                outputs = model(X_batch)\n",
        "                logits = outputs\n",
        "                labels = y_batch\n",
        "\n",
        "                # Apply weights to CrossEntropyLoss\n",
        "                loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))\n",
        "                total_loss += loss.item()\n",
        "                print_loss = loss.item()\n",
        "\n",
        "                # Compute and log metrics every batch\n",
        "                predictions = logits.detach().cpu().numpy()\n",
        "                labels_cpu = labels.detach().cpu().numpy()\n",
        "                all_logical_predictions.extend(predictions)\n",
        "                all_logical_labels.extend(labels_cpu)\n",
        "\n",
        "                # Backward pass\n",
        "                loss.backward()\n",
        "\n",
        "                for name, p in model.named_parameters():\n",
        "                    if p.requires_grad:\n",
        "                        with torch.no_grad():\n",
        "                            grad_sample = p.grad_sample.sum(dim = 0).clone().detach()  #birone if #sum accross samples\n",
        "                            if name not in layer_wise_grad_samples:\n",
        "                                layer_wise_grad_samples[name] = torch.zeros_like(grad_sample)  # Store per-layer name\n",
        "                            layer_wise_grad_samples[name] += grad_sample  # Accumulate across physical batches\n",
        "\n",
        "                #Calculating and storing gradient before clipping and adding noise and after clipping and adding noise\n",
        "                is_final_step = optimizer.pre_step()\n",
        "\n",
        "                if is_final_step == True:\n",
        "                    total_grad_norm_squared = 0.0\n",
        "                    total_grad_norm_squared_summed_noisy = 0.0\n",
        "                    total_noise_grad = 0.0\n",
        "                    all_layer_grads_before=[]\n",
        "                    all_layer_grads_after=[]\n",
        "\n",
        "                    num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
        "                    expected_noise = (num_trainable_params ** 0.5) * sigma # * max_grad_norm\n",
        "\n",
        "                    for name, grad_list in layer_wise_grad_samples.items():\n",
        "                        all_layer_grads_before.append(grad_list)  # Store per-layer gradients\n",
        "                        ## Compute the L2 norm squared\n",
        "                        total_grad_norm_squared += torch.norm(grad_list, p=2) ** 2\n",
        "\n",
        "                    for p in model.parameters():\n",
        "                        if p.requires_grad:\n",
        "                             with torch.no_grad():\n",
        "                                    all_layer_grads_after.append(p.summed_grad.clone().detach())\n",
        "\n",
        "                                    layer_grad_norm_summed_noisy = torch.norm(p.summed_grad, p=2).item()  #torch.norm(p.summed_grad.view(-1), p=2).item()\n",
        "                                    total_grad_norm_squared_summed_noisy += layer_grad_norm_summed_noisy ** 2\n",
        "\n",
        "                                    grad = torch.norm(p.grad, p=2).item()\n",
        "                                    total_noise_grad += (grad - layer_grad_norm_summed_noisy) ** 2\n",
        "\n",
        "\n",
        "                    # Convert lists to single tensors\n",
        "                    total_grad_before = torch.cat([v.view(-1) for v in all_layer_grads_before])\n",
        "                    total_grad_after = torch.cat([v.view(-1) for v in all_layer_grads_after])  # Aggregate all layers\n",
        "\n",
        "                    # Compute Cosine Similarity\n",
        "                    cosine_similarity = torch.nn.functional.cosine_similarity(\n",
        "                        total_grad_before.view(1, -1),\n",
        "                        total_grad_after.view(1, -1),\n",
        "                        dim=1\n",
        "                    ).item()\n",
        "\n",
        "                    # Compute Angle in Degrees\n",
        "                    angle = torch.acos(torch.clamp(torch.tensor(cosine_similarity), -1.0, 1.0)).item() * (180 / torch.pi)\n",
        "\n",
        "                    # Store values for logging\n",
        "                    cosine_similarities.append(cosine_similarity)\n",
        "                    angles.append(angle)\n",
        "\n",
        "                    #print(f\"Cosine Similarity: {cosine_similarity}\")\n",
        "                    #print(f\"Angle in Degrees: {angle}\")\n",
        "\n",
        "                    # Compute the overall gradient norm\n",
        "                    overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "                    grad_norms_before_clipping.append(overall_grad_norm)\n",
        "                    overall_grad_norm_summed_noisy = total_grad_norm_squared_summed_noisy ** 0.5\n",
        "                    grad_norms_after_clipping_noisy.append(overall_grad_norm_summed_noisy)\n",
        "                    overal_noise_grad = total_noise_grad ** 0.5\n",
        "                    all_noises_grad.append(overal_noise_grad)\n",
        "\n",
        "\n",
        "                    # Log gradient norms for the logical batch\n",
        "                    wandb.log({\n",
        "                        \"logical_batch_grad_norm_after_clipping_with_Noise\":overall_grad_norm_summed_noisy,\n",
        "                        \"logical_batch_grad_norm_before_clipping\":overall_grad_norm,\n",
        "                        \"overal noise grad - summed_grad\":overal_noise_grad,\n",
        "                        \"summed_grad/noise grad\":overall_grad_norm_summed_noisy/overal_noise_grad,\n",
        "                        \"Expected noise\" : expected_noise,\n",
        "                        \"norm after clipping / norm before clipping: \":overall_grad_norm_summed_noisy/overall_grad_norm,\n",
        "                        \"Cosine Similarity\": cosine_similarity,\n",
        "                        \"Angle in Degrees\": angle\n",
        "                    })\n",
        "                    #print(f\"logical batch grad_norm before clipping: {overall_grad_norm}\")\n",
        "                    #print(f\"logical batch grad norm after clipping with Noise: {overall_grad_norm_summed_noisy}\")\n",
        "                    #print(f\"overal noise: grad - summed_grad: {overal_noise_grad}\")\n",
        "                    #print(f\"summed_grad/noise grad: {overall_grad_norm_summed_noisy/overal_noise_grad}\")\n",
        "                    #print(f\"Expected noise: {expected_noise}\")\n",
        "\n",
        "\n",
        "                    # Compute metrics after all physical batches in the logical batch are processed\n",
        "                    train_metrics = compute_metrics(np.array(all_logical_predictions), np.array(all_logical_labels))\n",
        "                    avg_train_loss = total_loss / total_micro_batches\n",
        "\n",
        "                    # Log to wandb\n",
        "                    wandb.log({\n",
        "                        \"train_loss\": avg_train_loss,\n",
        "                        \"train_accuracy\": train_metrics['accuracy'],\n",
        "                        \"train_precision\": train_metrics['precision'],\n",
        "                        \"train_recall\": train_metrics['recall'],\n",
        "                        \"train_f1\": train_metrics['f1'],\n",
        "                        \"train auoc\":train_metrics['roc_auc']} )\n",
        "\n",
        "                    #print(f\"train_loss: {avg_train_loss}\")\n",
        "                    #print(f\"train_accuracy: {train_metrics['accuracy']}\")\n",
        "                    #print(f\"train_precision: {train_metrics['precision']}\")\n",
        "                    #print(f\"train_recall: { train_metrics['recall']}\")\n",
        "                    #print(f\"train_f1: {train_metrics['f1']}\")\n",
        "                    #print(f\"train_Roc: {train_metrics['roc_auc']}\")\n",
        "\n",
        "\n",
        "                    optimizer.original_optimizer.step()\n",
        "                    scheduler.step()\n",
        "                    all_logical_labels = []\n",
        "                    all_logical_predictions = []\n",
        "                    grad_list = []\n",
        "                    layer_wise_grad_samples = {}\n",
        "\n",
        "                progress_bar.set_postfix({\n",
        "                \"lr\": scheduler.get_last_lr()[0],\n",
        "                \"step_loss\": print_loss\n",
        "                })\n",
        "\n",
        "                optimizer.zero_grad()\n",
        "\n",
        "            if (epoch + 1) % 10 == 0:\n",
        "                checkpoint_path = os.path.join(checkpoint_dir, f\"checkpoint_epoch_{epoch+1}.pt\")\n",
        "                torch.save({\n",
        "                    'epoch': epoch + 1,\n",
        "                    'model_state_dict': model.state_dict(),\n",
        "                    'optimizer_state_dict': optimizer.state_dict(),\n",
        "                    'scheduler_state_dict': scheduler.state_dict(),\n",
        "                    'loss': total_loss / total_micro_batches,\n",
        "                    'sigma': sigma\n",
        "                }, checkpoint_path)\n",
        "                print(f\"Full checkpoint saved at epoch {epoch + 1} -> {checkpoint_path}\")\n",
        "\n",
        "            print(\"Training complete without differential privacy!\")\n",
        "\n",
        "        # **Inference and Evaluation**\n",
        "        model.eval()\n",
        "        val_preds = []\n",
        "        val_labels = []\n",
        "        val_probs = []\n",
        "        val_loss = 0\n",
        "        with torch.no_grad():\n",
        "            for X_batch, y_batch in val_loader:\n",
        "                outputs = model(X_batch)\n",
        "                loss = criterion(outputs, y_batch)\n",
        "                probs = torch.softmax(outputs, dim=1)[:, 1]\n",
        "                val_preds.extend(outputs.argmax(dim=1).cpu().numpy())\n",
        "                val_labels.extend(y_batch.cpu().numpy())\n",
        "                val_probs.extend(probs.detach().cpu().numpy())\n",
        "                val_loss += loss.item()\n",
        "\n",
        "        val_f1 = f1_score(val_labels, val_preds)\n",
        "        val_acc = accuracy_score(val_labels, val_preds)\n",
        "        val_auroc = roc_auc_score(val_labels, val_probs)\n",
        "        val_precision = precision_score(val_labels, val_preds)\n",
        "        val_recall = recall_score(val_labels, val_preds)\n",
        "\n",
        "        #val_metrics = compute_metrics(np.stack(val_preds), np.stack(val_labels))\n",
        "        avg_val_loss = val_loss / len(val_loader)\n",
        "\n",
        "        #print(f\"Epoch {epoch+1}: Train F1 {val_f1:.4f}, Val F1 {val_metrics['f1']:.4f}\")\n",
        "\n",
        "        wandb.log({\n",
        "            \"val/loss\": avg_val_loss,\n",
        "            \"val/accuracy\": val_acc,\n",
        "            \"val/precision\": val_precision,\n",
        "            \"val/recall\": val_recall,\n",
        "            \"val/f1\": val_f1,\n",
        "            \"val/roc_auc\": val_auroc,\n",
        "            \"epoch\": epoch + 1,\n",
        "            \"sigma\":sigma\n",
        "        })\n",
        "\n",
        "        ## Early stopping check\n",
        "        #if early_stopper(val_f1):\n",
        "        #    print(f\"Early stopping triggered at epoch {epoch + 1}\")\n",
        "        #    break\n",
        "\n",
        "\n",
        "    # Print Metrics\n",
        "    print(f\"Accuracy: {val_acc:.4f}\")\n",
        "    print(f\"Precision: {val_precision:.4f}\")\n",
        "    print(f\"Recall: {val_recall:.4f}\")\n",
        "    print(f\"F1-score: {val_f1:.4f}\")\n",
        "    print(f\"AUROC: {val_auroc:.4f}\")\n",
        "\n",
        "\n",
        "    print('Final Testing phase with val data:')\n",
        "\n",
        "    def compute_gradient_norms(model, dataloader, optimizer, group, groupname):\n",
        "\n",
        "\n",
        "        model.train()  # Enable gradient tracking\n",
        "        total_grad_norm_before = 0.0\n",
        "        total_grad_norm_after = 0.0\n",
        "        count = 0\n",
        "        total_loss = 0\n",
        "        eval_loss = 0\n",
        "        total_micro_batches = 0\n",
        "        micro_batch_size = 32\n",
        "        grad_list = []  # Initialize grad_list\n",
        "        all_logical_labels = []\n",
        "        all_logical_predictions = []\n",
        "        layer_wise_grad_samples = {}\n",
        "        cosine_similarities = []\n",
        "        angles = []\n",
        "\n",
        "        accuracy_White = []\n",
        "        accuracy_nonWhite = []\n",
        "\n",
        "        grad_before_White = []\n",
        "        grad_before_nonWhite = []\n",
        "\n",
        "        grad_after_White = []\n",
        "        grad_after_nonWhite = []\n",
        "\n",
        "        f1score_White = []\n",
        "        f1score_nonWhite = []\n",
        "\n",
        "        ratio_White = []\n",
        "        ratio_nonWhite = []\n",
        "\n",
        "        #with torch.no_grad():\n",
        "        with BatchMemoryManager(\n",
        "                data_loader=dataloader,\n",
        "                max_physical_batch_size=micro_batch_size,\n",
        "                optimizer=optimizer\n",
        "            ) as memory_safe_data_loader:\n",
        "\n",
        "            for batch in dataloader:\n",
        "\n",
        "                total_micro_batches += 1\n",
        "                optimizer.zero_grad()\n",
        "\n",
        "                X_batch, y_batch = batch\n",
        "                X_batch, y_batch = X_batch.to(device), y_batch.to(device)\n",
        "\n",
        "                # Forward pass\n",
        "                outputs = model(X_batch)\n",
        "                logits = outputs\n",
        "                labels = y_batch\n",
        "\n",
        "                # Apply weights to CrossEntropyLoss\n",
        "\n",
        "                loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))\n",
        "                total_loss += loss.item()\n",
        "                print_loss = loss.item()\n",
        "\n",
        "                # Compute and log metrics every batch\n",
        "                predictions = logits.detach().cpu().numpy()\n",
        "                labels_cpu = labels.detach().cpu().numpy()\n",
        "                all_logical_predictions.extend(predictions)\n",
        "                all_logical_labels.extend(labels_cpu)\n",
        "\n",
        "                # Backward pass\n",
        "                loss.backward()\n",
        "\n",
        "                # Compute gradient norm before clipping\n",
        "                for name, p in model.named_parameters():\n",
        "                    if p.requires_grad:\n",
        "                        with torch.no_grad():\n",
        "                            grad_sample = p.grad_sample.sum(dim = 0).clone().detach()\n",
        "                            if name not in layer_wise_grad_samples:\n",
        "                                layer_wise_grad_samples[name] = torch.zeros_like(grad_sample)  # Store per-layer name\n",
        "                                #\n",
        "                            layer_wise_grad_samples[name] += grad_sample  # Accumulate across physical batches\n",
        "\n",
        "                is_final_step = optimizer.pre_step()\n",
        "\n",
        "                if is_final_step:\n",
        "                    total_grad_norm_squared = 0.0\n",
        "                    total_grad_norm_squared_summed_noisy = 0.0\n",
        "                    total_noise_grad = 0.0\n",
        "                    all_layer_grads_before=[]\n",
        "                    all_layer_grads_after=[]\n",
        "\n",
        "                    for name, grad_list in layer_wise_grad_samples.items():\n",
        "                        all_layer_grads_before.append(grad_list)  # Store per-layer gradients\n",
        "                        ## Compute the L2 norm squared\n",
        "                        total_grad_norm_squared += torch.norm(grad_list, p=2) ** 2\n",
        "\n",
        "\n",
        "                    overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "\n",
        "                    for name, p in model.named_parameters():\n",
        "                        if p.requires_grad:\n",
        "                            with torch.no_grad():\n",
        "                                all_layer_grads_after.append(p.summed_grad.clone().detach())\n",
        "\n",
        "                                layer_grad_norm_summed_noisy = torch.norm(p.summed_grad, p=2).item()\n",
        "                                total_grad_norm_squared_summed_noisy += layer_grad_norm_summed_noisy ** 2\n",
        "\n",
        "                                grad = torch.norm(p.grad, p=2).item()\n",
        "                                total_noise_grad += (grad - layer_grad_norm_summed_noisy) ** 2\n",
        "\n",
        "\n",
        "                    # Convert lists to single tensors\n",
        "                    total_grad_before = torch.cat([v.view(-1) for v in all_layer_grads_before])\n",
        "                    total_grad_after = torch.cat([v.view(-1) for v in all_layer_grads_after])  # Aggregate all layers\n",
        "\n",
        "\n",
        "                    # Compute Cosine Similarity\n",
        "                    cosine_similarity = torch.nn.functional.cosine_similarity(\n",
        "                        total_grad_before.view(1, -1),\n",
        "                        total_grad_after.view(1, -1),\n",
        "                        dim=1\n",
        "                    ).item()\n",
        "\n",
        "                    # Compute Angle in Degrees\n",
        "                    angle = torch.acos(torch.clamp(torch.tensor(cosine_similarity), -1.0, 1.0)).item() * (180 / torch.pi)\n",
        "\n",
        "                    # Store values for logging\n",
        "                    cosine_similarities.append(cosine_similarity)\n",
        "                    angles.append(angle)\n",
        "\n",
        "                    print(f\"Cosine Similarity: {cosine_similarity}\")\n",
        "                    print(f\"Angle in Degrees: {angle}\")\n",
        "\n",
        "                    # Compute overall gradient norm\n",
        "                    overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "                    overall_grad_norm_summed_noisy = total_grad_norm_squared_summed_noisy ** 0.5\n",
        "                    overal_noise_grad = total_noise_grad ** 0.5\n",
        "\n",
        "                    print(f\"logical batch grad_norm before clipping: {overall_grad_norm}\")\n",
        "                    print(f\"logical batch grad norm after clipping with Noise: {overall_grad_norm_summed_noisy}\")\n",
        "                    print(f\"overal noise: grad - summed_grad: {overal_noise_grad}\")\n",
        "                    print(f\"summed_grad/noise grad: {overall_grad_norm_summed_noisy/overal_noise_grad}\")\n",
        "\n",
        "                    # Compute metrics after all physical batches in the logical batch are processed\n",
        "                    eval_metrics = compute_metrics(np.array(all_logical_predictions), np.array(all_logical_labels))\n",
        "                    avg_eval_loss = total_loss / total_micro_batches\n",
        "\n",
        "                    # Compute the average similarity and angle across all layers\n",
        "                    avg_cosine_similarity = sum(cosine_similarities) / len(cosine_similarities)\n",
        "                    avg_angle = sum(angles) / len(angles)\n",
        "\n",
        "\n",
        "                    #if groupname=='ethnicity':\n",
        "\n",
        "                    print(f\"eval_loss_{group}: {avg_eval_loss}\")\n",
        "                    print(f\"eval_accuracy_{group}: {eval_metrics['accuracy']}\")\n",
        "                    print(f\"eval_precision_{group}: {eval_metrics['precision']}\")\n",
        "                    print(f\"eval_recall_{group}: { eval_metrics['recall']}\")\n",
        "                    print(f\"eval_f1_{group}: {eval_metrics['f1']}\")\n",
        "                    print(f\"logical_batch_cosine_similarity{group}: {avg_cosine_similarity}\")\n",
        "                    print(f\"logical_batch_angle_degrees{group}: {avg_angle}\")\n",
        "\n",
        "                    #if group=='White':\n",
        "                    accuracy_White.append(eval_metrics['accuracy'])\n",
        "                    f1score_White.append(eval_metrics['f1'])\n",
        "                    grad_before_White.append(overall_grad_norm)\n",
        "                    grad_after_White.append(overall_grad_norm_summed_noisy)\n",
        "                    ratio_White.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "\n",
        "                    #elif group=='nonWhite':\n",
        "                    #    accuracy_nonWhite.append(eval_metrics['accuracy'])\n",
        "                    #    f1score_nonWhite.append(eval_metrics['f1'])\n",
        "                    #    grad_before_nonWhite.append(overall_grad_norm)\n",
        "                    #    grad_after_nonWhite.append(overall_grad_norm_summed_noisy)\n",
        "                    #    ratio_nonWhite.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "\n",
        "\n",
        "\n",
        "                    all_logical_labels = []\n",
        "                    all_logical_predictions = []\n",
        "                    grad_list = []  # Reset grad_list\n",
        "                    layer_wise_grad_samples = {}\n",
        "\n",
        "\n",
        "                optimizer.zero_grad()\n",
        "                count += 1\n",
        "\n",
        "\n",
        "            #if groupname=='ethnicity':\n",
        "\n",
        "            #if group == 'White':\n",
        "            avg_grad_norm_before_White = sum(grad_before_White)/len(grad_before_White)\n",
        "            avg_grad_norm_after_White = sum(grad_after_White) / len(grad_after_White)\n",
        "            avg_accuracy_White = sum(accuracy_White) / len(accuracy_White)\n",
        "            avg_f1_White = sum(f1score_White) / len(f1score_White)\n",
        "\n",
        "            print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_White}\")\n",
        "            print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_White}\")\n",
        "            print(f\"Average Accuracy for {group}: {avg_accuracy_White}\")\n",
        "            print(f\"Average f1score for {group}: {avg_f1_White}\")\n",
        "\n",
        "            result_file_path = \"resultsEICU.txt\"\n",
        "\n",
        "            # Open the file in append mode and write the results\n",
        "            with open(result_file_path, \"a\") as f:\n",
        "                f.write(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_White}\\n\")\n",
        "                f.write(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_White}\\n\")\n",
        "                f.write(f\"Average Accuracy for {group}: {avg_accuracy_White}\\n\")\n",
        "                f.write(f\"Average f1score for {group}: {avg_f1_White}\\n\")\n",
        "                f.write(\"\\n\")  # Add an empty line for better readability\n",
        "           #elif group == 'nonWhite':\n",
        "           #     avg_grad_norm_before_nonWhite = sum(grad_before_nonWhite)/len(grad_before_nonWhite)\n",
        "           #     avg_grad_norm_after_nonWhite = sum(grad_after_nonWhite) / len(grad_after_nonWhite)\n",
        "           #     avg_accuracy_nonWhite = sum(accuracy_nonWhite) / len(accuracy_nonWhite)\n",
        "           #     avg_f1_nonWhite = sum(f1score_nonWhite) / len(f1score_nonWhite)\n",
        "\n",
        "                #print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_nonWhite}\")\n",
        "                #print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_nonWhite}\")\n",
        "                #print(f\"Average Accuracy for {group}: {avg_accuracy_nonWhite}\")\n",
        "                #print(f\"Average f1score for {group}: {avg_f1_nonWhite}\")\n",
        "\n",
        "\n",
        "\n",
        "    # Compute gradient norms for different groups\n",
        "\n",
        "    #print(\"\\nComputing gradient norms for White dataset:\")\n",
        "    #compute_gradient_norms(model, White_dataloader, optimizer,'White' , groupname='ethnicity')\n",
        "\n",
        "    #print(\"\\nComputing gradient norms for non-White dataset:\")\n",
        "    #compute_gradient_norms(model, nonWhite_dataloader, optimizer,'nonWhite' , groupname='ethnicity')\n",
        "\n",
        "    print(\"\\nComputing gradient norms for female dataset:\")\n",
        "    compute_gradient_norms(model, female_dataloader, optimizer,'female' , groupname='gender')\n",
        "\n",
        "    print(\"\\nComputing gradient norms for male dataset:\")\n",
        "    compute_gradient_norms(model, male_dataloader, optimizer,'male' , groupname='gender')\n",
        "\n",
        "    #print(\"\\nComputing gradient norms for age below median dataset:\")\n",
        "    #compute_gradient_norms(model, age_below_med_dataloader, optimizer,'below_median' , groupname='age')\n",
        "\n",
        "    #print(\"\\nComputing gradient norms for age above median dataset:\")\n",
        "    #compute_gradient_norms(model, age_above_med_dataloader, optimizer,'above_median' , groupname='age')\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    fire.Fire(main)"
      ],
      "metadata": {
        "id": "JNxje3pSajhD"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}