{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "62d4e47e-4aec-4a70-9b41-058f9dc14587",
   "metadata": {},
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "# <span style=\"color:blue\"> *UDV project - SVD based pruning*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d62d1060",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports libraries\n",
    "\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.preprocessing import MinMaxScaler, LabelEncoder\n",
    "from sklearn.impute import SimpleImputer\n",
    "\n",
    "import pickle\n",
    "import time\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e91d4b0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the global seed for reproducibility\n",
    "\n",
    "public_seed = 529\n",
    "torch.manual_seed(public_seed)\n",
    "print(\"Current Seed is {0}\".format(public_seed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78c9549f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Obtain device\n",
    "\n",
    "from udvFunctions.udvDevice import get_device\n",
    "device = get_device()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "744d3c6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pre-process dataset\n",
    "\n",
    "from udvFunctions.udvDatasetPreprocessing import HousingPriceDataset\n",
    "from udvFunctions.udvDatasetPreprocessing import NYCDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86ab3012",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare the single-connection layer that carry weight matrix 'w'\n",
    "\n",
    "from udvFunctions.udvDiagonalLayer import D_singleConnection   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56ad74ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare networks\n",
    "\n",
    "from udvFunctions.udvRegNetworks import UDV_net_1\n",
    "from udvFunctions.udvRegNetworks import relu_net_1, fc_net_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f705c5a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare loss function (optional)\n",
    "\n",
    "from udvFunctions.udvLoss import UDV_Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b41017af",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the validation loop\n",
    "\n",
    "from udvFunctions.udvRegVal import reg_valLoop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69bd5c10",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Other customised functions\n",
    "\n",
    "from udvFunctions.udvOtherFunctions import store_metrics, take_avg, check_shapes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0cd1610",
   "metadata": {},
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "# <span style=\"color:blue\"> *Set SVD-based pruning test (Without re-train):*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acb7f5c9-bb9c-4ed0-8872-fdf7304ac36c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataset selection\n",
    "\n",
    "dataset_choice = 1 # 0 indicates HousingPriceDataset; 1 points NYC taxi duration dataset\n",
    "is_full = 0        # 0 indicates mini-batch; 1 points full batch size\n",
    "\n",
    "# HP\n",
    "if dataset_choice == 0:\n",
    "    data_path = \"./HP_Orig.csv\"\n",
    "    dataset = HousingPriceDataset(data_path)\n",
    "    print(\"load the housing price dataset\")\n",
    "    pre_path = \"./01_Orig/\"\n",
    "    file_list = [\"Adam_0.001_H1_26_H2_5_BS128_E200_S1000\",\n",
    "                 \"NAdam_0.001_H1_26_H2_5_BS128_E200_S1000\",\n",
    "                 \"SGD_0.1_H1_26_H2_5_BS128_E200_S1000\",\n",
    "                 \"SGDM_0.1_H1_26_H2_5_BS128_E200_S1000\"]\n",
    "    if is_full == 1:\n",
    "        pre_path = pre_path = \"./02_FullBatch/\"\n",
    "        file_list = [\"Adam_0.001_H1_26_H2_5_BSfull_E200_S1000\",\n",
    "                     \"NAdam_0.001_H1_26_H2_5_BSfull_E200_S1000\",\n",
    "                     \"SGD_0.1_H1_26_H2_5_BSfull_E200_S1000\",\n",
    "                     \"SGDM_0.1_H1_26_H2_5_BSfull_E200_S1000\"]\n",
    "        \n",
    "# NYC\n",
    "elif dataset_choice == 1:\n",
    "    data_path = \"./NYC_Orig.csv\"\n",
    "    dataset = NYCDataset(data_path)\n",
    "    print(\"load the NYC taxi duration dataset\")\n",
    "    pre_path = \"./01_Orig/\"\n",
    "    file_list = [\"Adam_0.0001_H1_10_H2_2_BS128_E50_S100\",\n",
    "                 \"NAdam_0.0001_H1_10_H2_2_BS128_E50_S100\",\n",
    "                 \"SGD_1_H1_10_H2_2_BS128_E50_S100\",\n",
    "                 \"SGDM_3_H1_10_H2_2_BS128_E50_S100\"]  \n",
    "    if is_full == 1:\n",
    "        pre_path = pre_path = \"./02_FullBatch/\"\n",
    "        file_list = [\"Adam_0.0001_H1_10_H2_2_BSfull_E50_S100\",\n",
    "                     \"NAdam_0.0001_H1_10_H2_2_BSfull_E50_S100\",\n",
    "                     \"SGD_1_H1_10_H2_2_BSfull_E50_S100\",\n",
    "                     \"SGDM_3_H1_10_H2_2_BSfull_E50_S100\"] \n",
    "else:\n",
    "    print(\"Wrong selection on dataset\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8be78ca-3828-4d21-afaf-7a3715cd5244",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Warning: raise Error if file_list is not match\n",
    "# Revise the file_list according to the baseline results\n",
    "from udvFunctions.udvOtherFunctions import CPU_Unpickler\n",
    "\n",
    "for drop_file_index in range(len(file_list)): \n",
    "    # Set path and open file\n",
    "    store_file_path = pre_path + file_list[drop_file_index] + '/SingleLayer/results.pkl'\n",
    "    print(\"current file is: \", store_file_path)\n",
    "\n",
    "    with open(store_file_path, 'rb') as file:\n",
    "        variables = CPU_Unpickler(file).load()\n",
    "\n",
    "    # Read parameters\n",
    "    num_epochs = variables['num_epochs']\n",
    "    num_seeds = variables['num_seeds']\n",
    "    batch_size = variables['batch_size']\n",
    "\n",
    "    optimiser_name = variables['optimiser_name']\n",
    "    learning_rate = variables['learning_rate']\n",
    "\n",
    "    num_input = variables['num_input']\n",
    "    num_hidden_1 = variables['num_hidden_1']\n",
    "    num_output = variables['num_output']   \n",
    "\n",
    "    u_1_V_model_0 = variables['u_1_V_model_0']\n",
    "    w_1_V_model_0 = variables['w_1_V_model_0']\n",
    "    v_1_V_model_0 = variables['v_1_V_model_0']\n",
    "\n",
    "    u_1_V_model_1 = variables['u_1_V_model_1']\n",
    "    w_1_V_model_1 = variables['w_1_V_model_1']\n",
    "    v_1_V_model_1 = variables['v_1_V_model_1']\n",
    "\n",
    "    u_1_M_model_2 = variables['u_1_M_model_2']\n",
    "    w_1_M_model_2 = variables['w_1_M_model_2']\n",
    "    u_2_M_model_2 = variables['u_2_M_model_2']\n",
    "\n",
    "    u_1_M_model_3 = variables['u_1_M_model_3']\n",
    "    w_1_M_model_3 = variables['w_1_M_model_3']\n",
    "    u_2_M_model_3 = variables['u_2_M_model_3']\n",
    "\n",
    "    l_1_model_4 = variables['l_1_model_4']\n",
    "    l_2_model_4 = variables['l_2_model_4']\n",
    "\n",
    "    l_1_model_5 = variables['l_1_model_5']\n",
    "    l_2_model_5 = variables['l_2_model_5']\n",
    "\n",
    "    seed_index_list = list(range(num_seeds))\n",
    "\n",
    "    # Use the same way as the baseline code (make sure the same validation dataset)\n",
    "    torch.manual_seed(public_seed)\n",
    "    training_ratio = 0.8\n",
    "    train_size = int(training_ratio * len(dataset))\n",
    "    validation_size = len(dataset) - train_size\n",
    "    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, validation_size])\n",
    "    \n",
    "    # If the result comes from full batch size experiment\n",
    "    if is_full == 0:\n",
    "        train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)\n",
    "        val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False)\n",
    "    else: \n",
    "        train_dataloader = DataLoader(train_dataset, batch_size = len(train_dataset), shuffle = True)\n",
    "        val_dataloader = DataLoader(val_dataset, batch_size = len(val_dataset), shuffle = False)\n",
    "    \n",
    "    # loss_fn = UDV_Loss()    # MSE/2\n",
    "    loss_fn = nn.MSELoss()  # MSE\n",
    "\n",
    "    # Re-producible setting\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "    # Create blank space for saving validation loss\n",
    "    # This code gradually decrease number of hidden neurons and obtain validation loss\n",
    "    val_losses_m_0 = [0] * num_hidden_1\n",
    "    val_losses_m_1 = [0] * num_hidden_1\n",
    "    val_losses_m_2 = [0] * num_hidden_1\n",
    "    val_losses_m_3 = [0] * num_hidden_1\n",
    "\n",
    "    # The SVD-based pruning results will be averaged by the same number of seeds as baseline code\n",
    "    for seed_index in seed_index_list:\n",
    "    \n",
    "    # =======================Model_0======DONE=======Model_0====================================\n",
    "        \n",
    "        # SVD of model_0: Vector_uwv (UDV-v1)\n",
    "        u1w1_m_0 = (torch.mul(w_1_V_model_0[seed_index].t(), u_1_V_model_0[seed_index])).t()\n",
    "        U_m_0, S_m_0, Vh_m_0 = torch.linalg.svd(u1w1_m_0, full_matrices = False)\n",
    "\n",
    "        # Obtain baseline with original structure and orignal weights \n",
    "        model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)\n",
    "        with torch.no_grad():\n",
    "            check_shapes(model_weight = model.fc1.weight, list_sample = u_1_V_model_0[seed_index])\n",
    "            check_shapes(model_weight = model.diag1.weight, list_sample = w_1_V_model_0[seed_index])\n",
    "            check_shapes(model_weight = model.fc2.weight, list_sample = v_1_V_model_0[seed_index])\n",
    "            model.fc1.weight = nn.Parameter(u_1_V_model_0[seed_index].clone().detach())\n",
    "            model.diag1.weight = nn.Parameter(w_1_V_model_0[seed_index].clone().detach())\n",
    "            model.fc2.weight = nn.Parameter(v_1_V_model_0[seed_index].clone().detach())\n",
    "\n",
    "        val_losses_m_0[0] += reg_valLoop(model = model,\n",
    "                                         loss_fn = loss_fn,\n",
    "                                         val_loader = val_dataloader,\n",
    "                                         device = device)\n",
    "        del model\n",
    "\n",
    "        # Gradually decrease number of hidden neurons \n",
    "        # Accumulate validation loss after SVD-based pruning\n",
    "        for num_hidden_new in range (num_hidden_1 - 1, 0, -1):\n",
    "            U_m_0_drop = U_m_0.clone().detach()[:,:num_hidden_new]\n",
    "            S_m_0_drop = S_m_0.clone().detach()[:num_hidden_new]\n",
    "            Vh_m_0_drop = Vh_m_0.clone().detach()[:num_hidden_new,:]\n",
    "            \n",
    "            # load constrained model\n",
    "            model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_new, num_output = num_output)\n",
    "            # send re-constructed weight matrix to model (re-construct the weight matrix by truncated singular value decomposition)\n",
    "            with torch.no_grad():\n",
    "                check_shapes(model_weight = model.fc1.weight, list_sample = U_m_0_drop.t())\n",
    "                check_shapes(model_weight = model.diag1.weight, list_sample = S_m_0_drop.unsqueeze(0))\n",
    "                check_shapes(model_weight = model.fc2.weight, list_sample = (v_1_V_model_0[seed_index].clone().detach())@(Vh_m_0_drop.t()))\n",
    "                model.fc1.weight = nn.Parameter(U_m_0_drop.t())\n",
    "                model.diag1.weight = nn.Parameter(S_m_0_drop.unsqueeze(0))\n",
    "                model.fc2.weight = nn.Parameter((v_1_V_model_0[seed_index].clone().detach())@(Vh_m_0_drop.t())) \n",
    "\n",
    "            val_losses_m_0[num_hidden_1-num_hidden_new] += reg_valLoop(model = model,\n",
    "                                                                       loss_fn = loss_fn,\n",
    "                                                                       val_loader = val_dataloader,\n",
    "                                                                       device = device)\n",
    "            del model, U_m_0_drop, S_m_0_drop, Vh_m_0_drop\n",
    "        del u1w1_m_0, U_m_0, S_m_0, Vh_m_0\n",
    "    # =======================Model_0======DONE=======Model_0====================================\n",
    "    \n",
    "    # =======================Model_1=================Model_1====================================    \n",
    "        \n",
    "        # UDV-v2\n",
    "        u1w1_m_1 = (torch.mul(w_1_V_model_1[seed_index].t(), u_1_V_model_1[seed_index])).t()\n",
    "        U_m_1, S_m_1, Vh_m_1 = torch.linalg.svd(u1w1_m_1, full_matrices = False)\n",
    "\n",
    "        model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)\n",
    "        with torch.no_grad():\n",
    "            check_shapes(model_weight = model.fc1.weight, list_sample = u_1_V_model_1[seed_index])\n",
    "            check_shapes(model_weight = model.diag1.weight, list_sample = w_1_V_model_1[seed_index])\n",
    "            check_shapes(model_weight = model.fc2.weight, list_sample = v_1_V_model_1[seed_index])\n",
    "            model.fc1.weight = nn.Parameter(u_1_V_model_1[seed_index].clone().detach())\n",
    "            model.diag1.weight = nn.Parameter(w_1_V_model_1[seed_index].clone().detach())\n",
    "            model.fc2.weight = nn.Parameter(v_1_V_model_1[seed_index].clone().detach())  \n",
    "        val_losses_m_1[0] += reg_valLoop(model = model,\n",
    "                                         loss_fn = loss_fn,\n",
    "                                         val_loader = val_dataloader,\n",
    "                                         device = device)\n",
    "        del model\n",
    "        \n",
    "        for num_hidden_new in range (num_hidden_1 - 1, 0, -1):\n",
    "            U_m_1_drop = U_m_1.clone().detach()[:,:num_hidden_new]\n",
    "            S_m_1_drop = S_m_1.clone().detach()[:num_hidden_new]\n",
    "            Vh_m_1_drop = Vh_m_1.clone().detach()[:num_hidden_new,:]\n",
    "\n",
    "            model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_new, num_output = num_output)\n",
    "\n",
    "            with torch.no_grad():\n",
    "                check_shapes(model_weight = model.fc1.weight, list_sample = U_m_1_drop.t())\n",
    "                check_shapes(model_weight = model.diag1.weight, list_sample = S_m_1_drop.unsqueeze(0))\n",
    "                check_shapes(model_weight = model.fc2.weight, list_sample = (v_1_V_model_1[seed_index].clone().detach())@(Vh_m_1_drop.t()))\n",
    "                model.fc1.weight = nn.Parameter(U_m_1_drop.t())\n",
    "                model.diag1.weight = nn.Parameter(S_m_1_drop.unsqueeze(0))\n",
    "                model.fc2.weight = nn.Parameter((v_1_V_model_1[seed_index].clone().detach())@(Vh_m_1_drop.t()))\n",
    "\n",
    "            val_losses_m_1[num_hidden_1-num_hidden_new] += reg_valLoop(model = model,\n",
    "                                                                       loss_fn = loss_fn,\n",
    "                                                                       val_loader = val_dataloader,\n",
    "                                                                       device = device)\n",
    "            del model, U_m_1_drop, S_m_1_drop, Vh_m_1_drop\n",
    "        del u1w1_m_1, U_m_1, S_m_1, Vh_m_1\n",
    "    # =======================Model_1=======DONE======Model_1====================================\n",
    "\n",
    "    # =======================Model_2=================Model_2====================================\n",
    "        \n",
    "        # UDV\n",
    "        u1w1_m_2 = (torch.mul(w_1_M_model_2[seed_index].t(), u_1_M_model_2[seed_index])).t()\n",
    "        U_m_2, S_m_2, Vh_m_2 = torch.linalg.svd(u1w1_m_2, full_matrices = False)\n",
    "\n",
    "        model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)\n",
    "        with torch.no_grad():\n",
    "            check_shapes(model_weight = model.fc1.weight, list_sample = u_1_M_model_2[seed_index])\n",
    "            check_shapes(model_weight = model.diag1.weight, list_sample = w_1_M_model_2[seed_index])\n",
    "            check_shapes(model_weight = model.fc2.weight, list_sample = u_2_M_model_2[seed_index])\n",
    "            model.fc1.weight = nn.Parameter(u_1_M_model_2[seed_index].clone().detach())\n",
    "            model.diag1.weight = nn.Parameter(w_1_M_model_2[seed_index].clone().detach())\n",
    "            model.fc2.weight = nn.Parameter(u_2_M_model_2[seed_index].clone().detach())        \n",
    "        val_losses_m_2[0] += reg_valLoop(model = model,\n",
    "                                          loss_fn = loss_fn,\n",
    "                                          val_loader = val_dataloader,\n",
    "                                          device = device)\n",
    "        del model\n",
    "        \n",
    "        for num_hidden_new in range (num_hidden_1 - 1, 0, -1):\n",
    "            U_m_2_drop = U_m_2.clone().detach()[:,:num_hidden_new]\n",
    "            S_m_2_drop = S_m_2.clone().detach()[:num_hidden_new]\n",
    "            Vh_m_2_drop = Vh_m_2.clone().detach()[:num_hidden_new,:]\n",
    "\n",
    "            model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_new, num_output = num_output)\n",
    "\n",
    "            with torch.no_grad():\n",
    "                check_shapes(model_weight = model.fc1.weight, list_sample = U_m_2_drop.t())\n",
    "                check_shapes(model_weight = model.diag1.weight, list_sample = S_m_2_drop.unsqueeze(0))\n",
    "                check_shapes(model_weight = model.fc2.weight, list_sample = (u_2_M_model_2[seed_index].clone().detach())@(Vh_m_2_drop.t()))\n",
    "                model.fc1.weight = nn.Parameter(U_m_2_drop.t())\n",
    "                model.diag1.weight = nn.Parameter(S_m_2_drop.unsqueeze(0))\n",
    "                model.fc2.weight = nn.Parameter((u_2_M_model_2[seed_index].clone().detach())@(Vh_m_2_drop.t()))\n",
    "\n",
    "            val_losses_m_2[num_hidden_1-num_hidden_new] += reg_valLoop(model = model,\n",
    "                                                                       loss_fn = loss_fn,\n",
    "                                                                       val_loader = val_dataloader,\n",
    "                                                                       device = device)\n",
    "            del model, U_m_2_drop, S_m_2_drop, Vh_m_2_drop\n",
    "        del u1w1_m_2, U_m_2, S_m_2, Vh_m_2\n",
    "    # =======================Model_2======DONE=======Model_2====================================   \n",
    "\n",
    "    # =======================Model_3=================Model_3====================================\n",
    "        \n",
    "        # UDV-s\n",
    "        u1w1_m_3 = (torch.mul(w_1_M_model_3[seed_index].t(), u_1_M_model_3[seed_index])).t()\n",
    "        U_m_3, S_m_3, Vh_m_3 = torch.linalg.svd(u1w1_m_3, full_matrices = False)\n",
    "\n",
    "        model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)\n",
    "        with torch.no_grad():\n",
    "            check_shapes(model_weight = model.fc1.weight, list_sample = u_1_M_model_3[seed_index])\n",
    "            check_shapes(model_weight = model.diag1.weight, list_sample = w_1_M_model_3[seed_index])\n",
    "            check_shapes(model_weight = model.fc2.weight, list_sample = u_2_M_model_3[seed_index])\n",
    "            model.fc1.weight = nn.Parameter(u_1_M_model_3[seed_index].clone().detach())\n",
    "            model.diag1.weight = nn.Parameter(w_1_M_model_3[seed_index].clone().detach()) \n",
    "            model.fc2.weight = nn.Parameter(u_2_M_model_3[seed_index].clone().detach()) \n",
    "        val_losses_m_3[0] += reg_valLoop(model = model,\n",
    "                                         loss_fn = loss_fn,\n",
    "                                         val_loader = val_dataloader,\n",
    "                                         device = device)\n",
    "        del model\n",
    "\n",
    "        for num_hidden_new in range (num_hidden_1 - 1, 0, -1):\n",
    "            U_m_3_drop = U_m_3.clone().detach()[:,:num_hidden_new]\n",
    "            S_m_3_drop = S_m_3.clone().detach()[:num_hidden_new]\n",
    "            Vh_m_3_drop = Vh_m_3.clone().detach()[:num_hidden_new,:]\n",
    "\n",
    "            model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_new, num_output = num_output)\n",
    "\n",
    "            with torch.no_grad():\n",
    "                check_shapes(model_weight = model.fc1.weight, list_sample = U_m_3_drop.t())\n",
    "                check_shapes(model_weight = model.diag1.weight, list_sample = S_m_3_drop.unsqueeze(0))\n",
    "                check_shapes(model_weight = model.fc2.weight, list_sample = (u_2_M_model_3[seed_index].clone().detach())@(Vh_m_3_drop.t()))\n",
    "                model.fc1.weight = nn.Parameter(U_m_3_drop.t())\n",
    "                model.diag1.weight = nn.Parameter(S_m_3_drop.unsqueeze(0))\n",
    "                model.fc2.weight = nn.Parameter((u_2_M_model_3[seed_index].clone().detach())@(Vh_m_3_drop.t()))\n",
    "\n",
    "            val_losses_m_3[num_hidden_1-num_hidden_new] += reg_valLoop(model = model,\n",
    "                                                                        loss_fn = loss_fn,\n",
    "                                                                        val_loader = val_dataloader,\n",
    "                                                                        device = device)\n",
    "            del model, U_m_3_drop, S_m_3_drop, Vh_m_3_drop\n",
    "        del u1w1_m_3, U_m_3, S_m_3, Vh_m_3\n",
    "    # =======================Model_3=======DONE======Model_3====================================\n",
    "\n",
    "    # Avergae the accumulate validation loss\n",
    "    for avg_index in range(len(val_losses_m_0)):\n",
    "        val_losses_m_0[avg_index] /= num_seeds\n",
    "        val_losses_m_1[avg_index] /= num_seeds\n",
    "        val_losses_m_2[avg_index] /= num_seeds\n",
    "        val_losses_m_3[avg_index] /= num_seeds\n",
    "\n",
    "    # Convert data to percentage compared to the baseline\n",
    "    val_losses_m_0_plot = [(x - val_losses_m_0[0]) / val_losses_m_0[0] for x in val_losses_m_0]\n",
    "    val_losses_m_1_plot = [(x - val_losses_m_1[0]) / val_losses_m_1[0] for x in val_losses_m_1]\n",
    "    val_losses_m_2_plot = [(x - val_losses_m_2[0]) / val_losses_m_2[0] for x in val_losses_m_2]\n",
    "    val_losses_m_3_plot = [(x - val_losses_m_3[0]) / val_losses_m_3[0] for x in val_losses_m_3]\n",
    "    \n",
    "    print(\"The file has been run: \", store_file_path)\n",
    "\n",
    "    with open('XXXXXX.txt', 'a+') as file: # Create or add text to the file\n",
    "        file.seek(0, 2) # Find the end of the file\n",
    "        for write_index in range(0, num_hidden_1):\n",
    "            file.write(\"{0}\\t\".format(val_losses_m_0[write_index]))\n",
    "            file.write(\"{0}\\t\".format(val_losses_m_1[write_index]))\n",
    "            file.write(\"{0}\\t\".format(val_losses_m_2[write_index]))\n",
    "            file.write(\"{0}\\t\".format(val_losses_m_3[write_index]))\n",
    "            file.write(\"{0}\\t\".format(num_hidden_1 - write_index))\n",
    "            file.write(\"{0}\\t\".format(val_losses_m_0_plot[write_index]))\n",
    "            file.write(\"{0}\\t\".format(val_losses_m_1_plot[write_index]))\n",
    "            file.write(\"{0}\\t\".format(val_losses_m_2_plot[write_index]))\n",
    "            file.write(\"{0}\\t\".format(val_losses_m_3_plot[write_index]))\n",
    "            file.write(\"\\n\")\n",
    "        file.write(\"\\n\\n\\n\\n\\n\")\n",
    "\n",
    "print(\"all done\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
