{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4",
      "collapsed_sections": [
        "jQdA66phYStq",
        "Q6Xk2HrvYV2E",
        "_nHtBAeNkzfS",
        "Qvf0KtmiL2H-",
        "hOAtK0XQ4CBn",
        "YakPqpgVX6en"
      ]
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "a995221ad90d4a1e8b2f292a2144836a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_39f80b41a398473aa0753fb1f9a85e05",
              "IPY_MODEL_f453a12668c44b70bd027828cb22d3f9",
              "IPY_MODEL_ab43eae1b2254714a4c6c2e8c5946b08"
            ],
            "layout": "IPY_MODEL_2f886b935c124a83a9e6bb5b4eb1318c"
          }
        },
        "39f80b41a398473aa0753fb1f9a85e05": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_77c6ad26324b47aebd59c01cdc5e9ffb",
            "placeholder": "​",
            "style": "IPY_MODEL_8abff24900964d7fb8808bc379f15715",
            "value": "README.md: "
          }
        },
        "f453a12668c44b70bd027828cb22d3f9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_690f534690aa4360b6d25e29a45dbbe6",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_f4fdf8d55afc4f1d8778c97ad4c0065f",
            "value": 1
          }
        },
        "ab43eae1b2254714a4c6c2e8c5946b08": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_b34dec285b064459956df5cafff604dd",
            "placeholder": "​",
            "style": "IPY_MODEL_1b41483d67474afd8c9c489b719ccf1d",
            "value": " 6.09k/? [00:00&lt;00:00, 153kB/s]"
          }
        },
        "2f886b935c124a83a9e6bb5b4eb1318c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "77c6ad26324b47aebd59c01cdc5e9ffb": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8abff24900964d7fb8808bc379f15715": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "690f534690aa4360b6d25e29a45dbbe6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "20px"
          }
        },
        "f4fdf8d55afc4f1d8778c97ad4c0065f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "b34dec285b064459956df5cafff604dd": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1b41483d67474afd8c9c489b719ccf1d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "87a65870c8a2413a955fa7f6580d06dc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_860a84b8925145bbbb03a5707a29fee2",
              "IPY_MODEL_2534511ff7bf4c26be3798f872332ecd",
              "IPY_MODEL_bcf1434503c64467b1933816504b56ff"
            ],
            "layout": "IPY_MODEL_5492c4d0bdf94c6cb11f3b5a2588ff6f"
          }
        },
        "860a84b8925145bbbb03a5707a29fee2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_245a95fee2fc40bab018983aed796c94",
            "placeholder": "​",
            "style": "IPY_MODEL_af479a90137e442398fa4aed086cd156",
            "value": "data/train-00000-of-00001.parquet: 100%"
          }
        },
        "2534511ff7bf4c26be3798f872332ecd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e61976050df44f8d9ba15393b39472d1",
            "max": 200734290,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_7ca1b8c20fde4049973f521ce5721418",
            "value": 200734290
          }
        },
        "bcf1434503c64467b1933816504b56ff": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_35e5a2e314ea4e22ad64fef2097a66b1",
            "placeholder": "​",
            "style": "IPY_MODEL_756b21451452433599c8320de3e47129",
            "value": " 201M/201M [00:02&lt;00:00, 114MB/s]"
          }
        },
        "5492c4d0bdf94c6cb11f3b5a2588ff6f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "245a95fee2fc40bab018983aed796c94": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "af479a90137e442398fa4aed086cd156": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "e61976050df44f8d9ba15393b39472d1": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7ca1b8c20fde4049973f521ce5721418": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "35e5a2e314ea4e22ad64fef2097a66b1": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "756b21451452433599c8320de3e47129": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "ad3640ee50f0407ea4276796c2eb5ccc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_5fb3c0b71298475db284141676bbd5f7",
              "IPY_MODEL_f1d8b4890d934be2a415725bf3f6b20b",
              "IPY_MODEL_6fe543e0e6354dbc800383abcfbdaf97"
            ],
            "layout": "IPY_MODEL_f0c814e6c640436781724cac8c502997"
          }
        },
        "5fb3c0b71298475db284141676bbd5f7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d164e8a25b15426daa6d358c066c3501",
            "placeholder": "​",
            "style": "IPY_MODEL_07666125af1543ada56d7c69fd3e09c6",
            "value": "Generating train split: 100%"
          }
        },
        "f1d8b4890d934be2a415725bf3f6b20b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c6c95710fae848d28f636e9694ca53b1",
            "max": 814277,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_74c2d043833d48d9928cb24cc5992e38",
            "value": 814277
          }
        },
        "6fe543e0e6354dbc800383abcfbdaf97": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_60954716da3c4de689a229ee1124b3e3",
            "placeholder": "​",
            "style": "IPY_MODEL_ed0881322e2647bc8418f8ef738cd109",
            "value": " 814277/814277 [00:06&lt;00:00, 140969.79 examples/s]"
          }
        },
        "f0c814e6c640436781724cac8c502997": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d164e8a25b15426daa6d358c066c3501": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "07666125af1543ada56d7c69fd3e09c6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c6c95710fae848d28f636e9694ca53b1": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "74c2d043833d48d9928cb24cc5992e38": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "60954716da3c4de689a229ee1124b3e3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ed0881322e2647bc8418f8ef738cd109": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Stylized Simulation"
      ],
      "metadata": {
        "id": "jQdA66phYStq"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "from dataclasses import dataclass, asdict\n",
        "from typing import Dict, Any, Optional, List\n",
        "\n",
        "\n",
        "# ===================================================================\n",
        "# 0. Utilities\n",
        "# ===================================================================\n",
        "\n",
        "def sigmoid(x: np.ndarray) -> np.ndarray:\n",
        "    return 1.0 / (1.0 + np.exp(-x))\n",
        "\n",
        "\n",
        "# ===================================================================\n",
        "# 1. Simulation configuration\n",
        "# ===================================================================\n",
        "\n",
        "@dataclass\n",
        "class SimulationConfig:\n",
        "    # basic structure\n",
        "    n_clients: int = 100               # total number of clients\n",
        "    T: int = 300                       # number of rounds\n",
        "    gaming_frac: float = 0.3           # fraction of clients using gaming strategy\n",
        "\n",
        "    # reward-related\n",
        "    base_reward: float = 1.0           # scale from metric to reward\n",
        "    reward_bias: float = 0.2           # baseline reward independent of metric\n",
        "\n",
        "    # audit / penalty\n",
        "    audit_budget: int = 5              # maximum number of clients audited per round\n",
        "    alpha_penalty: float = 0.7         # penalty strength α when caught\n",
        "    p_detect: float = 0.7              # probability of detecting gaming when audited\n",
        "\n",
        "    # public / private metric mix\n",
        "    rho_pub: float = 0.6               # weight on public metric (PB vs PC mix)\n",
        "    noise_pub: float = 0.01            # noise on public metric\n",
        "    noise_priv: float = 0.01           # noise on private metric\n",
        "\n",
        "    # welfare impact\n",
        "    gamma_welfare: float = 0.7         # coefficient γ: how strongly gaming harms welfare\n",
        "\n",
        "    # participation cost / decision parameters\n",
        "    base_cost_mean: float = 0.4        # mean participation cost across clients\n",
        "    base_cost_std: float = 0.1         # std dev of participation cost\n",
        "    participation_beta: float = 4.0    # logistic slope (larger → sharper threshold)\n",
        "    participation_bias: float = 0.0    # logistic offset (baseline profit threshold)\n",
        "\n",
        "    seed: Optional[int] = None         # random seed (for reproducibility)\n",
        "\n",
        "\n",
        "# ===================================================================\n",
        "# 2. Client initialization\n",
        "# ===================================================================\n",
        "\n",
        "def init_clients(cfg: SimulationConfig) -> Dict[str, Any]:\n",
        "    \"\"\"\n",
        "    For each client, initialize:\n",
        "    - type: 'honest' or 'gaming'\n",
        "    - cost: participation cost\n",
        "    \"\"\"\n",
        "    rng = np.random.default_rng(cfg.seed)\n",
        "\n",
        "    n = cfg.n_clients\n",
        "    n_gaming = int(np.round(cfg.gaming_frac * n))\n",
        "\n",
        "    client_type = np.array(['honest'] * n, dtype=object)\n",
        "    if n_gaming > 0:\n",
        "        gaming_idx = rng.choice(n, size=n_gaming, replace=False)\n",
        "        client_type[gaming_idx] = 'gaming'\n",
        "\n",
        "    # participation cost: controlled by mean/std in config\n",
        "    cost = rng.normal(loc=cfg.base_cost_mean,\n",
        "                      scale=cfg.base_cost_std,\n",
        "                      size=n)\n",
        "    cost = np.clip(cost, 0.05, None)  # avoid negative or too small costs\n",
        "\n",
        "    return {\n",
        "        \"client_type\": client_type,\n",
        "        \"cost\": cost,\n",
        "        \"rng\": rng,\n",
        "    }\n",
        "\n",
        "\n",
        "# ===================================================================\n",
        "# 3. Single-round simulation\n",
        "# ===================================================================\n",
        "\n",
        "def simulate_round(\n",
        "    cfg: SimulationConfig,\n",
        "    state: Dict[str, Any],\n",
        "    M_prev: float\n",
        ") -> Dict[str, Any]:\n",
        "\n",
        "    client_type = state[\"client_type\"]\n",
        "    cost = state[\"cost\"]\n",
        "    rng = state[\"rng\"]\n",
        "\n",
        "    n = cfg.n_clients\n",
        "\n",
        "    # ---------------------------------------------------------------\n",
        "    # 3.1 Participation decision: logistic-based probabilistic join\n",
        "    # ---------------------------------------------------------------\n",
        "    approx_p_audit = cfg.audit_budget / max(n, 1)\n",
        "\n",
        "    expected_penalty_gaming = cfg.alpha_penalty * approx_p_audit * cfg.p_detect\n",
        "    expected_penalty_honest = 0.0\n",
        "\n",
        "    profit = np.zeros(n)\n",
        "\n",
        "    mask_g = (client_type == 'gaming')\n",
        "    mask_h = (client_type == 'honest')\n",
        "\n",
        "    profit[mask_g] = (\n",
        "        cfg.reward_bias\n",
        "        + cfg.base_reward * M_prev\n",
        "        - expected_penalty_gaming\n",
        "        - cost[mask_g]\n",
        "    )\n",
        "    profit[mask_h] = (\n",
        "        cfg.reward_bias\n",
        "        + cfg.base_reward * M_prev\n",
        "        - expected_penalty_honest\n",
        "        - cost[mask_h]\n",
        "    )\n",
        "\n",
        "    logits = cfg.participation_beta * (profit - cfg.participation_bias)\n",
        "    p_participate = sigmoid(logits)\n",
        "\n",
        "    participate = rng.random(size=n) < p_participate\n",
        "    participants_idx = np.where(participate)[0]\n",
        "    n_participants = len(participants_idx)\n",
        "\n",
        "    # ---------------------------------------------------------------\n",
        "    # 3.2 Number of honest / gaming participants and participation rate\n",
        "    # ---------------------------------------------------------------\n",
        "    if n_participants > 0:\n",
        "        types_participants = client_type[participants_idx]\n",
        "        H = int(np.sum(types_participants == 'honest'))\n",
        "        G = int(np.sum(types_participants == 'gaming'))\n",
        "    else:\n",
        "        H = 0\n",
        "        G = 0\n",
        "\n",
        "    x_t = n_participants / n\n",
        "\n",
        "    # ---------------------------------------------------------------\n",
        "    # 3.3 Welfare and metric computation\n",
        "    # ---------------------------------------------------------------\n",
        "    if n == 0:\n",
        "        W_t = 0.0\n",
        "    else:\n",
        "        W_t = (H - cfg.gamma_welfare * G) / n\n",
        "        W_t = float(np.clip(W_t, 0.0, 1.0))\n",
        "\n",
        "    delta_metric = 0.3\n",
        "    noise_pub = rng.normal(0.0, cfg.noise_pub)\n",
        "    noise_priv = rng.normal(0.0, cfg.noise_priv)\n",
        "\n",
        "    M_pub = W_t + delta_metric * (G / max(n, 1)) + noise_pub\n",
        "    M_priv = W_t + noise_priv\n",
        "\n",
        "    M_t = cfg.rho_pub * M_pub + (1.0 - cfg.rho_pub) * M_priv\n",
        "    M_t = float(np.clip(M_t, 0.0, 1.0))\n",
        "\n",
        "    # ---------------------------------------------------------------\n",
        "    # 3.4 Auditing and penalties\n",
        "    # ---------------------------------------------------------------\n",
        "    audited_idx = np.array([], dtype=int)\n",
        "    penalized_idx = np.array([], dtype=int)\n",
        "\n",
        "    if n_participants > 0 and cfg.audit_budget > 0:\n",
        "        k = min(cfg.audit_budget, n_participants)\n",
        "        audited_idx = rng.choice(participants_idx, size=k, replace=False)\n",
        "\n",
        "        is_gaming_audited = (client_type[audited_idx] == 'gaming')\n",
        "        detect_mask = rng.random(size=k) < (cfg.p_detect * is_gaming_audited.astype(float))\n",
        "        penalized_idx = audited_idx[detect_mask]\n",
        "\n",
        "    rewards = np.zeros(n)\n",
        "    if n_participants > 0:\n",
        "        rewards[participants_idx] = cfg.base_reward * M_t - cost[participants_idx]\n",
        "        rewards[penalized_idx] -= cfg.alpha_penalty\n",
        "\n",
        "    return {\n",
        "        \"W_t\": W_t,\n",
        "        \"M_t\": M_t,\n",
        "        \"x_t\": x_t,\n",
        "        \"H_t\": H,\n",
        "        \"G_t\": G,\n",
        "        \"n_participants\": n_participants,\n",
        "        \"audited_idx\": audited_idx,\n",
        "        \"penalized_idx\": penalized_idx,\n",
        "        \"rewards\": rewards,\n",
        "    }\n",
        "\n",
        "\n",
        "# ===================================================================\n",
        "# 4. Full simulation runner\n",
        "# ===================================================================\n",
        "\n",
        "def run_simulation(cfg: SimulationConfig) -> pd.DataFrame:\n",
        "    \"\"\"\n",
        "    Run T rounds of simulation under a single policy configuration (cfg)\n",
        "    and return a DataFrame of per-round logs.\n",
        "    Logged fields: W_t, M_t, x_t, H_t, G_t, n_participants, n_audited, n_penalized.\n",
        "    \"\"\"\n",
        "    state = init_clients(cfg)\n",
        "    M_prev = 0.6  # initial metric\n",
        "\n",
        "    logs: List[Dict[str, Any]] = []\n",
        "\n",
        "    for t in range(cfg.T):\n",
        "        round_result = simulate_round(cfg, state, M_prev)\n",
        "\n",
        "        logs.append({\n",
        "            \"t\": t,\n",
        "            \"W_t\": round_result[\"W_t\"],\n",
        "            \"M_t\": round_result[\"M_t\"],\n",
        "            \"x_t\": round_result[\"x_t\"],\n",
        "            \"H_t\": round_result[\"H_t\"],\n",
        "            \"G_t\": round_result[\"G_t\"],\n",
        "            \"n_participants\": round_result[\"n_participants\"],\n",
        "            \"n_audited\": len(round_result[\"audited_idx\"]),\n",
        "            \"n_penalized\": len(round_result[\"penalized_idx\"]),\n",
        "        })\n",
        "\n",
        "        M_prev = round_result[\"M_t\"]\n",
        "\n",
        "    df = pd.DataFrame(logs)\n",
        "    return df\n",
        "\n",
        "\n",
        "# ===================================================================\n",
        "# 5. Helper for PoG estimation (aligned vs gaming)\n",
        "# ===================================================================\n",
        "\n",
        "def estimate_price_of_gaming(\n",
        "    cfg_aligned: SimulationConfig,\n",
        "    cfg_gaming: SimulationConfig,\n",
        "    burn_in: int = 100\n",
        ") -> Dict[str, Any]:\n",
        "    \"\"\"\n",
        "    Compare two scenarios to approximate PoG:\n",
        "    - cfg_aligned: gaming_frac = 0\n",
        "    - cfg_gaming: gaming_frac > 0\n",
        "\n",
        "    We approximate steady state by averaging W_t, x_t, M_t after 'burn_in' rounds.\n",
        "    \"\"\"\n",
        "\n",
        "    df_align = run_simulation(cfg_aligned)\n",
        "    df_game = run_simulation(cfg_gaming)\n",
        "\n",
        "    mask_align = (df_align[\"t\"] >= burn_in)\n",
        "    mask_game = (df_game[\"t\"] >= burn_in)\n",
        "\n",
        "    W_align = df_align.loc[mask_align, \"W_t\"].mean()\n",
        "    W_game = df_game.loc[mask_game, \"W_t\"].mean()\n",
        "\n",
        "    x_align = df_align.loc[mask_align, \"x_t\"].mean()\n",
        "    x_game = df_game.loc[mask_game, \"x_t\"].mean()\n",
        "\n",
        "    M_align = df_align.loc[mask_align, \"M_t\"].mean()\n",
        "    M_game = df_game.loc[mask_game, \"M_t\"].mean()\n",
        "\n",
        "    if W_align <= 0:\n",
        "        PoG = np.nan\n",
        "    else:\n",
        "        PoG = (W_align - W_game) / W_align\n",
        "\n",
        "    return {\n",
        "        \"W_align\": W_align,\n",
        "        \"W_game\": W_game,\n",
        "        \"x_align\": x_align,\n",
        "        \"x_game\": x_game,\n",
        "        \"M_align\": M_align,\n",
        "        \"M_game\": M_game,\n",
        "        \"PoG\": PoG,\n",
        "        \"cfg_aligned\": asdict(cfg_aligned),\n",
        "        \"cfg_gaming\": asdict(cfg_gaming),\n",
        "        \"df_aligned\": df_align,\n",
        "        \"df_gaming\": df_game,\n",
        "    }\n",
        "\n",
        "\n",
        "# ===================================================================\n",
        "# 6. Experiment 1: sweep alpha_penalty vs PoG\n",
        "# ===================================================================\n",
        "\n",
        "def sweep_alpha(\n",
        "    base_cfg: SimulationConfig,\n",
        "    alphas: List[float],\n",
        "    burn_in: int = 100,\n",
        ") -> pd.DataFrame:\n",
        "    \"\"\"\n",
        "    Vary alpha_penalty and measure PoG and participation.\n",
        "    \"\"\"\n",
        "    rows = []\n",
        "\n",
        "    for alpha in alphas:\n",
        "        cfg_gaming = SimulationConfig(**{**asdict(base_cfg),\n",
        "                                         \"gaming_frac\": base_cfg.gaming_frac,\n",
        "                                         \"alpha_penalty\": alpha})\n",
        "        cfg_aligned = SimulationConfig(**{**asdict(base_cfg),\n",
        "                                          \"gaming_frac\": 0.0,\n",
        "                                          \"alpha_penalty\": alpha})\n",
        "\n",
        "        result = estimate_price_of_gaming(cfg_aligned, cfg_gaming, burn_in=burn_in)\n",
        "\n",
        "        rows.append({\n",
        "            \"alpha_penalty\": alpha,\n",
        "            \"W_align\": result[\"W_align\"],\n",
        "            \"W_game\": result[\"W_game\"],\n",
        "            \"x_align\": result[\"x_align\"],\n",
        "            \"x_game\": result[\"x_game\"],\n",
        "            \"M_align\": result[\"M_align\"],\n",
        "            \"M_game\": result[\"M_game\"],\n",
        "            \"PoG\": result[\"PoG\"],\n",
        "        })\n",
        "\n",
        "    return pd.DataFrame(rows)\n",
        "\n",
        "\n",
        "# ===================================================================\n",
        "# 7. Experiment 2: sweep rho_pub (public metric weight) vs PoG\n",
        "# ===================================================================\n",
        "\n",
        "def sweep_rho_pub(\n",
        "    base_cfg: SimulationConfig,\n",
        "    rhos: List[float],\n",
        "    burn_in: int = 100,\n",
        ") -> pd.DataFrame:\n",
        "    \"\"\"\n",
        "    Vary rho_pub (weight on public metric) and measure PoG and M-W gap.\n",
        "    \"\"\"\n",
        "    rows = []\n",
        "\n",
        "    for rho in rhos:\n",
        "        cfg_gaming = SimulationConfig(**{**asdict(base_cfg),\n",
        "                                         \"gaming_frac\": base_cfg.gaming_frac,\n",
        "                                         \"rho_pub\": rho})\n",
        "        cfg_aligned = SimulationConfig(**{**asdict(base_cfg),\n",
        "                                          \"gaming_frac\": 0.0,\n",
        "                                          \"rho_pub\": rho})\n",
        "\n",
        "        result = estimate_price_of_gaming(cfg_aligned, cfg_gaming, burn_in=burn_in)\n",
        "\n",
        "        # M - W gap (metric inflation) under each scenario\n",
        "        M_W_gap_align = result[\"M_align\"] - result[\"W_align\"]\n",
        "        M_W_gap_game = result[\"M_game\"] - result[\"W_game\"]\n",
        "\n",
        "        rows.append({\n",
        "            \"rho_pub\": rho,\n",
        "            \"W_align\": result[\"W_align\"],\n",
        "            \"W_game\": result[\"W_game\"],\n",
        "            \"x_align\": result[\"x_align\"],\n",
        "            \"x_game\": result[\"x_game\"],\n",
        "            \"M_align\": result[\"M_align\"],\n",
        "            \"M_game\": result[\"M_game\"],\n",
        "            \"M_W_gap_align\": M_W_gap_align,\n",
        "            \"M_W_gap_game\": M_W_gap_game,\n",
        "            \"PoG\": result[\"PoG\"],\n",
        "        })\n",
        "\n",
        "    return pd.DataFrame(rows)\n",
        "\n",
        "\n",
        "# ===================================================================\n",
        "# 8. Experiment 3: participation vs alpha (resilience / stability)\n",
        "# ===================================================================\n",
        "\n",
        "def sweep_alpha_participation(\n",
        "    base_cfg: SimulationConfig,\n",
        "    alphas: List[float],\n",
        "    burn_in: int = 100,\n",
        ") -> pd.DataFrame:\n",
        "    \"\"\"\n",
        "    Vary alpha_penalty and, under the gaming scenario, measure average participation (x_t).\n",
        "    (We only track the gaming case here rather than aligned vs gaming.)\n",
        "    \"\"\"\n",
        "    rows = []\n",
        "\n",
        "    for alpha in alphas:\n",
        "        cfg_gaming = SimulationConfig(**{**asdict(base_cfg),\n",
        "                                         \"alpha_penalty\": alpha})\n",
        "\n",
        "        df_game = run_simulation(cfg_gaming)\n",
        "        mask_game = (df_game[\"t\"] >= burn_in)\n",
        "\n",
        "        x_mean = df_game.loc[mask_game, \"x_t\"].mean()\n",
        "        W_mean = df_game.loc[mask_game, \"W_t\"].mean()\n",
        "\n",
        "        rows.append({\n",
        "            \"alpha_penalty\": alpha,\n",
        "            \"x_game_mean\": x_mean,\n",
        "            \"W_game_mean\": W_mean,\n",
        "        })\n",
        "\n",
        "    return pd.DataFrame(rows)\n",
        "\n",
        "\n",
        "# ===================================================================\n",
        "# 9. Simple usage example (only when run as a script)\n",
        "# ===================================================================\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    # Base configuration (near the \"bad policy\" region used in the discussion)\n",
        "    base_cfg = SimulationConfig(\n",
        "        n_clients=100,\n",
        "        T=300,\n",
        "        gaming_frac=0.3,\n",
        "        base_reward=1.0,\n",
        "        reward_bias=0.2,\n",
        "        audit_budget=10,\n",
        "        alpha_penalty=0.7,\n",
        "        p_detect=0.7,\n",
        "        rho_pub=0.6,\n",
        "        noise_pub=0.01,\n",
        "        noise_priv=0.01,\n",
        "        gamma_welfare=0.7,\n",
        "        base_cost_mean=0.4,\n",
        "        base_cost_std=0.1,\n",
        "        participation_beta=4.0,\n",
        "        participation_bias=0.0,\n",
        "        seed=42,\n",
        "    )\n",
        "\n",
        "    print(\"=== Single simulation example (gaming scenario) ===\")\n",
        "    df_example = run_simulation(base_cfg)\n",
        "    print(df_example.head())\n",
        "    print(df_example.tail())\n",
        "\n",
        "    print(\"\\n=== PoG (aligned vs gaming) example ===\")\n",
        "    cfg_aligned = SimulationConfig(**{**asdict(base_cfg), \"gaming_frac\": 0.0})\n",
        "    result = estimate_price_of_gaming(cfg_aligned, base_cfg, burn_in=100)\n",
        "    print(f\"W_align ≈ {result['W_align']:.3f}\")\n",
        "    print(f\"W_game  ≈ {result['W_game']:.3f}\")\n",
        "    print(f\"x_align ≈ {result['x_align']:.3f}\")\n",
        "    print(f\"x_game  ≈ {result['x_game']:.3f}\")\n",
        "    print(f\"PoG     ≈ {result['PoG']:.3f}\")\n",
        "\n",
        "    print(\"\\n=== Experiment 1: PoG vs alpha_penalty ===\")\n",
        "    alphas = [0.3, 0.5, 0.7, 1.0, 1.5]\n",
        "    df_alpha = sweep_alpha(base_cfg, alphas, burn_in=100)\n",
        "    print(df_alpha)\n",
        "\n",
        "    print(\"\\n=== Experiment 2: PoG vs rho_pub ===\")\n",
        "    rhos = [1.0, 0.8, 0.6, 0.4, 0.2]\n",
        "    df_rho = sweep_rho_pub(base_cfg, rhos, burn_in=100)\n",
        "    print(df_rho)\n",
        "\n",
        "    print(\"\\n=== Experiment 3: participation vs alpha (gaming only) ===\")\n",
        "    df_alpha_part = sweep_alpha_participation(base_cfg, alphas, burn_in=100)\n",
        "    print(df_alpha_part)"
      ],
      "metadata": {
        "id": "fptdZLLMnpfD"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Real-World Federated Learning"
      ],
      "metadata": {
        "id": "Q6Xk2HrvYV2E"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import random\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from torch.utils.data import DataLoader, Subset, ConcatDataset\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "# ============================================================\n",
        "# 0. Config / Seed\n",
        "# ============================================================\n",
        "\n",
        "class Config:\n",
        "    n_clients = 30\n",
        "    gaming_frac = 0.3\n",
        "    n_rounds = 40\n",
        "    local_epochs = 2\n",
        "    batch_size = 64\n",
        "    lr = 0.01\n",
        "    momentum = 0.9\n",
        "\n",
        "    # Dirichlet non-IID partition parameter\n",
        "    dirichlet_alpha = 0.5\n",
        "\n",
        "    # Fraction of public validation data leaked to gaming clients\n",
        "    leak_fraction = 1.0  # use entire head-only public val for maximum effect\n",
        "\n",
        "    # Class split\n",
        "    head_classes = {0, 1, 2, 3, 4}   # head classes used by the public metric\n",
        "    tail_classes = {5, 6, 7, 8, 9}   # tail classes used by welfare\n",
        "\n",
        "    seed = 42\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "\n",
        "def set_seed(seed: int):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "\n",
        "set_seed(Config.seed)\n",
        "\n",
        "# ============================================================\n",
        "# 1. Dataset preparation (Fashion-MNIST)\n",
        "#    - train 60k -> train_local 50k + public_val (head-biased) ~10k\n",
        "#    - test 10k  -> hidden welfare evaluation on tail-only\n",
        "# ============================================================\n",
        "\n",
        "transform = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize((0.5,), (0.5,))\n",
        "])\n",
        "\n",
        "root = \"./data\"\n",
        "\n",
        "full_train = datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)\n",
        "test_set = datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)\n",
        "\n",
        "# From the 60k training examples:\n",
        "#   first 50k: local training\n",
        "#   last 10k : public validation candidates\n",
        "train_local_size = 50_000\n",
        "indices_all = np.arange(len(full_train))\n",
        "train_local_indices = indices_all[:train_local_size]\n",
        "public_val_candidate_indices = indices_all[train_local_size:]\n",
        "\n",
        "train_local_set = Subset(full_train, train_local_indices)\n",
        "\n",
        "# Public validation: only head classes\n",
        "all_train_targets = np.array(full_train.targets)\n",
        "head_mask = np.isin(all_train_targets, list(Config.head_classes))\n",
        "public_val_head_indices = [idx for idx in public_val_candidate_indices if head_mask[idx]]\n",
        "\n",
        "public_val_head_indices = np.array(public_val_head_indices)\n",
        "rng = np.random.default_rng(Config.seed)\n",
        "rng.shuffle(public_val_head_indices)\n",
        "\n",
        "leak_size = int(Config.leak_fraction * len(public_val_head_indices))\n",
        "leak_indices = public_val_head_indices[:leak_size]\n",
        "\n",
        "public_val_set = Subset(full_train, public_val_head_indices)\n",
        "public_val_loader = DataLoader(public_val_set, batch_size=Config.batch_size, shuffle=False)\n",
        "\n",
        "# Hidden welfare: only tail classes from the test set\n",
        "all_test_targets = np.array(test_set.targets)\n",
        "tail_mask = np.isin(all_test_targets, list(Config.tail_classes))\n",
        "tail_test_indices = np.where(tail_mask)[0]\n",
        "hidden_tail_test_set = Subset(test_set, tail_test_indices)\n",
        "hidden_tail_test_loader = DataLoader(hidden_tail_test_set, batch_size=Config.batch_size, shuffle=False)\n",
        "\n",
        "# (Optional) If we want full test accuracy later\n",
        "full_test_loader = DataLoader(test_set, batch_size=Config.batch_size, shuffle=False)\n",
        "\n",
        "# Public leak subset for gaming clients (head-only)\n",
        "leak_subset = Subset(full_train, leak_indices)\n",
        "\n",
        "\n",
        "# ============================================================\n",
        "# 2. Non-IID partitioning (Dirichlet) for client local datasets\n",
        "# ============================================================\n",
        "\n",
        "def make_dirichlet_partitions(labels: np.ndarray, n_clients: int, alpha: float, rng: np.random.Generator):\n",
        "    \"\"\"\n",
        "    labels: 1D array of class labels for train_local_set\n",
        "    n_clients: number of clients\n",
        "    alpha: Dirichlet concentration parameter\n",
        "\n",
        "    Returns:\n",
        "        list of index arrays per client (indices are with respect to train_local_set)\n",
        "    \"\"\"\n",
        "    n_classes = int(labels.max()) + 1\n",
        "    client_indices = [[] for _ in range(n_clients)]\n",
        "\n",
        "    # Collect indices per class (indices are within train_local_set)\n",
        "    class_indices = []\n",
        "    for k in range(n_classes):\n",
        "        idx_k = np.where(labels == k)[0]\n",
        "        rng.shuffle(idx_k)\n",
        "        class_indices.append(idx_k)\n",
        "\n",
        "    for k in range(n_classes):\n",
        "        idx_k = class_indices[k]\n",
        "        n_k = len(idx_k)\n",
        "        if n_k == 0:\n",
        "            continue\n",
        "\n",
        "        proportions = rng.dirichlet(alpha * np.ones(n_clients))\n",
        "        proportions = proportions / proportions.sum()\n",
        "        splits = (np.cumsum(proportions) * n_k).astype(int)\n",
        "        shard = np.split(idx_k, splits[:-1])\n",
        "\n",
        "        for cid, shard_c in enumerate(shard):\n",
        "            client_indices[cid].extend(shard_c.tolist())\n",
        "\n",
        "    for cid in range(n_clients):\n",
        "        rng.shuffle(client_indices[cid])\n",
        "\n",
        "    return client_indices\n",
        "\n",
        "\n",
        "rng = np.random.default_rng(Config.seed)\n",
        "\n",
        "# Extract labels for train_local_set (in the order of train_local_indices)\n",
        "train_targets = np.array(full_train.targets)[train_local_indices]\n",
        "\n",
        "client_indices = make_dirichlet_partitions(\n",
        "    labels=train_targets,\n",
        "    n_clients=Config.n_clients,\n",
        "    alpha=Config.dirichlet_alpha,\n",
        "    rng=rng\n",
        ")\n",
        "\n",
        "client_datasets_base = [\n",
        "    Subset(train_local_set, idxs) for idxs in client_indices\n",
        "]\n",
        "\n",
        "\n",
        "# ============================================================\n",
        "# 3. Helper to filter a Subset by allowed labels\n",
        "# ============================================================\n",
        "\n",
        "def filter_subset_by_labels(base_subset: Subset, allowed_labels: set[int]) -> Subset:\n",
        "    \"\"\"\n",
        "    base_subset: Subset(train_local_set, ...)\n",
        "    allowed_labels: set of class labels to keep (e.g., head_classes)\n",
        "\n",
        "    Returns:\n",
        "        A new Subset(train_local_set, ...) that only includes samples whose\n",
        "        labels are in allowed_labels.\n",
        "    \"\"\"\n",
        "    assert isinstance(base_subset.dataset, Subset) or isinstance(base_subset.dataset, datasets.FashionMNIST) \\\n",
        "        or isinstance(base_subset.dataset, torch.utils.data.Dataset)\n",
        "\n",
        "    # base_subset is a subset of train_local_set.\n",
        "    # train_local_set.indices: indices into full_train.\n",
        "    # base_subset.indices: indices within train_local_set (0~49999).\n",
        "    kept_indices_in_train_local = []\n",
        "\n",
        "    for idx_in_train_local in base_subset.indices:\n",
        "        global_idx = train_local_set.indices[idx_in_train_local]\n",
        "        label = int(full_train.targets[global_idx])\n",
        "        if label in allowed_labels:\n",
        "            kept_indices_in_train_local.append(idx_in_train_local)\n",
        "\n",
        "    return Subset(train_local_set, kept_indices_in_train_local)\n",
        "\n",
        "\n",
        "# ============================================================\n",
        "# 4. CNN model definition\n",
        "# ============================================================\n",
        "\n",
        "class SimpleCNN(nn.Module):\n",
        "    def __init__(self, num_classes: int = 10):\n",
        "        super().__init__()\n",
        "        self.features = nn.Sequential(\n",
        "            nn.Conv2d(1, 32, kernel_size=3, padding=1),  # 1x28x28 -> 32x28x28\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.MaxPool2d(2),                             # 32x14x14\n",
        "            nn.Conv2d(32, 64, kernel_size=3, padding=1), # 64x14x14\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.MaxPool2d(2),                             # 64x7x7\n",
        "        )\n",
        "        self.classifier = nn.Sequential(\n",
        "            nn.Flatten(),\n",
        "            nn.Linear(64 * 7 * 7, 128),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.Linear(128, num_classes),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.features(x)\n",
        "        x = self.classifier(x)\n",
        "        return x\n",
        "\n",
        "\n",
        "def get_model():\n",
        "    model = SimpleCNN(num_classes=10)\n",
        "    return model.to(Config.device)\n",
        "\n",
        "\n",
        "# ============================================================\n",
        "# 5. Evaluation function (accuracy)\n",
        "# ============================================================\n",
        "\n",
        "def evaluate_model(model: nn.Module, data_loader: DataLoader) -> float:\n",
        "    model.eval()\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    with torch.no_grad():\n",
        "        for x, y in data_loader:\n",
        "            x = x.to(Config.device)\n",
        "            y = y.to(Config.device)\n",
        "            logits = model(x)\n",
        "            preds = logits.argmax(dim=1)\n",
        "            correct += (preds == y).sum().item()\n",
        "            total += y.size(0)\n",
        "    return correct / total if total > 0 else 0.0\n",
        "\n",
        "\n",
        "# ============================================================\n",
        "# 6. Local training\n",
        "#    (honest vs gaming is controlled by which dataset they receive)\n",
        "# ============================================================\n",
        "\n",
        "def local_train(model: nn.Module, dataset, epochs: int) -> nn.Module:\n",
        "    loader = DataLoader(dataset, batch_size=Config.batch_size, shuffle=True)\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "    optimizer = optim.SGD(model.parameters(), lr=Config.lr, momentum=Config.momentum)\n",
        "\n",
        "    model.train()\n",
        "    for _ in range(epochs):\n",
        "        for x, y in loader:\n",
        "            x = x.to(Config.device)\n",
        "            y = y.to(Config.device)\n",
        "            optimizer.zero_grad()\n",
        "            logits = model(x)\n",
        "            loss = criterion(logits, y)\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "    return model\n",
        "\n",
        "\n",
        "# ============================================================\n",
        "# 7. FedAvg aggregation\n",
        "# ============================================================\n",
        "\n",
        "def fedavg(models: list[nn.Module]) -> nn.Module:\n",
        "    global_model = get_model()\n",
        "    global_state = global_model.state_dict()\n",
        "\n",
        "    state_dicts = [m.state_dict() for m in models]\n",
        "\n",
        "    with torch.no_grad():\n",
        "        for key in global_state.keys():\n",
        "            stacked = torch.stack([sd[key] for sd in state_dicts], dim=0)\n",
        "            global_state[key] = stacked.mean(dim=0)\n",
        "\n",
        "    global_model.load_state_dict(global_state)\n",
        "    return global_model\n",
        "\n",
        "\n",
        "# ============================================================\n",
        "# 8. Federated learning experiment (only gaming_frac changes)\n",
        "# ============================================================\n",
        "\n",
        "def run_federated_experiment(gaming_frac: float):\n",
        "    \"\"\"\n",
        "    gaming_frac: 0.0 => all honest (aligned)\n",
        "                 0.3 => 30% gaming\n",
        "\n",
        "    Returns:\n",
        "      - W_history: per-round hidden tail test accuracy (welfare)\n",
        "      - M_history: per-round head-only public validation accuracy (metric)\n",
        "      - overall_test_history: per-round full test accuracy (for reference)\n",
        "    \"\"\"\n",
        "    n_clients = Config.n_clients\n",
        "    n_gaming = int(round(gaming_frac * n_clients))\n",
        "\n",
        "    all_client_ids = np.arange(n_clients)\n",
        "    rng_local = np.random.default_rng(Config.seed)  # fix seed to keep client composition consistent\n",
        "    gaming_clients = set(rng_local.choice(all_client_ids, size=n_gaming, replace=False))\n",
        "    honest_clients = [cid for cid in all_client_ids if cid not in gaming_clients]\n",
        "\n",
        "    print(f\"[Run] gaming_frac = {gaming_frac:.2f}, \"\n",
        "          f\"honest = {len(honest_clients)}, gaming = {len(gaming_clients)}\")\n",
        "\n",
        "    # Construct training dataset for each client\n",
        "    client_train_datasets = []\n",
        "    for cid in range(n_clients):\n",
        "        base_ds = client_datasets_base[cid]\n",
        "        if cid in gaming_clients and gaming_frac > 0:\n",
        "            # Gaming client:\n",
        "            #   1) Remove tail classes from local data (head-only)\n",
        "            #   2) Concatenate head-only public leak\n",
        "            base_head_only = filter_subset_by_labels(base_ds, Config.head_classes)\n",
        "            ds = ConcatDataset([base_head_only, leak_subset])\n",
        "        else:\n",
        "            # Honest client: use entire local data (head + tail)\n",
        "            ds = base_ds\n",
        "        client_train_datasets.append(ds)\n",
        "\n",
        "    global_model = get_model()\n",
        "\n",
        "    W_history = []\n",
        "    M_history = []\n",
        "    overall_test_history = []\n",
        "\n",
        "    for rnd in range(Config.n_rounds):\n",
        "        print(f\"=== Round {rnd+1}/{Config.n_rounds} ===\")\n",
        "\n",
        "        local_models = []\n",
        "        for cid in range(n_clients):\n",
        "            lm = get_model()\n",
        "            lm.load_state_dict(global_model.state_dict())\n",
        "            lm = local_train(lm, client_train_datasets[cid], epochs=Config.local_epochs)\n",
        "            local_models.append(lm)\n",
        "\n",
        "        global_model = fedavg(local_models)\n",
        "\n",
        "        # Evaluation:\n",
        "        # - W_t: tail-only hidden test (welfare)\n",
        "        # - M_t: head-only public validation (metric)\n",
        "        # - overall test accuracy as an additional reference\n",
        "        W_t = evaluate_model(global_model, hidden_tail_test_loader)\n",
        "        M_t = evaluate_model(global_model, public_val_loader)\n",
        "        overall_acc = evaluate_model(global_model, full_test_loader)\n",
        "\n",
        "        W_history.append(W_t)\n",
        "        M_history.append(M_t)\n",
        "        overall_test_history.append(overall_acc)\n",
        "\n",
        "        print(f\"  Hidden tail test acc (W_t) : {W_t:.4f}\")\n",
        "        print(f\"  Public head val acc (M_t)  : {M_t:.4f}\")\n",
        "        print(f\"  Full test acc (ref)       : {overall_acc:.4f}\")\n",
        "\n",
        "    return W_history, M_history, overall_test_history\n",
        "\n",
        "\n",
        "# ============================================================\n",
        "# 9. Run aligned vs gaming experiments and summarize\n",
        "# ============================================================\n",
        "\n",
        "def tail_mean(vals, tail=10):\n",
        "    vals = np.array(vals)\n",
        "    if len(vals) < tail:\n",
        "        return vals.mean()\n",
        "    return vals[-tail:].mean()\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    # aligned (no gaming)\n",
        "    W_aligned, M_aligned, full_aligned = run_federated_experiment(gaming_frac=0.0)\n",
        "\n",
        "    # gaming (30% gaming)\n",
        "    W_gaming, M_gaming, full_gaming = run_federated_experiment(gaming_frac=Config.gaming_frac)\n",
        "\n",
        "    W_align_mean = tail_mean(W_aligned)\n",
        "    W_game_mean = tail_mean(W_gaming)\n",
        "    M_align_mean = tail_mean(M_aligned)\n",
        "    M_game_mean = tail_mean(M_gaming)\n",
        "    full_align_mean = tail_mean(full_aligned)\n",
        "    full_game_mean = tail_mean(full_gaming)\n",
        "\n",
        "    if W_align_mean > 0:\n",
        "        PoG = (W_align_mean - W_game_mean) / W_align_mean\n",
        "    else:\n",
        "        PoG = float(\"nan\")\n",
        "\n",
        "    print(\"\\n=== Summary (last 10 rounds, head-metric vs tail-welfare) ===\")\n",
        "    print(f\"W_align (tail) ≈ {W_align_mean:.3f}\")\n",
        "    print(f\"W_game  (tail) ≈ {W_game_mean:.3f}\")\n",
        "    print(f\"M_align (head) ≈ {M_align_mean:.3f}\")\n",
        "    print(f\"M_game  (head) ≈ {M_game_mean:.3f}\")\n",
        "    print(f\"Full test (aligned) ≈ {full_align_mean:.3f}\")\n",
        "    print(f\"Full test (gaming)  ≈ {full_game_mean:.3f}\")\n",
        "    print(f\"PoG (tail welfare)  ≈ {PoG:.3f}\")"
      ],
      "metadata": {
        "id": "-_CC07g2oU1b"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Estimator reliability under partial audits"
      ],
      "metadata": {
        "id": "_nHtBAeNkzfS"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import random\n",
        "from dataclasses import dataclass\n",
        "from typing import Dict, List, Tuple\n",
        "\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from torch.utils.data import DataLoader, Subset, ConcatDataset\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "\n",
        "# ============================================================\n",
        "# E1: Estimator Reliability (Ground-truth vs Audit-based estimator)\n",
        "# Standalone code (not meant to be merged into your old script)\n",
        "# ============================================================\n",
        "\n",
        "# ----------------------------\n",
        "# 0) Config / Seed\n",
        "# ----------------------------\n",
        "@dataclass\n",
        "class Config:\n",
        "    # Keep small for Colab\n",
        "    n_clients: int = 12\n",
        "    n_rounds: int = 25\n",
        "    local_epochs: int = 1\n",
        "    batch_size: int = 64\n",
        "    lr: float = 0.01\n",
        "    momentum: float = 0.9\n",
        "\n",
        "    # Non-IID partition\n",
        "    dirichlet_alpha: float = 0.5\n",
        "\n",
        "    # Head/Tail split\n",
        "    head_classes: Tuple[int, ...] = (0, 1, 2, 3, 4)\n",
        "    tail_classes: Tuple[int, ...] = (5, 6, 7, 8, 9)\n",
        "\n",
        "    # Public leak\n",
        "    leak_fraction: float = 1.0  # use all head-only public val for max gaming effect\n",
        "\n",
        "    # Strategy mix for profiles\n",
        "    gaming_fracs: Tuple[float, ...] = (0.0, 0.1, 0.2, 0.3, 0.4, 0.5)     # for GT / estimator comparison\n",
        "    benign_frac: float = 0.10                               # small benign cooperation fraction\n",
        "    update_scale_factor: float = 1.75                       # for update-scaling gaming\n",
        "\n",
        "    # Audit budgets to test estimator reliability\n",
        "    audit_budgets: Tuple[float, ...] = (0.10, 0.25, 0.50)\n",
        "    audit_trials: int = 5  # resample audits to estimate estimator noise\n",
        "\n",
        "    # Evaluation\n",
        "    tail_k: int = 7          # tail-mean over last K rounds\n",
        "    pog_threshold: float = 0.05  # risk threshold for FP/FN\n",
        "\n",
        "    seed: int = 42\n",
        "    root: str = \"./data\"\n",
        "    device: torch.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "\n",
        "CFG = Config()\n",
        "\n",
        "\n",
        "def set_seed(seed: int):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "\n",
        "set_seed(CFG.seed)\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 1) Dataset preparation\n",
        "# ----------------------------\n",
        "transform = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize((0.5,), (0.5,))\n",
        "])\n",
        "\n",
        "full_train = datasets.FashionMNIST(root=CFG.root, train=True, download=True, transform=transform)\n",
        "\n",
        "# We'll use:\n",
        "# - 50k for client-local data (partitioned)\n",
        "# - 10k as public validation candidates (from the same training set)\n",
        "train_local_size = 50_000\n",
        "idx_all = np.arange(len(full_train))\n",
        "idx_local = idx_all[:train_local_size]\n",
        "idx_public_candidates = idx_all[train_local_size:]\n",
        "\n",
        "train_local_set = Subset(full_train, idx_local)\n",
        "\n",
        "targets_all = np.array(full_train.targets)\n",
        "\n",
        "# Public validation set: head-only from last 10k\n",
        "head_mask = np.isin(targets_all, list(CFG.head_classes))\n",
        "public_head_indices = [i for i in idx_public_candidates if head_mask[i]]\n",
        "public_head_indices = np.array(public_head_indices)\n",
        "\n",
        "rng = np.random.default_rng(CFG.seed)\n",
        "rng.shuffle(public_head_indices)\n",
        "\n",
        "leak_size = int(CFG.leak_fraction * len(public_head_indices))\n",
        "leak_indices = public_head_indices[:leak_size]\n",
        "\n",
        "public_val_set = Subset(full_train, public_head_indices)\n",
        "public_val_loader = DataLoader(public_val_set, batch_size=CFG.batch_size, shuffle=False)\n",
        "\n",
        "# Leak subset that gaming clients can append\n",
        "leak_subset = Subset(full_train, leak_indices)\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 2) Utilities: partition + filtering\n",
        "# ----------------------------\n",
        "def make_dirichlet_partitions(labels: np.ndarray, n_clients: int, alpha: float, rng: np.random.Generator):\n",
        "    \"\"\"Returns list of index arrays per client (indices w.r.t. train_local_set).\"\"\"\n",
        "    n_classes = int(labels.max()) + 1\n",
        "    client_indices: List[List[int]] = [[] for _ in range(n_clients)]\n",
        "\n",
        "    class_indices = []\n",
        "    for k in range(n_classes):\n",
        "        idx_k = np.where(labels == k)[0]\n",
        "        rng.shuffle(idx_k)\n",
        "        class_indices.append(idx_k)\n",
        "\n",
        "    for k in range(n_classes):\n",
        "        idx_k = class_indices[k]\n",
        "        n_k = len(idx_k)\n",
        "        if n_k == 0:\n",
        "            continue\n",
        "        proportions = rng.dirichlet(alpha * np.ones(n_clients))\n",
        "        proportions = proportions / proportions.sum()\n",
        "        splits = (np.cumsum(proportions) * n_k).astype(int)\n",
        "        shards = np.split(idx_k, splits[:-1])\n",
        "        for cid, shard in enumerate(shards):\n",
        "            client_indices[cid].extend(shard.tolist())\n",
        "\n",
        "    for cid in range(n_clients):\n",
        "        rng.shuffle(client_indices[cid])\n",
        "\n",
        "    return client_indices\n",
        "\n",
        "\n",
        "def filter_subset_by_labels(base_subset: Subset, allowed_labels: Tuple[int, ...]) -> Subset:\n",
        "    \"\"\"base_subset is a Subset of train_local_set; keep only items with labels in allowed_labels.\"\"\"\n",
        "    allowed = set(allowed_labels)\n",
        "    kept = []\n",
        "    # base_subset.indices are indices inside train_local_set (0..train_local_size-1)\n",
        "    for idx_in_local in base_subset.indices:\n",
        "        global_idx = train_local_set.indices[idx_in_local]  # index into full_train\n",
        "        y = int(full_train.targets[global_idx])\n",
        "        if y in allowed:\n",
        "            kept.append(idx_in_local)\n",
        "    return Subset(train_local_set, kept)\n",
        "\n",
        "\n",
        "def split_subset_train_eval(base_subset: Subset, eval_ratio: float = 0.2) -> Tuple[Subset, Subset]:\n",
        "    \"\"\"Split a client's local subset into train/eval (by index).\"\"\"\n",
        "    idxs = list(base_subset.indices)\n",
        "    rng_local = np.random.default_rng(CFG.seed + 123)\n",
        "    rng_local.shuffle(idxs)\n",
        "    n_eval = int(round(eval_ratio * len(idxs)))\n",
        "    eval_idxs = idxs[:n_eval]\n",
        "    train_idxs = idxs[n_eval:]\n",
        "    return Subset(train_local_set, train_idxs), Subset(train_local_set, eval_idxs)\n",
        "\n",
        "\n",
        "def tail_only_subset(base_subset: Subset) -> Subset:\n",
        "    return filter_subset_by_labels(base_subset, CFG.tail_classes)\n",
        "\n",
        "\n",
        "def oversample_tail(train_subset: Subset, factor: int = 2) -> ConcatDataset:\n",
        "    \"\"\"Benign cooperation: upweight tail samples in local training.\"\"\"\n",
        "    tail_sub = tail_only_subset(train_subset)\n",
        "    # If tail is empty, just return original\n",
        "    if len(tail_sub) == 0:\n",
        "        return ConcatDataset([train_subset])\n",
        "    reps = [tail_sub for _ in range(factor)]\n",
        "    return ConcatDataset([train_subset] + reps)\n",
        "\n",
        "\n",
        "# Build base client datasets (Dirichlet)\n",
        "train_targets_local = np.array(full_train.targets)[idx_local]\n",
        "client_indices = make_dirichlet_partitions(\n",
        "    labels=train_targets_local,\n",
        "    n_clients=CFG.n_clients,\n",
        "    alpha=CFG.dirichlet_alpha,\n",
        "    rng=np.random.default_rng(CFG.seed)\n",
        ")\n",
        "\n",
        "client_base_subsets = [Subset(train_local_set, idxs) for idxs in client_indices]\n",
        "\n",
        "# For auditing/welfare evaluation, each client has a held-out local eval set; welfare = tail-accuracy on that eval set.\n",
        "client_train_subsets: List[Subset] = []\n",
        "client_eval_subsets: List[Subset] = []\n",
        "client_tail_eval_subsets: List[Subset] = []\n",
        "\n",
        "for cid in range(CFG.n_clients):\n",
        "    tr, ev = split_subset_train_eval(client_base_subsets[cid], eval_ratio=0.2)\n",
        "    client_train_subsets.append(tr)\n",
        "    client_eval_subsets.append(ev)\n",
        "    client_tail_eval_subsets.append(tail_only_subset(ev))\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 3) Model\n",
        "# ----------------------------\n",
        "class SimpleCNN(nn.Module):\n",
        "    def __init__(self, num_classes: int = 10):\n",
        "        super().__init__()\n",
        "        self.features = nn.Sequential(\n",
        "            nn.Conv2d(1, 32, kernel_size=3, padding=1),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.MaxPool2d(2),\n",
        "            nn.Conv2d(32, 64, kernel_size=3, padding=1),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.MaxPool2d(2),\n",
        "        )\n",
        "        self.classifier = nn.Sequential(\n",
        "            nn.Flatten(),\n",
        "            nn.Linear(64 * 7 * 7, 128),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.Linear(128, num_classes),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.features(x)\n",
        "        x = self.classifier(x)\n",
        "        return x\n",
        "\n",
        "\n",
        "def get_model() -> nn.Module:\n",
        "    return SimpleCNN().to(CFG.device)\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 4) Train/Eval helpers\n",
        "# ----------------------------\n",
        "@torch.no_grad()\n",
        "def evaluate_acc(model: nn.Module, loader: DataLoader) -> float:\n",
        "    model.eval()\n",
        "    correct, total = 0, 0\n",
        "    for x, y in loader:\n",
        "        x, y = x.to(CFG.device), y.to(CFG.device)\n",
        "        logits = model(x)\n",
        "        pred = logits.argmax(dim=1)\n",
        "        correct += (pred == y).sum().item()\n",
        "        total += y.size(0)\n",
        "    return float(correct / total) if total > 0 else 0.0\n",
        "\n",
        "\n",
        "def local_train(model: nn.Module, dataset, epochs: int) -> nn.Module:\n",
        "    loader = DataLoader(dataset, batch_size=CFG.batch_size, shuffle=True)\n",
        "    opt = optim.SGD(model.parameters(), lr=CFG.lr, momentum=CFG.momentum)\n",
        "    crit = nn.CrossEntropyLoss()\n",
        "\n",
        "    model.train()\n",
        "    for _ in range(epochs):\n",
        "        for x, y in loader:\n",
        "            x, y = x.to(CFG.device), y.to(CFG.device)\n",
        "            opt.zero_grad()\n",
        "            loss = crit(model(x), y)\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "    return model\n",
        "\n",
        "\n",
        "def state_dict_add_scaled_delta(global_sd: Dict[str, torch.Tensor],\n",
        "                                local_sd: Dict[str, torch.Tensor],\n",
        "                                scale: float) -> Dict[str, torch.Tensor]:\n",
        "    \"\"\"Return transmitted state = global + scale*(local - global).\"\"\"\n",
        "    out = {}\n",
        "    for k in global_sd.keys():\n",
        "        out[k] = global_sd[k] + scale * (local_sd[k] - global_sd[k])\n",
        "    return out\n",
        "\n",
        "\n",
        "def fedavg_from_states(states: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:\n",
        "    \"\"\"Average a list of state_dicts (FedAvg).\"\"\"\n",
        "    avg = {}\n",
        "    for k in states[0].keys():\n",
        "        stacked = torch.stack([sd[k] for sd in states], dim=0)\n",
        "        avg[k] = stacked.mean(dim=0)\n",
        "    return avg\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 5) Strategy definitions\n",
        "# ----------------------------\n",
        "STR_HONEST = \"HONEST\"\n",
        "STR_BENIGN = \"BENIGN_TAIL_UPWEIGHT\"\n",
        "STR_GAME_HEAD = \"GAMING_HEAD_ONLY_LEAK\"\n",
        "STR_GAME_SCALE = \"GAMING_UPDATE_SCALING\"\n",
        "\n",
        "\n",
        "def assign_strategies(n_clients: int,\n",
        "                      gaming_frac: float,\n",
        "                      benign_frac: float,\n",
        "                      rng: np.random.Generator) -> Dict[int, str]:\n",
        "    \"\"\"\n",
        "    Assign each client a strategy from {HONEST, BENIGN, GAMING_HEAD, GAMING_SCALE}.\n",
        "    To keep it simple:\n",
        "      - benign_frac portion get BENIGN\n",
        "      - gaming_frac portion get a GAMING strategy (half head-only, half update-scaling)\n",
        "      - rest are HONEST\n",
        "    \"\"\"\n",
        "    all_ids = np.arange(n_clients)\n",
        "    rng.shuffle(all_ids)\n",
        "\n",
        "    n_benign = int(round(benign_frac * n_clients))\n",
        "    n_gaming = int(round(gaming_frac * n_clients))\n",
        "\n",
        "    benign_ids = set(all_ids[:n_benign])\n",
        "    gaming_ids = list(all_ids[n_benign:n_benign + n_gaming])\n",
        "    honest_ids = set(all_ids[n_benign + n_gaming:])\n",
        "\n",
        "    strat = {}\n",
        "    for cid in benign_ids:\n",
        "        strat[cid] = STR_BENIGN\n",
        "    for cid in honest_ids:\n",
        "        strat[cid] = STR_HONEST\n",
        "\n",
        "    # split gaming IDs into two types\n",
        "    half = len(gaming_ids) // 2\n",
        "    for cid in gaming_ids[:half]:\n",
        "        strat[cid] = STR_GAME_HEAD\n",
        "    for cid in gaming_ids[half:]:\n",
        "        strat[cid] = STR_GAME_SCALE\n",
        "\n",
        "    return strat\n",
        "\n",
        "\n",
        "def build_client_train_dataset(cid: int, strategy: str):\n",
        "    base_train = client_train_subsets[cid]\n",
        "\n",
        "    if strategy == STR_HONEST:\n",
        "        return base_train\n",
        "\n",
        "    if strategy == STR_BENIGN:\n",
        "        # upweight tail a bit (benign cooperation)\n",
        "        return oversample_tail(base_train, factor=2)\n",
        "\n",
        "    if strategy == STR_GAME_HEAD:\n",
        "        # Remove tail samples from local training; append public head leak\n",
        "        head_only = filter_subset_by_labels(base_train, CFG.head_classes)\n",
        "        return ConcatDataset([head_only, leak_subset])\n",
        "\n",
        "    if strategy == STR_GAME_SCALE:\n",
        "        # Training data itself is honest; manipulation happens in transmitted update\n",
        "        return base_train\n",
        "\n",
        "    raise ValueError(f\"Unknown strategy: {strategy}\")\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 6) Run FL once per profile, store per-round:\n",
        "#    - public metric M_t (head-only public val accuracy)\n",
        "#    - per-client tail accuracies on client-heldout tail eval (for GT + audit sampling)\n",
        "# ----------------------------\n",
        "def run_profile_once(gaming_frac: float, benign_frac: float, seed_offset: int = 0):\n",
        "    rng = np.random.default_rng(CFG.seed + seed_offset)\n",
        "\n",
        "    strat_map = assign_strategies(CFG.n_clients, gaming_frac, benign_frac, rng)\n",
        "\n",
        "    # Pre-build each client's training dataset according to strategy\n",
        "    client_train_ds = [build_client_train_dataset(cid, strat_map[cid]) for cid in range(CFG.n_clients)]\n",
        "\n",
        "    global_model = get_model()\n",
        "    global_sd = {k: v.detach().clone() for k, v in global_model.state_dict().items()}\n",
        "\n",
        "    M_hist: List[float] = []\n",
        "    tail_acc_matrix = np.zeros((CFG.n_rounds, CFG.n_clients), dtype=np.float32)\n",
        "\n",
        "    for rnd in range(CFG.n_rounds):\n",
        "        transmitted_states: List[Dict[str, torch.Tensor]] = []\n",
        "\n",
        "        for cid in range(CFG.n_clients):\n",
        "            local_model = get_model()\n",
        "            local_model.load_state_dict(global_sd)\n",
        "\n",
        "            local_model = local_train(local_model, client_train_ds[cid], epochs=CFG.local_epochs)\n",
        "            local_sd = local_model.state_dict()\n",
        "\n",
        "            # Manipulation at transmission time for UPDATE_SCALING gaming\n",
        "            if strat_map[cid] == STR_GAME_SCALE and gaming_frac > 0:\n",
        "                tx_sd = state_dict_add_scaled_delta(global_sd, local_sd, scale=CFG.update_scale_factor)\n",
        "            else:\n",
        "                tx_sd = {k: v.detach().clone() for k, v in local_sd.items()}\n",
        "\n",
        "            transmitted_states.append(tx_sd)\n",
        "\n",
        "        # FedAvg update\n",
        "        global_sd = fedavg_from_states(transmitted_states)\n",
        "        global_model.load_state_dict(global_sd)\n",
        "\n",
        "        # Metric: server-observable public head validation accuracy\n",
        "        M_t = evaluate_acc(global_model, public_val_loader)\n",
        "        M_hist.append(M_t)\n",
        "\n",
        "        # Welfare ingredient: per-client tail eval accuracy (experimenter can compute for GT;\n",
        "        # audit will subsample from this vector)\n",
        "        for cid in range(CFG.n_clients):\n",
        "            tail_ev = client_tail_eval_subsets[cid]\n",
        "            if len(tail_ev) == 0:\n",
        "                tail_acc_matrix[rnd, cid] = np.nan  # if a client has no tail samples\n",
        "                continue\n",
        "            loader = DataLoader(tail_ev, batch_size=CFG.batch_size, shuffle=False)\n",
        "            tail_acc_matrix[rnd, cid] = evaluate_acc(global_model, loader)\n",
        "\n",
        "        if (rnd + 1) % 10 == 0 or rnd == 0:\n",
        "            w_full = np.nanmean(tail_acc_matrix[rnd])\n",
        "            print(f\"[Profile gf={gaming_frac:.2f}] Round {rnd+1:>2}/{CFG.n_rounds} | \"\n",
        "                  f\"M(head public)={M_t:.4f} | W_full_tail(mean clients)={w_full:.4f}\")\n",
        "\n",
        "    return {\n",
        "        \"gaming_frac\": gaming_frac,\n",
        "        \"benign_frac\": benign_frac,\n",
        "        \"strategy_map\": strat_map,\n",
        "        \"M_hist\": np.array(M_hist, dtype=np.float32),\n",
        "        \"tail_acc_matrix\": tail_acc_matrix,  # shape [T, N]\n",
        "    }\n",
        "\n",
        "\n",
        "def tail_mean(x: np.ndarray, k: int) -> float:\n",
        "    if len(x) == 0:\n",
        "        return float(\"nan\")\n",
        "    if len(x) <= k:\n",
        "        return float(np.nanmean(x))\n",
        "    return float(np.nanmean(x[-k:]))\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 7) Estimator: audit-based welfare estimate from tail acc matrix\n",
        "#    - GT welfare per round: mean over all clients (nanmean)\n",
        "#    - Estimated welfare per round: mean over audited clients (nanmean over sampled subset)\n",
        "#    - PoG uses baseline welfare from aligned profile\n",
        "# ----------------------------\n",
        "def compute_pog_series(w_ref: float, w_series: np.ndarray) -> np.ndarray:\n",
        "    # PoG_t = (W_ref - W_t) / W_ref\n",
        "    if not np.isfinite(w_ref) or w_ref <= 1e-8:\n",
        "        return np.full_like(w_series, np.nan, dtype=np.float32)\n",
        "    out = (w_ref - w_series) / w_ref\n",
        "    return out.astype(np.float32)\n",
        "\n",
        "\n",
        "def audit_estimate_welfare_series(tail_acc_matrix: np.ndarray,\n",
        "                                 audit_budget: float,\n",
        "                                 rng: np.random.Generator) -> np.ndarray:\n",
        "    \"\"\"\n",
        "    tail_acc_matrix: [T, N] per-client tail acc (nan if no tail data)\n",
        "    returns: W_hat_series [T]\n",
        "    \"\"\"\n",
        "    T, N = tail_acc_matrix.shape\n",
        "    m = max(1, int(round(audit_budget * N)))\n",
        "    W_hat = np.zeros(T, dtype=np.float32)\n",
        "    for t in range(T):\n",
        "        audited = rng.choice(np.arange(N), size=m, replace=False)\n",
        "        W_hat[t] = float(np.nanmean(tail_acc_matrix[t, audited]))\n",
        "    return W_hat\n",
        "\n",
        "\n",
        "def detection_delay(pog_hat: np.ndarray, tau: float) -> int:\n",
        "    \"\"\"First round index (1-based) where PoG_hat >= tau; return 0 if never detected.\"\"\"\n",
        "    for i, v in enumerate(pog_hat):\n",
        "        if np.isfinite(v) and v >= tau:\n",
        "            return i + 1\n",
        "    return 0\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 8) Main: run baseline + profiles, then evaluate estimator reliability\n",
        "# ----------------------------\n",
        "def main():\n",
        "    print(\"\\n=== E1: Estimator Reliability (GT vs Audit-based Estimator) ===\")\n",
        "    print(f\"Device: {CFG.device}\")\n",
        "    print(f\"Clients={CFG.n_clients}, Rounds={CFG.n_rounds}, LocalEpochs={CFG.local_epochs}\")\n",
        "    print(f\"Audit budgets={CFG.audit_budgets}, trials={CFG.audit_trials}\\n\")\n",
        "\n",
        "    # 8.1 Baseline aligned run (gaming_frac=0.0), used as welfare reference\n",
        "    baseline = run_profile_once(gaming_frac=0.0, benign_frac=CFG.benign_frac, seed_offset=0)\n",
        "    M_ref_series = baseline[\"M_hist\"]\n",
        "    W_ref_series_full = np.nanmean(baseline[\"tail_acc_matrix\"], axis=1)  # GT welfare series\n",
        "    W_ref = tail_mean(W_ref_series_full, CFG.tail_k)\n",
        "    M_ref = tail_mean(M_ref_series, CFG.tail_k)\n",
        "\n",
        "    print(\"\\n[Baseline] Reference values (tail-mean over last K rounds)\")\n",
        "    print(f\"  W_ref (tail welfare) = {W_ref:.4f}\")\n",
        "    print(f\"  M_ref (head metric)  = {M_ref:.4f}\\n\")\n",
        "\n",
        "    # 8.2 Run profiles for each gaming_frac (one run each), store GT\n",
        "    profiles = []\n",
        "    seed_offset = 10\n",
        "    for gf in CFG.gaming_fracs:\n",
        "        prof = run_profile_once(gaming_frac=gf, benign_frac=CFG.benign_frac, seed_offset=seed_offset)\n",
        "        seed_offset += 1\n",
        "        profiles.append(prof)\n",
        "\n",
        "    # 8.3 For each profile, compute GT PoG and estimator performance for each audit budget\n",
        "    rows = []\n",
        "    for prof in profiles:\n",
        "        gf = prof[\"gaming_frac\"]\n",
        "        M_series = prof[\"M_hist\"]\n",
        "        tail_mat = prof[\"tail_acc_matrix\"]\n",
        "\n",
        "        # Ground-truth welfare series (full visibility, experimenter only)\n",
        "        W_full_series = np.nanmean(tail_mat, axis=1)\n",
        "        W_full = tail_mean(W_full_series, CFG.tail_k)\n",
        "\n",
        "        # Ground-truth PoG (final)\n",
        "        pog_gt_series = compute_pog_series(W_ref, W_full_series)\n",
        "        pog_gt = tail_mean(pog_gt_series, CFG.tail_k)\n",
        "\n",
        "        # Simple manipulation index proxy: metric inflation vs baseline (final)\n",
        "        M_final = tail_mean(M_series, CFG.tail_k)\n",
        "        manip_delta = M_final - M_ref\n",
        "\n",
        "        # Risk label for FP/FN (based on GT)\n",
        "        is_risky_gt = (np.isfinite(pog_gt) and pog_gt >= CFG.pog_threshold)\n",
        "\n",
        "        print(f\"\\n[Profile gf={gf:.2f}] GT summary:\")\n",
        "        print(f\"  M_final(head metric)  = {M_final:.4f} (delta vs ref {manip_delta:+.4f})\")\n",
        "        print(f\"  W_full(tail welfare)  = {W_full:.4f}\")\n",
        "        print(f\"  PoG_GT                = {pog_gt:.4f} (risk>=tau? {is_risky_gt})\")\n",
        "\n",
        "        # Estimator: for each budget, resample audits multiple times\n",
        "        for b in CFG.audit_budgets:\n",
        "            pog_hats = []\n",
        "            delays = []\n",
        "            risky_preds = []\n",
        "\n",
        "            for tr in range(CFG.audit_trials):\n",
        "                rng_a = np.random.default_rng(CFG.seed + 1000 + tr + int(100 * b) + int(1000 * gf))\n",
        "                W_hat_series = audit_estimate_welfare_series(tail_mat, audit_budget=b, rng=rng_a)\n",
        "                pog_hat_series = compute_pog_series(W_ref, W_hat_series)\n",
        "                pog_hat = tail_mean(pog_hat_series, CFG.tail_k)\n",
        "\n",
        "                pog_hats.append(pog_hat)\n",
        "                delays.append(detection_delay(pog_hat_series, CFG.pog_threshold))\n",
        "                risky_preds.append(np.isfinite(pog_hat) and pog_hat >= CFG.pog_threshold)\n",
        "\n",
        "            pog_hat_mean = float(np.nanmean(pog_hats))\n",
        "            pog_hat_std = float(np.nanstd(pog_hats))\n",
        "            delay_mean = float(np.mean(delays))\n",
        "\n",
        "            # Classification outcomes\n",
        "            pred_risky = (np.mean(risky_preds) >= 0.5)  # majority vote across trials\n",
        "            fp = int((not is_risky_gt) and pred_risky)\n",
        "            fn = int(is_risky_gt and (not pred_risky))\n",
        "\n",
        "            rows.append({\n",
        "                \"gaming_frac\": gf,\n",
        "                \"audit_budget\": b,\n",
        "                \"M_final\": M_final,\n",
        "                \"M_delta_vs_ref\": manip_delta,\n",
        "                \"W_full\": W_full,\n",
        "                \"PoG_GT\": pog_gt,\n",
        "                \"PoG_hat_mean\": pog_hat_mean,\n",
        "                \"PoG_hat_std\": pog_hat_std,\n",
        "                \"delay_mean_rounds\": delay_mean,\n",
        "                \"risk_GT\": int(is_risky_gt),\n",
        "                \"risk_pred_majority\": int(pred_risky),\n",
        "                \"FP\": fp,\n",
        "                \"FN\": fn,\n",
        "            })\n",
        "\n",
        "            print(f\"  [Audit b={b:.2f}] PoG_hat={pog_hat_mean:.4f}±{pog_hat_std:.4f} | \"\n",
        "                  f\"delay~{delay_mean:.1f} rounds | pred_risky={pred_risky} | FP={fp} FN={fn}\")\n",
        "\n",
        "    # 8.4 Aggregate reliability metrics across profiles for each budget\n",
        "    print(\"\\n=== Aggregated Estimator Reliability (across profiles) ===\")\n",
        "    rows_by_b = {}\n",
        "    for r in rows:\n",
        "        rows_by_b.setdefault(r[\"audit_budget\"], []).append(r)\n",
        "\n",
        "    for b, rs in rows_by_b.items():\n",
        "        gt = np.array([x[\"PoG_GT\"] for x in rs], dtype=np.float32)\n",
        "        hat = np.array([x[\"PoG_hat_mean\"] for x in rs], dtype=np.float32)\n",
        "\n",
        "        # Spearman rank correlation (manual, simple)\n",
        "        def rankdata(a):\n",
        "            temp = a.argsort()\n",
        "            ranks = np.empty_like(temp)\n",
        "            ranks[temp] = np.arange(len(a))\n",
        "            return ranks.astype(np.float32)\n",
        "\n",
        "        mask = np.isfinite(gt) & np.isfinite(hat)\n",
        "        if mask.sum() >= 2:\n",
        "            rg = rankdata(gt[mask])\n",
        "            rh = rankdata(hat[mask])\n",
        "            spearman = float(np.corrcoef(rg, rh)[0, 1])\n",
        "        else:\n",
        "            spearman = float(\"nan\")\n",
        "\n",
        "        FP = sum(x[\"FP\"] for x in rs)\n",
        "        FN = sum(x[\"FN\"] for x in rs)\n",
        "        n = len(rs)\n",
        "\n",
        "        print(f\"Audit b={b:.2f} | Spearman(PoG_GT, PoG_hat)={spearman:.3f} | \"\n",
        "              f\"FP={FP}/{n} FN={FN}/{n}\")\n",
        "\n",
        "    # 8.5 Save CSV\n",
        "    out_path = \"E1_estimator_reliability_results.csv\"\n",
        "    import csv\n",
        "    with open(out_path, \"w\", newline=\"\", encoding=\"utf-8\") as f:\n",
        "        writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))\n",
        "        writer.writeheader()\n",
        "        writer.writerows(rows)\n",
        "\n",
        "    print(f\"\\nSaved: {out_path}\")\n",
        "    print(\"Done.\")\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9ZN1_uC3pGmt",
        "outputId": "ea547c01-92c4-4d48-cdcb-dc3c4482cfe5"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "=== E1: Estimator Reliability (GT vs Audit-based Estimator) ===\n",
            "Device: cuda\n",
            "Clients=12, Rounds=25, LocalEpochs=1\n",
            "Audit budgets=(0.1, 0.25, 0.5), trials=5\n",
            "\n",
            "[Profile gf=0.00] Round  1/25 | M(head public)=0.6346 | W_full_tail(mean clients)=0.5981\n",
            "[Profile gf=0.00] Round 10/25 | M(head public)=0.8438 | W_full_tail(mean clients)=0.8482\n",
            "[Profile gf=0.00] Round 20/25 | M(head public)=0.8705 | W_full_tail(mean clients)=0.8552\n",
            "\n",
            "[Baseline] Reference values (tail-mean over last K rounds)\n",
            "  W_ref (tail welfare) = 0.8822\n",
            "  M_ref (head metric)  = 0.8621\n",
            "\n",
            "[Profile gf=0.00] Round  1/25 | M(head public)=0.5285 | W_full_tail(mean clients)=0.6210\n",
            "[Profile gf=0.00] Round 10/25 | M(head public)=0.8358 | W_full_tail(mean clients)=0.8682\n",
            "[Profile gf=0.00] Round 20/25 | M(head public)=0.8625 | W_full_tail(mean clients)=0.8922\n",
            "[Profile gf=0.10] Round  1/25 | M(head public)=0.5904 | W_full_tail(mean clients)=0.6928\n",
            "[Profile gf=0.10] Round 10/25 | M(head public)=0.8495 | W_full_tail(mean clients)=0.8594\n",
            "[Profile gf=0.10] Round 20/25 | M(head public)=0.8627 | W_full_tail(mean clients)=0.8944\n",
            "[Profile gf=0.20] Round  1/25 | M(head public)=0.6741 | W_full_tail(mean clients)=0.6403\n",
            "[Profile gf=0.20] Round 10/25 | M(head public)=0.8737 | W_full_tail(mean clients)=0.8174\n",
            "[Profile gf=0.20] Round 20/25 | M(head public)=0.8835 | W_full_tail(mean clients)=0.8829\n",
            "[Profile gf=0.30] Round  1/25 | M(head public)=0.7503 | W_full_tail(mean clients)=0.4927\n",
            "[Profile gf=0.30] Round 10/25 | M(head public)=0.8882 | W_full_tail(mean clients)=0.7935\n",
            "[Profile gf=0.30] Round 20/25 | M(head public)=0.9065 | W_full_tail(mean clients)=0.8449\n",
            "[Profile gf=0.40] Round  1/25 | M(head public)=0.7259 | W_full_tail(mean clients)=0.5711\n",
            "[Profile gf=0.40] Round 10/25 | M(head public)=0.8668 | W_full_tail(mean clients)=0.8007\n",
            "[Profile gf=0.40] Round 20/25 | M(head public)=0.8992 | W_full_tail(mean clients)=0.8676\n",
            "[Profile gf=0.50] Round  1/25 | M(head public)=0.7943 | W_full_tail(mean clients)=0.5059\n",
            "[Profile gf=0.50] Round 10/25 | M(head public)=0.9004 | W_full_tail(mean clients)=0.7919\n",
            "[Profile gf=0.50] Round 20/25 | M(head public)=0.9175 | W_full_tail(mean clients)=0.8285\n",
            "\n",
            "[Profile gf=0.00] GT summary:\n",
            "  M_final(head metric)  = 0.8668 (delta vs ref +0.0047)\n",
            "  W_full(tail welfare)  = 0.8861\n",
            "  PoG_GT                = -0.0044 (risk>=tau? False)\n",
            "  [Audit b=0.10] PoG_hat=0.0176±0.0272 | delay~1.0 rounds | pred_risky=False | FP=0 FN=0\n",
            "  [Audit b=0.25] PoG_hat=-0.0101±0.0085 | delay~1.0 rounds | pred_risky=False | FP=0 FN=0\n",
            "  [Audit b=0.50] PoG_hat=-0.0074±0.0075 | delay~1.0 rounds | pred_risky=False | FP=0 FN=0\n",
            "\n",
            "[Profile gf=0.10] GT summary:\n",
            "  M_final(head metric)  = 0.8667 (delta vs ref +0.0046)\n",
            "  W_full(tail welfare)  = 0.8874\n",
            "  PoG_GT                = -0.0059 (risk>=tau? False)\n",
            "  [Audit b=0.10] PoG_hat=0.0046±0.0214 | delay~1.0 rounds | pred_risky=False | FP=0 FN=0\n",
            "  [Audit b=0.25] PoG_hat=-0.0080±0.0082 | delay~1.0 rounds | pred_risky=False | FP=0 FN=0\n",
            "  [Audit b=0.50] PoG_hat=-0.0056±0.0031 | delay~1.0 rounds | pred_risky=False | FP=0 FN=0\n",
            "\n",
            "[Profile gf=0.20] GT summary:\n",
            "  M_final(head metric)  = 0.8844 (delta vs ref +0.0223)\n",
            "  W_full(tail welfare)  = 0.8841\n",
            "  PoG_GT                = -0.0022 (risk>=tau? False)\n",
            "  [Audit b=0.10] PoG_hat=0.0078±0.0101 | delay~1.2 rounds | pred_risky=False | FP=0 FN=0\n",
            "  [Audit b=0.25] PoG_hat=-0.0034±0.0065 | delay~1.0 rounds | pred_risky=False | FP=0 FN=0\n",
            "  [Audit b=0.50] PoG_hat=-0.0079±0.0067 | delay~1.0 rounds | pred_risky=False | FP=0 FN=0\n",
            "\n",
            "[Profile gf=0.30] GT summary:\n",
            "  M_final(head metric)  = 0.9113 (delta vs ref +0.0492)\n",
            "  W_full(tail welfare)  = 0.8454\n",
            "  PoG_GT                = 0.0417 (risk>=tau? False)\n",
            "  [Audit b=0.10] PoG_hat=0.0597±0.0184 | delay~1.0 rounds | pred_risky=True | FP=1 FN=0\n",
            "  [Audit b=0.25] PoG_hat=0.0377±0.0140 | delay~1.0 rounds | pred_risky=False | FP=0 FN=0\n",
            "  [Audit b=0.50] PoG_hat=0.0445±0.0088 | delay~1.0 rounds | pred_risky=False | FP=0 FN=0\n",
            "\n",
            "[Profile gf=0.40] GT summary:\n",
            "  M_final(head metric)  = 0.9052 (delta vs ref +0.0431)\n",
            "  W_full(tail welfare)  = 0.8642\n",
            "  PoG_GT                = 0.0204 (risk>=tau? False)\n",
            "  [Audit b=0.10] PoG_hat=0.0406±0.0242 | delay~1.4 rounds | pred_risky=False | FP=0 FN=0\n",
            "  [Audit b=0.25] PoG_hat=0.0172±0.0047 | delay~1.0 rounds | pred_risky=False | FP=0 FN=0\n",
            "  [Audit b=0.50] PoG_hat=0.0178±0.0053 | delay~1.0 rounds | pred_risky=False | FP=0 FN=0\n",
            "\n",
            "[Profile gf=0.50] GT summary:\n",
            "  M_final(head metric)  = 0.9218 (delta vs ref +0.0598)\n",
            "  W_full(tail welfare)  = 0.8405\n",
            "  PoG_GT                = 0.0472 (risk>=tau? False)\n",
            "  [Audit b=0.10] PoG_hat=0.0383±0.0320 | delay~1.0 rounds | pred_risky=False | FP=0 FN=0\n",
            "  [Audit b=0.25] PoG_hat=0.0455±0.0219 | delay~1.0 rounds | pred_risky=True | FP=1 FN=0\n",
            "  [Audit b=0.50] PoG_hat=0.0482±0.0043 | delay~1.0 rounds | pred_risky=True | FP=1 FN=0\n",
            "\n",
            "=== Aggregated Estimator Reliability (across profiles) ===\n",
            "Audit b=0.10 | Spearman(PoG_GT, PoG_hat)=0.771 | FP=1/6 FN=0/6\n",
            "Audit b=0.25 | Spearman(PoG_GT, PoG_hat)=0.943 | FP=1/6 FN=0/6\n",
            "Audit b=0.50 | Spearman(PoG_GT, PoG_hat)=0.771 | FP=1/6 FN=0/6\n",
            "\n",
            "Saved: E1_estimator_reliability_results.csv\n",
            "Done.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Noise (privacy) and auditability"
      ],
      "metadata": {
        "id": "Qvf0KtmiL2H-"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import random\n",
        "from dataclasses import dataclass\n",
        "from typing import Dict, List, Tuple\n",
        "\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from torch.utils.data import DataLoader, Subset, ConcatDataset\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "\n",
        "# ============================================================\n",
        "# E2: Noise / Privacy Trade-off (DP-like update noise injection)\n",
        "# Standalone code (separate from your E1 script)\n",
        "#\n",
        "# Changes requested:\n",
        "# - noise_multipliers = (0.00, 0.05, 0.10)\n",
        "# - keep the rest of settings the same as the previous E2 code\n",
        "# - add more measured metrics\n",
        "# ============================================================\n",
        "\n",
        "# ----------------------------\n",
        "# 0) Config / Seed\n",
        "# ----------------------------\n",
        "@dataclass\n",
        "class Config:\n",
        "    # Keep small for Colab\n",
        "    n_clients: int = 12\n",
        "    n_rounds: int = 25\n",
        "    local_epochs: int = 1\n",
        "    batch_size: int = 64\n",
        "    lr: float = 0.01\n",
        "    momentum: float = 0.9\n",
        "\n",
        "    # Non-IID partition\n",
        "    dirichlet_alpha: float = 0.5\n",
        "\n",
        "    # Head/Tail split\n",
        "    head_classes: Tuple[int, ...] = (0, 1, 2, 3, 4)\n",
        "    tail_classes: Tuple[int, ...] = (5, 6, 7, 8, 9)\n",
        "\n",
        "    # Public leak (for GAMING_HEAD_ONLY_LEAK)\n",
        "    leak_fraction: float = 1.0\n",
        "\n",
        "    # Strategy mix\n",
        "    benign_frac: float = 0.10\n",
        "    gaming_frac_main: float = 0.30  # E2's \"gaming\" condition\n",
        "    update_scale_factor: float = 1.75  # for update-scaling gaming\n",
        "\n",
        "    # DP-like noise injection (privacy knob)\n",
        "    # noise_multiplier = sigma, noise std = sigma * clip_C\n",
        "    clip_C: float = 1.0\n",
        "    noise_multipliers: Tuple[float, ...] = (0.00, 0.05, 0.10)  # <-- requested\n",
        "\n",
        "    # Evaluation summary\n",
        "    tail_k: int = 7\n",
        "\n",
        "    seed: int = 42\n",
        "    root: str = \"./data\"\n",
        "    device: torch.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "\n",
        "CFG = Config()\n",
        "\n",
        "\n",
        "def set_seed(seed: int):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "\n",
        "set_seed(CFG.seed)\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 1) Dataset preparation\n",
        "# ----------------------------\n",
        "transform = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize((0.5,), (0.5,))\n",
        "])\n",
        "\n",
        "full_train = datasets.FashionMNIST(root=CFG.root, train=True, download=True, transform=transform)\n",
        "\n",
        "# We'll use:\n",
        "# - 50k for client-local data\n",
        "# - 10k as public validation candidates (from train)\n",
        "train_local_size = 50_000\n",
        "idx_all = np.arange(len(full_train))\n",
        "idx_local = idx_all[:train_local_size]\n",
        "idx_public_candidates = idx_all[train_local_size:]\n",
        "\n",
        "train_local_set = Subset(full_train, idx_local)\n",
        "targets_all = np.array(full_train.targets)\n",
        "\n",
        "# Public validation set: head-only from last 10k\n",
        "head_mask = np.isin(targets_all, list(CFG.head_classes))\n",
        "public_head_indices = [i for i in idx_public_candidates if head_mask[i]]\n",
        "public_head_indices = np.array(public_head_indices)\n",
        "\n",
        "rng = np.random.default_rng(CFG.seed)\n",
        "rng.shuffle(public_head_indices)\n",
        "\n",
        "leak_size = int(CFG.leak_fraction * len(public_head_indices))\n",
        "leak_indices = public_head_indices[:leak_size]\n",
        "\n",
        "public_val_set = Subset(full_train, public_head_indices)\n",
        "public_val_loader = DataLoader(public_val_set, batch_size=CFG.batch_size, shuffle=False)\n",
        "\n",
        "# Leak subset that gaming clients can append\n",
        "leak_subset = Subset(full_train, leak_indices)\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 2) Utilities: partition + filtering\n",
        "# ----------------------------\n",
        "def make_dirichlet_partitions(labels: np.ndarray, n_clients: int, alpha: float, rng: np.random.Generator):\n",
        "    \"\"\"Returns list of index arrays per client (indices w.r.t. train_local_set).\"\"\"\n",
        "    n_classes = int(labels.max()) + 1\n",
        "    client_indices: List[List[int]] = [[] for _ in range(n_clients)]\n",
        "\n",
        "    class_indices = []\n",
        "    for k in range(n_classes):\n",
        "        idx_k = np.where(labels == k)[0]\n",
        "        rng.shuffle(idx_k)\n",
        "        class_indices.append(idx_k)\n",
        "\n",
        "    for k in range(n_classes):\n",
        "        idx_k = class_indices[k]\n",
        "        n_k = len(idx_k)\n",
        "        if n_k == 0:\n",
        "            continue\n",
        "        proportions = rng.dirichlet(alpha * np.ones(n_clients))\n",
        "        proportions = proportions / proportions.sum()\n",
        "        splits = (np.cumsum(proportions) * n_k).astype(int)\n",
        "        shards = np.split(idx_k, splits[:-1])\n",
        "        for cid, shard in enumerate(shards):\n",
        "            client_indices[cid].extend(shard.tolist())\n",
        "\n",
        "    for cid in range(n_clients):\n",
        "        rng.shuffle(client_indices[cid])\n",
        "\n",
        "    return client_indices\n",
        "\n",
        "\n",
        "def filter_subset_by_labels(base_subset: Subset, allowed_labels: Tuple[int, ...]) -> Subset:\n",
        "    \"\"\"base_subset is a Subset of train_local_set; keep only items with labels in allowed_labels.\"\"\"\n",
        "    allowed = set(allowed_labels)\n",
        "    kept = []\n",
        "    for idx_in_local in base_subset.indices:\n",
        "        global_idx = train_local_set.indices[idx_in_local]  # index into full_train\n",
        "        y = int(full_train.targets[global_idx])\n",
        "        if y in allowed:\n",
        "            kept.append(idx_in_local)\n",
        "    return Subset(train_local_set, kept)\n",
        "\n",
        "\n",
        "def split_subset_train_eval(base_subset: Subset, eval_ratio: float = 0.2) -> Tuple[Subset, Subset]:\n",
        "    idxs = list(base_subset.indices)\n",
        "    rng_local = np.random.default_rng(CFG.seed + 123)\n",
        "    rng_local.shuffle(idxs)\n",
        "    n_eval = int(round(eval_ratio * len(idxs)))\n",
        "    eval_idxs = idxs[:n_eval]\n",
        "    train_idxs = idxs[n_eval:]\n",
        "    return Subset(train_local_set, train_idxs), Subset(train_local_set, eval_idxs)\n",
        "\n",
        "\n",
        "def tail_only_subset(base_subset: Subset) -> Subset:\n",
        "    return filter_subset_by_labels(base_subset, CFG.tail_classes)\n",
        "\n",
        "\n",
        "def oversample_tail(train_subset: Subset, factor: int = 2) -> ConcatDataset:\n",
        "    \"\"\"Benign cooperation: upweight tail samples in local training.\"\"\"\n",
        "    tail_sub = tail_only_subset(train_subset)\n",
        "    if len(tail_sub) == 0:\n",
        "        return ConcatDataset([train_subset])\n",
        "    reps = [tail_sub for _ in range(factor)]\n",
        "    return ConcatDataset([train_subset] + reps)\n",
        "\n",
        "\n",
        "# Build base client datasets (Dirichlet)\n",
        "train_targets_local = np.array(full_train.targets)[idx_local]\n",
        "client_indices = make_dirichlet_partitions(\n",
        "    labels=train_targets_local,\n",
        "    n_clients=CFG.n_clients,\n",
        "    alpha=CFG.dirichlet_alpha,\n",
        "    rng=np.random.default_rng(CFG.seed)\n",
        ")\n",
        "\n",
        "client_base_subsets = [Subset(train_local_set, idxs) for idxs in client_indices]\n",
        "\n",
        "client_train_subsets: List[Subset] = []\n",
        "client_eval_subsets: List[Subset] = []\n",
        "client_tail_eval_subsets: List[Subset] = []\n",
        "client_tail_eval_sizes: List[int] = []\n",
        "\n",
        "for cid in range(CFG.n_clients):\n",
        "    tr, ev = split_subset_train_eval(client_base_subsets[cid], eval_ratio=0.2)\n",
        "    client_train_subsets.append(tr)\n",
        "    client_eval_subsets.append(ev)\n",
        "    tail_ev = tail_only_subset(ev)\n",
        "    client_tail_eval_subsets.append(tail_ev)\n",
        "    client_tail_eval_sizes.append(len(tail_ev))\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 3) Model\n",
        "# ----------------------------\n",
        "class SimpleCNN(nn.Module):\n",
        "    def __init__(self, num_classes: int = 10):\n",
        "        super().__init__()\n",
        "        self.features = nn.Sequential(\n",
        "            nn.Conv2d(1, 32, kernel_size=3, padding=1),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.MaxPool2d(2),\n",
        "            nn.Conv2d(32, 64, kernel_size=3, padding=1),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.MaxPool2d(2),\n",
        "        )\n",
        "        self.classifier = nn.Sequential(\n",
        "            nn.Flatten(),\n",
        "            nn.Linear(64 * 7 * 7, 128),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.Linear(128, num_classes),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.features(x)\n",
        "        x = self.classifier(x)\n",
        "        return x\n",
        "\n",
        "\n",
        "def get_model() -> nn.Module:\n",
        "    return SimpleCNN().to(CFG.device)\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 4) Train/Eval helpers\n",
        "# ----------------------------\n",
        "@torch.no_grad()\n",
        "def evaluate_acc(model: nn.Module, loader: DataLoader) -> float:\n",
        "    model.eval()\n",
        "    correct, total = 0, 0\n",
        "    for x, y in loader:\n",
        "        x, y = x.to(CFG.device), y.to(CFG.device)\n",
        "        logits = model(x)\n",
        "        pred = logits.argmax(dim=1)\n",
        "        correct += (pred == y).sum().item()\n",
        "        total += y.size(0)\n",
        "    return float(correct / total) if total > 0 else 0.0\n",
        "\n",
        "\n",
        "def local_train(model: nn.Module, dataset, epochs: int) -> nn.Module:\n",
        "    loader = DataLoader(dataset, batch_size=CFG.batch_size, shuffle=True)\n",
        "    opt = optim.SGD(model.parameters(), lr=CFG.lr, momentum=CFG.momentum)\n",
        "    crit = nn.CrossEntropyLoss()\n",
        "\n",
        "    model.train()\n",
        "    for _ in range(epochs):\n",
        "        for x, y in loader:\n",
        "            x, y = x.to(CFG.device), y.to(CFG.device)\n",
        "            opt.zero_grad()\n",
        "            loss = crit(model(x), y)\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "    return model\n",
        "\n",
        "\n",
        "def fedavg_from_states(states: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:\n",
        "    avg = {}\n",
        "    for k in states[0].keys():\n",
        "        stacked = torch.stack([sd[k] for sd in states], dim=0)\n",
        "        avg[k] = stacked.mean(dim=0)\n",
        "    return avg\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 5) Strategy definitions\n",
        "# ----------------------------\n",
        "STR_HONEST = \"HONEST\"\n",
        "STR_BENIGN = \"BENIGN_TAIL_UPWEIGHT\"\n",
        "STR_GAME_HEAD = \"GAMING_HEAD_ONLY_LEAK\"\n",
        "STR_GAME_SCALE = \"GAMING_UPDATE_SCALING\"\n",
        "\n",
        "\n",
        "def assign_strategies(n_clients: int,\n",
        "                      gaming_frac: float,\n",
        "                      benign_frac: float,\n",
        "                      rng: np.random.Generator) -> Dict[int, str]:\n",
        "    \"\"\"\n",
        "    Assign each client a strategy from {HONEST, BENIGN, GAMING_HEAD, GAMING_SCALE}.\n",
        "      - benign_frac portion get BENIGN\n",
        "      - gaming_frac portion get a GAMING strategy (half head-only, half update-scaling)\n",
        "      - rest are HONEST\n",
        "    \"\"\"\n",
        "    all_ids = np.arange(n_clients)\n",
        "    rng.shuffle(all_ids)\n",
        "\n",
        "    n_benign = int(round(benign_frac * n_clients))\n",
        "    n_gaming = int(round(gaming_frac * n_clients))\n",
        "\n",
        "    benign_ids = set(all_ids[:n_benign])\n",
        "    gaming_ids = list(all_ids[n_benign:n_benign + n_gaming])\n",
        "    honest_ids = set(all_ids[n_benign + n_gaming:])\n",
        "\n",
        "    strat = {}\n",
        "    for cid in benign_ids:\n",
        "        strat[cid] = STR_BENIGN\n",
        "    for cid in honest_ids:\n",
        "        strat[cid] = STR_HONEST\n",
        "\n",
        "    half = len(gaming_ids) // 2\n",
        "    for cid in gaming_ids[:half]:\n",
        "        strat[cid] = STR_GAME_HEAD\n",
        "    for cid in gaming_ids[half:]:\n",
        "        strat[cid] = STR_GAME_SCALE\n",
        "\n",
        "    return strat\n",
        "\n",
        "\n",
        "def build_client_train_dataset(cid: int, strategy: str):\n",
        "    base_train = client_train_subsets[cid]\n",
        "\n",
        "    if strategy == STR_HONEST:\n",
        "        return base_train\n",
        "\n",
        "    if strategy == STR_BENIGN:\n",
        "        return oversample_tail(base_train, factor=2)\n",
        "\n",
        "    if strategy == STR_GAME_HEAD:\n",
        "        head_only = filter_subset_by_labels(base_train, CFG.head_classes)\n",
        "        return ConcatDataset([head_only, leak_subset])\n",
        "\n",
        "    if strategy == STR_GAME_SCALE:\n",
        "        return base_train\n",
        "\n",
        "    raise ValueError(f\"Unknown strategy: {strategy}\")\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 6) DP-like noise injection on transmitted update\n",
        "# ----------------------------\n",
        "def _flatten_delta(delta_sd: Dict[str, torch.Tensor]) -> torch.Tensor:\n",
        "    vecs = []\n",
        "    for k in delta_sd.keys():\n",
        "        vecs.append(delta_sd[k].reshape(-1))\n",
        "    return torch.cat(vecs, dim=0)\n",
        "\n",
        "\n",
        "def _unflatten_to_sd(vec: torch.Tensor, template_sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n",
        "    out = {}\n",
        "    offset = 0\n",
        "    for k in template_sd.keys():\n",
        "        numel = template_sd[k].numel()\n",
        "        out[k] = vec[offset:offset + numel].reshape(template_sd[k].shape)\n",
        "        offset += numel\n",
        "    return out\n",
        "\n",
        "\n",
        "def dp_perturb_delta(delta_sd: Dict[str, torch.Tensor], clip_C: float, noise_mult: float) -> Dict[str, torch.Tensor]:\n",
        "    \"\"\"\n",
        "    DP-SGD-like: clip global L2 norm of the update delta, then add Gaussian noise.\n",
        "    noise std = noise_mult * clip_C\n",
        "    \"\"\"\n",
        "    delta_vec = _flatten_delta(delta_sd)\n",
        "\n",
        "    # Clip\n",
        "    norm = torch.norm(delta_vec, p=2)\n",
        "    if norm > clip_C:\n",
        "        delta_vec = delta_vec * (clip_C / (norm + 1e-12))\n",
        "\n",
        "    # Noise\n",
        "    if noise_mult > 0:\n",
        "        std = noise_mult * clip_C\n",
        "        delta_vec = delta_vec + torch.randn_like(delta_vec) * std\n",
        "\n",
        "    return _unflatten_to_sd(delta_vec, delta_sd)\n",
        "\n",
        "\n",
        "def make_transmitted_state(global_sd: Dict[str, torch.Tensor],\n",
        "                           local_sd: Dict[str, torch.Tensor],\n",
        "                           scale: float,\n",
        "                           clip_C: float,\n",
        "                           noise_mult: float) -> Dict[str, torch.Tensor]:\n",
        "    \"\"\"\n",
        "    transmitted = global + DP( scale * (local - global) )\n",
        "    \"\"\"\n",
        "    delta = {}\n",
        "    for k in global_sd.keys():\n",
        "        delta[k] = scale * (local_sd[k] - global_sd[k])\n",
        "\n",
        "    delta = dp_perturb_delta(delta, clip_C=clip_C, noise_mult=noise_mult)\n",
        "\n",
        "    tx = {}\n",
        "    for k in global_sd.keys():\n",
        "        tx[k] = global_sd[k] + delta[k]\n",
        "    return tx\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 7) Run FL once (given gaming_frac and noise_mult)\n",
        "#    Per-round:\n",
        "#    - M_t: public head accuracy (metric)\n",
        "#    - W_t: mean tail accuracy across clients (welfare)\n",
        "#    - W_std_t: std tail acc across clients\n",
        "#    - W_min_t, W_max_t: min/max tail acc across clients\n",
        "#    - tail_nonempty_count_t: number of clients with tail eval > 0\n",
        "# ----------------------------\n",
        "def run_once(gaming_frac: float, benign_frac: float, noise_mult: float, seed_offset: int = 0):\n",
        "    rng = np.random.default_rng(CFG.seed + seed_offset)\n",
        "    strat_map = assign_strategies(CFG.n_clients, gaming_frac, benign_frac, rng)\n",
        "\n",
        "    client_train_ds = [build_client_train_dataset(cid, strat_map[cid]) for cid in range(CFG.n_clients)]\n",
        "\n",
        "    global_model = get_model()\n",
        "    global_sd = {k: v.detach().clone() for k, v in global_model.state_dict().items()}\n",
        "\n",
        "    M_hist: List[float] = []\n",
        "    W_hist: List[float] = []\n",
        "    W_std_hist: List[float] = []\n",
        "    W_min_hist: List[float] = []\n",
        "    W_max_hist: List[float] = []\n",
        "    tail_nonempty_hist: List[int] = []\n",
        "\n",
        "    # For reporting how many are gaming/benign/honest (fixed for the run)\n",
        "    n_honest = sum(1 for s in strat_map.values() if s == STR_HONEST)\n",
        "    n_benign = sum(1 for s in strat_map.values() if s == STR_BENIGN)\n",
        "    n_ghead = sum(1 for s in strat_map.values() if s == STR_GAME_HEAD)\n",
        "    n_gscale = sum(1 for s in strat_map.values() if s == STR_GAME_SCALE)\n",
        "\n",
        "    for rnd in range(CFG.n_rounds):\n",
        "        transmitted_states: List[Dict[str, torch.Tensor]] = []\n",
        "\n",
        "        for cid in range(CFG.n_clients):\n",
        "            local_model = get_model()\n",
        "            local_model.load_state_dict(global_sd)\n",
        "\n",
        "            local_model = local_train(local_model, client_train_ds[cid], epochs=CFG.local_epochs)\n",
        "            local_sd = local_model.state_dict()\n",
        "\n",
        "            # scale for UPDATE_SCALING gaming; others scale=1\n",
        "            if strat_map[cid] == STR_GAME_SCALE and gaming_frac > 0:\n",
        "                scale = CFG.update_scale_factor\n",
        "            else:\n",
        "                scale = 1.0\n",
        "\n",
        "            # DP-like perturbation on transmitted update (applies to everyone)\n",
        "            tx_sd = make_transmitted_state(\n",
        "                global_sd=global_sd,\n",
        "                local_sd=local_sd,\n",
        "                scale=scale,\n",
        "                clip_C=CFG.clip_C,\n",
        "                noise_mult=noise_mult\n",
        "            )\n",
        "            transmitted_states.append(tx_sd)\n",
        "\n",
        "        # FedAvg update\n",
        "        global_sd = fedavg_from_states(transmitted_states)\n",
        "        global_model.load_state_dict(global_sd)\n",
        "\n",
        "        # Metric: server-observable public head validation accuracy\n",
        "        M_t = evaluate_acc(global_model, public_val_loader)\n",
        "        M_hist.append(M_t)\n",
        "\n",
        "        # Welfare: tail accuracy distribution across clients (experimenter-side evaluation)\n",
        "        tail_accs = []\n",
        "        nonempty = 0\n",
        "        for cid in range(CFG.n_clients):\n",
        "            tail_ev = client_tail_eval_subsets[cid]\n",
        "            if len(tail_ev) == 0:\n",
        "                continue\n",
        "            nonempty += 1\n",
        "            loader = DataLoader(tail_ev, batch_size=CFG.batch_size, shuffle=False)\n",
        "            tail_accs.append(evaluate_acc(global_model, loader))\n",
        "\n",
        "        tail_nonempty_hist.append(nonempty)\n",
        "\n",
        "        if len(tail_accs) > 0:\n",
        "            W_t = float(np.mean(tail_accs))\n",
        "            W_std_t = float(np.std(tail_accs))\n",
        "            W_min_t = float(np.min(tail_accs))\n",
        "            W_max_t = float(np.max(tail_accs))\n",
        "        else:\n",
        "            W_t, W_std_t, W_min_t, W_max_t = 0.0, float(\"nan\"), float(\"nan\"), float(\"nan\")\n",
        "\n",
        "        W_hist.append(W_t)\n",
        "        W_std_hist.append(W_std_t)\n",
        "        W_min_hist.append(W_min_t)\n",
        "        W_max_hist.append(W_max_t)\n",
        "\n",
        "        if (rnd + 1) % 10 == 0 or rnd == 0:\n",
        "            print(f\"[gf={gaming_frac:.2f} | noise={noise_mult:.2f}] \"\n",
        "                  f\"Round {rnd+1:>2}/{CFG.n_rounds} | M(head)={M_t:.4f} | \"\n",
        "                  f\"W_tail(mean)={W_t:.4f} | W_std={W_std_t:.4f} | n_tail_clients={nonempty}\")\n",
        "\n",
        "    return {\n",
        "        \"gaming_frac\": gaming_frac,\n",
        "        \"benign_frac\": benign_frac,\n",
        "        \"noise_mult\": noise_mult,\n",
        "        \"strategy_map\": strat_map,\n",
        "        \"strategy_counts\": {\n",
        "            \"honest\": n_honest,\n",
        "            \"benign\": n_benign,\n",
        "            \"gaming_head\": n_ghead,\n",
        "            \"gaming_scale\": n_gscale,\n",
        "        },\n",
        "        \"M_hist\": np.array(M_hist, dtype=np.float32),\n",
        "        \"W_hist\": np.array(W_hist, dtype=np.float32),\n",
        "        \"W_std_hist\": np.array(W_std_hist, dtype=np.float32),\n",
        "        \"W_min_hist\": np.array(W_min_hist, dtype=np.float32),\n",
        "        \"W_max_hist\": np.array(W_max_hist, dtype=np.float32),\n",
        "        \"tail_nonempty_hist\": np.array(tail_nonempty_hist, dtype=np.int32),\n",
        "    }\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 8) Summary metrics\n",
        "# ----------------------------\n",
        "def tail_slice(x: np.ndarray, k: int) -> np.ndarray:\n",
        "    if len(x) == 0:\n",
        "        return x\n",
        "    if len(x) <= k:\n",
        "        return x\n",
        "    return x[-k:]\n",
        "\n",
        "\n",
        "def mean_last_k(x: np.ndarray, k: int) -> float:\n",
        "    xs = tail_slice(x, k)\n",
        "    return float(np.mean(xs)) if len(xs) > 0 else float(\"nan\")\n",
        "\n",
        "\n",
        "def std_last_k(x: np.ndarray, k: int) -> float:\n",
        "    xs = tail_slice(x, k)\n",
        "    return float(np.std(xs)) if len(xs) > 0 else float(\"nan\")\n",
        "\n",
        "\n",
        "def min_last_k(x: np.ndarray, k: int) -> float:\n",
        "    xs = tail_slice(x, k)\n",
        "    return float(np.min(xs)) if len(xs) > 0 else float(\"nan\")\n",
        "\n",
        "\n",
        "def max_last_k(x: np.ndarray, k: int) -> float:\n",
        "    xs = tail_slice(x, k)\n",
        "    return float(np.max(xs)) if len(xs) > 0 else float(\"nan\")\n",
        "\n",
        "\n",
        "def compute_pog(W_ref: float, W_final: float) -> float:\n",
        "    if not np.isfinite(W_ref) or W_ref <= 1e-8:\n",
        "        return float(\"nan\")\n",
        "    return float((W_ref - W_final) / W_ref)\n",
        "\n",
        "\n",
        "def safe_div(a: float, b: float, eps: float = 1e-12) -> float:\n",
        "    return float(a / (b + eps))\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 9) Main: Sweep noise levels, compare aligned vs gaming\n",
        "# ----------------------------\n",
        "def main():\n",
        "    print(\"\\n=== E2: Noise / Privacy Trade-off (DP-like update noise) ===\")\n",
        "    print(f\"Device: {CFG.device}\")\n",
        "    print(f\"Clients={CFG.n_clients}, Rounds={CFG.n_rounds}, LocalEpochs={CFG.local_epochs}\")\n",
        "    print(f\"clip_C={CFG.clip_C}, noise_multipliers={CFG.noise_multipliers}\")\n",
        "    print(f\"benign_frac={CFG.benign_frac}, gaming_frac_main={CFG.gaming_frac_main}\\n\")\n",
        "\n",
        "    # 9.1 Reference baseline (aligned, noise=0)\n",
        "    print(\">> Running reference baseline: gf=0.0, noise=0.0\")\n",
        "    ref = run_once(gaming_frac=0.0, benign_frac=CFG.benign_frac, noise_mult=0.0, seed_offset=0)\n",
        "    M_ref = mean_last_k(ref[\"M_hist\"], CFG.tail_k)\n",
        "    W_ref = mean_last_k(ref[\"W_hist\"], CFG.tail_k)\n",
        "\n",
        "    print(\"\\n[Reference] tail-mean over last K rounds\")\n",
        "    print(f\"  M_ref (head metric) = {M_ref:.4f}\")\n",
        "    print(f\"  W_ref (tail welfare)= {W_ref:.4f}\\n\")\n",
        "\n",
        "    # 9.2 Sweep noise for two conditions: aligned (gf=0) vs gaming (gf=gaming_frac_main)\n",
        "    results = []\n",
        "    seed_offset = 10\n",
        "    for nm in CFG.noise_multipliers:\n",
        "        # aligned under noise\n",
        "        print(f\"\\n>> Running ALIGNED under noise={nm:.2f} (gf=0.0)\")\n",
        "        aligned = run_once(gaming_frac=0.0, benign_frac=CFG.benign_frac, noise_mult=nm, seed_offset=seed_offset)\n",
        "        seed_offset += 1\n",
        "\n",
        "        # gaming under noise\n",
        "        print(f\"\\n>> Running GAMING under noise={nm:.2f} (gf={CFG.gaming_frac_main:.2f})\")\n",
        "        gaming = run_once(gaming_frac=CFG.gaming_frac_main, benign_frac=CFG.benign_frac, noise_mult=nm, seed_offset=seed_offset)\n",
        "        seed_offset += 1\n",
        "\n",
        "        # Compute summary metrics for each condition\n",
        "        def summarize(tag: str, out: Dict):\n",
        "            M_final = mean_last_k(out[\"M_hist\"], CFG.tail_k)\n",
        "            W_final = mean_last_k(out[\"W_hist\"], CFG.tail_k)\n",
        "            gap = M_final - W_final\n",
        "\n",
        "            # \"Classic\" PoG vs reference W_ref (kept as-is; does NOT change your paper's definition)\n",
        "            pog = compute_pog(W_ref, W_final)\n",
        "\n",
        "            # Additional metrics (more stable / more informative under noise)\n",
        "            M_std = std_last_k(out[\"M_hist\"], CFG.tail_k)\n",
        "            W_std = std_last_k(out[\"W_hist\"], CFG.tail_k)\n",
        "\n",
        "            # Tail distribution stats across clients (per-round std/min/max, summarized over last K)\n",
        "            W_std_across_clients_mean = mean_last_k(out[\"W_std_hist\"], CFG.tail_k)\n",
        "            W_min = mean_last_k(out[\"W_min_hist\"], CFG.tail_k)\n",
        "            W_max = mean_last_k(out[\"W_max_hist\"], CFG.tail_k)\n",
        "\n",
        "            # Stability / spread\n",
        "            range_M = max_last_k(out[\"M_hist\"], CFG.tail_k) - min_last_k(out[\"M_hist\"], CFG.tail_k)\n",
        "            range_W = max_last_k(out[\"W_hist\"], CFG.tail_k) - min_last_k(out[\"W_hist\"], CFG.tail_k)\n",
        "\n",
        "            # Deltas vs reference\n",
        "            dM_vs_ref = M_final - M_ref\n",
        "            dW_vs_ref = W_final - W_ref\n",
        "\n",
        "            # Normalized gap ratio (how large is gap relative to welfare)\n",
        "            gap_ratio = safe_div(gap, max(W_final, 0.0))\n",
        "\n",
        "            # Average number of clients with non-empty tail eval (should be constant, but log anyway)\n",
        "            tail_nonempty = int(np.round(mean_last_k(out[\"tail_nonempty_hist\"].astype(np.float32), CFG.tail_k)))\n",
        "\n",
        "            sc = out[\"strategy_counts\"]\n",
        "\n",
        "            return {\n",
        "                \"condition\": tag,\n",
        "                \"noise_mult\": float(out[\"noise_mult\"]),\n",
        "                \"gaming_frac\": float(out[\"gaming_frac\"]),\n",
        "                \"benign_frac\": float(out[\"benign_frac\"]),\n",
        "                \"n_clients\": int(CFG.n_clients),\n",
        "                \"n_honest\": int(sc[\"honest\"]),\n",
        "                \"n_benign\": int(sc[\"benign\"]),\n",
        "                \"n_gaming_head\": int(sc[\"gaming_head\"]),\n",
        "                \"n_gaming_scale\": int(sc[\"gaming_scale\"]),\n",
        "\n",
        "                # Primary end metrics\n",
        "                \"M_final\": float(M_final),\n",
        "                \"W_tail_final\": float(W_final),\n",
        "                \"gap_M_minus_W\": float(gap),\n",
        "                \"gap_ratio_gap_over_W\": float(gap_ratio),\n",
        "                \"PoG_vs_ref\": float(pog),\n",
        "\n",
        "                # Deltas vs reference\n",
        "                \"M_delta_vs_ref\": float(dM_vs_ref),\n",
        "                \"W_delta_vs_ref\": float(dW_vs_ref),\n",
        "\n",
        "                # Within-run temporal stability (last K rounds)\n",
        "                \"M_std_lastK\": float(M_std),\n",
        "                \"W_std_lastK\": float(W_std),\n",
        "                \"M_range_lastK\": float(range_M),\n",
        "                \"W_range_lastK\": float(range_W),\n",
        "\n",
        "                # Across-client tail distribution stats (last K rounds averaged)\n",
        "                \"W_std_across_clients_lastK\": float(W_std_across_clients_mean),\n",
        "                \"W_min_across_clients_lastK\": float(W_min),\n",
        "                \"W_max_across_clients_lastK\": float(W_max),\n",
        "\n",
        "                # Tail eval coverage\n",
        "                \"tail_nonempty_clients_lastK\": int(tail_nonempty),\n",
        "            }\n",
        "\n",
        "        row_aligned = summarize(\"ALIGNED\", aligned)\n",
        "        row_gaming = summarize(\"GAMING\", gaming)\n",
        "\n",
        "        # Paired \"extra effect\" metrics at the same noise (does NOT change PoG definition)\n",
        "        # These help you talk about \"additional harm under gaming\" without redefining PoG.\n",
        "        deltaW_same_noise = row_aligned[\"W_tail_final\"] - row_gaming[\"W_tail_final\"]\n",
        "        deltaGap_same_noise = row_gaming[\"gap_M_minus_W\"] - row_aligned[\"gap_M_minus_W\"]\n",
        "\n",
        "        # Relative additional harm ratio (if aligned welfare is too small, it will blow up; keep it but be cautious)\n",
        "        rel_deltaW = safe_div(deltaW_same_noise, max(row_aligned[\"W_tail_final\"], 0.0))\n",
        "\n",
        "        # Store rows with paired deltas\n",
        "        row_aligned[\"paired_deltaW_aligned_minus_gaming\"] = float(deltaW_same_noise)\n",
        "        row_aligned[\"paired_deltaGap_gaming_minus_aligned\"] = float(deltaGap_same_noise)\n",
        "        row_aligned[\"paired_rel_deltaW_over_Waligned\"] = float(rel_deltaW)\n",
        "\n",
        "        row_gaming[\"paired_deltaW_aligned_minus_gaming\"] = float(deltaW_same_noise)\n",
        "        row_gaming[\"paired_deltaGap_gaming_minus_aligned\"] = float(deltaGap_same_noise)\n",
        "        row_gaming[\"paired_rel_deltaW_over_Waligned\"] = float(rel_deltaW)\n",
        "\n",
        "        results.append(row_aligned)\n",
        "        results.append(row_gaming)\n",
        "\n",
        "        # quick print\n",
        "        print(f\"\\n[Summary @ noise={nm:.2f}]\")\n",
        "        print(f\"  ALIGNED: M={row_aligned['M_final']:.4f} | W={row_aligned['W_tail_final']:.4f} | \"\n",
        "              f\"gap={row_aligned['gap_M_minus_W']:.4f} | PoG(ref)={row_aligned['PoG_vs_ref']:.4f}\")\n",
        "        print(f\"  GAMING : M={row_gaming['M_final']:.4f} | W={row_gaming['W_tail_final']:.4f} | \"\n",
        "              f\"gap={row_gaming['gap_M_minus_W']:.4f} | PoG(ref)={row_gaming['PoG_vs_ref']:.4f}\")\n",
        "        print(f\"  Paired: ΔW(aligned-gaming)={deltaW_same_noise:+.4f} | \"\n",
        "              f\"Δgap(gaming-aligned)={deltaGap_same_noise:+.4f} | relΔW={rel_deltaW:+.4f}\")\n",
        "\n",
        "    # 9.3 Save CSV\n",
        "    out_path = \"E2_noise_privacy_tradeoff_results.csv\"\n",
        "    import csv\n",
        "    with open(out_path, \"w\", newline=\"\", encoding=\"utf-8\") as f:\n",
        "        fieldnames = list(results[0].keys())\n",
        "        writer = csv.DictWriter(f, fieldnames=fieldnames)\n",
        "        writer.writeheader()\n",
        "        writer.writerows(results)\n",
        "\n",
        "    print(f\"\\nSaved: {out_path}\")\n",
        "    print(\"Done.\")\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wTzIl0xGiUtb",
        "outputId": "8ab1eb33-2374-4bee-ffd8-63027159d41d"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "=== E2: Noise / Privacy Trade-off (DP-like update noise) ===\n",
            "Device: cuda\n",
            "Clients=12, Rounds=25, LocalEpochs=1\n",
            "clip_C=1.0, noise_multipliers=(0.0, 0.05, 0.1)\n",
            "benign_frac=0.1, gaming_frac_main=0.3\n",
            "\n",
            ">> Running reference baseline: gf=0.0, noise=0.0\n",
            "[gf=0.00 | noise=0.00] Round  1/25 | M(head)=0.4477 | W_tail(mean)=0.4587 | W_std=0.1238 | n_tail_clients=12\n",
            "[gf=0.00 | noise=0.00] Round 10/25 | M(head)=0.8371 | W_tail(mean)=0.8291 | W_std=0.0725 | n_tail_clients=12\n",
            "[gf=0.00 | noise=0.00] Round 20/25 | M(head)=0.8690 | W_tail(mean)=0.8483 | W_std=0.0610 | n_tail_clients=12\n",
            "\n",
            "[Reference] tail-mean over last K rounds\n",
            "  M_ref (head metric) = 0.8605\n",
            "  W_ref (tail welfare)= 0.8767\n",
            "\n",
            "\n",
            ">> Running ALIGNED under noise=0.00 (gf=0.0)\n",
            "[gf=0.00 | noise=0.00] Round  1/25 | M(head)=0.5039 | W_tail(mean)=0.5282 | W_std=0.1274 | n_tail_clients=12\n",
            "[gf=0.00 | noise=0.00] Round 10/25 | M(head)=0.8389 | W_tail(mean)=0.8468 | W_std=0.0602 | n_tail_clients=12\n",
            "[gf=0.00 | noise=0.00] Round 20/25 | M(head)=0.8613 | W_tail(mean)=0.8854 | W_std=0.0580 | n_tail_clients=12\n",
            "\n",
            ">> Running GAMING under noise=0.00 (gf=0.30)\n",
            "[gf=0.30 | noise=0.00] Round  1/25 | M(head)=0.6259 | W_tail(mean)=0.5145 | W_std=0.2273 | n_tail_clients=12\n",
            "[gf=0.30 | noise=0.00] Round 10/25 | M(head)=0.8780 | W_tail(mean)=0.7559 | W_std=0.1258 | n_tail_clients=12\n",
            "[gf=0.30 | noise=0.00] Round 20/25 | M(head)=0.9037 | W_tail(mean)=0.8399 | W_std=0.0772 | n_tail_clients=12\n",
            "\n",
            "[Summary @ noise=0.00]\n",
            "  ALIGNED: M=0.8656 | W=0.8831 | gap=-0.0175 | PoG(ref)=-0.0074\n",
            "  GAMING : M=0.9046 | W=0.8289 | gap=0.0757 | PoG(ref)=0.0545\n",
            "  Paired: ΔW(aligned-gaming)=+0.0542 | Δgap(gaming-aligned)=+0.0932 | relΔW=+0.0614\n",
            "\n",
            ">> Running ALIGNED under noise=0.05 (gf=0.0)\n",
            "[gf=0.00 | noise=0.05] Round  1/25 | M(head)=0.2729 | W_tail(mean)=0.3869 | W_std=0.1292 | n_tail_clients=12\n",
            "[gf=0.00 | noise=0.05] Round 10/25 | M(head)=0.8236 | W_tail(mean)=0.8277 | W_std=0.0615 | n_tail_clients=12\n",
            "[gf=0.00 | noise=0.05] Round 20/25 | M(head)=0.8363 | W_tail(mean)=0.8973 | W_std=0.0390 | n_tail_clients=12\n",
            "\n",
            ">> Running GAMING under noise=0.05 (gf=0.30)\n",
            "[gf=0.30 | noise=0.05] Round  1/25 | M(head)=0.5874 | W_tail(mean)=0.3098 | W_std=0.0933 | n_tail_clients=12\n",
            "[gf=0.30 | noise=0.05] Round 10/25 | M(head)=0.8794 | W_tail(mean)=0.7741 | W_std=0.0711 | n_tail_clients=12\n",
            "[gf=0.30 | noise=0.05] Round 20/25 | M(head)=0.9098 | W_tail(mean)=0.8169 | W_std=0.0803 | n_tail_clients=12\n",
            "\n",
            "[Summary @ noise=0.05]\n",
            "  ALIGNED: M=0.8419 | W=0.8802 | gap=-0.0383 | PoG(ref)=-0.0041\n",
            "  GAMING : M=0.9101 | W=0.8290 | gap=0.0812 | PoG(ref)=0.0544\n",
            "  Paired: ΔW(aligned-gaming)=+0.0513 | Δgap(gaming-aligned)=+0.1195 | relΔW=+0.0583\n",
            "\n",
            ">> Running ALIGNED under noise=0.10 (gf=0.0)\n",
            "[gf=0.00 | noise=0.10] Round  1/25 | M(head)=0.3495 | W_tail(mean)=0.1260 | W_std=0.0525 | n_tail_clients=12\n",
            "[gf=0.00 | noise=0.10] Round 10/25 | M(head)=0.7833 | W_tail(mean)=0.8095 | W_std=0.0760 | n_tail_clients=12\n",
            "[gf=0.00 | noise=0.10] Round 20/25 | M(head)=0.7513 | W_tail(mean)=0.8679 | W_std=0.0355 | n_tail_clients=12\n",
            "\n",
            ">> Running GAMING under noise=0.10 (gf=0.30)\n",
            "[gf=0.30 | noise=0.10] Round  1/25 | M(head)=0.3039 | W_tail(mean)=0.3438 | W_std=0.1860 | n_tail_clients=12\n",
            "[gf=0.30 | noise=0.10] Round 10/25 | M(head)=0.8568 | W_tail(mean)=0.7344 | W_std=0.1002 | n_tail_clients=12\n",
            "[gf=0.30 | noise=0.10] Round 20/25 | M(head)=0.8530 | W_tail(mean)=0.7952 | W_std=0.0416 | n_tail_clients=12\n",
            "\n",
            "[Summary @ noise=0.10]\n",
            "  ALIGNED: M=0.8018 | W=0.8440 | gap=-0.0422 | PoG(ref)=0.0373\n",
            "  GAMING : M=0.8592 | W=0.7807 | gap=0.0785 | PoG(ref)=0.1095\n",
            "  Paired: ΔW(aligned-gaming)=+0.0633 | Δgap(gaming-aligned)=+0.1207 | relΔW=+0.0750\n",
            "\n",
            "Saved: E2_noise_privacy_tradeoff_results.csv\n",
            "Done.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# High-alignment metrics"
      ],
      "metadata": {
        "id": "hOAtK0XQ4CBn"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import random\n",
        "from dataclasses import dataclass\n",
        "from typing import Dict, List, Tuple\n",
        "\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from torch.utils.data import DataLoader, Subset, ConcatDataset\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "\n",
        "# ============================================================\n",
        "# E3: High-alignment regime (metric closer to welfare)\n",
        "# Fashion-MNIST, same skeleton as your E2 script\n",
        "#\n",
        "# Key change vs E2:\n",
        "# - Remove DP noise sweep\n",
        "# - Sweep lambda for mixed public metric:\n",
        "#     M_public(lambda) = (1-lambda)*M_head + lambda*M_tail\n",
        "# - Keep strategy mix / partitions / model / rounds identical\n",
        "# - Save CSV with richer metrics (similar to E2)\n",
        "# ============================================================\n",
        "\n",
        "# ----------------------------\n",
        "# 0) Config / Seed\n",
        "# ----------------------------\n",
        "@dataclass\n",
        "class Config:\n",
        "    # Keep small for Colab\n",
        "    n_clients: int = 12\n",
        "    n_rounds: int = 25\n",
        "    local_epochs: int = 1\n",
        "    batch_size: int = 64\n",
        "    lr: float = 0.01\n",
        "    momentum: float = 0.9\n",
        "\n",
        "    # Non-IID partition\n",
        "    dirichlet_alpha: float = 0.5\n",
        "\n",
        "    # Head/Tail split\n",
        "    head_classes: Tuple[int, ...] = (0, 1, 2, 3, 4)\n",
        "    tail_classes: Tuple[int, ...] = (5, 6, 7, 8, 9)\n",
        "\n",
        "    # Public leak (for GAMING_HEAD_ONLY_LEAK)\n",
        "    leak_fraction: float = 1.0\n",
        "\n",
        "    # Strategy mix (same as your E2)\n",
        "    benign_frac: float = 0.10\n",
        "    gaming_frac_main: float = 0.30  # E3's \"gaming\" condition\n",
        "    update_scale_factor: float = 1.75  # for update-scaling gaming\n",
        "\n",
        "    # E3: alignment knob(s)\n",
        "    # lambda=0 -> metric=head only, lambda=1 -> metric=tail only\n",
        "    metric_lambdas: Tuple[float, ...] = (0.00, 0.30, 0.60)\n",
        "\n",
        "    # Evaluation summary\n",
        "    tail_k: int = 7  # last K rounds to average\n",
        "\n",
        "    seed: int = 42\n",
        "    root: str = \"./data\"\n",
        "    device: torch.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "\n",
        "CFG = Config()\n",
        "\n",
        "\n",
        "def set_seed(seed: int):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "\n",
        "set_seed(CFG.seed)\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 1) Dataset preparation\n",
        "# ----------------------------\n",
        "transform = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize((0.5,), (0.5,))\n",
        "])\n",
        "\n",
        "full_train = datasets.FashionMNIST(root=CFG.root, train=True, download=True, transform=transform)\n",
        "\n",
        "# We'll use:\n",
        "# - 50k for client-local data\n",
        "# - 10k as public validation candidates (from train)\n",
        "train_local_size = 50_000\n",
        "idx_all = np.arange(len(full_train))\n",
        "idx_local = idx_all[:train_local_size]\n",
        "idx_public_candidates = idx_all[train_local_size:]\n",
        "\n",
        "train_local_set = Subset(full_train, idx_local)\n",
        "targets_all = np.array(full_train.targets)\n",
        "\n",
        "# Build public validation sets (head + tail) from last 10k\n",
        "head_mask = np.isin(targets_all, list(CFG.head_classes))\n",
        "tail_mask = np.isin(targets_all, list(CFG.tail_classes))\n",
        "\n",
        "public_head_indices = np.array([i for i in idx_public_candidates if head_mask[i]])\n",
        "public_tail_indices = np.array([i for i in idx_public_candidates if tail_mask[i]])\n",
        "\n",
        "rng = np.random.default_rng(CFG.seed)\n",
        "rng.shuffle(public_head_indices)\n",
        "rng.shuffle(public_tail_indices)\n",
        "\n",
        "public_head_val_set = Subset(full_train, public_head_indices)\n",
        "public_tail_val_set = Subset(full_train, public_tail_indices)\n",
        "\n",
        "public_head_val_loader = DataLoader(public_head_val_set, batch_size=CFG.batch_size, shuffle=False)\n",
        "public_tail_val_loader = DataLoader(public_tail_val_set, batch_size=CFG.batch_size, shuffle=False)\n",
        "\n",
        "# Leak subset that gaming clients can append (head-only, as in your E2)\n",
        "leak_size = int(CFG.leak_fraction * len(public_head_indices))\n",
        "leak_indices = public_head_indices[:leak_size]\n",
        "leak_subset = Subset(full_train, leak_indices)\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 2) Utilities: partition + filtering\n",
        "# ----------------------------\n",
        "def make_dirichlet_partitions(labels: np.ndarray, n_clients: int, alpha: float, rng: np.random.Generator):\n",
        "    \"\"\"Returns list of index arrays per client (indices w.r.t. train_local_set).\"\"\"\n",
        "    n_classes = int(labels.max()) + 1\n",
        "    client_indices: List[List[int]] = [[] for _ in range(n_clients)]\n",
        "\n",
        "    class_indices = []\n",
        "    for k in range(n_classes):\n",
        "        idx_k = np.where(labels == k)[0]\n",
        "        rng.shuffle(idx_k)\n",
        "        class_indices.append(idx_k)\n",
        "\n",
        "    for k in range(n_classes):\n",
        "        idx_k = class_indices[k]\n",
        "        n_k = len(idx_k)\n",
        "        if n_k == 0:\n",
        "            continue\n",
        "        proportions = rng.dirichlet(alpha * np.ones(n_clients))\n",
        "        proportions = proportions / proportions.sum()\n",
        "        splits = (np.cumsum(proportions) * n_k).astype(int)\n",
        "        shards = np.split(idx_k, splits[:-1])\n",
        "        for cid, shard in enumerate(shards):\n",
        "            client_indices[cid].extend(shard.tolist())\n",
        "\n",
        "    for cid in range(n_clients):\n",
        "        rng.shuffle(client_indices[cid])\n",
        "\n",
        "    return client_indices\n",
        "\n",
        "\n",
        "def filter_subset_by_labels(base_subset: Subset, allowed_labels: Tuple[int, ...]) -> Subset:\n",
        "    \"\"\"base_subset is a Subset of train_local_set; keep only items with labels in allowed_labels.\"\"\"\n",
        "    allowed = set(allowed_labels)\n",
        "    kept = []\n",
        "    for idx_in_local in base_subset.indices:\n",
        "        global_idx = train_local_set.indices[idx_in_local]  # index into full_train\n",
        "        y = int(full_train.targets[global_idx])\n",
        "        if y in allowed:\n",
        "            kept.append(idx_in_local)\n",
        "    return Subset(train_local_set, kept)\n",
        "\n",
        "\n",
        "def split_subset_train_eval(base_subset: Subset, eval_ratio: float = 0.2) -> Tuple[Subset, Subset]:\n",
        "    idxs = list(base_subset.indices)\n",
        "    rng_local = np.random.default_rng(CFG.seed + 123)\n",
        "    rng_local.shuffle(idxs)\n",
        "    n_eval = int(round(eval_ratio * len(idxs)))\n",
        "    eval_idxs = idxs[:n_eval]\n",
        "    train_idxs = idxs[n_eval:]\n",
        "    return Subset(train_local_set, train_idxs), Subset(train_local_set, eval_idxs)\n",
        "\n",
        "\n",
        "def tail_only_subset(base_subset: Subset) -> Subset:\n",
        "    return filter_subset_by_labels(base_subset, CFG.tail_classes)\n",
        "\n",
        "\n",
        "def oversample_tail(train_subset: Subset, factor: int = 2) -> ConcatDataset:\n",
        "    \"\"\"Benign cooperation: upweight tail samples in local training.\"\"\"\n",
        "    tail_sub = tail_only_subset(train_subset)\n",
        "    if len(tail_sub) == 0:\n",
        "        return ConcatDataset([train_subset])\n",
        "    reps = [tail_sub for _ in range(factor)]\n",
        "    return ConcatDataset([train_subset] + reps)\n",
        "\n",
        "\n",
        "# Build base client datasets (Dirichlet)\n",
        "train_targets_local = np.array(full_train.targets)[idx_local]\n",
        "client_indices = make_dirichlet_partitions(\n",
        "    labels=train_targets_local,\n",
        "    n_clients=CFG.n_clients,\n",
        "    alpha=CFG.dirichlet_alpha,\n",
        "    rng=np.random.default_rng(CFG.seed)\n",
        ")\n",
        "\n",
        "client_base_subsets = [Subset(train_local_set, idxs) for idxs in client_indices]\n",
        "\n",
        "client_train_subsets: List[Subset] = []\n",
        "client_eval_subsets: List[Subset] = []\n",
        "client_tail_eval_subsets: List[Subset] = []\n",
        "client_tail_eval_sizes: List[int] = []\n",
        "\n",
        "for cid in range(CFG.n_clients):\n",
        "    tr, ev = split_subset_train_eval(client_base_subsets[cid], eval_ratio=0.2)\n",
        "    client_train_subsets.append(tr)\n",
        "    client_eval_subsets.append(ev)\n",
        "    tail_ev = tail_only_subset(ev)\n",
        "    client_tail_eval_subsets.append(tail_ev)\n",
        "    client_tail_eval_sizes.append(len(tail_ev))\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 3) Model\n",
        "# ----------------------------\n",
        "class SimpleCNN(nn.Module):\n",
        "    def __init__(self, num_classes: int = 10):\n",
        "        super().__init__()\n",
        "        self.features = nn.Sequential(\n",
        "            nn.Conv2d(1, 32, kernel_size=3, padding=1),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.MaxPool2d(2),\n",
        "            nn.Conv2d(32, 64, kernel_size=3, padding=1),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.MaxPool2d(2),\n",
        "        )\n",
        "        self.classifier = nn.Sequential(\n",
        "            nn.Flatten(),\n",
        "            nn.Linear(64 * 7 * 7, 128),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.Linear(128, num_classes),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.features(x)\n",
        "        x = self.classifier(x)\n",
        "        return x\n",
        "\n",
        "\n",
        "def get_model() -> nn.Module:\n",
        "    return SimpleCNN().to(CFG.device)\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 4) Train/Eval helpers\n",
        "# ----------------------------\n",
        "@torch.no_grad()\n",
        "def evaluate_acc(model: nn.Module, loader: DataLoader) -> float:\n",
        "    model.eval()\n",
        "    correct, total = 0, 0\n",
        "    for x, y in loader:\n",
        "        x, y = x.to(CFG.device), y.to(CFG.device)\n",
        "        logits = model(x)\n",
        "        pred = logits.argmax(dim=1)\n",
        "        correct += (pred == y).sum().item()\n",
        "        total += y.size(0)\n",
        "    return float(correct / total) if total > 0 else 0.0\n",
        "\n",
        "\n",
        "def local_train(model: nn.Module, dataset, epochs: int) -> nn.Module:\n",
        "    loader = DataLoader(dataset, batch_size=CFG.batch_size, shuffle=True)\n",
        "    opt = optim.SGD(model.parameters(), lr=CFG.lr, momentum=CFG.momentum)\n",
        "    crit = nn.CrossEntropyLoss()\n",
        "\n",
        "    model.train()\n",
        "    for _ in range(epochs):\n",
        "        for x, y in loader:\n",
        "            x, y = x.to(CFG.device), y.to(CFG.device)\n",
        "            opt.zero_grad()\n",
        "            loss = crit(model(x), y)\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "    return model\n",
        "\n",
        "\n",
        "def fedavg_from_states(states: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:\n",
        "    avg = {}\n",
        "    for k in states[0].keys():\n",
        "        stacked = torch.stack([sd[k] for sd in states], dim=0)\n",
        "        avg[k] = stacked.mean(dim=0)\n",
        "    return avg\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 5) Strategy definitions (same as E2)\n",
        "# ----------------------------\n",
        "STR_HONEST = \"HONEST\"\n",
        "STR_BENIGN = \"BENIGN_TAIL_UPWEIGHT\"\n",
        "STR_GAME_HEAD = \"GAMING_HEAD_ONLY_LEAK\"\n",
        "STR_GAME_SCALE = \"GAMING_UPDATE_SCALING\"\n",
        "\n",
        "\n",
        "def assign_strategies(n_clients: int,\n",
        "                      gaming_frac: float,\n",
        "                      benign_frac: float,\n",
        "                      rng: np.random.Generator) -> Dict[int, str]:\n",
        "    \"\"\"\n",
        "    Assign each client a strategy from {HONEST, BENIGN, GAMING_HEAD, GAMING_SCALE}.\n",
        "      - benign_frac portion get BENIGN\n",
        "      - gaming_frac portion get a GAMING strategy (half head-only, half update-scaling)\n",
        "      - rest are HONEST\n",
        "    \"\"\"\n",
        "    all_ids = np.arange(n_clients)\n",
        "    rng.shuffle(all_ids)\n",
        "\n",
        "    n_benign = int(round(benign_frac * n_clients))\n",
        "    n_gaming = int(round(gaming_frac * n_clients))\n",
        "\n",
        "    benign_ids = set(all_ids[:n_benign])\n",
        "    gaming_ids = list(all_ids[n_benign:n_benign + n_gaming])\n",
        "    honest_ids = set(all_ids[n_benign + n_gaming:])\n",
        "\n",
        "    strat = {}\n",
        "    for cid in benign_ids:\n",
        "        strat[cid] = STR_BENIGN\n",
        "    for cid in honest_ids:\n",
        "        strat[cid] = STR_HONEST\n",
        "\n",
        "    half = len(gaming_ids) // 2\n",
        "    for cid in gaming_ids[:half]:\n",
        "        strat[cid] = STR_GAME_HEAD\n",
        "    for cid in gaming_ids[half:]:\n",
        "        strat[cid] = STR_GAME_SCALE\n",
        "\n",
        "    return strat\n",
        "\n",
        "\n",
        "def build_client_train_dataset(cid: int, strategy: str):\n",
        "    base_train = client_train_subsets[cid]\n",
        "\n",
        "    if strategy == STR_HONEST:\n",
        "        return base_train\n",
        "\n",
        "    if strategy == STR_BENIGN:\n",
        "        return oversample_tail(base_train, factor=2)\n",
        "\n",
        "    if strategy == STR_GAME_HEAD:\n",
        "        head_only = filter_subset_by_labels(base_train, CFG.head_classes)\n",
        "        return ConcatDataset([head_only, leak_subset])\n",
        "\n",
        "    if strategy == STR_GAME_SCALE:\n",
        "        return base_train\n",
        "\n",
        "    raise ValueError(f\"Unknown strategy: {strategy}\")\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 6) Transmitted state (no DP in E3; keep update scaling for gaming)\n",
        "# ----------------------------\n",
        "def make_transmitted_state_no_dp(global_sd: Dict[str, torch.Tensor],\n",
        "                                 local_sd: Dict[str, torch.Tensor],\n",
        "                                 scale: float) -> Dict[str, torch.Tensor]:\n",
        "    \"\"\"\n",
        "    transmitted = global + scale * (local - global)\n",
        "    \"\"\"\n",
        "    tx = {}\n",
        "    for k in global_sd.keys():\n",
        "        tx[k] = global_sd[k] + scale * (local_sd[k] - global_sd[k])\n",
        "    return tx\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 7) Run FL once (given gaming_frac and lambda)\n",
        "#    Per-round:\n",
        "#    - M_head_t: public head accuracy\n",
        "#    - M_tail_t: public tail accuracy\n",
        "#    - M_pub_t : mixed metric (lambda)\n",
        "#    - W_t     : mean tail accuracy across clients (welfare)\n",
        "#    plus distribution stats across clients on tail eval\n",
        "# ----------------------------\n",
        "def run_once(gaming_frac: float, benign_frac: float, metric_lambda: float, seed_offset: int = 0):\n",
        "    rng = np.random.default_rng(CFG.seed + seed_offset)\n",
        "    strat_map = assign_strategies(CFG.n_clients, gaming_frac, benign_frac, rng)\n",
        "\n",
        "    client_train_ds = [build_client_train_dataset(cid, strat_map[cid]) for cid in range(CFG.n_clients)]\n",
        "\n",
        "    global_model = get_model()\n",
        "    global_sd = {k: v.detach().clone() for k, v in global_model.state_dict().items()}\n",
        "\n",
        "    M_head_hist: List[float] = []\n",
        "    M_tail_hist: List[float] = []\n",
        "    M_pub_hist: List[float] = []\n",
        "\n",
        "    W_hist: List[float] = []\n",
        "    W_std_hist: List[float] = []\n",
        "    W_min_hist: List[float] = []\n",
        "    W_max_hist: List[float] = []\n",
        "    tail_nonempty_hist: List[int] = []\n",
        "\n",
        "    n_honest = sum(1 for s in strat_map.values() if s == STR_HONEST)\n",
        "    n_benign = sum(1 for s in strat_map.values() if s == STR_BENIGN)\n",
        "    n_ghead = sum(1 for s in strat_map.values() if s == STR_GAME_HEAD)\n",
        "    n_gscale = sum(1 for s in strat_map.values() if s == STR_GAME_SCALE)\n",
        "\n",
        "    for rnd in range(CFG.n_rounds):\n",
        "        transmitted_states: List[Dict[str, torch.Tensor]] = []\n",
        "\n",
        "        for cid in range(CFG.n_clients):\n",
        "            local_model = get_model()\n",
        "            local_model.load_state_dict(global_sd)\n",
        "\n",
        "            local_model = local_train(local_model, client_train_ds[cid], epochs=CFG.local_epochs)\n",
        "            local_sd = local_model.state_dict()\n",
        "\n",
        "            # scale for UPDATE_SCALING gaming; others scale=1\n",
        "            if strat_map[cid] == STR_GAME_SCALE and gaming_frac > 0:\n",
        "                scale = CFG.update_scale_factor\n",
        "            else:\n",
        "                scale = 1.0\n",
        "\n",
        "            tx_sd = make_transmitted_state_no_dp(global_sd=global_sd, local_sd=local_sd, scale=scale)\n",
        "            transmitted_states.append(tx_sd)\n",
        "\n",
        "        # FedAvg update\n",
        "        global_sd = fedavg_from_states(transmitted_states)\n",
        "        global_model.load_state_dict(global_sd)\n",
        "\n",
        "        # Public accuracies\n",
        "        M_head = evaluate_acc(global_model, public_head_val_loader)\n",
        "        M_tail = evaluate_acc(global_model, public_tail_val_loader)\n",
        "        M_pub = (1.0 - metric_lambda) * M_head + metric_lambda * M_tail\n",
        "\n",
        "        M_head_hist.append(M_head)\n",
        "        M_tail_hist.append(M_tail)\n",
        "        M_pub_hist.append(M_pub)\n",
        "\n",
        "        # Welfare: tail accuracy distribution across clients (experimenter-side evaluation)\n",
        "        tail_accs = []\n",
        "        nonempty = 0\n",
        "        for cid in range(CFG.n_clients):\n",
        "            tail_ev = client_tail_eval_subsets[cid]\n",
        "            if len(tail_ev) == 0:\n",
        "                continue\n",
        "            nonempty += 1\n",
        "            loader = DataLoader(tail_ev, batch_size=CFG.batch_size, shuffle=False)\n",
        "            tail_accs.append(evaluate_acc(global_model, loader))\n",
        "\n",
        "        tail_nonempty_hist.append(nonempty)\n",
        "\n",
        "        if len(tail_accs) > 0:\n",
        "            W_t = float(np.mean(tail_accs))\n",
        "            W_std_t = float(np.std(tail_accs))\n",
        "            W_min_t = float(np.min(tail_accs))\n",
        "            W_max_t = float(np.max(tail_accs))\n",
        "        else:\n",
        "            W_t, W_std_t, W_min_t, W_max_t = 0.0, float(\"nan\"), float(\"nan\"), float(\"nan\")\n",
        "\n",
        "        W_hist.append(W_t)\n",
        "        W_std_hist.append(W_std_t)\n",
        "        W_min_hist.append(W_min_t)\n",
        "        W_max_hist.append(W_max_t)\n",
        "\n",
        "        if (rnd + 1) % 10 == 0 or rnd == 0:\n",
        "            print(f\"[gf={gaming_frac:.2f} | lam={metric_lambda:.2f}] \"\n",
        "                  f\"Round {rnd+1:>2}/{CFG.n_rounds} | \"\n",
        "                  f\"M_head={M_head:.4f} | M_tail={M_tail:.4f} | M_pub={M_pub:.4f} | \"\n",
        "                  f\"W_tail(mean)={W_t:.4f} | W_std={W_std_t:.4f} | n_tail_clients={nonempty}\")\n",
        "\n",
        "    return {\n",
        "        \"gaming_frac\": gaming_frac,\n",
        "        \"benign_frac\": benign_frac,\n",
        "        \"metric_lambda\": metric_lambda,\n",
        "        \"strategy_map\": strat_map,\n",
        "        \"strategy_counts\": {\n",
        "            \"honest\": n_honest,\n",
        "            \"benign\": n_benign,\n",
        "            \"gaming_head\": n_ghead,\n",
        "            \"gaming_scale\": n_gscale,\n",
        "        },\n",
        "        \"M_head_hist\": np.array(M_head_hist, dtype=np.float32),\n",
        "        \"M_tail_hist\": np.array(M_tail_hist, dtype=np.float32),\n",
        "        \"M_pub_hist\": np.array(M_pub_hist, dtype=np.float32),\n",
        "        \"W_hist\": np.array(W_hist, dtype=np.float32),\n",
        "        \"W_std_hist\": np.array(W_std_hist, dtype=np.float32),\n",
        "        \"W_min_hist\": np.array(W_min_hist, dtype=np.float32),\n",
        "        \"W_max_hist\": np.array(W_max_hist, dtype=np.float32),\n",
        "        \"tail_nonempty_hist\": np.array(tail_nonempty_hist, dtype=np.int32),\n",
        "    }\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 8) Summary metrics (same style as E2)\n",
        "# ----------------------------\n",
        "def tail_slice(x: np.ndarray, k: int) -> np.ndarray:\n",
        "    if len(x) == 0:\n",
        "        return x\n",
        "    if len(x) <= k:\n",
        "        return x\n",
        "    return x[-k:]\n",
        "\n",
        "\n",
        "def mean_last_k(x: np.ndarray, k: int) -> float:\n",
        "    xs = tail_slice(x, k)\n",
        "    return float(np.mean(xs)) if len(xs) > 0 else float(\"nan\")\n",
        "\n",
        "\n",
        "def std_last_k(x: np.ndarray, k: int) -> float:\n",
        "    xs = tail_slice(x, k)\n",
        "    return float(np.std(xs)) if len(xs) > 0 else float(\"nan\")\n",
        "\n",
        "\n",
        "def min_last_k(x: np.ndarray, k: int) -> float:\n",
        "    xs = tail_slice(x, k)\n",
        "    return float(np.min(xs)) if len(xs) > 0 else float(\"nan\")\n",
        "\n",
        "\n",
        "def max_last_k(x: np.ndarray, k: int) -> float:\n",
        "    xs = tail_slice(x, k)\n",
        "    return float(np.max(xs)) if len(xs) > 0 else float(\"nan\")\n",
        "\n",
        "\n",
        "def safe_div(a: float, b: float, eps: float = 1e-12) -> float:\n",
        "    return float(a / (b + eps))\n",
        "\n",
        "\n",
        "# \"PoG-style\" summary helpers.\n",
        "# Since your paper's exact PoG definition may differ, we compute two common views:\n",
        "# (1) Global reference vs aligned @ lambda=0 (kept constant across lambdas)\n",
        "# (2) Paired at same lambda: (W_aligned - W_gaming) / W_aligned  (good for E3 story)\n",
        "def pog_vs_ref(W_ref: float, W_final: float) -> float:\n",
        "    if not np.isfinite(W_ref) or W_ref <= 1e-8:\n",
        "        return float(\"nan\")\n",
        "    return float((W_ref - W_final) / W_ref)\n",
        "\n",
        "\n",
        "def pog_paired_same_lambda(W_aligned: float, W_gaming: float) -> float:\n",
        "    if not np.isfinite(W_aligned) or W_aligned <= 1e-8:\n",
        "        return float(\"nan\")\n",
        "    return float((W_aligned - W_gaming) / W_aligned)\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 9) Main: Sweep lambdas, compare aligned vs gaming\n",
        "# ----------------------------\n",
        "def main():\n",
        "    print(\"\\n=== E3: High-alignment regime (Fashion-MNIST) ===\")\n",
        "    print(f\"Device: {CFG.device}\")\n",
        "    print(f\"Clients={CFG.n_clients}, Rounds={CFG.n_rounds}, LocalEpochs={CFG.local_epochs}\")\n",
        "    print(f\"dirichlet_alpha={CFG.dirichlet_alpha}\")\n",
        "    print(f\"benign_frac={CFG.benign_frac}, gaming_frac_main={CFG.gaming_frac_main}\")\n",
        "    print(f\"metric_lambdas={CFG.metric_lambdas}\")\n",
        "    print(f\"Head classes={CFG.head_classes} | Tail classes={CFG.tail_classes}\\n\")\n",
        "\n",
        "    # 9.1 Global reference baseline (aligned, lambda=0.0) to keep a fixed anchor if desired\n",
        "    print(\">> Running global reference baseline: ALIGNED, lambda=0.0\")\n",
        "    ref = run_once(gaming_frac=0.0, benign_frac=CFG.benign_frac, metric_lambda=0.0, seed_offset=0)\n",
        "    Mpub_ref = mean_last_k(ref[\"M_pub_hist\"], CFG.tail_k)\n",
        "    W_ref = mean_last_k(ref[\"W_hist\"], CFG.tail_k)\n",
        "\n",
        "    print(\"\\n[Global Reference] mean over last K rounds\")\n",
        "    print(f\"  M_pub_ref (lambda=0.0) = {Mpub_ref:.4f}\")\n",
        "    print(f\"  W_ref (tail welfare)   = {W_ref:.4f}\\n\")\n",
        "\n",
        "    results = []\n",
        "    seed_offset = 10\n",
        "\n",
        "    for lam in CFG.metric_lambdas:\n",
        "        print(f\"\\n==============================\")\n",
        "        print(f\" Lambda = {lam:.2f}\")\n",
        "        print(f\"==============================\")\n",
        "\n",
        "        # ALIGNED at this lambda\n",
        "        print(f\">> Running ALIGNED @ lambda={lam:.2f} (gf=0.0)\")\n",
        "        aligned = run_once(gaming_frac=0.0, benign_frac=CFG.benign_frac, metric_lambda=lam, seed_offset=seed_offset)\n",
        "        seed_offset += 1\n",
        "\n",
        "        # GAMING at this lambda\n",
        "        print(f\"\\n>> Running GAMING  @ lambda={lam:.2f} (gf={CFG.gaming_frac_main:.2f})\")\n",
        "        gaming = run_once(gaming_frac=CFG.gaming_frac_main, benign_frac=CFG.benign_frac, metric_lambda=lam, seed_offset=seed_offset)\n",
        "        seed_offset += 1\n",
        "\n",
        "        def summarize(tag: str, out: Dict):\n",
        "            M_head_final = mean_last_k(out[\"M_head_hist\"], CFG.tail_k)\n",
        "            M_tail_final = mean_last_k(out[\"M_tail_hist\"], CFG.tail_k)\n",
        "            M_pub_final = mean_last_k(out[\"M_pub_hist\"], CFG.tail_k)\n",
        "            W_final = mean_last_k(out[\"W_hist\"], CFG.tail_k)\n",
        "\n",
        "            # gaps\n",
        "            gap_pub_minus_W = M_pub_final - W_final\n",
        "            gap_head_minus_tail = M_head_final - M_tail_final\n",
        "            align_error_abs = abs(gap_pub_minus_W)\n",
        "\n",
        "            # global reference PoG-style (anchor at aligned lambda=0.0 reference)\n",
        "            pog_global = pog_vs_ref(W_ref=W_ref, W_final=W_final)\n",
        "\n",
        "            # stability\n",
        "            Mpub_std = std_last_k(out[\"M_pub_hist\"], CFG.tail_k)\n",
        "            W_std = std_last_k(out[\"W_hist\"], CFG.tail_k)\n",
        "\n",
        "            Mpub_range = max_last_k(out[\"M_pub_hist\"], CFG.tail_k) - min_last_k(out[\"M_pub_hist\"], CFG.tail_k)\n",
        "            W_range = max_last_k(out[\"W_hist\"], CFG.tail_k) - min_last_k(out[\"W_hist\"], CFG.tail_k)\n",
        "\n",
        "            # across-client tail distribution stats (already per-round)\n",
        "            W_std_across_clients = mean_last_k(out[\"W_std_hist\"], CFG.tail_k)\n",
        "            W_min = mean_last_k(out[\"W_min_hist\"], CFG.tail_k)\n",
        "            W_max = mean_last_k(out[\"W_max_hist\"], CFG.tail_k)\n",
        "\n",
        "            # normalized ratios\n",
        "            gap_pub_over_W = safe_div(gap_pub_minus_W, max(W_final, 0.0))\n",
        "            gap_headtail_over_tail = safe_div(gap_head_minus_tail, max(M_tail_final, 0.0))\n",
        "\n",
        "            # tail coverage\n",
        "            tail_nonempty = int(np.round(mean_last_k(out[\"tail_nonempty_hist\"].astype(np.float32), CFG.tail_k)))\n",
        "\n",
        "            sc = out[\"strategy_counts\"]\n",
        "\n",
        "            return {\n",
        "                \"condition\": tag,\n",
        "                \"lambda\": float(out[\"metric_lambda\"]),\n",
        "                \"gaming_frac\": float(out[\"gaming_frac\"]),\n",
        "                \"benign_frac\": float(out[\"benign_frac\"]),\n",
        "                \"n_clients\": int(CFG.n_clients),\n",
        "\n",
        "                \"n_honest\": int(sc[\"honest\"]),\n",
        "                \"n_benign\": int(sc[\"benign\"]),\n",
        "                \"n_gaming_head\": int(sc[\"gaming_head\"]),\n",
        "                \"n_gaming_scale\": int(sc[\"gaming_scale\"]),\n",
        "\n",
        "                # Primary: public metrics + welfare\n",
        "                \"M_head_final\": float(M_head_final),\n",
        "                \"M_tail_final\": float(M_tail_final),\n",
        "                \"M_pub_final\": float(M_pub_final),\n",
        "                \"W_tail_final\": float(W_final),\n",
        "\n",
        "                # Alignment / gap diagnostics\n",
        "                \"gap_Mpub_minus_W\": float(gap_pub_minus_W),\n",
        "                \"abs_align_error_|Mpub-W|\": float(align_error_abs),\n",
        "                \"gap_Mhead_minus_Mtail\": float(gap_head_minus_tail),\n",
        "\n",
        "                \"gap_pub_over_W\": float(gap_pub_over_W),\n",
        "                \"gap_headtail_over_Mtail\": float(gap_headtail_over_tail),\n",
        "\n",
        "                # PoG-style (global anchor; keep constant anchor)\n",
        "                \"PoG_vs_global_refW\": float(pog_global),\n",
        "\n",
        "                # Temporal stability (last K)\n",
        "                \"Mpub_std_lastK\": float(Mpub_std),\n",
        "                \"W_std_lastK\": float(W_std),\n",
        "                \"Mpub_range_lastK\": float(Mpub_range),\n",
        "                \"W_range_lastK\": float(W_range),\n",
        "\n",
        "                # Across-client tail distribution stats (last K averaged)\n",
        "                \"W_std_across_clients_lastK\": float(W_std_across_clients),\n",
        "                \"W_min_across_clients_lastK\": float(W_min),\n",
        "                \"W_max_across_clients_lastK\": float(W_max),\n",
        "\n",
        "                # Tail eval coverage\n",
        "                \"tail_nonempty_clients_lastK\": int(tail_nonempty),\n",
        "            }\n",
        "\n",
        "        row_aligned = summarize(\"ALIGNED\", aligned)\n",
        "        row_gaming = summarize(\"GAMING\", gaming)\n",
        "\n",
        "        # Paired, same-lambda effects (this is the clean E3 story)\n",
        "        paired_deltaW = row_aligned[\"W_tail_final\"] - row_gaming[\"W_tail_final\"]\n",
        "        paired_deltaMpub = row_gaming[\"M_pub_final\"] - row_aligned[\"M_pub_final\"]\n",
        "        paired_deltaGap = row_gaming[\"gap_Mpub_minus_W\"] - row_aligned[\"gap_Mpub_minus_W\"]\n",
        "        paired_pog = pog_paired_same_lambda(row_aligned[\"W_tail_final\"], row_gaming[\"W_tail_final\"])\n",
        "\n",
        "        # Store paired columns on both rows\n",
        "        for r in (row_aligned, row_gaming):\n",
        "            r[\"paired_deltaW_aligned_minus_gaming\"] = float(paired_deltaW)\n",
        "            r[\"paired_deltaMpub_gaming_minus_aligned\"] = float(paired_deltaMpub)\n",
        "            r[\"paired_deltaGap_gaming_minus_aligned\"] = float(paired_deltaGap)\n",
        "            r[\"paired_PoG_same_lambda\"] = float(paired_pog)\n",
        "\n",
        "        results.append(row_aligned)\n",
        "        results.append(row_gaming)\n",
        "\n",
        "        # quick print\n",
        "        print(f\"\\n[Summary @ lambda={lam:.2f}] (last K mean)\")\n",
        "        print(f\"  ALIGNED: Mpub={row_aligned['M_pub_final']:.4f} | W={row_aligned['W_tail_final']:.4f} | \"\n",
        "              f\"gap={row_aligned['gap_Mpub_minus_W']:.4f}\")\n",
        "        print(f\"  GAMING : Mpub={row_gaming['M_pub_final']:.4f} | W={row_gaming['W_tail_final']:.4f} | \"\n",
        "              f\"gap={row_gaming['gap_Mpub_minus_W']:.4f}\")\n",
        "        print(f\"  Paired : ΔW(aligned-gaming)={paired_deltaW:+.4f} | \"\n",
        "              f\"Δgap(gaming-aligned)={paired_deltaGap:+.4f} | paired_PoG={paired_pog:+.4f}\")\n",
        "\n",
        "    # 9.2 Save CSV\n",
        "    out_path = \"E3_high_alignment_metric_sweep_results.csv\"\n",
        "    import csv\n",
        "    with open(out_path, \"w\", newline=\"\", encoding=\"utf-8\") as f:\n",
        "        fieldnames = list(results[0].keys())\n",
        "        writer = csv.DictWriter(f, fieldnames=fieldnames)\n",
        "        writer.writeheader()\n",
        "        writer.writerows(results)\n",
        "\n",
        "    print(f\"\\nSaved: {out_path}\")\n",
        "    print(\"Done.\")\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yjozde2g4FIw",
        "outputId": "cec89563-e40d-477a-92fb-0132e9db52ed"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "=== E3: High-alignment regime (Fashion-MNIST) ===\n",
            "Device: cuda\n",
            "Clients=12, Rounds=25, LocalEpochs=1\n",
            "dirichlet_alpha=0.5\n",
            "benign_frac=0.1, gaming_frac_main=0.3\n",
            "metric_lambdas=(0.0, 0.3, 0.6)\n",
            "Head classes=(0, 1, 2, 3, 4) | Tail classes=(5, 6, 7, 8, 9)\n",
            "\n",
            ">> Running global reference baseline: ALIGNED, lambda=0.0\n",
            "[gf=0.00 | lam=0.00] Round  1/25 | M_head=0.6344 | M_tail=0.5729 | M_pub=0.6344 | W_tail(mean)=0.5971 | W_std=0.1427 | n_tail_clients=12\n",
            "[gf=0.00 | lam=0.00] Round 10/25 | M_head=0.8485 | M_tail=0.8365 | M_pub=0.8485 | W_tail(mean)=0.8358 | W_std=0.0664 | n_tail_clients=12\n",
            "[gf=0.00 | lam=0.00] Round 20/25 | M_head=0.8682 | M_tail=0.8676 | M_pub=0.8682 | W_tail(mean)=0.8692 | W_std=0.0552 | n_tail_clients=12\n",
            "\n",
            "[Global Reference] mean over last K rounds\n",
            "  M_pub_ref (lambda=0.0) = 0.8648\n",
            "  W_ref (tail welfare)   = 0.8791\n",
            "\n",
            "\n",
            "==============================\n",
            " Lambda = 0.00\n",
            "==============================\n",
            ">> Running ALIGNED @ lambda=0.00 (gf=0.0)\n",
            "[gf=0.00 | lam=0.00] Round  1/25 | M_head=0.7277 | M_tail=0.6208 | M_pub=0.7277 | W_tail(mean)=0.6281 | W_std=0.1370 | n_tail_clients=12\n",
            "[gf=0.00 | lam=0.00] Round 10/25 | M_head=0.8501 | M_tail=0.8407 | M_pub=0.8501 | W_tail(mean)=0.8365 | W_std=0.0682 | n_tail_clients=12\n",
            "[gf=0.00 | lam=0.00] Round 20/25 | M_head=0.8774 | M_tail=0.8637 | M_pub=0.8774 | W_tail(mean)=0.8732 | W_std=0.0615 | n_tail_clients=12\n",
            "\n",
            ">> Running GAMING  @ lambda=0.00 (gf=0.30)\n",
            "[gf=0.30 | lam=0.00] Round  1/25 | M_head=0.5982 | M_tail=0.3868 | M_pub=0.5982 | W_tail(mean)=0.4519 | W_std=0.2703 | n_tail_clients=12\n",
            "[gf=0.30 | lam=0.00] Round 10/25 | M_head=0.8876 | M_tail=0.7666 | M_pub=0.8876 | W_tail(mean)=0.7756 | W_std=0.1150 | n_tail_clients=12\n",
            "[gf=0.30 | lam=0.00] Round 20/25 | M_head=0.9053 | M_tail=0.8132 | M_pub=0.9053 | W_tail(mean)=0.8226 | W_std=0.0923 | n_tail_clients=12\n",
            "\n",
            "[Summary @ lambda=0.00] (last K mean)\n",
            "  ALIGNED: Mpub=0.8633 | W=0.8917 | gap=-0.0284\n",
            "  GAMING : Mpub=0.9101 | W=0.8396 | gap=0.0705\n",
            "  Paired : ΔW(aligned-gaming)=+0.0521 | Δgap(gaming-aligned)=+0.0989 | paired_PoG=+0.0585\n",
            "\n",
            "==============================\n",
            " Lambda = 0.30\n",
            "==============================\n",
            ">> Running ALIGNED @ lambda=0.30 (gf=0.0)\n",
            "[gf=0.00 | lam=0.30] Round  1/25 | M_head=0.6141 | M_tail=0.6071 | M_pub=0.6120 | W_tail(mean)=0.6249 | W_std=0.1402 | n_tail_clients=12\n",
            "[gf=0.00 | lam=0.30] Round 10/25 | M_head=0.8466 | M_tail=0.8532 | M_pub=0.8485 | W_tail(mean)=0.8623 | W_std=0.0598 | n_tail_clients=12\n",
            "[gf=0.00 | lam=0.30] Round 20/25 | M_head=0.8517 | M_tail=0.8878 | M_pub=0.8625 | W_tail(mean)=0.8909 | W_std=0.0467 | n_tail_clients=12\n",
            "\n",
            ">> Running GAMING  @ lambda=0.30 (gf=0.30)\n",
            "[gf=0.30 | lam=0.30] Round  1/25 | M_head=0.7318 | M_tail=0.4379 | M_pub=0.6436 | W_tail(mean)=0.4894 | W_std=0.1064 | n_tail_clients=12\n",
            "[gf=0.30 | lam=0.30] Round 10/25 | M_head=0.8859 | M_tail=0.7976 | M_pub=0.8594 | W_tail(mean)=0.8070 | W_std=0.0784 | n_tail_clients=12\n",
            "[gf=0.30 | lam=0.30] Round 20/25 | M_head=0.9083 | M_tail=0.8377 | M_pub=0.8871 | W_tail(mean)=0.8415 | W_std=0.0622 | n_tail_clients=12\n",
            "\n",
            "[Summary @ lambda=0.30] (last K mean)\n",
            "  ALIGNED: Mpub=0.8730 | W=0.8863 | gap=-0.0133\n",
            "  GAMING : Mpub=0.8913 | W=0.8470 | gap=0.0442\n",
            "  Paired : ΔW(aligned-gaming)=+0.0393 | Δgap(gaming-aligned)=+0.0575 | paired_PoG=+0.0443\n",
            "\n",
            "==============================\n",
            " Lambda = 0.60\n",
            "==============================\n",
            ">> Running ALIGNED @ lambda=0.60 (gf=0.0)\n",
            "[gf=0.00 | lam=0.60] Round  1/25 | M_head=0.6253 | M_tail=0.5159 | M_pub=0.5597 | W_tail(mean)=0.5612 | W_std=0.1580 | n_tail_clients=12\n",
            "[gf=0.00 | lam=0.60] Round 10/25 | M_head=0.8495 | M_tail=0.8389 | M_pub=0.8431 | W_tail(mean)=0.8476 | W_std=0.0606 | n_tail_clients=12\n",
            "[gf=0.00 | lam=0.60] Round 20/25 | M_head=0.8731 | M_tail=0.8719 | M_pub=0.8724 | W_tail(mean)=0.8706 | W_std=0.0546 | n_tail_clients=12\n",
            "\n",
            ">> Running GAMING  @ lambda=0.60 (gf=0.30)\n",
            "[gf=0.30 | lam=0.60] Round  1/25 | M_head=0.5194 | M_tail=0.3580 | M_pub=0.4226 | W_tail(mean)=0.3994 | W_std=0.0913 | n_tail_clients=12\n",
            "[gf=0.30 | lam=0.60] Round 10/25 | M_head=0.8890 | M_tail=0.7866 | M_pub=0.8275 | W_tail(mean)=0.7872 | W_std=0.0837 | n_tail_clients=12\n",
            "[gf=0.30 | lam=0.60] Round 20/25 | M_head=0.9055 | M_tail=0.8456 | M_pub=0.8696 | W_tail(mean)=0.8438 | W_std=0.0564 | n_tail_clients=12\n",
            "\n",
            "[Summary @ lambda=0.60] (last K mean)\n",
            "  ALIGNED: Mpub=0.8755 | W=0.8800 | gap=-0.0045\n",
            "  GAMING : Mpub=0.8740 | W=0.8511 | gap=0.0229\n",
            "  Paired : ΔW(aligned-gaming)=+0.0289 | Δgap(gaming-aligned)=+0.0274 | paired_PoG=+0.0329\n",
            "\n",
            "Saved: E3_high_alignment_metric_sweep_results.csv\n",
            "Done.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Modern attack--defense replication"
      ],
      "metadata": {
        "id": "YakPqpgVX6en"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip -q install flwr-datasets[vision] datasets"
      ],
      "metadata": {
        "id": "imBnJSwq-Y_f"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ============================================================\n",
        "# E4 FULL (No LEAF): FEMNIST via HuggingFace/Flower Datasets\n",
        "#   - Attacks: PoisonedFL (multi-round consistency + proximal + dyn magnitude),\n",
        "#              Backdoor/Model Replacement\n",
        "#   - Defenses: FedCC (linear CKA filtering), Attack-Adaptive Aggregation\n",
        "#   - FL: FedAvg skeleton (with defense hooks)\n",
        "#   - Output: CSV + NPZ histories\n",
        "# ============================================================\n",
        "\n",
        "# If running on Colab, install once:\n",
        "# !pip -q install flwr-datasets[vision] datasets\n",
        "\n",
        "import os\n",
        "import csv\n",
        "import json\n",
        "import random\n",
        "from dataclasses import dataclass\n",
        "from typing import Dict, List, Tuple, Optional\n",
        "\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from torch.utils.data import DataLoader, TensorDataset, Dataset\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 0) Config / Seed\n",
        "# ----------------------------\n",
        "from dataclasses import dataclass\n",
        "from typing import Tuple\n",
        "import torch\n",
        "import os\n",
        "\n",
        "@dataclass\n",
        "class Config:\n",
        "    # compute budget (Colab-friendly)\n",
        "    n_clients: int = 32            # CHANGED (24 -> 32)\n",
        "    n_rounds: int = 100            # CHANGED (60 -> 100)\n",
        "    local_epochs: int = 1\n",
        "    batch_size: int = 64\n",
        "    lr: float = 0.02\n",
        "    momentum: float = 0.9\n",
        "    weight_decay: float = 0.0\n",
        "\n",
        "    # participation\n",
        "    clients_per_round: int = 12    # CHANGED (8 -> 12)\n",
        "\n",
        "    # FEMNIST classes: typically 62 classes (digits + letters)\n",
        "    # We'll define \"head\" as digits, \"tail\" as letters for head/tail gap.\n",
        "    head_classes: Tuple[int, ...] = tuple(range(0, 10))       # digits\n",
        "    tail_classes: Tuple[int, ...] = tuple(range(10, 62))      # letters\n",
        "\n",
        "    # public sets (sampled from client train data)\n",
        "    public_support_size: int = 128       # (CKA용) 유지\n",
        "    public_head_eval_size: int = 4096    # CHANGED (2048 -> 4096)\n",
        "    public_tail_eval_size: int = 4096    # CHANGED (2048 -> 4096)\n",
        "\n",
        "    # malicious population + attack schedule\n",
        "    malicious_frac: float = 0.30         # CHANGED (0.20 -> 0.30)\n",
        "    attack_start_round: int = 5          # CHANGED (8 -> 5)\n",
        "\n",
        "    # PoisonedFL-ish knobs (more faithful but still light)\n",
        "    poisonedfl_scale: float = 2.2\n",
        "    poisonedfl_scale_min: float = 1.0\n",
        "    poisonedfl_scale_max: float = 4.0\n",
        "    poisonedfl_dyn_eta: float = 0.8\n",
        "    poisonedfl_beta_consistency: float = 1e-3\n",
        "    poisonedfl_mu_global: float = 5e-4\n",
        "    poisonedfl_tail_label_flip: bool = True\n",
        "\n",
        "    # Backdoor / Model Replacement knobs\n",
        "    backdoor_target_label: int = 0\n",
        "    backdoor_poison_frac: float = 0.30\n",
        "    model_replacement_gamma: float = 6.0\n",
        "    trigger_size: int = 3\n",
        "    trigger_value: float = 1.0   # in normalized space, clamped to [-1, 1]\n",
        "\n",
        "    # FedCC (CKA) filtering knobs\n",
        "    fedcc_reject_frac: float = 0.25\n",
        "    fedcc_min_keep: int = 3\n",
        "\n",
        "    # Attack-adaptive aggregation knobs\n",
        "    adaagg_alpha: float = 3.0\n",
        "\n",
        "    # summary / alarm\n",
        "    tail_k: int = 10\n",
        "    warmup_ignore: int = 5\n",
        "    alarm_threshold_std: float = 2.0     # CHANGED (3.0 -> 2.0)\n",
        "\n",
        "    # FEMNIST client selection safeguards\n",
        "    min_train_samples_per_client: int = 40\n",
        "    min_test_samples_per_client: int = 10\n",
        "    max_partition_id_try: int = 5000\n",
        "\n",
        "    # reproducibility\n",
        "    seed: int = 42\n",
        "    device: torch.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # outputs\n",
        "    out_dir: str = \"./E4_outputs_noLEAF\"\n",
        "    csv_name: str = \"E4_FEMNIST_2x2_results.csv\"\n",
        "\n",
        "\n",
        "CFG = Config()\n",
        "os.makedirs(CFG.out_dir, exist_ok=True)\n",
        "\n",
        "\n",
        "def set_seed(seed: int):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "\n",
        "set_seed(CFG.seed)\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 1) FEMNIST via HF/Flower Datasets\n",
        "# ----------------------------\n",
        "def build_femnist_clients_from_hf(\n",
        "    n_clients: int,\n",
        "    train_ratio: float,\n",
        "    seed: int,\n",
        "    min_train: int,\n",
        "    min_test: int,\n",
        "    max_try: int,\n",
        ") -> Tuple[List[TensorDataset], List[TensorDataset], int]:\n",
        "    \"\"\"\n",
        "    Uses flwr_datasets FederatedDataset with NaturalIdPartitioner(partition_by=\"writer_id\").\n",
        "    Loads partitions sequentially (partition_id=0,1,2,...) until collecting enough clients.\n",
        "\n",
        "    Returns:\n",
        "      client_train_sets, client_test_sets, num_classes (assumed 62 if present; otherwise inferred)\n",
        "    \"\"\"\n",
        "    from flwr_datasets import FederatedDataset\n",
        "    from flwr_datasets.partitioner import NaturalIdPartitioner\n",
        "\n",
        "    rng = np.random.default_rng(seed)\n",
        "\n",
        "    fds = FederatedDataset(\n",
        "        dataset=\"flwrlabs/femnist\",\n",
        "        partitioners={\"train\": NaturalIdPartitioner(partition_by=\"writer_id\")},\n",
        "    )\n",
        "\n",
        "    client_train_sets: List[TensorDataset] = []\n",
        "    client_test_sets: List[TensorDataset] = []\n",
        "    max_label = -1\n",
        "\n",
        "    # Helper: PIL->torch normalized [-1,1]\n",
        "    def to_tensors(part) -> Tuple[torch.Tensor, torch.Tensor]:\n",
        "        images = part[\"image\"]  # PIL images\n",
        "        labels = np.array(part[\"character\"], dtype=np.int64)\n",
        "        # PIL -> np\n",
        "        X = np.stack([np.array(img, dtype=np.float32) for img in images], axis=0)  # [N,28,28]\n",
        "        X = X / 255.0\n",
        "        X = (X - 0.5) / 0.5\n",
        "        X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)  # [N,1,28,28]\n",
        "        Y = torch.tensor(labels, dtype=torch.long)\n",
        "        return X, Y\n",
        "\n",
        "    for pid in range(max_try):\n",
        "        if len(client_train_sets) >= n_clients:\n",
        "            break\n",
        "        try:\n",
        "            part = fds.load_partition(partition_id=pid, split=\"train\")\n",
        "        except Exception:\n",
        "            # Ran out of partitions or unsupported id\n",
        "            break\n",
        "\n",
        "        # Convert\n",
        "        X, Y = to_tensors(part)\n",
        "        n = len(Y)\n",
        "        if n < (min_train + min_test):\n",
        "            continue  # too small client\n",
        "\n",
        "        # split per-client\n",
        "        idx = np.arange(n)\n",
        "        rng.shuffle(idx)\n",
        "        cut = int(round(train_ratio * n))\n",
        "        tr_idx = idx[:cut]\n",
        "        te_idx = idx[cut:]\n",
        "\n",
        "        if len(tr_idx) < min_train or len(te_idx) < min_test:\n",
        "            continue\n",
        "\n",
        "        ds_tr = TensorDataset(X[tr_idx], Y[tr_idx])\n",
        "        ds_te = TensorDataset(X[te_idx], Y[te_idx])\n",
        "\n",
        "        client_train_sets.append(ds_tr)\n",
        "        client_test_sets.append(ds_te)\n",
        "\n",
        "        if n > 0:\n",
        "            max_label = max(max_label, int(Y.max().item()))\n",
        "\n",
        "    if len(client_train_sets) < n_clients:\n",
        "        raise RuntimeError(\n",
        "            f\"Not enough FEMNIST clients collected. Got {len(client_train_sets)} / {n_clients}. \"\n",
        "            f\"Increase max_partition_id_try (currently {max_try}) or relax min_*_samples.\"\n",
        "        )\n",
        "\n",
        "    num_classes = max_label + 1\n",
        "    # FEMNIST is usually 62 classes; we keep inferred value to be safe.\n",
        "    return client_train_sets, client_test_sets, num_classes\n",
        "\n",
        "\n",
        "def filter_by_labels(ds: TensorDataset, allowed: Tuple[int, ...]) -> TensorDataset:\n",
        "    allowed_set = set(int(x) for x in allowed)\n",
        "    X, Y = ds.tensors\n",
        "    if len(Y) == 0:\n",
        "        return TensorDataset(X, Y)\n",
        "    mask = torch.zeros_like(Y, dtype=torch.bool)\n",
        "    for c in allowed_set:\n",
        "        mask |= (Y == c)\n",
        "    return TensorDataset(X[mask], Y[mask])\n",
        "\n",
        "\n",
        "def sample_public_from_clients(\n",
        "    client_train_sets: List[TensorDataset],\n",
        "    allowed_labels: Tuple[int, ...],\n",
        "    total_size: int,\n",
        "    rng: np.random.Generator,\n",
        ") -> TensorDataset:\n",
        "    xs, ys = [], []\n",
        "    for ds in client_train_sets:\n",
        "        ds_f = filter_by_labels(ds, allowed_labels)\n",
        "        if len(ds_f) == 0:\n",
        "            continue\n",
        "        X, Y = ds_f.tensors\n",
        "        xs.append(X)\n",
        "        ys.append(Y)\n",
        "    if len(xs) == 0:\n",
        "        return TensorDataset(torch.empty(0), torch.empty(0, dtype=torch.long))\n",
        "    X = torch.cat(xs, dim=0)\n",
        "    Y = torch.cat(ys, dim=0)\n",
        "    n = len(Y)\n",
        "    if n == 0:\n",
        "        return TensorDataset(torch.empty(0), torch.empty(0, dtype=torch.long))\n",
        "    idx = np.arange(n)\n",
        "    rng.shuffle(idx)\n",
        "    idx = idx[: min(total_size, n)]\n",
        "    return TensorDataset(X[idx], Y[idx])\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 2) Model with penultimate features\n",
        "# ----------------------------\n",
        "class SimpleCNNFeat(nn.Module):\n",
        "    def __init__(self, num_classes: int):\n",
        "        super().__init__()\n",
        "        self.features = nn.Sequential(\n",
        "            nn.Conv2d(1, 32, kernel_size=3, padding=1),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.MaxPool2d(2),\n",
        "            nn.Conv2d(32, 64, kernel_size=3, padding=1),\n",
        "            nn.ReLU(inplace=True),\n",
        "            nn.MaxPool2d(2),\n",
        "        )\n",
        "        self.fc1 = nn.Linear(64 * 7 * 7, 128)\n",
        "        self.relu = nn.ReLU(inplace=True)\n",
        "        self.fc2 = nn.Linear(128, num_classes)\n",
        "\n",
        "    def forward(self, x, return_feat: bool = False):\n",
        "        z = self.features(x).flatten(1)\n",
        "        feat = self.relu(self.fc1(z))\n",
        "        logits = self.fc2(feat)\n",
        "        if return_feat:\n",
        "            return logits, feat\n",
        "        return logits\n",
        "\n",
        "\n",
        "def get_model(num_classes: int) -> nn.Module:\n",
        "    return SimpleCNNFeat(num_classes=num_classes).to(CFG.device)\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 3) Eval helpers\n",
        "# ----------------------------\n",
        "@torch.no_grad()\n",
        "def evaluate_acc(model: nn.Module, loader: DataLoader) -> float:\n",
        "    model.eval()\n",
        "    correct, total = 0, 0\n",
        "    for x, y in loader:\n",
        "        x, y = x.to(CFG.device), y.to(CFG.device)\n",
        "        logits = model(x)\n",
        "        pred = logits.argmax(dim=1)\n",
        "        correct += (pred == y).sum().item()\n",
        "        total += y.size(0)\n",
        "    return float(correct / total) if total > 0 else 0.0\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def collect_features(model: nn.Module, loader: DataLoader) -> torch.Tensor:\n",
        "    model.eval()\n",
        "    feats = []\n",
        "    for x, _ in loader:\n",
        "        x = x.to(CFG.device)\n",
        "        _, f = model(x, return_feat=True)\n",
        "        feats.append(f.detach())\n",
        "    if len(feats) == 0:\n",
        "        return torch.empty(0, 128, device=CFG.device)\n",
        "    return torch.cat(feats, dim=0)\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 4) State helpers\n",
        "# ----------------------------\n",
        "def fedavg_mean(states: List[Dict[str, torch.Tensor]], weights: Optional[List[float]] = None) -> Dict[str, torch.Tensor]:\n",
        "    if weights is None:\n",
        "        weights = [1.0] * len(states)\n",
        "    wsum = float(sum(weights))\n",
        "    avg = {}\n",
        "    for k in states[0].keys():\n",
        "        stacked = torch.stack([sd[k] * float(w) for sd, w in zip(states, weights)], dim=0)\n",
        "        avg[k] = stacked.sum(dim=0) / wsum\n",
        "    return avg\n",
        "\n",
        "\n",
        "def make_transmitted_state(global_sd: Dict[str, torch.Tensor], local_sd: Dict[str, torch.Tensor], scale: float) -> Dict[str, torch.Tensor]:\n",
        "    tx = {}\n",
        "    for k in global_sd.keys():\n",
        "        tx[k] = global_sd[k] + scale * (local_sd[k] - global_sd[k])\n",
        "    return tx\n",
        "\n",
        "\n",
        "def flatten_update(global_sd: Dict[str, torch.Tensor], tx_sd: Dict[str, torch.Tensor]) -> torch.Tensor:\n",
        "    vecs = []\n",
        "    for k in global_sd.keys():\n",
        "        vecs.append((tx_sd[k] - global_sd[k]).detach().flatten())\n",
        "    return torch.cat(vecs, dim=0)\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 5) Backdoor wrapper (trigger + target relabel)\n",
        "# ----------------------------\n",
        "class BackdoorWrapper(Dataset):\n",
        "    def __init__(self, base: TensorDataset, poison_frac: float, target_label: int,\n",
        "                 trigger_size: int, trigger_value: float, seed: int):\n",
        "        self.base = base\n",
        "        self.poison_frac = float(poison_frac)\n",
        "        self.target_label = int(target_label)\n",
        "        self.trigger_size = int(trigger_size)\n",
        "        self.trigger_value = float(trigger_value)\n",
        "        rng = np.random.default_rng(seed)\n",
        "\n",
        "        n = len(base)\n",
        "        idx = np.arange(n)\n",
        "        rng.shuffle(idx)\n",
        "        self.poison_idx = set(idx[: int(round(self.poison_frac * n))].tolist())\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.base)\n",
        "\n",
        "    def _apply_trigger(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        x2 = x.clone()\n",
        "        s = self.trigger_size\n",
        "        x2[:, -s:, -s:] = torch.clamp(torch.tensor(self.trigger_value, dtype=x2.dtype), -1.0, 1.0)\n",
        "        return x2\n",
        "\n",
        "    def __getitem__(self, idx: int):\n",
        "        x, y = self.base[idx]\n",
        "        if idx in self.poison_idx:\n",
        "            x = self._apply_trigger(x)\n",
        "            y = torch.tensor(self.target_label, dtype=torch.long)\n",
        "        return x, y\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 6) PoisonedFL-ish local training (multi-round consistency + proximal)\n",
        "# ----------------------------\n",
        "def _param_l2(sd_a: Dict[str, torch.Tensor], sd_b: Dict[str, torch.Tensor]) -> torch.Tensor:\n",
        "    s = torch.tensor(0.0, device=CFG.device)\n",
        "    for k in sd_a.keys():\n",
        "        s = s + (sd_a[k] - sd_b[k]).pow(2).sum()\n",
        "    return s\n",
        "\n",
        "\n",
        "def local_train_honest(model: nn.Module, dataset, epochs: int) -> nn.Module:\n",
        "    loader = DataLoader(dataset, batch_size=CFG.batch_size, shuffle=True)\n",
        "    opt = optim.SGD(model.parameters(), lr=CFG.lr, momentum=CFG.momentum, weight_decay=CFG.weight_decay)\n",
        "    crit = nn.CrossEntropyLoss()\n",
        "    model.train()\n",
        "    for _ in range(epochs):\n",
        "        for x, y in loader:\n",
        "            x, y = x.to(CFG.device), y.to(CFG.device)\n",
        "            loss = crit(model(x), y)\n",
        "            opt.zero_grad()\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "    return model\n",
        "\n",
        "\n",
        "def local_train_poisonedfl(\n",
        "    model: nn.Module,\n",
        "    dataset,\n",
        "    global_sd: Dict[str, torch.Tensor],\n",
        "    prev_mal_tx_sd: Optional[Dict[str, torch.Tensor]],\n",
        "    epochs: int,\n",
        ") -> Dict[str, torch.Tensor]:\n",
        "    \"\"\"\n",
        "    Lightweight-but-closer PoisonedFL-style:\n",
        "      loss = CE(on modified labels) +\n",
        "             mu * ||w - w_global||^2 +\n",
        "             beta * ||w - w_prev_mal||^2\n",
        "    \"\"\"\n",
        "    loader = DataLoader(dataset, batch_size=CFG.batch_size, shuffle=True)\n",
        "    opt = optim.SGD(model.parameters(), lr=CFG.lr, momentum=CFG.momentum, weight_decay=CFG.weight_decay)\n",
        "    crit = nn.CrossEntropyLoss()\n",
        "\n",
        "    head_arr = torch.tensor(list(CFG.head_classes), device=CFG.device, dtype=torch.long)\n",
        "    tail_set = set(int(x) for x in CFG.tail_classes)\n",
        "\n",
        "    model.train()\n",
        "    for _ in range(epochs):\n",
        "        for x, y in loader:\n",
        "            x, y = x.to(CFG.device), y.to(CFG.device)\n",
        "\n",
        "            # degrade tail via flipping tail labels into head labels (cheap & consistent)\n",
        "            y_mod = y\n",
        "            if CFG.poisonedfl_tail_label_flip:\n",
        "                y_mod = y.clone()\n",
        "                tail_mask = torch.zeros_like(y_mod, dtype=torch.bool)\n",
        "                for c in tail_set:\n",
        "                    tail_mask |= (y_mod == c)\n",
        "                if tail_mask.any() and len(head_arr) > 0:\n",
        "                    ridx = torch.randint(0, len(head_arr), (tail_mask.sum().item(),), device=CFG.device)\n",
        "                    y_mod[tail_mask] = head_arr[ridx]\n",
        "\n",
        "            logits = model(x)\n",
        "            loss_ce = crit(logits, y_mod)\n",
        "\n",
        "            cur_sd = {k: v for k, v in model.state_dict().items()}\n",
        "            loss_glob = _param_l2(cur_sd, global_sd)\n",
        "\n",
        "            if prev_mal_tx_sd is not None:\n",
        "                loss_cons = _param_l2(cur_sd, prev_mal_tx_sd)\n",
        "            else:\n",
        "                loss_cons = torch.tensor(0.0, device=CFG.device)\n",
        "\n",
        "            loss = loss_ce + CFG.poisonedfl_mu_global * loss_glob + CFG.poisonedfl_beta_consistency * loss_cons\n",
        "\n",
        "            opt.zero_grad()\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "\n",
        "    return model.state_dict()\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 7) Defenses: FedCC (linear CKA) + Attack-Adaptive Aggregation\n",
        "# ----------------------------\n",
        "def linear_cka(X: torch.Tensor, Y: torch.Tensor, eps: float = 1e-12) -> float:\n",
        "    if X.numel() == 0 or Y.numel() == 0:\n",
        "        return 0.0\n",
        "    n = min(X.shape[0], Y.shape[0])\n",
        "    X = X[:n]\n",
        "    Y = Y[:n]\n",
        "\n",
        "    Xc = X - X.mean(dim=0, keepdim=True)\n",
        "    Yc = Y - Y.mean(dim=0, keepdim=True)\n",
        "\n",
        "    XT_Y = Xc.t() @ Yc\n",
        "    XT_X = Xc.t() @ Xc\n",
        "    YT_Y = Yc.t() @ Yc\n",
        "\n",
        "    num = (XT_Y.pow(2)).sum()\n",
        "    den = torch.sqrt((XT_X.pow(2)).sum() * (YT_Y.pow(2)).sum() + eps)\n",
        "    return float((num / (den + eps)).clamp(0.0, 1.0).item())\n",
        "\n",
        "\n",
        "def defense_fedcc_cka_filter(\n",
        "    num_classes: int,\n",
        "    global_sd: Dict[str, torch.Tensor],\n",
        "    client_tx_sds: List[Dict[str, torch.Tensor]],\n",
        "    support_loader: DataLoader,\n",
        ") -> Tuple[Dict[str, torch.Tensor], Dict]:\n",
        "    g_model = get_model(num_classes)\n",
        "    g_model.load_state_dict(global_sd)\n",
        "    G = collect_features(g_model, support_loader)\n",
        "\n",
        "    sims = []\n",
        "    for tx_sd in client_tx_sds:\n",
        "        c_model = get_model(num_classes)\n",
        "        c_model.load_state_dict(tx_sd)\n",
        "        C = collect_features(c_model, support_loader)\n",
        "        sims.append(linear_cka(G, C))\n",
        "\n",
        "    n = len(client_tx_sds)\n",
        "    reject_n = int(round(CFG.fedcc_reject_frac * n))\n",
        "    reject_n = min(reject_n, max(0, n - CFG.fedcc_min_keep))\n",
        "\n",
        "    order = np.argsort(sims)  # ascending\n",
        "    keep_idx = order[reject_n:].tolist()\n",
        "    kept_sds = [client_tx_sds[i] for i in keep_idx]\n",
        "\n",
        "    new_sd = fedavg_mean(kept_sds)\n",
        "\n",
        "    info = {\n",
        "        \"defense\": \"FedCC_CKA\",\n",
        "        \"cka_sim\": sims,\n",
        "        \"keep_idx\": keep_idx,\n",
        "        \"reject_idx\": [i for i in range(n) if i not in set(keep_idx)],\n",
        "    }\n",
        "    return new_sd, info\n",
        "\n",
        "\n",
        "def defense_attack_adaptive_agg(\n",
        "    global_sd: Dict[str, torch.Tensor],\n",
        "    client_tx_sds: List[Dict[str, torch.Tensor]],\n",
        ") -> Tuple[Dict[str, torch.Tensor], Dict]:\n",
        "    updates = [flatten_update(global_sd, tx) for tx in client_tx_sds]\n",
        "    U = torch.stack(updates, dim=0)\n",
        "\n",
        "    med = U.median(dim=0).values\n",
        "    med_norm = med.norm() + 1e-12\n",
        "\n",
        "    norms = torch.tensor([u.norm().item() for u in updates], device=U.device)\n",
        "    z = (norms - norms.mean()) / (norms.std() + 1e-12)\n",
        "\n",
        "    cos_terms = []\n",
        "    for u in updates:\n",
        "        cos_terms.append((u @ med) / (u.norm() * med_norm + 1e-12))\n",
        "    cos_terms = torch.stack(cos_terms, dim=0).clamp(-1.0, 1.0)\n",
        "    cos_dist = 1.0 - cos_terms\n",
        "\n",
        "    score = z + cos_dist\n",
        "    weights = torch.softmax(-CFG.adaagg_alpha * score, dim=0).detach().cpu().numpy().tolist()\n",
        "\n",
        "    new_sd = fedavg_mean(client_tx_sds, weights=weights)\n",
        "\n",
        "    info = {\n",
        "        \"defense\": \"AttackAdaptiveAggregation\",\n",
        "        \"scores\": score.detach().cpu().numpy().tolist(),\n",
        "        \"weights\": weights,\n",
        "    }\n",
        "    return new_sd, info\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 8) Run one FL (attack x defense)\n",
        "# ----------------------------\n",
        "def tail_slice(x: np.ndarray, k: int) -> np.ndarray:\n",
        "    if len(x) <= k:\n",
        "        return x\n",
        "    return x[-k:]\n",
        "\n",
        "\n",
        "def mean_last_k(x: np.ndarray, k: int) -> float:\n",
        "    xs = tail_slice(x, k)\n",
        "    return float(np.mean(xs)) if len(xs) > 0 else float(\"nan\")\n",
        "\n",
        "\n",
        "def std_last_k(x: np.ndarray, k: int) -> float:\n",
        "    xs = tail_slice(x, k)\n",
        "    return float(np.std(xs)) if len(xs) > 0 else float(\"nan\")\n",
        "\n",
        "\n",
        "def run_fl_once(\n",
        "    num_classes: int,\n",
        "    client_train_sets: List[TensorDataset],\n",
        "    client_test_sets: List[TensorDataset],\n",
        "    public_support: TensorDataset,\n",
        "    public_head_eval: TensorDataset,\n",
        "    public_tail_eval: TensorDataset,\n",
        "    defense_name: str,  # \"FedCC\" | \"AttackAdaptiveAggregation\"\n",
        "    attack_name: str,   # \"NONE\" | \"POISONEDFL\" | \"BACKDOOR\"\n",
        "    seed_offset: int = 0,\n",
        ") -> Dict:\n",
        "    rng = np.random.default_rng(CFG.seed + seed_offset)\n",
        "\n",
        "    n_clients = len(client_train_sets)\n",
        "    n_mal = int(round(CFG.malicious_frac * n_clients))\n",
        "    all_ids = np.arange(n_clients)\n",
        "    rng.shuffle(all_ids)\n",
        "    malicious_ids = set(all_ids[:n_mal].tolist())\n",
        "\n",
        "    # build per-client train datasets (wrap backdoor for malicious)\n",
        "    built_train_sets = []\n",
        "    for cid in range(n_clients):\n",
        "        base = client_train_sets[cid]\n",
        "        if cid in malicious_ids and attack_name == \"BACKDOOR\":\n",
        "            built = BackdoorWrapper(\n",
        "                base=base,\n",
        "                poison_frac=CFG.backdoor_poison_frac,\n",
        "                target_label=CFG.backdoor_target_label,\n",
        "                trigger_size=CFG.trigger_size,\n",
        "                trigger_value=CFG.trigger_value,\n",
        "                seed=int(CFG.seed + 777 + seed_offset + cid),\n",
        "            )\n",
        "            built_train_sets.append(built)\n",
        "        else:\n",
        "            built_train_sets.append(base)\n",
        "\n",
        "    # loaders\n",
        "    support_loader = DataLoader(public_support, batch_size=CFG.batch_size, shuffle=False)\n",
        "    head_eval_loader = DataLoader(public_head_eval, batch_size=CFG.batch_size, shuffle=False)\n",
        "    tail_eval_loader = DataLoader(public_tail_eval, batch_size=CFG.batch_size, shuffle=False)\n",
        "\n",
        "    # welfare: tail accuracy across clients on their test-tail subsets\n",
        "    client_tail_tests = [filter_by_labels(ds, CFG.tail_classes) for ds in client_test_sets]\n",
        "\n",
        "    # init global\n",
        "    global_model = get_model(num_classes)\n",
        "    global_sd = {k: v.detach().clone() for k, v in global_model.state_dict().items()}\n",
        "\n",
        "    # PoisonedFL memory\n",
        "    prev_mal_tx_sd: Dict[int, Optional[Dict[str, torch.Tensor]]] = {cid: None for cid in malicious_ids}\n",
        "\n",
        "    last_reject_rate = 0.0\n",
        "\n",
        "    # histories\n",
        "    M_head_hist, M_tail_hist, W_tail_hist, gap_hist = [], [], [], []\n",
        "    reject_rate_hist = []\n",
        "    scale_hist = []\n",
        "\n",
        "    for rnd in range(CFG.n_rounds):\n",
        "        part = rng.choice(n_clients, size=min(CFG.clients_per_round, n_clients), replace=False).tolist()\n",
        "        client_tx_sds = []\n",
        "\n",
        "        for cid in part:\n",
        "            local_model = get_model(num_classes)\n",
        "            local_model.load_state_dict(global_sd)\n",
        "\n",
        "            is_attack_active = (attack_name != \"NONE\") and (rnd >= CFG.attack_start_round) and (cid in malicious_ids)\n",
        "\n",
        "            if is_attack_active and attack_name == \"POISONEDFL\":\n",
        "                local_sd = local_train_poisonedfl(\n",
        "                    model=local_model,\n",
        "                    dataset=built_train_sets[cid],\n",
        "                    global_sd=global_sd,\n",
        "                    prev_mal_tx_sd=prev_mal_tx_sd[cid],\n",
        "                    epochs=CFG.local_epochs,\n",
        "                )\n",
        "\n",
        "                # dynamic magnitude adjustment based on last reject rate (FedCC only; else 0)\n",
        "                scale = CFG.poisonedfl_scale * (1.0 + CFG.poisonedfl_dyn_eta * (last_reject_rate - 0.25))\n",
        "                scale = float(np.clip(scale, CFG.poisonedfl_scale_min, CFG.poisonedfl_scale_max))\n",
        "                scale_hist.append(scale)\n",
        "\n",
        "                tx_sd = make_transmitted_state(global_sd, local_sd, scale=scale)\n",
        "                client_tx_sds.append(tx_sd)\n",
        "\n",
        "                prev_mal_tx_sd[cid] = {k: v.detach().clone() for k, v in tx_sd.items()}\n",
        "\n",
        "            else:\n",
        "                # honest train (or backdoor already injected into data)\n",
        "                local_model = local_train_honest(local_model, built_train_sets[cid], epochs=CFG.local_epochs)\n",
        "                local_sd = local_model.state_dict()\n",
        "\n",
        "                if is_attack_active and attack_name == \"BACKDOOR\":\n",
        "                    scale = CFG.model_replacement_gamma\n",
        "                else:\n",
        "                    scale = 1.0\n",
        "                scale_hist.append(scale)\n",
        "\n",
        "                tx_sd = make_transmitted_state(global_sd, local_sd, scale=scale)\n",
        "                client_tx_sds.append(tx_sd)\n",
        "\n",
        "        # defense\n",
        "        if defense_name == \"FedCC\":\n",
        "            new_sd, info = defense_fedcc_cka_filter(\n",
        "                num_classes=num_classes,\n",
        "                global_sd=global_sd,\n",
        "                client_tx_sds=client_tx_sds,\n",
        "                support_loader=support_loader,\n",
        "            )\n",
        "            rej = len(info[\"reject_idx\"])\n",
        "            last_reject_rate = float(rej / max(1, len(client_tx_sds)))\n",
        "        elif defense_name == \"AttackAdaptiveAggregation\":\n",
        "            new_sd, info = defense_attack_adaptive_agg(global_sd=global_sd, client_tx_sds=client_tx_sds)\n",
        "            last_reject_rate = 0.0\n",
        "        else:\n",
        "            raise ValueError(f\"Unknown defense: {defense_name}\")\n",
        "\n",
        "        global_sd = new_sd\n",
        "        global_model.load_state_dict(global_sd)\n",
        "\n",
        "        # public head/tail metric\n",
        "        M_head = evaluate_acc(global_model, head_eval_loader)\n",
        "        M_tail = evaluate_acc(global_model, tail_eval_loader)\n",
        "\n",
        "        # welfare: mean tail acc across clients\n",
        "        tail_accs = []\n",
        "        for ds_tail in client_tail_tests:\n",
        "            if len(ds_tail) == 0:\n",
        "                continue\n",
        "            loader = DataLoader(ds_tail, batch_size=CFG.batch_size, shuffle=False)\n",
        "            tail_accs.append(evaluate_acc(global_model, loader))\n",
        "        W_tail = float(np.mean(tail_accs)) if len(tail_accs) > 0 else 0.0\n",
        "\n",
        "        gap = float(M_head - W_tail)\n",
        "\n",
        "        M_head_hist.append(M_head)\n",
        "        M_tail_hist.append(M_tail)\n",
        "        W_tail_hist.append(W_tail)\n",
        "        gap_hist.append(gap)\n",
        "        reject_rate_hist.append(last_reject_rate)\n",
        "\n",
        "        if (rnd + 1) % 10 == 0 or rnd == 0:\n",
        "            print(\n",
        "                f\"[{attack_name} | {defense_name}] \"\n",
        "                f\"Round {rnd+1:>3}/{CFG.n_rounds} | \"\n",
        "                f\"M_head={M_head:.4f} M_tail={M_tail:.4f} W_tail={W_tail:.4f} gap={gap:.4f} | \"\n",
        "                f\"rej_rate={last_reject_rate:.2f}\"\n",
        "            )\n",
        "\n",
        "    return {\n",
        "        \"attack\": attack_name,\n",
        "        \"defense\": defense_name,\n",
        "        \"n_clients\": n_clients,\n",
        "        \"clients_per_round\": CFG.clients_per_round,\n",
        "        \"malicious_frac\": CFG.malicious_frac,\n",
        "        \"n_malicious\": n_mal,\n",
        "        \"attack_start_round\": CFG.attack_start_round,\n",
        "        \"M_head_hist\": np.array(M_head_hist, dtype=np.float32),\n",
        "        \"M_tail_hist\": np.array(M_tail_hist, dtype=np.float32),\n",
        "        \"W_tail_hist\": np.array(W_tail_hist, dtype=np.float32),\n",
        "        \"gap_hist\": np.array(gap_hist, dtype=np.float32),\n",
        "        \"reject_rate_hist\": np.array(reject_rate_hist, dtype=np.float32),\n",
        "        \"scale_hist\": np.array(scale_hist, dtype=np.float32),\n",
        "    }\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 9) Threshold + summary metrics\n",
        "# ----------------------------\n",
        "def compute_alarm_threshold_from_baseline(gap_hist: np.ndarray) -> float:\n",
        "    xs = tail_slice(gap_hist, CFG.tail_k).astype(np.float64)\n",
        "    mu = float(np.mean(xs))\n",
        "    sd = float(np.std(xs))\n",
        "    return mu + CFG.alarm_threshold_std * sd\n",
        "\n",
        "\n",
        "def fp_fn_ewdelay(gap_hist: np.ndarray, thr: float, attack_start_round: int, is_attack: bool) -> Tuple[float, float, float]:\n",
        "    T = len(gap_hist)\n",
        "    g = gap_hist.astype(np.float64)\n",
        "\n",
        "    if not is_attack:\n",
        "        start = min(T, CFG.warmup_ignore)\n",
        "        fp = float(np.mean(g[start:] > thr)) if T > start else 0.0\n",
        "        return fp, float(\"nan\"), float(\"nan\")\n",
        "\n",
        "    start = min(T, attack_start_round)\n",
        "    post = g[start:]\n",
        "    fn = float(np.mean(post <= thr)) if len(post) > 0 else float(\"nan\")\n",
        "    ew = float(\"nan\")\n",
        "    for t in range(start, T):\n",
        "        if g[t] > thr:\n",
        "            ew = float(t - start)\n",
        "            break\n",
        "    return float(\"nan\"), fn, ew\n",
        "\n",
        "\n",
        "def summarize_run(run: Dict, thr: float, baseline_W: Optional[float]) -> Dict:\n",
        "    M_head = mean_last_k(run[\"M_head_hist\"], CFG.tail_k)\n",
        "    M_tail = mean_last_k(run[\"M_tail_hist\"], CFG.tail_k)\n",
        "    W_tail = mean_last_k(run[\"W_tail_hist\"], CFG.tail_k)\n",
        "    gap = mean_last_k(run[\"gap_hist\"], CFG.tail_k)\n",
        "\n",
        "    pog = float(\"nan\")\n",
        "    if baseline_W is not None and np.isfinite(baseline_W) and baseline_W > 1e-8:\n",
        "        pog = float((baseline_W - W_tail) / baseline_W)\n",
        "\n",
        "    is_attack = (run[\"attack\"] != \"NONE\")\n",
        "    fp, fn, ew = fp_fn_ewdelay(run[\"gap_hist\"], thr, CFG.attack_start_round, is_attack=is_attack)\n",
        "\n",
        "    return {\n",
        "        \"condition\": \"ATTACK\" if is_attack else \"BASELINE\",\n",
        "        \"attack\": run[\"attack\"],\n",
        "        \"defense\": run[\"defense\"],\n",
        "        \"n_clients\": run[\"n_clients\"],\n",
        "        \"clients_per_round\": run[\"clients_per_round\"],\n",
        "        \"malicious_frac\": run[\"malicious_frac\"],\n",
        "        \"n_malicious\": run[\"n_malicious\"],\n",
        "        \"attack_start_round\": run[\"attack_start_round\"],\n",
        "        \"threshold_gap\": float(thr),\n",
        "\n",
        "        \"M_head_lastK\": float(M_head),\n",
        "        \"M_tail_lastK\": float(M_tail),\n",
        "        \"W_tail_lastK\": float(W_tail),\n",
        "        \"gap_Mhead_minus_W_lastK\": float(gap),\n",
        "\n",
        "        \"PoG_vs_baselineW_same_defense\": float(pog),\n",
        "        \"FP_rate_baseline\": float(fp),\n",
        "        \"FN_rate_attack\": float(fn),\n",
        "        \"EW_delay_rounds\": float(ew),\n",
        "\n",
        "        \"gap_std_lastK\": float(std_last_k(run[\"gap_hist\"], CFG.tail_k)),\n",
        "        \"W_std_lastK\": float(std_last_k(run[\"W_tail_hist\"], CFG.tail_k)),\n",
        "\n",
        "        \"avg_reject_rate\": float(np.mean(run[\"reject_rate_hist\"])) if len(run[\"reject_rate_hist\"]) > 0 else float(\"nan\"),\n",
        "    }\n",
        "\n",
        "\n",
        "# ----------------------------\n",
        "# 10) Main: baselines per defense + 2x2 runs + CSV\n",
        "# ----------------------------\n",
        "def main():\n",
        "    print(\"\\n=== E4 (No LEAF): FEMNIST | 2x2 Attacks x Defenses ===\")\n",
        "    print(f\"Device: {CFG.device}\")\n",
        "    print(f\"n_clients={CFG.n_clients}, clients_per_round={CFG.clients_per_round}, rounds={CFG.n_rounds}, local_epochs={CFG.local_epochs}\")\n",
        "    print(f\"malicious_frac={CFG.malicious_frac}, attack_start_round={CFG.attack_start_round}\")\n",
        "    print(f\"public_support_size={CFG.public_support_size} (CKA)\\n\")\n",
        "\n",
        "    rng = np.random.default_rng(CFG.seed)\n",
        "\n",
        "    # 1) Load FEMNIST clients (writer_id = client) without LEAF\n",
        "    client_train_sets, client_test_sets, num_classes = build_femnist_clients_from_hf(\n",
        "        n_clients=CFG.n_clients,\n",
        "        train_ratio=0.8,\n",
        "        seed=CFG.seed,\n",
        "        min_train=CFG.min_train_samples_per_client,\n",
        "        min_test=CFG.min_test_samples_per_client,\n",
        "        max_try=CFG.max_partition_id_try,\n",
        "    )\n",
        "    print(f\"Collected clients: {len(client_train_sets)}\")\n",
        "    print(f\"Inferred num_classes: {num_classes}\")\n",
        "\n",
        "    # 2) Public sets from client train data\n",
        "    public_head = sample_public_from_clients(client_train_sets, CFG.head_classes, CFG.public_head_eval_size, rng)\n",
        "    public_tail = sample_public_from_clients(client_train_sets, CFG.tail_classes, CFG.public_tail_eval_size, rng)\n",
        "    public_support = sample_public_from_clients(client_train_sets, CFG.head_classes, CFG.public_support_size, rng)\n",
        "    if len(public_support) == 0:\n",
        "        public_support = sample_public_from_clients(client_train_sets, tuple(range(num_classes)), CFG.public_support_size, rng)\n",
        "\n",
        "    if len(public_head) == 0 or len(public_tail) == 0:\n",
        "        raise RuntimeError(\n",
        "            \"Public head/tail eval sets are empty. \"\n",
        "            \"Adjust head_classes/tail_classes or increase public_*_eval_size.\"\n",
        "        )\n",
        "\n",
        "    defenses = [\"FedCC\", \"AttackAdaptiveAggregation\"]\n",
        "    attacks = [\"POISONEDFL\", \"BACKDOOR\"]\n",
        "\n",
        "    seed_offset = 0\n",
        "    baseline_cache: Dict[str, Tuple[float, float]] = {}  # defense -> (thr, baseline_W)\n",
        "    rows: List[Dict] = []\n",
        "\n",
        "    # 3) Baselines per defense (attack=NONE), used to set alarm threshold per defense\n",
        "    for defense in defenses:\n",
        "        print(f\"\\n--- BASELINE (NONE) | defense={defense} ---\")\n",
        "        base = run_fl_once(\n",
        "            num_classes=num_classes,\n",
        "            client_train_sets=client_train_sets,\n",
        "            client_test_sets=client_test_sets,\n",
        "            public_support=public_support,\n",
        "            public_head_eval=public_head,\n",
        "            public_tail_eval=public_tail,\n",
        "            defense_name=defense,\n",
        "            attack_name=\"NONE\",\n",
        "            seed_offset=seed_offset,\n",
        "        )\n",
        "        seed_offset += 1\n",
        "\n",
        "        thr = compute_alarm_threshold_from_baseline(base[\"gap_hist\"])\n",
        "        Wb = mean_last_k(base[\"W_tail_hist\"], CFG.tail_k)\n",
        "        baseline_cache[defense] = (thr, Wb)\n",
        "\n",
        "        row = summarize_run(base, thr=thr, baseline_W=None)\n",
        "        rows.append(row)\n",
        "\n",
        "        np.savez(os.path.join(CFG.out_dir, f\"E4_hist_BASELINE_{defense}.npz\"), **base)\n",
        "\n",
        "    # 4) 2x2 runs\n",
        "    for attack in attacks:\n",
        "        for defense in defenses:\n",
        "            print(f\"\\n=== RUN | attack={attack} | defense={defense} ===\")\n",
        "            thr, Wb = baseline_cache[defense]\n",
        "\n",
        "            out = run_fl_once(\n",
        "                num_classes=num_classes,\n",
        "                client_train_sets=client_train_sets,\n",
        "                client_test_sets=client_test_sets,\n",
        "                public_support=public_support,\n",
        "                public_head_eval=public_head,\n",
        "                public_tail_eval=public_tail,\n",
        "                defense_name=defense,\n",
        "                attack_name=attack,\n",
        "                seed_offset=seed_offset,\n",
        "            )\n",
        "            seed_offset += 1\n",
        "\n",
        "            row = summarize_run(out, thr=thr, baseline_W=Wb)\n",
        "            rows.append(row)\n",
        "\n",
        "            np.savez(os.path.join(CFG.out_dir, f\"E4_hist_{attack}_{defense}.npz\"), **out)\n",
        "\n",
        "    # 5) Save CSV\n",
        "    csv_path = os.path.join(CFG.out_dir, CFG.csv_name)\n",
        "    with open(csv_path, \"w\", newline=\"\", encoding=\"utf-8\") as f:\n",
        "        fieldnames = list(rows[0].keys())\n",
        "        w = csv.DictWriter(f, fieldnames=fieldnames)\n",
        "        w.writeheader()\n",
        "        w.writerows(rows)\n",
        "\n",
        "    print(f\"\\nSaved CSV: {csv_path}\")\n",
        "    print(f\"Saved histories (.npz) in: {CFG.out_dir}\")\n",
        "    print(\"Done.\")\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "a995221ad90d4a1e8b2f292a2144836a",
            "39f80b41a398473aa0753fb1f9a85e05",
            "f453a12668c44b70bd027828cb22d3f9",
            "ab43eae1b2254714a4c6c2e8c5946b08",
            "2f886b935c124a83a9e6bb5b4eb1318c",
            "77c6ad26324b47aebd59c01cdc5e9ffb",
            "8abff24900964d7fb8808bc379f15715",
            "690f534690aa4360b6d25e29a45dbbe6",
            "f4fdf8d55afc4f1d8778c97ad4c0065f",
            "b34dec285b064459956df5cafff604dd",
            "1b41483d67474afd8c9c489b719ccf1d",
            "87a65870c8a2413a955fa7f6580d06dc",
            "860a84b8925145bbbb03a5707a29fee2",
            "2534511ff7bf4c26be3798f872332ecd",
            "bcf1434503c64467b1933816504b56ff",
            "5492c4d0bdf94c6cb11f3b5a2588ff6f",
            "245a95fee2fc40bab018983aed796c94",
            "af479a90137e442398fa4aed086cd156",
            "e61976050df44f8d9ba15393b39472d1",
            "7ca1b8c20fde4049973f521ce5721418",
            "35e5a2e314ea4e22ad64fef2097a66b1",
            "756b21451452433599c8320de3e47129",
            "ad3640ee50f0407ea4276796c2eb5ccc",
            "5fb3c0b71298475db284141676bbd5f7",
            "f1d8b4890d934be2a415725bf3f6b20b",
            "6fe543e0e6354dbc800383abcfbdaf97",
            "f0c814e6c640436781724cac8c502997",
            "d164e8a25b15426daa6d358c066c3501",
            "07666125af1543ada56d7c69fd3e09c6",
            "c6c95710fae848d28f636e9694ca53b1",
            "74c2d043833d48d9928cb24cc5992e38",
            "60954716da3c4de689a229ee1124b3e3",
            "ed0881322e2647bc8418f8ef738cd109"
          ]
        },
        "id": "yxbsbfYo-LfP",
        "outputId": "3e28f18a-9040-4550-8da1-e84e80392377"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "=== E4 (No LEAF): FEMNIST | 2x2 Attacks x Defenses ===\n",
            "Device: cuda\n",
            "n_clients=32, clients_per_round=12, rounds=100, local_epochs=1\n",
            "malicious_frac=0.3, attack_start_round=5\n",
            "public_support_size=128 (CKA)\n",
            "\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n",
            "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
            "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
            "You will be able to reuse this secret in all of your notebooks.\n",
            "Please note that authentication is recommended but still optional to access public models or datasets.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "README.md: 0.00B [00:00, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "a995221ad90d4a1e8b2f292a2144836a"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "data/train-00000-of-00001.parquet:   0%|          | 0.00/201M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "87a65870c8a2413a955fa7f6580d06dc"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Generating train split:   0%|          | 0/814277 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "ad3640ee50f0407ea4276796c2eb5ccc"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Generating partition_id_to_indices: 814277it [00:00, 1702870.76it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collected clients: 32\n",
            "Inferred num_classes: 62\n",
            "\n",
            "--- BASELINE (NONE) | defense=FedCC ---\n",
            "[NONE | FedCC] Round   1/100 | M_head=0.0000 M_tail=0.0393 W_tail=0.0427 gap=-0.0427 | rej_rate=0.25\n",
            "[NONE | FedCC] Round  10/100 | M_head=0.0000 M_tail=0.0776 W_tail=0.0879 gap=-0.0879 | rej_rate=0.25\n",
            "[NONE | FedCC] Round  20/100 | M_head=0.0000 M_tail=0.0776 W_tail=0.0879 gap=-0.0879 | rej_rate=0.25\n",
            "[NONE | FedCC] Round  30/100 | M_head=0.0000 M_tail=0.1440 W_tail=0.1553 gap=-0.1553 | rej_rate=0.25\n",
            "[NONE | FedCC] Round  40/100 | M_head=0.0777 M_tail=0.1628 W_tail=0.1777 gap=-0.1000 | rej_rate=0.25\n",
            "[NONE | FedCC] Round  50/100 | M_head=0.2914 M_tail=0.2266 W_tail=0.2311 gap=0.0603 | rej_rate=0.25\n",
            "[NONE | FedCC] Round  60/100 | M_head=0.5932 M_tail=0.3044 W_tail=0.3089 gap=0.2843 | rej_rate=0.25\n",
            "[NONE | FedCC] Round  70/100 | M_head=0.7029 M_tail=0.3770 W_tail=0.3642 gap=0.3386 | rej_rate=0.25\n",
            "[NONE | FedCC] Round  80/100 | M_head=0.6586 M_tail=0.4573 W_tail=0.4557 gap=0.2029 | rej_rate=0.25\n",
            "[NONE | FedCC] Round  90/100 | M_head=0.7493 M_tail=0.5029 W_tail=0.4994 gap=0.2499 | rej_rate=0.25\n",
            "[NONE | FedCC] Round 100/100 | M_head=0.7615 M_tail=0.5540 W_tail=0.5449 gap=0.2167 | rej_rate=0.25\n",
            "\n",
            "--- BASELINE (NONE) | defense=AttackAdaptiveAggregation ---\n",
            "[NONE | AttackAdaptiveAggregation] Round   1/100 | M_head=0.0000 M_tail=0.0703 W_tail=0.0785 gap=-0.0785 | rej_rate=0.00\n",
            "[NONE | AttackAdaptiveAggregation] Round  10/100 | M_head=0.0000 M_tail=0.0791 W_tail=0.0794 gap=-0.0794 | rej_rate=0.00\n",
            "[NONE | AttackAdaptiveAggregation] Round  20/100 | M_head=0.0000 M_tail=0.1418 W_tail=0.1418 gap=-0.1418 | rej_rate=0.00\n",
            "[NONE | AttackAdaptiveAggregation] Round  30/100 | M_head=0.0014 M_tail=0.1521 W_tail=0.1609 gap=-0.1595 | rej_rate=0.00\n",
            "[NONE | AttackAdaptiveAggregation] Round  40/100 | M_head=0.0450 M_tail=0.1709 W_tail=0.1788 gap=-0.1339 | rej_rate=0.00\n",
            "[NONE | AttackAdaptiveAggregation] Round  50/100 | M_head=0.2878 M_tail=0.2046 W_tail=0.2117 gap=0.0760 | rej_rate=0.00\n",
            "[NONE | AttackAdaptiveAggregation] Round  60/100 | M_head=0.4371 M_tail=0.2393 W_tail=0.2336 gap=0.2035 | rej_rate=0.00\n",
            "[NONE | AttackAdaptiveAggregation] Round  70/100 | M_head=0.4950 M_tail=0.2825 W_tail=0.2852 gap=0.2098 | rej_rate=0.00\n",
            "[NONE | AttackAdaptiveAggregation] Round  80/100 | M_head=0.6453 M_tail=0.3376 W_tail=0.3310 gap=0.3144 | rej_rate=0.00\n",
            "[NONE | AttackAdaptiveAggregation] Round  90/100 | M_head=0.6478 M_tail=0.3621 W_tail=0.3430 gap=0.3048 | rej_rate=0.00\n",
            "[NONE | AttackAdaptiveAggregation] Round 100/100 | M_head=0.6061 M_tail=0.4050 W_tail=0.4086 gap=0.1975 | rej_rate=0.00\n",
            "\n",
            "=== RUN | attack=POISONEDFL | defense=FedCC ===\n",
            "[POISONEDFL | FedCC] Round   1/100 | M_head=0.0000 M_tail=0.0334 W_tail=0.0332 gap=-0.0332 | rej_rate=0.25\n",
            "[POISONEDFL | FedCC] Round  10/100 | M_head=0.0000 M_tail=0.0776 W_tail=0.0879 gap=-0.0879 | rej_rate=0.25\n",
            "[POISONEDFL | FedCC] Round  20/100 | M_head=0.0000 M_tail=0.0791 W_tail=0.0787 gap=-0.0787 | rej_rate=0.25\n",
            "[POISONEDFL | FedCC] Round  30/100 | M_head=0.1665 M_tail=0.0417 W_tail=0.0406 gap=0.1259 | rej_rate=0.25\n",
            "[POISONEDFL | FedCC] Round  40/100 | M_head=0.0903 M_tail=0.1421 W_tail=0.1382 gap=-0.0479 | rej_rate=0.25\n",
            "[POISONEDFL | FedCC] Round  50/100 | M_head=0.3086 M_tail=0.0229 W_tail=0.0253 gap=0.2834 | rej_rate=0.25\n",
            "[POISONEDFL | FedCC] Round  60/100 | M_head=0.4863 M_tail=0.0273 W_tail=0.0303 gap=0.4560 | rej_rate=0.25\n",
            "[POISONEDFL | FedCC] Round  70/100 | M_head=0.2950 M_tail=0.3606 W_tail=0.3638 gap=-0.0689 | rej_rate=0.25\n",
            "[POISONEDFL | FedCC] Round  80/100 | M_head=0.7450 M_tail=0.0000 W_tail=0.0000 gap=0.7450 | rej_rate=0.25\n",
            "[POISONEDFL | FedCC] Round  90/100 | M_head=0.7270 M_tail=0.1562 W_tail=0.1505 gap=0.5765 | rej_rate=0.25\n",
            "[POISONEDFL | FedCC] Round 100/100 | M_head=0.7335 M_tail=0.4719 W_tail=0.4539 gap=0.2796 | rej_rate=0.25\n",
            "\n",
            "=== RUN | attack=POISONEDFL | defense=AttackAdaptiveAggregation ===\n",
            "[POISONEDFL | AttackAdaptiveAggregation] Round   1/100 | M_head=0.0000 M_tail=0.0354 W_tail=0.0309 gap=-0.0309 | rej_rate=0.00\n",
            "[POISONEDFL | AttackAdaptiveAggregation] Round  10/100 | M_head=0.0000 M_tail=0.1248 W_tail=0.1302 gap=-0.1302 | rej_rate=0.00\n",
            "[POISONEDFL | AttackAdaptiveAggregation] Round  20/100 | M_head=0.0000 M_tail=0.1177 W_tail=0.1118 gap=-0.1118 | rej_rate=0.00\n",
            "[POISONEDFL | AttackAdaptiveAggregation] Round  30/100 | M_head=0.1227 M_tail=0.1584 W_tail=0.1656 gap=-0.0429 | rej_rate=0.00\n",
            "[POISONEDFL | AttackAdaptiveAggregation] Round  40/100 | M_head=0.3068 M_tail=0.1711 W_tail=0.1828 gap=0.1241 | rej_rate=0.00\n",
            "[POISONEDFL | AttackAdaptiveAggregation] Round  50/100 | M_head=0.4975 M_tail=0.2478 W_tail=0.2622 gap=0.2353 | rej_rate=0.00\n",
            "[POISONEDFL | AttackAdaptiveAggregation] Round  60/100 | M_head=0.6014 M_tail=0.3020 W_tail=0.3059 gap=0.2955 | rej_rate=0.00\n",
            "[POISONEDFL | AttackAdaptiveAggregation] Round  70/100 | M_head=0.6942 M_tail=0.4011 W_tail=0.3956 gap=0.2986 | rej_rate=0.00\n",
            "[POISONEDFL | AttackAdaptiveAggregation] Round  80/100 | M_head=0.8058 M_tail=0.4546 W_tail=0.4289 gap=0.3769 | rej_rate=0.00\n",
            "[POISONEDFL | AttackAdaptiveAggregation] Round  90/100 | M_head=0.8209 M_tail=0.4688 W_tail=0.4574 gap=0.3635 | rej_rate=0.00\n",
            "[POISONEDFL | AttackAdaptiveAggregation] Round 100/100 | M_head=0.7737 M_tail=0.5515 W_tail=0.5360 gap=0.2377 | rej_rate=0.00\n",
            "\n",
            "=== RUN | attack=BACKDOOR | defense=FedCC ===\n",
            "[BACKDOOR | FedCC] Round   1/100 | M_head=0.0446 M_tail=0.0681 W_tail=0.0805 gap=-0.0359 | rej_rate=0.25\n",
            "[BACKDOOR | FedCC] Round  10/100 | M_head=0.1040 M_tail=0.0000 W_tail=0.0000 gap=0.1040 | rej_rate=0.25\n",
            "[BACKDOOR | FedCC] Round  20/100 | M_head=0.0000 M_tail=0.1528 W_tail=0.1501 gap=-0.1501 | rej_rate=0.25\n",
            "[BACKDOOR | FedCC] Round  30/100 | M_head=0.1040 M_tail=0.0688 W_tail=0.0701 gap=0.0339 | rej_rate=0.25\n",
            "[BACKDOOR | FedCC] Round  40/100 | M_head=0.0547 M_tail=0.2156 W_tail=0.2297 gap=-0.1750 | rej_rate=0.25\n",
            "[BACKDOOR | FedCC] Round  50/100 | M_head=0.4874 M_tail=0.3450 W_tail=0.3531 gap=0.1343 | rej_rate=0.25\n",
            "[BACKDOOR | FedCC] Round  60/100 | M_head=0.5424 M_tail=0.2129 W_tail=0.2173 gap=0.3251 | rej_rate=0.25\n",
            "[BACKDOOR | FedCC] Round  70/100 | M_head=0.6122 M_tail=0.4456 W_tail=0.4414 gap=0.1708 | rej_rate=0.25\n",
            "[BACKDOOR | FedCC] Round  80/100 | M_head=0.7014 M_tail=0.5166 W_tail=0.5143 gap=0.1871 | rej_rate=0.25\n",
            "[BACKDOOR | FedCC] Round  90/100 | M_head=0.2633 M_tail=0.1665 W_tail=0.1845 gap=0.0788 | rej_rate=0.25\n",
            "[BACKDOOR | FedCC] Round 100/100 | M_head=0.7701 M_tail=0.5688 W_tail=0.5617 gap=0.2085 | rej_rate=0.25\n",
            "\n",
            "=== RUN | attack=BACKDOOR | defense=AttackAdaptiveAggregation ===\n",
            "[BACKDOOR | AttackAdaptiveAggregation] Round   1/100 | M_head=0.0000 M_tail=0.0776 W_tail=0.0879 gap=-0.0879 | rej_rate=0.00\n",
            "[BACKDOOR | AttackAdaptiveAggregation] Round  10/100 | M_head=0.0000 M_tail=0.0776 W_tail=0.0879 gap=-0.0879 | rej_rate=0.00\n",
            "[BACKDOOR | AttackAdaptiveAggregation] Round  20/100 | M_head=0.0000 M_tail=0.0776 W_tail=0.0879 gap=-0.0879 | rej_rate=0.00\n",
            "[BACKDOOR | AttackAdaptiveAggregation] Round  30/100 | M_head=0.0000 M_tail=0.0791 W_tail=0.0787 gap=-0.0787 | rej_rate=0.00\n",
            "[BACKDOOR | AttackAdaptiveAggregation] Round  40/100 | M_head=0.0439 M_tail=0.0776 W_tail=0.0780 gap=-0.0341 | rej_rate=0.00\n",
            "[BACKDOOR | AttackAdaptiveAggregation] Round  50/100 | M_head=0.2022 M_tail=0.0955 W_tail=0.1132 gap=0.0890 | rej_rate=0.00\n",
            "[BACKDOOR | AttackAdaptiveAggregation] Round  60/100 | M_head=0.4579 M_tail=0.2000 W_tail=0.2138 gap=0.2441 | rej_rate=0.00\n",
            "[BACKDOOR | AttackAdaptiveAggregation] Round  70/100 | M_head=0.5806 M_tail=0.2893 W_tail=0.2979 gap=0.2827 | rej_rate=0.00\n",
            "[BACKDOOR | AttackAdaptiveAggregation] Round  80/100 | M_head=0.6151 M_tail=0.3726 W_tail=0.3910 gap=0.2241 | rej_rate=0.00\n",
            "[BACKDOOR | AttackAdaptiveAggregation] Round  90/100 | M_head=0.6813 M_tail=0.4407 W_tail=0.4324 gap=0.2489 | rej_rate=0.00\n",
            "[BACKDOOR | AttackAdaptiveAggregation] Round 100/100 | M_head=0.7183 M_tail=0.4265 W_tail=0.4159 gap=0.3024 | rej_rate=0.00\n",
            "\n",
            "Saved CSV: ./E4_outputs_noLEAF/E4_FEMNIST_2x2_results.csv\n",
            "Saved histories (.npz) in: ./E4_outputs_noLEAF\n",
            "Done.\n"
          ]
        }
      ]
    }
  ]
}