{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"main.ipynb","provenance":[],"collapsed_sections":[],"machine_shape":"hm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"mNdZRfeYIZoC","executionInfo":{"status":"ok","timestamp":1636293769938,"user_tz":-420,"elapsed":10670,"user":{"displayName":"hồng sơn lương","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"05966228294099207788"}},"outputId":"869ae020-3e9a-4d8a-ca6a-5dbcf8862bf3"},"source":["!pip install from-root\n","!pip install emnist\n","!pip install -U PyYAML"],"execution_count":1,"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting from-root\n","  Downloading from_root-1.0.2-py3-none-any.whl (6.5 kB)\n","Installing collected packages: from-root\n","Successfully installed from-root-1.0.2\n","Collecting emnist\n","  Downloading emnist-0.0-py3-none-any.whl (7.3 kB)\n","Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from emnist) (4.62.3)\n","Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from emnist) (2.23.0)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from emnist) (1.19.5)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->emnist) (2021.5.30)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->emnist) (2.10)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->emnist) (3.0.4)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->emnist) (1.24.3)\n","Installing collected packages: emnist\n","Successfully installed emnist-0.0\n","Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (3.13)\n","Collecting PyYAML\n","  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)\n","\u001b[K     |████████████████████████████████| 596 kB 5.0 MB/s \n","\u001b[?25hInstalling collected packages: PyYAML\n","  Attempting uninstall: PyYAML\n","    Found existing installation: PyYAML 3.13\n","    Uninstalling PyYAML-3.13:\n","      Successfully uninstalled PyYAML-3.13\n","Successfully installed PyYAML-6.0\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"fK_DoXmzIh2_","executionInfo":{"status":"ok","timestamp":1636293803861,"user_tz":-420,"elapsed":32052,"user":{"displayName":"hồng sơn lương","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"05966228294099207788"}},"outputId":"72fa57cd-d193-4ceb-c0c7-2a68bbf9211f"},"source":["from google.colab import drive\n","drive.mount('/content/drive')\n","%cd /content/drive/Shareddrives/Duong-Son/FL-project/fl_learning/\n","!ls"],"execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n","/content/drive/Shareddrives/Duong-Son/FL-project/fl_learning\n","export_env.sh  main.py\tpackages\t      utils\n","main.ipynb     model\tsystem_configuration\n"]}]},{"cell_type":"code","metadata":{"id":"nPpL81A-PgMA","executionInfo":{"status":"ok","timestamp":1636293810125,"user_tz":-420,"elapsed":6272,"user":{"displayName":"hồng sơn lương","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"05966228294099207788"}}},"source":["# from-root\n","import sys\n","from from_root import from_root\n","sys.path.append(str(from_root()))\n","#\n","from utils.header import *\n","from packages.dataset.dataset import JointKDDataset\n","from packages.engine.derivative_nw import TeacherNetwork, StudentNetwork\n","from utils.c_reliability import CReliability\n","from utils.helper_function import *"],"execution_count":3,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"F0fzMbNYPjXG","executionInfo":{"status":"ok","timestamp":1635953366175,"user_tz":-540,"elapsed":61409,"user":{"displayName":"Dương Minh","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"17555676595979985239"}},"outputId":"ba1b4bc1-fecb-43ca-bc28-9546f3d0fe9d"},"source":["start=datetime.now()\n","# get configure\n","g_config = GeneralConfigure()\n","# prepare dataset\n","dataset = JointKDDataset()\n","cluster_data = dataset.preProcessData()\n","list_of_samples_region_train, list_of_samples_region_test = dataset.assignDataForClients(cluster_data)\n","\n","# get student network\n","student_model = StudentNetwork()\n","student  = keras.models.Model(inputs  = student_model.model.input,                      \n","                              outputs = student_model.model.get_layer('logits').output)\n","student.compile(optimizer='adam',\n","                loss='sparse_categorical_crossentropy', \n","                metrics=['accuracy'])\n","# define distilation loss function\n","distillation_loss_fn = keras.losses.KLDivergence()\n","# create teacher model\n","train_able = False\n","if train_able == True:\n","  teacher_params = []\n","  teacher = []\n","  for region in range(g_config.regions):\n","    teacher.append(TeacherNetwork(list_of_samples_region_train[region+1],\n","                                        list_of_samples_region_test[region+1],\n","                                        list_of_samples_region_train[0],\n","                                        list_of_samples_region_test[0],\n","                                        student,\n","                                        distillation_loss_fn,\n","                                        g_config.model_path, region, dataset.num_classes))\n","    teacher_param, _ = teacher[region].regionalAggregation()\n","    teacher_params.append(teacher_param)\n","else:\n","  pass\n","# C-Reliability \n","c_reliability = CReliability(list_of_samples_region_train[0][0], dataset.num_classes)\n","beta = c_reliability.weightedClass()\n","\n","dataset_server = processDataServer(list_of_samples_region_train[0][0])"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["num_of_clients: 0  -  13395\n","num_of_clients: 1  -  446\n","num_of_clients: 2  -  3024\n","num_of_clients: 3  -  3025\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":314},"id":"tU8OMmnWSMna","executionInfo":{"status":"error","timestamp":1635952590268,"user_tz":-420,"elapsed":21023,"user":{"displayName":"hồng sơn lương","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"05966228294099207788"}},"outputId":"587fa17c-0c69-413b-f474-3a6bcfaa1f0f"},"source":["for epoch in range(g_config.distil_epochs):\n","  print(\"Start of epoch %d\" % (epoch,))\n","  for step, batch_train in enumerate(dataset_server):\n","    loss = 0\n","    x_batch_tf, y_batch_tf = batch_train\n","    x_batch = x_batch_tf.numpy()\n","    y_batch = np.argmax(y_batch_tf.numpy(), axis=1)\n","    # print(len(x_batch))\n","    # print(len(y_batch))\n","    for region in range(g_config.regions):\n","      teacher_model = tf.keras.models.load_model(os.path.join(c_reliability.model_path,f\"teacher_region_{region}.h5\"),compile=False)\n","      logit_teacher_model = keras.models.Model(inputs  = teacher_model.input, outputs = teacher_model.get_layer('logits').output)\n","      logit_teacher_model.compile(optimizer='adam',\n","                                  loss='sparse_categorical_crossentropy',\n","                                  metrics=['sparse_categorical_accuracy'])\n","      logits_predict = logit_teacher_model(x_batch, training=False)\n","      softmax_predict = softmax(logits_predict)\n","      rounded_predict = np.argmax(softmax_predict, axis = 1)\n","      pseudo_dataset = list(zip(x_batch, rounded_predict, y_batch))\n","\n","      aligned_data_teacher, aligned_label_teacher = dataAlignment(pseudo_dataset, dataset.num_classes)\n","      teacher_predictions = {}\n","      for label in range(dataset.num_classes):\n","        # Calculate individual loss by \n","        # Forward pass of teacher\n","        teacher_predictions[label] = logit_teacher_model(aligned_data_teacher[label].reshape(-1,28,28,1), training=False)\n","      with tf.GradientTape() as tape:\n","        # Calculate each label-driven loss\n","        # Forward pass of student\n","        for label in range(dataset.num_classes):\n","          student_predictions = student(aligned_data_teacher[label].reshape(-1,28,28,1), training=True)\n","          # Compute losses at $label$ round\n","          # print(aligned_label_teacher[label])\n","          # print(student_predictions)\n","          student_loss = g_config.student_loss_fn(aligned_label_teacher[label], student_predictions)\n","          distillation_loss = beta[region][label] * distillation_loss_fn(tf.nn.softmax(teacher_predictions[label] / 20, axis=1),\n","                                                      tf.nn.softmax(student_predictions        / 20, axis=1))\n","          if label == 0: \n","              concat_student_predictions = student_predictions\n","          else:\n","              concat_student_predictions = tf.concat([concat_student_predictions,student_predictions], 0)\n","          loss += g_config.alpha * student_loss + (1 - g_config.alpha) * distillation_loss\n","    # Compute gradients\n","    trainable_vars = student.trainable_variables\n","    gradients = tape.gradient(loss, trainable_vars)\n","\n","    # Update weights\n","    g_config.optimizer.apply_gradients(zip(gradients, trainable_vars))\n","    if step % 2 == 0:\n","      print(f\"step {step}: studentloss = {student_loss}, distillation loss = {loss}, accuracy = {g_config.accuracy_metric.result()}\")\n","      student.evaluate(x = list_of_samples_region_test[0][0], y = list_of_samples_region_test[0][1])\n","stop = datetime.now()\n","print('Total Time: ', stop - start)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Start of epoch 0\n"]},{"output_type":"error","ename":"ResourceExhaustedError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mResourceExhaustedError\u001b[0m                    Traceback (most recent call last)","\u001b[0;32m<ipython-input-16-6c19fd7f017b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     14\u001b[0m                                   \u001b[0mloss\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'sparse_categorical_crossentropy'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m                                   metrics=['sparse_categorical_accuracy'])\n\u001b[0;32m---> 16\u001b[0;31m       \u001b[0mlogits_predict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogit_teacher_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     17\u001b[0m       \u001b[0msoftmax_predict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits_predict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     18\u001b[0m       \u001b[0mrounded_predict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msoftmax_predict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/keras/engine/base_layer.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1035\u001b[0m         with autocast_variable.enable_auto_cast_variables(\n\u001b[1;32m   1036\u001b[0m             self._compute_dtype_object):\n\u001b[0;32m-> 1037\u001b[0;31m           \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcall_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1038\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1039\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_activity_regularizer\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py\u001b[0m in \u001b[0;36mcall\u001b[0;34m(self, inputs, training, mask)\u001b[0m\n\u001b[1;32m    413\u001b[0m     \"\"\"\n\u001b[1;32m    414\u001b[0m     return self._run_internal_graph(\n\u001b[0;32m--> 415\u001b[0;31m         inputs, training=training, mask=mask)\n\u001b[0m\u001b[1;32m    416\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    417\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0mcompute_output_shape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py\u001b[0m in \u001b[0;36m_run_internal_graph\u001b[0;34m(self, inputs, training, mask)\u001b[0m\n\u001b[1;32m    548\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    549\u001b[0m         \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmap_arguments\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor_dict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 550\u001b[0;31m         \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    551\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    552\u001b[0m         \u001b[0;31m# Update tensor_dict.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/keras/engine/base_layer.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1035\u001b[0m         with autocast_variable.enable_auto_cast_variables(\n\u001b[1;32m   1036\u001b[0m             self._compute_dtype_object):\n\u001b[0;32m-> 1037\u001b[0;31m           \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcall_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1038\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1039\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_activity_regularizer\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/keras/layers/convolutional.py\u001b[0m in \u001b[0;36mcall\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m    247\u001b[0m       \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compute_causal_padding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    248\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 249\u001b[0;31m     \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_convolution_op\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkernel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    250\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    251\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_bias\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    204\u001b[0m     \u001b[0;34m\"\"\"Call target, and fall back on dispatchers if there is a TypeError.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    205\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 206\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    207\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mTypeError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    208\u001b[0m       \u001b[0;31m# Note: convert_to_eager_tensor currently raises a ValueError, not a\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/nn_ops.py\u001b[0m in \u001b[0;36mconvolution_v2\u001b[0;34m(input, filters, strides, padding, data_format, dilations, name)\u001b[0m\n\u001b[1;32m   1136\u001b[0m       \u001b[0mdata_format\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata_format\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1137\u001b[0m       \u001b[0mdilations\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdilations\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1138\u001b[0;31m       name=name)\n\u001b[0m\u001b[1;32m   1139\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1140\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/nn_ops.py\u001b[0m in \u001b[0;36mconvolution_internal\u001b[0;34m(input, filters, strides, padding, data_format, dilations, name, call_from_convolution, num_spatial_dims)\u001b[0m\n\u001b[1;32m   1266\u001b[0m           \u001b[0mdata_format\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata_format\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1267\u001b[0m           \u001b[0mdilations\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdilations\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1268\u001b[0;31m           name=name)\n\u001b[0m\u001b[1;32m   1269\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1270\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0mchannel_index\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/nn_ops.py\u001b[0m in \u001b[0;36m_conv2d_expanded_batch\u001b[0;34m(input, filters, strides, padding, data_format, dilations, name)\u001b[0m\n\u001b[1;32m   2720\u001b[0m         \u001b[0mdata_format\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata_format\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2721\u001b[0m         \u001b[0mdilations\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdilations\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2722\u001b[0;31m         name=name)\n\u001b[0m\u001b[1;32m   2723\u001b[0m   return squeeze_batch_dims(\n\u001b[1;32m   2724\u001b[0m       \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/gen_nn_ops.py\u001b[0m in \u001b[0;36mconv2d\u001b[0;34m(input, filter, strides, padding, use_cudnn_on_gpu, explicit_paddings, data_format, dilations, name)\u001b[0m\n\u001b[1;32m    930\u001b[0m       \u001b[0;32mreturn\u001b[0m \u001b[0m_result\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    931\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0m_core\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 932\u001b[0;31m       \u001b[0m_ops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraise_from_not_ok_status\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    933\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0m_core\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_FallbackException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    934\u001b[0m       \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py\u001b[0m in \u001b[0;36mraise_from_not_ok_status\u001b[0;34m(e, name)\u001b[0m\n\u001b[1;32m   6939\u001b[0m   \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m\" name: \"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m\"\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   6940\u001b[0m   \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 6941\u001b[0;31m   \u001b[0msix\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraise_from\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_status_to_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   6942\u001b[0m   \u001b[0;31m# pylint: enable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   6943\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/six.py\u001b[0m in \u001b[0;36mraise_from\u001b[0;34m(value, from_value)\u001b[0m\n","\u001b[0;31mResourceExhaustedError\u001b[0m: OOM when allocating tensor with shape[7000,128,28,28] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:Conv2D]"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":429},"id":"t9BttHXTIikO","executionInfo":{"status":"error","timestamp":1636293633194,"user_tz":-420,"elapsed":145924,"user":{"displayName":"hồng sơn lương","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"05966228294099207788"}},"outputId":"35af1946-713a-4d40-8cdd-b3000ebbfbda"},"source":["def main():\n","    start=datetime.now()\n","    # get configure\n","    g_config = GeneralConfigure()\n","    # prepare dataset\n","    dataset = JointKDDataset()\n","    cluster_data = dataset.preProcessData()\n","    list_of_samples_region_train, list_of_samples_region_test = dataset.assignDataForClients(cluster_data)\n","    \n","    # get student network\n","    student_model = StudentNetwork()\n","    student  = keras.models.Model(inputs  = student_model.model.input,                      \n","                                 outputs = student_model.model.get_layer('logits').output)\n","    student.compile(optimizer='adam',\n","                    loss='sparse_categorical_crossentropy', \n","                    metrics=['accuracy'])\n","    # define distilation loss function\n","    distillation_loss_fn = keras.losses.KLDivergence()\n","    # create teacher model\n","    train_able = False\n","    if train_able == True:\n","        teacher_params = []\n","        teacher = []\n","        for region in range(g_config.regions):\n","            teacher.append(TeacherNetwork(list_of_samples_region_train[region+1],\n","                                        list_of_samples_region_test[region+1],\n","                                        list_of_samples_region_train[0],\n","                                        list_of_samples_region_test[0],\n","                                        student,\n","                                        distillation_loss_fn,\n","                                        g_config.model_path, region, dataset.num_classes))\n","            teacher_param, _ = teacher[region].regionalAggregation()\n","            teacher_params.append(teacher_param)\n","    else:\n","        pass\n","    # C-Reliability \n","    c_reliability = CReliability(list_of_samples_region_train[0][0], dataset.num_classes)\n","    beta = c_reliability.weightedClass()\n","\n","    dataset_server = processDataServer(list_of_samples_region_train[0][0])\n","    for epoch in range(g_config.distil_epochs):\n","        print(\"Start of epoch %d\" % (epoch,))\n","        for step, batch_train in enumerate(dataset_server):\n","            loss = 0\n","            x_batch_tf, y_batch_tf = batch_train\n","            x_batch = x_batch_tf.numpy()\n","            y_batch = np.argmax(y_batch_tf.numpy(), axis=1)\n","            # print(len(x_batch))\n","            # print(len(y_batch))\n","            for region in range(g_config.regions):\n","                teacher_model = tf.keras.models.load_model(os.path.join(c_reliability.model_path,f\"teacher_region_{region}.h5\"),compile=False)\n","                logit_teacher_model = keras.models.Model(inputs  = teacher_model.input, outputs = teacher_model.get_layer('logits').output)\n","                logit_teacher_model.compile(optimizer='adam',\n","                                            loss='sparse_categorical_crossentropy',\n","                                            metrics=['sparse_categorical_accuracy'])\n","                logits_predict = logit_teacher_model(x_batch, training=False)\n","                softmax_predict = softmax(logits_predict)\n","                rounded_predict = np.argmax(softmax_predict, axis = 1)\n","                pseudo_dataset = list(zip(x_batch, rounded_predict, y_batch))\n","\n","                aligned_data_teacher, aligned_label_teacher = dataAlignment(pseudo_dataset, dataset.num_classes)\n","                teacher_predictions = {}\n","                for label in range(dataset.num_classes):\n","                    # Calculate individual loss by \n","                    # Forward pass of teacher\n","                    teacher_predictions[label] = logit_teacher_model(aligned_data_teacher[label].reshape(-1,28,28,1), training=False)\n","                with tf.GradientTape() as tape:\n","                    # Calculate each label-driven loss\n","                    # Forward pass of student\n","                    for label in range(dataset.num_classes):\n","                        student_predictions = student(aligned_data_teacher[label].reshape(-1,28,28,1), training=True)\n","                        # Compute losses at $label$ round\n","                        student_loss = g_config.student_loss_fn(aligned_label_teacher[label], student_predictions)\n","                        distillation_loss = beta[region][label] * distillation_loss_fn(tf.nn.softmax(teacher_predictions[label] / 20, axis=1),\n","                                                                    tf.nn.softmax(student_predictions        / 20, axis=1))\n","                        if label == 0: \n","                            concat_student_predictions = student_predictions\n","                        else:\n","                            concat_student_predictions = tf.concat([concat_student_predictions,student_predictions], 0)\n","                        loss += g_config.alpha * student_loss + (1 - g_config.alpha) * distillation_loss\n","            # Compute gradients\n","            trainable_vars = student.trainable_variables\n","            gradients = tape.gradient(loss, trainable_vars)\n","\n","            # Update weights\n","            g_config.optimizer.apply_gradients(zip(gradients, trainable_vars))       \n","\n","            if step % 2 == 0:\n","                print(f\"step {step}: studentloss = {student_loss}, distillation loss = {loss}, accuracy = {g_config.accuracy_metric.result()}\")\n","                student.evaluate(x = list_of_samples_region_test[1][0][0], y = np.argmax(list_of_samples_region_test[1][0][1], axis=1))\n","\n","    stop = datetime.now()\n","    print('Total Time: ', stop - start)\n","\n","\n","if __name__ == \"__main__\":\n","    os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'\n","    main()"],"execution_count":4,"outputs":[{"output_type":"stream","name":"stderr","text":["Downloading emnist.zip: 536MB [00:05, 95.6MB/s]\n"]},{"output_type":"stream","name":"stdout","text":["num_of_clients: 0  -  13395\n","num_of_clients: 1  -  446\n","num_of_clients: 2  -  3024\n","num_of_clients: 3  -  3025\n","Start of epoch 0\n"]},{"output_type":"error","ename":"KeyError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)","\u001b[0;32m<ipython-input-4-73c101a2d4fd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     96\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"__main__\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     97\u001b[0m     \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menviron\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'TF_FORCE_GPU_ALLOW_GROWTH'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'true'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 98\u001b[0;31m     \u001b[0mmain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;32m<ipython-input-4-73c101a2d4fd>\u001b[0m in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m     71\u001b[0m                         \u001b[0mstudent_predictions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstudent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maligned_data_teacher\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlabel\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m28\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m28\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     72\u001b[0m                         \u001b[0;31m# Compute losses at $label$ round\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m                         \u001b[0mstudent_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mg_config\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstudent_loss_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maligned_label_teacher\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlabel\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstudent_predictions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     74\u001b[0m                         distillation_loss = beta[region][label] * distillation_loss_fn(tf.nn.softmax(teacher_predictions[label] / 20, axis=1),\n\u001b[1;32m     75\u001b[0m                                                                     tf.nn.softmax(student_predictions        / 20, axis=1))\n","\u001b[0;31mKeyError\u001b[0m: 2"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"APBeW0Iqrq_8","outputId":"d3ba44f8-1a11-4c36-87dd-71a1a1d39790"},"source":["from utils.header import *\n","from packages.dataset.dataset import JointKDDataset\n","from packages.engine.derivative_nw import TeacherNetwork, StudentNetwork\n","from utils.c_reliability import CReliability\n","from utils.helper_function import *\n","\n","def main():\n","    start=datetime.now()\n","    # get configure\n","    g_config = GeneralConfigure()\n","    # prepare dataset\n","    dataset = JointKDDataset()\n","    cluster_data = dataset.preProcessData()\n","    list_of_samples_region_train, list_of_samples_region_test = dataset.assignDataForClients(cluster_data)\n","    \n","    # get student network\n","    student_model = StudentNetwork()\n","    student  = keras.models.Model(inputs  = student_model.model.input,                      \n","                                 outputs = student_model.model.get_layer('logits').output)\n","    student.compile(optimizer='adam',\n","                    loss='sparse_categorical_crossentropy', \n","                    metrics=['accuracy'])\n","    # define distilation loss function\n","    distillation_loss_fn = keras.losses.KLDivergence()\n","    # create teacher model\n","    train_able = False\n","    if train_able == True:\n","        teacher_params = []\n","        teacher = []\n","        for region in range(g_config.regions):\n","            teacher.append(TeacherNetwork(list_of_samples_region_train[region+1],\n","                                        list_of_samples_region_test[region+1],\n","                                        list_of_samples_region_train[0],\n","                                        list_of_samples_region_test[0],\n","                                        student,\n","                                        distillation_loss_fn,\n","                                        g_config.model_path, region, dataset.num_classes))\n","            teacher_param, _ = teacher[region].regionalAggregation()\n","            teacher_params.append(teacher_param)\n","    else:\n","        pass\n","    # C-Reliability \n","    c_reliability = CReliability(list_of_samples_region_train[0][0], dataset.num_classes)\n","    beta = c_reliability.weightedClass()\n","\n","    dataset_server = processDataServer(list_of_samples_region_train[0][0])\n","    for epoch in range(g_config.distil_epochs):\n","        print(\"Start of epoch %d\" % (epoch,))\n","        for step, batch_train in enumerate(dataset_server):\n","            loss = 0\n","            x_batch_tf, y_batch_tf = batch_train\n","            x_batch = x_batch_tf.numpy()\n","            y_batch = np.argmax(y_batch_tf.numpy(), axis=1)\n","            # print(len(x_batch))\n","            # print(len(y_batch))\n","            for region in range(g_config.regions):\n","                teacher_model = tf.keras.models.load_model(os.path.join(c_reliability.model_path,f\"teacher_region_{region}.h5\"),compile=False)\n","                logit_teacher_model = keras.models.Model(inputs  = teacher_model.input, outputs = teacher_model.get_layer('logits').output)\n","                logit_teacher_model.compile(optimizer='adam',\n","                                            loss='sparse_categorical_crossentropy',\n","                                            metrics=['sparse_categorical_accuracy'])\n","                logits_predict = logit_teacher_model(x_batch, training=False)\n","                softmax_predict = softmax(logits_predict)\n","                rounded_predict = np.argmax(softmax_predict, axis = 1)\n","                pseudo_dataset = list(zip(x_batch, rounded_predict, y_batch))\n","\n","                aligned_data_teacher, aligned_label_teacher = dataAlignment(pseudo_dataset, dataset.num_classes)\n","                teacher_predictions = {}\n","                for label in range(dataset.num_classes):\n","                    # Calculate individual loss by \n","                    # Forward pass of teacher\n","                    teacher_predictions[label] = logit_teacher_model(aligned_data_teacher[label].reshape(-1,28,28,1), training=False)\n","                with tf.GradientTape() as tape:\n","                    # Calculate each label-driven loss\n","                    # Forward pass of student\n","                    for label in range(0, dataset.num_classes):\n","                        if len(aligned_data_teacher[label]) == 0:\n","                            distillation_loss = 0\n","                            student_loss = 0\n","                        else:\n","                            student_predictions = student(aligned_data_teacher[label].reshape(-1,28,28,1), training=True)\n","                            student_loss = g_config.student_loss_fn(aligned_label_teacher[label], student_predictions)\n","                            distillation_loss = beta[region][label] * distillation_loss_fn(tf.nn.softmax(teacher_predictions[label] / 20, axis=1),\n","                                                                    tf.nn.softmax(student_predictions        / 20, axis=1))\n","                            \n","                            if label == 0 or step == 0: \n","                                concat_student_predictions = student_predictions\n","                            else:\n","                                concat_student_predictions = tf.concat([concat_student_predictions,student_predictions], 0)\n","                        loss += g_config.alpha * student_loss + (1 - g_config.alpha) * distillation_loss\n","            # Compute gradients\n","            trainable_vars = student.trainable_variables\n","            gradients = tape.gradient(loss, trainable_vars)\n","\n","            # Update weights\n","            g_config.optimizer.apply_gradients(zip(gradients, trainable_vars))       \n","\n","            if step % 2 == 0:\n","                print(f\"step {step}: studentloss = {student_loss}, distillation loss = {loss}, accuracy = {g_config.accuracy_metric.result()}\")\n","                student.evaluate(x = list_of_samples_region_test[1][0][0], y = np.argmax(list_of_samples_region_test[1][0][1], axis=1))\n","\n","    stop = datetime.now()\n","    print('Total Time: ', stop - start)\n","\n","\n","if __name__ == \"__main__\":\n","    os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'\n","    main()"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["Downloading emnist.zip: 536MB [00:03, 155MB/s]\n"]},{"output_type":"stream","name":"stdout","text":["num_of_clients: 0  -  13395\n","num_of_clients: 1  -  446\n","num_of_clients: 2  -  3024\n","num_of_clients: 3  -  3025\n","Start of epoch 0\n","step 0: studentloss = 3.8191587924957275, distillation loss = 7.71936559677124, accuracy = 0.0\n","14/14 [==============================] - 1s 33ms/step - loss: 8.2787 - accuracy: 0.0695\n","step 2: studentloss = 3.7260489463806152, distillation loss = 7.10902214050293, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 8.1610 - accuracy: 0.0830\n","step 4: studentloss = 4.08432674407959, distillation loss = 7.162630558013916, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 6.5540 - accuracy: 0.0695\n","step 6: studentloss = 3.9012537002563477, distillation loss = 6.686339378356934, accuracy = 0.0\n","14/14 [==============================] - 0s 33ms/step - loss: 5.4848 - accuracy: 0.1996\n","step 8: studentloss = 3.6225669384002686, distillation loss = 5.72910213470459, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 4.0966 - accuracy: 0.3363\n","step 10: studentloss = 2.195136070251465, distillation loss = 5.183002948760986, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 4.1710 - accuracy: 0.2937\n","step 12: studentloss = 4.171525001525879, distillation loss = 5.315173625946045, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 3.5744 - accuracy: 0.4103\n","step 14: studentloss = 2.984687328338623, distillation loss = 3.9971299171447754, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 3.7063 - accuracy: 0.4238\n","step 16: studentloss = 0, distillation loss = 4.665457248687744, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 3.4120 - accuracy: 0.4439\n","step 18: studentloss = 2.5794944763183594, distillation loss = 4.272212982177734, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 3.1046 - accuracy: 0.5224\n","step 20: studentloss = 1.958707332611084, distillation loss = 3.881497383117676, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.9237 - accuracy: 0.4933\n","step 22: studentloss = 0.05403607711195946, distillation loss = 3.853257656097412, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.9630 - accuracy: 0.5291\n","step 24: studentloss = 1.4831351041793823, distillation loss = 3.8306753635406494, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.9821 - accuracy: 0.5673\n","step 26: studentloss = 2.035869598388672, distillation loss = 3.3127973079681396, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 3.0172 - accuracy: 0.5717\n","step 28: studentloss = 3.9262189865112305, distillation loss = 3.3605387210845947, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.9951 - accuracy: 0.5852\n","step 30: studentloss = 3.4732117652893066, distillation loss = 3.060673713684082, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 3.0354 - accuracy: 0.5740\n","step 32: studentloss = 6.24877405166626, distillation loss = 3.20624041557312, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.9311 - accuracy: 0.6054\n","step 34: studentloss = 3.101273536682129, distillation loss = 2.506582498550415, accuracy = 0.0\n","14/14 [==============================] - 0s 33ms/step - loss: 2.9490 - accuracy: 0.6099\n","step 36: studentloss = 2.829397201538086, distillation loss = 2.334752321243286, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.9459 - accuracy: 0.6413\n","step 38: studentloss = 2.5742852687835693, distillation loss = 2.3410637378692627, accuracy = 0.0\n","14/14 [==============================] - 0s 35ms/step - loss: 2.9030 - accuracy: 0.6233\n","step 40: studentloss = 3.036294937133789, distillation loss = 3.1050474643707275, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.9648 - accuracy: 0.6233\n","step 42: studentloss = 2.0585060119628906, distillation loss = 2.6840388774871826, accuracy = 0.0\n","14/14 [==============================] - 0s 35ms/step - loss: 2.9535 - accuracy: 0.6256\n","step 44: studentloss = 0.022435294464230537, distillation loss = 2.1309075355529785, accuracy = 0.0\n","14/14 [==============================] - 0s 35ms/step - loss: 2.8211 - accuracy: 0.6368\n","step 46: studentloss = 0, distillation loss = 2.1691527366638184, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8031 - accuracy: 0.6704\n","step 48: studentloss = 0.9015950560569763, distillation loss = 2.605268716812134, accuracy = 0.0\n","14/14 [==============================] - 0s 35ms/step - loss: 2.7968 - accuracy: 0.6883\n","step 50: studentloss = 0.9407201409339905, distillation loss = 2.728671073913574, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.7955 - accuracy: 0.6771\n","step 52: studentloss = 0.6810805797576904, distillation loss = 2.518712043762207, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8138 - accuracy: 0.6771\n","step 54: studentloss = 0.5776483416557312, distillation loss = 2.0172247886657715, accuracy = 0.0\n","14/14 [==============================] - 0s 35ms/step - loss: 2.8720 - accuracy: 0.6771\n","step 56: studentloss = 1.985278606414795, distillation loss = 2.403078317642212, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8706 - accuracy: 0.6996\n","step 58: studentloss = 0, distillation loss = 2.1082355976104736, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8931 - accuracy: 0.7063\n","step 60: studentloss = 0.5950157046318054, distillation loss = 2.270151376724243, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8895 - accuracy: 0.7130\n","step 62: studentloss = 3.4442319869995117, distillation loss = 2.057637929916382, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8472 - accuracy: 0.7220\n","step 64: studentloss = 4.018974781036377, distillation loss = 1.7721405029296875, accuracy = 0.0\n","14/14 [==============================] - 1s 35ms/step - loss: 2.8456 - accuracy: 0.7130\n","step 66: studentloss = 2.6868736743927, distillation loss = 2.038865089416504, accuracy = 0.0\n","14/14 [==============================] - 0s 35ms/step - loss: 2.8359 - accuracy: 0.7399\n","step 68: studentloss = 2.5742197036743164, distillation loss = 1.9774155616760254, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8403 - accuracy: 0.7130\n","step 70: studentloss = 1.336325764656067, distillation loss = 1.9661903381347656, accuracy = 0.0\n","14/14 [==============================] - 1s 35ms/step - loss: 2.8437 - accuracy: 0.7242\n","step 72: studentloss = 1.4556376934051514, distillation loss = 1.4319987297058105, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8590 - accuracy: 0.7332\n","step 74: studentloss = 1.3579570055007935, distillation loss = 1.559253215789795, accuracy = 0.0\n","14/14 [==============================] - 0s 35ms/step - loss: 2.8556 - accuracy: 0.7332\n","step 76: studentloss = 0.8662737607955933, distillation loss = 1.904473900794983, accuracy = 0.0\n","14/14 [==============================] - 0s 35ms/step - loss: 2.8552 - accuracy: 0.7466\n","step 78: studentloss = 0.43100348114967346, distillation loss = 1.8170123100280762, accuracy = 0.0\n","14/14 [==============================] - 1s 35ms/step - loss: 2.8332 - accuracy: 0.7534\n","step 80: studentloss = 0.007159008178859949, distillation loss = 1.9766261577606201, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8253 - accuracy: 0.7578\n","step 82: studentloss = 0.3844163417816162, distillation loss = 1.8308039903640747, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8439 - accuracy: 0.7601\n","step 84: studentloss = 0.39438652992248535, distillation loss = 1.5253403186798096, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8284 - accuracy: 0.7422\n","step 86: studentloss = 0.13573719561100006, distillation loss = 1.9140123128890991, accuracy = 0.0\n","14/14 [==============================] - 0s 33ms/step - loss: 2.8324 - accuracy: 0.7623\n","step 88: studentloss = 0.17223544418811798, distillation loss = 1.514766812324524, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8210 - accuracy: 0.7399\n","step 90: studentloss = 1.6409600973129272, distillation loss = 2.088489055633545, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8505 - accuracy: 0.7377\n","step 92: studentloss = 0.0032877461053431034, distillation loss = 1.930861234664917, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8563 - accuracy: 0.7646\n","step 94: studentloss = 4.875369071960449, distillation loss = 1.957887053489685, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8989 - accuracy: 0.7489\n","step 96: studentloss = 0.15455377101898193, distillation loss = 1.6769365072250366, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8604 - accuracy: 0.7556\n","step 98: studentloss = 1.4462553262710571, distillation loss = 1.6575192213058472, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8538 - accuracy: 0.7803\n","step 100: studentloss = 0.08834616094827652, distillation loss = 1.53761887550354, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8262 - accuracy: 0.7848\n","step 102: studentloss = 1.3528661727905273, distillation loss = 1.6651241779327393, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8634 - accuracy: 0.7937\n","step 104: studentloss = 1.0328130722045898, distillation loss = 1.4892802238464355, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8656 - accuracy: 0.8027\n","Start of epoch 1\n","step 0: studentloss = 0.009958011098206043, distillation loss = 1.5656688213348389, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8651 - accuracy: 0.7915\n","step 2: studentloss = 0.03348323330283165, distillation loss = 1.4918272495269775, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8605 - accuracy: 0.7870\n","step 4: studentloss = 0, distillation loss = 1.363111138343811, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8291 - accuracy: 0.7780\n","step 6: studentloss = 0.7483108043670654, distillation loss = 1.260353684425354, accuracy = 0.0\n","14/14 [==============================] - 0s 35ms/step - loss: 2.8218 - accuracy: 0.7735\n","step 8: studentloss = 0.05495346710085869, distillation loss = 1.5930144786834717, accuracy = 0.0\n","14/14 [==============================] - 0s 35ms/step - loss: 2.8142 - accuracy: 0.7780\n","step 10: studentloss = 0.017883000895380974, distillation loss = 1.5084556341171265, accuracy = 0.0\n","14/14 [==============================] - 0s 35ms/step - loss: 2.8089 - accuracy: 0.7780\n","step 12: studentloss = 0, distillation loss = 1.3743647336959839, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.7972 - accuracy: 0.7915\n","step 14: studentloss = 0.011106324382126331, distillation loss = 1.351064920425415, accuracy = 0.0\n","14/14 [==============================] - 0s 35ms/step - loss: 2.7944 - accuracy: 0.7937\n","step 16: studentloss = 1.0208660364151, distillation loss = 1.8219324350357056, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.7922 - accuracy: 0.8161\n","step 18: studentloss = 0.4795767068862915, distillation loss = 1.3481330871582031, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.7868 - accuracy: 0.8184\n","step 20: studentloss = 1.840871810913086, distillation loss = 1.398158073425293, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.7853 - accuracy: 0.8139\n","step 22: studentloss = 1.7177187204360962, distillation loss = 1.6436841487884521, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.7871 - accuracy: 0.8072\n","step 24: studentloss = 2.0267457962036133, distillation loss = 1.2439801692962646, accuracy = 0.0\n","14/14 [==============================] - 0s 35ms/step - loss: 2.7919 - accuracy: 0.8229\n","step 26: studentloss = 0.002617150079458952, distillation loss = 1.2268844842910767, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.7964 - accuracy: 0.8094\n","step 28: studentloss = 0.533573567867279, distillation loss = 1.6433820724487305, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8189 - accuracy: 0.8184\n","step 30: studentloss = 1.1640921831130981, distillation loss = 1.412760853767395, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8078 - accuracy: 0.8117\n","step 32: studentloss = 0.22995667159557343, distillation loss = 1.2514835596084595, accuracy = 0.0\n","14/14 [==============================] - 0s 35ms/step - loss: 2.8081 - accuracy: 0.8049\n","step 34: studentloss = 0.17965975403785706, distillation loss = 1.4882394075393677, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.7702 - accuracy: 0.7825\n","step 36: studentloss = 0.11477375030517578, distillation loss = 1.2234344482421875, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.7600 - accuracy: 0.7937\n","step 38: studentloss = 0.4953252971172333, distillation loss = 1.3751802444458008, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.7593 - accuracy: 0.7937\n","step 40: studentloss = 0.7169988751411438, distillation loss = 1.4279909133911133, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.7665 - accuracy: 0.7735\n","step 42: studentloss = 1.111491084098816, distillation loss = 1.2579998970031738, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.7815 - accuracy: 0.8027\n","step 44: studentloss = 0.4872099757194519, distillation loss = 1.5147778987884521, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.7945 - accuracy: 0.8049\n","step 46: studentloss = 0.012679283507168293, distillation loss = 1.3932892084121704, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8004 - accuracy: 0.8004\n","step 48: studentloss = 1.5057584047317505, distillation loss = 1.3584169149398804, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8243 - accuracy: 0.8072\n","step 50: studentloss = 0.03701483830809593, distillation loss = 1.117074966430664, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8171 - accuracy: 0.8094\n","step 52: studentloss = 3.1355528831481934, distillation loss = 1.3615750074386597, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.8215 - accuracy: 0.8184\n","step 54: studentloss = 0.041664667427539825, distillation loss = 1.4830960035324097, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.7982 - accuracy: 0.8027\n","step 56: studentloss = 0.3378978669643402, distillation loss = 1.2266757488250732, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.7901 - accuracy: 0.7780\n","step 58: studentloss = 0.051076821982860565, distillation loss = 1.3453096151351929, accuracy = 0.0\n","14/14 [==============================] - 0s 34ms/step - loss: 2.7747 - accuracy: 0.7848\n"]}]}]}