{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VYo6UiW90F3A"
      },
      "outputs": [],
      "source": [
        "#White and non White data loader ethnicity\n",
        "from torch.utils.data.dataset import Dataset\n",
        "from torch.utils.data import DataLoader\n",
        "from datasets import load_dataset\n",
        "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
        "import torch\n",
        "import wandb\n",
        "import numpy as np\n",
        "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
        "import fire\n",
        "import opacus\n",
        "from opacus.utils.batch_memory_manager import BatchMemoryManager\n",
        "from opacus.accountants.utils import get_noise_multiplier\n",
        "import torch.nn.functional as F\n",
        "import os\n",
        "\n",
        "\n",
        "# ====================================================================================================\n",
        "# User-configurable settings\n",
        "# ====================================================================================================\n",
        "\n",
        "\n",
        "# Random seed, Select it same as the seed that was used for training\n",
        "SEED = 1000\n",
        "\n",
        "# Paths\n",
        "CACHE_DIR = \"/cache\"  #Cache directory\n",
        "DATA_DIR = \"/datasets/unpackedmimiciii/mimic3MultiGender\" #Data directory\n",
        "CHECKPOINT_PATH = \"/Checkpoints/MimicLOS/NEWSEED1000-20010CheckpointsNEW/checkpoint_epoch_200.pt\" #checkpoint directory\n",
        "TOKENIZER_PATH = \"/Checkpoints/MimicLOS/NEWSEED1000-20010CheckpointsNEW/tokenizer_epoch_200\" #Tokenizer directory\n",
        "MODEL_NAME_OR_PATH = \"medicalai/ClinicalBERT\"\n",
        "\n",
        "# Dataset files (relative to DATA_DIR)\n",
        "TRAIN_FILE = \"LOS_WEEKS_train.csv\"\n",
        "VAL_FILE   = \"LOS_WEEKS_val.csv\"\n",
        "TEST_FILE  = \"LOS_WEEKS_test.csv\"\n",
        "\n",
        "def main():\n",
        "\n",
        "    def seed_everything(seed):\n",
        "        \"\"\"Set the seed for reproducibility across all libraries, including GPU operations.\"\"\"\n",
        "        os.environ['PYTHONHASHSEED'] = str(seed)\n",
        "        np.random.seed(seed)\n",
        "        torch.manual_seed(seed)\n",
        "        torch.cuda.manual_seed(seed)\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "        # Ensure deterministic behavior in cuDNN\n",
        "        torch.backends.cudnn.deterministic = True\n",
        "        torch.backends.cudnn.benchmark = False  # Set to False for full determinism\n",
        "\n",
        "        print(f\"Global seed set to: {seed}\")\n",
        "\n",
        "    # SEED_LIST=(42 0 1234 2023 777 999 1337)\n",
        "    seed = SEED\n",
        "    seed_everything(seed)\n",
        "\n",
        "    # Load dataset\n",
        "    cache_dir = CACHE_DIR\n",
        "    data_dir = DATA_DIR\n",
        "    data_files = {\n",
        "        \"train\": f\"{data_dir}/LOS_WEEKS_train.csv\",\n",
        "        \"validation\": f\"{data_dir}/LOS_WEEKS_val.csv\",\n",
        "        \"test\": f\"{data_dir}/LOS_WEEKS_test.csv\",\n",
        "    }\n",
        "    dataset = load_dataset('csv', data_files=data_files, cache_dir=cache_dir)\n",
        "\n",
        "    class TextDataset(Dataset):\n",
        "        def __init__(self, encodings, labels):\n",
        "            self.encodings = encodings\n",
        "            self.labels = labels\n",
        "\n",
        "        def __getitem__(self, idx):\n",
        "            item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
        "            item['labels'] = torch.tensor(self.labels[idx])\n",
        "            return item\n",
        "\n",
        "        def __len__(self):\n",
        "            return len(self.labels)\n",
        "\n",
        "    def load_checkpoint(checkpoint_path, model_name_or_path,tokenizer_path,dataset):\n",
        "\n",
        "        checkpoint = torch.load(checkpoint_path,weights_only=False)\n",
        "        model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, num_labels=4)\n",
        "\n",
        "\n",
        "        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)\n",
        "\n",
        "        #Freeze layers\n",
        "        for name, parameter in model.named_parameters():\n",
        "            if \"classifier\" in name or \"layer.5.output_layer_norm\" in name or \"layer.5.ffn.lin2\" in name or \"layer.5.ffn.lin1\" in name or \"layer.5.sa_layer_norm\" in name or \"layer.5.attention.out_lin\" in name or \"layer.5.attention.v_lin\" in name or \"layer.5.attention.k_lin\" in name:\n",
        "                parameter.requires_grad = True\n",
        "            else:\n",
        "                parameter.requires_grad = False\n",
        "\n",
        "        model.to(\"cuda\")\n",
        "        start_epoch = checkpoint['epoch']\n",
        "        best_val_f1 = checkpoint['best_val_f1']\n",
        "\n",
        "        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)\n",
        "\n",
        "\n",
        "        def process_dataset(dataset_split):\n",
        "            dataset_split = dataset_split.select_columns([\"text\", \"los_label\"])\n",
        "\n",
        "            encodings = tokenizer(\n",
        "                dataset_split['text'],\n",
        "                truncation=True,\n",
        "                padding='max_length',\n",
        "                max_length=512,\n",
        "                return_tensors='pt'\n",
        "            )\n",
        "            labels = [int(label) for label in dataset_split['los_label']]\n",
        "            return TextDataset(encodings, labels)\n",
        "\n",
        "        g = torch.Generator()\n",
        "        g.manual_seed(seed)\n",
        "\n",
        "        train_dataset = process_dataset(dataset['train'])\n",
        "\n",
        "        train_dataloader = DataLoader(\n",
        "            train_dataset,\n",
        "            batch_size=500, #batch_size,\n",
        "            shuffle=True,\n",
        "            drop_last=True,\n",
        "            worker_init_fn=lambda worker_id: torch.manual_seed(seed + worker_id),  # Fixes worker shuffling\n",
        "            generator=g\n",
        "            )\n",
        "\n",
        "        model.train()\n",
        "        # Opacus setup\n",
        "        sigma = get_noise_multiplier(\n",
        "            target_epsilon=8,\n",
        "            target_delta=1e-5,\n",
        "            sample_rate=500 / len(dataset['train']),\n",
        "            epochs=200\n",
        "        )\n",
        "        print(sigma)\n",
        "\n",
        "        privacy_engine = opacus.PrivacyEngine()\n",
        "\n",
        "        model, optimizer, train_dataloader = privacy_engine.make_private(\n",
        "            module=model,\n",
        "            optimizer=optimizer,\n",
        "            data_loader=train_dataloader,\n",
        "            noise_multiplier=sigma,\n",
        "            max_grad_norm=0.1,\n",
        "            poisson_sampling=False,\n",
        "            loss_reduction=\"sum\" #sum\n",
        "        )\n",
        "        print(f\"Loaded checkpoint from {checkpoint_path}, starting from epoch {start_epoch+1}\")\n",
        "\n",
        "        model.load_state_dict(checkpoint[\"model_state_dict\"], strict=False)  # strict=False allows missing keys\n",
        "        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
        "\n",
        "        # Extract LR from loaded optimizer\n",
        "        for param_group in optimizer.param_groups:\n",
        "            lr = param_group[\"lr\"]  # Extract the learning rate from the optimizer checkpoint\n",
        "\n",
        "        print(f\"Loaded learning rate from checkpoint: {lr}\")\n",
        "\n",
        "        return model, tokenizer, optimizer , start_epoch, best_val_f1\n",
        "\n",
        "    # Load a checkpoint, Define paths\n",
        "    checkpoint_path = CHECKPOINT_PATH # \"/Checkpoints/MimicLOS/NEWSEED1000-20010CheckpointsNEW/checkpoint_epoch_200.pt\"\n",
        "    model_name_or_path = MODEL_NAME_OR_PATH  #\"medicalai/ClinicalBERT\"\n",
        "    tokenizer_path = TOKENIZER_PATH  #\"/Checkpoints/MimicLOS/NEWSEED1000-20010CheckpointsNEW/tokenizer_epoch_200\"\n",
        "\n",
        "    model, tokenizer, optimizer, start_epoch, best_val_f1 = load_checkpoint(checkpoint_path, model_name_or_path,tokenizer_path,dataset)\n",
        "\n",
        "\n",
        "    def filter_ethnicity(dataset_split):\n",
        "        \"\"\"Filter dataset by gender and tokenize.\"\"\"\n",
        "        white_dataset = dataset_split.filter(lambda example: example['ethnicity_group'] == 'White')\n",
        "        nonwhite_dataset = dataset_split.filter(lambda example: example['ethnicity_group'] != 'White')\n",
        "\n",
        "        # Keep only relevant columns\n",
        "        dataset_split = dataset_split.select_columns([\"text\", \"los_label\"])\n",
        "\n",
        "        WhiteEncodings = tokenizer(\n",
        "            white_dataset['text'],\n",
        "            truncation=True,\n",
        "            padding='max_length',\n",
        "            max_length=512,\n",
        "            return_tensors='pt'\n",
        "        )\n",
        "\n",
        "        Whitelabels = torch.tensor([int(label) for label in white_dataset['los_label']])\n",
        "\n",
        "        nonWhiteEncodings = tokenizer(\n",
        "            nonwhite_dataset['text'],\n",
        "            truncation=True,\n",
        "            padding='max_length',\n",
        "            max_length=512,\n",
        "            return_tensors='pt'\n",
        "        )\n",
        "\n",
        "        nonWhitelabels = torch.tensor([int(label) for label in nonwhite_dataset['los_label']])\n",
        "\n",
        "        return TextDataset(WhiteEncodings, Whitelabels), TextDataset(nonWhiteEncodings, nonWhitelabels)\n",
        "\n",
        "    White_dataset, nonWhite_dataset = filter_ethnicity(dataset['test']) #White\n",
        "\n",
        "    White_dataloader = DataLoader(White_dataset, batch_size=500)\n",
        "    nonWhite_dataloader = DataLoader(nonWhite_dataset, batch_size=500)\n",
        "\n",
        "\n",
        "    def compute_metrics(predictions, labels):\n",
        "        predictions = np.argmax(predictions, axis=-1)\n",
        "        return {\n",
        "            \"accuracy\": accuracy_score(labels, predictions),\n",
        "            \"precision\": precision_score(labels, predictions, average='macro'),\n",
        "            \"recall\": recall_score(labels, predictions, average='macro'),\n",
        "            \"f1\": f1_score(labels, predictions, average='macro')\n",
        "        }\n",
        "\n",
        "    def compute_gradient_norms(model, dataloader, optimizer, group, groupname):\n",
        "\n",
        "        wandb.init(project=\"gradient_analysis_per_gender\", name=\"white and non-white\")\n",
        "\n",
        "        model.train()  # Enable gradient tracking\n",
        "        total_grad_norm_before = 0.0\n",
        "        total_grad_norm_after = 0.0\n",
        "        count = 0\n",
        "        eval_loss = 0\n",
        "        total_micro_batches = 0\n",
        "        micro_batch_size = 40\n",
        "        grad_list = []  # Initialize grad_list\n",
        "        all_logical_labels = []\n",
        "        all_logical_predictions = []\n",
        "        layer_wise_grad_samples = {}\n",
        "        cosine_similarities = []\n",
        "        angles = []\n",
        "\n",
        "        accuracy_White = []\n",
        "        accuracy_nonWhite = []\n",
        "\n",
        "        grad_before_White = []\n",
        "        grad_before_nonWhite = []\n",
        "\n",
        "        grad_after_White = []\n",
        "        grad_after_nonWhite = []\n",
        "\n",
        "        f1score_White = []\n",
        "        f1score_nonWhite = []\n",
        "\n",
        "        ratio_White = []\n",
        "        ratio_nonWhite = []\n",
        "\n",
        "        loss_White = []\n",
        "        loss_nonWhite = []\n",
        "\n",
        "        #with torch.no_grad():\n",
        "        with BatchMemoryManager(\n",
        "                data_loader=dataloader,\n",
        "                max_physical_batch_size=micro_batch_size,\n",
        "                optimizer=optimizer\n",
        "            ) as memory_safe_data_loader:\n",
        "\n",
        "            for batch in dataloader:\n",
        "\n",
        "                total_micro_batches += 1\n",
        "\n",
        "                optimizer.zero_grad()\n",
        "                batch = {k: v.to('cuda') for k, v in batch.items()}\n",
        "                outputs = model(**batch)\n",
        "                loss_fn = torch.nn.CrossEntropyLoss(reduction=\"sum\")\n",
        "                loss = loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)), batch['labels'].view(-1))\n",
        "                eval_loss += loss.item()\n",
        "                logits = outputs.logits\n",
        "                labels = batch['labels']\n",
        "\n",
        "                # Compute and log metrics every batch\n",
        "                predictions = logits.detach().cpu().numpy()\n",
        "                labels_cpu = labels.detach().cpu().numpy()\n",
        "                all_logical_predictions.extend(predictions)\n",
        "                all_logical_labels.extend(labels_cpu)\n",
        "\n",
        "                # Backward pass\n",
        "                loss.backward()\n",
        "\n",
        "                # Compute gradient norm before clipping\n",
        "                for name, p in model.named_parameters():\n",
        "                    if p.requires_grad:\n",
        "                        with torch.no_grad():\n",
        "                            grad_sample = p.grad_sample.sum(dim = 0).clone().detach()\n",
        "\n",
        "                            if name not in layer_wise_grad_samples:\n",
        "                                layer_wise_grad_samples[name] = torch.zeros_like(grad_sample)  # Store per-layer name\n",
        "\n",
        "                            layer_wise_grad_samples[name] += grad_sample  # Accumulate across physical batches\n",
        "\n",
        "                is_final_step = optimizer.pre_step()\n",
        "\n",
        "                if is_final_step:\n",
        "                    total_grad_norm_squared = 0.0\n",
        "                    total_grad_norm_squared_summed_noisy = 0.0\n",
        "                    total_noise_grad = 0.0\n",
        "                    all_layer_grads_before=[]\n",
        "                    all_layer_grads_after=[]\n",
        "\n",
        "                    for name, grad_list in layer_wise_grad_samples.items():\n",
        "                        all_layer_grads_before.append(grad_list)  # Store per-layer gradients\n",
        "                        ## Compute the L2 norm squared\n",
        "                        total_grad_norm_squared += torch.norm(grad_list, p=2) ** 2\n",
        "\n",
        "                    #for name, grad_list in layer_wise_grad_samples.items():\n",
        "                    #    grads_tensor = torch.stack(grad_list)  # Stack gradients across all physical batches\n",
        "                    #    total_grad_sample = grads_tensor.sum(dim=0)  # Sum across all micro-batches\n",
        "                    #    all_layer_grads_before.append(total_grad_sample)\n",
        "                    #    # Compute the L2 norm squared\n",
        "                    #    total_grad_norm_squared += torch.norm(total_grad_sample, p=2) ** 2\n",
        "\n",
        "                    overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "\n",
        "                    for name, p in model.named_parameters():\n",
        "                        if p.requires_grad:\n",
        "                            with torch.no_grad():\n",
        "                                all_layer_grads_after.append(p.summed_grad.clone().detach())\n",
        "\n",
        "                                layer_grad_norm_summed_noisy = torch.norm(p.summed_grad, p=2).item()\n",
        "                                total_grad_norm_squared_summed_noisy += layer_grad_norm_summed_noisy ** 2\n",
        "\n",
        "                                grad = torch.norm(p.grad, p=2).item()\n",
        "                                total_noise_grad += (grad - layer_grad_norm_summed_noisy) ** 2\n",
        "\n",
        "\n",
        "                    # Convert lists to single tensors\n",
        "                    total_grad_before = torch.cat([v.view(-1) for v in all_layer_grads_before])\n",
        "                    total_grad_after = torch.cat([v.view(-1) for v in all_layer_grads_after])  # Aggregate all layers\n",
        "\n",
        "\n",
        "                    # Compute Cosine Similarity\n",
        "                    cosine_similarity = torch.nn.functional.cosine_similarity(\n",
        "                        total_grad_before.view(1, -1),\n",
        "                        total_grad_after.view(1, -1),\n",
        "                        dim=1\n",
        "                    ).item()\n",
        "\n",
        "                    # Compute Angle in Degrees\n",
        "                    angle = torch.acos(torch.clamp(torch.tensor(cosine_similarity), -1.0, 1.0)).item() * (180 / torch.pi)\n",
        "\n",
        "                    # Store values for logging\n",
        "                    cosine_similarities.append(cosine_similarity)\n",
        "                    angles.append(angle)\n",
        "\n",
        "                    #print(f\"Cosine Similarity: {cosine_similarity}\")\n",
        "                    #print(f\"Angle in Degrees: {angle}\")\n",
        "\n",
        "                    # Compute overall gradient norm\n",
        "                    overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "                    overall_grad_norm_summed_noisy = total_grad_norm_squared_summed_noisy ** 0.5\n",
        "                    overal_noise_grad = total_noise_grad ** 0.5\n",
        "\n",
        "                    #print(f\"logical batch grad_norm before clipping: {overall_grad_norm}\")\n",
        "                    #print(f\"logical batch grad norm after clipping with Noise: {overall_grad_norm_summed_noisy}\")\n",
        "                    #print(f\"overal noise: grad - summed_grad: {overal_noise_grad}\")\n",
        "                    #print(f\"summed_grad/noise grad: {overall_grad_norm_summed_noisy/overal_noise_grad}\")\n",
        "\n",
        "                    # Compute metrics after all physical batches in the logical batch are processed\n",
        "                    eval_metrics = compute_metrics(np.array(all_logical_predictions), np.array(all_logical_labels))\n",
        "                    avg_eval_loss = eval_loss / total_micro_batches\n",
        "\n",
        "                    # Compute the average similarity and angle across all layers\n",
        "                    avg_cosine_similarity = sum(cosine_similarities) / len(cosine_similarities)\n",
        "                    avg_angle = sum(angles) / len(angles)\n",
        "\n",
        "\n",
        "                    if groupname=='ethnicity':\n",
        "                        wandb.log({\n",
        "                            f\"eval_loss_{group}\": avg_eval_loss,\n",
        "                            f\"eval_accuracy_{group}\": eval_metrics['accuracy'],\n",
        "                            f\"eval_precision_{group}\": eval_metrics['precision'],\n",
        "                            f\"eval_recall_{group}\": eval_metrics['recall'],\n",
        "                            f\"eval_f1_{group}\": eval_metrics['f1'],\n",
        "                            f\"logical batch grad_norm before clipping: {group}\": overall_grad_norm,\n",
        "                            f\"logical batch grad norm after clipping with Noise: {group}\": overall_grad_norm_summed_noisy,\n",
        "                            f\"overal noise: grad - summed_gradg: {group}\": overal_noise_grad,\n",
        "                            f\"summed_grad/noise grad: {group}\":(overall_grad_norm_summed_noisy/overal_noise_grad),\n",
        "                            f\"norm after clipping / norm before clipping: {group} \" :overall_grad_norm_summed_noisy/overall_grad_norm,\n",
        "                            f\"logical_batch_cosine_similarity: {group}\": avg_cosine_similarity,\n",
        "                            f\"logical_batch_angle_degrees: {group}\": avg_angle,\n",
        "                            f\"Cosine Similarity_{group}\": cosine_similarity,\n",
        "                            f\"Angle in Degrees_{group}\": angle\n",
        "                        })\n",
        "\n",
        "                        #print(f\"eval_loss_{group}: {avg_eval_loss}\")\n",
        "                        #print(f\"eval_accuracy_{group}: {eval_metrics['accuracy']}\")\n",
        "                        #print(f\"eval_precision_{group}: {eval_metrics['precision']}\")\n",
        "                        #print(f\"eval_recall_{group}: { eval_metrics['recall']}\")\n",
        "                        #print(f\"eval_f1_{group}: {eval_metrics['f1']}\")\n",
        "                        #print(f\"logical_batch_cosine_similarity{group}: {avg_cosine_similarity}\")\n",
        "                        #print(f\"logical_batch_angle_degrees{group}: {avg_angle}\")\n",
        "\n",
        "                        if group=='White':\n",
        "                            accuracy_White.append(eval_metrics['accuracy'])\n",
        "                            f1score_White.append(eval_metrics['f1'])\n",
        "                            grad_before_White.append(overall_grad_norm)\n",
        "                            grad_after_White.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_White.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "                            loss_White.append(avg_eval_loss)\n",
        "\n",
        "                        elif group=='nonWhite':\n",
        "                            accuracy_nonWhite.append(eval_metrics['accuracy'])\n",
        "                            f1score_nonWhite.append(eval_metrics['f1'])\n",
        "                            grad_before_nonWhite.append(overall_grad_norm)\n",
        "                            grad_after_nonWhite.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_nonWhite.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "                            loss_nonWhite.append(avg_eval_loss)\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "                    all_logical_labels = []\n",
        "                    all_logical_predictions = []\n",
        "                    grad_list = []  # Reset grad_list\n",
        "                    layer_wise_grad_samples = {}\n",
        "\n",
        "\n",
        "                optimizer.zero_grad()\n",
        "                count += 1\n",
        "\n",
        "\n",
        "            if groupname=='ethnicity':\n",
        "\n",
        "                if group == 'White':\n",
        "                    avg_grad_norm_before_White = sum(grad_before_White)/len(grad_before_White)\n",
        "                    avg_grad_norm_after_White = sum(grad_after_White) / len(grad_after_White)\n",
        "                    avg_accuracy_White = sum(accuracy_White) / len(accuracy_White)\n",
        "                    avg_f1_White = sum(f1score_White) / len(f1score_White)\n",
        "                    avg_loss_White = sum(loss_White) / len(loss_White)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_White}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_White}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_White}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_White}\")\n",
        "                    print(f\"Average loss for {group}: {avg_loss_White}\")\n",
        "\n",
        "\n",
        "                elif group == 'nonWhite':\n",
        "                    avg_grad_norm_before_nonWhite = sum(grad_before_nonWhite)/len(grad_before_nonWhite)\n",
        "                    avg_grad_norm_after_nonWhite = sum(grad_after_nonWhite) / len(grad_after_nonWhite)\n",
        "                    avg_accuracy_nonWhite = sum(accuracy_nonWhite) / len(accuracy_nonWhite)\n",
        "                    avg_f1_nonWhite = sum(f1score_nonWhite) / len(f1score_nonWhite)\n",
        "                    avg_loss_nonWhite = sum(loss_nonWhite) / len(loss_nonWhite)\n",
        "\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_nonWhite}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_nonWhite}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_nonWhite}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_nonWhite}\")\n",
        "                    print(f\"Average loss for {group}: {avg_loss_nonWhite}\")\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "    # Compute gradient norms for different groups\n",
        "\n",
        "    print(\"\\nComputing gradient norms for White dataset:\")\n",
        "    compute_gradient_norms(model, White_dataloader, optimizer,'White' , groupname='ethnicity')\n",
        "\n",
        "    print(\"\\nComputing gradient norms for non-White dataset:\")\n",
        "    compute_gradient_norms(model, nonWhite_dataloader, optimizer,'nonWhite' , groupname='ethnicity')\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    import fire\n",
        "    fire.Fire(main)"
      ]
    }
  ]
}