{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "131b07e8-4b49-4b45-b704-038c72890773",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\mdeep2\\anaconda3\\envs\\py39\\lib\\site-packages\\h5py\\__init__.py:36: UserWarning: h5py is running against HDF5 1.12.2 when it was built against 1.12.1, this may cause problems\n",
      "  _warn((\"h5py is running against HDF5 {0} when it was built against {1}, \"\n"
     ]
    }
   ],
   "source": [
    "## part of the code is from and modified: https://github.com/artemyk/ibsgd \n",
    "\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import os\n",
    "from tensorflow.keras import backend as K\n",
    "import utils\n",
    "\n",
    "def generate_data():\n",
    "\n",
    "    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "\n",
    "    # normalization to 1\n",
    "    x_train = (x_train / 255.0)\n",
    "    x_test = (x_test / 255.0)\n",
    "\n",
    "    x_train = x_train.reshape(x_train.shape[0], -1)\n",
    "    x_test = x_test.reshape(x_test.shape[0], -1)\n",
    "\n",
    "    Y_train = tf.keras.utils.to_categorical(y_train, 10)\n",
    "    Y_test = tf.keras.utils.to_categorical(y_test, 10)\n",
    "\n",
    "\n",
    "    return x_train,Y_train,x_test,Y_test\n",
    "\n",
    "class getMIOutput(tf.keras.callbacks.Callback):\n",
    "    def __init__(self, trn, tst, Z_layer_idx, num_selection, do_save_func=None, *kargs, **kwargs):\n",
    "        super(getMIOutput, self).__init__(*kargs, **kwargs)\n",
    "        self.trn = trn\n",
    "        self.tst = tst\n",
    "        self.Z_layer_idx = Z_layer_idx\n",
    "        self.num_selection = num_selection\n",
    "        self.do_save_func = do_save_func # control the saved epoch\n",
    "        self.layer_values = []\n",
    "        self.layerixs = []\n",
    "        self.layerfuncs = []\n",
    "\n",
    "    def on_train_begin(self, logs=None):\n",
    "        for lndx, l in enumerate(self.model.layers):\n",
    "            self.layerixs.append(lndx)\n",
    "            self.layer_values.append(lndx)\n",
    "            self.layerfuncs.append(K.function(self.model.inputs, [l.output,]))\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        if self.do_save_func is not None and not self.do_save_func(epoch):\n",
    "            return\n",
    "\n",
    "        data = {\n",
    "            'activity': []  # Activity in each layer\n",
    "        }\n",
    "\n",
    "        for lndx, layerix in enumerate(self.layerixs):\n",
    "            if lndx == self.Z_layer_idx:\n",
    "                clayer = self.model.layers[layerix]\n",
    "                activity_tst = self.layerfuncs[lndx]([self.trn[:self.num_selection],])[0]\n",
    "                data['activity'].append(activity_tst)\n",
    "\n",
    "        # Convert the list of numpy arrays to a single numpy array for npy compatibility\n",
    "        activity_tst_array = np.array(data['activity']).reshape(self.num_selection, -1)\n",
    "        \n",
    "        # Save the numpy array to an npy file\n",
    "        filename = f\"MN_epoch_{epoch}_z_{self.Z_layer_idx}.npy\"\n",
    "        filepath = os.path.join('savedata', filename)\n",
    "        np.save(filepath, activity_tst_array)\n",
    "\n",
    "        print(f\"Saved data for epoch {epoch} to {filename}\")\n",
    "        \n",
    "        \n",
    "def do_report_MN(epoch):\n",
    "    if epoch < 100  :       \n",
    "        return (epoch % 1 == 0)\n",
    "    elif epoch < 200  :\n",
    "        return (epoch % 5 == 0)\n",
    "    \n",
    "    \n",
    "\n",
    "def train_model(config):\n",
    "    # Get data\n",
    "    x_1train,Y_1train,x_1test,Y_1test = generate_data()\n",
    "\n",
    "    # Model training\n",
    "    tf.keras.backend.clear_session()\n",
    "    tf.random.set_seed(42)\n",
    "\n",
    "    input_layer = tf.keras.layers.Input((x_1train.shape[1],))\n",
    "    x = tf.keras.layers.Dense(1024, activation='relu')(input_layer)\n",
    "    x = tf.keras.layers.Dense(20, activation='relu')(x)\n",
    "    x = tf.keras.layers.Dense(20, activation='relu')(x)\n",
    "    x = tf.keras.layers.Dense(20, activation='relu')(x)\n",
    "    CE_output = tf.keras.layers.Dense(10, activation='softmax', name='CE')(x)\n",
    "\n",
    "    model = tf.keras.Model(inputs=input_layer, outputs=[CE_output])\n",
    "\n",
    "    # Use the optimizer and learning rate from the config\n",
    "    if config[\"optimizer\"] == \"SGD\":\n",
    "        opt = tf.keras.optimizers.SGD(learning_rate=config[\"lr\"])\n",
    "    elif config[\"optimizer\"] == \"Adam\":\n",
    "        opt = tf.keras.optimizers.Adam(learning_rate=config[\"lr\"])\n",
    "    # Add other optimizers as needed\n",
    "\n",
    "    model.compile(optimizer=opt,\n",
    "                  loss={'CE': 'categorical_crossentropy'},\n",
    "                  metrics={'CE': 'accuracy'})\n",
    "\n",
    "    reporter = getMIOutput(trn=x_1train,\n",
    "                           tst=x_1test,\n",
    "                           Z_layer_idx=config[\"z_idx\"],  # Use z_idx from config\n",
    "                           num_selection=10000,\n",
    "                           do_save_func=do_report_MN)\n",
    "\n",
    "    history = model.fit(x=x_1train, y=Y_1train,\n",
    "                        batch_size=config[\"batch_size\"],  # Use batch size from config\n",
    "                        epochs=config[\"epoch\"],  # Use number of epochs from config\n",
    "                        verbose=0,\n",
    "                        validation_data=(x_1test, Y_1test),\n",
    "                        callbacks=[reporter,])\n",
    "\n",
    "    # Print the final generalization gap (train accuracy - test accuracy / train loss - test loss)\n",
    "    final_train_acc = history.history['accuracy'][-1]\n",
    "    final_val_acc = history.history['val_accuracy'][-1]\n",
    "    final_train_loss = history.history['loss'][-1]\n",
    "    final_val_loss = history.history['val_loss'][-1]\n",
    "\n",
    "    generalization_gap_acc = final_train_acc - final_val_acc\n",
    "    generalization_gap_loss = final_train_loss - final_val_loss\n",
    "\n",
    "    print(f\"Final Generalization Gap (Accuracy): {generalization_gap_acc}\")\n",
    "    print(f\"Final Generalization Gap (Loss): {generalization_gap_loss}\")\n",
    "    print(f\"Final Train Accuracy: {final_train_acc}\")\n",
    "    print(f\"Final Val Accuracy: {final_val_acc}\")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e7508c22-bf65-4fbf-88a5-4cd96aef06bb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved data for epoch 0 to MN_epoch_0_z_4.npy\n",
      "Saved data for epoch 1 to MN_epoch_1_z_4.npy\n",
      "Saved data for epoch 2 to MN_epoch_2_z_4.npy\n",
      "Saved data for epoch 3 to MN_epoch_3_z_4.npy\n",
      "Saved data for epoch 4 to MN_epoch_4_z_4.npy\n",
      "Saved data for epoch 5 to MN_epoch_5_z_4.npy\n",
      "Saved data for epoch 6 to MN_epoch_6_z_4.npy\n",
      "Saved data for epoch 7 to MN_epoch_7_z_4.npy\n",
      "Saved data for epoch 8 to MN_epoch_8_z_4.npy\n",
      "Saved data for epoch 9 to MN_epoch_9_z_4.npy\n",
      "Saved data for epoch 10 to MN_epoch_10_z_4.npy\n",
      "Saved data for epoch 11 to MN_epoch_11_z_4.npy\n",
      "Saved data for epoch 12 to MN_epoch_12_z_4.npy\n",
      "Saved data for epoch 13 to MN_epoch_13_z_4.npy\n",
      "Saved data for epoch 14 to MN_epoch_14_z_4.npy\n",
      "Saved data for epoch 15 to MN_epoch_15_z_4.npy\n",
      "Saved data for epoch 16 to MN_epoch_16_z_4.npy\n",
      "Saved data for epoch 17 to MN_epoch_17_z_4.npy\n",
      "Saved data for epoch 18 to MN_epoch_18_z_4.npy\n",
      "Saved data for epoch 19 to MN_epoch_19_z_4.npy\n",
      "Saved data for epoch 20 to MN_epoch_20_z_4.npy\n",
      "Saved data for epoch 21 to MN_epoch_21_z_4.npy\n",
      "Saved data for epoch 22 to MN_epoch_22_z_4.npy\n",
      "Saved data for epoch 23 to MN_epoch_23_z_4.npy\n",
      "Saved data for epoch 24 to MN_epoch_24_z_4.npy\n",
      "Saved data for epoch 25 to MN_epoch_25_z_4.npy\n",
      "Saved data for epoch 26 to MN_epoch_26_z_4.npy\n",
      "Saved data for epoch 27 to MN_epoch_27_z_4.npy\n",
      "Saved data for epoch 28 to MN_epoch_28_z_4.npy\n",
      "Saved data for epoch 29 to MN_epoch_29_z_4.npy\n",
      "Saved data for epoch 30 to MN_epoch_30_z_4.npy\n",
      "Saved data for epoch 31 to MN_epoch_31_z_4.npy\n",
      "Saved data for epoch 32 to MN_epoch_32_z_4.npy\n",
      "Saved data for epoch 33 to MN_epoch_33_z_4.npy\n",
      "Saved data for epoch 34 to MN_epoch_34_z_4.npy\n",
      "Saved data for epoch 35 to MN_epoch_35_z_4.npy\n",
      "Saved data for epoch 36 to MN_epoch_36_z_4.npy\n",
      "Saved data for epoch 37 to MN_epoch_37_z_4.npy\n",
      "Saved data for epoch 38 to MN_epoch_38_z_4.npy\n",
      "Saved data for epoch 39 to MN_epoch_39_z_4.npy\n",
      "Saved data for epoch 40 to MN_epoch_40_z_4.npy\n",
      "Saved data for epoch 41 to MN_epoch_41_z_4.npy\n",
      "Saved data for epoch 42 to MN_epoch_42_z_4.npy\n",
      "Saved data for epoch 43 to MN_epoch_43_z_4.npy\n",
      "Saved data for epoch 44 to MN_epoch_44_z_4.npy\n",
      "Saved data for epoch 45 to MN_epoch_45_z_4.npy\n",
      "Saved data for epoch 46 to MN_epoch_46_z_4.npy\n",
      "Saved data for epoch 47 to MN_epoch_47_z_4.npy\n",
      "Saved data for epoch 48 to MN_epoch_48_z_4.npy\n",
      "Saved data for epoch 49 to MN_epoch_49_z_4.npy\n",
      "Saved data for epoch 50 to MN_epoch_50_z_4.npy\n",
      "Saved data for epoch 51 to MN_epoch_51_z_4.npy\n",
      "Saved data for epoch 52 to MN_epoch_52_z_4.npy\n",
      "Saved data for epoch 53 to MN_epoch_53_z_4.npy\n",
      "Saved data for epoch 54 to MN_epoch_54_z_4.npy\n",
      "Saved data for epoch 55 to MN_epoch_55_z_4.npy\n",
      "Saved data for epoch 56 to MN_epoch_56_z_4.npy\n",
      "Saved data for epoch 57 to MN_epoch_57_z_4.npy\n",
      "Saved data for epoch 58 to MN_epoch_58_z_4.npy\n",
      "Saved data for epoch 59 to MN_epoch_59_z_4.npy\n",
      "Saved data for epoch 60 to MN_epoch_60_z_4.npy\n",
      "Saved data for epoch 61 to MN_epoch_61_z_4.npy\n",
      "Saved data for epoch 62 to MN_epoch_62_z_4.npy\n",
      "Saved data for epoch 63 to MN_epoch_63_z_4.npy\n",
      "Saved data for epoch 64 to MN_epoch_64_z_4.npy\n",
      "Saved data for epoch 65 to MN_epoch_65_z_4.npy\n",
      "Saved data for epoch 66 to MN_epoch_66_z_4.npy\n",
      "Saved data for epoch 67 to MN_epoch_67_z_4.npy\n",
      "Saved data for epoch 68 to MN_epoch_68_z_4.npy\n",
      "Saved data for epoch 69 to MN_epoch_69_z_4.npy\n",
      "Saved data for epoch 70 to MN_epoch_70_z_4.npy\n",
      "Saved data for epoch 71 to MN_epoch_71_z_4.npy\n",
      "Saved data for epoch 72 to MN_epoch_72_z_4.npy\n",
      "Saved data for epoch 73 to MN_epoch_73_z_4.npy\n",
      "Saved data for epoch 74 to MN_epoch_74_z_4.npy\n",
      "Saved data for epoch 75 to MN_epoch_75_z_4.npy\n",
      "Saved data for epoch 76 to MN_epoch_76_z_4.npy\n",
      "Saved data for epoch 77 to MN_epoch_77_z_4.npy\n",
      "Saved data for epoch 78 to MN_epoch_78_z_4.npy\n",
      "Saved data for epoch 79 to MN_epoch_79_z_4.npy\n",
      "Saved data for epoch 80 to MN_epoch_80_z_4.npy\n",
      "Saved data for epoch 81 to MN_epoch_81_z_4.npy\n",
      "Saved data for epoch 82 to MN_epoch_82_z_4.npy\n",
      "Saved data for epoch 83 to MN_epoch_83_z_4.npy\n",
      "Saved data for epoch 84 to MN_epoch_84_z_4.npy\n",
      "Saved data for epoch 85 to MN_epoch_85_z_4.npy\n",
      "Saved data for epoch 86 to MN_epoch_86_z_4.npy\n",
      "Saved data for epoch 87 to MN_epoch_87_z_4.npy\n",
      "Saved data for epoch 88 to MN_epoch_88_z_4.npy\n",
      "Saved data for epoch 89 to MN_epoch_89_z_4.npy\n",
      "Saved data for epoch 90 to MN_epoch_90_z_4.npy\n",
      "Saved data for epoch 91 to MN_epoch_91_z_4.npy\n",
      "Saved data for epoch 92 to MN_epoch_92_z_4.npy\n",
      "Saved data for epoch 93 to MN_epoch_93_z_4.npy\n",
      "Saved data for epoch 94 to MN_epoch_94_z_4.npy\n",
      "Saved data for epoch 95 to MN_epoch_95_z_4.npy\n",
      "Saved data for epoch 96 to MN_epoch_96_z_4.npy\n",
      "Saved data for epoch 97 to MN_epoch_97_z_4.npy\n",
      "Saved data for epoch 98 to MN_epoch_98_z_4.npy\n",
      "Saved data for epoch 99 to MN_epoch_99_z_4.npy\n",
      "Saved data for epoch 100 to MN_epoch_100_z_4.npy\n",
      "Saved data for epoch 105 to MN_epoch_105_z_4.npy\n",
      "Saved data for epoch 110 to MN_epoch_110_z_4.npy\n",
      "Saved data for epoch 115 to MN_epoch_115_z_4.npy\n",
      "Saved data for epoch 120 to MN_epoch_120_z_4.npy\n",
      "Saved data for epoch 125 to MN_epoch_125_z_4.npy\n",
      "Saved data for epoch 130 to MN_epoch_130_z_4.npy\n",
      "Saved data for epoch 135 to MN_epoch_135_z_4.npy\n",
      "Saved data for epoch 140 to MN_epoch_140_z_4.npy\n",
      "Saved data for epoch 145 to MN_epoch_145_z_4.npy\n",
      "Saved data for epoch 150 to MN_epoch_150_z_4.npy\n",
      "Saved data for epoch 155 to MN_epoch_155_z_4.npy\n",
      "Saved data for epoch 160 to MN_epoch_160_z_4.npy\n",
      "Saved data for epoch 165 to MN_epoch_165_z_4.npy\n",
      "Saved data for epoch 170 to MN_epoch_170_z_4.npy\n",
      "Saved data for epoch 175 to MN_epoch_175_z_4.npy\n",
      "Saved data for epoch 180 to MN_epoch_180_z_4.npy\n",
      "Saved data for epoch 185 to MN_epoch_185_z_4.npy\n",
      "Saved data for epoch 190 to MN_epoch_190_z_4.npy\n",
      "Saved data for epoch 195 to MN_epoch_195_z_4.npy\n",
      "Final Generalization Gap (Accuracy): 0.00550001859664917\n",
      "Final Generalization Gap (Loss): -0.01142418384552002\n",
      "Final Train Accuracy: 0.9611999988555908\n",
      "Final Val Accuracy: 0.9556999802589417\n"
     ]
    }
   ],
   "source": [
    "MODEL_CONFIG = {\n",
    "    \"optimizer\": \"SGD\",\n",
    "    \"lr\": 0.0005,\n",
    "    \"batch_size\": 128,\n",
    "    \"epoch\": 200,\n",
    "    \"z_idx\": 4,\n",
    "    \"activation\":\"relu\"   \n",
    "}\n",
    "\n",
    "# Call the train_model function with the config\n",
    "train_model(MODEL_CONFIG)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea47f7b5-fd39-4bb0-ae1d-4e7b84b7212c",
   "metadata": {},
   "outputs": [],
   "source": [
    "importgenerate_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92a57f9d-81ef-44d6-b0e3-a5c9c6a8e0f5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
