{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e38215d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T04:32:03.326508Z",
     "start_time": "2023-04-28T04:32:01.589536Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "import tensorflow_addons as tfa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "335bf866",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T04:32:04.222783Z",
     "start_time": "2023-04-28T04:32:03.331987Z"
    }
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "## Prepare the data\n",
    "\"\"\"\n",
    "\n",
    "num_classes = 2\n",
    "input_shape = (32, 32, 3)\n",
    "\n",
    "(x_train_o, y_train_o), (x_test_o, y_test_o) = keras.datasets.cifar10.load_data()\n",
    "\n",
    "all_index_train = np.where((y_train_o==0 ) | (y_train_o==1))[0]\n",
    "all_index_test = np.where((y_test_o==0) | (y_test_o==1))[0]\n",
    "\n",
    "np.random.seed(1234)\n",
    "LS_index = all_index_train[np.random.randint(0,len(all_index_train),500)]\n",
    "Separability_images,Separability_labels = x_train_o[LS_index,:], y_train_o[LS_index,]\n",
    "\n",
    "x_train,y_train = x_train_o[all_index_train,:],y_train_o[all_index_train,:]\n",
    "x_test,y_test = x_test_o[all_index_test,:],y_test_o[all_index_test,:]\n",
    "\n",
    "\n",
    "\n",
    "print(f\"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}\")\n",
    "print(f\"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b114103c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T04:32:04.240224Z",
     "start_time": "2023-04-28T04:32:04.235762Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "\"\"\"\n",
    "## Configure the hyperparameters\n",
    "\"\"\"\n",
    "\n",
    "learning_rate = 0.001\n",
    "weight_decay = 0.0001\n",
    "batch_size = 256\n",
    "num_epochs = 100\n",
    "image_size = 72  # We'll resize input images to this size\n",
    "patch_size = 6  # Size of the patches to be extract from the input images\n",
    "num_patches = (image_size // patch_size) ** 2\n",
    "projection_dim = 64\n",
    "num_heads = 4\n",
    "transformer_units = [\n",
    "    projection_dim * 2,\n",
    "    projection_dim,\n",
    "]  # Size of the transformer layers\n",
    "transformer_layers = 5\n",
    "mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0214e2fb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T04:32:05.552127Z",
     "start_time": "2023-04-28T04:32:04.242295Z"
    }
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "## Use data augmentation\n",
    "\"\"\"\n",
    "\n",
    "data_augmentation = keras.Sequential(\n",
    "    [\n",
    "        layers.Normalization(),\n",
    "        layers.Resizing(image_size, image_size),\n",
    "        layers.RandomFlip(\"horizontal\"),\n",
    "        layers.RandomRotation(factor=0.02),\n",
    "        layers.RandomZoom(height_factor=0.2, width_factor=0.2),\n",
    "    ],\n",
    "    name=\"data_augmentation\",\n",
    ")\n",
    "# Compute the mean and the variance of the training data for normalization.\n",
    "data_augmentation.layers[0].adapt(x_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bee1591f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T04:32:05.563562Z",
     "start_time": "2023-04-28T04:32:05.556430Z"
    }
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "## Implement multilayer perceptron (MLP)\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "def mlp(x, hidden_units, dropout_rate):\n",
    "    for units in hidden_units:\n",
    "        x = layers.Dense(units, activation=tf.nn.gelu)(x)\n",
    "        x = layers.Dropout(dropout_rate)(x)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9adf4df8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T04:32:05.577348Z",
     "start_time": "2023-04-28T04:32:05.567475Z"
    }
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "## Implement patch creation as a layer\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "class Patches(layers.Layer):\n",
    "    def __init__(self, patch_size):\n",
    "        super().__init__()\n",
    "        self.patch_size = patch_size\n",
    "\n",
    "    def call(self, images):\n",
    "        batch_size = tf.shape(images)[0]\n",
    "        patches = tf.image.extract_patches(\n",
    "            images=images,\n",
    "            sizes=[1, self.patch_size, self.patch_size, 1],\n",
    "            strides=[1, self.patch_size, self.patch_size, 1],\n",
    "            rates=[1, 1, 1, 1],\n",
    "            padding=\"VALID\",\n",
    "        )\n",
    "        patch_dims = patches.shape[-1]\n",
    "        patches = tf.reshape(patches, [batch_size, -1, patch_dims])\n",
    "        return patches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "687a44b4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T04:32:09.338036Z",
     "start_time": "2023-04-28T04:32:05.580137Z"
    }
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Let's display patches for a sample image\n",
    "\"\"\"\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure(figsize=(4, 4))\n",
    "image = x_train[np.random.choice(range(x_train.shape[0]))]\n",
    "plt.imshow(image.astype(\"uint8\"))\n",
    "plt.axis(\"off\")\n",
    "\n",
    "resized_image = tf.image.resize(\n",
    "    tf.convert_to_tensor([image]), size=(image_size, image_size)\n",
    ")\n",
    "patches = Patches(patch_size)(resized_image)\n",
    "print(f\"Image size: {image_size} X {image_size}\")\n",
    "print(f\"Patch size: {patch_size} X {patch_size}\")\n",
    "print(f\"Patches per image: {patches.shape[1]}\")\n",
    "print(f\"Elements per patch: {patches.shape[-1]}\")\n",
    "\n",
    "n = int(np.sqrt(patches.shape[1]))\n",
    "plt.figure(figsize=(4, 4))\n",
    "for i, patch in enumerate(patches[0]):\n",
    "    ax = plt.subplot(n, n, i + 1)\n",
    "    patch_img = tf.reshape(patch, (patch_size, patch_size, 3))\n",
    "    plt.imshow(patch_img.numpy().astype(\"uint8\"))\n",
    "    plt.axis(\"off\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db52bccf",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T04:32:09.347653Z",
     "start_time": "2023-04-28T04:32:09.342253Z"
    }
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "## Implement the patch encoding layer\n",
    "The `PatchEncoder` layer will linearly transform a patch by projecting it into a\n",
    "vector of size `projection_dim`. In addition, it adds a learnable position\n",
    "embedding to the projected vector.\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "class PatchEncoder(layers.Layer):\n",
    "    def __init__(self, num_patches, projection_dim):\n",
    "        super().__init__()\n",
    "        self.num_patches = num_patches\n",
    "        self.projection = layers.Dense(units=projection_dim)\n",
    "        self.position_embedding = layers.Embedding(\n",
    "            input_dim=num_patches, output_dim=projection_dim\n",
    "        )\n",
    "\n",
    "    def call(self, patch):\n",
    "        positions = tf.range(start=0, limit=self.num_patches, delta=1)\n",
    "        encoded = self.projection(patch) + self.position_embedding(positions)\n",
    "        return encoded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9a5b8e6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T04:32:09.357133Z",
     "start_time": "2023-04-28T04:32:09.349890Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def create_vit_classifier():\n",
    "    inputs = layers.Input(shape=input_shape)\n",
    "    # Augment data.\n",
    "    augmented = data_augmentation(inputs)\n",
    "    # Create patches.\n",
    "    patches = Patches(patch_size)(augmented)\n",
    "    # Encode patches.\n",
    "    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)\n",
    "\n",
    "    # Create multiple layers of the Transformer block.\n",
    "    for _ in range(transformer_layers):\n",
    "        # Layer normalization 1.\n",
    "        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)\n",
    "        # Create a multi-head attention layer.\n",
    "        attention_output = layers.MultiHeadAttention(\n",
    "            num_heads=num_heads, key_dim=projection_dim, dropout=0.1\n",
    "        )(x1, x1)\n",
    "        # Skip connection 1.\n",
    "        x2 = layers.Add()([attention_output, encoded_patches])\n",
    "        # Layer normalization 2.\n",
    "        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)\n",
    "        # MLP.\n",
    "        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)\n",
    "        # Skip connection 2.\n",
    "        encoded_patches = layers.Add()([x3, x2])\n",
    "\n",
    "    # Create a [batch_size, projection_dim] tensor.\n",
    "    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)\n",
    "    representation = layers.Flatten()(representation)\n",
    "    representation = layers.Dropout(0.5)(representation)\n",
    "    # Add MLP.\n",
    "    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)\n",
    "    # Classify outputs.\n",
    "    logits = layers.Dense(num_classes)(features)\n",
    "    # Create the Keras model.\n",
    "    model = keras.Model(inputs=inputs, outputs=logits)\n",
    "    return model\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9c90298",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T04:32:10.473710Z",
     "start_time": "2023-04-28T04:32:09.358911Z"
    }
   },
   "outputs": [],
   "source": [
    "vit_classifier = create_vit_classifier()\n",
    "vit_classifier.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fc430c5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T04:32:10.485069Z",
     "start_time": "2023-04-28T04:32:10.476198Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "layer_outputs = [\n",
    "    layer.output for layer in vit_classifier.layers if 'input' not in layer.name and 'data_augmentation'not in layer.name and 'random_zoom' not in layer.name\n",
    "]\n",
    "layer_names = [\n",
    "    layer.name for layer in vit_classifier.layers if 'input' not in layer.name and 'data_augmentation'not in layer.name and 'random_zoom' not in layer.name\n",
    "]\n",
    "\n",
    "activation_model = tf.keras.models.Model(inputs=vit_classifier.input,\n",
    "                                         outputs=layer_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d52ad3f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T04:32:17.051946Z",
     "start_time": "2023-04-28T04:32:10.487095Z"
    }
   },
   "outputs": [],
   "source": [
    "from Linear_Separability_numpy import *\n",
    "\n",
    "# parameters (train)\n",
    "num_epochs = 100\n",
    "\n",
    "batch_ep = 1\n",
    "\n",
    "x_plot = np.arange(num_epochs)*batch_ep\n",
    "reserved_layers = 0\n",
    "\n",
    "# initialize record matrix\n",
    "x=activation_model.predict(Separability_images)\n",
    "LS_1_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_2_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "J_w_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LDA_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_1_squence_base = np.zeros((1,num_epochs))\n",
    "LS_2_squence_base = np.zeros((1,num_epochs))\n",
    "J_w_squence_base = np.zeros((1,num_epochs))\n",
    "LDA_squence_base = np.zeros((1,num_epochs))\n",
    "train_loss_squence = np.zeros((num_epochs,))\n",
    "train_accuracy_squence = np.zeros((num_epochs,))\n",
    "test_loss_squence = np.zeros((num_epochs,))\n",
    "test_accuracy_squence = np.zeros((num_epochs,))\n",
    "\n",
    "#Separability_labels = np.argmax(Separability_labels,axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87cf5734",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T04:32:17.511012Z",
     "start_time": "2023-04-28T04:32:17.056090Z"
    }
   },
   "outputs": [],
   "source": [
    "LS_1_squence_base[0,:],LS_2_squence_base[0,:],\\\n",
    "J_w_squence_base[0,:],\\\n",
    "LDA_squence_base[0,:]=W(tf.constant(Separability_images/ 255.0),Separability_labels.reshape(-1,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0157de2",
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-04-28T04:31:15.044Z"
    }
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "## Compile, train, and evaluate the mode\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "\n",
    "optimizer = tfa.optimizers.AdamW(\n",
    "    learning_rate=learning_rate, weight_decay=weight_decay\n",
    ")\n",
    "\n",
    "vit_classifier.compile(\n",
    "    optimizer=optimizer,\n",
    "    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "    metrics=[\n",
    "        keras.metrics.SparseCategoricalAccuracy(name=\"accuracy\"),\n",
    "    ],\n",
    ")\n",
    "\n",
    "checkpoint_filepath = \"/tmp/checkpoint\"\n",
    "checkpoint_callback = keras.callbacks.ModelCheckpoint(\n",
    "    checkpoint_filepath,\n",
    "    monitor=\"val_accuracy\",\n",
    "    save_best_only=True,\n",
    "    save_weights_only=True,\n",
    ")\n",
    "\n",
    "for iter_e in range(num_epochs):\n",
    "            \n",
    "    history = vit_classifier.fit(\n",
    "        x=x_train,\n",
    "        y=y_train,\n",
    "        batch_size=batch_size,\n",
    "        epochs=1,\n",
    "        validation_split=0.1,\n",
    "        callbacks=[checkpoint_callback],\n",
    "    )\n",
    "\n",
    "    x=activation_model(Separability_images)\n",
    "\n",
    "    # Loss and Acc by epochs\n",
    "    train_loss_squence[iter_e],train_accuracy_squence[iter_e]=vit_classifier.evaluate(x_train,y_train)\n",
    "    test_loss_squence[iter_e],test_accuracy_squence[iter_e]=vit_classifier.evaluate(x_test, y_test)\n",
    "\n",
    "    # LS of every layer's output\n",
    "    for layers_i in range(len(x)-reserved_layers):\n",
    "        LS_1_squence[layers_i,iter_e],LS_2_squence[layers_i,iter_e],J_w_squence[layers_i,iter_e],\\\n",
    "        LDA_squence[layers_i,iter_e]=W(x[layers_i+reserved_layers],Separability_labels.reshape(-1,1))    \n",
    "\n",
    "    print('**********'+'the',iter_e,'epochs has finished'+'**********')\n",
    "\n",
    "                \n",
    "\n",
    "# vit_classifier.load_weights(checkpoint_filepath)\n",
    "# _, accuracy, top_5_accuracy = vit_classifier.evaluate(x_test, y_test)\n",
    "# print(f\"Test accuracy: {round(accuracy * 100, 2)}%\")\n",
    "# print(f\"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "540b976c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T12:04:31.222949Z",
     "start_time": "2023-04-28T11:59:04.041167Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for iter_e in range(99,num_epochs):\n",
    "            \n",
    "    history = vit_classifier.fit(\n",
    "        x=x_train,\n",
    "        y=y_train,\n",
    "        batch_size=batch_size,\n",
    "        epochs=1,\n",
    "        validation_split=0.1,\n",
    "        callbacks=[checkpoint_callback],\n",
    "    )\n",
    "\n",
    "    x=activation_model(Separability_images)\n",
    "\n",
    "    # Loss and Acc by epochs\n",
    "    train_loss_squence[iter_e],train_accuracy_squence[iter_e]=vit_classifier.evaluate(x_train,y_train)\n",
    "    test_loss_squence[iter_e],test_accuracy_squence[iter_e]=vit_classifier.evaluate(x_test, y_test)\n",
    "\n",
    "    # LS of every layer's output\n",
    "    for layers_i in range(len(x)-reserved_layers):\n",
    "        LS_1_squence[layers_i,iter_e],LS_2_squence[layers_i,iter_e],J_w_squence[layers_i,iter_e],\\\n",
    "        LDA_squence[layers_i,iter_e]=W(x[layers_i+reserved_layers],Separability_labels.reshape(-1,1))    \n",
    "\n",
    "    print('**********'+'the',iter_e,'epochs has finished'+'**********')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69370a8b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T12:53:01.908492Z",
     "start_time": "2023-04-28T12:53:01.905421Z"
    }
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "import time\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19776ee7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T12:53:10.896692Z",
     "start_time": "2023-04-28T12:53:02.390081Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "layer_name_list = get_layer_name(vit_classifier)[reserved_layers:]\n",
    "Separability_figure = plot_Separability_figure(layer_name_list,x_plot,LS_1_squence,LS_2_squence,J_w_squence,LDA_squence,LS_1_squence_base,LS_2_squence_base,J_w_squence_base,LDA_squence_base)\n",
    "net_figure = plot_net_figure(layer_name_list,x_plot,train_loss_squence,train_accuracy_squence,test_loss_squence,test_accuracy_squence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82465e0a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-28T12:53:15.858372Z",
     "start_time": "2023-04-28T12:53:15.844378Z"
    }
   },
   "outputs": [],
   "source": [
    "info={'layer_name_list':layer_name_list,'x_plot':x_plot,'LS_1_squence':LS_1_squence,'LS_2_squence':LS_2_squence,'J_w_squence':J_w_squence,\n",
    "            'LDA_squence':LDA_squence,'train_loss_squence':train_loss_squence,'train_accuracy_squence':train_accuracy_squence,'test_loss_squence':test_loss_squence,'test_accuracy_squence':test_accuracy_squence,'LS_1_squence_base':LS_1_squence_base,'LS_2_squence_base':LS_2_squence_base,'J_w_squence_base':J_w_squence_base,\n",
    "            'LDA_squence_base':LDA_squence_base,}"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (tensorflow2.7)",
   "language": "python",
   "name": "test1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
