{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "DUQ_FM_final_c.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "033c2cd6f6fc44aca937e213d7505d4a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_221c3d31d4b64fc1ac17b75b050a1fd4",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_b94137bad9ed4bc3bbe7e2f8837a3f8d",
              "IPY_MODEL_e60f689aaa8d464081cba497657142d2"
            ]
          }
        },
        "221c3d31d4b64fc1ac17b75b050a1fd4": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "b94137bad9ed4bc3bbe7e2f8837a3f8d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_714ce153aa6c43c1a0e014c85783f436",
            "_dom_classes": [],
            "description": "Epoch [1/30]: [468/468] 100%",
            "_model_name": "FloatProgressModel",
            "bar_style": "",
            "max": 468,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 468,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_65f61c473a4e4ad1ac6452c2b684ce72"
          }
        },
        "e60f689aaa8d464081cba497657142d2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_99ef99ca67944d2d9e4e91fb8628bee4",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " [00:19&lt;00:00]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_a6f5055dbc234db9955e3bd9db525681"
          }
        },
        "714ce153aa6c43c1a0e014c85783f436": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "65f61c473a4e4ad1ac6452c2b684ce72": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "99ef99ca67944d2d9e4e91fb8628bee4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "a6f5055dbc234db9955e3bd9db525681": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "8da51da9caf84e7aa747548e2565093e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_e7705347db02484fa65848784f054d02",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_2829a70558c64051902098eb1e4ee2ed",
              "IPY_MODEL_c9be8a65cf8845acbf664960c54e0040"
            ]
          }
        },
        "e7705347db02484fa65848784f054d02": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "2829a70558c64051902098eb1e4ee2ed": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_4b0d08efda4b41da92fb2f60eeda1a1b",
            "_dom_classes": [],
            "description": "Epoch [2/30]: [468/468] 100%",
            "_model_name": "FloatProgressModel",
            "bar_style": "",
            "max": 468,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 468,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_45f0bcddb03e47c2a021ebd694d566c3"
          }
        },
        "c9be8a65cf8845acbf664960c54e0040": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_50dcac658e1545499fafc75bc8879154",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " [00:19&lt;00:00]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_a21666dd846641a78583ebdf432b2c77"
          }
        },
        "4b0d08efda4b41da92fb2f60eeda1a1b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "45f0bcddb03e47c2a021ebd694d566c3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "50dcac658e1545499fafc75bc8879154": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "a21666dd846641a78583ebdf432b2c77": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "c0f8db75a1c941c29980631db65a17b3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_a2714e60df6e4912be025f79bcbc023d",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_b0a1609ef6f64ebe804417fb99d8bc77",
              "IPY_MODEL_b939c5da20b844feae3b64f3ad55d70d"
            ]
          }
        },
        "a2714e60df6e4912be025f79bcbc023d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "b0a1609ef6f64ebe804417fb99d8bc77": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_14660d04c4104036b04529b5f8292864",
            "_dom_classes": [],
            "description": "Epoch [3/30]: [468/468] 100%",
            "_model_name": "FloatProgressModel",
            "bar_style": "",
            "max": 468,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 468,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_3dc1f878f23d447d961ad207d3576486"
          }
        },
        "b939c5da20b844feae3b64f3ad55d70d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_608ce2322fd24bc38fe0f7ec5a7c46af",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " [00:19&lt;00:00]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_e9bd1e9df8614a849fcdae1da0ca0369"
          }
        },
        "14660d04c4104036b04529b5f8292864": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "3dc1f878f23d447d961ad207d3576486": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "608ce2322fd24bc38fe0f7ec5a7c46af": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "e9bd1e9df8614a849fcdae1da0ca0369": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "dd51bb586a9e47f9bd105dc60920b86d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_feabfa66c8874029bbcbe99e38b9d11c",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_a5a04cfb26dd4bbdbde532a5d14578c1",
              "IPY_MODEL_d306ca83a62a40069a916204e7952375"
            ]
          }
        },
        "feabfa66c8874029bbcbe99e38b9d11c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "a5a04cfb26dd4bbdbde532a5d14578c1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_c491ad93b96f4430a1dd01f0ec4fbd74",
            "_dom_classes": [],
            "description": "Epoch [4/30]: [468/468] 100%",
            "_model_name": "FloatProgressModel",
            "bar_style": "",
            "max": 468,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 468,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_e71e9639b87c4816af9997f931c1c9fb"
          }
        },
        "d306ca83a62a40069a916204e7952375": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_abb475e1b5d34ba98ffc47983b88f5c8",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " [00:19&lt;00:00]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_ed2aa5b6830a448285c13431400374d2"
          }
        },
        "c491ad93b96f4430a1dd01f0ec4fbd74": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "e71e9639b87c4816af9997f931c1c9fb": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "abb475e1b5d34ba98ffc47983b88f5c8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "ed2aa5b6830a448285c13431400374d2": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "ec81b7ede40e4b2ca267e2bbda56b545": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_ef4da194d5d2476d8fccd512d305942b",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_6ba7980351dd4b20b735cfc1719dd517",
              "IPY_MODEL_8289423d4ae844d6ae6fec002a320304"
            ]
          }
        },
        "ef4da194d5d2476d8fccd512d305942b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "6ba7980351dd4b20b735cfc1719dd517": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_296c8d15ee3e4d5fac02f74c7aa0e70c",
            "_dom_classes": [],
            "description": "Epoch [5/30]: [468/468] 100%",
            "_model_name": "FloatProgressModel",
            "bar_style": "",
            "max": 468,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 468,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_46b7b35419e242fc8c6be97242ad23a1"
          }
        },
        "8289423d4ae844d6ae6fec002a320304": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_1cec65c1933c4081a83d7f5f9d824def",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " [00:19&lt;00:00]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_680e93a1cdcb45a29feca9f7aa2b0e7b"
          }
        },
        "296c8d15ee3e4d5fac02f74c7aa0e70c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "46b7b35419e242fc8c6be97242ad23a1": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "1cec65c1933c4081a83d7f5f9d824def": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "680e93a1cdcb45a29feca9f7aa2b0e7b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "99b826a4760843d589b54d926815cc73": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_00040273d9ba4b158a2fcd398e230003",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_0fdae08e52a04b04aa773df3f3982dc5",
              "IPY_MODEL_6101a9827d4a4414b93ce9a29d6cd104"
            ]
          }
        },
        "00040273d9ba4b158a2fcd398e230003": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "0fdae08e52a04b04aa773df3f3982dc5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_dc2c4b6dc7654285b8c7ff7c83eb79f7",
            "_dom_classes": [],
            "description": "Epoch [6/30]: [468/468] 100%",
            "_model_name": "FloatProgressModel",
            "bar_style": "",
            "max": 468,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 468,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_f2626ff8d77f42b88cda64158fea4999"
          }
        },
        "6101a9827d4a4414b93ce9a29d6cd104": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_ce55cf43640e47f19491bb7f45a21f7f",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " [00:19&lt;00:00]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_547bad7b18774f4aaed0a66b590c1a09"
          }
        },
        "dc2c4b6dc7654285b8c7ff7c83eb79f7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "f2626ff8d77f42b88cda64158fea4999": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "ce55cf43640e47f19491bb7f45a21f7f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "547bad7b18774f4aaed0a66b590c1a09": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "8115b4007c6a4b099f10965d25771af7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_cbc0d575106d44df84a5681e0c87173c",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_6d10f9eb02414244b44ed572a6cd9361",
              "IPY_MODEL_0b4fb3d18eaa46e3bfc65cda8ac2334d"
            ]
          }
        },
        "cbc0d575106d44df84a5681e0c87173c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "6d10f9eb02414244b44ed572a6cd9361": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_b2ce33b157f543c2be1cc23ede617dc7",
            "_dom_classes": [],
            "description": "Epoch [7/30]: [468/468] 100%",
            "_model_name": "FloatProgressModel",
            "bar_style": "",
            "max": 468,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 468,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_fa05a73b73d1450d92d0e816c873eb92"
          }
        },
        "0b4fb3d18eaa46e3bfc65cda8ac2334d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_7fb75ed79d8c494c8b19eed1bc846a08",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " [00:19&lt;00:00]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_9dce5d15234d4325b7dbb40b070dedca"
          }
        },
        "b2ce33b157f543c2be1cc23ede617dc7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "fa05a73b73d1450d92d0e816c873eb92": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "7fb75ed79d8c494c8b19eed1bc846a08": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "9dce5d15234d4325b7dbb40b070dedca": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "d6976ef7a70d40fe84c0e4a8ea823bdf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_a44624dae9aa4fb5a87c59a1dd993aa7",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_5d5723a921f54b669ef5c41d8db7c135",
              "IPY_MODEL_9b383a99f8984fa0aa7b369b1039140e"
            ]
          }
        },
        "a44624dae9aa4fb5a87c59a1dd993aa7": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "5d5723a921f54b669ef5c41d8db7c135": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_913a90cc9f6b4638bcf99c4aab0ae696",
            "_dom_classes": [],
            "description": "Epoch [8/30]: [195/468]  42%",
            "_model_name": "FloatProgressModel",
            "bar_style": "",
            "max": 468,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 195,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_e631103843114d52b776306a3f5c8442"
          }
        },
        "9b383a99f8984fa0aa7b369b1039140e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_18df2292a7b34ea584db828050a7ddcb",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " [00:08&lt;00:11]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_c98d4169e44e49cfa8f22df3225c683b"
          }
        },
        "913a90cc9f6b4638bcf99c4aab0ae696": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "e631103843114d52b776306a3f5c8442": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "18df2292a7b34ea584db828050a7ddcb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "c98d4169e44e49cfa8f22df3225c683b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "2lppBolPJ60z",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "edc8c709-7a4b-45d7-febf-768e633e8b7a"
      },
      "source": [
        "!mkdir -p data && cd data && curl -O \"http://yaroslavvb.com/upload/notMNIST/notMNIST_small.mat\"\n",
        "\n",
        "!pip install pytorch-ignite\n",
        "\n",
        "#centroid size "
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
            "                                 Dload  Upload   Total   Spent    Left  Speed\n",
            "100  112M  100  112M    0     0  12.8M      0  0:00:08  0:00:08 --:--:-- 16.7M\n",
            "Requirement already satisfied: pytorch-ignite in /usr/local/lib/python3.6/dist-packages (0.4.2)\n",
            "Requirement already satisfied: torch<2,>=1.3 in /usr/local/lib/python3.6/dist-packages (from pytorch-ignite) (1.7.0+cu101)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch<2,>=1.3->pytorch-ignite) (0.16.0)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch<2,>=1.3->pytorch-ignite) (3.7.4.3)\n",
            "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch<2,>=1.3->pytorch-ignite) (0.8)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch<2,>=1.3->pytorch-ignite) (1.18.5)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hdvCGqwyfbcY"
      },
      "source": [
        "import random\n",
        "import numpy as np\n",
        "\n",
        "import torch\n",
        "import torch.utils.data\n",
        "from torch.nn import functional as F\n",
        "\n",
        "from ignite.engine import Events, Engine\n",
        "from ignite.metrics import Accuracy, Loss\n",
        "\n",
        "from ignite.contrib.handlers.tqdm_logger import ProgressBar\n",
        "\n",
        "from utils.evaluate_ood import (\n",
        "    get_fashionmnist_mnist_ood,\n",
        "    get_fashionmnist_notmnist_ood,\n",
        ")\n",
        "from utils.datasets import FastFashionMNIST, get_FashionMNIST\n",
        "from utils.cnn_duq import CNN_DUQ\n"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tJOnZiIwRjnw"
      },
      "source": [
        "model={}\n",
        "def train_model(l_gradient_penalty, length_scale, final_model,epochs,centroidz):\n",
        "\n",
        "    input_size = 28\n",
        "    num_classes = 10\n",
        "    embedding_size = centroidz\n",
        "    learnable_length_scale = False\n",
        "    gamma = 0.999\n",
        "\n",
        "\n",
        "    ## Main (FashionMNIST) and ood (Mnist) Dataset\n",
        "    dataset = FastFashionMNIST(\"data/\", train=True, download=True)\n",
        "    test_dataset = FastFashionMNIST(\"data/\", train=False, download=True)\n",
        "\n",
        "    idx = list(range(60000))\n",
        "    random.shuffle(idx)\n",
        "\n",
        "    if final_model:\n",
        "        train_dataset = dataset\n",
        "        val_dataset = test_dataset\n",
        "    else:\n",
        "        train_dataset = torch.utils.data.Subset(dataset, indices=idx[:55000])\n",
        "        val_dataset = torch.utils.data.Subset(dataset, indices=idx[55000:])\n",
        "\n",
        "    dl_train = torch.utils.data.DataLoader(\n",
        "        train_dataset, batch_size=128, shuffle=True, num_workers=0, drop_last=True\n",
        "    )\n",
        "\n",
        "    dl_val = torch.utils.data.DataLoader(\n",
        "        val_dataset, batch_size=2000, shuffle=False, num_workers=0\n",
        "    )\n",
        "\n",
        "    dl_test = torch.utils.data.DataLoader(\n",
        "        test_dataset, batch_size=2000, shuffle=False, num_workers=0\n",
        "    )\n",
        "\n",
        "\n",
        "    # Model\n",
        "    global model\n",
        "    model = CNN_DUQ(\n",
        "        input_size,\n",
        "        num_classes,\n",
        "        embedding_size,\n",
        "        learnable_length_scale,\n",
        "        length_scale,\n",
        "        gamma,\n",
        "    )\n",
        "    \n",
        "    model = model.cuda()\n",
        "    #model.load_state_dict(torch.load(\"DUQ_FM_30_FULL.pt\"))\n",
        "\n",
        "    # Optimiser\n",
        "    optimizer = torch.optim.SGD(\n",
        "        model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4\n",
        "    )\n",
        "\n",
        "    def output_transform_bce(output):\n",
        "        y_pred, y, _, _ = output\n",
        "        return y_pred, y\n",
        "\n",
        "    def output_transform_acc(output):\n",
        "        y_pred, y, _, _ = output\n",
        "        return y_pred, torch.argmax(y, dim=1)\n",
        "\n",
        "    def output_transform_gp(output):\n",
        "        y_pred, y, x, y_pred_sum = output\n",
        "        return x, y_pred_sum\n",
        "\n",
        "    def calc_gradient_penalty(x, y_pred_sum):\n",
        "        gradients = torch.autograd.grad(\n",
        "            outputs=y_pred_sum,\n",
        "            inputs=x,\n",
        "            grad_outputs=torch.ones_like(y_pred_sum),\n",
        "            create_graph=True,\n",
        "            retain_graph=True,\n",
        "        )[0]\n",
        "\n",
        "        gradients = gradients.flatten(start_dim=1)\n",
        "\n",
        "        # L2 norm\n",
        "        grad_norm = gradients.norm(2, dim=1)\n",
        "\n",
        "        # Two sided penalty\n",
        "        gradient_penalty = ((grad_norm - 1) ** 2).mean()\n",
        "\n",
        "        return gradient_penalty\n",
        "\n",
        "    def step(engine, batch):\n",
        "        model.train()\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        x, y = batch\n",
        "        y = F.one_hot(y, num_classes=10).float()\n",
        "\n",
        "        x, y = x.cuda(), y.cuda()\n",
        "\n",
        "        x.requires_grad_(True)\n",
        "\n",
        "        z, y_pred = model(x)\n",
        "\n",
        "        loss = F.binary_cross_entropy(y_pred, y)\n",
        "        loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred.sum(1))\n",
        "\n",
        "        x.requires_grad_(False)\n",
        "\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        with torch.no_grad():\n",
        "            model.eval()\n",
        "            model.update_embeddings(x, y)\n",
        "\n",
        "        return loss.item()\n",
        "\n",
        "    def eval_step(engine, batch):\n",
        "        model.eval()\n",
        "\n",
        "        x, y = batch\n",
        "        y = F.one_hot(y, num_classes=10).float()\n",
        "\n",
        "        x, y = x.cuda(), y.cuda()\n",
        "\n",
        "        x.requires_grad_(True)\n",
        "\n",
        "        z, y_pred = model(x)\n",
        "\n",
        "        return y_pred, y, x, y_pred.sum(1)\n",
        "\n",
        "    trainer = Engine(step)\n",
        "    evaluator = Engine(eval_step)\n",
        "\n",
        "    metric = Accuracy(output_transform=output_transform_acc)\n",
        "    metric.attach(evaluator, \"accuracy\")\n",
        "\n",
        "    metric = Loss(F.binary_cross_entropy, output_transform=output_transform_bce)\n",
        "    metric.attach(evaluator, \"bce\")\n",
        "\n",
        "    metric = Loss(calc_gradient_penalty, output_transform=output_transform_gp)\n",
        "    metric.attach(evaluator, \"gradient_penalty\")\n",
        "\n",
        "    scheduler = torch.optim.lr_scheduler.MultiStepLR(\n",
        "        optimizer, milestones=[10, 20], gamma=0.2\n",
        "    )\n",
        "\n",
        "    pbar = ProgressBar()\n",
        "    pbar.attach(trainer)\n",
        "\n",
        "    @trainer.on(Events.EPOCH_COMPLETED)\n",
        "    def log_results(trainer):\n",
        "        scheduler.step()\n",
        "\n",
        "        # logging every 5 epoch\n",
        "        if trainer.state.epoch % 5 == 0:\n",
        "            evaluator.run(dl_val)\n",
        "\n",
        "            # AUROC on FashionMNIST + Mnist / NotMnist\n",
        "            accuracy, roc_auc_mnist = get_fashionmnist_mnist_ood(model)\n",
        "            accuracy, roc_auc_notmnist = get_fashionmnist_notmnist_ood(model)\n",
        "            metrics = evaluator.state.metrics\n",
        "\n",
        "            print(\n",
        "                f\"Validation Results - Epoch: {trainer.state.epoch} \"\n",
        "                f\"Val_Acc: {metrics['accuracy']:.4f} \"\n",
        "                f\"BCE: {metrics['bce']:.2f} \"\n",
        "                f\"GP: {metrics['gradient_penalty']:.6f} \"\n",
        "                f\"AUROC MNIST: {roc_auc_mnist:.4f} \"\n",
        "                f\"AUROC NotMNIST: {roc_auc_notmnist:.2f} \"\n",
        "            )\n",
        "            print(f\"Sigma: {model.sigma}\")\n",
        "\n",
        "    # Train\n",
        "    trainer.run(dl_train, max_epochs=epochs)\n",
        "\n",
        "    # Validation\n",
        "    evaluator.run(dl_val)\n",
        "    val_accuracy = evaluator.state.metrics[\"accuracy\"]\n",
        "\n",
        "    # Test\n",
        "    evaluator.run(dl_test)\n",
        "    test_accuracy = evaluator.state.metrics[\"accuracy\"]\n",
        "\n",
        "    return model, val_accuracy, test_accuracy"
      ],
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tdAmrBrcSCRc",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 116,
          "referenced_widgets": [
            "033c2cd6f6fc44aca937e213d7505d4a",
            "221c3d31d4b64fc1ac17b75b050a1fd4",
            "b94137bad9ed4bc3bbe7e2f8837a3f8d",
            "e60f689aaa8d464081cba497657142d2",
            "714ce153aa6c43c1a0e014c85783f436",
            "65f61c473a4e4ad1ac6452c2b684ce72",
            "99ef99ca67944d2d9e4e91fb8628bee4",
            "a6f5055dbc234db9955e3bd9db525681",
            "8da51da9caf84e7aa747548e2565093e",
            "e7705347db02484fa65848784f054d02",
            "2829a70558c64051902098eb1e4ee2ed",
            "c9be8a65cf8845acbf664960c54e0040",
            "4b0d08efda4b41da92fb2f60eeda1a1b",
            "45f0bcddb03e47c2a021ebd694d566c3",
            "50dcac658e1545499fafc75bc8879154",
            "a21666dd846641a78583ebdf432b2c77",
            "c0f8db75a1c941c29980631db65a17b3",
            "a2714e60df6e4912be025f79bcbc023d",
            "b0a1609ef6f64ebe804417fb99d8bc77",
            "b939c5da20b844feae3b64f3ad55d70d",
            "14660d04c4104036b04529b5f8292864",
            "3dc1f878f23d447d961ad207d3576486",
            "608ce2322fd24bc38fe0f7ec5a7c46af",
            "e9bd1e9df8614a849fcdae1da0ca0369",
            "dd51bb586a9e47f9bd105dc60920b86d",
            "feabfa66c8874029bbcbe99e38b9d11c",
            "a5a04cfb26dd4bbdbde532a5d14578c1",
            "d306ca83a62a40069a916204e7952375",
            "c491ad93b96f4430a1dd01f0ec4fbd74",
            "e71e9639b87c4816af9997f931c1c9fb",
            "abb475e1b5d34ba98ffc47983b88f5c8",
            "ed2aa5b6830a448285c13431400374d2",
            "ec81b7ede40e4b2ca267e2bbda56b545",
            "ef4da194d5d2476d8fccd512d305942b",
            "6ba7980351dd4b20b735cfc1719dd517",
            "8289423d4ae844d6ae6fec002a320304",
            "296c8d15ee3e4d5fac02f74c7aa0e70c",
            "46b7b35419e242fc8c6be97242ad23a1",
            "1cec65c1933c4081a83d7f5f9d824def",
            "680e93a1cdcb45a29feca9f7aa2b0e7b",
            "99b826a4760843d589b54d926815cc73",
            "00040273d9ba4b158a2fcd398e230003",
            "0fdae08e52a04b04aa773df3f3982dc5",
            "6101a9827d4a4414b93ce9a29d6cd104",
            "dc2c4b6dc7654285b8c7ff7c83eb79f7",
            "f2626ff8d77f42b88cda64158fea4999",
            "ce55cf43640e47f19491bb7f45a21f7f",
            "547bad7b18774f4aaed0a66b590c1a09",
            "8115b4007c6a4b099f10965d25771af7",
            "cbc0d575106d44df84a5681e0c87173c",
            "6d10f9eb02414244b44ed572a6cd9361",
            "0b4fb3d18eaa46e3bfc65cda8ac2334d",
            "b2ce33b157f543c2be1cc23ede617dc7",
            "fa05a73b73d1450d92d0e816c873eb92",
            "7fb75ed79d8c494c8b19eed1bc846a08",
            "9dce5d15234d4325b7dbb40b070dedca",
            "d6976ef7a70d40fe84c0e4a8ea823bdf",
            "a44624dae9aa4fb5a87c59a1dd993aa7",
            "5d5723a921f54b669ef5c41d8db7c135",
            "9b383a99f8984fa0aa7b369b1039140e",
            "913a90cc9f6b4638bcf99c4aab0ae696",
            "e631103843114d52b776306a3f5c8442",
            "18df2292a7b34ea584db828050a7ddcb",
            "c98d4169e44e49cfa8f22df3225c683b"
          ]
        },
        "outputId": "16b88103-1c07-40cd-a0e3-abdfab2fd3bb"
      },
      "source": [
        "if __name__ == \"__main__\":\n",
        "    _, _, _, fashionmnist_test_dataset = get_FashionMNIST()\n",
        "    centroid_size=[2,10,50,256,500,1000]\n",
        "    l_gradient_penalty = 0.05\n",
        "    length_scales = [0.1]\n",
        "    epochs=30\n",
        "\n",
        "    repetition = 1  # Increase for multiple repetitions\n",
        "    final_model = True  # set true for final model to train on full train set\n",
        "\n",
        "    results = {}\n",
        "\n",
        "    for centroidz in centroid_size:\n",
        "        print(f\"for Centroid Size = {centroidz} \")\n",
        "        for length_scale in length_scales:\n",
        "            val_accuracies = []\n",
        "            test_accuracies = []\n",
        "            roc_aucs_mnist = []\n",
        "            roc_aucs_notmnist = []\n",
        "\n",
        "            for _ in range(repetition):\n",
        "                print(\" ### NEW MODEL ### \")\n",
        "                model, val_accuracy, test_accuracy = train_model(\n",
        "                    l_gradient_penalty, length_scale, final_model, epochs,centroidz\n",
        "                )\n",
        "                accuracy, roc_auc_mnist = get_fashionmnist_mnist_ood(model)\n",
        "                _, roc_auc_notmnist = get_fashionmnist_notmnist_ood(model)\n",
        "\n",
        "                val_accuracies.append(val_accuracy)\n",
        "                test_accuracies.append(test_accuracy)\n",
        "                roc_aucs_mnist.append(roc_auc_mnist)\n",
        "                roc_aucs_notmnist.append(roc_auc_notmnist)\n",
        "            \n",
        "            # All stats\n",
        "            results[f\"lgp{l_gradient_penalty}_ls{length_scale}_centroidsz{centroidz}\"] = [\n",
        "                (\"val acc\", np.mean(val_accuracies)),\n",
        "                (\"test acc\", np.mean(test_accuracies)),\n",
        "                (\"M auroc\", np.mean(roc_aucs_mnist)),\n",
        "                (\"NM auroc\", np.mean(roc_aucs_notmnist)),\n",
        "            ]\n",
        "            #print(results[f\"lgp{l_gradient_penalty}_ls{length_scale}\"])\n",
        "    \n",
        "    # Save\n",
        "    torch.save(model.state_dict(), \"DUQ_FM_30_FULL.pt\")\n",
        "    print(results)\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "for Centroid Size = 1000 \n",
            " ### NEW MODEL ### \n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "033c2cd6f6fc44aca937e213d7505d4a",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "\r"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "8da51da9caf84e7aa747548e2565093e",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "\r"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c0f8db75a1c941c29980631db65a17b3",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "\r"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "dd51bb586a9e47f9bd105dc60920b86d",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "\r"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ec81b7ede40e4b2ca267e2bbda56b545",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "\rValidation Results - Epoch: 5 Val_Acc: 0.8994 BCE: 0.06 GP: 0.153114 AUROC MNIST: 0.9261 AUROC NotMNIST: 0.96 \n",
            "Sigma: 0.1\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "99b826a4760843d589b54d926815cc73",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "\r"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "8115b4007c6a4b099f10965d25771af7",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "\r"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d6976ef7a70d40fe84c0e4a8ea823bdf",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    }
  ]
}