{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "131b07e8-4b49-4b45-b704-038c72890773",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\mdeep2\\anaconda3\\envs\\py39\\lib\\site-packages\\h5py\\__init__.py:36: UserWarning: h5py is running against HDF5 1.12.2 when it was built against 1.12.1, this may cause problems\n",
      "  _warn((\"h5py is running against HDF5 {0} when it was built against {1}, \"\n"
     ]
    }
   ],
   "source": [
    "## part of the code is from and modified: https://github.com/artemyk/ibsgd \n",
    "\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import os\n",
    "from tensorflow.keras import backend as K\n",
    "import utils\n",
    "\n",
    "def generate_data():\n",
    "\n",
    "    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "\n",
    "    # normalization to 1\n",
    "    x_train = (x_train / 255.0)\n",
    "    x_test = (x_test / 255.0)\n",
    "\n",
    "    x_train = x_train.reshape(x_train.shape[0], -1)\n",
    "    x_test = x_test.reshape(x_test.shape[0], -1)\n",
    "\n",
    "    Y_train = tf.keras.utils.to_categorical(y_train, 10)\n",
    "    Y_test = tf.keras.utils.to_categorical(y_test, 10)\n",
    "\n",
    "\n",
    "    return x_train,Y_train,x_test,Y_test\n",
    "\n",
    "class getMIOutput(tf.keras.callbacks.Callback):\n",
    "    def __init__(self, trn, tst, Z_layer_idx, num_selection, do_save_func=None, *kargs, **kwargs):\n",
    "        super(getMIOutput, self).__init__(*kargs, **kwargs)\n",
    "        self.trn = trn\n",
    "        self.tst = tst\n",
    "        self.Z_layer_idx = Z_layer_idx\n",
    "        self.num_selection = num_selection\n",
    "        self.do_save_func = do_save_func # control the saved epoch\n",
    "        self.layer_values = []\n",
    "        self.layerixs = []\n",
    "        self.layerfuncs = []\n",
    "\n",
    "    def on_train_begin(self, logs=None):\n",
    "        for lndx, l in enumerate(self.model.layers):\n",
    "            self.layerixs.append(lndx)\n",
    "            self.layer_values.append(lndx)\n",
    "            self.layerfuncs.append(K.function(self.model.inputs, [l.output,]))\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        if self.do_save_func is not None and not self.do_save_func(epoch):\n",
    "            return\n",
    "\n",
    "        data = {\n",
    "            'activity': []  # Activity in each layer\n",
    "        }\n",
    "\n",
    "        for lndx, layerix in enumerate(self.layerixs):\n",
    "            if lndx == self.Z_layer_idx:\n",
    "                clayer = self.model.layers[layerix]\n",
    "                activity_tst = self.layerfuncs[lndx]([self.trn[:self.num_selection],])[0]\n",
    "                data['activity'].append(activity_tst)\n",
    "\n",
    "        # Convert the list of numpy arrays to a single numpy array for npy compatibility\n",
    "        activity_tst_array = np.array(data['activity']).reshape(self.num_selection, -1)\n",
    "        \n",
    "        # Save the numpy array to an npy file\n",
    "        filename = f\"MN_epoch_{epoch}_z_{self.Z_layer_idx}.npy\"\n",
    "        filepath = os.path.join('savedata', filename)\n",
    "        np.save(filepath, activity_tst_array)\n",
    "\n",
    "        print(f\"Saved data for epoch {epoch} to {filename}\")\n",
    "        \n",
    "        \n",
    "def do_report_MN(epoch):\n",
    "    if epoch < 100  :       \n",
    "        return (epoch % 1 == 0)\n",
    "    elif epoch < 200  :\n",
    "        return (epoch % 5 == 0)\n",
    "    \n",
    "    \n",
    "\n",
    "def train_model(config):\n",
    "    # Get data\n",
    "    x_1train,Y_1train,x_1test,Y_1test = generate_data()\n",
    "\n",
    "    # Model training\n",
    "    tf.keras.backend.clear_session()\n",
    "    tf.random.set_seed(42)\n",
    "\n",
    "    input_layer = tf.keras.layers.Input((x_1train.shape[1],))\n",
    "    x = tf.keras.layers.Dense(1024, activation='relu')(input_layer)\n",
    "    x = tf.keras.layers.Dense(20, activation='relu')(x)\n",
    "    x = tf.keras.layers.Dense(20, activation='relu')(x)\n",
    "    x = tf.keras.layers.Dense(20, activation='relu')(x)\n",
    "    CE_output = tf.keras.layers.Dense(10, activation='softmax', name='CE')(x)\n",
    "\n",
    "    model = tf.keras.Model(inputs=input_layer, outputs=[CE_output])\n",
    "\n",
    "    # Use the optimizer and learning rate from the config\n",
    "    if config[\"optimizer\"] == \"SGD\":\n",
    "        opt = tf.keras.optimizers.SGD(learning_rate=config[\"lr\"])\n",
    "    elif config[\"optimizer\"] == \"Adam\":\n",
    "        opt = tf.keras.optimizers.Adam(learning_rate=config[\"lr\"])\n",
    "    # Add other optimizers as needed\n",
    "\n",
    "    model.compile(optimizer=opt,\n",
    "                  loss={'CE': 'categorical_crossentropy'},\n",
    "                  metrics={'CE': 'accuracy'})\n",
    "\n",
    "    reporter = getMIOutput(trn=x_1train,\n",
    "                           tst=x_1test,\n",
    "                           Z_layer_idx=config[\"z_idx\"],  # Use z_idx from config\n",
    "                           num_selection=10000,\n",
    "                           do_save_func=do_report_MN)\n",
    "\n",
    "    history = model.fit(x=x_1train, y=Y_1train,\n",
    "                        batch_size=config[\"batch_size\"],  # Use batch size from config\n",
    "                        epochs=config[\"epoch\"],  # Use number of epochs from config\n",
    "                        verbose=0,\n",
    "                        validation_data=(x_1test, Y_1test),\n",
    "                        callbacks=[reporter,])\n",
    "\n",
    "    # Print the final generalization gap (train accuracy - test accuracy / train loss - test loss)\n",
    "    final_train_acc = history.history['accuracy'][-1]\n",
    "    final_val_acc = history.history['val_accuracy'][-1]\n",
    "    final_train_loss = history.history['loss'][-1]\n",
    "    final_val_loss = history.history['val_loss'][-1]\n",
    "\n",
    "    generalization_gap_acc = final_train_acc - final_val_acc\n",
    "    generalization_gap_loss = final_train_loss - final_val_loss\n",
    "\n",
    "    print(f\"Final Generalization Gap (Accuracy): {generalization_gap_acc}\")\n",
    "    print(f\"Final Generalization Gap (Loss): {generalization_gap_loss}\")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e7508c22-bf65-4fbf-88a5-4cd96aef06bb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved data for epoch 0 to IB_epoch_0_z_2.npy\n",
      "Saved data for epoch 1 to IB_epoch_1_z_2.npy\n",
      "Saved data for epoch 2 to IB_epoch_2_z_2.npy\n",
      "Saved data for epoch 3 to IB_epoch_3_z_2.npy\n",
      "Saved data for epoch 4 to IB_epoch_4_z_2.npy\n",
      "Saved data for epoch 5 to IB_epoch_5_z_2.npy\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[2], line 11\u001b[0m\n\u001b[0;32m      1\u001b[0m MODEL_CONFIG \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m      2\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moptimizer\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSGD\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m      3\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m0.0005\u001b[39m,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m      7\u001b[0m     \n\u001b[0;32m      8\u001b[0m }\n\u001b[0;32m     10\u001b[0m \u001b[38;5;66;03m# Call the train_model function with the config\u001b[39;00m\n\u001b[1;32m---> 11\u001b[0m \u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mMODEL_CONFIG\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[1;32mIn[1], line 111\u001b[0m, in \u001b[0;36mtrain_model\u001b[1;34m(config)\u001b[0m\n\u001b[0;32m    101\u001b[0m model\u001b[38;5;241m.\u001b[39mcompile(optimizer\u001b[38;5;241m=\u001b[39mopt,\n\u001b[0;32m    102\u001b[0m               loss\u001b[38;5;241m=\u001b[39m{\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mCE\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcategorical_crossentropy\u001b[39m\u001b[38;5;124m'\u001b[39m},\n\u001b[0;32m    103\u001b[0m               metrics\u001b[38;5;241m=\u001b[39m{\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mCE\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124maccuracy\u001b[39m\u001b[38;5;124m'\u001b[39m})\n\u001b[0;32m    105\u001b[0m reporter \u001b[38;5;241m=\u001b[39m getMIOutput(trn\u001b[38;5;241m=\u001b[39mx_1train,\n\u001b[0;32m    106\u001b[0m                        tst\u001b[38;5;241m=\u001b[39mx_1test,\n\u001b[0;32m    107\u001b[0m                        Z_layer_idx\u001b[38;5;241m=\u001b[39mconfig[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mz_idx\u001b[39m\u001b[38;5;124m\"\u001b[39m],  \u001b[38;5;66;03m# Use z_idx from config\u001b[39;00m\n\u001b[0;32m    108\u001b[0m                        num_selection\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10000\u001b[39m,\n\u001b[0;32m    109\u001b[0m                        do_save_func\u001b[38;5;241m=\u001b[39mdo_report_MN)\n\u001b[1;32m--> 111\u001b[0m history \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mx_1train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mY_1train\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    112\u001b[0m \u001b[43m                    \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbatch_size\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# Use batch size from config\u001b[39;49;00m\n\u001b[0;32m    113\u001b[0m \u001b[43m                    \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mepoch\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# Use number of epochs from config\u001b[39;49;00m\n\u001b[0;32m    114\u001b[0m \u001b[43m                    \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m    115\u001b[0m \u001b[43m                    \u001b[49m\u001b[43mvalidation_data\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mx_1test\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mY_1test\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    116\u001b[0m \u001b[43m                    \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43mreporter\u001b[49m\u001b[43m,\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    118\u001b[0m \u001b[38;5;66;03m# Print the final generalization gap (train accuracy - test accuracy / train loss - test loss)\u001b[39;00m\n\u001b[0;32m    119\u001b[0m final_train_acc \u001b[38;5;241m=\u001b[39m history\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124maccuracy\u001b[39m\u001b[38;5;124m'\u001b[39m][\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\py39\\lib\\site-packages\\keras\\utils\\traceback_utils.py:65\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m     63\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m     64\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m---> 65\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m fn(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m     66\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m     67\u001b[0m     filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\py39\\lib\\site-packages\\keras\\engine\\training.py:1564\u001b[0m, in \u001b[0;36mModel.fit\u001b[1;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[0;32m   1556\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m tf\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mexperimental\u001b[38;5;241m.\u001b[39mTrace(\n\u001b[0;32m   1557\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m   1558\u001b[0m     epoch_num\u001b[38;5;241m=\u001b[39mepoch,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1561\u001b[0m     _r\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m,\n\u001b[0;32m   1562\u001b[0m ):\n\u001b[0;32m   1563\u001b[0m     callbacks\u001b[38;5;241m.\u001b[39mon_train_batch_begin(step)\n\u001b[1;32m-> 1564\u001b[0m     tmp_logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_function\u001b[49m\u001b[43m(\u001b[49m\u001b[43miterator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1565\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m data_handler\u001b[38;5;241m.\u001b[39mshould_sync:\n\u001b[0;32m   1566\u001b[0m         context\u001b[38;5;241m.\u001b[39masync_wait()\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\py39\\lib\\site-packages\\tensorflow\\python\\util\\traceback_utils.py:150\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    148\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m    149\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 150\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m fn(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m    151\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m    152\u001b[0m   filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\py39\\lib\\site-packages\\tensorflow\\python\\eager\\def_function.py:915\u001b[0m, in \u001b[0;36mFunction.__call__\u001b[1;34m(self, *args, **kwds)\u001b[0m\n\u001b[0;32m    912\u001b[0m compiler \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mxla\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jit_compile \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnonXla\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    914\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m OptionalXlaContext(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jit_compile):\n\u001b[1;32m--> 915\u001b[0m   result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\n\u001b[0;32m    917\u001b[0m new_tracing_count \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mexperimental_get_tracing_count()\n\u001b[0;32m    918\u001b[0m without_tracing \u001b[38;5;241m=\u001b[39m (tracing_count \u001b[38;5;241m==\u001b[39m new_tracing_count)\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\py39\\lib\\site-packages\\tensorflow\\python\\eager\\def_function.py:947\u001b[0m, in \u001b[0;36mFunction._call\u001b[1;34m(self, *args, **kwds)\u001b[0m\n\u001b[0;32m    944\u001b[0m   \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock\u001b[38;5;241m.\u001b[39mrelease()\n\u001b[0;32m    945\u001b[0m   \u001b[38;5;66;03m# In this case we have created variables on the first call, so we run the\u001b[39;00m\n\u001b[0;32m    946\u001b[0m   \u001b[38;5;66;03m# defunned version which is guaranteed to never create variables.\u001b[39;00m\n\u001b[1;32m--> 947\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_stateless_fn(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)  \u001b[38;5;66;03m# pylint: disable=not-callable\u001b[39;00m\n\u001b[0;32m    948\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_stateful_fn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    949\u001b[0m   \u001b[38;5;66;03m# Release the lock early so that multiple threads can perform the call\u001b[39;00m\n\u001b[0;32m    950\u001b[0m   \u001b[38;5;66;03m# in parallel.\u001b[39;00m\n\u001b[0;32m    951\u001b[0m   \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock\u001b[38;5;241m.\u001b[39mrelease()\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\py39\\lib\\site-packages\\tensorflow\\python\\eager\\function.py:2496\u001b[0m, in \u001b[0;36mFunction.__call__\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   2493\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock:\n\u001b[0;32m   2494\u001b[0m   (graph_function,\n\u001b[0;32m   2495\u001b[0m    filtered_flat_args) \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_maybe_define_function(args, kwargs)\n\u001b[1;32m-> 2496\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mgraph_function\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_flat\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   2497\u001b[0m \u001b[43m    \u001b[49m\u001b[43mfiltered_flat_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcaptured_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgraph_function\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcaptured_inputs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\py39\\lib\\site-packages\\tensorflow\\python\\eager\\function.py:1862\u001b[0m, in \u001b[0;36mConcreteFunction._call_flat\u001b[1;34m(self, args, captured_inputs, cancellation_manager)\u001b[0m\n\u001b[0;32m   1858\u001b[0m possible_gradient_type \u001b[38;5;241m=\u001b[39m gradients_util\u001b[38;5;241m.\u001b[39mPossibleTapeGradientTypes(args)\n\u001b[0;32m   1859\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (possible_gradient_type \u001b[38;5;241m==\u001b[39m gradients_util\u001b[38;5;241m.\u001b[39mPOSSIBLE_GRADIENT_TYPES_NONE\n\u001b[0;32m   1860\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m executing_eagerly):\n\u001b[0;32m   1861\u001b[0m   \u001b[38;5;66;03m# No tape is watching; skip to running the function.\u001b[39;00m\n\u001b[1;32m-> 1862\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_build_call_outputs(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_inference_function\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcall\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   1863\u001b[0m \u001b[43m      \u001b[49m\u001b[43mctx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcancellation_manager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcancellation_manager\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[0;32m   1864\u001b[0m forward_backward \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_select_forward_and_backward_functions(\n\u001b[0;32m   1865\u001b[0m     args,\n\u001b[0;32m   1866\u001b[0m     possible_gradient_type,\n\u001b[0;32m   1867\u001b[0m     executing_eagerly)\n\u001b[0;32m   1868\u001b[0m forward_function, args_with_tangents \u001b[38;5;241m=\u001b[39m forward_backward\u001b[38;5;241m.\u001b[39mforward()\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\py39\\lib\\site-packages\\tensorflow\\python\\eager\\function.py:499\u001b[0m, in \u001b[0;36m_EagerDefinedFunction.call\u001b[1;34m(self, ctx, args, cancellation_manager)\u001b[0m\n\u001b[0;32m    497\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _InterpolateFunctionError(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m    498\u001b[0m   \u001b[38;5;28;01mif\u001b[39;00m cancellation_manager \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m--> 499\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m \u001b[43mexecute\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    500\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mstr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msignature\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    501\u001b[0m \u001b[43m        \u001b[49m\u001b[43mnum_outputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_num_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    502\u001b[0m \u001b[43m        \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    503\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattrs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattrs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    504\u001b[0m \u001b[43m        \u001b[49m\u001b[43mctx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mctx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    505\u001b[0m   \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m    506\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m execute\u001b[38;5;241m.\u001b[39mexecute_with_cancellation(\n\u001b[0;32m    507\u001b[0m         \u001b[38;5;28mstr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msignature\u001b[38;5;241m.\u001b[39mname),\n\u001b[0;32m    508\u001b[0m         num_outputs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_outputs,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    511\u001b[0m         ctx\u001b[38;5;241m=\u001b[39mctx,\n\u001b[0;32m    512\u001b[0m         cancellation_manager\u001b[38;5;241m=\u001b[39mcancellation_manager)\n",
      "File \u001b[1;32m~\\anaconda3\\envs\\py39\\lib\\site-packages\\tensorflow\\python\\eager\\execute.py:54\u001b[0m, in \u001b[0;36mquick_execute\u001b[1;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[0;32m     52\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m     53\u001b[0m   ctx\u001b[38;5;241m.\u001b[39mensure_initialized()\n\u001b[1;32m---> 54\u001b[0m   tensors \u001b[38;5;241m=\u001b[39m \u001b[43mpywrap_tfe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mTFE_Py_Execute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mctx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_handle\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mop_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m     55\u001b[0m \u001b[43m                                      \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattrs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_outputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     56\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m core\u001b[38;5;241m.\u001b[39m_NotOkStatusException \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m     57\u001b[0m   \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "MODEL_CONFIG = {\n",
    "    \"optimizer\": \"SGD\",\n",
    "    \"lr\": 0.0005,\n",
    "    \"batch_size\": 128,\n",
    "    \"epoch\": 200,\n",
    "    \"z_idx\": 2,\n",
    "    \n",
    "}\n",
    "\n",
    "# Call the train_model function with the config\n",
    "train_model(MODEL_CONFIG)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea47f7b5-fd39-4bb0-ae1d-4e7b84b7212c",
   "metadata": {},
   "outputs": [],
   "source": [
    "importgenerate_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92a57f9d-81ef-44d6-b0e3-a5c9c6a8e0f5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
