{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-JD9lvUMiJ6K"
      },
      "outputs": [],
      "source": [
        "def prepcosessNew(df):\n",
        "    columns = [\"age\", \"workclass\", \"fnlwgt\", \"education\", \"education-num\",\n",
        "           \"marital-status\", \"occupation\", \"relationship\", \"race\", \"sex\",\n",
        "           \"capital-gain\", \"capital-loss\", \"hours-per-week\", \"native-country\", \"income\"]\n",
        "\n",
        "\n",
        "    df.replace('?',np.nan,inplace = True)\n",
        "    percent_missing = df.isnull().sum()*100/len(df)\n",
        "    print(\"Number of observation before dropping Nulls:\",df.shape[0])\n",
        "    df.dropna(how='any',inplace = True)\n",
        "    print(\"Number of observation after dropping Nulls:\",df.shape[0])\n",
        "    df= df.drop_duplicates()\n",
        "    print(\"Number of observation after dropping duplicates:\",df.shape[0])\n",
        "\n",
        "\n",
        "    def income(salary):\n",
        "        if salary == '<=50K' or salary == '<=50K.':\n",
        "            return 0\n",
        "        else:\n",
        "            return 1\n",
        "\n",
        "    df['income']= df['income'].apply(income)\n",
        "\n",
        "    def hours_per_week(hours):\n",
        "        if hours< 40:\n",
        "            return 0\n",
        "        elif 40 <= hours <= 60:\n",
        "            return 1\n",
        "        else:\n",
        "            return 2\n",
        "\n",
        "    # Remove extra invisible characters and whitespace\n",
        "    df['age'] = df['age'].astype(str).str.strip()\n",
        "    # Remove non-numeric characters if any slipped through (optional)\n",
        "    df['age'] = df['age'].str.extract('(\\d+)', expand=False)\n",
        "    # Now convert to numeric\n",
        "    df['age'] = pd.to_numeric(df['age'], errors='coerce')\n",
        "    # Check how many were not convertible (optional)\n",
        "    print(\"Non-convertible values:\", df['age'].isna().sum())\n",
        "    # Drop only truly bad rows (should be very few if any)\n",
        "    df.dropna(subset=['age'], inplace=True)\n",
        "    df['age'] = df['age'].astype(int)\n",
        "\n",
        "\n",
        "    df['hours-per-week']=df['hours-per-week'].apply(hours_per_week)\n",
        "\n",
        "    hs_grad = ['HS-grad','11th','10th','9th','12th']\n",
        "    elementary = ['1st-4th','5th-6th','7th-8th']\n",
        "\n",
        "    # replace elements in list.\n",
        "    df['education'].replace(to_replace = hs_grad,value = 'HS-grad',inplace = True)\n",
        "    df['education'].replace(to_replace = elementary,value = 'elementary_school',inplace = True)\n",
        "\n",
        "    low_edu = ['Preschool','elementary_school','HS-grad','Some-college','Assoc-voc','Assoc-acdm']\n",
        "    high_edu = ['Bachelors','Masters','Doctorate','Prof-school']\n",
        "\n",
        "    #Replace elements in list\n",
        "    df['education'].replace(to_replace = low_edu,value = 0,inplace = True) #low\n",
        "    df['education'].replace(to_replace = high_edu,value = 1,inplace = True) #high\n",
        "\n",
        "    married= ['Married-spouse-absent','Married-civ-spouse','Married-AF-spouse']\n",
        "    separated = ['Separated','Divorced']\n",
        "\n",
        "    #replace elements in list.\n",
        "    df['marital-status'].replace(to_replace = married ,value = 'Married',inplace = True)\n",
        "    df['marital-status'].replace(to_replace = separated,value = 'Separated',inplace = True)\n",
        "\n",
        "    non_married=['Never-married','Separated','Widowed']\n",
        "    df['marital-status'].replace(to_replace = non_married,value = 'other',inplace = True)\n",
        "\n",
        "\n",
        "    self_employed = ['Self-emp-not-inc','Self-emp-inc']\n",
        "    govt_employees = ['Local-gov','State-gov','Federal-gov']\n",
        "\n",
        "    #replace elements in list.\n",
        "    df['workclass'].replace(to_replace = self_employed ,value = 'Self_employed',inplace = True)\n",
        "    df['workclass'].replace(to_replace = govt_employees,value = 'Govt_employees',inplace = True)\n",
        "\n",
        "    non_private = ['Govt_employees','Self_employed','Without-pay']\n",
        "    df['workclass'].replace(to_replace = non_private,value = 'non_private',inplace = True)\n",
        "\n",
        "\n",
        "    # coutry to us and non_us\n",
        "    us_country = ['United-States']\n",
        "    non_us_country = ['Holand-Netherlands','Scotland','Honduras','Hungary','Outlying-US(Guam-USVI-etc)','Yugoslavia','Laos','Thailand',\n",
        "                     'Trinadad&Tobago','Cambodia','Hong','Ireland','Ecuador','France','Greece','Peru','Nicaragua','Portugal','Taiwan',\n",
        "                     'Haiti','Iran','Columbia','Poland','Japan','Guatemala','Vietnam','Dominican-Republic','Italy','China','South','Jamaica',\n",
        "                     'England','Cuba','India','El-Salvador','Canada','Puerto-Rico','Germany','Philippines','Mexico']\n",
        "\n",
        "    df['native-country'].replace(to_replace = us_country,value = 'US',inplace = True)\n",
        "    df['native-country'].replace(to_replace = non_us_country,value = 'non_US',inplace = True)\n",
        "\n",
        "    #Race to white and non_white\n",
        "    white=['White']\n",
        "    non_white=['Black','Asian-Pac-Islander','Amer-Indian-Eskimo','Other']\n",
        "\n",
        "    df['race'].replace(to_replace = white,value = 'white',inplace = True)\n",
        "    df['race'].replace(to_replace = non_white,value = 'non_white',inplace = True)\n",
        "\n",
        "    #Relationship to married and other\n",
        "    rel_married=['Husband','Wife']\n",
        "    rel_other=['Not-in-family','Own-child','Unmarried','Other-relative']\n",
        "\n",
        "    df['relationship'].replace(to_replace=rel_married, value='married',inplace=True)\n",
        "    df['relationship'].replace(to_replace=rel_other, value='other',inplace=True)\n",
        "\n",
        "    #occupation to office/heacy_work,other\n",
        "    office=['Adm-clerical','Exec-managerial','Prof-specialty','Sales','Tech-support']\n",
        "    heavy_work=['Craft-repair','Machine-op-inspct','Transport-moving','Handlers-cleaners','Farming-fishing','Priv-house-serv','Armed-Forces']\n",
        "    other=['Other-service','Protective-serv']\n",
        "\n",
        "    df['occupation'].replace(to_replace=office, value='office',inplace=True)\n",
        "    df['occupation'].replace(to_replace=heavy_work, value='heavy_work',inplace=True)\n",
        "    df['occupation'].replace(to_replace=other, value='other',inplace=True)\n",
        "\n",
        "    # Convert 'sex' column to binary if it's in string format\n",
        "    df['sex'] = df['sex'].map({'Male': 1, 'Female': 0})\n",
        "\n",
        "    '''#Undersample based on gender\n",
        "    df_women = df[df['sex'] == 0]\n",
        "    df_men = df[df['sex'] == 1]\n",
        "\n",
        "    min_size = min(len(df_women), len(df_men))\n",
        "    df_women_sample = df_women.sample(n=min_size, random_state=42)\n",
        "    df_men_sample = df_men.sample(n=min_size, random_state=42)\n",
        "\n",
        "    # Combine balanced dataset\n",
        "    df = pd.concat([df_women_sample, df_men_sample]).sample(frac=1, random_state=42).reset_index(drop=True)'''\n",
        "    # Separate the data by 'sex' and 'income'\n",
        "    df_women_class_0 = df[(df['sex'] == 0) & (df['income'] == 0)]  # Women with income <=50K\n",
        "    df_women_class_1 = df[(df['sex'] == 0) & (df['income'] == 1)]  # Women with income >50K\n",
        "    df_men_class_0 = df[(df['sex'] == 1) & (df['income'] == 0)]  # Men with income <=50K\n",
        "    df_men_class_1 = df[(df['sex'] == 1) & (df['income'] == 1)]  # Men with income >50K\n",
        "\n",
        "    # Find the minimum size of the minority class (income 1) in the women and men groups\n",
        "    min_size_women = len(df_women_class_0) + len(df_women_class_1)  # All women\n",
        "    min_size_men = len(df_men_class_1)  # All men with income 1\n",
        "\n",
        "    diff = min_size_women - min_size_men\n",
        "    # For men with income 0, sample the same number as men with income 1\n",
        "    df_men_class_0_sampled = df_men_class_0.sample(n=diff, random_state=42)\n",
        "\n",
        "    # Combine all women with all men having income 1 and the sampled men with income 0\n",
        "    df = pd.concat([df_women_class_0, df_women_class_1, df_men_class_1, df_men_class_0_sampled])\n",
        "\n",
        "    # Shuffle the data\n",
        "    df = df.sample(frac=1, random_state=42).reset_index(drop=True)\n",
        "\n",
        "    # Check the final size and distribution\n",
        "    print('Size after undersampling:', df.shape[0])\n",
        "    print('Women with income 0 size:', df_women_class_0.shape[0], 'Women with income 1 size:', df_women_class_1.shape[0])\n",
        "    print('Men with income 0 size:', df_men_class_0_sampled.shape[0], 'Men with income 1 size:', df_men_class_1.shape[0])\n",
        "\n",
        "    #Age\n",
        "    median_age = df['age'].median()\n",
        "    df['age'] = df['age'].apply(lambda x: 0 if x <= median_age else 1)\n",
        "\n",
        "    # Different from the paper ... this is outlier\n",
        "    print(\"Number of observation before removing:\",df.shape)\n",
        "    index_gain = df[df['capital-gain'] == 99999].index\n",
        "    df.drop(labels = index_gain,axis = 0,inplace =True)\n",
        "    print(\"Number of observation after removing:\",df.shape)\n",
        "\n",
        "\n",
        "    df['capital-gain']=df['capital-gain'].apply(lambda x: 0 if x<=5000 else 1)\n",
        "    df['capital-loss']=df['capital-loss'].apply(lambda x: 0 if x<=40 else 1)\n",
        "\n",
        "    df=df.drop(columns=['fnlwgt','education-num'])\n",
        "\n",
        "    #Separate categorical and numberical columns\n",
        "    cat_col = df.dtypes[df.dtypes == 'object']\n",
        "    num_col = df.dtypes[df.dtypes != 'object']\n",
        "\n",
        "    adult_df =df.copy()\n",
        "\n",
        "    #Mapping\n",
        "    adult_df['workclass'] = adult_df['workclass'].map({'Private': 1, 'non_private': 0})\n",
        "    adult_df['marital-status'] = adult_df['marital-status'].map({'Married': 1, 'other': 0})\n",
        "    adult_df['relationship'] = adult_df['relationship'].map({'married': 1, 'other': 0})\n",
        "    adult_df['race'] = adult_df['race'].map({'white': 1, 'non_white': 0})\n",
        "    adult_df['native-country'] = adult_df['native-country'].map({'US': 1, 'non_US': 0})\n",
        "    #adult_df['sex'] = adult_df['sex'].map({'Male': 1, 'Female': 0})\n",
        "    adult_df = pd.get_dummies(adult_df, columns=['occupation'])\n",
        "\n",
        "\n",
        "    y = adult_df['income']\n",
        "    adult_df.drop(labels = ['income'],axis = 1,inplace = True)\n",
        "    X = adult_df\n",
        "    return X,y"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#Adult income non private\n",
        "#non private\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "from sklearn.pipeline import Pipeline\n",
        "from sklearn.base import TransformerMixin\n",
        "from sklearn.preprocessing import MinMaxScaler,StandardScaler\n",
        "from sklearn.model_selection import train_test_split,cross_val_score,GridSearchCV\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from torch.utils.data import Dataset, DataLoader\n",
        "from sklearn.preprocessing import StandardScaler\n",
        "import opacus\n",
        "from opacus.accountants.utils import get_noise_multiplier\n",
        "from opacus.utils.batch_memory_manager import BatchMemoryManager\n",
        "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score\n",
        "from sklearn.preprocessing import LabelEncoder\n",
        "import torch.nn.functional as F\n",
        "import optuna\n",
        "from transformers import get_linear_schedule_with_warmup\n",
        "from torch.utils.data import ConcatDataset\n",
        "import sys\n",
        "from tqdm.auto import tqdm\n",
        "import wandb\n",
        "import os\n",
        "print('start')\n",
        "\n",
        "# ====================================================================================================\n",
        "# User-configurable settings\n",
        "# ====================================================================================================\n",
        "\n",
        "# Random seed (match with training seed for checkpoint compatibility)\n",
        "seed = 2025\n",
        "\n",
        "# Paths\n",
        "train_path = '/datasets/adults_extracted/adult.data'\n",
        "test_path = '/datasets/adults_extracted/adult.test'\n",
        "CHECKPOINT_PATH = \"/Checkpoints/LessComplexIncomeclipping=0.05/30CheckpointsIncomeEqual\"  #Choose the correct path for the checkpoint\n",
        "\n",
        "\n",
        "epsilon = 8\n",
        "max_grad_norm = 0.01 #0.05  Choose either 0.01 or 0.05\n",
        "epochs = 30\n",
        "batch_size = 256\n",
        "lr = 4e-5 #5e-5  for max_grad_norm = 0.01 we need lr = 4e-5 and for max_grad_norm = 0.05 we need lr = 5e-5\n",
        "\n",
        "\n",
        "columns = [\"age\", \"workclass\", \"fnlwgt\", \"education\", \"education-num\",\n",
        "       \"marital-status\", \"occupation\", \"relationship\", \"race\", \"sex\",\n",
        "       \"capital-gain\", \"capital-loss\", \"hours-per-week\", \"native-country\", \"income\"]\n",
        "\n",
        "\n",
        "np.random.seed(seed)\n",
        "torch.manual_seed(seed)\n",
        "torch.cuda.manual_seed(seed)\n",
        "torch.backends.cudnn.deterministic = True\n",
        "torch.backends.cudnn.benchmark = False\n",
        "\n",
        "g = torch.Generator()\n",
        "g.manual_seed(seed)\n",
        "\n",
        "\n",
        "wandb.init(\n",
        "    project=\"income-DP_differentThreshold\",  # Change project name if needed\n",
        "    name=f\"income-DP-genderEqual-{max_grad_norm} Threshold-{epochs}Epoch-{batch_size} batch size, {lr} LR\",  # Give a unique run name\n",
        ")\n",
        "wandb.config.seed = seed\n",
        "\n",
        "# Set seeds for reproducibility\n",
        "\n",
        "class EarlyStopping:\n",
        "    def __init__(self, patience=3, threshold=0.0):\n",
        "        self.patience = patience\n",
        "        self.threshold = threshold\n",
        "        self.counter = 0\n",
        "        self.best_score = None\n",
        "        self.early_stop = False\n",
        "\n",
        "    def __call__(self, val_score):\n",
        "        if self.best_score is None:\n",
        "            self.best_score = val_score\n",
        "        elif val_score < self.best_score + self.threshold:\n",
        "            self.counter += 1\n",
        "            if self.counter >= self.patience:\n",
        "                self.early_stop = True\n",
        "        else:\n",
        "            self.best_score = val_score\n",
        "            self.counter = 0\n",
        "        return self.early_stop\n",
        "\n",
        "def compute_metrics(preds, labels, probs=None):\n",
        "    preds = np.array(preds)\n",
        "    labels = np.array(labels)\n",
        "    metrics = {\n",
        "        \"accuracy\": accuracy_score(labels, preds),\n",
        "        \"precision\": precision_score(labels, preds, zero_division=0),\n",
        "        \"recall\": recall_score(labels, preds, zero_division=0),\n",
        "        \"f1\": f1_score(labels, preds, zero_division=0),\n",
        "    }\n",
        "    # Check if both classes are present for AUC computation\n",
        "    if len(np.unique(labels)) == 2 and probs is not None:\n",
        "        metrics[\"auc_roc\"] = roc_auc_score(labels, probs)\n",
        "    else:\n",
        "        metrics[\"auc_roc\"] = None  # Set AUC to None if not both classes are present\n",
        "\n",
        "    return metrics\n",
        "\n",
        "checkpoint_dir = CHECKPOINT_PATH\n",
        "os.makedirs(checkpoint_dir, exist_ok=True)\n",
        "\n",
        "df_train = pd.read_csv(train_path, names=columns, skipinitialspace=True)\n",
        "df_test = pd.read_csv(test_path, names=columns, skipinitialspace=True)\n",
        "\n",
        "\n",
        "X_train,y_train = prepcosessNew(df_train)\n",
        "X_test,y_test = prepcosessNew(df_test)\n",
        "\n",
        "#X = X.values if hasattr(X, 'values') else np.array(X)\n",
        "#y = y.values if hasattr(y, 'values') else np.array(y)\n",
        "\n",
        "#X_train,X_test,y_train,y_test = train_test_split(X,y,test_size =0.3,random_state = 42,stratify=y)\n",
        "#X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)\n",
        "X_test, X_val, y_test, y_val = train_test_split(X_test, y_test, test_size=0.5, random_state=42,stratify=y_test)\n",
        "\n",
        "\n",
        "X_train = X_train.to_numpy(dtype=np.float32)  # Convert DataFrame to NumPy\n",
        "y_train = y_train.to_numpy(dtype=np.int64)  # Ensure y_train is integer\n",
        "X_test = X_test.to_numpy(dtype=np.float32)\n",
        "y_test = y_test.to_numpy(dtype=np.int64)\n",
        "X_val = X_val.to_numpy(dtype=np.float32)\n",
        "y_val = y_val.to_numpy(dtype=np.int64)\n",
        "\n",
        "\n",
        "# Convert to PyTorch tensors\n",
        "X_train_tensor = torch.tensor(X_train, dtype=torch.float32)\n",
        "y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)\n",
        "X_test_tensor = torch.tensor(X_test, dtype=torch.float32)\n",
        "y_test_tensor = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)\n",
        "X_val_tensor = torch.tensor(X_val, dtype=torch.float32)\n",
        "y_val_tensor = torch.tensor(y_val, dtype=torch.float32).unsqueeze(1)\n",
        "\n",
        "\n",
        "age_col_idx = columns.index('age')  # Find index of 'age' in original DataFrame\n",
        "age_below_med_mask = X_test[:, age_col_idx] == 0\n",
        "X_test_below_med, y_test_below_med = X_test[age_below_med_mask], y_test[age_below_med_mask]\n",
        "X_test_above_med, y_test_above_med = X_test[~age_below_med_mask], y_test[~age_below_med_mask]\n",
        "\n",
        "\n",
        "X_test_below_med_tensor = torch.tensor(X_test_below_med, dtype=torch.float32)\n",
        "y_test_below_med_tensor = torch.tensor(y_test_below_med, dtype=torch.float32).unsqueeze(1)\n",
        "X_test_above_med_tensor = torch.tensor(X_test_above_med, dtype=torch.float32)\n",
        "y_test_above_med_tensor = torch.tensor(y_test_above_med, dtype=torch.float32).unsqueeze(1)\n",
        "\n",
        "\n",
        "gender_col_idx = columns.index('sex') # Find index of 'age' in original DataFrame\n",
        "male_mask = X_test[:, gender_col_idx] == 1\n",
        "female_mask = X_test[:, gender_col_idx] == 0\n",
        "\n",
        "X_test_male, y_test_male = X_test[male_mask], y_test[male_mask]\n",
        "X_test_female, y_test_female = X_test[female_mask], y_test[female_mask]\n",
        "\n",
        "X_test_male_tensor = torch.tensor(X_test_male, dtype=torch.float32)\n",
        "y_test_male_tensor = torch.tensor(y_test_male, dtype=torch.float32).unsqueeze(1)\n",
        "X_test_female_tensor = torch.tensor(X_test_female, dtype=torch.float32)\n",
        "y_test_female_tensor = torch.tensor(y_test_female, dtype=torch.float32).unsqueeze(1)\n",
        "\n",
        "\n",
        "\n",
        "class IncomeDataset(Dataset):\n",
        "    def __init__(self, X, y):\n",
        "        self.X = X\n",
        "        self.y = y\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.X)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        return self.X[idx], self.y[idx]\n",
        "\n",
        "# Create dataset objects\n",
        "train_dataset = IncomeDataset(X_train_tensor, y_train_tensor)\n",
        "#test_dataset = IncomeDataset(X_test_tensor, y_test_tensor)\n",
        "val_dataset =  IncomeDataset(X_val_tensor, y_val_tensor)\n",
        "male_dataset = IncomeDataset(X_test_male_tensor,y_test_male_tensor)\n",
        "female_dataset = IncomeDataset(X_test_female_tensor,y_test_female_tensor)\n",
        "age_above_med = IncomeDataset(X_test_above_med_tensor,y_test_above_med_tensor)\n",
        "age_below_med = IncomeDataset(X_test_below_med_tensor,y_test_below_med_tensor)\n",
        "\n",
        "# Create data loaders\n",
        "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,\n",
        "                          worker_init_fn=lambda worker_id: torch.manual_seed(seed + worker_id),  # Fixes worker shuffling\n",
        "                          generator=g)  # Ensures deterministic shuffling\n",
        "#test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
        "male_loader = DataLoader(male_dataset, batch_size=batch_size, shuffle=True)\n",
        "female_loader = DataLoader(female_dataset, batch_size=batch_size, shuffle=True)\n",
        "age_above_med_loader = DataLoader(age_above_med, batch_size=batch_size, shuffle=True)\n",
        "age_below_med_loader = DataLoader(age_below_med, batch_size=batch_size, shuffle=True)\n",
        "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)\n",
        "\n",
        "class IncomeModel(nn.Module):\n",
        "    def __init__(self, input_size):\n",
        "        super(IncomeModel, self).__init__()\n",
        "        self.fc1 = nn.Linear(input_size, 256)\n",
        "        self.fc2 = nn.Linear(256, 256)\n",
        "        self.out = nn.Linear(256, 1)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = F.relu(self.fc1(x))\n",
        "        x = F.relu(self.fc2(x))\n",
        "        return self.out(x)\n",
        "\n",
        "# Initialize model\n",
        "input_size = X_train.shape[1]  # Number of features\n",
        "model = IncomeModel(input_size)\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "model.to(device)  # Move model to GPU if available\n",
        "\n",
        "# Define loss function and optimizer\n",
        "#criterion = nn.BCEWithLogitsLoss()\n",
        "pos_weight = torch.tensor([2.0]).to(device)  # Only one weight for the positive class\n",
        "criterion = nn.BCEWithLogitsLoss(reduction=\"sum\", pos_weight=pos_weight)\n",
        "optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.000005)\n",
        "\n",
        "# Compute warmup steps dynamically\n",
        "warmup_ratio = 0.1  # 10% of total steps\n",
        "num_training_steps = len(train_loader) * epochs\n",
        "num_warmup_steps = int(warmup_ratio * num_training_steps)\n",
        "scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)\n",
        "\n",
        "\n",
        "sample_rate = batch_size / len(X_train)\n",
        "\n",
        "sigma = get_noise_multiplier(\n",
        "    target_epsilon=epsilon,\n",
        "    target_delta=1e-5,\n",
        "    sample_rate=sample_rate,\n",
        "    epochs=epochs\n",
        ")\n",
        "print(f\"Using Noise Multiplier (sigma): {sigma:.4f}\")\n",
        "\n",
        "privacy_engine = opacus.PrivacyEngine()\n",
        "\n",
        "model, optimizer, train_loader = privacy_engine.make_private(\n",
        "    module=model,\n",
        "    optimizer=optimizer,\n",
        "    data_loader=train_loader,\n",
        "    noise_multiplier=sigma,\n",
        "    max_grad_norm=max_grad_norm,\n",
        "    poisson_sampling=False, # poisson_sampling=True, If I was using drop_last = False in dataloader\n",
        "    loss_reduction=\"sum\" #sum\n",
        ")\n",
        "\n",
        "early_stopper = EarlyStopping(patience=20, threshold=0.0)  # You can tweak this\n",
        "best_val_f1 = 0\n",
        "best_model_state = None\n",
        "\n",
        "micro_batch_size = 20\n",
        "grad_norms_before_clipping = []\n",
        "grad_norms_after_clipping = []\n",
        "grad_norms_after_clipping_noisy = []\n",
        "all_noises_grad_sample = []\n",
        "all_noises_grad = []\n",
        "\n",
        "#Training loop\n",
        "model.train()  # Set to training mode\n",
        "for epoch in range(epochs):\n",
        "    model.train()\n",
        "    total_loss = 0\n",
        "    total = 0\n",
        "    all_preds = []\n",
        "    all_labels = []\n",
        "    all_probs = []\n",
        "\n",
        "    print_loss=0\n",
        "    total_micro_batches = 0\n",
        "    all_train_predictions = []\n",
        "    all_train_labels = []\n",
        "    grad_list = []\n",
        "    all_logical_predictions = []\n",
        "    all_logical_labels = []\n",
        "    layer_wise_grad_samples = {}\n",
        "    cosine_similarities = []\n",
        "    angles = []\n",
        "\n",
        "    progress_bar = tqdm(train_loader, desc=f\"Epoch {epoch + 1}\")\n",
        "\n",
        "    optimizer.zero_grad()\n",
        "\n",
        "    with BatchMemoryManager(\n",
        "    data_loader=train_loader,\n",
        "    max_physical_batch_size=micro_batch_size,\n",
        "    optimizer=optimizer,\n",
        "    ) as memory_safe_data_loader:\n",
        "\n",
        "        for batch_idx, batch in enumerate(memory_safe_data_loader):\n",
        "\n",
        "            total_micro_batches += 1\n",
        "\n",
        "            X_batch, y_batch = batch\n",
        "            X_batch, y_batch = X_batch.to(device), y_batch.to(device)\n",
        "\n",
        "            # Forward pass\n",
        "            outputs = model(X_batch)\n",
        "            probs = torch.sigmoid(outputs).squeeze()\n",
        "            predicted = (probs >= 0.5).long()\n",
        "\n",
        "            loss = criterion(outputs, y_batch)\n",
        "            total_loss += loss.item()\n",
        "            print_loss = loss.item()\n",
        "\n",
        "            # Compute and log metrics every batch\n",
        "            all_preds.extend(predicted.detach().cpu().numpy())\n",
        "            all_labels.extend(y_batch.detach().cpu().numpy())\n",
        "            all_probs.extend(probs.detach().cpu().numpy())\n",
        "\n",
        "            # Backward pass\n",
        "            loss.backward()\n",
        "\n",
        "            for name, p in model.named_parameters():\n",
        "                if p.requires_grad:\n",
        "                    with torch.no_grad():\n",
        "                        grad_sample = p.grad_sample.sum(dim = 0).clone().detach()  #birone if #sum accross samples\n",
        "                        if name not in layer_wise_grad_samples:\n",
        "                            layer_wise_grad_samples[name] = torch.zeros_like(grad_sample)  # Store per-layer name\n",
        "                        layer_wise_grad_samples[name] += grad_sample  # Accumulate across physical batches\n",
        "\n",
        "            #Calculating and storing gradient before clipping and adding noise and after clipping and adding noise\n",
        "            is_final_step = optimizer.pre_step()\n",
        "\n",
        "            if is_final_step == True:\n",
        "                total_grad_norm_squared = 0.0\n",
        "                total_grad_norm_squared_summed_noisy = 0.0\n",
        "                total_noise_grad = 0.0\n",
        "                all_layer_grads_before=[]\n",
        "                all_layer_grads_after=[]\n",
        "\n",
        "                num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
        "                expected_noise = (num_trainable_params ** 0.5) * sigma # * max_grad_norm\n",
        "\n",
        "                for name, grad_list in layer_wise_grad_samples.items():\n",
        "                    all_layer_grads_before.append(grad_list)  # Store per-layer gradients\n",
        "                    ## Compute the L2 norm squared\n",
        "                    total_grad_norm_squared += torch.norm(grad_list, p=2) ** 2\n",
        "\n",
        "                for p in model.parameters():\n",
        "                    if p.requires_grad:\n",
        "                         with torch.no_grad():\n",
        "                                all_layer_grads_after.append(p.summed_grad.clone().detach())\n",
        "\n",
        "                                layer_grad_norm_summed_noisy = torch.norm(p.summed_grad, p=2).item()  #torch.norm(p.summed_grad.view(-1), p=2).item()\n",
        "                                total_grad_norm_squared_summed_noisy += layer_grad_norm_summed_noisy ** 2\n",
        "\n",
        "                                grad = torch.norm(p.grad, p=2).item()\n",
        "                                total_noise_grad += (grad - layer_grad_norm_summed_noisy) ** 2\n",
        "\n",
        "                # Convert lists to single tensors\n",
        "                total_grad_before = torch.cat([v.view(-1) for v in all_layer_grads_before])\n",
        "                total_grad_after = torch.cat([v.view(-1) for v in all_layer_grads_after])  # Aggregate all layers\n",
        "\n",
        "                # Compute Cosine Similarity\n",
        "                cosine_similarity = torch.nn.functional.cosine_similarity(\n",
        "                    total_grad_before.view(1, -1),\n",
        "                    total_grad_after.view(1, -1),\n",
        "                    dim=1\n",
        "                ).item()\n",
        "\n",
        "                # Compute Angle in Degrees\n",
        "                angle = torch.acos(torch.clamp(torch.tensor(cosine_similarity), -1.0, 1.0)).item() * (180 / torch.pi)\n",
        "\n",
        "                # Store values for logging\n",
        "                cosine_similarities.append(cosine_similarity)\n",
        "                angles.append(angle)\n",
        "\n",
        "                #print(f\"Cosine Similarity: {cosine_similarity}\")\n",
        "                #print(f\"Angle in Degrees: {angle}\")\n",
        "\n",
        "                # Compute the overall gradient norm\n",
        "                overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "                grad_norms_before_clipping.append(overall_grad_norm)\n",
        "                overall_grad_norm_summed_noisy = total_grad_norm_squared_summed_noisy ** 0.5\n",
        "                grad_norms_after_clipping_noisy.append(overall_grad_norm_summed_noisy)\n",
        "                overal_noise_grad = total_noise_grad ** 0.5\n",
        "                all_noises_grad.append(overal_noise_grad)\n",
        "\n",
        "                # Compute metrics after all physical batches in the logical batch are processed\n",
        "                train_metrics = compute_metrics(np.array(all_preds), np.array(all_labels), np.array(all_probs))\n",
        "                avg_train_loss = total_loss / total_micro_batches\n",
        "\n",
        "                # print(f\"train_loss: {avg_train_loss}\")\n",
        "                # print(f\"train_accuracy: {train_metrics['accuracy']}\")\n",
        "                # print(f\"train_precision: {train_metrics['precision']}\")\n",
        "                # print(f\"train_recall: { train_metrics['recall']}\")\n",
        "                # print(f\"train_f1: {train_metrics['f1']}\")\n",
        "                # print(f\"auroc: {train_metrics['auc_roc']}\")\n",
        "\n",
        "                wandb.log({\n",
        "                    \"epoch\": epoch + 1,\n",
        "                    \"train_loss\": avg_train_loss,\n",
        "                    \"train_f1\": train_metrics[\"f1\"],\n",
        "                    \"train_accuracy\": train_metrics[\"accuracy\"],\n",
        "                    \"train_precicion\": train_metrics[\"precision\"],\n",
        "                    \"train_recall\": train_metrics[\"recall\"],\n",
        "                    \"train_auroc\": train_metrics[\"auc_roc\"],\n",
        "\n",
        "                })\n",
        "\n",
        "                optimizer.original_optimizer.step()\n",
        "                scheduler.step()\n",
        "                all_logical_labels = []\n",
        "                all_logical_predictions = []\n",
        "                all_labels = []\n",
        "                all_preds = []\n",
        "                all_probs = []\n",
        "                grad_list = []\n",
        "                layer_wise_grad_samples = {}\n",
        "\n",
        "            progress_bar.set_postfix({\n",
        "            \"lr\": scheduler.get_last_lr()[0],\n",
        "            \"step_loss\": print_loss\n",
        "            })\n",
        "\n",
        "            optimizer.zero_grad()\n",
        "\n",
        "    Save checkpoint every 5 epochs\n",
        "    if (epoch + 1) % 5 == 0:\n",
        "        checkpoint_path = os.path.join(checkpoint_dir, f\"checkpoint_epoch_{epoch+1}.pt\")\n",
        "        torch.save({\n",
        "            'epoch': epoch + 1,\n",
        "            'model_state_dict': model.state_dict(),\n",
        "            'optimizer_state_dict': optimizer.state_dict(),\n",
        "            'scheduler_state_dict': scheduler.state_dict(),\n",
        "            'best_val_f1': best_val_f1 if 'best_val_f1' in locals() else None,\n",
        "            'sigma': sigma\n",
        "        }, checkpoint_path)\n",
        "        print(f\" Checkpoint saved at {checkpoint_path}\")\n",
        "\n",
        "    print(f\"Epoch {epoch+1}/{epochs} | Loss: {total_loss:.4f} | Acc: {train_metrics['accuracy']:.4f} | Precision: {train_metrics['precision']:.4f} | Recall: {train_metrics['recall']:.4f} | F1: {train_metrics['f1']:.4f} | AUROC: {train_metrics['auc_roc']}\",flush=True)\n",
        "    print(\"Training complete with differential privacy!\")\n",
        "\n",
        "    #Evaluate\n",
        "    model.eval()  # Set to evaluation mode\n",
        "    eval_loss = 0\n",
        "    all_eval_preds = []\n",
        "    all_eval_labels = []\n",
        "    all_eval_probs = []\n",
        "    with torch.no_grad():\n",
        "        for X_batch, y_batch in val_loader:\n",
        "            X_batch, y_batch = X_batch.to('cuda'), y_batch.to('cuda')\n",
        "\n",
        "            outputs = model(X_batch)\n",
        "            probs = torch.sigmoid(outputs).squeeze()\n",
        "            predicted = (probs >= 0.5).long()\n",
        "            loss = criterion(outputs, y_batch)\n",
        "            eval_loss += loss.item()\n",
        "\n",
        "            all_eval_preds.extend(predicted.detach().cpu().numpy())\n",
        "            all_eval_labels.extend(y_batch.detach().cpu().numpy())\n",
        "            all_eval_probs.extend(probs.detach().cpu().numpy())\n",
        "\n",
        "    average_loss = eval_loss / len(val_loader)\n",
        "    eval_metrics = compute_metrics(np.array(all_eval_preds), np.array(all_eval_labels), np.array(all_eval_probs))\n",
        "\n",
        "    # Log metrics for the epoch\n",
        "    current_lr = scheduler.get_last_lr()[0]\n",
        "\n",
        "    # Print epoch metrics\n",
        "    print(f\"\\nEpoch {epoch + 1} Summary:\")\n",
        "    print(f\"Validation Loss: {average_loss:.4f}\")\n",
        "    print(f\"Validation Accuracy: {eval_metrics['accuracy']:.4f}\")\n",
        "    print(f\"Validation F1: {eval_metrics['f1']:.4f}\")\n",
        "    print(f\"Learning Rate: {current_lr:.2e}\")\n",
        "    print(f\"Sigma: {sigma}\")\n",
        "\n",
        "    wandb.log({\n",
        "        \"epoch\": epoch + 1,\n",
        "        \"val_loss\": average_loss,\n",
        "        \"val_f1\": eval_metrics[\"f1\"],\n",
        "        \"val_accuracy\": eval_metrics[\"accuracy\"],\n",
        "        \"val_auroc\": eval_metrics[\"auc_roc\"],\n",
        "\n",
        "    })\n",
        "\n",
        "    ## Early stopping check\n",
        "    # if early_stopper(eval_metrics[\"accuracy\"]):\n",
        "    #     print(f\"Early stopping triggered at epoch {epoch + 1}\")\n",
        "    #     break\n",
        "\n",
        "\n",
        "print('Final Testing phase with val data:')\n",
        "\n",
        "def compute_gradient_norms(model, dataloader, optimizer, group, groupname):\n",
        "\n",
        "    model.train()  # Enable gradient tracking\n",
        "    total_grad_norm_before = 0.0\n",
        "    total_grad_norm_after = 0.0\n",
        "    count = 0\n",
        "    total_loss = 0\n",
        "    eval_loss = 0\n",
        "    total_micro_batches = 0\n",
        "    micro_batch_size = 20\n",
        "    grad_list = []  # Initialize grad_list\n",
        "    all_logical_labels = []\n",
        "    all_logical_predictions = []\n",
        "    layer_wise_grad_samples = {}\n",
        "    cosine_similarities = []\n",
        "    angles = []\n",
        "    all_preds = []\n",
        "    all_labels = []\n",
        "    all_probs = []\n",
        "    accuracy_White = []\n",
        "    accuracy_nonWhite = []\n",
        "\n",
        "    grad_before_White = []\n",
        "    grad_before_nonWhite = []\n",
        "\n",
        "    grad_after_White = []\n",
        "    grad_after_nonWhite = []\n",
        "\n",
        "    f1score_White = []\n",
        "    f1score_nonWhite = []\n",
        "\n",
        "    ratio_White = []\n",
        "    ratio_nonWhite = []\n",
        "\n",
        "    #with torch.no_grad():\n",
        "    with BatchMemoryManager(\n",
        "            data_loader=dataloader,\n",
        "            max_physical_batch_size=micro_batch_size,\n",
        "            optimizer=optimizer\n",
        "        ) as memory_safe_data_loader:\n",
        "\n",
        "        for batch in dataloader:\n",
        "\n",
        "            total_micro_batches += 1\n",
        "\n",
        "            X_batch, y_batch = batch\n",
        "            X_batch, y_batch = X_batch.to(device), y_batch.to(device)\n",
        "\n",
        "            # Forward pass\n",
        "            outputs = model(X_batch)\n",
        "            probs = torch.sigmoid(outputs).squeeze()\n",
        "            predicted = (probs >= 0.5).long()\n",
        "\n",
        "            loss = criterion(outputs, y_batch)\n",
        "            total_loss += loss.item()\n",
        "            print_loss = loss.item()\n",
        "\n",
        "            # Compute and log metrics every batch\n",
        "            all_preds.extend(predicted.detach().cpu().numpy())\n",
        "            all_labels.extend(y_batch.detach().cpu().numpy())\n",
        "            all_probs.extend(probs.detach().cpu().numpy())\n",
        "\n",
        "            # Backward pass\n",
        "            loss.backward()\n",
        "\n",
        "            # Compute gradient norm before clipping\n",
        "            for name, p in model.named_parameters():\n",
        "                if p.requires_grad:\n",
        "                    with torch.no_grad():\n",
        "                        grad_sample = p.grad_sample.sum(dim = 0).clone().detach()\n",
        "                        if name not in layer_wise_grad_samples:\n",
        "                            layer_wise_grad_samples[name] = torch.zeros_like(grad_sample)  # Store per-layer name\n",
        "\n",
        "                        layer_wise_grad_samples[name] += grad_sample  # Accumulate across physical batches\n",
        "\n",
        "            is_final_step = optimizer.pre_step()\n",
        "\n",
        "            if is_final_step:\n",
        "                total_grad_norm_squared = 0.0\n",
        "                total_grad_norm_squared_summed_noisy = 0.0\n",
        "                total_noise_grad = 0.0\n",
        "                all_layer_grads_before=[]\n",
        "                all_layer_grads_after=[]\n",
        "\n",
        "                for name, grad_list in layer_wise_grad_samples.items():\n",
        "                    all_layer_grads_before.append(grad_list)  # Store per-layer gradients\n",
        "                    ## Compute the L2 norm squared\n",
        "                    total_grad_norm_squared += torch.norm(grad_list, p=2) ** 2\n",
        "\n",
        "\n",
        "                overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "\n",
        "                for name, p in model.named_parameters():\n",
        "                    if p.requires_grad:\n",
        "                        with torch.no_grad():\n",
        "                            all_layer_grads_after.append(p.summed_grad.clone().detach())\n",
        "\n",
        "                            layer_grad_norm_summed_noisy = torch.norm(p.summed_grad, p=2).item()\n",
        "                            total_grad_norm_squared_summed_noisy += layer_grad_norm_summed_noisy ** 2\n",
        "\n",
        "                            grad = torch.norm(p.grad, p=2).item()\n",
        "                            total_noise_grad += (grad - layer_grad_norm_summed_noisy) ** 2\n",
        "\n",
        "\n",
        "                # Convert lists to single tensors\n",
        "                total_grad_before = torch.cat([v.view(-1) for v in all_layer_grads_before])\n",
        "                total_grad_after = torch.cat([v.view(-1) for v in all_layer_grads_after])  # Aggregate all layers\n",
        "\n",
        "\n",
        "                # Compute Cosine Similarity\n",
        "                cosine_similarity = torch.nn.functional.cosine_similarity(\n",
        "                    total_grad_before.view(1, -1),\n",
        "                    total_grad_after.view(1, -1),\n",
        "                    dim=1\n",
        "                ).item()\n",
        "\n",
        "                # Compute Angle in Degrees\n",
        "                angle = torch.acos(torch.clamp(torch.tensor(cosine_similarity), -1.0, 1.0)).item() * (180 / torch.pi)\n",
        "\n",
        "                # Store values for logging\n",
        "                cosine_similarities.append(cosine_similarity)\n",
        "                angles.append(angle)\n",
        "\n",
        "                #print(f\"Cosine Similarity: {cosine_similarity}\")\n",
        "                #print(f\"Angle in Degrees: {angle}\")\n",
        "\n",
        "                # Compute overall gradient norm\n",
        "                overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "                overall_grad_norm_summed_noisy = total_grad_norm_squared_summed_noisy ** 0.5\n",
        "                overal_noise_grad = total_noise_grad ** 0.5\n",
        "\n",
        "                #print(f\"logical batch grad_norm before clipping: {overall_grad_norm}\")\n",
        "                #print(f\"logical batch grad norm after clipping with Noise: {overall_grad_norm_summed_noisy}\")\n",
        "                #print(f\"overal noise: grad - summed_grad: {overal_noise_grad}\")\n",
        "                #print(f\"summed_grad/noise grad: {overall_grad_norm_summed_noisy/overal_noise_grad}\")\n",
        "\n",
        "                # Compute metrics after all physical batches in the logical batch are processed\n",
        "                eval_metrics = compute_metrics(np.array(all_preds), np.array(all_labels), np.array(all_probs))\n",
        "                avg_train_loss = total_loss / total_micro_batches\n",
        "\n",
        "                avg_eval_loss = eval_loss / total_micro_batches\n",
        "\n",
        "                # Compute the average similarity and angle across all layers\n",
        "                avg_cosine_similarity = sum(cosine_similarities) / len(cosine_similarities)\n",
        "                avg_angle = sum(angles) / len(angles)\n",
        "\n",
        "\n",
        "                #if groupname=='ethnicity':\n",
        "\n",
        "                #print(f\"eval_loss_{group}: {avg_eval_loss}\")\n",
        "                # print(f\"eval_accuracy_{group}: {eval_metrics['accuracy']}\")\n",
        "                # print(f\"eval_precision_{group}: {eval_metrics['precision']}\")\n",
        "                # print(f\"eval_recall_{group}: { eval_metrics['recall']}\")\n",
        "                # print(f\"eval_f1_{group}: {eval_metrics['f1']}\")\n",
        "                # print(f\"logical_batch_cosine_similarity{group}: {avg_cosine_similarity}\")\n",
        "                # print(f\"logical_batch_angle_degrees{group}: {avg_angle}\")\n",
        "\n",
        "                #if group=='White':\n",
        "                accuracy_White.append(eval_metrics['accuracy'])\n",
        "                f1score_White.append(eval_metrics['f1'])\n",
        "                grad_before_White.append(overall_grad_norm)\n",
        "                grad_after_White.append(overall_grad_norm_summed_noisy)\n",
        "                ratio_White.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "\n",
        "                #elif group=='nonWhite':\n",
        "                #    accuracy_nonWhite.append(eval_metrics['accuracy'])\n",
        "                #    f1score_nonWhite.append(eval_metrics['f1'])\n",
        "                #    grad_before_nonWhite.append(overall_grad_norm)\n",
        "                #    grad_after_nonWhite.append(overall_grad_norm_summed_noisy)\n",
        "                #    ratio_nonWhite.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "\n",
        "\n",
        "                all_logical_labels = []\n",
        "                all_logical_predictions = []\n",
        "                grad_list = []  # Reset grad_list\n",
        "                layer_wise_grad_samples = {}\n",
        "                all_labels = []\n",
        "                all_preds = []\n",
        "                all_probs = []\n",
        "\n",
        "\n",
        "\n",
        "            optimizer.zero_grad()\n",
        "            count += 1\n",
        "\n",
        "\n",
        "        #if groupname=='ethnicity':\n",
        "\n",
        "        #if group == 'White':\n",
        "        avg_grad_norm_before_White = sum(grad_before_White)/len(grad_before_White)\n",
        "        avg_grad_norm_after_White = sum(grad_after_White) / len(grad_after_White)\n",
        "        avg_accuracy_White = sum(accuracy_White) / len(accuracy_White)\n",
        "        avg_f1_White = sum(f1score_White) / len(f1score_White)\n",
        "\n",
        "        print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_White}\")\n",
        "        print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_White}\")\n",
        "        print(f\"Average Accuracy for {group}: {avg_accuracy_White}\")\n",
        "        print(f\"Average f1score for {group}: {avg_f1_White}\")\n",
        "\n",
        "\n",
        "        result_file_path = \"resultsIncome.txt\"\n",
        "\n",
        "        # Open the file in append mode and write the results\n",
        "        with open(result_file_path, \"a\") as f:\n",
        "            f.write(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_White}\\n\")\n",
        "            f.write(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_White}\\n\")\n",
        "            f.write(f\"Average Accuracy for {group}: {avg_accuracy_White}\\n\")\n",
        "            f.write(f\"Average f1score for {group}: {avg_f1_White}\\n\")\n",
        "\n",
        "       # elif group == 'nonWhite':\n",
        "       #     avg_grad_norm_before_nonWhite = sum(grad_before_nonWhite)/len(grad_before_nonWhite)\n",
        "       #     avg_grad_norm_after_nonWhite = sum(grad_after_nonWhite) / len(grad_after_nonWhite)\n",
        "       #     avg_accuracy_nonWhite = sum(accuracy_nonWhite) / len(accuracy_nonWhite)\n",
        "       #     avg_f1_nonWhite = sum(f1score_nonWhite) / len(f1score_nonWhite)\n",
        "\n",
        "            #print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_nonWhite}\")\n",
        "            #print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_nonWhite}\")\n",
        "            #print(f\"Average Accuracy for {group}: {avg_accuracy_nonWhite}\")\n",
        "            #print(f\"Average f1score for {group}: {avg_f1_nonWhite}\")\n",
        "\n",
        "\n",
        "\n",
        "# Compute gradient norms for different groups\n",
        "\n",
        "#print(\"\\nComputing gradient norms for age below median dataset:\")\n",
        "#compute_gradient_norms(model, age_below_med_loader, optimizer,'below age median' , groupname='age')\n",
        "\n",
        "#print(\"\\nComputing gradient norms for age above/equal to median dataset:\")\n",
        "#compute_gradient_norms(model, age_above_med_loader, optimizer,'above age median' , groupname='age')"
      ],
      "metadata": {
        "id": "KqRHnnZfiKti"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}