{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e06b5c1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T03:50:21.151633Z",
     "start_time": "2024-11-20T03:50:21.148111Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c99fcb7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T03:50:23.702787Z",
     "start_time": "2024-11-20T03:50:21.153603Z"
    }
   },
   "outputs": [],
   "source": [
    "from Linear_Separability_numpy import *\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8814dd24",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T03:50:23.709592Z",
     "start_time": "2024-11-20T03:50:23.705331Z"
    }
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras import layers, models, regularizers\n",
    "from tensorflow.keras.datasets import imdb\n",
    "from tensorflow.keras.preprocessing.sequence import pad_sequences"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d222e5e3",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e18e9a2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T03:50:27.907281Z",
     "start_time": "2024-11-20T03:50:23.712983Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "max_features = 10000  \n",
    "maxlen = 500         \n",
    "\n",
    "(train_features, train_labels), (test_features, test_labels) = imdb.load_data(num_words=max_features)\n",
    "\n",
    "train_features = pad_sequences(train_features, maxlen=maxlen)\n",
    "test_features = pad_sequences(test_features, maxlen=maxlen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "181cbcbd",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T03:50:27.916645Z",
     "start_time": "2024-11-20T03:50:27.910221Z"
    }
   },
   "outputs": [],
   "source": [
    "# Print the number of positive and negative samples in the training labels\n",
    "unique, counts = np.unique(train_labels, return_counts=True)\n",
    "positive_count = counts[unique == 1][0] if 1 in unique else 0\n",
    "negative_count = counts[unique == 0][0] if 0 in unique else 0\n",
    "print(f\"Training labels - Positive count: {positive_count}, Negative count: {negative_count}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d714474",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T03:50:27.922909Z",
     "start_time": "2024-11-20T03:50:27.918946Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "subset_size = 200\n",
    "\n",
    "indices = np.random.choice(len(train_features), size=subset_size, replace=False)\n",
    "\n",
    "subset_features = train_features[indices]\n",
    "subset_labels = train_labels[indices]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45ae285e",
   "metadata": {},
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a29cdd6d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T03:50:27.927450Z",
     "start_time": "2024-11-20T03:50:27.924755Z"
    }
   },
   "outputs": [],
   "source": [
    "# parameters\n",
    "batch_size = 128"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "561327c9",
   "metadata": {},
   "source": [
    "# network structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e33651b4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T03:50:30.931298Z",
     "start_time": "2024-11-20T03:50:27.929580Z"
    }
   },
   "outputs": [],
   "source": [
    "# Generator of the chosen data using tf.data.Dataset\n",
    "train_dataset = tf.data.Dataset.from_tensor_slices((train_features, train_labels)).batch(batch_size).shuffle(buffer_size=1000)\n",
    "test_dataset = tf.data.Dataset.from_tensor_slices((test_features, test_labels)).batch(batch_size)\n",
    "\n",
    "\n",
    "model = tf.keras.Sequential([\n",
    "    layers.Embedding(input_dim=max_features, output_dim=128, input_length=maxlen),\n",
    "    \n",
    "    layers.Conv1D(filters=20, kernel_size=3, activation='relu', padding='same'),\n",
    "    layers.MaxPooling1D(pool_size=2, padding='same'),\n",
    "    \n",
    "    layers.Conv1D(filters=20, kernel_size=3, activation='relu', padding='same'),\n",
    "    layers.MaxPooling1D(pool_size=2, padding='same'),\n",
    "    \n",
    "    layers.Flatten(),\n",
    "    \n",
    "    layers.Dense(50, activation='sigmoid'),\n",
    "    #layers.Dropout(0.5),\n",
    "    layers.Dense(1, activation='sigmoid')\n",
    "])\n",
    "\n",
    "\n",
    "model.compile(loss='binary_crossentropy',\n",
    "              optimizer=tf.keras.optimizers.Adam(learning_rate=5e-6),\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "# Get every layer's output of the network\n",
    "layer_outputs = [\n",
    "    layer.output for layer in model.layers if ('dropout' not in layer.name and 'flatten' not in layer.name)\n",
    "]\n",
    "activation_model = tf.keras.models.Model(inputs=model.input, outputs=layer_outputs)\n",
    "\n",
    "# Example usage: Get activations for the test set\n",
    "activations = activation_model.predict(subset_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f1620d1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T03:50:30.941938Z",
     "start_time": "2024-11-20T03:50:30.935709Z"
    }
   },
   "outputs": [],
   "source": [
    "layer_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15138d0c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T03:50:30.949179Z",
     "start_time": "2024-11-20T03:50:30.944238Z"
    }
   },
   "outputs": [],
   "source": [
    "activations[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3a5fd53",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T03:50:30.956151Z",
     "start_time": "2024-11-20T03:50:30.951368Z"
    }
   },
   "outputs": [],
   "source": [
    "activations[1].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89588125",
   "metadata": {},
   "source": [
    "# initial record matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa0a7eaf",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T03:50:31.150480Z",
     "start_time": "2024-11-20T03:50:30.958428Z"
    }
   },
   "outputs": [],
   "source": [
    "# parameters (train)\n",
    "num_epochs = 100\n",
    "batch_ep = 1\n",
    "# batch_size = 32\n",
    "# num_batches = int(num_train_data // batch_size * num_epochs)\n",
    "x_plot = np.arange(num_epochs)*batch_ep\n",
    "reserved_layers = 0\n",
    "\n",
    "for Separability_image,Separability_label in train_dataset:\n",
    "    break\n",
    "Separability_label = tf.reshape(Separability_label, (-1, 1))\n",
    "\n",
    "# initialize record matrix\n",
    "x=activation_model.predict(Separability_image)\n",
    "\n",
    "LS_star_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_0_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_1_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_2_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LDA_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "\n",
    "LS_star_squence_base = np.zeros((1,num_epochs))\n",
    "LS_0_squence_base = np.zeros((1,num_epochs))\n",
    "LS_1_squence_base = np.zeros((1,num_epochs))\n",
    "LS_2_squence_base = np.zeros((1,num_epochs))\n",
    "LDA_squence_base = np.zeros((1,num_epochs))\n",
    "\n",
    "train_loss_squence = np.zeros((num_epochs,))\n",
    "train_accuracy_squence = np.zeros((num_epochs,))\n",
    "test_loss_squence = np.zeros((num_epochs,))\n",
    "test_accuracy_squence = np.zeros((num_epochs,))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "836d152f",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88db6281",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T04:02:09.695548Z",
     "start_time": "2024-11-20T03:50:31.157667Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# LS of original data\n",
    "LS_star_squence_base[0,:],LS_0_squence_base[0,:],\\\n",
    "LS_1_squence_base[0,:],LS_2_squence_base[0,:],\\\n",
    "LDA_squence_base[0,:]=W(tf.constant(subset_features),subset_labels.reshape(-1,1))\n",
    "\n",
    "\n",
    "for iter_e in range(num_epochs):\n",
    "    \n",
    "    # Train the model\n",
    "    model.fit(train_dataset, epochs=1, validation_data=test_dataset)\n",
    "    \n",
    "    x = activation_model(subset_features)\n",
    "    \n",
    "    # Loss and Acc by epochs\n",
    "    train_loss_squence[iter_e],train_accuracy_squence[iter_e] = model.evaluate(train_dataset)\n",
    "    test_loss_squence[iter_e],test_accuracy_squence[iter_e] = model.evaluate(test_dataset)\n",
    "    \n",
    "    # LS of every layer's output\n",
    "    for layers_i in range(len(x)-reserved_layers):\n",
    "        LS_star_squence[layers_i,iter_e],LS_0_squence[layers_i,iter_e],LS_1_squence[layers_i,iter_e],\\\n",
    "        LS_2_squence[layers_i,iter_e],LDA_squence[layers_i,iter_e]=W(x[layers_i+reserved_layers],subset_labels.reshape(-1,1))\n",
    "    \n",
    "    print('**********'+'the',iter_e,'epochs has finished'+'**********')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01d1d6e0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T04:02:09.702868Z",
     "start_time": "2024-11-20T04:02:09.698991Z"
    }
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "import time\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bee1b2e3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T04:02:09.710667Z",
     "start_time": "2024-11-20T04:02:09.705539Z"
    }
   },
   "outputs": [],
   "source": [
    "layer_name_list = get_layer_name(model)[reserved_layers:]\n",
    "layer_name_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebf3295a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T04:02:11.231913Z",
     "start_time": "2024-11-20T04:02:09.713201Z"
    }
   },
   "outputs": [],
   "source": [
    "layer_name_list = ['1st_embedding','1st_conv1d','1st_max_pooling','2nd_conv1d','2nd_max_pooling','1st_dense','2nd_dense']\n",
    "Separability_figure = plot_Separability_figure(layer_name_list,x_plot,LS_star_squence,LS_0_squence,LS_1_squence,LS_2_squence,LDA_squence,LS_star_squence_base,LS_0_squence_base,LS_1_squence_base,LS_2_squence_base,LDA_squence_base)\n",
    "net_figure = plot_net_figure(layer_name_list,x_plot,train_loss_squence,train_accuracy_squence,test_loss_squence,test_accuracy_squence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "229871e6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T04:02:11.238559Z",
     "start_time": "2024-11-20T04:02:11.234108Z"
    }
   },
   "outputs": [],
   "source": [
    "LS_star_squence.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b51fcab",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T04:02:11.244939Z",
     "start_time": "2024-11-20T04:02:11.240972Z"
    }
   },
   "outputs": [],
   "source": [
    "info={'layer_name_list':layer_name_list,'x_plot':x_plot,'LS_star_squence':LS_star_squence,'LS_0_squence':LS_0_squence,\n",
    "'LS_1_squence':LS_1_squence,'LS_2_squence':LS_2_squence,'LDA_squence':LDA_squence,\n",
    "'train_loss_squence':train_loss_squence,'train_accuracy_squence':train_accuracy_squence,\n",
    "'test_loss_squence':test_loss_squence,'test_accuracy_squence':test_accuracy_squence,\n",
    "'LS_star_squence_base':LS_star_squence_base,'LS_0_squence_base':LS_0_squence_base,'LS_1_squence_base':LS_1_squence_base,\n",
    " 'LS_2_squence_base':LS_2_squence_base, 'LDA_squence_base':LDA_squence_base}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff351e33",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-20T04:02:11.251486Z",
     "start_time": "2024-11-20T04:02:11.247234Z"
    }
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "def save_info(info, file_path):\n",
    "    with open(file_path, 'wb') as file:\n",
    "        pickle.dump(info, file)\n",
    "\n",
    "save_info(info, '../saved_data/CNN.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fed698e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5341e3b8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow-lsm",
   "language": "python",
   "name": "tensorflow-lsm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
