{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "904c0281-f9f0-4d1f-b5ac-d718f3492786",
   "metadata": {},
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "# <span style=\"color:blue\"> *UDV project - SVD based pruning*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "354b9fc0-342a-483c-bf55-a33ca67c2e22",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports libraries\n",
    "\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, Subset\n",
    "import numpy as np\n",
    "\n",
    "from torchvision import transforms\n",
    "from torchvision import models\n",
    "from collections import OrderedDict\n",
    "\n",
    "import pickle\n",
    "import time\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78c9549f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Set the global seed for reproducibility\n",
    "\n",
    "public_seed = 529\n",
    "torch.manual_seed(public_seed)\n",
    "print(\"Current Seed is {0}\".format(public_seed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e91d4b0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Obtain device\n",
    "\n",
    "from udvFunctions.udvDevice import get_device\n",
    "device = get_device()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "707acd6c-c250-458f-953d-f18f4203f4ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load dataset\n",
    "\n",
    "from udvFunctions.udvClassDataset import MNIST_Pre"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9089550-4d77-4c59-ac37-681700dc1cb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare the single-connection layer that carry weight matrix 'w'\n",
    "\n",
    "from udvFunctions.udvDiagonalLayer import D_singleConnection   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b41017af",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the validation loop\n",
    "\n",
    "from udvFunctions.udvClassVal import class_valLoop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69bd5c10",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Other customised functions\n",
    "\n",
    "from udvFunctions.udvOtherFunctions import store_metrics, take_avg, check_shapes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0cd1610",
   "metadata": {},
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "# <span style=\"color:blue\"> *Set SVD-based pruning test (Without re-train):*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d00ce2d-4b97-4372-bca6-4a1ab015115a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataset selection\n",
    "from udvFunctions.udvOtherFunctions import obtain_current_directory\n",
    "\n",
    "dataset_choice = 0 # 0 points MNIST dataset\n",
    "torch.manual_seed(public_seed)\n",
    "if dataset_choice == 0:\n",
    "    folder_list = obtain_current_directory(endString = \"PreTrue\") \n",
    "    data_path = \"./MNIST\"\n",
    "    train_dataset, val_dataset, num_output = MNIST_Pre(load_All = True, data_path = data_path)  \n",
    "    pre_path = \"./\"\n",
    "    print(\"load the MNIST dataset\")\n",
    "\n",
    "else:\n",
    "    print(\"Wrong selection on dataset\")\n",
    "\n",
    "model_list = [\"/SingleLayer/Seed_1_finalEpoch_70_model_0.pt\",\n",
    "              \"/SingleLayer/Seed_1_finalEpoch_70_model_1.pt\",\n",
    "              \"/SingleLayer/Seed_1_finalEpoch_70_model_2.pt\",\n",
    "              \"/SingleLayer/Seed_1_finalEpoch_70_model_3.pt\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f4479ef-366f-42e2-ae9f-d45634f4655c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Warning: raise Error if folder_list is not match\n",
    "# Revise the folder_list according to the baseline results\n",
    "from udvFunctions.udvOtherFunctions import CPU_Unpickler, non_Value, pruning_list, same_saving\n",
    "from udvFunctions.udvClassObtainModels import revised_model, pruning_model\n",
    "\n",
    "for drop_file_index in range(len(folder_list)): \n",
    "    # Set path and open file\n",
    "    store_file_path = pre_path + folder_list[drop_file_index] + '/SingleLayer/results.pkl'\n",
    "    print(\"current pickle file is: \", store_file_path)\n",
    "\n",
    "    with open(store_file_path, 'rb') as file:\n",
    "        #variables = CPU_Unpickler(file).load()\n",
    "        variables = pickle.load(file)\n",
    "        \n",
    "    # Read parameters\n",
    "    num_epochs = variables['num_epochs']\n",
    "    num_seeds = variables['num_seeds']\n",
    "    batch_size = variables['batch_size']\n",
    "\n",
    "    trans_model = variables['trans_model']\n",
    "    optimiser_name = variables['optimiser_name']\n",
    "    learning_rate = variables['learning_rate']\n",
    "\n",
    "    num_input = variables['num_input']\n",
    "    num_hidden_1 = variables['num_hidden_1']\n",
    "    num_output = variables['num_output']   \n",
    "\n",
    "    u_1_V_model_0 = variables['u_1_V_model_0']\n",
    "    w_1_V_model_0 = variables['w_1_V_model_0']\n",
    "    v_1_V_model_0 = variables['v_1_V_model_0']\n",
    "\n",
    "    u_1_V_model_1 = variables['u_1_V_model_1']\n",
    "    w_1_V_model_1 = variables['w_1_V_model_1']\n",
    "    v_1_V_model_1 = variables['v_1_V_model_1']\n",
    "\n",
    "    u_1_M_model_2 = variables['u_1_M_model_2']\n",
    "    w_1_M_model_2 = variables['w_1_M_model_2']\n",
    "    u_2_M_model_2 = variables['u_2_M_model_2']\n",
    "\n",
    "    u_1_M_model_3 = variables['u_1_M_model_3']\n",
    "    w_1_M_model_3 = variables['w_1_M_model_3']\n",
    "    u_2_M_model_3 = variables['u_2_M_model_3']\n",
    "    \n",
    "    seed_index_list = list(range(num_seeds))\n",
    "\n",
    "    num_workers = 4\n",
    "    train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)\n",
    "    val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers)\n",
    "    \n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "\n",
    "    # Re-producible setting\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "    # Create pruning list and blank space for saving validation loss\n",
    "    # This code gradually decrease number of hidden neurons and obtain validation loss\n",
    "    neurons_list = pruning_list(start_number = num_hidden_1, deduce_factor = 0.9) \n",
    "    val_losses_m_0 = [0] * (len(neurons_list) + 1)\n",
    "    val_losses_m_1 = [0] * (len(neurons_list) + 1)\n",
    "    val_losses_m_2 = [0] * (len(neurons_list) + 1)\n",
    "    val_losses_m_3 = [0] * (len(neurons_list) + 1)\n",
    "    val_acces_m_0 = [0] * (len(neurons_list) + 1)\n",
    "    val_acces_m_1 = [0] * (len(neurons_list) + 1)\n",
    "    val_acces_m_2 = [0] * (len(neurons_list) + 1)\n",
    "    val_acces_m_3 = [0] * (len(neurons_list) + 1)\n",
    "\n",
    "    # The SVD-based pruning results will be averaged by the same number of seeds as baseline code\n",
    "    for seed_index in seed_index_list:\n",
    "    \n",
    "    # =======================Model_0=================Model_0====================================\n",
    "        # If \"NaN\" is involved, no need to do the validation\n",
    "        if not (non_Value(u_1_V_model_0) or non_Value(w_1_V_model_0) or non_Value(v_1_V_model_0)):    \n",
    "            \n",
    "            # Set the pt file path and load model\n",
    "            store_pt_path = pre_path + folder_list[drop_file_index] + model_list[0]\n",
    "            model = revised_model(name = trans_model,\n",
    "                                  train_features = True,\n",
    "                                  pre_trained = True, \n",
    "                                  model_order = 0, \n",
    "                                  num_input = num_input,\n",
    "                                  num_hidden_1 = num_hidden_1, \n",
    "                                  num_output = num_output)\n",
    "            # Model for validation (Not trainable)\n",
    "            model.load_state_dict(torch.load(store_pt_path))\n",
    "            model.to(device)\n",
    "            \n",
    "            # Compare (Verify) the saved weights from pickle and pt files\n",
    "            same_saving(u1 = u_1_V_model_0[seed_index], w1 = w_1_V_model_0[seed_index], v1 = v_1_V_model_0[seed_index],\n",
    "                        u2 = model.classifier.fc1.weight, w2 = model.classifier.diag1.weight, v2 = model.classifier.fc2.weight)\n",
    "            \n",
    "            # Obtain the baseline (Original number of hidden neurons)\n",
    "            val_loss_drop, val_acc_drop = class_valLoop(model = model,\n",
    "                                                        loss_fn = loss_fn,\n",
    "                                                        val_loader = val_dataloader,\n",
    "                                                        device = device)\n",
    "            val_losses_m_0[0] += val_loss_drop\n",
    "            val_acces_m_0[0] += val_acc_drop\n",
    "            del model\n",
    "            \n",
    "            # SVD of model_0: Vector_uwv - UDV-v1\n",
    "            u1w1_m_0 = (torch.mul(w_1_V_model_0[seed_index].t(), u_1_V_model_0[seed_index])).t()\n",
    "            U_m_0, S_m_0, Vh_m_0 = torch.linalg.svd(u1w1_m_0, full_matrices = False)\n",
    "            \n",
    "            # Gradually reduce number of hidden neurons\n",
    "            for num_hidden_new_index, num_hidden_new in enumerate(neurons_list):\n",
    "                U_m_0_drop = U_m_0.clone().detach()[:,:num_hidden_new]\n",
    "                S_m_0_drop = S_m_0.clone().detach()[:num_hidden_new]\n",
    "                Vh_m_0_drop = Vh_m_0.clone().detach()[:num_hidden_new,:]\n",
    "                \n",
    "                #Re-load the model\n",
    "                model = revised_model(name = trans_model,\n",
    "                                      train_features = True,\n",
    "                                      pre_trained = True, \n",
    "                                      model_order = 0, \n",
    "                                      num_input = num_input,\n",
    "                                      num_hidden_1 = num_hidden_1, \n",
    "                                      num_output = num_output)\n",
    "                \n",
    "                # Model for validation (Not trainable) and set new top layers\n",
    "                model = pruning_model(model = model, \n",
    "                                      name = trans_model, \n",
    "                                      num_input = num_input,\n",
    "                                      num_hidden_new = num_hidden_new,\n",
    "                                      num_output = num_output,\n",
    "                                      ptPath = store_pt_path,\n",
    "                                      device = device)\n",
    "                \n",
    "                # Verify the weight matrices' shape\n",
    "                with torch.no_grad():\n",
    "                    check_shapes(model_weight = model.classifier.fc1.weight, list_sample = U_m_0_drop.t())\n",
    "                    check_shapes(model_weight = model.classifier.diag1.weight, list_sample = S_m_0_drop.unsqueeze(0))\n",
    "                    check_shapes(model_weight = model.classifier.fc2.weight, list_sample = (v_1_V_model_0[seed_index].clone().detach())@(Vh_m_0_drop.t()))\n",
    "                    model.classifier.fc1.weight = nn.Parameter(U_m_0_drop.t())\n",
    "                    model.classifier.diag1.weight = nn.Parameter(S_m_0_drop.unsqueeze(0))\n",
    "                    model.classifier.fc2.weight = nn.Parameter((v_1_V_model_0[seed_index].clone().detach())@(Vh_m_0_drop.t()))\n",
    "                \n",
    "                # Obtain the validation metrics \n",
    "                val_loss_drop, val_acc_drop = class_valLoop(model = model,\n",
    "                                                            loss_fn = loss_fn,\n",
    "                                                            val_loader = val_dataloader,\n",
    "                                                            device = device)\n",
    "                val_losses_m_0[num_hidden_new_index + 1] += val_loss_drop\n",
    "                val_acces_m_0[num_hidden_new_index + 1] += val_acc_drop\n",
    "                print(\"#neurons = {0} has been updated\".format(num_hidden_new))\n",
    "                \n",
    "                del model, U_m_0_drop, S_m_0_drop, Vh_m_0_drop, val_loss_drop, val_acc_drop\n",
    "            del u1w1_m_0, U_m_0, S_m_0, Vh_m_0\n",
    "        \n",
    "        else:\n",
    "            # If NaN things involved, we still maintrain the saving structuer\n",
    "            val_losses_m_0[0] += 10000\n",
    "            val_acces_m_0[0] += 10000\n",
    "            for num_hidden_new_index, num_hidden_new in enumerate(neurons_list):\n",
    "                val_losses_m_0[num_hidden_new_index + 1] += 10000\n",
    "                val_acces_m_0[num_hidden_new_index + 1] += 10000\n",
    "        print(\"finish test model 0\\n\\n\\n\")\n",
    "    # =======================Model_0======DONE=======Model_0====================================\n",
    "    \n",
    "    # =======================Model_1=================Model_1====================================\n",
    "        # UDV-v2\n",
    "        if not (non_Value(u_1_V_model_1) or non_Value(w_1_V_model_1) or non_Value(v_1_V_model_1)):    \n",
    "            store_pt_path = pre_path + folder_list[drop_file_index] + model_list[1]\n",
    "            model = revised_model(name = trans_model,\n",
    "                                  train_features = True,\n",
    "                                  pre_trained = True, \n",
    "                                  model_order = 1, \n",
    "                                  num_input = num_input,\n",
    "                                  num_hidden_1 = num_hidden_1, \n",
    "                                  num_output = num_output)\n",
    "\n",
    "            model.load_state_dict(torch.load(store_pt_path))\n",
    "            model.to(device)\n",
    "            \n",
    "            same_saving(u1 = u_1_V_model_1[seed_index], w1 = w_1_V_model_1[seed_index], v1 = v_1_V_model_1[seed_index],\n",
    "                        u2 = model.classifier.fc1.weight, w2 = model.classifier.diag1.weight, v2 = model.classifier.fc2.weight)\n",
    "            \n",
    "            val_loss_drop, val_acc_drop = class_valLoop(model = model,\n",
    "                                                        loss_fn = loss_fn,\n",
    "                                                        val_loader = val_dataloader,\n",
    "                                                        device = device)\n",
    "            val_losses_m_1[0] += val_loss_drop\n",
    "            val_acces_m_1[0] += val_acc_drop\n",
    "            del model\n",
    "            \n",
    "            u1w1_m_1 = (torch.mul(w_1_V_model_1[seed_index].t(), u_1_V_model_1[seed_index])).t()\n",
    "            U_m_1, S_m_1, Vh_m_1 = torch.linalg.svd(u1w1_m_1, full_matrices = False)\n",
    "\n",
    "            for num_hidden_new_index, num_hidden_new in enumerate(neurons_list):\n",
    "                U_m_1_drop = U_m_1.clone().detach()[:,:num_hidden_new]\n",
    "                S_m_1_drop = S_m_1.clone().detach()[:num_hidden_new]\n",
    "                Vh_m_1_drop = Vh_m_1.clone().detach()[:num_hidden_new,:]\n",
    "                \n",
    "                model = revised_model(name = trans_model,\n",
    "                                      train_features = True,\n",
    "                                      pre_trained = True, \n",
    "                                      model_order = 1, \n",
    "                                      num_input = num_input,\n",
    "                                      num_hidden_1 = num_hidden_1, \n",
    "                                      num_output = num_output)\n",
    "                model = pruning_model(model = model, \n",
    "                                      name = trans_model, \n",
    "                                      num_input = num_input,\n",
    "                                      num_hidden_new = num_hidden_new,\n",
    "                                      num_output = num_output,\n",
    "                                      ptPath = store_pt_path,\n",
    "                                      device = device)\n",
    "                \n",
    "                with torch.no_grad():\n",
    "                    check_shapes(model_weight = model.classifier.fc1.weight, list_sample = U_m_1_drop.t())\n",
    "                    check_shapes(model_weight = model.classifier.diag1.weight, list_sample = S_m_1_drop.unsqueeze(0))\n",
    "                    check_shapes(model_weight = model.classifier.fc2.weight, list_sample = (v_1_V_model_1[seed_index].clone().detach())@(Vh_m_1_drop.t()))\n",
    "                    model.classifier.fc1.weight = nn.Parameter(U_m_1_drop.t())\n",
    "                    model.classifier.diag1.weight = nn.Parameter(S_m_1_drop.unsqueeze(0))\n",
    "                    model.classifier.fc2.weight = nn.Parameter((v_1_V_model_1[seed_index].clone().detach())@(Vh_m_1_drop.t()))\n",
    "                \n",
    "                val_loss_drop, val_acc_drop = class_valLoop(model = model,\n",
    "                                                            loss_fn = loss_fn,\n",
    "                                                            val_loader = val_dataloader,\n",
    "                                                            device = device)\n",
    "                val_losses_m_1[num_hidden_new_index + 1] += val_loss_drop\n",
    "                val_acces_m_1[num_hidden_new_index + 1] += val_acc_drop\n",
    "                print(\"#neurons = {0} has been updated\".format(num_hidden_new))\n",
    "                \n",
    "                del model, U_m_1_drop, S_m_1_drop, Vh_m_1_drop, val_loss_drop, val_acc_drop\n",
    "            del u1w1_m_1, U_m_1, S_m_1, Vh_m_1\n",
    "        \n",
    "        else:\n",
    "            val_losses_m_1[0] += 10000\n",
    "            val_acces_m_1[0] += 10000\n",
    "            for num_hidden_new_index, num_hidden_new in enumerate(neurons_list):\n",
    "                val_losses_m_1[num_hidden_new_index + 1] += 10000\n",
    "                val_acces_m_1[num_hidden_new_index + 1] += 10000\n",
    "        print(\"finish test model 1\\n\\n\\n\")\n",
    "    # =======================Model_1======DONE=======Model_1==================================== \n",
    "\n",
    "    # =======================Model_2==================Model_2====================================\n",
    "        # UDV\n",
    "        if not (non_Value(u_1_M_model_2) or non_Value(w_1_M_model_2) or non_Value(u_2_M_model_2)):    \n",
    "            store_pt_path = pre_path + folder_list[drop_file_index] + model_list[2]\n",
    "            model = revised_model(name = trans_model,\n",
    "                                  train_features = True,\n",
    "                                  pre_trained = True, \n",
    "                                  model_order = 2, \n",
    "                                  num_input = num_input,\n",
    "                                  num_hidden_1 = num_hidden_1, \n",
    "                                  num_output = num_output)\n",
    "            \n",
    "            model.load_state_dict(torch.load(store_pt_path))\n",
    "            model.to(device)\n",
    "            \n",
    "            same_saving(u1 = u_1_M_model_2[seed_index], w1 = w_1_M_model_2[seed_index], v1 = u_2_M_model_2[seed_index],\n",
    "                        u2 = model.classifier.fc1.weight, w2 = model.classifier.diag1.weight, v2 = model.classifier.fc2.weight)\n",
    "            \n",
    "            val_loss_drop, val_acc_drop = class_valLoop(model = model,\n",
    "                                                        loss_fn = loss_fn,\n",
    "                                                        val_loader = val_dataloader,\n",
    "                                                        device = device)\n",
    "            val_losses_m_2[0] += val_loss_drop\n",
    "            val_acces_m_2[0] += val_acc_drop\n",
    "            del model\n",
    "            \n",
    "            u1w1_m_2 = (torch.mul(w_1_M_model_2[seed_index].t(), u_1_M_model_2[seed_index])).t()\n",
    "            U_m_2, S_m_2, Vh_m_2 = torch.linalg.svd(u1w1_m_2, full_matrices = False)\n",
    "            \n",
    "            for num_hidden_new_index, num_hidden_new in enumerate(neurons_list):\n",
    "                U_m_2_drop = U_m_2.clone().detach()[:,:num_hidden_new]\n",
    "                S_m_2_drop = S_m_2.clone().detach()[:num_hidden_new]\n",
    "                Vh_m_2_drop = Vh_m_2.clone().detach()[:num_hidden_new,:]\n",
    "                \n",
    "                model = revised_model(name = trans_model,\n",
    "                                      train_features = True,\n",
    "                                      pre_trained = True, \n",
    "                                      model_order = 2, \n",
    "                                      num_input = num_input,\n",
    "                                      num_hidden_1 = num_hidden_1, \n",
    "                                      num_output = num_output)\n",
    "                \n",
    "                model = pruning_model(model = model, \n",
    "                                      name = trans_model, \n",
    "                                      num_input = num_input,\n",
    "                                      num_hidden_new = num_hidden_new,\n",
    "                                      num_output = num_output,\n",
    "                                      ptPath = store_pt_path,\n",
    "                                      device = device)\n",
    "                \n",
    "                with torch.no_grad():\n",
    "                    check_shapes(model_weight = model.classifier.fc1.weight, list_sample = U_m_2_drop.t())\n",
    "                    check_shapes(model_weight = model.classifier.diag1.weight, list_sample = S_m_2_drop.unsqueeze(0))\n",
    "                    check_shapes(model_weight = model.classifier.fc2.weight, list_sample = (u_2_M_model_2[seed_index].clone().detach())@(Vh_m_2_drop.t()))\n",
    "                    model.classifier.fc1.weight = nn.Parameter(U_m_2_drop.t())\n",
    "                    model.classifier.diag1.weight = nn.Parameter(S_m_2_drop.unsqueeze(0))\n",
    "                    model.classifier.fc2.weight = nn.Parameter((u_2_M_model_2[seed_index].clone().detach())@(Vh_m_2_drop.t()))\n",
    "                \n",
    "                val_loss_drop, val_acc_drop = class_valLoop(model = model,\n",
    "                                                            loss_fn = loss_fn,\n",
    "                                                            val_loader = val_dataloader,\n",
    "                                                            device = device)\n",
    "                val_losses_m_2[num_hidden_new_index + 1] += val_loss_drop\n",
    "                val_acces_m_2[num_hidden_new_index + 1] += val_acc_drop\n",
    "                print(\"#neurons = {0} has been updated\".format(num_hidden_new))\n",
    "                \n",
    "                del model, U_m_2_drop, S_m_2_drop, Vh_m_2_drop, val_loss_drop, val_acc_drop\n",
    "            del u1w1_m_2, U_m_2, S_m_2, Vh_m_2\n",
    "        \n",
    "        else:\n",
    "            val_losses_m_2[0] += 10000\n",
    "            val_acces_m_2[0] += 10000\n",
    "            for num_hidden_new_index, num_hidden_new in enumerate(neurons_list):\n",
    "                val_losses_m_2[num_hidden_new_index + 1] += 10000\n",
    "                val_acces_m_2[num_hidden_new_index + 1] += 10000\n",
    "        print(\"finish test model 2\\n\\n\\n\")\n",
    "    # =======================Model_2======DONE=======Model_2==================================== \n",
    "\n",
    "    # =======================Model_3==================Model_3====================================\n",
    "        # UDV-s\n",
    "        if not (non_Value(u_1_M_model_3) or non_Value(w_1_M_model_3) or non_Value(u_2_M_model_3)):    \n",
    "            store_pt_path = pre_path + folder_list[drop_file_index] + model_list[3]\n",
    "            model = revised_model(name = trans_model,\n",
    "                                  train_features = True,\n",
    "                                  pre_trained = True, \n",
    "                                  model_order = 3, \n",
    "                                  num_input = num_input,\n",
    "                                  num_hidden_1 = num_hidden_1, \n",
    "                                  num_output = num_output)\n",
    "            \n",
    "            model.load_state_dict(torch.load(store_pt_path))\n",
    "            model.to(device)\n",
    "            \n",
    "            same_saving(u1 = u_1_M_model_3[seed_index], w1 = w_1_M_model_3[seed_index], v1 = u_2_M_model_3[seed_index],\n",
    "                        u2 = model.classifier.fc1.weight, w2 = model.classifier.diag1.weight, v2 = model.classifier.fc2.weight)\n",
    "            \n",
    "            val_loss_drop, val_acc_drop = class_valLoop(model = model,\n",
    "                                                        loss_fn = loss_fn,\n",
    "                                                        val_loader = val_dataloader,\n",
    "                                                        device = device)\n",
    "            val_losses_m_3[0] += val_loss_drop\n",
    "            val_acces_m_3[0] += val_acc_drop\n",
    "            del model\n",
    "            \n",
    "            u1w1_m_3 = (torch.mul(w_1_M_model_3[seed_index].t(), u_1_M_model_3[seed_index])).t()\n",
    "            U_m_3, S_m_3, Vh_m_3 = torch.linalg.svd(u1w1_m_3, full_matrices = False)\n",
    "            \n",
    "            for num_hidden_new_index, num_hidden_new in enumerate(neurons_list):\n",
    "                U_m_3_drop = U_m_3.clone().detach()[:,:num_hidden_new]\n",
    "                S_m_3_drop = S_m_3.clone().detach()[:num_hidden_new]\n",
    "                Vh_m_3_drop = Vh_m_3.clone().detach()[:num_hidden_new,:]\n",
    "                \n",
    "                model = revised_model(name = trans_model,\n",
    "                                      train_features = True,\n",
    "                                      pre_trained = True, \n",
    "                                      model_order = 3, \n",
    "                                      num_input = num_input,\n",
    "                                      num_hidden_1 = num_hidden_1, \n",
    "                                      num_output = num_output)\n",
    "                \n",
    "                model = pruning_model(model = model, \n",
    "                                      name = trans_model, \n",
    "                                      num_input = num_input,\n",
    "                                      num_hidden_new = num_hidden_new,\n",
    "                                      num_output = num_output,\n",
    "                                      ptPath = store_pt_path,\n",
    "                                      device = device)\n",
    "                \n",
    "                with torch.no_grad():\n",
    "                    check_shapes(model_weight = model.classifier.fc1.weight, list_sample = U_m_3_drop.t())\n",
    "                    check_shapes(model_weight = model.classifier.diag1.weight, list_sample = S_m_3_drop.unsqueeze(0))\n",
    "                    check_shapes(model_weight = model.classifier.fc2.weight, list_sample = (u_2_M_model_3[seed_index].clone().detach())@(Vh_m_3_drop.t()))\n",
    "                    model.classifier.fc1.weight = nn.Parameter(U_m_3_drop.t())\n",
    "                    model.classifier.diag1.weight = nn.Parameter(S_m_3_drop.unsqueeze(0))\n",
    "                    model.classifier.fc2.weight = nn.Parameter((u_2_M_model_3[seed_index].clone().detach())@(Vh_m_3_drop.t()))\n",
    "                \n",
    "                val_loss_drop, val_acc_drop = class_valLoop(model = model,\n",
    "                                                            loss_fn = loss_fn,\n",
    "                                                            val_loader = val_dataloader,\n",
    "                                                            device = device)\n",
    "                val_losses_m_3[num_hidden_new_index + 1] += val_loss_drop\n",
    "                val_acces_m_3[num_hidden_new_index + 1] += val_acc_drop\n",
    "                print(\"#neurons = {0} has been updated\".format(num_hidden_new))\n",
    "                \n",
    "                del model, U_m_3_drop, S_m_3_drop, Vh_m_3_drop, val_loss_drop, val_acc_drop\n",
    "            del u1w1_m_3, U_m_3, S_m_3, Vh_m_3\n",
    "        \n",
    "        else:\n",
    "            val_losses_m_3[0] += 10000\n",
    "            val_acces_m_3[0] += 10000\n",
    "            for num_hidden_new_index, num_hidden_new in enumerate(neurons_list):\n",
    "                val_losses_m_3[num_hidden_new_index + 1] += 10000\n",
    "                val_acces_m_3[num_hidden_new_index + 1] += 10000\n",
    "        print(\"finish test model 3\\n\\n\\n\")\n",
    "    # =======================Model_3======DONE=======Model_3==================================== \n",
    "    \n",
    "    # Avergae the accumulate validation loss/acc\n",
    "    for avg_index in range(len(val_losses_m_0)):\n",
    "        val_losses_m_0[avg_index] /= num_seeds\n",
    "        val_losses_m_1[avg_index] /= num_seeds\n",
    "        val_losses_m_2[avg_index] /= num_seeds\n",
    "        val_losses_m_3[avg_index] /= num_seeds\n",
    "        val_acces_m_0[avg_index] /= num_seeds\n",
    "        val_acces_m_1[avg_index] /= num_seeds\n",
    "        val_acces_m_2[avg_index] /= num_seeds\n",
    "        val_acces_m_3[avg_index] /= num_seeds\n",
    "\n",
    "    # Convert data to percentage compared to the baseline\n",
    "    val_losses_m_0_plot = [(x - val_losses_m_0[0]) / val_losses_m_0[0] for x in val_losses_m_0]\n",
    "    val_losses_m_1_plot = [(x - val_losses_m_1[0]) / val_losses_m_1[0] for x in val_losses_m_1]\n",
    "    val_losses_m_2_plot = [(x - val_losses_m_2[0]) / val_losses_m_2[0] for x in val_losses_m_2]\n",
    "    val_losses_m_3_plot = [(x - val_losses_m_3[0]) / val_losses_m_3[0] for x in val_losses_m_3]\n",
    "    # val_acces_m_0_plot = [(x - val_acces_m_0[0]) / val_acces_m_0[0] for x in val_acces_m_0]\n",
    "    # val_acces_m_1_plot = [(x - val_acces_m_1[0]) / val_acces_m_1[0] for x in val_acces_m_1]\n",
    "    # val_acces_m_2_plot = [(x - val_acces_m_2[0]) / val_acces_m_2[0] for x in val_acces_m_2]\n",
    "    # val_acces_m_3_plot = [(x - val_acces_m_3[0]) / val_acces_m_3[0] for x in val_acces_m_3]\n",
    "\n",
    "    # =======================WriteToFile=================WriteToFile============================     \n",
    "\n",
    "    with open('XXXXXX.txt', 'a+') as file: # Create or add text to the file\n",
    "        file.seek(0, 2) # Find the end of the file\n",
    "        for write_index in range(0, len(neurons_list) + 1):\n",
    "            file.write(\"{0}\\t\".format(val_losses_m_0[write_index]))\n",
    "            file.write(\"{0}\\t\".format(val_losses_m_1[write_index]))\n",
    "            file.write(\"{0}\\t\".format(val_losses_m_2[write_index]))\n",
    "            file.write(\"{0}\\t\".format(val_losses_m_3[write_index]))\n",
    "            if write_index == 0:\n",
    "                file.write(\"{0}\\t\".format(num_hidden_1))\n",
    "            else:\n",
    "                file.write(\"{0}\\t\".format(neurons_list[write_index - 1]))\n",
    "            file.write(\"{0}\\t\".format(val_losses_m_0_plot[write_index]))\n",
    "            file.write(\"{0}\\t\".format(val_losses_m_1_plot[write_index]))\n",
    "            file.write(\"{0}\\t\".format(val_losses_m_2_plot[write_index]))\n",
    "            file.write(\"{0}\\t\".format(val_losses_m_3_plot[write_index]))\n",
    "            if write_index == 0:\n",
    "                file.write(\"{0}\\t\".format(num_hidden_1))\n",
    "            else:\n",
    "                file.write(\"{0}\\t\".format(neurons_list[write_index - 1]))\n",
    "            file.write(\"{0}\\t\".format(val_acces_m_0[write_index]))\n",
    "            file.write(\"{0}\\t\".format(val_acces_m_1[write_index]))\n",
    "            file.write(\"{0}\\t\".format(val_acces_m_2[write_index]))\n",
    "            file.write(\"{0}\\t\".format(val_acces_m_3[write_index]))\n",
    "            file.write(\"\\n\")\n",
    "        file.write(\"\\n\\n\\n\\n\\n\")\n",
    "    print(\"The file has been run: \", store_file_path)\n",
    "print(\"All done\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
