{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [],
      "metadata": {
        "id": "zOhZtIBJn0JB"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Load a checkpoint, Define paths\n",
        "CHECKPOINT_PATH = \"Checkpoints/NormalFair/MimicLOS/NEWSEED2025NormalFair/checkpoint_epoch_200.pt\"\n",
        "MODEL_NAME_OR_PATH = \"medicalai/ClinicalBERT\"\n",
        "TOKENIZER_PATH = \"/Checkpoints/NormalFair/MimicLOS/NEWSEED2025NormalFair/tokenizer_epoch_200\"  # Where tokenizer was saved\n",
        "CHACHE_DIR = \"/cache\"\n",
        "DATA_DIR = \"/datasets/unpackedmimiciii/mimic3MultiGender\"\n",
        "\n",
        "SEED = 2025  # SEED_LIST=(42 0 1234 2023 777 999 1337) CHOOSE THE SEED THAT MATCHES THE TRAINING SEED"
      ],
      "metadata": {
        "id": "vV2hhEzsn72p"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FEUZJAvEmCiq"
      },
      "outputs": [],
      "source": [
        "from opacus.optimizers import DPOptimizer\n",
        "from opacus.optimizers.optimizer import (\n",
        "    _check_processed_flag,\n",
        "    _generate_noise,\n",
        "    _mark_as_processed\n",
        ")\n",
        "\n",
        "class NormalFairOptimizer(DPOptimizer):\n",
        "    def __init__(self, *args, **kwargs):\n",
        "        super().__init__(*args,**kwargs)\n",
        "\n",
        "\n",
        "    def clip_and_accumulate(self):\n",
        "        \"\"\"\n",
        "        Performs gradient clipping.\n",
        "        Stores clipped and aggregated gradients into `p.summed_grad```\n",
        "        \"\"\"\n",
        "\n",
        "        if len(self.grad_samples[0]) == 0:\n",
        "            # Empty batch\n",
        "            per_sample_clip_factor = torch.zeros(\n",
        "                (0,), device=self.grad_samples[0].device\n",
        "            )\n",
        "        else:\n",
        "            per_param_norms = [\n",
        "                g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples\n",
        "            ]\n",
        "            per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)\n",
        "            per_sample_clip_factor = torch.tanh(self.max_grad_norm / (per_sample_norms + 1e-6))  #the only change to make it fair with tanh\n",
        "\n",
        "        for p in self.params:\n",
        "            _check_processed_flag(p.grad_sample)\n",
        "            grad_sample = self._get_flat_grad_sample(p)\n",
        "            grad = torch.einsum(\"i,i...\", per_sample_clip_factor, grad_sample)\n",
        "\n",
        "            if p.summed_grad is not None:\n",
        "                p.summed_grad += grad\n",
        "            else:\n",
        "                p.summed_grad = grad\n",
        "\n",
        "            _mark_as_processed(p.grad_sample)\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from opacus import PrivacyEngine\n",
        "import torch.optim as optim\n",
        "from typing import List, Union\n",
        "from opacus.optimizers import DPOptimizer\n",
        "\n",
        "class NormalfairPrivacyEngine(PrivacyEngine):\n",
        "    def __init__(self, *args, **kwargs):\n",
        "        super().__init__(*args, **kwargs)\n",
        "\n",
        "    def _prepare_optimizer(\n",
        "        self,\n",
        "        *,\n",
        "        optimizer: optim.Optimizer,\n",
        "        noise_multiplier: float,\n",
        "        max_grad_norm: Union[float, List[float]],\n",
        "        expected_batch_size: int,\n",
        "        loss_reduction: str = \"mean\",\n",
        "        distributed: bool = False,\n",
        "        clipping: str = \"flat\",\n",
        "        noise_generator=None,\n",
        "        grad_sample_mode=\"hooks\",\n",
        "        **kwargs,\n",
        "    ) -> DPOptimizer:\n",
        "        if isinstance(optimizer, DPOptimizer):\n",
        "            optimizer = optimizer.original_optimizer\n",
        "\n",
        "        generator = None\n",
        "        if self.secure_mode:\n",
        "            generator = self.secure_rng\n",
        "        elif noise_generator is not None:\n",
        "            generator = noise_generator\n",
        "\n",
        "        new_fair = NormalFairOptimizer(\n",
        "            optimizer=optimizer,\n",
        "            noise_multiplier=noise_multiplier,\n",
        "            max_grad_norm=max_grad_norm,\n",
        "            expected_batch_size=expected_batch_size,\n",
        "            loss_reduction=loss_reduction,\n",
        "            generator=generator,\n",
        "            secure_mode=self.secure_mode,\n",
        "            **kwargs,\n",
        "        )\n",
        "\n",
        "        return new_fair\n",
        ""
      ],
      "metadata": {
        "id": "6U04VBq_mDu-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Testing phase ethnicity**"
      ],
      "metadata": {
        "id": "wf-zJlMhmYgu"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#White and non White data loader ethnicity fair Adaptive privacy (New approach)\n",
        "from torch.utils.data.dataset import Dataset\n",
        "from torch.utils.data import DataLoader\n",
        "from datasets import load_dataset\n",
        "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
        "import torch\n",
        "import wandb\n",
        "import numpy as np\n",
        "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
        "import fire\n",
        "import opacus\n",
        "from opacus.utils.batch_memory_manager import BatchMemoryManager\n",
        "from opacus.accountants.utils import get_noise_multiplier\n",
        "import torch.nn.functional as F\n",
        "import os\n",
        "from opacus import PrivacyEngine\n",
        "import torch.optim as optim\n",
        "from typing import List, Union\n",
        "from opacus.optimizers import DPOptimizer\n",
        "from opacus.optimizers import AdaClipDPOptimizer, DPOptimizer\n",
        "from opacus.optimizers.optimizer import (\n",
        "    _check_processed_flag,\n",
        "    _generate_noise,\n",
        "    _mark_as_processed\n",
        ")\n",
        "\n",
        "def main():\n",
        "\n",
        "    def seed_everything(seed):\n",
        "        \"\"\"Set the seed for reproducibility across all libraries, including GPU operations.\"\"\"\n",
        "        os.environ['PYTHONHASHSEED'] = str(seed)\n",
        "        np.random.seed(seed)\n",
        "        torch.manual_seed(seed)\n",
        "        torch.cuda.manual_seed(seed)\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "        # Ensure deterministic behavior in cuDNN\n",
        "        torch.backends.cudnn.deterministic = True\n",
        "        torch.backends.cudnn.benchmark = False  # Set to False for full determinism\n",
        "\n",
        "        print(f\"Global seed set to: {seed}\")\n",
        "\n",
        "    seed_everything(seed)\n",
        "\n",
        "    # Load dataset\n",
        "    cache_dir = CHACHE_DIR\n",
        "    data_dir = DATA_DIR\n",
        "    data_files = {\n",
        "        \"train\": f\"{data_dir}/LOS_WEEKS_train.csv\",\n",
        "        \"validation\": f\"{data_dir}/LOS_WEEKS_val.csv\",\n",
        "        \"test\": f\"{data_dir}/LOS_WEEKS_test.csv\",\n",
        "    }\n",
        "    dataset = load_dataset('csv', data_files=data_files, cache_dir=cache_dir)\n",
        "\n",
        "    class TextDataset(Dataset):\n",
        "        def __init__(self, encodings, labels):\n",
        "            self.encodings = encodings\n",
        "            self.labels = labels\n",
        "\n",
        "        def __getitem__(self, idx):\n",
        "            item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
        "            item['labels'] = torch.tensor(self.labels[idx])\n",
        "            return item\n",
        "\n",
        "        def __len__(self):\n",
        "            return len(self.labels)\n",
        "\n",
        "    def load_checkpoint(checkpoint_path, model_name_or_path,tokenizer_path,dataset):\n",
        "\n",
        "        checkpoint = torch.load(checkpoint_path,weights_only=False)\n",
        "        model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, num_labels=4)\n",
        "        # Remove `_module` created by opacus\n",
        "        #new_state_dict = {}\n",
        "        #for key, value in checkpoint[\"model_state_dict\"].items():\n",
        "        #    new_key = key.replace(\"_module.\", \"\")  # Remove _module prefix\n",
        "        #    new_state_dict[new_key] = value\n",
        "\n",
        "        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)\n",
        "\n",
        "        #Freeze layers\n",
        "        for name, parameter in model.named_parameters():\n",
        "            if \"classifier\" in name or \"layer.5.output_layer_norm\" in name or \"layer.5.ffn.lin2\" in name or \"layer.5.ffn.lin1\" in name or \"layer.5.sa_layer_norm\" in name or \"layer.5.attention.out_lin\" in name or \"layer.5.attention.v_lin\" in name or \"layer.5.attention.k_lin\" in name:\n",
        "                parameter.requires_grad = True\n",
        "            else:\n",
        "                parameter.requires_grad = False\n",
        "\n",
        "        model.to(\"cuda\")\n",
        "        start_epoch = checkpoint['epoch']\n",
        "        best_val_f1 = checkpoint['best_val_f1']\n",
        "\n",
        "        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)\n",
        "\n",
        "        def process_dataset(dataset_split):\n",
        "            dataset_split = dataset_split.select_columns([\"text\", \"los_label\"])\n",
        "\n",
        "            encodings = tokenizer(\n",
        "                dataset_split['text'],\n",
        "                truncation=True,\n",
        "                padding='max_length',\n",
        "                max_length=512,\n",
        "                return_tensors='pt'\n",
        "            )\n",
        "            labels = [int(label) for label in dataset_split['los_label']]\n",
        "            return TextDataset(encodings, labels)\n",
        "\n",
        "        g = torch.Generator()\n",
        "        g.manual_seed(seed)\n",
        "\n",
        "        train_dataset = process_dataset(dataset['train'])\n",
        "\n",
        "        train_dataloader = DataLoader(\n",
        "            train_dataset,\n",
        "            batch_size=500, #batch_size,\n",
        "            shuffle=True,\n",
        "            drop_last=True,\n",
        "            worker_init_fn=lambda worker_id: torch.manual_seed(seed + worker_id),  # Fixes worker shuffling\n",
        "            generator=g\n",
        "            )\n",
        "\n",
        "        model.train()\n",
        "\n",
        "        # Opacus setup\n",
        "\n",
        "        sigma = get_noise_multiplier(\n",
        "            target_epsilon=8,\n",
        "            target_delta=1e-5,\n",
        "            sample_rate=500 / len(dataset['train']),\n",
        "            epochs=200\n",
        "        )\n",
        "        print(sigma)\n",
        "\n",
        "        privacy_engine = NormalfairPrivacyEngine()\n",
        "\n",
        "        model, optimizer, train_loader = privacy_engine.make_private(\n",
        "            module=model,\n",
        "            optimizer=optimizer,\n",
        "            data_loader=train_dataloader,\n",
        "            noise_multiplier=sigma,\n",
        "            max_grad_norm=0.1,\n",
        "            poisson_sampling=False, # poisson_sampling=True, If I was using drop_last = False in dataloader\n",
        "            loss_reduction=\"sum\" #sum\n",
        "        )\n",
        "\n",
        "        print(f\"Loaded checkpoint from {checkpoint_path}, starting from epoch {start_epoch+1}\")\n",
        "\n",
        "        model.load_state_dict(checkpoint[\"model_state_dict\"], strict=False)  # strict=False allows missing keys\n",
        "        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
        "\n",
        "        # Extract LR from loaded optimizer\n",
        "        for param_group in optimizer.param_groups:\n",
        "            lr = param_group[\"lr\"]  # Extract the learning rate from the optimizer checkpoint\n",
        "\n",
        "        print(f\"Loaded learning rate from checkpoint: {lr}\")\n",
        "\n",
        "\n",
        "        return model, tokenizer, optimizer , start_epoch, best_val_f1\n",
        "\n",
        "    # Load a checkpoint, Define paths\n",
        "    checkpoint_path = CHECKPOINT_PATH\n",
        "    model_name_or_path = MODEL_NAME_OR_PATH\n",
        "    #tokenizer_path = \"medicalai/ClinicalBERT\"\n",
        "    tokenizer_path = TOKENIZER_PATH # Where tokenizer was saved\n",
        "\n",
        "    model, tokenizer, optimizer, start_epoch, best_val_f1 = load_checkpoint(checkpoint_path, model_name_or_path,tokenizer_path,dataset)\n",
        "\n",
        "\n",
        "    def filter_ethnicity(dataset_split):\n",
        "        \"\"\"Filter dataset by gender and tokenize.\"\"\"\n",
        "        white_dataset = dataset_split.filter(lambda example: example['ethnicity_group'] == 'White')\n",
        "        nonwhite_dataset = dataset_split.filter(lambda example: example['ethnicity_group'] != 'White')\n",
        "\n",
        "        # Keep only relevant columns\n",
        "        dataset_split = dataset_split.select_columns([\"text\", \"los_label\"])\n",
        "\n",
        "        WhiteEncodings = tokenizer(\n",
        "            white_dataset['text'],\n",
        "            truncation=True,\n",
        "            padding='max_length',\n",
        "            max_length=512,\n",
        "            return_tensors='pt'\n",
        "        )\n",
        "\n",
        "        Whitelabels = torch.tensor([int(label) for label in white_dataset['los_label']])\n",
        "\n",
        "        nonWhiteEncodings = tokenizer(\n",
        "            nonwhite_dataset['text'],\n",
        "            truncation=True,\n",
        "            padding='max_length',\n",
        "            max_length=512,\n",
        "            return_tensors='pt'\n",
        "        )\n",
        "\n",
        "        nonWhitelabels = torch.tensor([int(label) for label in nonwhite_dataset['los_label']])\n",
        "\n",
        "        return TextDataset(WhiteEncodings, Whitelabels), TextDataset(nonWhiteEncodings, nonWhitelabels)\n",
        "\n",
        "    White_dataset, nonWhite_dataset = filter_ethnicity(dataset['test']) #White\n",
        "\n",
        "    White_dataloader = DataLoader(White_dataset, batch_size=500)\n",
        "    nonWhite_dataloader = DataLoader(nonWhite_dataset, batch_size=500)\n",
        "\n",
        "\n",
        "    def compute_metrics(predictions, labels):\n",
        "        predictions = np.argmax(predictions, axis=-1)\n",
        "        return {\n",
        "            \"accuracy\": accuracy_score(labels, predictions),\n",
        "            \"precision\": precision_score(labels, predictions, average='macro'),\n",
        "            \"recall\": recall_score(labels, predictions, average='macro'),\n",
        "            \"f1\": f1_score(labels, predictions, average='macro')\n",
        "        }\n",
        "\n",
        "    def compute_gradient_norms(model, dataloader, optimizer, group, groupname):\n",
        "\n",
        "        wandb.init(project=\"gradient_analysis_per_gender\", name=\"white and non-white\")\n",
        "\n",
        "        model.train()  # Enable gradient tracking\n",
        "        total_grad_norm_before = 0.0\n",
        "        total_grad_norm_after = 0.0\n",
        "        count = 0\n",
        "        eval_loss = 0\n",
        "        total_micro_batches = 0\n",
        "        micro_batch_size = 40\n",
        "        grad_list = []  # Initialize grad_list\n",
        "        all_logical_labels = []\n",
        "        all_logical_predictions = []\n",
        "        layer_wise_grad_samples = {}\n",
        "        cosine_similarities = []\n",
        "        angles = []\n",
        "\n",
        "        accuracy_White = []\n",
        "        accuracy_nonWhite = []\n",
        "\n",
        "        grad_before_White = []\n",
        "        grad_before_nonWhite = []\n",
        "\n",
        "        grad_after_White = []\n",
        "        grad_after_nonWhite = []\n",
        "\n",
        "        f1score_White = []\n",
        "        f1score_nonWhite = []\n",
        "\n",
        "        ratio_White = []\n",
        "        ratio_nonWhite = []\n",
        "\n",
        "        loss_White = []\n",
        "        loss_nonWhite = []\n",
        "\n",
        "        #with torch.no_grad():\n",
        "        with BatchMemoryManager(\n",
        "                data_loader=dataloader,\n",
        "                max_physical_batch_size=micro_batch_size,\n",
        "                optimizer=optimizer\n",
        "            ) as memory_safe_data_loader:\n",
        "\n",
        "            for batch in dataloader:\n",
        "\n",
        "                total_micro_batches += 1\n",
        "\n",
        "                optimizer.zero_grad()\n",
        "                batch = {k: v.to('cuda') for k, v in batch.items()}\n",
        "                outputs = model(**batch)\n",
        "                loss_fn = torch.nn.CrossEntropyLoss(reduction=\"sum\")\n",
        "                loss = loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)), batch['labels'].view(-1))\n",
        "                eval_loss += loss.item()\n",
        "                logits = outputs.logits\n",
        "                labels = batch['labels']\n",
        "\n",
        "                # Compute and log metrics every batch\n",
        "                predictions = logits.detach().cpu().numpy()\n",
        "                labels_cpu = labels.detach().cpu().numpy()\n",
        "                all_logical_predictions.extend(predictions)\n",
        "                all_logical_labels.extend(labels_cpu)\n",
        "\n",
        "                # Backward pass\n",
        "                loss.backward()\n",
        "\n",
        "                # Compute gradient norm before clipping\n",
        "                for name, p in model.named_parameters():\n",
        "                    if p.requires_grad:\n",
        "                        with torch.no_grad():\n",
        "                            grad_sample = p.grad_sample.sum(dim = 0).clone().detach()  #birone if #sum accross samples\n",
        "                            #if grad_sample.dim() == 1:  # 1D tensor (e.g., `[768]`, `[4]`)\n",
        "                            #    grad_sample = grad_sample.unsqueeze(1)\n",
        "                            #grad_list.append(grad_sample)# item haro jam  # birone if\n",
        "                            if name not in layer_wise_grad_samples:\n",
        "                                layer_wise_grad_samples[name] = torch.zeros_like(grad_sample)  # Store per-layer name\n",
        "                                #layer_wise_grad_samples[name] = []  # Store per-layer name\n",
        "                            #layer_wise_grad_samples[name].append(grad_sample)\n",
        "                            layer_wise_grad_samples[name] += grad_sample  # Accumulate across physical batches\n",
        "\n",
        "                is_final_step = optimizer.pre_step()\n",
        "\n",
        "                if is_final_step:\n",
        "                    total_grad_norm_squared = 0.0\n",
        "                    total_grad_norm_squared_summed_noisy = 0.0\n",
        "                    total_noise_grad = 0.0\n",
        "                    all_layer_grads_before=[]\n",
        "                    all_layer_grads_after=[]\n",
        "\n",
        "                    for name, grad_list in layer_wise_grad_samples.items():\n",
        "                        all_layer_grads_before.append(grad_list)  # Store per-layer gradients\n",
        "                        ## Compute the L2 norm squared\n",
        "                        total_grad_norm_squared += torch.norm(grad_list, p=2) ** 2\n",
        "\n",
        "                    #for name, grad_list in layer_wise_grad_samples.items():\n",
        "                    #    grads_tensor = torch.stack(grad_list)  # Stack gradients across all physical batches\n",
        "                    #    total_grad_sample = grads_tensor.sum(dim=0)  # Sum across all micro-batches\n",
        "                    #    all_layer_grads_before.append(total_grad_sample)\n",
        "                    #    # Compute the L2 norm squared\n",
        "                    #    total_grad_norm_squared += torch.norm(total_grad_sample, p=2) ** 2\n",
        "\n",
        "                    overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "\n",
        "                    for name, p in model.named_parameters():\n",
        "                        if p.requires_grad:\n",
        "                            with torch.no_grad():\n",
        "                                all_layer_grads_after.append(p.summed_grad.clone().detach())\n",
        "\n",
        "                                layer_grad_norm_summed_noisy = torch.norm(p.summed_grad, p=2).item()\n",
        "                                total_grad_norm_squared_summed_noisy += layer_grad_norm_summed_noisy ** 2\n",
        "\n",
        "                                grad = torch.norm(p.grad, p=2).item()\n",
        "                                total_noise_grad += (grad - layer_grad_norm_summed_noisy) ** 2\n",
        "\n",
        "\n",
        "                    # Convert lists to single tensors\n",
        "                    total_grad_before = torch.cat([v.view(-1) for v in all_layer_grads_before])\n",
        "                    total_grad_after = torch.cat([v.view(-1) for v in all_layer_grads_after])  # Aggregate all layers\n",
        "\n",
        "\n",
        "                    # Compute Cosine Similarity\n",
        "                    cosine_similarity = torch.nn.functional.cosine_similarity(\n",
        "                        total_grad_before.view(1, -1),\n",
        "                        total_grad_after.view(1, -1),\n",
        "                        dim=1\n",
        "                    ).item()\n",
        "\n",
        "                    # Compute Angle in Degrees\n",
        "                    angle = torch.acos(torch.clamp(torch.tensor(cosine_similarity), -1.0, 1.0)).item() * (180 / torch.pi)\n",
        "\n",
        "                    # Store values for logging\n",
        "                    cosine_similarities.append(cosine_similarity)\n",
        "                    angles.append(angle)\n",
        "\n",
        "                    #print(f\"Cosine Similarity: {cosine_similarity}\")\n",
        "                    #print(f\"Angle in Degrees: {angle}\")\n",
        "\n",
        "                    # Compute overall gradient norm\n",
        "                    overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "                    overall_grad_norm_summed_noisy = total_grad_norm_squared_summed_noisy ** 0.5\n",
        "                    overal_noise_grad = total_noise_grad ** 0.5\n",
        "\n",
        "                    #print(f\"logical batch grad_norm before clipping: {overall_grad_norm}\")\n",
        "                    #print(f\"logical batch grad norm after clipping with Noise: {overall_grad_norm_summed_noisy}\")\n",
        "                    #print(f\"overal noise: grad - summed_grad: {overal_noise_grad}\")\n",
        "                    #print(f\"summed_grad/noise grad: {overall_grad_norm_summed_noisy/overal_noise_grad}\")\n",
        "\n",
        "                    # Compute metrics after all physical batches in the logical batch are processed\n",
        "                    eval_metrics = compute_metrics(np.array(all_logical_predictions), np.array(all_logical_labels))\n",
        "                    avg_eval_loss = eval_loss / total_micro_batches\n",
        "\n",
        "                    # Compute the average similarity and angle across all layers\n",
        "                    avg_cosine_similarity = sum(cosine_similarities) / len(cosine_similarities)\n",
        "                    avg_angle = sum(angles) / len(angles)\n",
        "\n",
        "\n",
        "                    if groupname=='ethnicity':\n",
        "                        wandb.log({\n",
        "                            f\"eval_loss_{group}\": avg_eval_loss,\n",
        "                            f\"eval_accuracy_{group}\": eval_metrics['accuracy'],\n",
        "                            f\"eval_precision_{group}\": eval_metrics['precision'],\n",
        "                            f\"eval_recall_{group}\": eval_metrics['recall'],\n",
        "                            f\"eval_f1_{group}\": eval_metrics['f1'],\n",
        "                            f\"logical batch grad_norm before clipping: {group}\": overall_grad_norm,\n",
        "                            f\"logical batch grad norm after clipping with Noise: {group}\": overall_grad_norm_summed_noisy,\n",
        "                            f\"overal noise: grad - summed_gradg: {group}\": overal_noise_grad,\n",
        "                            f\"summed_grad/noise grad: {group}\":(overall_grad_norm_summed_noisy/overal_noise_grad),\n",
        "                            f\"norm after clipping / norm before clipping: {group} \" :overall_grad_norm_summed_noisy/overall_grad_norm,\n",
        "                            f\"logical_batch_cosine_similarity: {group}\": avg_cosine_similarity,\n",
        "                            f\"logical_batch_angle_degrees: {group}\": avg_angle,\n",
        "                            f\"Cosine Similarity_{group}\": cosine_similarity,\n",
        "                            f\"Angle in Degrees_{group}\": angle\n",
        "                        })\n",
        "\n",
        "                        #print(f\"eval_loss_{group}: {avg_eval_loss}\")\n",
        "                        #print(f\"eval_accuracy_{group}: {eval_metrics['accuracy']}\")\n",
        "                        #print(f\"eval_precision_{group}: {eval_metrics['precision']}\")\n",
        "                        #print(f\"eval_recall_{group}: { eval_metrics['recall']}\")\n",
        "                        #print(f\"eval_f1_{group}: {eval_metrics['f1']}\")\n",
        "                        #print(f\"logical_batch_cosine_similarity{group}: {avg_cosine_similarity}\")\n",
        "                        #print(f\"logical_batch_angle_degrees{group}: {avg_angle}\")\n",
        "\n",
        "                        if group=='White':\n",
        "                            accuracy_White.append(eval_metrics['accuracy'])\n",
        "                            f1score_White.append(eval_metrics['f1'])\n",
        "                            grad_before_White.append(overall_grad_norm)\n",
        "                            grad_after_White.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_White.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "                            loss_White.append(avg_eval_loss)\n",
        "\n",
        "                        elif group=='nonWhite':\n",
        "                            accuracy_nonWhite.append(eval_metrics['accuracy'])\n",
        "                            f1score_nonWhite.append(eval_metrics['f1'])\n",
        "                            grad_before_nonWhite.append(overall_grad_norm)\n",
        "                            grad_after_nonWhite.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_nonWhite.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "                            loss_nonWhite.append(avg_eval_loss)\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "                    all_logical_labels = []\n",
        "                    all_logical_predictions = []\n",
        "                    grad_list = []  # Reset grad_list\n",
        "                    layer_wise_grad_samples = {}\n",
        "\n",
        "\n",
        "                optimizer.zero_grad()\n",
        "                count += 1\n",
        "\n",
        "\n",
        "            if groupname=='ethnicity':\n",
        "\n",
        "                if group == 'White':\n",
        "                    avg_grad_norm_before_White = sum(grad_before_White)/len(grad_before_White)\n",
        "                    avg_grad_norm_after_White = sum(grad_after_White) / len(grad_after_White)\n",
        "                    avg_accuracy_White = sum(accuracy_White) / len(accuracy_White)\n",
        "                    avg_f1_White = sum(f1score_White) / len(f1score_White)\n",
        "                    avg_loss_White = sum(loss_White) / len(loss_White)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_White}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_White}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_White}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_White}\")\n",
        "                    print(f\"Average loss for {group}: {avg_loss_White}\")\n",
        "\n",
        "\n",
        "                elif group == 'nonWhite':\n",
        "                    avg_grad_norm_before_nonWhite = sum(grad_before_nonWhite)/len(grad_before_nonWhite)\n",
        "                    avg_grad_norm_after_nonWhite = sum(grad_after_nonWhite) / len(grad_after_nonWhite)\n",
        "                    avg_accuracy_nonWhite = sum(accuracy_nonWhite) / len(accuracy_nonWhite)\n",
        "                    avg_f1_nonWhite = sum(f1score_nonWhite) / len(f1score_nonWhite)\n",
        "                    avg_loss_nonWhite = sum(loss_nonWhite) / len(loss_nonWhite)\n",
        "\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_nonWhite}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_nonWhite}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_nonWhite}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_nonWhite}\")\n",
        "                    print(f\"Average loss for {group}: {avg_loss_nonWhite}\")\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "    # Compute gradient norms for different groups\n",
        "\n",
        "    print(\"\\nComputing gradient norms for White dataset:\")\n",
        "    compute_gradient_norms(model, White_dataloader, optimizer,'White' , groupname='ethnicity')\n",
        "\n",
        "    print(\"\\nComputing gradient norms for non-White dataset:\")\n",
        "    compute_gradient_norms(model, nonWhite_dataloader, optimizer,'nonWhite' , groupname='ethnicity')\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    import fire\n",
        "    fire.Fire(main)"
      ],
      "metadata": {
        "id": "DIKChDYSmD6W"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Testing Gender**"
      ],
      "metadata": {
        "id": "3bg1QmBnmWTl"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#gender and multi group ethnicity Fair Adaptive DP (New approach)\n",
        "from torch.utils.data.dataset import Dataset\n",
        "from torch.utils.data import DataLoader\n",
        "from datasets import load_dataset\n",
        "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
        "import torch\n",
        "import wandb\n",
        "import numpy as np\n",
        "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
        "import fire\n",
        "import opacus\n",
        "from opacus.utils.batch_memory_manager import BatchMemoryManager\n",
        "from opacus.accountants.utils import get_noise_multiplier\n",
        "import torch.nn.functional as F\n",
        "import os\n",
        "from opacus.optimizers import AdaClipDPOptimizer, DPOptimizer\n",
        "from opacus.optimizers.optimizer import (\n",
        "    _check_processed_flag,\n",
        "    _generate_noise,\n",
        "    _mark_as_processed\n",
        ")\n",
        "from opacus import PrivacyEngine\n",
        "import torch.optim as optim\n",
        "from typing import List, Union\n",
        "from opacus.optimizers import DPOptimizer\n",
        "\n",
        "\n",
        "def main():\n",
        "\n",
        "    def seed_everything(seed):\n",
        "        \"\"\"Set the seed for reproducibility across all libraries, including GPU operations.\"\"\"\n",
        "        os.environ['PYTHONHASHSEED'] = str(seed)\n",
        "        np.random.seed(seed)\n",
        "        torch.manual_seed(seed)\n",
        "        torch.cuda.manual_seed(seed)\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "        # Ensure deterministic behavior in cuDNN\n",
        "        torch.backends.cudnn.deterministic = True\n",
        "        torch.backends.cudnn.benchmark = False  # Set to False for full determinism\n",
        "\n",
        "        print(f\"Global seed set to: {seed}\")\n",
        "\n",
        "    seed_everything(seed)\n",
        "\n",
        "    # Load dataset\n",
        "    cache_dir = CHACHE_DIR\n",
        "    data_dir = DATA_DIR\n",
        "    data_files = {\n",
        "        \"train\": f\"{data_dir}/LOS_WEEKS_train.csv\",\n",
        "        \"validation\": f\"{data_dir}/LOS_WEEKS_val.csv\",\n",
        "        \"test\": f\"{data_dir}/LOS_WEEKS_test.csv\",\n",
        "    }\n",
        "    dataset = load_dataset('csv', data_files=data_files, cache_dir=cache_dir)\n",
        "\n",
        "    class TextDataset(Dataset):\n",
        "        def __init__(self, encodings, labels):\n",
        "            self.encodings = encodings\n",
        "            self.labels = labels\n",
        "\n",
        "        def __getitem__(self, idx):\n",
        "            item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
        "            item['labels'] = torch.tensor(self.labels[idx])\n",
        "            return item\n",
        "\n",
        "        def __len__(self):\n",
        "            return len(self.labels)\n",
        "\n",
        "    def load_checkpoint(checkpoint_path, model_name_or_path,tokenizer_path,dataset):\n",
        "\n",
        "        checkpoint = torch.load(checkpoint_path,weights_only=False)\n",
        "        model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, num_labels=4)\n",
        "        # Remove `_module` created by opacus\n",
        "        #new_state_dict = {}\n",
        "        #for key, value in checkpoint[\"model_state_dict\"].items():\n",
        "        #    new_key = key.replace(\"_module.\", \"\")  # Remove _module prefix\n",
        "        #    new_state_dict[new_key] = value\n",
        "\n",
        "        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)\n",
        "\n",
        "        #Freeze layers\n",
        "        for name, parameter in model.named_parameters():\n",
        "            if \"classifier\" in name or \"layer.5.output_layer_norm\" in name or \"layer.5.ffn.lin2\" in name or \"layer.5.ffn.lin1\" in name or \"layer.5.sa_layer_norm\" in name or \"layer.5.attention.out_lin\" in name or \"layer.5.attention.v_lin\" in name or \"layer.5.attention.k_lin\" in name:\n",
        "                parameter.requires_grad = True\n",
        "            else:\n",
        "                parameter.requires_grad = False\n",
        "\n",
        "        model.to(\"cuda\")\n",
        "        start_epoch = checkpoint['epoch']\n",
        "        best_val_f1 = checkpoint['best_val_f1']\n",
        "\n",
        "        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)\n",
        "\n",
        "        def process_dataset(dataset_split):\n",
        "            dataset_split = dataset_split.select_columns([\"text\", \"los_label\"])\n",
        "\n",
        "            encodings = tokenizer(\n",
        "                dataset_split['text'],\n",
        "                truncation=True,\n",
        "                padding='max_length',\n",
        "                max_length=512,\n",
        "                return_tensors='pt'\n",
        "            )\n",
        "            labels = [int(label) for label in dataset_split['los_label']]\n",
        "            return TextDataset(encodings, labels)\n",
        "\n",
        "\n",
        "        train_dataset = process_dataset(dataset['train'])\n",
        "\n",
        "        g = torch.Generator()\n",
        "        g.manual_seed(seed)\n",
        "\n",
        "        train_dataloader = DataLoader(\n",
        "            train_dataset,\n",
        "            batch_size=500, #batch_size,\n",
        "            shuffle=True,\n",
        "            drop_last=True,\n",
        "            worker_init_fn=lambda worker_id: torch.manual_seed(seed + worker_id),  # Fixes worker shuffling\n",
        "            generator=g\n",
        "            )\n",
        "\n",
        "        model.train()\n",
        "\n",
        "        # Opacus setup\n",
        "        sigma = get_noise_multiplier(\n",
        "            target_epsilon=8,\n",
        "            target_delta=1e-5,\n",
        "            sample_rate=500 / len(dataset['train']),\n",
        "            epochs=200\n",
        "        )\n",
        "        print(sigma)\n",
        "\n",
        "        privacy_engine = NormalfairPrivacyEngine()\n",
        "\n",
        "        model, optimizer, train_loader = privacy_engine.make_private(\n",
        "            module=model,\n",
        "            optimizer=optimizer,\n",
        "            data_loader=train_dataloader,\n",
        "            noise_multiplier=sigma,\n",
        "            max_grad_norm=0.1,\n",
        "            poisson_sampling=False, # poisson_sampling=True, If I was using drop_last = False in dataloader\n",
        "            loss_reduction=\"sum\" #sum\n",
        "        )\n",
        "\n",
        "\n",
        "        print(f\"Loaded checkpoint from {checkpoint_path}, starting from epoch {start_epoch+1}\")\n",
        "\n",
        "        model.load_state_dict(checkpoint[\"model_state_dict\"], strict=False)  # strict=False allows missing keys\n",
        "\n",
        "        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
        "\n",
        "        # Extract LR from loaded optimizer\n",
        "        for param_group in optimizer.param_groups:\n",
        "            lr = param_group[\"lr\"]  # Extract the learning rate from the optimizer checkpoint\n",
        "\n",
        "        print(f\"Loaded learning rate from checkpoint: {lr}\")\n",
        "\n",
        "\n",
        "        return model, tokenizer, optimizer , start_epoch, best_val_f1\n",
        "\n",
        "    # Load a checkpoint, Define paths\n",
        "    checkpoint_path = CHECKPOINT_PATH\n",
        "    model_name_or_path = MODEL_NAME_OR_PATH\n",
        "    tokenizer_path = TOKENIZER_PATH\n",
        "\n",
        "    model, tokenizer, optimizer, start_epoch, best_val_f1 = load_checkpoint(checkpoint_path, model_name_or_path,tokenizer_path,dataset)\n",
        "\n",
        "\n",
        "    def filter_ethnicity(dataset_split, ethnicity_group):\n",
        "        \"\"\"Filter dataset by gender and tokenize.\"\"\"\n",
        "        dataset_split = dataset_split.filter(lambda example: example['ethnicity_group'] == ethnicity_group)\n",
        "\n",
        "        # Keep only relevant columns\n",
        "        dataset_split = dataset_split.select_columns([\"text\", \"los_label\"])\n",
        "\n",
        "        encodings = tokenizer(\n",
        "            dataset_split['text'],\n",
        "            truncation=True,\n",
        "            padding='max_length',\n",
        "            max_length=512,\n",
        "            return_tensors='pt'\n",
        "        )\n",
        "\n",
        "        labels = torch.tensor([int(label) for label in dataset_split['los_label']])\n",
        "        return TextDataset(encodings, labels)\n",
        "\n",
        "    def filter_gender(dataset_split, gender, num_samples=None):\n",
        "        \"\"\"Filter dataset by gender and tokenize.\"\"\"\n",
        "        dataset_split = dataset_split.filter(lambda example: example['gender'] == gender)\n",
        "\n",
        "        # Select only the first `num_samples` if provided (only for females)\n",
        "        if num_samples is not None:\n",
        "            dataset_split = dataset_split.select(range(min(num_samples, len(dataset_split))))\n",
        "\n",
        "        # Keep only relevant columns\n",
        "        dataset_split = dataset_split.select_columns([\"text\", \"los_label\"])\n",
        "\n",
        "        encodings = tokenizer(\n",
        "            dataset_split['text'],\n",
        "            truncation=True,\n",
        "            padding='max_length',\n",
        "            max_length=512,\n",
        "            return_tensors='pt'\n",
        "        )\n",
        "\n",
        "        labels = torch.tensor([int(label) for label in dataset_split['los_label']])\n",
        "        return TextDataset(encodings, labels)\n",
        "\n",
        "    # Create gender-specific dataloaders\n",
        "    female_dataset = filter_gender(dataset['test'], gender='F', num_samples=500) #female\n",
        "    male_dataset = filter_gender(dataset['test'], gender='M') #male\n",
        "\n",
        "    White_dataset = filter_ethnicity(dataset['test'], ethnicity_group='White') #White\n",
        "    Black_dataset = filter_ethnicity(dataset['test'], ethnicity_group='Black') #Black\n",
        "    Hispanic_dataset = filter_ethnicity(dataset['test'], ethnicity_group='Hispanic') #Hispanic\n",
        "    Asian_dataset = filter_ethnicity(dataset['test'], ethnicity_group='Asian') #Asian\n",
        "    other_dataset = filter_ethnicity(dataset['test'], ethnicity_group='other') #other\n",
        "    unknown_dataset = filter_ethnicity(dataset['test'], ethnicity_group='unknown') #unknown\n",
        "\n",
        "\n",
        "    female_dataloader = DataLoader(female_dataset, batch_size=500)\n",
        "    male_dataloader = DataLoader(male_dataset, batch_size=500)\n",
        "\n",
        "    White_dataloader = DataLoader(White_dataset, batch_size=500)\n",
        "    Black_dataloader = DataLoader(Black_dataset, batch_size=500)\n",
        "    Hispanic_dataloader = DataLoader(Hispanic_dataset, batch_size=500)\n",
        "    Asian_dataloader = DataLoader(Asian_dataset, batch_size=500)\n",
        "    other_dataloader = DataLoader(other_dataset, batch_size=500)\n",
        "    unknown_dataloader = DataLoader(unknown_dataset, batch_size=500)\n",
        "\n",
        "\n",
        "    def compute_metrics(predictions, labels):\n",
        "        predictions = np.argmax(predictions, axis=-1)\n",
        "        return {\n",
        "            \"accuracy\": accuracy_score(labels, predictions),\n",
        "            \"precision\": precision_score(labels, predictions, average='macro'),\n",
        "            \"recall\": recall_score(labels, predictions, average='macro'),\n",
        "            \"f1\": f1_score(labels, predictions, average='macro')\n",
        "        }\n",
        "\n",
        "    def compute_gradient_norms(model, dataloader, optimizer, group, groupname):\n",
        "\n",
        "        wandb.init(project=\"gradient_analysis_per_gender\", name=\"evaluation_run\")\n",
        "\n",
        "        model.train()  # Enable gradient tracking\n",
        "        total_grad_norm_before = 0.0\n",
        "        total_grad_norm_after = 0.0\n",
        "        count = 0\n",
        "        eval_loss = 0\n",
        "        total_micro_batches = 0\n",
        "        micro_batch_size = 40\n",
        "        grad_list = []  # Initialize grad_list\n",
        "        all_logical_labels = []\n",
        "        all_logical_predictions = []\n",
        "        layer_wise_grad_samples = {}\n",
        "        cosine_similarities = []\n",
        "        angles = []\n",
        "\n",
        "        accuracy_male = []\n",
        "        accuracy_female = []\n",
        "        accuracy_White = []\n",
        "        accuracy_Black = []\n",
        "        accuracy_Hispanic = []\n",
        "        accuracy_Asian = []\n",
        "        accuracy_unknown = []\n",
        "        accuracy_other = []\n",
        "\n",
        "        grad_before_male = []\n",
        "        grad_before_female = []\n",
        "        grad_before_White = []\n",
        "        grad_before_Black = []\n",
        "        grad_before_Hispanic = []\n",
        "        grad_before_Asian = []\n",
        "        grad_before_unknown = []\n",
        "        grad_before_other = []\n",
        "\n",
        "        grad_after_male = []\n",
        "        grad_after_female = []\n",
        "        grad_after_White = []\n",
        "        grad_after_Black = []\n",
        "        grad_after_Hispanic = []\n",
        "        grad_after_Asian = []\n",
        "        grad_after_unknown = []\n",
        "        grad_after_other = []\n",
        "\n",
        "        f1score_male = []\n",
        "        f1score_female = []\n",
        "        f1score_White = []\n",
        "        f1score_Black = []\n",
        "        f1score_Hispanic = []\n",
        "        f1score_Asian = []\n",
        "        f1score_unknown = []\n",
        "        f1score_other = []\n",
        "\n",
        "        ratio_male = []\n",
        "        ratio_female = []\n",
        "        ratio_White = []\n",
        "        ratio_Black = []\n",
        "        ratio_Hispanic = []\n",
        "        ratio_Asian = []\n",
        "        ratio_unknown = []\n",
        "        ratio_other = []\n",
        "\n",
        "        loss_male = []\n",
        "        loss_female = []\n",
        "\n",
        "\n",
        "        #with torch.no_grad():\n",
        "        with BatchMemoryManager(\n",
        "                data_loader=dataloader,\n",
        "                max_physical_batch_size=micro_batch_size,\n",
        "                optimizer=optimizer\n",
        "            ) as memory_safe_data_loader:\n",
        "\n",
        "            for batch in dataloader:\n",
        "\n",
        "                total_micro_batches += 1\n",
        "\n",
        "                optimizer.zero_grad()\n",
        "                batch = {k: v.to('cuda') for k, v in batch.items()}\n",
        "                outputs = model(**batch)\n",
        "                loss_fn = torch.nn.CrossEntropyLoss(reduction=\"sum\")\n",
        "                loss = loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)), batch['labels'].view(-1))\n",
        "                eval_loss += loss.item()\n",
        "                logits = outputs.logits\n",
        "                labels = batch['labels']\n",
        "\n",
        "                # Compute and log metrics every batch\n",
        "                predictions = logits.detach().cpu().numpy()\n",
        "                labels_cpu = labels.detach().cpu().numpy()\n",
        "                all_logical_predictions.extend(predictions)\n",
        "                all_logical_labels.extend(labels_cpu)\n",
        "\n",
        "                # Backward pass\n",
        "                loss.backward()\n",
        "\n",
        "                # Compute gradient norm before clipping\n",
        "                for name, p in model.named_parameters():\n",
        "                    if p.requires_grad:\n",
        "                        with torch.no_grad():\n",
        "                            grad_sample = p.grad_sample.sum(dim = 0).clone().detach()\n",
        "                            if name not in layer_wise_grad_samples:\n",
        "                                layer_wise_grad_samples[name] = torch.zeros_like(grad_sample)  # Store per-layer name\n",
        "                            layer_wise_grad_samples[name] += grad_sample  # Accumulate across physical batches\n",
        "\n",
        "                is_final_step = optimizer.pre_step()\n",
        "\n",
        "                if is_final_step:\n",
        "                    total_grad_norm_squared = 0.0\n",
        "                    total_grad_norm_squared_summed_noisy = 0.0\n",
        "                    total_noise_grad = 0.0\n",
        "                    all_layer_grads_before=[]\n",
        "                    all_layer_grads_after=[]\n",
        "\n",
        "                    for name, grad_list in layer_wise_grad_samples.items():\n",
        "                        all_layer_grads_before.append(grad_list)  # Store per-layer gradients\n",
        "                        ## Compute the L2 norm squared\n",
        "                        total_grad_norm_squared += torch.norm(grad_list, p=2) ** 2\n",
        "\n",
        "\n",
        "                    overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "\n",
        "                    for name, p in model.named_parameters():\n",
        "                        if p.requires_grad:\n",
        "                            with torch.no_grad():\n",
        "                                all_layer_grads_after.append(p.summed_grad.clone().detach())\n",
        "\n",
        "                                layer_grad_norm_summed_noisy = torch.norm(p.summed_grad, p=2).item()\n",
        "                                total_grad_norm_squared_summed_noisy += layer_grad_norm_summed_noisy ** 2\n",
        "\n",
        "                                grad = torch.norm(p.grad, p=2).item()\n",
        "                                total_noise_grad += (grad - layer_grad_norm_summed_noisy) ** 2\n",
        "\n",
        "\n",
        "                    # Convert lists to single tensors\n",
        "                    total_grad_before = torch.cat([v.view(-1) for v in all_layer_grads_before])\n",
        "                    total_grad_after = torch.cat([v.view(-1) for v in all_layer_grads_after])  # Aggregate all layers\n",
        "\n",
        "\n",
        "                    # Compute Cosine Similarity\n",
        "                    cosine_similarity = torch.nn.functional.cosine_similarity(\n",
        "                        total_grad_before.view(1, -1),\n",
        "                        total_grad_after.view(1, -1),\n",
        "                        dim=1\n",
        "                    ).item()\n",
        "\n",
        "                    # Compute Angle in Degrees\n",
        "                    angle = torch.acos(torch.clamp(torch.tensor(cosine_similarity), -1.0, 1.0)).item() * (180 / torch.pi)\n",
        "\n",
        "                    # Store values for logging\n",
        "                    cosine_similarities.append(cosine_similarity)\n",
        "                    angles.append(angle)\n",
        "\n",
        "                    #print(f\"Cosine Similarity: {cosine_similarity}\")\n",
        "                    #print(f\"Angle in Degrees: {angle}\")\n",
        "\n",
        "                    # Compute overall gradient norm\n",
        "                    overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "                    overall_grad_norm_summed_noisy = total_grad_norm_squared_summed_noisy ** 0.5\n",
        "                    overal_noise_grad = total_noise_grad ** 0.5\n",
        "\n",
        "                    #print(f\"logical batch grad_norm before clipping: {overall_grad_norm}\")\n",
        "                    #print(f\"logical batch grad norm after clipping with Noise: {overall_grad_norm_summed_noisy}\")\n",
        "                    #print(f\"overal noise: grad - summed_grad: {overal_noise_grad}\")\n",
        "                    #print(f\"summed_grad/noise grad: {overall_grad_norm_summed_noisy/overal_noise_grad}\")\n",
        "\n",
        "                    # Compute metrics after all physical batches in the logical batch are processed\n",
        "                    eval_metrics = compute_metrics(np.array(all_logical_predictions), np.array(all_logical_labels))\n",
        "                    avg_eval_loss = eval_loss / total_micro_batches\n",
        "\n",
        "                    # Compute the average similarity and angle across all layers\n",
        "                    avg_cosine_similarity = sum(cosine_similarities) / len(cosine_similarities)\n",
        "                    avg_angle = sum(angles) / len(angles)\n",
        "\n",
        "                    if groupname=='gender':\n",
        "                    # Log to wandb\n",
        "                        wandb.log({\n",
        "                            f\"eval_loss_{group}\": avg_eval_loss,\n",
        "                            f\"eval_accuracy_{group}\": eval_metrics['accuracy'],\n",
        "                            f\"eval_precision_{group}\": eval_metrics['precision'],\n",
        "                            f\"eval_recall_{group}\": eval_metrics['recall'],\n",
        "                            f\"eval_f1_{group}\": eval_metrics['f1'],\n",
        "                            f\"logical batch grad_norm before clipping: {group}\": overall_grad_norm,\n",
        "                            f\"logical batch grad norm after clipping with Noise: {group}\": overall_grad_norm_summed_noisy,\n",
        "                            f\"overal noise: grad - summed_gradg: {group}\": overal_noise_grad,\n",
        "                            f\"summed_grad/noise grad: {group}\":(overall_grad_norm_summed_noisy/overal_noise_grad),\n",
        "                            f\"norm after clipping / norm before clipping: {group} \" :overall_grad_norm_summed_noisy/overall_grad_norm,\n",
        "                            f\"logical_batch_cosine_similarity: {group}\": avg_cosine_similarity,\n",
        "                            f\"logical_batch_angle_degrees: {group}\": avg_angle,\n",
        "                            f\"Cosine Similarity_{group}\": cosine_similarity,\n",
        "                            f\"Angle in Degrees_{group}\": angle\n",
        "                        })\n",
        "\n",
        "                        #print(f\"eval_loss_{group}: {avg_eval_loss}\")\n",
        "                        #print(f\"eval_accuracy_{group}: {eval_metrics['accuracy']}\")\n",
        "                        #print(f\"eval_precision_{group}: {eval_metrics['precision']}\")\n",
        "                        #print(f\"eval_recall_{group}: { eval_metrics['recall']}\")\n",
        "                        #print(f\"eval_f1_{group}: {eval_metrics['f1']}\")\n",
        "                        #print(f\"logical_batch_cosine_similarity{group}: {avg_cosine_similarity}\")\n",
        "                        #print(f\"logical_batch_angle_degrees{group}: {avg_angle}\")\n",
        "\n",
        "                        if group=='male':\n",
        "                            accuracy_male.append(eval_metrics['accuracy'])\n",
        "                            f1score_male.append(eval_metrics['f1'])\n",
        "                            grad_before_male.append(overall_grad_norm)\n",
        "                            grad_after_male.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_male.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "                            loss_male.append(avg_eval_loss)\n",
        "\n",
        "                        elif group=='female':\n",
        "                            accuracy_female.append(eval_metrics['accuracy'])\n",
        "                            f1score_female.append(eval_metrics['f1'])\n",
        "                            grad_before_female.append(overall_grad_norm)\n",
        "                            grad_after_female.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_female.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "                            loss_female.append(avg_eval_loss)\n",
        "\n",
        "\n",
        "                    elif groupname=='ethnicity':\n",
        "                        wandb.log({\n",
        "                            f\"eval_loss_{group}\": avg_eval_loss,\n",
        "                            f\"eval_accuracy_{group}\": eval_metrics['accuracy'],\n",
        "                            f\"eval_precision_{group}\": eval_metrics['precision'],\n",
        "                            f\"eval_recall_{group}\": eval_metrics['recall'],\n",
        "                            f\"eval_f1_{group}\": eval_metrics['f1'],\n",
        "                            f\"logical batch grad_norm before clipping: {group}\": overall_grad_norm,\n",
        "                            f\"logical batch grad norm after clipping with Noise: {group}\": overall_grad_norm_summed_noisy,\n",
        "                            f\"overal noise: grad - summed_gradg: {group}\": overal_noise_grad,\n",
        "                            f\"summed_grad/noise grad: {group}\":(overall_grad_norm_summed_noisy/overal_noise_grad),\n",
        "                            f\"norm after clipping / norm before clipping: {group} \" :overall_grad_norm_summed_noisy/overall_grad_norm,\n",
        "                            f\"logical_batch_cosine_similarity: {group}\": avg_cosine_similarity,\n",
        "                            f\"logical_batch_angle_degrees: {group}\": avg_angle,\n",
        "                            f\"Cosine Similarity_{group}\": cosine_similarity,\n",
        "                            f\"Angle in Degrees_{group}\": angle\n",
        "                        })\n",
        "\n",
        "                        #print(f\"eval_loss_{group}: {avg_eval_loss}\")\n",
        "                        #print(f\"eval_accuracy_{group}: {eval_metrics['accuracy']}\")\n",
        "                        #print(f\"eval_precision_{group}: {eval_metrics['precision']}\")\n",
        "                        #print(f\"eval_recall_{group}: { eval_metrics['recall']}\")\n",
        "                        #print(f\"eval_f1_{group}: {eval_metrics['f1']}\")\n",
        "                        #print(f\"logical_batch_cosine_similarity{group}: {avg_cosine_similarity}\")\n",
        "                        #print(f\"logical_batch_angle_degrees{group}: {avg_angle}\")\n",
        "\n",
        "                        if group=='White':\n",
        "                            accuracy_White.append(eval_metrics['accuracy'])\n",
        "                            f1score_White.append(eval_metrics['f1'])\n",
        "                            grad_before_White.append(overall_grad_norm)\n",
        "                            grad_after_White.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_White.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "\n",
        "                        elif group=='Black':\n",
        "                            accuracy_Black.append(eval_metrics['accuracy'])\n",
        "                            f1score_Black.append(eval_metrics['f1'])\n",
        "                            grad_before_Black.append(overall_grad_norm)\n",
        "                            grad_after_Black.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_Black.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "\n",
        "                        elif group=='Hispanic':\n",
        "                            accuracy_Hispanic.append(eval_metrics['accuracy'])\n",
        "                            f1score_Hispanic.append(eval_metrics['f1'])\n",
        "                            grad_before_Hispanic.append(overall_grad_norm)\n",
        "                            grad_after_Hispanic.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_Hispanic.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "\n",
        "                        elif group=='Asian':\n",
        "                            accuracy_Asian.append(eval_metrics['accuracy'])\n",
        "                            f1score_Asian.append(eval_metrics['f1'])\n",
        "                            grad_before_Asian.append(overall_grad_norm)\n",
        "                            grad_after_Asian.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_Asian.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "\n",
        "                        elif group=='unknown':\n",
        "                            accuracy_unknown.append(eval_metrics['accuracy'])\n",
        "                            f1score_unknown.append(eval_metrics['f1'])\n",
        "                            grad_before_unknown.append(overall_grad_norm)\n",
        "                            grad_after_unknown.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_unknown.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "\n",
        "                        elif group=='other':\n",
        "                            accuracy_other.append(eval_metrics['accuracy'])\n",
        "                            f1score_other.append(eval_metrics['f1'])\n",
        "                            grad_before_other.append(overall_grad_norm)\n",
        "                            grad_after_other.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_other.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "\n",
        "                    all_logical_labels = []\n",
        "                    all_logical_predictions = []\n",
        "                    grad_list = []  # Reset grad_list\n",
        "                    layer_wise_grad_samples = {}\n",
        "\n",
        "\n",
        "                optimizer.zero_grad()\n",
        "                count += 1\n",
        "\n",
        "            if groupname=='gender':\n",
        "\n",
        "                if group == 'male':\n",
        "                    avg_grad_norm_before_male = sum(grad_before_male)/len(grad_before_male)\n",
        "                    avg_grad_norm_after_male = sum(grad_after_male) / len(grad_after_male)\n",
        "                    avg_accuracy_male = sum(accuracy_male) / len(accuracy_male)\n",
        "                    avg_f1_male = sum(f1score_male) / len(f1score_male)\n",
        "                    avg_loss_male = sum(loss_male) / len(loss_male)\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_male}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_male}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_male}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_male}\")\n",
        "                    print(f\"Average loss for {group}: {avg_loss_male}\")\n",
        "\n",
        "\n",
        "                elif group == 'female':\n",
        "                    avg_grad_norm_before_female = sum(grad_before_female)/len(grad_before_female)\n",
        "                    avg_grad_norm_after_female = sum(grad_after_female) / len(grad_after_female)\n",
        "                    avg_accuracy_female = sum(accuracy_female) / len(accuracy_female)\n",
        "                    avg_f1_female = sum(f1score_female) / len(f1score_female)\n",
        "                    avg_loss_female = sum(loss_female) / len(loss_female)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_female}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_female}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_female}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_female}\")\n",
        "                    print(f\"Average loss for {group}: {avg_loss_female}\")\n",
        "\n",
        "\n",
        "            if groupname=='ethnicity':\n",
        "\n",
        "                if group == 'White':\n",
        "                    avg_grad_norm_before_White = sum(grad_before_White)/len(grad_before_White)\n",
        "                    avg_grad_norm_after_White = sum(grad_after_White) / len(grad_after_White)\n",
        "                    avg_accuracy_White = sum(accuracy_White) / len(accuracy_White)\n",
        "                    avg_f1_White = sum(f1score_White) / len(f1score_White)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_White}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_White}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_White}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_White}\")\n",
        "\n",
        "                elif group == 'Black':\n",
        "                    avg_grad_norm_before_Black = sum(grad_before_Black)/len(grad_before_Black)\n",
        "                    avg_grad_norm_after_Black = sum(grad_after_Black) / len(grad_after_Black)\n",
        "                    avg_accuracy_Black = sum(accuracy_Black) / len(accuracy_Black)\n",
        "                    avg_f1_Black = sum(f1score_Black) / len(f1score_Black)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_Black}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_Black}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_Black}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_Black}\")\n",
        "\n",
        "                elif group == 'Hispanic':\n",
        "                    avg_grad_norm_before_Hispanic = sum(grad_before_Hispanic)/len(grad_before_Hispanic)\n",
        "                    avg_grad_norm_after_Hispanic = sum(grad_after_Hispanic) / len(grad_after_Hispanic)\n",
        "                    avg_accuracy_Hispanic = sum(accuracy_Hispanic) / len(accuracy_Hispanic)\n",
        "                    avg_f1_Hispanic = sum(f1score_Hispanic) / len(f1score_Hispanic)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_Hispanic}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_Hispanic}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_Hispanic}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_Hispanic}\")\n",
        "\n",
        "                elif group == 'Asian':\n",
        "                    avg_grad_norm_before_Asian = sum(grad_before_Asian)/len(grad_before_Asian)\n",
        "                    avg_grad_norm_after_Asian = sum(grad_after_Asian) / len(grad_after_Asian)\n",
        "                    avg_accuracy_Asian = sum(accuracy_Asian) / len(accuracy_Asian)\n",
        "                    avg_f1_Asian = sum(f1score_Asian) / len(f1score_Asian)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_Asian}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_Asian}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_Asian}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_Asian}\")\n",
        "\n",
        "                elif group == 'unknown':\n",
        "                    avg_grad_norm_before_unknown = sum(grad_before_unknown)/len(grad_before_unknown)\n",
        "                    avg_grad_norm_after_unknown = sum(grad_after_unknown) / len(grad_after_unknown)\n",
        "                    avg_accuracy_unknown = sum(accuracy_unknown) / len(accuracy_unknown)\n",
        "                    avg_f1_unknown = sum(f1score_unknown) / len(f1score_unknown)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_unknown}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_unknown}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_unknown}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_unknown}\")\n",
        "\n",
        "                elif group == 'other':\n",
        "                    avg_grad_norm_before_other = sum(grad_before_other)/len(grad_before_other)\n",
        "                    avg_grad_norm_after_other = sum(grad_after_other) / len(grad_after_other)\n",
        "                    avg_accuracy_other = sum(accuracy_other) / len(accuracy_other)\n",
        "                    avg_f1_other = sum(f1score_other) / len(f1score_other)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_other}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_other}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_other}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_other}\")\n",
        "\n",
        "\n",
        "\n",
        "    # Compute gradient norms for female and male datasets\n",
        "    print(\"Computing gradient norms for female dataset:\")\n",
        "    compute_gradient_norms(model, female_dataloader, optimizer,'female', groupname='gender')\n",
        "\n",
        "    print(\"\\nComputing gradient norms for male dataset:\")\n",
        "    compute_gradient_norms(model, male_dataloader, optimizer,'male', groupname='gender')\n",
        "\n",
        "\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    import fire\n",
        "    fire.Fire(main)"
      ],
      "metadata": {
        "id": "h0x4ArOpmRsh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Testing Age**"
      ],
      "metadata": {
        "id": "4TdIL6ZImlZb"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#Age and insurance and marital status Adaptive Differential Privacy\n",
        "from torch.utils.data.dataset import Dataset\n",
        "from torch.utils.data import DataLoader\n",
        "from datasets import load_dataset\n",
        "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
        "import torch\n",
        "import wandb\n",
        "import numpy as np\n",
        "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
        "import fire\n",
        "import opacus\n",
        "from opacus.utils.batch_memory_manager import BatchMemoryManager\n",
        "from opacus.accountants.utils import get_noise_multiplier\n",
        "import torch.nn.functional as F\n",
        "import os\n",
        "from opacus import PrivacyEngine\n",
        "import torch.optim as optim\n",
        "from typing import List, Union\n",
        "from opacus.optimizers import DPOptimizer\n",
        "from opacus.optimizers import AdaClipDPOptimizer, DPOptimizer\n",
        "from opacus.optimizers.optimizer import (\n",
        "    _check_processed_flag,\n",
        "    _generate_noise,\n",
        "    _mark_as_processed\n",
        ")\n",
        "\n",
        "\n",
        "def main():\n",
        "\n",
        "    def seed_everything(seed):\n",
        "        \"\"\"Set the seed for reproducibility across all libraries, including GPU operations.\"\"\"\n",
        "        os.environ['PYTHONHASHSEED'] = str(seed)\n",
        "        np.random.seed(seed)\n",
        "        torch.manual_seed(seed)\n",
        "        torch.cuda.manual_seed(seed)\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "        # Ensure deterministic behavior in cuDNN\n",
        "        torch.backends.cudnn.deterministic = True\n",
        "        torch.backends.cudnn.benchmark = False  # Set to False for full determinism\n",
        "\n",
        "        print(f\"Global seed set to: {seed}\")\n",
        "\n",
        "\n",
        "    seed_everything(seed)\n",
        "\n",
        "    # Load dataset\n",
        "    cache_dir = CHACHE_DIR\n",
        "    data_dir = DATA_DIR\n",
        "    data_files = {\n",
        "        \"train\": f\"{data_dir}/LOS_WEEKS_train.csv\",\n",
        "        \"validation\": f\"{data_dir}/LOS_WEEKS_val.csv\",\n",
        "        \"test\": f\"{data_dir}/LOS_WEEKS_test.csv\",\n",
        "    }\n",
        "    dataset = load_dataset('csv', data_files=data_files, cache_dir=cache_dir)\n",
        "\n",
        "    class TextDataset(Dataset):\n",
        "        def __init__(self, encodings, labels):\n",
        "            self.encodings = encodings\n",
        "            self.labels = labels\n",
        "\n",
        "        def __getitem__(self, idx):\n",
        "            item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
        "            item['labels'] = torch.tensor(self.labels[idx])\n",
        "            return item\n",
        "\n",
        "        def __len__(self):\n",
        "            return len(self.labels)\n",
        "\n",
        "    def load_checkpoint(checkpoint_path, model_name_or_path,tokenizer_path,dataset):\n",
        "\n",
        "        checkpoint = torch.load(checkpoint_path,weights_only=False)\n",
        "        model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, num_labels=4)\n",
        "        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)\n",
        "\n",
        "        #Freeze layers\n",
        "        for name, parameter in model.named_parameters():\n",
        "            if \"classifier\" in name or \"layer.5.output_layer_norm\" in name or \"layer.5.ffn.lin2\" in name or \"layer.5.ffn.lin1\" in name or \"layer.5.sa_layer_norm\" in name or \"layer.5.attention.out_lin\" in name or \"layer.5.attention.v_lin\" in name or \"layer.5.attention.k_lin\" in name:\n",
        "                parameter.requires_grad = True\n",
        "            else:\n",
        "                parameter.requires_grad = False\n",
        "\n",
        "        model.to(\"cuda\")\n",
        "        start_epoch = checkpoint['epoch']\n",
        "        best_val_f1 = checkpoint['best_val_f1']\n",
        "\n",
        "        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)\n",
        "\n",
        "        def process_dataset(dataset_split):\n",
        "            dataset_split = dataset_split.select_columns([\"text\", \"los_label\"])\n",
        "\n",
        "            encodings = tokenizer(\n",
        "                dataset_split['text'],\n",
        "                truncation=True,\n",
        "                padding='max_length',\n",
        "                max_length=512,\n",
        "                return_tensors='pt'\n",
        "            )\n",
        "            labels = [int(label) for label in dataset_split['los_label']]\n",
        "            return TextDataset(encodings, labels)\n",
        "\n",
        "        g = torch.Generator()\n",
        "        g.manual_seed(seed)\n",
        "\n",
        "        train_dataset = process_dataset(dataset['train'])\n",
        "\n",
        "        train_dataloader = DataLoader(\n",
        "            train_dataset,\n",
        "            batch_size=500, #batch_size,\n",
        "            shuffle=True,\n",
        "            drop_last=True,\n",
        "            worker_init_fn=lambda worker_id: torch.manual_seed(seed + worker_id),  # Fixes worker shuffling\n",
        "            generator=g\n",
        "            )\n",
        "\n",
        "        model.train()\n",
        "\n",
        "        # Opacus setup\n",
        "        sigma = get_noise_multiplier(\n",
        "            target_epsilon=8,\n",
        "            target_delta=1e-5,\n",
        "            sample_rate=500 / len(dataset['train']),\n",
        "            epochs=200\n",
        "        )\n",
        "        print(sigma)\n",
        "\n",
        "        privacy_engine = NormalfairPrivacyEngine()\n",
        "\n",
        "        model, optimizer, train_loader = privacy_engine.make_private(\n",
        "            module=model,\n",
        "            optimizer=optimizer,\n",
        "            data_loader=train_dataloader,\n",
        "            noise_multiplier=sigma,\n",
        "            max_grad_norm=0.1,\n",
        "            poisson_sampling=False, # poisson_sampling=True, If I was using drop_last = False in dataloader\n",
        "            loss_reduction=\"sum\" #sum\n",
        "        )\n",
        "\n",
        "\n",
        "        print(f\"Loaded checkpoint from {checkpoint_path}, starting from epoch {start_epoch+1}\")\n",
        "\n",
        "        model.load_state_dict(checkpoint[\"model_state_dict\"], strict=False)  # strict=False allows missing keys\n",
        "        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
        "\n",
        "        # Extract LR from loaded optimizer\n",
        "        for param_group in optimizer.param_groups:\n",
        "            lr = param_group[\"lr\"]  # Extract the learning rate from the optimizer checkpoint\n",
        "\n",
        "        print(f\"Loaded learning rate from checkpoint: {lr}\")\n",
        "\n",
        "        return model, tokenizer, optimizer, start_epoch, best_val_f1\n",
        "\n",
        "    # Load a checkpoint, Define paths\n",
        "    checkpoint_path =    CHECKPOINT_PATH\n",
        "    model_name_or_path = MODEL_NAME_OR_PATH\n",
        "    tokenizer_path = TOKENIZER_PATH\n",
        "\n",
        "    model, tokenizer, optimizer, start_epoch, best_val_f1 = load_checkpoint(checkpoint_path, model_name_or_path,tokenizer_path,dataset)\n",
        "\n",
        "\n",
        "    def filter_insurance(dataset_split, insurance_group):\n",
        "        \"\"\"Filter dataset by gender and tokenize.\"\"\"\n",
        "        dataset_split = dataset_split.filter(lambda example: example['insurance'] == insurance_group)\n",
        "\n",
        "        # Keep only relevant columns\n",
        "        dataset_split = dataset_split.select_columns([\"text\", \"los_label\"])\n",
        "\n",
        "        encodings = tokenizer(\n",
        "            dataset_split['text'],\n",
        "            truncation=True,\n",
        "            padding='max_length',\n",
        "            max_length=512,\n",
        "            return_tensors='pt'\n",
        "        )\n",
        "\n",
        "        labels = torch.tensor([int(label) for label in dataset_split['los_label']])\n",
        "        return TextDataset(encodings, labels)\n",
        "\n",
        "    def filter_age_median(dataset_split, age_median_gp, num_samples=None):\n",
        "        \"\"\"Filter dataset by gender and tokenize.\"\"\"\n",
        "        dataset_split = dataset_split.filter(lambda example: example['age_median_gp'] == age_median_gp)\n",
        "\n",
        "        # Select only the first `num_samples` if provided (only for females)\n",
        "        if num_samples is not None:\n",
        "            dataset_split = dataset_split.select(range(min(num_samples, len(dataset_split))))\n",
        "\n",
        "        # Keep only relevant columns\n",
        "        dataset_split = dataset_split.select_columns([\"text\", \"los_label\"])\n",
        "\n",
        "        encodings = tokenizer(\n",
        "            dataset_split['text'],\n",
        "            truncation=True,\n",
        "            padding='max_length',\n",
        "            max_length=512,\n",
        "            return_tensors='pt'\n",
        "        )\n",
        "\n",
        "        labels = torch.tensor([int(label) for label in dataset_split['los_label']])\n",
        "        return TextDataset(encodings, labels)\n",
        "\n",
        "    def filter_age_bins(dataset_split, age_bins, num_samples=None):\n",
        "        \"\"\"Filter dataset by gender and tokenize.\"\"\"\n",
        "        dataset_split = dataset_split.filter(lambda example: example['age_bins'] == age_bins)\n",
        "\n",
        "        # Select only the first `num_samples` if provided (only for females)\n",
        "        if num_samples is not None:\n",
        "            dataset_split = dataset_split.select(range(min(num_samples, len(dataset_split))))\n",
        "\n",
        "        # Keep only relevant columns\n",
        "        dataset_split = dataset_split.select_columns([\"text\", \"los_label\"])\n",
        "\n",
        "        encodings = tokenizer(\n",
        "            dataset_split['text'],\n",
        "            truncation=True,\n",
        "            padding='max_length',\n",
        "            max_length=512,\n",
        "            return_tensors='pt'\n",
        "        )\n",
        "\n",
        "        labels = torch.tensor([int(label) for label in dataset_split['los_label']])\n",
        "        return TextDataset(encodings, labels)\n",
        "\n",
        "    # Create gender-specific dataloaders\n",
        "    below_median_dataset = filter_age_median(dataset['test'], age_median_gp=0) #below median = 0\n",
        "    above_median_dataset = filter_age_median(dataset['test'], age_median_gp=1) #male\n",
        "\n",
        "\n",
        "    Senior_dataset = filter_age_bins(dataset['test'], age_bins='Senior') #below median = 0\n",
        "    Middle_aged_dataset = filter_age_bins(dataset['test'], age_bins='Middle-aged') #male\n",
        "    Elderly_dataset = filter_age_bins(dataset['test'], age_bins='Elderly') #male\n",
        "    Adult_dataset = filter_age_bins(dataset['test'], age_bins='Adult') #male\n",
        "    Young_Adult_dataset = filter_age_bins(dataset['test'], age_bins='Young Adult') #male\n",
        "    Child_dataset = filter_age_bins(dataset['test'], age_bins='Child') #male\n",
        "\n",
        "    Medicare_dataset = filter_insurance(dataset['test'], insurance_group='Medicare')\n",
        "    Private_dataset = filter_insurance(dataset['test'], insurance_group='Private')\n",
        "    Medicaid_dataset = filter_insurance(dataset['test'], insurance_group='Medicaid')\n",
        "    Government_dataset = filter_insurance(dataset['test'], insurance_group='Government')\n",
        "    SelfPay_dataset = filter_insurance(dataset['test'], insurance_group='Self Pay')\n",
        "\n",
        "\n",
        "    below_median_dataloader = DataLoader(below_median_dataset, batch_size=500)\n",
        "    above_median_dataloader = DataLoader(above_median_dataset, batch_size=500)\n",
        "\n",
        "    Medicare_dataloader = DataLoader(Medicare_dataset, batch_size=500)\n",
        "    Private_dataloader = DataLoader(Private_dataset, batch_size=500)\n",
        "    Medicaid_dataloader = DataLoader(Medicaid_dataset, batch_size=500)\n",
        "    Government_dataloader = DataLoader(Government_dataset, batch_size=500)\n",
        "    SelfPay_dataloader = DataLoader(SelfPay_dataset, batch_size=500)\n",
        "\n",
        "\n",
        "    Senior_dataloader = DataLoader(Senior_dataset, batch_size=500)\n",
        "    Middle_aged_dataloader = DataLoader(Middle_aged_dataset, batch_size=500)\n",
        "    Elderly_dataloader = DataLoader(Elderly_dataset, batch_size=500)\n",
        "    Adult_dataloader = DataLoader(Adult_dataset, batch_size=500)\n",
        "    Young_Adult_dataloader = DataLoader(Young_Adult_dataset, batch_size=500)\n",
        "    Child_dataloader = DataLoader(Child_dataset, batch_size=500)\n",
        "\n",
        "\n",
        "    def compute_metrics(predictions, labels):\n",
        "        predictions = np.argmax(predictions, axis=-1)\n",
        "        return {\n",
        "            \"accuracy\": accuracy_score(labels, predictions),\n",
        "            \"precision\": precision_score(labels, predictions, average='macro'),\n",
        "            \"recall\": recall_score(labels, predictions, average='macro'),\n",
        "            \"f1\": f1_score(labels, predictions, average='macro')\n",
        "        }\n",
        "\n",
        "    def compute_gradient_norms(model, dataloader, optimizer, group, groupname):\n",
        "\n",
        "        wandb.init(project=\"gradient_analysis_per_gender\", name=\"evaluation_run\")\n",
        "\n",
        "        model.train()  # Enable gradient tracking\n",
        "        total_grad_norm_before = 0.0\n",
        "        total_grad_norm_after = 0.0\n",
        "        count = 0\n",
        "        eval_loss = 0\n",
        "        total_micro_batches = 0\n",
        "        micro_batch_size = 40\n",
        "        grad_list = []  # Initialize grad_list\n",
        "        all_logical_labels = []\n",
        "        all_logical_predictions = []\n",
        "        layer_wise_grad_samples = {}\n",
        "        cosine_similarities = []\n",
        "        angles = []\n",
        "\n",
        "        accuracy_below_median = []\n",
        "        accuracy_above_median = []\n",
        "        accuracy_Medicare = []\n",
        "        accuracy_Private = []\n",
        "        accuracy_Medicaid = []\n",
        "        accuracy_Government= []\n",
        "        accuracy_SelfPay = []\n",
        "        accuracy_Senior = []\n",
        "        accuracy_Middle_aged = []\n",
        "        accuracy_Elderly = []\n",
        "        accuracy_Adult = []\n",
        "        accuracy_Young_Adult = []\n",
        "        accuracy_Child = []\n",
        "\n",
        "\n",
        "        grad_before_below_median = []\n",
        "        grad_before_above_median = []\n",
        "        grad_before_Medicare = []\n",
        "        grad_before_Private = []\n",
        "        grad_before_Medicaid = []\n",
        "        grad_before_Government = []\n",
        "        grad_before_SelfPay = []\n",
        "        grad_before_Senior = []\n",
        "        grad_before_Middle_aged = []\n",
        "        grad_before_Elderly = []\n",
        "        grad_before_Adult = []\n",
        "        grad_before_Young_Adult = []\n",
        "        grad_before_Child = []\n",
        "\n",
        "\n",
        "\n",
        "        grad_after_below_median = []\n",
        "        grad_after_above_median = []\n",
        "        grad_after_Medicare = []\n",
        "        grad_after_Private = []\n",
        "        grad_after_Medicaid = []\n",
        "        grad_after_Government = []\n",
        "        grad_after_SelfPay = []\n",
        "        grad_after_Senior = []\n",
        "        grad_after_Middle_aged = []\n",
        "        grad_after_Elderly = []\n",
        "        grad_after_Adult = []\n",
        "        grad_after_Young_Adult = []\n",
        "        grad_after_Child = []\n",
        "\n",
        "        f1score_below_median = []\n",
        "        f1score_above_median = []\n",
        "        f1score_Medicare = []\n",
        "        f1score_Private = []\n",
        "        f1score_Medicaid = []\n",
        "        f1score_Government = []\n",
        "        f1score_SelfPay = []\n",
        "        f1score_Senior = []\n",
        "        f1score_Middle_aged = []\n",
        "        f1score_Elderly = []\n",
        "        f1score_Adult = []\n",
        "        f1score_Young_Adult = []\n",
        "        f1score_Child = []\n",
        "\n",
        "        ratio_below_median = []\n",
        "        ratio_above_median = []\n",
        "        ratio_Medicare = []\n",
        "        ratio_Private = []\n",
        "        ratio_Medicaid = []\n",
        "        ratio_Government = []\n",
        "        ratio_SelfPay = []\n",
        "        ratio_Senior = []\n",
        "        ratio_Middle_aged = []\n",
        "        ratio_Elderly = []\n",
        "        ratio_Adult = []\n",
        "        ratio_Young_Adult = []\n",
        "        ratio_Child = []\n",
        "\n",
        "        loss_below_median = []\n",
        "        loss_above_median = []\n",
        "        loss_Medicare = []\n",
        "        loss_private = []\n",
        "        loss_Medicaid = []\n",
        "        loss_Government = []\n",
        "        loss_SelfPay = []\n",
        "\n",
        "        #with torch.no_grad():\n",
        "        with BatchMemoryManager(\n",
        "                data_loader=dataloader,\n",
        "                max_physical_batch_size=micro_batch_size,\n",
        "                optimizer=optimizer\n",
        "            ) as memory_safe_data_loader:\n",
        "\n",
        "            for batch in dataloader:\n",
        "\n",
        "                total_micro_batches += 1\n",
        "\n",
        "                optimizer.zero_grad()\n",
        "                batch = {k: v.to('cuda') for k, v in batch.items()}\n",
        "                outputs = model(**batch)\n",
        "                loss_fn = torch.nn.CrossEntropyLoss(reduction=\"sum\")\n",
        "                loss = loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)), batch['labels'].view(-1))\n",
        "                eval_loss += loss.item()\n",
        "                logits = outputs.logits\n",
        "                labels = batch['labels']\n",
        "\n",
        "                # Compute and log metrics every batch\n",
        "                predictions = logits.detach().cpu().numpy()\n",
        "                labels_cpu = labels.detach().cpu().numpy()\n",
        "                all_logical_predictions.extend(predictions)\n",
        "                all_logical_labels.extend(labels_cpu)\n",
        "\n",
        "                # Backward pass\n",
        "                loss.backward()\n",
        "\n",
        "                # Compute gradient norm before clipping\n",
        "                for name, p in model.named_parameters():\n",
        "                    if p.requires_grad:\n",
        "                        with torch.no_grad():\n",
        "                            grad_sample = p.grad_sample.sum(dim = 0).clone().detach()\n",
        "                            if name not in layer_wise_grad_samples:\n",
        "                                layer_wise_grad_samples[name] = torch.zeros_like(grad_sample)  # Store per-layer name\n",
        "                            layer_wise_grad_samples[name] += grad_sample  # Accumulate across physical batches\n",
        "\n",
        "                is_final_step = optimizer.pre_step()\n",
        "\n",
        "                if is_final_step:\n",
        "                    total_grad_norm_squared = 0.0\n",
        "                    total_grad_norm_squared_summed_noisy = 0.0\n",
        "                    total_noise_grad = 0.0\n",
        "                    all_layer_grads_before=[]\n",
        "                    all_layer_grads_after=[]\n",
        "\n",
        "                    for name, grad_list in layer_wise_grad_samples.items():\n",
        "                        all_layer_grads_before.append(grad_list)  # Store per-layer gradients\n",
        "                        ## Compute the L2 norm squared\n",
        "                        total_grad_norm_squared += torch.norm(grad_list, p=2) ** 2\n",
        "\n",
        "\n",
        "\n",
        "                    overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "\n",
        "                    for name, p in model.named_parameters():\n",
        "                        if p.requires_grad:\n",
        "                            with torch.no_grad():\n",
        "                                all_layer_grads_after.append(p.summed_grad.clone().detach())\n",
        "\n",
        "                                layer_grad_norm_summed_noisy = torch.norm(p.summed_grad, p=2).item()\n",
        "                                total_grad_norm_squared_summed_noisy += layer_grad_norm_summed_noisy ** 2\n",
        "\n",
        "                                grad = torch.norm(p.grad, p=2).item()\n",
        "                                total_noise_grad += (grad - layer_grad_norm_summed_noisy) ** 2\n",
        "\n",
        "\n",
        "                    # Convert lists to single tensors\n",
        "                    total_grad_before = torch.cat([v.view(-1) for v in all_layer_grads_before])\n",
        "                    total_grad_after = torch.cat([v.view(-1) for v in all_layer_grads_after])  # Aggregate all layers\n",
        "\n",
        "\n",
        "                    # Compute Cosine Similarity\n",
        "                    cosine_similarity = torch.nn.functional.cosine_similarity(\n",
        "                        total_grad_before.view(1, -1),\n",
        "                        total_grad_after.view(1, -1),\n",
        "                        dim=1\n",
        "                    ).item()\n",
        "\n",
        "                    # Compute Angle in Degrees\n",
        "                    angle = torch.acos(torch.clamp(torch.tensor(cosine_similarity), -1.0, 1.0)).item() * (180 / torch.pi)\n",
        "\n",
        "                    # Store values for logging\n",
        "                    cosine_similarities.append(cosine_similarity)\n",
        "                    angles.append(angle)\n",
        "\n",
        "                    #print(f\"Cosine Similarity: {cosine_similarity}\")\n",
        "                    #print(f\"Angle in Degrees: {angle}\")\n",
        "\n",
        "                    # Compute overall gradient norm\n",
        "                    overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "                    overall_grad_norm_summed_noisy = total_grad_norm_squared_summed_noisy ** 0.5\n",
        "                    overal_noise_grad = total_noise_grad ** 0.5\n",
        "\n",
        "                    #print(f\"logical batch grad_norm before clipping: {overall_grad_norm}\")\n",
        "                    #print(f\"logical batch grad norm after clipping with Noise: {overall_grad_norm_summed_noisy}\")\n",
        "                    #print(f\"overal noise: grad - summed_grad: {overal_noise_grad}\")\n",
        "                    #print(f\"summed_grad/noise grad: {overall_grad_norm_summed_noisy/overal_noise_grad}\")\n",
        "\n",
        "                    # Compute metrics after all physical batches in the logical batch are processed\n",
        "                    eval_metrics = compute_metrics(np.array(all_logical_predictions), np.array(all_logical_labels))\n",
        "                    avg_eval_loss = eval_loss / total_micro_batches\n",
        "\n",
        "                    # Compute the average similarity and angle across all layers\n",
        "                    avg_cosine_similarity = sum(cosine_similarities) / len(cosine_similarities)\n",
        "                    avg_angle = sum(angles) / len(angles)\n",
        "\n",
        "                    if groupname=='age_median':\n",
        "                    # Log to wandb\n",
        "                        wandb.log({\n",
        "                            f\"eval_loss_{group}\": avg_eval_loss,\n",
        "                            f\"eval_accuracy_{group}\": eval_metrics['accuracy'],\n",
        "                            f\"eval_precision_{group}\": eval_metrics['precision'],\n",
        "                            f\"eval_recall_{group}\": eval_metrics['recall'],\n",
        "                            f\"eval_f1_{group}\": eval_metrics['f1'],\n",
        "                            f\"logical batch grad_norm before clipping: {group}\": overall_grad_norm,\n",
        "                            f\"logical batch grad norm after clipping with Noise: {group}\": overall_grad_norm_summed_noisy,\n",
        "                            f\"overal noise: grad - summed_gradg: {group}\": overal_noise_grad,\n",
        "                            f\"summed_grad/noise grad: {group}\":(overall_grad_norm_summed_noisy/overal_noise_grad),\n",
        "                            f\"norm after clipping / norm before clipping: {group} \" :overall_grad_norm_summed_noisy/overall_grad_norm,\n",
        "                            f\"logical_batch_cosine_similarity: {group}\": avg_cosine_similarity,\n",
        "                            f\"logical_batch_angle_degrees: {group}\": avg_angle,\n",
        "                            f\"Cosine Similarity_{group}\": cosine_similarity,\n",
        "                            f\"Angle in Degrees_{group}\": angle\n",
        "                        })\n",
        "\n",
        "                        #print(f\"eval_loss_{group}: {avg_eval_loss}\")\n",
        "                        #print(f\"eval_accuracy_{group}: {eval_metrics['accuracy']}\")\n",
        "                        #print(f\"eval_precision_{group}: {eval_metrics['precision']}\")\n",
        "                        #print(f\"eval_recall_{group}: { eval_metrics['recall']}\")\n",
        "                        #print(f\"eval_f1_{group}: {eval_metrics['f1']}\")\n",
        "                        #print(f\"logical_batch_cosine_similarity{group}: {avg_cosine_similarity}\")\n",
        "                        #print(f\"logical_batch_angle_degrees{group}: {avg_angle}\")\n",
        "\n",
        "                        if group=='below_median':\n",
        "                            accuracy_below_median.append(eval_metrics['accuracy'])\n",
        "                            f1score_below_median.append(eval_metrics['f1'])\n",
        "                            grad_before_below_median.append(overall_grad_norm)\n",
        "                            grad_after_below_median.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_below_median.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "                            loss_below_median.append(avg_eval_loss)\n",
        "\n",
        "                        elif group=='above_median':\n",
        "                            accuracy_above_median.append(eval_metrics['accuracy'])\n",
        "                            f1score_above_median.append(eval_metrics['f1'])\n",
        "                            grad_before_above_median.append(overall_grad_norm)\n",
        "                            grad_after_above_median.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_above_median.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "                            loss_above_median.append(avg_eval_loss)\n",
        "\n",
        "                    elif groupname=='insurance':\n",
        "                        wandb.log({\n",
        "                            f\"eval_loss_{group}\": avg_eval_loss,\n",
        "                            f\"eval_accuracy_{group}\": eval_metrics['accuracy'],\n",
        "                            f\"eval_precision_{group}\": eval_metrics['precision'],\n",
        "                            f\"eval_recall_{group}\": eval_metrics['recall'],\n",
        "                            f\"eval_f1_{group}\": eval_metrics['f1'],\n",
        "                            f\"logical batch grad_norm before clipping: {group}\": overall_grad_norm,\n",
        "                            f\"logical batch grad norm after clipping with Noise: {group}\": overall_grad_norm_summed_noisy,\n",
        "                            f\"overal noise: grad - summed_gradg: {group}\": overal_noise_grad,\n",
        "                            f\"summed_grad/noise grad: {group}\":(overall_grad_norm_summed_noisy/overal_noise_grad),\n",
        "                            f\"norm after clipping / norm before clipping: {group} \" :overall_grad_norm_summed_noisy/overall_grad_norm,\n",
        "                            f\"logical_batch_cosine_similarity: {group}\": avg_cosine_similarity,\n",
        "                            f\"logical_batch_angle_degrees: {group}\": avg_angle,\n",
        "                            f\"Cosine Similarity_{group}\": cosine_similarity,\n",
        "                            f\"Angle in Degrees_{group}\": angle\n",
        "                        })\n",
        "\n",
        "                        print(f\"eval_loss_{group}: {avg_eval_loss}\")\n",
        "                        print(f\"eval_accuracy_{group}: {eval_metrics['accuracy']}\")\n",
        "                        print(f\"eval_precision_{group}: {eval_metrics['precision']}\")\n",
        "                        print(f\"eval_recall_{group}: { eval_metrics['recall']}\")\n",
        "                        print(f\"eval_f1_{group}: {eval_metrics['f1']}\")\n",
        "                        print(f\"logical_batch_cosine_similarity{group}: {avg_cosine_similarity}\")\n",
        "                        print(f\"logical_batch_angle_degrees{group}: {avg_angle}\")\n",
        "\n",
        "                        if group=='Medicare':\n",
        "                            accuracy_Medicare.append(eval_metrics['accuracy'])\n",
        "                            f1score_Medicare.append(eval_metrics['f1'])\n",
        "                            grad_before_Medicare.append(overall_grad_norm)\n",
        "                            grad_after_Medicare.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_Medicare.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "                            loss_Medicare.append(avg_eval_loss)\n",
        "\n",
        "                        elif group=='Private':\n",
        "                            accuracy_Private.append(eval_metrics['accuracy'])\n",
        "                            f1score_Private.append(eval_metrics['f1'])\n",
        "                            grad_before_Private.append(overall_grad_norm)\n",
        "                            grad_after_Private.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_Private.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "                            loss_private.append(avg_eval_loss)\n",
        "\n",
        "\n",
        "                        elif group=='Medicaid':\n",
        "                            accuracy_Medicaid.append(eval_metrics['accuracy'])\n",
        "                            f1score_Medicaid.append(eval_metrics['f1'])\n",
        "                            grad_before_Medicaid.append(overall_grad_norm)\n",
        "                            grad_after_Medicaid.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_Medicaid.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "                            loss_Medicaid.append(avg_eval_loss)\n",
        "\n",
        "                        elif group=='Government':\n",
        "                            accuracy_Government.append(eval_metrics['accuracy'])\n",
        "                            f1score_Government.append(eval_metrics['f1'])\n",
        "                            grad_before_Government.append(overall_grad_norm)\n",
        "                            grad_after_Government.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_Government.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "                            loss_Government.append(avg_eval_loss)\n",
        "\n",
        "                        elif group=='SelfPay':\n",
        "                            accuracy_SelfPay.append(eval_metrics['accuracy'])\n",
        "                            f1score_SelfPay.append(eval_metrics['f1'])\n",
        "                            grad_before_SelfPay.append(overall_grad_norm)\n",
        "                            grad_after_SelfPay.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_SelfPay.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "                            loss_SelfPay.append(avg_eval_loss)\n",
        "\n",
        "\n",
        "                    elif groupname=='age_bins':\n",
        "\n",
        "                        wandb.log({\n",
        "                            f\"eval_loss_{group}\": avg_eval_loss,\n",
        "                            f\"eval_accuracy_{group}\": eval_metrics['accuracy'],\n",
        "                            f\"eval_precision_{group}\": eval_metrics['precision'],\n",
        "                            f\"eval_recall_{group}\": eval_metrics['recall'],\n",
        "                            f\"eval_f1_{group}\": eval_metrics['f1'],\n",
        "                            f\"logical batch grad_norm before clipping: {group}\": overall_grad_norm,\n",
        "                            f\"logical batch grad norm after clipping with Noise: {group}\": overall_grad_norm_summed_noisy,\n",
        "                            f\"overal noise: grad - summed_gradg: {group}\": overal_noise_grad,\n",
        "                            f\"summed_grad/noise grad: {group}\":(overall_grad_norm_summed_noisy/overal_noise_grad),\n",
        "                            f\"norm after clipping / norm before clipping: {group} \" :overall_grad_norm_summed_noisy/overall_grad_norm,\n",
        "                            f\"logical_batch_cosine_similarity: {group}\": avg_cosine_similarity,\n",
        "                            f\"logical_batch_angle_degrees: {group}\": avg_angle,\n",
        "                            f\"Cosine Similarity_{group}\": cosine_similarity,\n",
        "                            f\"Angle in Degrees_{group}\": angle\n",
        "                        })\n",
        "\n",
        "                        print(f\"eval_loss_{group}: {avg_eval_loss}\")\n",
        "                        print(f\"eval_accuracy_{group}: {eval_metrics['accuracy']}\")\n",
        "                        print(f\"eval_precision_{group}: {eval_metrics['precision']}\")\n",
        "                        print(f\"eval_recall_{group}: { eval_metrics['recall']}\")\n",
        "                        print(f\"eval_f1_{group}: {eval_metrics['f1']}\")\n",
        "                        print(f\"logical_batch_cosine_similarity{group}: {avg_cosine_similarity}\")\n",
        "                        print(f\"logical_batch_angle_degrees{group}: {avg_angle}\")\n",
        "\n",
        "                        if group=='Senior':\n",
        "                            accuracy_Senior.append(eval_metrics['accuracy'])\n",
        "                            f1score_Senior.append(eval_metrics['f1'])\n",
        "                            grad_before_Senior.append(overall_grad_norm)\n",
        "                            grad_after_Senior.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_Senior.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "\n",
        "                        elif group=='Middle_aged':\n",
        "                            accuracy_Middle_aged.append(eval_metrics['accuracy'])\n",
        "                            f1score_Middle_aged.append(eval_metrics['f1'])\n",
        "                            grad_before_Middle_aged.append(overall_grad_norm)\n",
        "                            grad_after_Middle_aged.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_Middle_aged.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "\n",
        "                        elif group=='Young_Adult':\n",
        "                            accuracy_Young_Adult.append(eval_metrics['accuracy'])\n",
        "                            f1score_Young_Adult.append(eval_metrics['f1'])\n",
        "                            grad_before_Young_Adult.append(overall_grad_norm)\n",
        "                            grad_after_Young_Adult.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_Young_Adult.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "\n",
        "                        elif group=='Elderly':\n",
        "                            accuracy_Elderly.append(eval_metrics['accuracy'])\n",
        "                            f1score_Elderly.append(eval_metrics['f1'])\n",
        "                            grad_before_Elderly.append(overall_grad_norm)\n",
        "                            grad_after_Elderly.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_Elderly.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "\n",
        "                        elif group=='Adult':\n",
        "                            accuracy_Adult.append(eval_metrics['accuracy'])\n",
        "                            f1score_Adult.append(eval_metrics['f1'])\n",
        "                            grad_before_Adult.append(overall_grad_norm)\n",
        "                            grad_after_Adult.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_Adult.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "\n",
        "                        elif group=='Child':\n",
        "                            accuracy_Child.append(eval_metrics['accuracy'])\n",
        "                            f1score_Child.append(eval_metrics['f1'])\n",
        "                            grad_before_Child.append(overall_grad_norm)\n",
        "                            grad_after_Child.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_Child.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "\n",
        "\n",
        "                    all_logical_labels = []\n",
        "                    all_logical_predictions = []\n",
        "                    grad_list = []  # Reset grad_list\n",
        "                    layer_wise_grad_samples = {}\n",
        "\n",
        "\n",
        "                optimizer.zero_grad()\n",
        "                count += 1\n",
        "\n",
        "            if groupname=='age_median':\n",
        "\n",
        "                if group == 'below_median':\n",
        "                    avg_grad_norm_before_below_median = sum(grad_before_below_median)/len(grad_before_below_median)\n",
        "                    avg_grad_norm_after_below_median = sum(grad_after_below_median) / len(grad_after_below_median)\n",
        "                    avg_accuracy_below_median = sum(accuracy_below_median) / len(accuracy_below_median)\n",
        "                    avg_f1_below_median = sum(f1score_below_median) / len(f1score_below_median)\n",
        "                    avg_loss_below_median = sum(loss_below_median) / len(loss_below_median)\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_below_median}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_below_median}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_below_median}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_below_median}\")\n",
        "                    print(f\"Average loss for {group}: {avg_loss_below_median}\")\n",
        "\n",
        "                elif group == 'above_median':\n",
        "                    avg_grad_norm_before_above_median = sum(grad_before_above_median)/len(grad_before_above_median)\n",
        "                    avg_grad_norm_after_above_median = sum(grad_after_above_median) / len(grad_after_above_median)\n",
        "                    avg_accuracy_above_median = sum(accuracy_above_median) / len(accuracy_above_median)\n",
        "                    avg_f1_above_median = sum(f1score_above_median) / len(f1score_above_median)\n",
        "                    avg_loss_above_median = sum(loss_above_median) / len(loss_above_median)\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_above_median}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_above_median}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_above_median}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_above_median}\")\n",
        "                    print(f\"Average loss for {group}: {avg_loss_above_median}\")\n",
        "\n",
        "            if groupname=='insurance':\n",
        "\n",
        "                if group == 'Medicare':\n",
        "                    avg_grad_norm_before_Medicare = sum(grad_before_Medicare)/len(grad_before_Medicare)\n",
        "                    avg_grad_norm_after_Medicare = sum(grad_after_Medicare) / len(grad_after_Medicare)\n",
        "                    avg_accuracy_Medicare = sum(accuracy_Medicare) / len(accuracy_Medicare)\n",
        "                    avg_f1_Medicare = sum(f1score_Medicare) / len(f1score_Medicare)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_Medicare}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_Medicare}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_Medicare}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_Medicare}\")\n",
        "\n",
        "                elif group == 'Private':\n",
        "                    avg_grad_norm_before_Private = sum(grad_before_Private)/len(grad_before_Private)\n",
        "                    avg_grad_norm_after_Private = sum(grad_after_Private) / len(grad_after_Private)\n",
        "                    avg_accuracy_Private = sum(accuracy_Private) / len(accuracy_Private)\n",
        "                    avg_f1_Private = sum(f1score_Private) / len(f1score_Private)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_Private}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_Private}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_Private}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_Private}\")\n",
        "\n",
        "                elif group == 'Medicaid':\n",
        "                    avg_grad_norm_before_Medicaid = sum(grad_before_Medicaid)/len(grad_before_Medicaid)\n",
        "                    avg_grad_norm_after_Medicaid = sum(grad_after_Medicaid) / len(grad_after_Medicaid)\n",
        "                    avg_accuracy_Medicaid = sum(accuracy_Medicaid) / len(accuracy_Medicaid)\n",
        "                    avg_f1_Medicaid = sum(f1score_Medicaid) / len(f1score_Medicaid)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_Medicaid}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_Medicaid}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_Medicaid}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_Medicaid}\")\n",
        "\n",
        "                elif group == 'Government':\n",
        "                    avg_grad_norm_before_Government = sum(grad_before_Government)/len(grad_before_Government)\n",
        "                    avg_grad_norm_after_Government = sum(grad_after_Government) / len(grad_after_Government)\n",
        "                    avg_accuracy_Government = sum(accuracy_Government) / len(accuracy_Government)\n",
        "                    avg_f1_Government = sum(f1score_Government) / len(f1score_Government)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_Government}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_Government}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_Government}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_Government}\")\n",
        "\n",
        "                elif group == 'SelfPay':\n",
        "                    avg_grad_norm_before_SelfPay = sum(grad_before_SelfPay)/len(grad_before_SelfPay)\n",
        "                    avg_grad_norm_after_SelfPay = sum(grad_after_SelfPay) / len(grad_after_SelfPay)\n",
        "                    avg_accuracy_SelfPay = sum(accuracy_SelfPay) / len(accuracy_SelfPay)\n",
        "                    avg_f1_SelfPay = sum(f1score_SelfPay) / len(f1score_SelfPay)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_SelfPay}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_SelfPay}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_SelfPay}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_SelfPay}\")\n",
        "\n",
        "            if groupname=='age_bins':\n",
        "\n",
        "                if group == 'Senior':\n",
        "                    avg_grad_norm_before_Senior = sum(grad_before_Senior)/len(grad_before_Senior)\n",
        "                    avg_grad_norm_after_Senior = sum(grad_after_Senior) / len(grad_after_Senior)\n",
        "                    avg_accuracy_Senior = sum(accuracy_Senior) / len(accuracy_Senior)\n",
        "                    avg_f1_Senior = sum(f1score_Senior) / len(f1score_Senior)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_Senior}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_Senior}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_Senior}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_Senior}\")\n",
        "\n",
        "                elif group == 'Middle_aged':\n",
        "                    avg_grad_norm_before_Middle_aged = sum(grad_before_Middle_aged)/len(grad_before_Middle_aged)\n",
        "                    avg_grad_norm_after_Middle_aged = sum(grad_after_Middle_aged) / len(grad_after_Middle_aged)\n",
        "                    avg_accuracy_Middle_aged = sum(accuracy_Middle_aged) / len(accuracy_Middle_aged)\n",
        "                    avg_f1_Middle_aged = sum(f1score_Middle_aged) / len(f1score_Middle_aged)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_Middle_aged}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_Middle_aged}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_Middle_aged}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_Middle_aged}\")\n",
        "\n",
        "                elif group == 'Elderly':\n",
        "                    avg_grad_norm_before_Elderly = sum(grad_before_Elderly)/len(grad_before_Elderly)\n",
        "                    avg_grad_norm_after_Elderly = sum(grad_after_Elderly) / len(grad_after_Elderly)\n",
        "                    avg_accuracy_Elderly = sum(accuracy_Elderly) / len(accuracy_Elderly)\n",
        "                    avg_f1_Elderly = sum(f1score_Elderly) / len(f1score_Elderly)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_Elderly}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_Elderly}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_Elderly}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_Elderly}\")\n",
        "\n",
        "                elif group == 'Adult':\n",
        "                    avg_grad_norm_before_Adult = sum(grad_before_Adult)/len(grad_before_Adult)\n",
        "                    avg_grad_norm_after_Adult = sum(grad_after_Adult) / len(grad_after_Adult)\n",
        "                    avg_accuracy_Adult = sum(accuracy_Adult) / len(accuracy_Adult)\n",
        "                    avg_f1_Adult = sum(f1score_Adult) / len(f1score_Adult)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_Adult}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_Adult}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_Adult}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_Adult}\")\n",
        "\n",
        "                elif group == 'Young_Adult':\n",
        "                    avg_grad_norm_before_Young_Adult= sum(grad_before_Young_Adult)/len(grad_before_Young_Adult)\n",
        "                    avg_grad_norm_after_Young_Adult= sum(grad_after_Young_Adult) / len(grad_after_Young_Adult)\n",
        "                    avg_accuracy_Young_Adult= sum(accuracy_Young_Adult) / len(accuracy_Young_Adult)\n",
        "                    avg_f1_Young_Adult = sum(f1score_Young_Adult) / len(f1score_Young_Adult)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_Young_Adult}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_Young_Adult}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_Young_Adult}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_Young_Adult}\")\n",
        "\n",
        "                elif group == 'Child':\n",
        "                    avg_grad_norm_before_Child = sum(grad_before_Child)/len(grad_before_Child)\n",
        "                    avg_grad_norm_after_Child = sum(grad_after_Child) / len(grad_after_Child)\n",
        "                    avg_accuracy_Child = sum(accuracy_Child) / len(accuracy_Child)\n",
        "                    avg_f1_Child = sum(f1score_Child) / len(f1score_Child)\n",
        "\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_Child}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_Child}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_Child}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_Child}\")\n",
        "\n",
        "\n",
        "\n",
        "    # Compute gradient norms for above_median and male datasets\n",
        "    print(\"Computing gradient norms for above_median dataset:\")\n",
        "    compute_gradient_norms(model, above_median_dataloader, optimizer,'above_median', groupname='age_median')\n",
        "\n",
        "    print(\"\\nComputing gradient norms for below_median dataset:\")\n",
        "    compute_gradient_norms(model, below_median_dataloader, optimizer,'below_median', groupname='age_median')\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    import fire\n",
        "    fire.Fire(main)"
      ],
      "metadata": {
        "id": "VXv0HRpamoik"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}