{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b42e5452-e408-427d-aaa9-fab847baba63",
   "metadata": {},
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "# <span style=\"color:blue\"> *UDV project*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d62d1060",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports libraries\n",
    "\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, Subset\n",
    "import numpy as np\n",
    "\n",
    "from torchvision import transforms\n",
    "from torchvision import models\n",
    "from collections import OrderedDict\n",
    "\n",
    "import pickle\n",
    "import time\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd39d620-34bb-4cdb-8745-72732e39c93c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the global seed for reproducibility\n",
    "\n",
    "public_seed = 529\n",
    "torch.manual_seed(public_seed)\n",
    "print(\"Current Seed is {0}\".format(public_seed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78c9549f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Obtain device\n",
    "\n",
    "from udvFunctions.udvDevice import get_device\n",
    "device = get_device()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cbb36fb-4ac0-4823-b807-a20574e72626",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load dataset\n",
    "\n",
    "from udvFunctions.udvClassDataset import MNIST_Pre"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3da2afff-786f-422a-a9ff-c08923536558",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare UDV constraints\n",
    "\n",
    "from udvFunctions.udvConstraints import Matrix_bothside, UDV_Diag\n",
    "from udvFunctions.udvConstraints import Vector_left_U, Vector_right_V"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0877c1c-16a9-4f17-beea-0c1be846e856",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare the single-connection layer that carry weight matrix 'w'\n",
    "\n",
    "from udvFunctions.udvDiagonalLayer import D_singleConnection   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5226c9ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare training framework\n",
    "\n",
    "# Model_0: Vector_uwv\n",
    "from udvFunctions.udvClassVectorUWV import udv_frame_v_uwv\n",
    "\n",
    "# Model_1: Vector_uv\n",
    "from udvFunctions.udvClassVectorUV import udv_frame_v_uv\n",
    "\n",
    "# Model_2: Matrix_uwv\n",
    "from udvFunctions.udvClassMatrixUWV import udv_frame_m_uwv\n",
    "\n",
    "# Model_3: Matrix_uv\n",
    "from udvFunctions.udvClassMatrixUV import udv_frame_m_uv\n",
    "\n",
    "# Model_4: ReLU\n",
    "from udvFunctions.udvClassReLU import udv_frame_relu\n",
    "\n",
    "# Model_5: Linear activation\n",
    "from udvFunctions.udvClassLinear import udv_frame_LAct\n",
    "\n",
    "# Model_6: Original transfer learning model\n",
    "from udvFunctions.udvClassOrig import udv_frame_Orig\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe2173aa-94b2-4be3-a8d6-c7dabfbc33cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Identical initialisation\n",
    "\n",
    "from udvFunctions.udvSameInit import seedList1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0712116-d92c-41c4-8d5d-f7e6fbea1bca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Other customised functions\n",
    "\n",
    "from udvFunctions.udvOtherFunctions import store_metrics, take_avg, check_shapes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0cd1610",
   "metadata": {},
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "# <span style=\"color:blue\"> *Setting Experiments:*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c48d712b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from udvFunctions.udvClassObtainModels import model_InitPara, revised_model\n",
    "torch.manual_seed(public_seed)\n",
    "\n",
    "# Load dataset\n",
    "data_path = \"./MNIST\"\n",
    "train_dataset, val_dataset, num_output = MNIST_Pre(load_All = True, data_path = data_path)\n",
    "\n",
    "# Set number of epochs and how many different seeds for training \n",
    "num_epochs = 70       # Number of epochs; Min:5\n",
    "num_seeds = 1          # Number of seeds;  Min:1\n",
    "\n",
    "# Define top layer (classifier)\n",
    "trans_model = 'regnet_x_32gf' # 'efficientnet_b0' or 'maxvit_t' or 'regnet_x_32gf'\n",
    "num_input, num_hidden_1, batch_size = model_InitPara(name = trans_model, scale_factor = 2/3)\n",
    "\n",
    "# Is feature layers trainable? (All layers before the classifier)\n",
    "train_features = True           # Either 'True' (all layers trainable) or 'False' (Only classifier trainable)\n",
    "\n",
    "# Is pre-trained weights used?\n",
    "pre_trained = True              # Either False (no pre-trained weights) or True (pre-trained weights from ImageNet-1K)\n",
    "\n",
    "# Load data\n",
    "num_workers = 4\n",
    "train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)\n",
    "val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers)\n",
    "\n",
    "# Set optimiser (later) and loss function\n",
    "optimiser_name = 'Adam'             # 'Adam', 'NAdam', 'SGD', 'SGDM'\n",
    "SGD_M = 0.9                         # Only available when 'SGDM' is specified in optimiser_name\n",
    "learning_rate = 1e-6\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "\n",
    "# Constraints setting\n",
    "vector_u_norm = 1\n",
    "vector_v_norm = 1\n",
    "matrix_uvnorm = 1\n",
    "d_threshold = 0\n",
    "d_boundto = 0\n",
    "\n",
    "# time record \n",
    "time_record = []\n",
    "\n",
    "# Set result path\n",
    "result_path = './{0}_{1}_{2}_H{3}_BS{4}_E{5}_S{6}_TrainFea{7}_Pre{8}/SingleLayer'.format(trans_model, optimiser_name, learning_rate, num_hidden_1, batch_size, num_epochs, num_seeds, int(train_features), pre_trained)\n",
    "if not os.path.exists(result_path):\n",
    "    os.makedirs(result_path)\n",
    "\n",
    "# Re-producible setting\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6fa48c7",
   "metadata": {},
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model with uwv constraints (Vector): Model_0_UDV-v1*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bf7dae4-0b70-4aac-b662-44de291f3b29",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "# Generate parameters matrices from public seed\n",
    "init_ulist1, init_wlist1, init_vlist1 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "# To store train/validation loss/accuracy\n",
    "full_train_loss_model_0 = []\n",
    "full_val_loss_model_0 = []\n",
    "full_train_acc_model_0 =[]\n",
    "full_val_acc_model_0 = []\n",
    "\n",
    "# Saved weights\n",
    "u_1_V_model_0 = []\n",
    "w_1_V_model_0 = []\n",
    "v_1_V_model_0 = []\n",
    "\n",
    "# Load constraints\n",
    "constraints_u = Vector_left_U(normLim = vector_u_norm)\n",
    "constraints_d = UDV_Diag(threshold = d_threshold, boundTo = d_boundto)\n",
    "constraints_v = Vector_right_V(normLim = vector_v_norm)\n",
    "\n",
    "# Timer start\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "# Model training and validation \n",
    "for i in range (1, num_seeds + 1):\n",
    "    model = revised_model(name = trans_model,\n",
    "                          train_features = train_features,\n",
    "                          pre_trained = pre_trained, \n",
    "                          model_order = 0, \n",
    "                          num_input = num_input,\n",
    "                          num_hidden_1 = num_hidden_1, \n",
    "                          num_output = num_output)\n",
    "\n",
    "    # Use reproducible parameters overwrite the random initialization\n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.classifier.fc1.weight, list_sample = init_ulist1[i-1])\n",
    "        check_shapes(model_weight = model.classifier.diag1.weight, list_sample = init_wlist1[i-1])\n",
    "        check_shapes(model_weight = model.classifier.fc2.weight, list_sample = init_vlist1[i-1])\n",
    "        model.classifier.fc1.weight = nn.Parameter(init_ulist1[i-1])\n",
    "        model.classifier.diag1.weight = nn.Parameter(init_wlist1[i-1])\n",
    "        model.classifier.fc2.weight = nn.Parameter(init_vlist1[i-1])\n",
    "\n",
    "    # Load optimiser \n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "\n",
    "    # Model training/validation loops\n",
    "    model, train_losses, val_losses, train_accs, val_accs, save_weights_list = udv_frame_v_uwv(model = model, \n",
    "                                                                                               optimizer = optimizer, \n",
    "                                                                                               loss_fn = loss_fn, \n",
    "                                                                                               train_loader = train_dataloader, \n",
    "                                                                                               val_loader = val_dataloader, \n",
    "                                                                                               num_epochs = num_epochs, \n",
    "                                                                                               device = device,\n",
    "                                                                                               constraints_u = constraints_u,\n",
    "                                                                                               constraints_d = constraints_d,\n",
    "                                                                                               constraints_v = constraints_v\n",
    "                                                                                              )\n",
    "    \n",
    "    # Record training/validation loss/acc of all epochs \n",
    "    full_train_loss_model_0 = store_metrics(full_train_loss_model_0, train_losses)\n",
    "    full_val_loss_model_0 = store_metrics(full_val_loss_model_0, val_losses)\n",
    "    full_train_acc_model_0 = store_metrics(full_train_acc_model_0, train_accs)\n",
    "    full_val_acc_model_0 = store_metrics(full_val_acc_model_0, val_accs)\n",
    "    \n",
    "    # Save model weights\n",
    "    check_shapes(model_weight = model.classifier.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.classifier.diag1.weight, list_sample = save_weights_list[1])\n",
    "    check_shapes(model_weight = model.classifier.fc2.weight, list_sample = save_weights_list[2])\n",
    "    u_1_V_model_0.append(save_weights_list[0])\n",
    "    w_1_V_model_0.append(save_weights_list[1])\n",
    "    v_1_V_model_0.append(save_weights_list[2])\n",
    "    \n",
    "    # Save model weights (ALL)\n",
    "    torch.save(model.state_dict(),'{0}/Seed_{1}_finalEpoch_{2}_model_0.pt'.format(result_path, i, num_epochs))\n",
    "\n",
    "# Average training/validation loss/acc from all seeds  \n",
    "avg_train_loss_model_0 = take_avg(full_train_loss_model_0)\n",
    "avg_val_loss_model_0 = take_avg(full_val_loss_model_0)\n",
    "avg_train_acc_model_0 = take_avg(full_train_acc_model_0)\n",
    "avg_val_acc_model_0 = take_avg(full_val_acc_model_0)\n",
    "\n",
    "# Timer end \n",
    "end_time = time.time()    \n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_ulist1, init_wlist1, init_vlist1, model, constraints_u, constraints_d, constraints_v, save_weights_list, train_losses, val_losses, train_accs, val_accs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb4a5fd8",
   "metadata": {},
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model with uv constraints (Vector): Model_1_UDV-v2*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d77af7f6-a93e-47a6-b1f0-08df173d3ef3",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "init_ulist1, init_wlist1, init_vlist1 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "full_train_loss_model_1 = []\n",
    "full_val_loss_model_1 = []\n",
    "full_train_acc_model_1 =[]\n",
    "full_val_acc_model_1 = []\n",
    "\n",
    "u_1_V_model_1 = []\n",
    "w_1_V_model_1 = []\n",
    "v_1_V_model_1 = []\n",
    "\n",
    "constraints_u = Vector_left_U(normLim = vector_u_norm)\n",
    "constraints_v = Vector_right_V(normLim = vector_v_norm)\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range (1, num_seeds + 1): \n",
    "    model = revised_model(name = trans_model,\n",
    "                          train_features = train_features,\n",
    "                          pre_trained = pre_trained, \n",
    "                          model_order = 1, \n",
    "                          num_input = num_input,\n",
    "                          num_hidden_1 = num_hidden_1, \n",
    "                          num_output = num_output)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.classifier.fc1.weight, list_sample = init_ulist1[i-1])\n",
    "        check_shapes(model_weight = model.classifier.diag1.weight, list_sample = init_wlist1[i-1])\n",
    "        check_shapes(model_weight = model.classifier.fc2.weight, list_sample = init_vlist1[i-1])\n",
    "        model.classifier.fc1.weight = nn.Parameter(init_ulist1[i-1])\n",
    "        model.classifier.diag1.weight = nn.Parameter(init_wlist1[i-1])\n",
    "        model.classifier.fc2.weight = nn.Parameter(init_vlist1[i-1])\n",
    "\n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    model, train_losses, val_losses, train_accs, val_accs, save_weights_list = udv_frame_v_uv(model = model, \n",
    "                                                                                              optimizer = optimizer, \n",
    "                                                                                              loss_fn = loss_fn, \n",
    "                                                                                              train_loader = train_dataloader, \n",
    "                                                                                              val_loader = val_dataloader, \n",
    "                                                                                              num_epochs = num_epochs, \n",
    "                                                                                              device = device,\n",
    "                                                                                              constraints_u = constraints_u,\n",
    "                                                                                              constraints_v = constraints_v\n",
    "                                                                                             )\n",
    "    full_train_loss_model_1 = store_metrics(full_train_loss_model_1, train_losses)\n",
    "    full_val_loss_model_1 = store_metrics(full_val_loss_model_1, val_losses)\n",
    "    full_train_acc_model_1 = store_metrics(full_train_acc_model_1, train_accs)\n",
    "    full_val_acc_model_1 = store_metrics(full_val_acc_model_1, val_accs)\n",
    "    \n",
    "    check_shapes(model_weight = model.classifier.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.classifier.diag1.weight, list_sample = save_weights_list[1])\n",
    "    check_shapes(model_weight = model.classifier.fc2.weight, list_sample = save_weights_list[2])\n",
    "    u_1_V_model_1.append(save_weights_list[0])\n",
    "    w_1_V_model_1.append(save_weights_list[1])\n",
    "    v_1_V_model_1.append(save_weights_list[2])\n",
    "    \n",
    "    torch.save(model.state_dict(),'{0}/Seed_{1}_finalEpoch_{2}_model_1.pt'.format(result_path, i, num_epochs))\n",
    "\n",
    "avg_train_loss_model_1 = take_avg(full_train_loss_model_1)\n",
    "avg_val_loss_model_1 = take_avg(full_val_loss_model_1)\n",
    "avg_train_acc_model_1 = take_avg(full_train_acc_model_1)\n",
    "avg_val_acc_model_1 = take_avg(full_val_acc_model_1)\n",
    "\n",
    "end_time = time.time()\n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_ulist1, init_wlist1, init_vlist1, model, constraints_u, constraints_v, save_weights_list, train_losses, val_losses, train_accs, val_accs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5472d996",
   "metadata": {},
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model with uwv constraints (Matrix): Model_2_UDV*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5e7020c-a8c0-4d3c-8b84-54d3f756f28a",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "init_ulist1, init_wlist1, init_ulist2 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "full_train_loss_model_2 = []\n",
    "full_val_loss_model_2 = []\n",
    "full_train_acc_model_2 =[]\n",
    "full_val_acc_model_2 = []\n",
    "\n",
    "u_1_M_model_2 = []\n",
    "w_1_M_model_2 = []\n",
    "u_2_M_model_2 = []\n",
    "\n",
    "constraints_uv = Matrix_bothside(normLim = matrix_uvnorm)\n",
    "constraints_d = UDV_Diag(threshold = d_threshold, boundTo = d_boundto)\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range (1, num_seeds + 1):\n",
    "    model = revised_model(name = trans_model,\n",
    "                          train_features = train_features,\n",
    "                          pre_trained = pre_trained, \n",
    "                          model_order = 2, \n",
    "                          num_input = num_input,\n",
    "                          num_hidden_1 = num_hidden_1, \n",
    "                          num_output = num_output)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.classifier.fc1.weight, list_sample = init_ulist1[i-1])\n",
    "        check_shapes(model_weight = model.classifier.diag1.weight, list_sample = init_wlist1[i-1])\n",
    "        check_shapes(model_weight = model.classifier.fc2.weight, list_sample = init_ulist2[i-1])\n",
    "        model.classifier.fc1.weight = nn.Parameter(init_ulist1[i-1])\n",
    "        model.classifier.diag1.weight = nn.Parameter(init_wlist1[i-1])\n",
    "        model.classifier.fc2.weight = nn.Parameter(init_ulist2[i-1])\n",
    "    \n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    model, train_losses, val_losses, train_accs, val_accs, save_weights_list = udv_frame_m_uwv(model = model, \n",
    "                                                                                               optimizer = optimizer, \n",
    "                                                                                               loss_fn = loss_fn, \n",
    "                                                                                               train_loader = train_dataloader, \n",
    "                                                                                               val_loader = val_dataloader, \n",
    "                                                                                               num_epochs = num_epochs, \n",
    "                                                                                               device = device,\n",
    "                                                                                               constraints_uv = constraints_uv,\n",
    "                                                                                               constraints_d = constraints_d,\n",
    "                                                                                              ) \n",
    "    full_train_loss_model_2 = store_metrics(full_train_loss_model_2, train_losses)\n",
    "    full_val_loss_model_2 = store_metrics(full_val_loss_model_2, val_losses)\n",
    "    full_train_acc_model_2 = store_metrics(full_train_acc_model_2, train_accs)\n",
    "    full_val_acc_model_2 = store_metrics(full_val_acc_model_2, val_accs)\n",
    "    \n",
    "    check_shapes(model_weight = model.classifier.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.classifier.diag1.weight, list_sample = save_weights_list[1])\n",
    "    check_shapes(model_weight = model.classifier.fc2.weight, list_sample = save_weights_list[2])\n",
    "    u_1_M_model_2.append(save_weights_list[0])\n",
    "    w_1_M_model_2.append(save_weights_list[1])\n",
    "    u_2_M_model_2.append(save_weights_list[2])\n",
    "    \n",
    "    torch.save(model.state_dict(),'{0}/Seed_{1}_finalEpoch_{2}_model_2.pt'.format(result_path, i, num_epochs))\n",
    "\n",
    "avg_train_loss_model_2 = take_avg(full_train_loss_model_2)\n",
    "avg_val_loss_model_2 = take_avg(full_val_loss_model_2)\n",
    "avg_train_acc_model_2 = take_avg(full_train_acc_model_2)\n",
    "avg_val_acc_model_2 = take_avg(full_val_acc_model_2)\n",
    " \n",
    "end_time = time.time()\n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_ulist1, init_wlist1, init_ulist2, model, constraints_uv, constraints_d, save_weights_list, train_losses, val_losses, train_accs, val_accs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4fa2575",
   "metadata": {},
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model with uv constraints (Matrix): Model_3_UDV-s*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cde0657-ce62-4848-b202-ed193278f67f",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "init_ulist1, init_wlist1, init_ulist2 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "full_train_loss_model_3 = []\n",
    "full_val_loss_model_3 = []\n",
    "full_train_acc_model_3 =[]\n",
    "full_val_acc_model_3 = []\n",
    "\n",
    "u_1_M_model_3 = []\n",
    "w_1_M_model_3 = []\n",
    "u_2_M_model_3 = []\n",
    "\n",
    "constraints_uv = Matrix_bothside(normLim = matrix_uvnorm)\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range (1, num_seeds + 1):  \n",
    "    model = revised_model(name = trans_model,\n",
    "                          train_features = train_features,\n",
    "                          pre_trained = pre_trained, \n",
    "                          model_order = 3, \n",
    "                          num_input = num_input,\n",
    "                          num_hidden_1 = num_hidden_1, \n",
    "                          num_output = num_output)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.classifier.fc1.weight, list_sample = init_ulist1[i-1])\n",
    "        check_shapes(model_weight = model.classifier.diag1.weight, list_sample = init_wlist1[i-1])\n",
    "        check_shapes(model_weight = model.classifier.fc2.weight, list_sample = init_ulist2[i-1])\n",
    "        model.classifier.fc1.weight = nn.Parameter(init_ulist1[i-1])\n",
    "        model.classifier.diag1.weight = nn.Parameter(init_wlist1[i-1])\n",
    "        model.classifier.fc2.weight = nn.Parameter(init_ulist2[i-1])\n",
    "\n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    model, train_losses, val_losses, train_accs, val_accs, save_weights_list = udv_frame_m_uv(model = model, \n",
    "                                                                                              optimizer = optimizer, \n",
    "                                                                                              loss_fn = loss_fn, \n",
    "                                                                                              train_loader = train_dataloader, \n",
    "                                                                                              val_loader = val_dataloader, \n",
    "                                                                                              num_epochs = num_epochs, \n",
    "                                                                                              device = device,\n",
    "                                                                                              constraints_uv = constraints_uv,\n",
    "                                                                                             )\n",
    "    full_train_loss_model_3 = store_metrics(full_train_loss_model_3, train_losses)\n",
    "    full_val_loss_model_3 = store_metrics(full_val_loss_model_3, val_losses)\n",
    "    full_train_acc_model_3 = store_metrics(full_train_acc_model_3, train_accs)\n",
    "    full_val_acc_model_3 = store_metrics(full_val_acc_model_3, val_accs)\n",
    "    \n",
    "    check_shapes(model_weight = model.classifier.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.classifier.diag1.weight, list_sample = save_weights_list[1])\n",
    "    check_shapes(model_weight = model.classifier.fc2.weight, list_sample = save_weights_list[2])\n",
    "    u_1_M_model_3.append(save_weights_list[0])\n",
    "    w_1_M_model_3.append(save_weights_list[1])\n",
    "    u_2_M_model_3.append(save_weights_list[2])\n",
    "    \n",
    "    torch.save(model.state_dict(),'{0}/Seed_{1}_finalEpoch_{2}_model_3.pt'.format(result_path, i, num_epochs))\n",
    "\n",
    "avg_train_loss_model_3 = take_avg(full_train_loss_model_3)\n",
    "avg_val_loss_model_3 = take_avg(full_val_loss_model_3)\n",
    "avg_train_acc_model_3 = take_avg(full_train_acc_model_3)\n",
    "avg_val_acc_model_3 = take_avg(full_val_acc_model_3)\n",
    "\n",
    "end_time = time.time()\n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_ulist1, init_wlist1, init_ulist2, model, constraints_uv, save_weights_list, train_losses, val_losses, train_accs, val_accs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e630e64e",
   "metadata": {},
   "source": [
    "#### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model with ReLU: Model_4_UV_ReLU*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af946cce-7c6a-4653-a720-4797bd6f13d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "init_llist1, init_wlist1, init_llist2 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "\n",
    "\n",
    "full_train_loss_model_4 = []\n",
    "full_val_loss_model_4 = []\n",
    "full_train_acc_model_4 =[]\n",
    "full_val_acc_model_4 = []\n",
    "\n",
    "l_1_model_4 = []\n",
    "l_2_model_4 = []\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range (1, num_seeds + 1): \n",
    "    model = revised_model(name = trans_model,\n",
    "                          train_features = train_features,\n",
    "                          pre_trained = pre_trained, \n",
    "                          model_order = 4, \n",
    "                          num_input = num_input,\n",
    "                          num_hidden_1 = num_hidden_1, \n",
    "                          num_output = num_output)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.classifier.fc1.weight, list_sample = init_llist1[i-1])\n",
    "        check_shapes(model_weight = model.classifier.fc2.weight, list_sample = init_llist2[i-1])\n",
    "        model.classifier.fc1.weight = nn.Parameter(init_llist1[i-1])\n",
    "        model.classifier.fc2.weight = nn.Parameter(init_llist2[i-1])\n",
    "        \n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    model, train_losses, val_losses, train_accs, val_accs, save_weights_list = udv_frame_relu(model = model, \n",
    "                                                                                              optimizer = optimizer, \n",
    "                                                                                              loss_fn = loss_fn, \n",
    "                                                                                              train_loader = train_dataloader, \n",
    "                                                                                              val_loader = val_dataloader, \n",
    "                                                                                              num_epochs = num_epochs, \n",
    "                                                                                              device = device\n",
    "                                                                                             )    \n",
    "    full_train_loss_model_4 = store_metrics(full_train_loss_model_4, train_losses)\n",
    "    full_val_loss_model_4 = store_metrics(full_val_loss_model_4, val_losses)\n",
    "    full_train_acc_model_4 = store_metrics(full_train_acc_model_4, train_accs)\n",
    "    full_val_acc_model_4 = store_metrics(full_val_acc_model_4, val_accs)\n",
    "    \n",
    "    check_shapes(model_weight = model.classifier.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.classifier.fc2.weight, list_sample = save_weights_list[1])\n",
    "    l_1_model_4.append(save_weights_list[0])\n",
    "    l_2_model_4.append(save_weights_list[1])\n",
    "    \n",
    "    torch.save(model.state_dict(),'{0}/Seed_{1}_finalEpoch_{2}_model_4.pt'.format(result_path, i, num_epochs))\n",
    "    \n",
    "avg_train_loss_model_4 = take_avg(full_train_loss_model_4)\n",
    "avg_val_loss_model_4 = take_avg(full_val_loss_model_4)\n",
    "avg_train_acc_model_4 = take_avg(full_train_acc_model_4)\n",
    "avg_val_acc_model_4 = take_avg(full_val_acc_model_4)\n",
    "\n",
    "end_time = time.time()\n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_llist1, init_wlist1, init_llist2, model, save_weights_list, train_losses, val_losses, train_accs, val_accs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7fb229a",
   "metadata": {},
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model is fully connected layers (NO activation): Model_5_UV*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e8111ee-fcc5-4102-b974-abb37dbe0ac7",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "init_llist1, init_wlist1, init_llist2 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "full_train_loss_model_5 = []\n",
    "full_val_loss_model_5 = []\n",
    "full_train_acc_model_5 =[]\n",
    "full_val_acc_model_5 = []\n",
    "\n",
    "l_1_model_5 = []\n",
    "l_2_model_5 = []\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range (1, num_seeds + 1):\n",
    "    model = revised_model(name = trans_model,\n",
    "                          train_features = train_features,\n",
    "                          pre_trained = pre_trained, \n",
    "                          model_order = 5, \n",
    "                          num_input = num_input,\n",
    "                          num_hidden_1 = num_hidden_1, \n",
    "                          num_output = num_output)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.classifier.fc1.weight, list_sample = init_llist1[i-1])\n",
    "        check_shapes(model_weight = model.classifier.fc2.weight, list_sample = init_llist2[i-1])\n",
    "        model.classifier.fc1.weight = nn.Parameter(init_llist1[i-1])\n",
    "        model.classifier.fc2.weight = nn.Parameter(init_llist2[i-1])\n",
    "\n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    model, train_losses, val_losses, train_accs, val_accs, save_weights_list = udv_frame_LAct(model = model, \n",
    "                                                                                              optimizer = optimizer, \n",
    "                                                                                              loss_fn = loss_fn, \n",
    "                                                                                              train_loader = train_dataloader, \n",
    "                                                                                              val_loader = val_dataloader, \n",
    "                                                                                              num_epochs = num_epochs, \n",
    "                                                                                              device = device\n",
    "                                                                                             )  \n",
    "    full_train_loss_model_5 = store_metrics(full_train_loss_model_5, train_losses)\n",
    "    full_val_loss_model_5 = store_metrics(full_val_loss_model_5, val_losses)\n",
    "    full_train_acc_model_5 = store_metrics(full_train_acc_model_5, train_accs)\n",
    "    full_val_acc_model_5 = store_metrics(full_val_acc_model_5, val_accs)\n",
    "    \n",
    "    check_shapes(model_weight = model.classifier.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.classifier.fc2.weight, list_sample = save_weights_list[1])\n",
    "    l_1_model_5.append(save_weights_list[0])\n",
    "    l_2_model_5.append(save_weights_list[1])\n",
    "        \n",
    "avg_train_loss_model_5 = take_avg(full_train_loss_model_5)\n",
    "avg_val_loss_model_5 = take_avg(full_val_loss_model_5)\n",
    "avg_train_acc_model_5 = take_avg(full_train_acc_model_5)\n",
    "avg_val_acc_model_5 = take_avg(full_val_acc_model_5)\n",
    "\n",
    "end_time = time.time()\n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_llist1, init_wlist1, init_llist2, model, save_weights_list, train_losses, val_losses, train_accs, val_accs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b7ee3f8",
   "metadata": {},
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model is original transferred model: Model_6*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b443c021-0d79-435d-ac61-f1ef5ba15d7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "full_train_loss_model_6 = []\n",
    "full_val_loss_model_6 = []\n",
    "full_train_acc_model_6 =[]\n",
    "full_val_acc_model_6 = []\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range (1, num_seeds + 1):\n",
    "    model = revised_model(name = trans_model,\n",
    "                          train_features = train_features,\n",
    "                          pre_trained = pre_trained, \n",
    "                          model_order = 6, \n",
    "                          num_input = num_input,\n",
    "                          num_hidden_1 = num_hidden_1, \n",
    "                          num_output = num_output)\n",
    "    \n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    model, train_losses, val_losses, train_accs, val_accs = udv_frame_Orig(model = model, \n",
    "                                                                           optimizer = optimizer, \n",
    "                                                                           loss_fn = loss_fn, \n",
    "                                                                           train_loader = train_dataloader, \n",
    "                                                                           val_loader = val_dataloader, \n",
    "                                                                           num_epochs = num_epochs, \n",
    "                                                                           device = device,\n",
    "                                                                          )  \n",
    "    full_train_loss_model_6 = store_metrics(full_train_loss_model_6, train_losses)\n",
    "    full_val_loss_model_6 = store_metrics(full_val_loss_model_6, val_losses)    \n",
    "    full_train_acc_model_6 = store_metrics(full_train_acc_model_6, train_accs)\n",
    "    full_val_acc_model_6 = store_metrics(full_val_acc_model_6, val_accs)\n",
    "    \n",
    "avg_train_loss_model_6 = take_avg(full_train_loss_model_6)\n",
    "avg_val_loss_model_6 = take_avg(full_val_loss_model_6)\n",
    "avg_train_acc_model_6 = take_avg(full_train_acc_model_6)\n",
    "avg_val_acc_model_6 = take_avg(full_val_acc_model_6)\n",
    "\n",
    "end_time = time.time()\n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del model, train_losses, val_losses, train_accs, val_accs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "073ce221",
   "metadata": {},
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *Save Variables*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d295a8bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save variables for analysis\n",
    "\n",
    "store_file_path = '{0}/results.pkl'.format(result_path)\n",
    "\n",
    "variables = {\n",
    "    'num_input': num_input,\n",
    "    'num_hidden_1': num_hidden_1,\n",
    "    'num_output': num_output,\n",
    "    'public_seed': public_seed,\n",
    "    'data_path': data_path,\n",
    "    'batch_size': batch_size,\n",
    "    'num_epochs': num_epochs,\n",
    "    'num_seeds': num_seeds,\n",
    "    'learning_rate': learning_rate,\n",
    "    'device': device,\n",
    "    'optimiser_name': optimiser_name,\n",
    "    'trans_model':trans_model,\n",
    "    \n",
    "    'full_train_loss_model_0': full_train_loss_model_0,\n",
    "    'full_train_loss_model_1': full_train_loss_model_1,\n",
    "    'full_train_loss_model_2': full_train_loss_model_2,\n",
    "    'full_train_loss_model_3': full_train_loss_model_3,\n",
    "    'full_train_loss_model_4': full_train_loss_model_4,\n",
    "    'full_train_loss_model_5': full_train_loss_model_5,\n",
    "    'full_train_loss_model_6': full_train_loss_model_6,\n",
    "    'avg_train_loss_model_0': avg_train_loss_model_0,\n",
    "    'avg_train_loss_model_1': avg_train_loss_model_1,\n",
    "    'avg_train_loss_model_2': avg_train_loss_model_2,\n",
    "    'avg_train_loss_model_3': avg_train_loss_model_3,\n",
    "    'avg_train_loss_model_4': avg_train_loss_model_4,\n",
    "    'avg_train_loss_model_5': avg_train_loss_model_5,\n",
    "    'avg_train_loss_model_6': avg_train_loss_model_6,\n",
    "    'full_val_loss_model_0': full_val_loss_model_0,\n",
    "    'full_val_loss_model_1': full_val_loss_model_1,\n",
    "    'full_val_loss_model_2': full_val_loss_model_2,\n",
    "    'full_val_loss_model_3': full_val_loss_model_3,\n",
    "    'full_val_loss_model_4': full_val_loss_model_4,\n",
    "    'full_val_loss_model_5': full_val_loss_model_5,\n",
    "    'full_val_loss_model_6': full_val_loss_model_6,\n",
    "    'avg_val_loss_model_0': avg_val_loss_model_0,\n",
    "    'avg_val_loss_model_1': avg_val_loss_model_1,\n",
    "    'avg_val_loss_model_2': avg_val_loss_model_2,\n",
    "    'avg_val_loss_model_3': avg_val_loss_model_3,\n",
    "    'avg_val_loss_model_4': avg_val_loss_model_4,\n",
    "    'avg_val_loss_model_5': avg_val_loss_model_5,\n",
    "    'avg_val_loss_model_6': avg_val_loss_model_6,\n",
    "    \n",
    "    'full_train_acc_model_0': full_train_acc_model_0,\n",
    "    'full_train_acc_model_1': full_train_acc_model_1,\n",
    "    'full_train_acc_model_2': full_train_acc_model_2,\n",
    "    'full_train_acc_model_3': full_train_acc_model_3,\n",
    "    'full_train_acc_model_4': full_train_acc_model_4,\n",
    "    'full_train_acc_model_5': full_train_acc_model_5,\n",
    "    'full_train_acc_model_6': full_train_acc_model_6,\n",
    "    'avg_train_acc_model_0': avg_train_acc_model_0,\n",
    "    'avg_train_acc_model_1': avg_train_acc_model_1,\n",
    "    'avg_train_acc_model_2': avg_train_acc_model_2,\n",
    "    'avg_train_acc_model_3': avg_train_acc_model_3,\n",
    "    'avg_train_acc_model_4': avg_train_acc_model_4,\n",
    "    'avg_train_acc_model_5': avg_train_acc_model_5,\n",
    "    'avg_train_acc_model_6': avg_train_acc_model_6,\n",
    "    'full_val_acc_model_0': full_val_acc_model_0,\n",
    "    'full_val_acc_model_1': full_val_acc_model_1,\n",
    "    'full_val_acc_model_2': full_val_acc_model_2,\n",
    "    'full_val_acc_model_3': full_val_acc_model_3,\n",
    "    'full_val_acc_model_4': full_val_acc_model_4,\n",
    "    'full_val_acc_model_5': full_val_acc_model_5,\n",
    "    'full_val_acc_model_6': full_val_acc_model_6,\n",
    "    'avg_val_acc_model_0': avg_val_acc_model_0,\n",
    "    'avg_val_acc_model_1': avg_val_acc_model_1,\n",
    "    'avg_val_acc_model_2': avg_val_acc_model_2,\n",
    "    'avg_val_acc_model_3': avg_val_acc_model_3,\n",
    "    'avg_val_acc_model_4': avg_val_acc_model_4,\n",
    "    'avg_val_acc_model_5': avg_val_acc_model_5,\n",
    "    'avg_val_acc_model_6': avg_val_acc_model_6, \n",
    "    \n",
    "    'u_1_V_model_0': u_1_V_model_0,\n",
    "    'v_1_V_model_0': v_1_V_model_0,\n",
    "    'w_1_V_model_0': w_1_V_model_0,\n",
    "    \n",
    "    'u_1_V_model_1': u_1_V_model_1,\n",
    "    'v_1_V_model_1': v_1_V_model_1,\n",
    "    'w_1_V_model_1': w_1_V_model_1,\n",
    "    \n",
    "    'u_1_M_model_2': u_1_M_model_2,\n",
    "    'u_2_M_model_2': u_2_M_model_2,\n",
    "    'w_1_M_model_2': w_1_M_model_2,    \n",
    "\n",
    "    'u_1_M_model_3': u_1_M_model_3,\n",
    "    'u_2_M_model_3': u_2_M_model_3,\n",
    "    'w_1_M_model_3': w_1_M_model_3,\n",
    "\n",
    "    'l_1_model_4': l_1_model_4,\n",
    "    'l_2_model_4': l_2_model_4,\n",
    "    \n",
    "    'l_1_model_5': l_1_model_5,\n",
    "    'l_2_model_5': l_2_model_5,\n",
    "    \n",
    "    'time_record': time_record\n",
    "}\n",
    "\n",
    "with open(store_file_path, 'wb') as file:\n",
    "    pickle.dump(variables, file)\n",
    "\n",
    "print(\"Saving pickle file, done\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a508d100-948a-4e68-949d-482c021ad4ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "from udvFunctions.udvOtherFunctions import avg_stable\n",
    "\n",
    "last_n = 5            # how many epochs (from end) are averaged\n",
    "final_train_loss_m0 = avg_stable(avg_train_loss_model_0, last_n)\n",
    "final_train_loss_m1 = avg_stable(avg_train_loss_model_1, last_n)\n",
    "final_train_loss_m2 = avg_stable(avg_train_loss_model_2, last_n)\n",
    "final_train_loss_m3 = avg_stable(avg_train_loss_model_3, last_n)\n",
    "final_train_loss_m4 = avg_stable(avg_train_loss_model_4, last_n)\n",
    "final_train_loss_m5 = avg_stable(avg_train_loss_model_5, last_n)\n",
    "final_train_loss_m6 = avg_stable(avg_train_loss_model_6, last_n)\n",
    "\n",
    "final_val_loss_m0 = avg_stable(avg_val_loss_model_0, last_n)\n",
    "final_val_loss_m1 = avg_stable(avg_val_loss_model_1, last_n)\n",
    "final_val_loss_m2 = avg_stable(avg_val_loss_model_2, last_n)\n",
    "final_val_loss_m3 = avg_stable(avg_val_loss_model_3, last_n)\n",
    "final_val_loss_m4 = avg_stable(avg_val_loss_model_4, last_n)\n",
    "final_val_loss_m5 = avg_stable(avg_val_loss_model_5, last_n)\n",
    "final_val_loss_m6 = avg_stable(avg_val_loss_model_6, last_n)\n",
    "\n",
    "final_train_acc_m0 = avg_stable(avg_train_acc_model_0, last_n)\n",
    "final_train_acc_m1 = avg_stable(avg_train_acc_model_1, last_n)\n",
    "final_train_acc_m2 = avg_stable(avg_train_acc_model_2, last_n)\n",
    "final_train_acc_m3 = avg_stable(avg_train_acc_model_3, last_n)\n",
    "final_train_acc_m4 = avg_stable(avg_train_acc_model_4, last_n)\n",
    "final_train_acc_m5 = avg_stable(avg_train_acc_model_5, last_n)\n",
    "final_train_acc_m6 = avg_stable(avg_train_acc_model_6, last_n)\n",
    "\n",
    "final_val_acc_m0 = avg_stable(avg_val_acc_model_0, last_n)\n",
    "final_val_acc_m1 = avg_stable(avg_val_acc_model_1, last_n)\n",
    "final_val_acc_m2 = avg_stable(avg_val_acc_model_2, last_n)\n",
    "final_val_acc_m3 = avg_stable(avg_val_acc_model_3, last_n)\n",
    "final_val_acc_m4 = avg_stable(avg_val_acc_model_4, last_n)\n",
    "final_val_acc_m5 = avg_stable(avg_val_acc_model_5, last_n)\n",
    "final_val_acc_m6 = avg_stable(avg_val_acc_model_6, last_n)\n",
    "\n",
    "print('Batch size: {0}\\n#Seeds: {1}\\n#Epochs: {2}\\nOptimier: {3}\\nLearning Rate: {4}\\n#Hidden_1: {5}'.format(batch_size, num_seeds, num_epochs, optimiser_name, learning_rate, num_hidden_1))\n",
    "print('Final results are averaged from the last {0} epochs\\n'.format(last_n))\n",
    "print('model_0: UDV-v1            (avg:-{0}): T-loss {1}; V-loss {2}; T-acc {3}; V-acc {4}'.format(last_n, final_train_loss_m0, final_val_loss_m0, final_train_acc_m0, final_val_acc_m0))\n",
    "print('model_1: UDV-v2            (avg:-{0}): T-loss {1}; V-loss {2}; T-acc {3}; V-acc {4}'.format(last_n, final_train_loss_m1, final_val_loss_m1, final_train_acc_m1, final_val_acc_m1))\n",
    "print('model_2: UDV               (avg:-{0}): T-loss {1}; V-loss {2}; T-acc {3}; V-acc {4}'.format(last_n, final_train_loss_m2, final_val_loss_m2, final_train_acc_m2, final_val_acc_m2))\n",
    "print('model_3: UDV-s             (avg:-{0}): T-loss {1}; V-loss {2}; T-acc {3}; V-acc {4}'.format(last_n, final_train_loss_m3, final_val_loss_m3, final_train_acc_m3, final_val_acc_m3))\n",
    "print('model_4: UV-ReLU           (avg:-{0}): T-loss {1}; V-loss {2}; T-acc {3}; V-acc {4}'.format(last_n, final_train_loss_m4, final_val_loss_m4, final_train_acc_m4, final_val_acc_m4))\n",
    "print('model_5: UV                (avg:-{0}): T-loss {1}; V-loss {2}; T-acc {3}; V-acc {4}'.format(last_n, final_train_loss_m5, final_val_loss_m5, final_train_acc_m5, final_val_acc_m5))\n",
    "print('model_6: Transferred Model (avg:-{0}): T-loss {1}; V-loss {2}; T-acc {3}; V-acc {4}'.format(last_n, final_train_loss_m5, final_val_loss_m5, final_train_acc_m6, final_val_acc_m6))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
