{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QG3oJYniwSUK"
      },
      "outputs": [],
      "source": [
        "# ====================================================================================================\n",
        "# User-configurable settings\n",
        "# ====================================================================================================\n",
        "\n",
        "# Random seed (match with training seed for checkpoint compatibility)\n",
        "SEED = 2025\n",
        "\n",
        "# Paths\n",
        "CACHE_DIR = \"/cache\"  #Cache directory\n",
        "DATA_DIR = \"/datasets/unpackedmimiciii/mimic3MultiGender\" #Data directory\n",
        "CHECKPOINT_PATH = \"/Checkpoints/MimicLOS/NEWSEED42-20010CheckpointsFairAdaptiveoptimizer/checkpoint_epoch_200.pt\"  #checkpoint directory\n",
        "TOKENIZER_PATH = \"/Checkpoints/MimicLOS/NEWSEED42-20010CheckpointsFairAdaptiveoptimizer/tokenizer_epoch_200\"  #Tokenizer directory\n",
        "MODEL_NAME_OR_PATH = \"medicalai/ClinicalBERT\"\n",
        "\n",
        "# Dataset files (relative to DATA_DIR)\n",
        "TRAIN_FILE = \"LOS_WEEKS_train.csv\"\n",
        "VAL_FILE   = \"LOS_WEEKS_val.csv\"\n",
        "TEST_FILE  = \"LOS_WEEKS_test.csv\"\n",
        "\n",
        "\n",
        "\n",
        "from torch.utils.data.dataset import Dataset\n",
        "from torch.utils.data import DataLoader\n",
        "from datasets import load_dataset\n",
        "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
        "import torch\n",
        "import wandb\n",
        "import numpy as np\n",
        "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
        "import fire\n",
        "import opacus\n",
        "from opacus.utils.batch_memory_manager import BatchMemoryManager\n",
        "from opacus.accountants.utils import get_noise_multiplier\n",
        "import torch.nn.functional as F\n",
        "import os\n",
        "from opacus import PrivacyEngine\n",
        "import torch.optim as optim\n",
        "from typing import List, Union\n",
        "from opacus.optimizers import DPOptimizer\n",
        "from opacus.optimizers import AdaClipDPOptimizer, DPOptimizer\n",
        "from opacus.optimizers.optimizer import (\n",
        "    _check_processed_flag,\n",
        "    _generate_noise,\n",
        "    _mark_as_processed\n",
        ")\n",
        "\n",
        "\n",
        "class FairOptimizer(AdaClipDPOptimizer):\n",
        "    def __init__(self, *args, **kwargs):\n",
        "        super().__init__(*args,**kwargs)\n",
        "\n",
        "\n",
        "    def clip_and_accumulate(self):\n",
        "        per_param_norms = [\n",
        "            g.view(len(g), -1).norm(2, dim=-1) for g in self.grad_samples\n",
        "        ]\n",
        "        per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)\n",
        "        per_sample_clip_factor = torch.tanh(self.max_grad_norm / (per_sample_norms + 1e-6))  #the only change to make it fair with tanh\n",
        "\n",
        "        # the two lines below are the only changes\n",
        "        # relative to the parent DPOptimizer class.\n",
        "        self.sample_size += len(per_sample_clip_factor)\n",
        "        self.unclipped_num += (\n",
        "            len(per_sample_clip_factor) - (per_sample_clip_factor < 1).sum()\n",
        "        )\n",
        "\n",
        "        for p in self.params:\n",
        "            _check_processed_flag(p.grad_sample)\n",
        "            grad_sample = self._get_flat_grad_sample(p)\n",
        "            grad = torch.einsum(\"i,i...\", per_sample_clip_factor, grad_sample)\n",
        "\n",
        "            if p.summed_grad is not None:\n",
        "                p.summed_grad += grad\n",
        "            else:\n",
        "                p.summed_grad = grad\n",
        "\n",
        "            _mark_as_processed(p.grad_sample)\n",
        "\n",
        "\n",
        "class fairPrivacyEngine(PrivacyEngine):\n",
        "    def __init__(self, *args, **kwargs):\n",
        "        super().__init__(*args, **kwargs)\n",
        "\n",
        "    def _prepare_optimizer(\n",
        "        self,\n",
        "        *,\n",
        "        optimizer: optim.Optimizer,\n",
        "        noise_multiplier: float,\n",
        "        max_grad_norm: Union[float, List[float]],\n",
        "        expected_batch_size: int,\n",
        "        loss_reduction: str = \"mean\",\n",
        "        distributed: bool = False,\n",
        "        clipping: str = \"flat\",\n",
        "        noise_generator=None,\n",
        "        grad_sample_mode=\"hooks\",\n",
        "        **kwargs,\n",
        "    ) -> DPOptimizer:\n",
        "        if isinstance(optimizer, DPOptimizer):\n",
        "            optimizer = optimizer.original_optimizer\n",
        "\n",
        "        generator = None\n",
        "        if self.secure_mode:\n",
        "            generator = self.secure_rng\n",
        "        elif noise_generator is not None:\n",
        "            generator = noise_generator\n",
        "\n",
        "\n",
        "        new_fair = FairOptimizer(\n",
        "            optimizer=optimizer,\n",
        "            noise_multiplier=noise_multiplier,\n",
        "            max_grad_norm=max_grad_norm,\n",
        "            expected_batch_size=expected_batch_size,\n",
        "            loss_reduction=loss_reduction,\n",
        "            generator=generator,\n",
        "            secure_mode=self.secure_mode,\n",
        "            **kwargs,\n",
        "        )\n",
        "\n",
        "        return new_fair\n",
        "\n",
        "\n",
        "def main():\n",
        "\n",
        "    def seed_everything(seed):\n",
        "        \"\"\"Set the seed for reproducibility across all libraries, including GPU operations.\"\"\"\n",
        "        os.environ['PYTHONHASHSEED'] = str(seed)\n",
        "        np.random.seed(seed)\n",
        "        torch.manual_seed(seed)\n",
        "        torch.cuda.manual_seed(seed)\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "        # Ensure deterministic behavior in cuDNN\n",
        "        torch.backends.cudnn.deterministic = True\n",
        "        torch.backends.cudnn.benchmark = False  # Set to False for full determinism\n",
        "\n",
        "        print(f\"Global seed set to: {seed}\")\n",
        "\n",
        "    # SEED_LIST=(42 0 1234 2023 777 999 1337)\n",
        "    seed = SEED\n",
        "    seed_everything(seed)\n",
        "\n",
        "    # Load dataset\n",
        "    cache_dir = CACHE_DIR\n",
        "    data_dir = DATA_DIR\n",
        "    data_files = {\n",
        "        \"train\": f\"{data_dir}/LOS_WEEKS_train.csv\",\n",
        "        \"validation\": f\"{data_dir}/LOS_WEEKS_val.csv\",\n",
        "        \"test\": f\"{data_dir}/LOS_WEEKS_test.csv\",\n",
        "    }\n",
        "    dataset = load_dataset('csv', data_files=data_files, cache_dir=cache_dir)\n",
        "\n",
        "    class TextDataset(Dataset):\n",
        "        def __init__(self, encodings, labels):\n",
        "            self.encodings = encodings\n",
        "            self.labels = labels\n",
        "\n",
        "        def __getitem__(self, idx):\n",
        "            item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
        "            item['labels'] = torch.tensor(self.labels[idx])\n",
        "            return item\n",
        "\n",
        "        def __len__(self):\n",
        "            return len(self.labels)\n",
        "\n",
        "    def load_checkpoint(checkpoint_path, model_name_or_path,tokenizer_path,dataset):\n",
        "\n",
        "        checkpoint = torch.load(checkpoint_path,weights_only=False)\n",
        "        model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, num_labels=4)\n",
        "\n",
        "\n",
        "        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)\n",
        "\n",
        "        #Freeze layers\n",
        "        for name, parameter in model.named_parameters():\n",
        "            if \"classifier\" in name or \"layer.5.output_layer_norm\" in name or \"layer.5.ffn.lin2\" in name or \"layer.5.ffn.lin1\" in name or \"layer.5.sa_layer_norm\" in name or \"layer.5.attention.out_lin\" in name or \"layer.5.attention.v_lin\" in name or \"layer.5.attention.k_lin\" in name:\n",
        "                parameter.requires_grad = True\n",
        "            else:\n",
        "                parameter.requires_grad = False\n",
        "\n",
        "        model.to(\"cuda\")\n",
        "        start_epoch = checkpoint['epoch']\n",
        "        best_val_f1 = checkpoint['best_val_f1']\n",
        "\n",
        "        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)\n",
        "\n",
        "        def process_dataset(dataset_split):\n",
        "            dataset_split = dataset_split.select_columns([\"text\", \"los_label\"])\n",
        "\n",
        "            encodings = tokenizer(\n",
        "                dataset_split['text'],\n",
        "                truncation=True,\n",
        "                padding='max_length',\n",
        "                max_length=512,\n",
        "                return_tensors='pt'\n",
        "            )\n",
        "            labels = [int(label) for label in dataset_split['los_label']]\n",
        "            return TextDataset(encodings, labels)\n",
        "\n",
        "        g = torch.Generator()\n",
        "        g.manual_seed(seed)\n",
        "\n",
        "        train_dataset = process_dataset(dataset['train'])\n",
        "\n",
        "        train_dataloader = DataLoader(\n",
        "            train_dataset,\n",
        "            batch_size=500, #batch_size,\n",
        "            shuffle=True,\n",
        "            drop_last=True,\n",
        "            worker_init_fn=lambda worker_id: torch.manual_seed(seed + worker_id),  # Fixes worker shuffling\n",
        "            generator=g\n",
        "            )\n",
        "\n",
        "        model.train()\n",
        "\n",
        "        # Opacus setup\n",
        "        sigma = get_noise_multiplier(\n",
        "            target_epsilon=8,\n",
        "            target_delta=1e-5,\n",
        "            sample_rate=500 / len(dataset['train']),\n",
        "            epochs=200\n",
        "        )\n",
        "        print(sigma)\n",
        "\n",
        "        #z = sigma\n",
        "        m = 500  # users per round\n",
        "        sigma_b = m / 20        # from the paper\n",
        "\n",
        "        privacy_engine = fairPrivacyEngine()\n",
        "\n",
        "        model, optimizer, train_loader = privacy_engine.make_private(\n",
        "            module=model,\n",
        "            optimizer=optimizer,\n",
        "            clipping = 'adaptive',\n",
        "            data_loader=train_dataloader,\n",
        "            noise_multiplier=sigma,\n",
        "            max_grad_norm=0.1,\n",
        "            poisson_sampling=False, # poisson_sampling=True, If I was using drop_last = False in dataloader\n",
        "            loss_reduction=\"sum\", #sum\n",
        "            target_unclipped_quantile=0.5, # gamma = median\n",
        "            clipbound_learning_rate=0.2, # ηC\n",
        "            max_clipbound=2.0,  # upper limit on clipping norm\n",
        "            min_clipbound=0.1,  # lower limit\n",
        "            unclipped_num_std=sigma_b\n",
        "        )\n",
        "\n",
        "        print(f\"Loaded checkpoint from {checkpoint_path}, starting from epoch {start_epoch+1}\")\n",
        "\n",
        "        model.load_state_dict(checkpoint[\"model_state_dict\"], strict=False)  # strict=False allows missing keys\n",
        "        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
        "\n",
        "        # Extract LR from loaded optimizer\n",
        "        for param_group in optimizer.param_groups:\n",
        "            lr = param_group[\"lr\"]  # Extract the learning rate from the optimizer checkpoint\n",
        "\n",
        "        print(f\"Loaded learning rate from checkpoint: {lr}\")\n",
        "\n",
        "        return model, tokenizer, optimizer, start_epoch, best_val_f1\n",
        "\n",
        "    # Load a checkpoint, Define paths\n",
        "    checkpoint_path = CHECKPOINT_PATH  #\"/Checkpoints/MimicLOS/NEWSEED42-20010CheckpointsFairAdaptiveoptimizer/checkpoint_epoch_200.pt\"\n",
        "    model_name_or_path = MODEL_NAME_OR_PATH  #\"medicalai/ClinicalBERT\"\n",
        "    tokenizer_path = TOKENIZER_PATH  #\"/Checkpoints/MimicLOS/NEWSEED42-20010CheckpointsFairAdaptiveoptimizer/tokenizer_epoch_200\"\n",
        "\n",
        "\n",
        "    model, tokenizer, optimizer, start_epoch, best_val_f1 = load_checkpoint(checkpoint_path, model_name_or_path,tokenizer_path,dataset)\n",
        "\n",
        "\n",
        "\n",
        "    def filter_age_median(dataset_split, age_median_gp, num_samples=None):\n",
        "        \"\"\"Filter dataset by gender and tokenize.\"\"\"\n",
        "        dataset_split = dataset_split.filter(lambda example: example['age_median_gp'] == age_median_gp)\n",
        "\n",
        "        # Select only the first `num_samples` if provided (only for females)\n",
        "        if num_samples is not None:\n",
        "            dataset_split = dataset_split.select(range(min(num_samples, len(dataset_split))))\n",
        "\n",
        "        # Keep only relevant columns\n",
        "        dataset_split = dataset_split.select_columns([\"text\", \"los_label\"])\n",
        "\n",
        "        encodings = tokenizer(\n",
        "            dataset_split['text'],\n",
        "            truncation=True,\n",
        "            padding='max_length',\n",
        "            max_length=512,\n",
        "            return_tensors='pt'\n",
        "        )\n",
        "\n",
        "        labels = torch.tensor([int(label) for label in dataset_split['los_label']])\n",
        "        return TextDataset(encodings, labels)\n",
        "\n",
        "\n",
        "\n",
        "    # Create gender-specific dataloaders\n",
        "    below_median_dataset = filter_age_median(dataset['test'], age_median_gp=0) #below median = 0\n",
        "    above_median_dataset = filter_age_median(dataset['test'], age_median_gp=1) #male\n",
        "\n",
        "\n",
        "    below_median_dataloader = DataLoader(below_median_dataset, batch_size=500)\n",
        "    above_median_dataloader = DataLoader(above_median_dataset, batch_size=500)\n",
        "\n",
        "\n",
        "    def compute_metrics(predictions, labels):\n",
        "        predictions = np.argmax(predictions, axis=-1)\n",
        "        return {\n",
        "            \"accuracy\": accuracy_score(labels, predictions),\n",
        "            \"precision\": precision_score(labels, predictions, average='macro'),\n",
        "            \"recall\": recall_score(labels, predictions, average='macro'),\n",
        "            \"f1\": f1_score(labels, predictions, average='macro')\n",
        "        }\n",
        "\n",
        "    def compute_gradient_norms(model, dataloader, optimizer, group, groupname):\n",
        "\n",
        "        wandb.init(project=\"gradient_analysis_per_gender\", name=\"evaluation_run\")\n",
        "\n",
        "        model.train()  # Enable gradient tracking\n",
        "        total_grad_norm_before = 0.0\n",
        "        total_grad_norm_after = 0.0\n",
        "        count = 0\n",
        "        eval_loss = 0\n",
        "        total_micro_batches = 0\n",
        "        micro_batch_size = 40\n",
        "        grad_list = []  # Initialize grad_list\n",
        "        all_logical_labels = []\n",
        "        all_logical_predictions = []\n",
        "        layer_wise_grad_samples = {}\n",
        "        cosine_similarities = []\n",
        "        angles = []\n",
        "\n",
        "        accuracy_below_median = []\n",
        "        accuracy_above_median = []\n",
        "        accuracy_Medicare = []\n",
        "        accuracy_Private = []\n",
        "        accuracy_Medicaid = []\n",
        "        accuracy_Government= []\n",
        "        accuracy_SelfPay = []\n",
        "        accuracy_Senior = []\n",
        "        accuracy_Middle_aged = []\n",
        "        accuracy_Elderly = []\n",
        "        accuracy_Adult = []\n",
        "        accuracy_Young_Adult = []\n",
        "        accuracy_Child = []\n",
        "\n",
        "\n",
        "        grad_before_below_median = []\n",
        "        grad_before_above_median = []\n",
        "        grad_before_Medicare = []\n",
        "        grad_before_Private = []\n",
        "        grad_before_Medicaid = []\n",
        "        grad_before_Government = []\n",
        "        grad_before_SelfPay = []\n",
        "        grad_before_Senior = []\n",
        "        grad_before_Middle_aged = []\n",
        "        grad_before_Elderly = []\n",
        "        grad_before_Adult = []\n",
        "        grad_before_Young_Adult = []\n",
        "        grad_before_Child = []\n",
        "\n",
        "\n",
        "\n",
        "        grad_after_below_median = []\n",
        "        grad_after_above_median = []\n",
        "        grad_after_Medicare = []\n",
        "        grad_after_Private = []\n",
        "        grad_after_Medicaid = []\n",
        "        grad_after_Government = []\n",
        "        grad_after_SelfPay = []\n",
        "        grad_after_Senior = []\n",
        "        grad_after_Middle_aged = []\n",
        "        grad_after_Elderly = []\n",
        "        grad_after_Adult = []\n",
        "        grad_after_Young_Adult = []\n",
        "        grad_after_Child = []\n",
        "\n",
        "        f1score_below_median = []\n",
        "        f1score_above_median = []\n",
        "        f1score_Medicare = []\n",
        "        f1score_Private = []\n",
        "        f1score_Medicaid = []\n",
        "        f1score_Government = []\n",
        "        f1score_SelfPay = []\n",
        "        f1score_Senior = []\n",
        "        f1score_Middle_aged = []\n",
        "        f1score_Elderly = []\n",
        "        f1score_Adult = []\n",
        "        f1score_Young_Adult = []\n",
        "        f1score_Child = []\n",
        "\n",
        "        ratio_below_median = []\n",
        "        ratio_above_median = []\n",
        "        ratio_Medicare = []\n",
        "        ratio_Private = []\n",
        "        ratio_Medicaid = []\n",
        "        ratio_Government = []\n",
        "        ratio_SelfPay = []\n",
        "        ratio_Senior = []\n",
        "        ratio_Middle_aged = []\n",
        "        ratio_Elderly = []\n",
        "        ratio_Adult = []\n",
        "        ratio_Young_Adult = []\n",
        "        ratio_Child = []\n",
        "\n",
        "        loss_below_median = []\n",
        "        loss_above_median = []\n",
        "        loss_Medicare = []\n",
        "        loss_private = []\n",
        "        loss_Medicaid = []\n",
        "        loss_Government = []\n",
        "        loss_SelfPay = []\n",
        "\n",
        "        #with torch.no_grad():\n",
        "        with BatchMemoryManager(\n",
        "                data_loader=dataloader,\n",
        "                max_physical_batch_size=micro_batch_size,\n",
        "                optimizer=optimizer\n",
        "            ) as memory_safe_data_loader:\n",
        "\n",
        "            for batch in dataloader:\n",
        "\n",
        "                total_micro_batches += 1\n",
        "\n",
        "                optimizer.zero_grad()\n",
        "                batch = {k: v.to('cuda') for k, v in batch.items()}\n",
        "                outputs = model(**batch)\n",
        "                loss_fn = torch.nn.CrossEntropyLoss(reduction=\"sum\")\n",
        "                loss = loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)), batch['labels'].view(-1))\n",
        "                eval_loss += loss.item()\n",
        "                logits = outputs.logits\n",
        "                labels = batch['labels']\n",
        "\n",
        "                # Compute and log metrics every batch\n",
        "                predictions = logits.detach().cpu().numpy()\n",
        "                labels_cpu = labels.detach().cpu().numpy()\n",
        "                all_logical_predictions.extend(predictions)\n",
        "                all_logical_labels.extend(labels_cpu)\n",
        "\n",
        "                # Backward pass\n",
        "                loss.backward()\n",
        "\n",
        "                # Compute gradient norm before clipping\n",
        "                for name, p in model.named_parameters():\n",
        "                    if p.requires_grad:\n",
        "                        with torch.no_grad():\n",
        "                            grad_sample = p.grad_sample.sum(dim = 0).clone().detach()\n",
        "\n",
        "                            if name not in layer_wise_grad_samples:\n",
        "                                layer_wise_grad_samples[name] = torch.zeros_like(grad_sample)  # Store per-layer name\n",
        "\n",
        "                            layer_wise_grad_samples[name] += grad_sample  # Accumulate across physical batches\n",
        "\n",
        "                is_final_step = optimizer.pre_step()\n",
        "\n",
        "                if is_final_step:\n",
        "                    total_grad_norm_squared = 0.0\n",
        "                    total_grad_norm_squared_summed_noisy = 0.0\n",
        "                    total_noise_grad = 0.0\n",
        "                    all_layer_grads_before=[]\n",
        "                    all_layer_grads_after=[]\n",
        "\n",
        "                    for name, grad_list in layer_wise_grad_samples.items():\n",
        "                        all_layer_grads_before.append(grad_list)  # Store per-layer gradients\n",
        "                        ## Compute the L2 norm squared\n",
        "                        total_grad_norm_squared += torch.norm(grad_list, p=2) ** 2\n",
        "\n",
        "\n",
        "\n",
        "                    overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "\n",
        "                    for name, p in model.named_parameters():\n",
        "                        if p.requires_grad:\n",
        "                            with torch.no_grad():\n",
        "                                all_layer_grads_after.append(p.summed_grad.clone().detach())\n",
        "\n",
        "                                layer_grad_norm_summed_noisy = torch.norm(p.summed_grad, p=2).item()\n",
        "                                total_grad_norm_squared_summed_noisy += layer_grad_norm_summed_noisy ** 2\n",
        "\n",
        "                                grad = torch.norm(p.grad, p=2).item()\n",
        "                                total_noise_grad += (grad - layer_grad_norm_summed_noisy) ** 2\n",
        "\n",
        "\n",
        "                    # Convert lists to single tensors\n",
        "                    total_grad_before = torch.cat([v.view(-1) for v in all_layer_grads_before])\n",
        "                    total_grad_after = torch.cat([v.view(-1) for v in all_layer_grads_after])  # Aggregate all layers\n",
        "\n",
        "\n",
        "                    # Compute Cosine Similarity\n",
        "                    cosine_similarity = torch.nn.functional.cosine_similarity(\n",
        "                        total_grad_before.view(1, -1),\n",
        "                        total_grad_after.view(1, -1),\n",
        "                        dim=1\n",
        "                    ).item()\n",
        "\n",
        "                    # Compute Angle in Degrees\n",
        "                    angle = torch.acos(torch.clamp(torch.tensor(cosine_similarity), -1.0, 1.0)).item() * (180 / torch.pi)\n",
        "\n",
        "                    # Store values for logging\n",
        "                    cosine_similarities.append(cosine_similarity)\n",
        "                    angles.append(angle)\n",
        "\n",
        "                    #print(f\"Cosine Similarity: {cosine_similarity}\")\n",
        "                    #print(f\"Angle in Degrees: {angle}\")\n",
        "\n",
        "                    # Compute overall gradient norm\n",
        "                    overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "                    overall_grad_norm_summed_noisy = total_grad_norm_squared_summed_noisy ** 0.5\n",
        "                    overal_noise_grad = total_noise_grad ** 0.5\n",
        "\n",
        "                    #print(f\"logical batch grad_norm before clipping: {overall_grad_norm}\")\n",
        "                    #print(f\"logical batch grad norm after clipping with Noise: {overall_grad_norm_summed_noisy}\")\n",
        "                    #print(f\"overal noise: grad - summed_grad: {overal_noise_grad}\")\n",
        "                    #print(f\"summed_grad/noise grad: {overall_grad_norm_summed_noisy/overal_noise_grad}\")\n",
        "\n",
        "                    # Compute metrics after all physical batches in the logical batch are processed\n",
        "                    eval_metrics = compute_metrics(np.array(all_logical_predictions), np.array(all_logical_labels))\n",
        "                    avg_eval_loss = eval_loss / total_micro_batches\n",
        "\n",
        "                    # Compute the average similarity and angle across all layers\n",
        "                    avg_cosine_similarity = sum(cosine_similarities) / len(cosine_similarities)\n",
        "                    avg_angle = sum(angles) / len(angles)\n",
        "\n",
        "                    if groupname=='age_median':\n",
        "                    # Log to wandb\n",
        "                        wandb.log({\n",
        "                            f\"eval_loss_{group}\": avg_eval_loss,\n",
        "                            f\"eval_accuracy_{group}\": eval_metrics['accuracy'],\n",
        "                            f\"eval_precision_{group}\": eval_metrics['precision'],\n",
        "                            f\"eval_recall_{group}\": eval_metrics['recall'],\n",
        "                            f\"eval_f1_{group}\": eval_metrics['f1'],\n",
        "                            f\"logical batch grad_norm before clipping: {group}\": overall_grad_norm,\n",
        "                            f\"logical batch grad norm after clipping with Noise: {group}\": overall_grad_norm_summed_noisy,\n",
        "                            f\"overal noise: grad - summed_gradg: {group}\": overal_noise_grad,\n",
        "                            f\"summed_grad/noise grad: {group}\":(overall_grad_norm_summed_noisy/overal_noise_grad),\n",
        "                            f\"norm after clipping / norm before clipping: {group} \" :overall_grad_norm_summed_noisy/overall_grad_norm,\n",
        "                            f\"logical_batch_cosine_similarity: {group}\": avg_cosine_similarity,\n",
        "                            f\"logical_batch_angle_degrees: {group}\": avg_angle,\n",
        "                            f\"Cosine Similarity_{group}\": cosine_similarity,\n",
        "                            f\"Angle in Degrees_{group}\": angle\n",
        "                        })\n",
        "\n",
        "                        #print(f\"eval_loss_{group}: {avg_eval_loss}\")\n",
        "                        #print(f\"eval_accuracy_{group}: {eval_metrics['accuracy']}\")\n",
        "                        #print(f\"eval_precision_{group}: {eval_metrics['precision']}\")\n",
        "                        #print(f\"eval_recall_{group}: { eval_metrics['recall']}\")\n",
        "                        #print(f\"eval_f1_{group}: {eval_metrics['f1']}\")\n",
        "                        #print(f\"logical_batch_cosine_similarity{group}: {avg_cosine_similarity}\")\n",
        "                        #print(f\"logical_batch_angle_degrees{group}: {avg_angle}\")\n",
        "\n",
        "                        if group=='below_median':\n",
        "                            accuracy_below_median.append(eval_metrics['accuracy'])\n",
        "                            f1score_below_median.append(eval_metrics['f1'])\n",
        "                            grad_before_below_median.append(overall_grad_norm)\n",
        "                            grad_after_below_median.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_below_median.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "                            loss_below_median.append(avg_eval_loss)\n",
        "\n",
        "                        elif group=='above_median':\n",
        "                            accuracy_above_median.append(eval_metrics['accuracy'])\n",
        "                            f1score_above_median.append(eval_metrics['f1'])\n",
        "                            grad_before_above_median.append(overall_grad_norm)\n",
        "                            grad_after_above_median.append(overall_grad_norm_summed_noisy)\n",
        "                            ratio_above_median.append(overall_grad_norm_summed_noisy/overall_grad_norm)\n",
        "                            loss_above_median.append(avg_eval_loss)\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "                    all_logical_labels = []\n",
        "                    all_logical_predictions = []\n",
        "                    grad_list = []  # Reset grad_list\n",
        "                    layer_wise_grad_samples = {}\n",
        "\n",
        "\n",
        "                optimizer.zero_grad()\n",
        "                count += 1\n",
        "\n",
        "            if groupname=='age_median':\n",
        "\n",
        "                if group == 'below_median':\n",
        "                    avg_grad_norm_before_below_median = sum(grad_before_below_median)/len(grad_before_below_median)\n",
        "                    avg_grad_norm_after_below_median = sum(grad_after_below_median) / len(grad_after_below_median)\n",
        "                    avg_accuracy_below_median = sum(accuracy_below_median) / len(accuracy_below_median)\n",
        "                    avg_f1_below_median = sum(f1score_below_median) / len(f1score_below_median)\n",
        "                    avg_loss_below_median = sum(loss_below_median) / len(loss_below_median)\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_below_median}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_below_median}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_below_median}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_below_median}\")\n",
        "                    print(f\"Average loss for {group}: {avg_loss_below_median}\")\n",
        "\n",
        "                elif group == 'above_median':\n",
        "                    avg_grad_norm_before_above_median = sum(grad_before_above_median)/len(grad_before_above_median)\n",
        "                    avg_grad_norm_after_above_median = sum(grad_after_above_median) / len(grad_after_above_median)\n",
        "                    avg_accuracy_above_median = sum(accuracy_above_median) / len(accuracy_above_median)\n",
        "                    avg_f1_above_median = sum(f1score_above_median) / len(f1score_above_median)\n",
        "                    avg_loss_above_median = sum(loss_above_median) / len(loss_above_median)\n",
        "                    print(f\"Average Gradient Norm Before Clipping {group}: {avg_grad_norm_before_above_median}\")\n",
        "                    print(f\"Average Gradient Norm After Clipping {group}: {avg_grad_norm_after_above_median}\")\n",
        "                    print(f\"Average Accuracy for {group}: {avg_accuracy_above_median}\")\n",
        "                    print(f\"Average f1score for {group}: {avg_f1_above_median}\")\n",
        "                    print(f\"Average loss for {group}: {avg_loss_above_median}\")\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "    # Compute gradient norms for above_median and male datasets\n",
        "    print(\"Computing gradient norms for above_median dataset:\")\n",
        "    compute_gradient_norms(model, above_median_dataloader, optimizer,'above_median', groupname='age_median')\n",
        "\n",
        "    print(\"\\nComputing gradient norms for below_median dataset:\")\n",
        "    compute_gradient_norms(model, below_median_dataloader, optimizer,'below_median', groupname='age_median')\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    import fire\n",
        "    fire.Fire(main)"
      ]
    }
  ]
}