{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "131b07e8-4b49-4b45-b704-038c72890773",
   "metadata": {},
   "outputs": [],
   "source": [
    "## part of the code is from and modified: https://github.com/artemyk/ibsgd \n",
    "\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import os\n",
    "from tensorflow.keras import backend as K\n",
    "import utils\n",
    "\n",
    "class getMIOutput(tf.keras.callbacks.Callback):\n",
    "    def __init__(self, trn, tst, Z_layer_idx, num_selection, do_save_func=None, *kargs, **kwargs):\n",
    "        super(getMIOutput, self).__init__(*kargs, **kwargs)\n",
    "        self.trn = trn\n",
    "        self.tst = tst\n",
    "        self.Z_layer_idx = Z_layer_idx\n",
    "        self.num_selection = num_selection\n",
    "        self.do_save_func = do_save_func # control the saved epoch\n",
    "        self.layer_values = []\n",
    "        self.layerixs = []\n",
    "        self.layerfuncs = []\n",
    "\n",
    "    def on_train_begin(self, logs=None):\n",
    "        for lndx, l in enumerate(self.model.layers):\n",
    "            self.layerixs.append(lndx)\n",
    "            self.layer_values.append(lndx)\n",
    "            self.layerfuncs.append(K.function(self.model.inputs, [l.output,]))\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        if self.do_save_func is not None and not self.do_save_func(epoch):\n",
    "            return\n",
    "\n",
    "        data = {\n",
    "            'activity': []  # Activity in each layer\n",
    "        }\n",
    "\n",
    "        for lndx, layerix in enumerate(self.layerixs):\n",
    "            if lndx == self.Z_layer_idx:\n",
    "                clayer = self.model.layers[layerix]\n",
    "                activity_tst = self.layerfuncs[lndx]([self.trn[:self.num_selection],])[0]\n",
    "                data['activity'].append(activity_tst)\n",
    "\n",
    "        # Convert the list of numpy arrays to a single numpy array for npy compatibility\n",
    "        activity_tst_array = np.array(data['activity']).reshape(self.num_selection, -1)\n",
    "        \n",
    "        # Save the numpy array to an npy file\n",
    "        filename = f\"IB_epoch_{epoch}_z_{self.Z_layer_idx}.npy\"\n",
    "        filepath = os.path.join('savedata', filename)\n",
    "        np.save(filepath, activity_tst_array)\n",
    "\n",
    "        print(f\"Saved data for epoch {epoch} to {filename}\")\n",
    "        \n",
    "        \n",
    "def do_report_IB(epoch):\n",
    "    # Only log activity for some epochs.  Mainly this is to make things run faster.\n",
    "    if epoch < 20:       # Log for all first 20 epochs\n",
    "        return True\n",
    "    elif epoch < 100:    # Then for every 5th epoch\n",
    "        return (epoch % 5 == 0)\n",
    "    elif epoch < 2000:    # Then every 10th\n",
    "        return (epoch % 20 == 0)\n",
    "    else:                # Then every 100th\n",
    "        return (epoch % 100 == 0)\n",
    "    \n",
    "def get_dataset():\n",
    "    trn, tst = utils.get_IB_data('2017_12_21_16_51_3_275766')\n",
    "    return trn.X, trn.Y\n",
    "\n",
    "def train_model(config):\n",
    "    # Get data\n",
    "    trn, tst = utils.get_IB_data('2017_12_21_16_51_3_275766')\n",
    "\n",
    "    # Model training\n",
    "    tf.keras.backend.clear_session()\n",
    "    tf.random.set_seed(42)\n",
    "\n",
    "    input_layer = tf.keras.layers.Input((trn.X.shape[1],))\n",
    "    x = tf.keras.layers.Dense(10, activation=MODEL_CONFIG[\"activation\"])(input_layer)\n",
    "    x = tf.keras.layers.Dense(7, activation=MODEL_CONFIG[\"activation\"])(x)\n",
    "    x = tf.keras.layers.Dense(5, activation=MODEL_CONFIG[\"activation\"])(x)\n",
    "    x = tf.keras.layers.Dense(4, activation=MODEL_CONFIG[\"activation\"])(x)\n",
    "    x = tf.keras.layers.Dense(3, activation=MODEL_CONFIG[\"activation\"])(x)\n",
    "    CE_output = tf.keras.layers.Dense(2, activation='softmax', name='CE')(x)\n",
    "\n",
    "    model = tf.keras.Model(inputs=input_layer, outputs=[CE_output])\n",
    "\n",
    "    # Use the optimizer and learning rate from the config\n",
    "    if config[\"optimizer\"] == \"SGD\":\n",
    "        opt = tf.keras.optimizers.SGD(learning_rate=config[\"lr\"])\n",
    "    elif config[\"optimizer\"] == \"Adam\":\n",
    "        opt = tf.keras.optimizers.Adam(learning_rate=config[\"lr\"])\n",
    "    # Add other optimizers as needed\n",
    "\n",
    "    model.compile(optimizer=opt,\n",
    "                  loss={'CE': 'categorical_crossentropy'},\n",
    "                  metrics={'CE': 'accuracy'})\n",
    "\n",
    "    reporter = getMIOutput(trn=trn.X,\n",
    "                           tst=tst.X,\n",
    "                           Z_layer_idx=config[\"z_idx\"],  # Use z_idx from config\n",
    "                           num_selection=trn.X.shape[0],\n",
    "                           do_save_func=do_report_IB)\n",
    "\n",
    "    history = model.fit(x=trn.X, y=trn.Y,\n",
    "                        batch_size=config[\"batch_size\"],  # Use batch size from config\n",
    "                        epochs=config[\"epoch\"],  # Use number of epochs from config\n",
    "                        verbose=0,\n",
    "                        validation_data=(tst.X, tst.Y),\n",
    "                        callbacks=[reporter,])\n",
    "\n",
    "    # Print the final generalization gap (train accuracy - test accuracy / train loss - test loss)\n",
    "    final_train_acc = history.history['accuracy'][-1]\n",
    "    final_val_acc = history.history['val_accuracy'][-1]\n",
    "    final_train_loss = history.history['loss'][-1]\n",
    "    final_val_loss = history.history['val_loss'][-1]\n",
    "\n",
    "    generalization_gap_acc = final_train_acc - final_val_acc\n",
    "    generalization_gap_loss = final_train_loss - final_val_loss\n",
    "\n",
    "    print(f\"Final Generalization Gap (Accuracy): {generalization_gap_acc}\")\n",
    "    print(f\"Final Generalization Gap (Loss): {generalization_gap_loss}\")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e7508c22-bf65-4fbf-88a5-4cd96aef06bb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved data for epoch 0 to IB_epoch_0_z_5.npy\n",
      "Saved data for epoch 1 to IB_epoch_1_z_5.npy\n",
      "Saved data for epoch 2 to IB_epoch_2_z_5.npy\n",
      "Saved data for epoch 3 to IB_epoch_3_z_5.npy\n",
      "Saved data for epoch 4 to IB_epoch_4_z_5.npy\n",
      "Saved data for epoch 5 to IB_epoch_5_z_5.npy\n",
      "Saved data for epoch 6 to IB_epoch_6_z_5.npy\n",
      "Saved data for epoch 7 to IB_epoch_7_z_5.npy\n",
      "Saved data for epoch 8 to IB_epoch_8_z_5.npy\n",
      "Saved data for epoch 9 to IB_epoch_9_z_5.npy\n",
      "Saved data for epoch 10 to IB_epoch_10_z_5.npy\n",
      "Saved data for epoch 11 to IB_epoch_11_z_5.npy\n",
      "Saved data for epoch 12 to IB_epoch_12_z_5.npy\n",
      "Saved data for epoch 13 to IB_epoch_13_z_5.npy\n",
      "Saved data for epoch 14 to IB_epoch_14_z_5.npy\n",
      "Saved data for epoch 15 to IB_epoch_15_z_5.npy\n",
      "Saved data for epoch 16 to IB_epoch_16_z_5.npy\n",
      "Saved data for epoch 17 to IB_epoch_17_z_5.npy\n",
      "Saved data for epoch 18 to IB_epoch_18_z_5.npy\n",
      "Saved data for epoch 19 to IB_epoch_19_z_5.npy\n",
      "Saved data for epoch 20 to IB_epoch_20_z_5.npy\n",
      "Saved data for epoch 25 to IB_epoch_25_z_5.npy\n",
      "Saved data for epoch 30 to IB_epoch_30_z_5.npy\n",
      "Saved data for epoch 35 to IB_epoch_35_z_5.npy\n",
      "Saved data for epoch 40 to IB_epoch_40_z_5.npy\n",
      "Saved data for epoch 45 to IB_epoch_45_z_5.npy\n",
      "Saved data for epoch 50 to IB_epoch_50_z_5.npy\n",
      "Saved data for epoch 55 to IB_epoch_55_z_5.npy\n",
      "Saved data for epoch 60 to IB_epoch_60_z_5.npy\n",
      "Saved data for epoch 65 to IB_epoch_65_z_5.npy\n",
      "Saved data for epoch 70 to IB_epoch_70_z_5.npy\n",
      "Saved data for epoch 75 to IB_epoch_75_z_5.npy\n",
      "Saved data for epoch 80 to IB_epoch_80_z_5.npy\n",
      "Saved data for epoch 85 to IB_epoch_85_z_5.npy\n",
      "Saved data for epoch 90 to IB_epoch_90_z_5.npy\n",
      "Saved data for epoch 95 to IB_epoch_95_z_5.npy\n",
      "Saved data for epoch 100 to IB_epoch_100_z_5.npy\n",
      "Saved data for epoch 120 to IB_epoch_120_z_5.npy\n",
      "Saved data for epoch 140 to IB_epoch_140_z_5.npy\n",
      "Saved data for epoch 160 to IB_epoch_160_z_5.npy\n",
      "Saved data for epoch 180 to IB_epoch_180_z_5.npy\n",
      "Saved data for epoch 200 to IB_epoch_200_z_5.npy\n",
      "Saved data for epoch 220 to IB_epoch_220_z_5.npy\n",
      "Saved data for epoch 240 to IB_epoch_240_z_5.npy\n",
      "Saved data for epoch 260 to IB_epoch_260_z_5.npy\n",
      "Saved data for epoch 280 to IB_epoch_280_z_5.npy\n",
      "Saved data for epoch 300 to IB_epoch_300_z_5.npy\n",
      "Saved data for epoch 320 to IB_epoch_320_z_5.npy\n",
      "Saved data for epoch 340 to IB_epoch_340_z_5.npy\n",
      "Saved data for epoch 360 to IB_epoch_360_z_5.npy\n",
      "Saved data for epoch 380 to IB_epoch_380_z_5.npy\n",
      "Saved data for epoch 400 to IB_epoch_400_z_5.npy\n",
      "Saved data for epoch 420 to IB_epoch_420_z_5.npy\n",
      "Saved data for epoch 440 to IB_epoch_440_z_5.npy\n",
      "Saved data for epoch 460 to IB_epoch_460_z_5.npy\n",
      "Saved data for epoch 480 to IB_epoch_480_z_5.npy\n",
      "Saved data for epoch 500 to IB_epoch_500_z_5.npy\n",
      "Saved data for epoch 520 to IB_epoch_520_z_5.npy\n",
      "Saved data for epoch 540 to IB_epoch_540_z_5.npy\n",
      "Saved data for epoch 560 to IB_epoch_560_z_5.npy\n",
      "Saved data for epoch 580 to IB_epoch_580_z_5.npy\n",
      "Saved data for epoch 600 to IB_epoch_600_z_5.npy\n",
      "Saved data for epoch 620 to IB_epoch_620_z_5.npy\n",
      "Saved data for epoch 640 to IB_epoch_640_z_5.npy\n",
      "Saved data for epoch 660 to IB_epoch_660_z_5.npy\n",
      "Saved data for epoch 680 to IB_epoch_680_z_5.npy\n",
      "Saved data for epoch 700 to IB_epoch_700_z_5.npy\n",
      "Saved data for epoch 720 to IB_epoch_720_z_5.npy\n",
      "Saved data for epoch 740 to IB_epoch_740_z_5.npy\n",
      "Saved data for epoch 760 to IB_epoch_760_z_5.npy\n",
      "Saved data for epoch 780 to IB_epoch_780_z_5.npy\n",
      "Saved data for epoch 800 to IB_epoch_800_z_5.npy\n",
      "Saved data for epoch 820 to IB_epoch_820_z_5.npy\n",
      "Saved data for epoch 840 to IB_epoch_840_z_5.npy\n",
      "Saved data for epoch 860 to IB_epoch_860_z_5.npy\n",
      "Saved data for epoch 880 to IB_epoch_880_z_5.npy\n",
      "Saved data for epoch 900 to IB_epoch_900_z_5.npy\n",
      "Saved data for epoch 920 to IB_epoch_920_z_5.npy\n",
      "Saved data for epoch 940 to IB_epoch_940_z_5.npy\n",
      "Saved data for epoch 960 to IB_epoch_960_z_5.npy\n",
      "Saved data for epoch 980 to IB_epoch_980_z_5.npy\n",
      "Saved data for epoch 1000 to IB_epoch_1000_z_5.npy\n",
      "Saved data for epoch 1020 to IB_epoch_1020_z_5.npy\n",
      "Saved data for epoch 1040 to IB_epoch_1040_z_5.npy\n",
      "Saved data for epoch 1060 to IB_epoch_1060_z_5.npy\n",
      "Saved data for epoch 1080 to IB_epoch_1080_z_5.npy\n",
      "Saved data for epoch 1100 to IB_epoch_1100_z_5.npy\n",
      "Saved data for epoch 1120 to IB_epoch_1120_z_5.npy\n",
      "Saved data for epoch 1140 to IB_epoch_1140_z_5.npy\n",
      "Saved data for epoch 1160 to IB_epoch_1160_z_5.npy\n",
      "Saved data for epoch 1180 to IB_epoch_1180_z_5.npy\n",
      "Saved data for epoch 1200 to IB_epoch_1200_z_5.npy\n",
      "Saved data for epoch 1220 to IB_epoch_1220_z_5.npy\n",
      "Saved data for epoch 1240 to IB_epoch_1240_z_5.npy\n",
      "Saved data for epoch 1260 to IB_epoch_1260_z_5.npy\n",
      "Saved data for epoch 1280 to IB_epoch_1280_z_5.npy\n",
      "Saved data for epoch 1300 to IB_epoch_1300_z_5.npy\n",
      "Saved data for epoch 1320 to IB_epoch_1320_z_5.npy\n",
      "Saved data for epoch 1340 to IB_epoch_1340_z_5.npy\n",
      "Saved data for epoch 1360 to IB_epoch_1360_z_5.npy\n",
      "Saved data for epoch 1380 to IB_epoch_1380_z_5.npy\n",
      "Saved data for epoch 1400 to IB_epoch_1400_z_5.npy\n",
      "Saved data for epoch 1420 to IB_epoch_1420_z_5.npy\n",
      "Saved data for epoch 1440 to IB_epoch_1440_z_5.npy\n",
      "Saved data for epoch 1460 to IB_epoch_1460_z_5.npy\n",
      "Saved data for epoch 1480 to IB_epoch_1480_z_5.npy\n",
      "Saved data for epoch 1500 to IB_epoch_1500_z_5.npy\n",
      "Saved data for epoch 1520 to IB_epoch_1520_z_5.npy\n",
      "Saved data for epoch 1540 to IB_epoch_1540_z_5.npy\n",
      "Saved data for epoch 1560 to IB_epoch_1560_z_5.npy\n",
      "Saved data for epoch 1580 to IB_epoch_1580_z_5.npy\n",
      "Saved data for epoch 1600 to IB_epoch_1600_z_5.npy\n",
      "Saved data for epoch 1620 to IB_epoch_1620_z_5.npy\n",
      "Saved data for epoch 1640 to IB_epoch_1640_z_5.npy\n",
      "Saved data for epoch 1660 to IB_epoch_1660_z_5.npy\n",
      "Saved data for epoch 1680 to IB_epoch_1680_z_5.npy\n",
      "Saved data for epoch 1700 to IB_epoch_1700_z_5.npy\n",
      "Saved data for epoch 1720 to IB_epoch_1720_z_5.npy\n",
      "Saved data for epoch 1740 to IB_epoch_1740_z_5.npy\n",
      "Saved data for epoch 1760 to IB_epoch_1760_z_5.npy\n",
      "Saved data for epoch 1780 to IB_epoch_1780_z_5.npy\n",
      "Saved data for epoch 1800 to IB_epoch_1800_z_5.npy\n",
      "Saved data for epoch 1820 to IB_epoch_1820_z_5.npy\n",
      "Saved data for epoch 1840 to IB_epoch_1840_z_5.npy\n",
      "Saved data for epoch 1860 to IB_epoch_1860_z_5.npy\n",
      "Saved data for epoch 1880 to IB_epoch_1880_z_5.npy\n",
      "Saved data for epoch 1900 to IB_epoch_1900_z_5.npy\n",
      "Saved data for epoch 1920 to IB_epoch_1920_z_5.npy\n",
      "Saved data for epoch 1940 to IB_epoch_1940_z_5.npy\n",
      "Saved data for epoch 1960 to IB_epoch_1960_z_5.npy\n",
      "Saved data for epoch 1980 to IB_epoch_1980_z_5.npy\n",
      "Final Generalization Gap (Accuracy): 0.007941901683807373\n",
      "Final Generalization Gap (Loss): -0.012580882757902145\n"
     ]
    }
   ],
   "source": [
    "MODEL_CONFIG = {\n",
    "    \"optimizer\": \"SGD\",\n",
    "    \"lr\": 5e-3,\n",
    "    \"batch_size\": 256,\n",
    "    \"epoch\": 2000,\n",
    "    \"z_idx\": 5,\n",
    "    \"activation\": \"tanh\"\n",
    "    \n",
    "}\n",
    "\n",
    "# Call the train_model function with the config\n",
    "train_model(MODEL_CONFIG)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c451288-c430-4a91-b505-0cf16b61f34b",
   "metadata": {},
   "source": [
    "13/13 [==============================] - 0s 4ms/step - loss: 0.0818 - accuracy: 0.9707 - val_loss: 0.1254 - val_accuracy: 0.9499\n",
    "Final Generalization Gap (Accuracy): 0.020765960216522217\n",
    "Final Generalization Gap (Loss): -0.0436239168047905"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea47f7b5-fd39-4bb0-ae1d-4e7b84b7212c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92a57f9d-81ef-44d6-b0e3-a5c9c6a8e0f5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
