{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SSYeyifq4aLw"
      },
      "outputs": [],
      "source": [
        "from opacus.optimizers import DPOptimizer\n",
        "from opacus.optimizers.optimizer import (\n",
        "    _check_processed_flag,\n",
        "    _generate_noise,\n",
        "    _mark_as_processed\n",
        ")\n",
        "\n",
        "class NormalFairOptimizer(DPOptimizer):\n",
        "    def __init__(self, *args, **kwargs):\n",
        "        super().__init__(*args,**kwargs)\n",
        "\n",
        "\n",
        "    def clip_and_accumulate(self):\n",
        "        \"\"\"\n",
        "        Performs gradient clipping.\n",
        "        Stores clipped and aggregated gradients into `p.summed_grad```\n",
        "        \"\"\"\n",
        "\n",
        "        if len(self.grad_samples[0]) == 0:\n",
        "            # Empty batch\n",
        "            per_sample_clip_factor = torch.zeros(\n",
        "                (0,), device=self.grad_samples[0].device\n",
        "            )\n",
        "        else:\n",
        "            per_param_norms = [\n",
        "                g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples\n",
        "            ]\n",
        "            per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)\n",
        "            per_sample_clip_factor = torch.tanh(self.max_grad_norm / (per_sample_norms + 1e-6))  #the only change to make it fair with tanh\n",
        "\n",
        "        for p in self.params:\n",
        "            _check_processed_flag(p.grad_sample)\n",
        "            grad_sample = self._get_flat_grad_sample(p)\n",
        "            grad = torch.einsum(\"i,i...\", per_sample_clip_factor, grad_sample)\n",
        "\n",
        "            if p.summed_grad is not None:\n",
        "                p.summed_grad += grad\n",
        "            else:\n",
        "                p.summed_grad = grad\n",
        "\n",
        "            _mark_as_processed(p.grad_sample)\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from opacus import PrivacyEngine\n",
        "import torch.optim as optim\n",
        "from typing import List, Union\n",
        "from opacus.optimizers import DPOptimizer\n",
        "\n",
        "class NormalfairPrivacyEngine(PrivacyEngine):\n",
        "    def __init__(self, *args, **kwargs):\n",
        "        super().__init__(*args, **kwargs)\n",
        "\n",
        "    def _prepare_optimizer(\n",
        "        self,\n",
        "        *,\n",
        "        optimizer: optim.Optimizer,\n",
        "        noise_multiplier: float,\n",
        "        max_grad_norm: Union[float, List[float]],\n",
        "        expected_batch_size: int,\n",
        "        loss_reduction: str = \"mean\",\n",
        "        distributed: bool = False,\n",
        "        clipping: str = \"flat\",\n",
        "        noise_generator=None,\n",
        "        grad_sample_mode=\"hooks\",\n",
        "        **kwargs,\n",
        "    ) -> DPOptimizer:\n",
        "        if isinstance(optimizer, DPOptimizer):\n",
        "            optimizer = optimizer.original_optimizer\n",
        "\n",
        "        generator = None\n",
        "        if self.secure_mode:\n",
        "            generator = self.secure_rng\n",
        "        elif noise_generator is not None:\n",
        "            generator = noise_generator\n",
        "\n",
        "        new_fair = NormalFairOptimizer(\n",
        "            optimizer=optimizer,\n",
        "            noise_multiplier=noise_multiplier,\n",
        "            max_grad_norm=max_grad_norm,\n",
        "            expected_batch_size=expected_batch_size,\n",
        "            loss_reduction=loss_reduction,\n",
        "            generator=generator,\n",
        "            secure_mode=self.secure_mode,\n",
        "            **kwargs,\n",
        "        )\n",
        "\n",
        "        return new_fair\n",
        ""
      ],
      "metadata": {
        "id": "RzGGwGTy4gkR"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}