{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e7325f48-2cd4-4cc2-a149-7eef1ed29f01",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "# <span style=\"color:blue\"> *UDV project*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d62d1060",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports libraries\n",
    "\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.preprocessing import MinMaxScaler, LabelEncoder\n",
    "from sklearn.impute import SimpleImputer\n",
    "\n",
    "import pickle\n",
    "import time\n",
    "import datetime\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3fbd59a-b3fe-4c1e-99af-f1e41904c67c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the global seed for reproducibility\n",
    "\n",
    "public_seed = 529\n",
    "torch.manual_seed(public_seed)\n",
    "print(\"Current Seed is {0}\".format(public_seed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64521320-8885-459a-b69d-af11048f4b7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Obtain device\n",
    "\n",
    "from udvFunctions.udvDevice import get_device\n",
    "device = get_device()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f3e66bc-2cf3-4763-8d96-eecf36485b88",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load dataset\n",
    "\n",
    "from udvFunctions.udvDatasetPreprocessing import NYCDataset\n",
    "data_path = \"./NYC_Orig.csv\"\n",
    "dataset = NYCDataset(data_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07cee382",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare UDV constraints\n",
    "\n",
    "from udvFunctions.udvConstraints import Matrix_bothside, UDV_Diag\n",
    "from udvFunctions.udvConstraints import Vector_left_U, Vector_right_V"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86ab3012",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare the single-connection layer that carry weight matrix 'w'\n",
    "\n",
    "from udvFunctions.udvDiagonalLayer import D_singleConnection   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14bfcc20-2be0-4187-8e02-b2090461410e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare networks\n",
    "\n",
    "from udvFunctions.udvRegNetworks import UDV_net_1\n",
    "from udvFunctions.udvRegNetworks import relu_net_1, fc_net_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6012e40-8ae9-4449-a1ce-6f662ef830be",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare loss function (optional)\n",
    "\n",
    "from udvFunctions.udvLoss import UDV_Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d97a9e38-9328-420a-a540-2212c8ad67f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare training framework\n",
    "\n",
    "# Model_0: Vector_uwv\n",
    "from udvFunctions.udvRegVectorUWV import udv_frame_v_uwv\n",
    "\n",
    "# Model_1: Vector_uv\n",
    "from udvFunctions.udvRegVectorUV import udv_frame_v_uv\n",
    "\n",
    "# Model_2: Matrix_uwv\n",
    "from udvFunctions.udvRegMatrixUWV import udv_frame_m_uwv\n",
    "\n",
    "# Model_3: Matrix_uv\n",
    "from udvFunctions.udvRegMatrixUV import udv_frame_m_uv\n",
    "\n",
    "# Model_4: ReLU\n",
    "from udvFunctions.udvRegReLU import udv_frame_relu\n",
    "\n",
    "# Model_5: Linear activation\n",
    "from udvFunctions.udvRegLinear import udv_frame_LAct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81c08633-a1dc-46d6-b137-1dd3f3bc41ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Identical initialisation\n",
    "\n",
    "from udvFunctions.udvSameInit import seedList1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd86ae55-4c1e-4fe5-8fc6-c3d9ba51536e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Other customised functions\n",
    "\n",
    "from udvFunctions.udvOtherFunctions import store_metrics, take_avg, check_shapes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0cd1610",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "# <span style=\"color:blue\"> *Setting Experiments:*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c48d712b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "# Dataset split\n",
    "training_ratio = 0.8\n",
    "train_size = int(training_ratio * len(dataset))\n",
    "validation_size = len(dataset) - train_size\n",
    "train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, validation_size])\n",
    "\n",
    "# Set number of epochs and how many different seeds for training \n",
    "num_epochs = 50         # Number of epochs; Min:5\n",
    "num_seeds = 100         # Number of seeds;  Min:1\n",
    "\n",
    "num_input = len(dataset.features[0,:])\n",
    "num_output = 1                                                                                  # Number of output feature\n",
    "num_hidden_1 = round((((num_output+2)*num_input)**0.5) + (2*((num_input/(num_output+2))**0.5))) # Number of hidde neurons in fully connected layer 1 (constrained layer)\n",
    "num_hidden_2 = round(num_output*((num_input/(num_output+2))**0.5))                              # Discarded (Expandable design)\n",
    "\n",
    "# Load data\n",
    "batch_size = 128\n",
    "train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)\n",
    "val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False)\n",
    "\n",
    "# Set optimiser (later) and loss function\n",
    "optimiser_name = 'Adam'             # 'Adam', 'NAdam', 'SGD', 'SGDM'\n",
    "SGD_M = 0.9                         # Only available when 'SGDM' is specified in optimiser_name\n",
    "learning_rate = 1e-4\n",
    "\n",
    "# loss_fn = UDV_Loss()    # MSE/2\n",
    "loss_fn = nn.MSELoss()  # MSE\n",
    "\n",
    "# Constraints setting\n",
    "vector_u_norm = 1\n",
    "vector_v_norm = 1\n",
    "matrix_uvnorm = 1\n",
    "d_threshold = 0\n",
    "d_boundto = 0\n",
    "\n",
    "# Fully Connect network without any activation leads many NaN with high learning rate\n",
    "if learning_rate >= 1:\n",
    "    test_LinearAct = False\n",
    "else:\n",
    "    test_LinearAct = True\n",
    "\n",
    "# time record \n",
    "time_record = []\n",
    "\n",
    "# Set result path\n",
    "result_path = './{0}_{1}_H1_{2}_H2_{3}_BS{4}_E{5}_S{6}/SingleLayer'.format(optimiser_name, learning_rate, num_hidden_1, num_hidden_2, batch_size, num_epochs, num_seeds)\n",
    "if not os.path.exists(result_path):\n",
    "    os.makedirs(result_path)  \n",
    "\n",
    "# Re-producible setting\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6fa48c7",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model with uwv constraints (Vector): Model_0_UDV-v1*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba5c9d38-ec9f-443e-ad53-8e9c8c17f49a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "# Generate parameters matrices from public seed\n",
    "init_ulist1, init_wlist1, init_vlist1 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "# To store train/validation loss\n",
    "full_train_loss_model_0 = []\n",
    "full_val_loss_model_0 = []\n",
    "\n",
    "# Saved weights\n",
    "u_1_V_model_0 = []\n",
    "w_1_V_model_0 = []\n",
    "v_1_V_model_0 = []\n",
    "\n",
    "# Load constraints\n",
    "constraints_u = Vector_left_U(normLim = vector_u_norm)\n",
    "constraints_d = UDV_Diag(threshold = d_threshold, boundTo = d_boundto)\n",
    "constraints_v = Vector_right_V(normLim = vector_v_norm)\n",
    "\n",
    "# Timer start\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "# Model training and validation \n",
    "for i in range (1, num_seeds + 1): \n",
    "    # Load constrained model\n",
    "    model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)\n",
    "    \n",
    "    # Use reproducible parameters overwrite the random initialization\n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.fc1.weight, list_sample = init_ulist1[i-1])\n",
    "        check_shapes(model_weight = model.diag1.weight, list_sample = init_wlist1[i-1])\n",
    "        check_shapes(model_weight = model.fc2.weight, list_sample = init_vlist1[i-1])\n",
    "        model.fc1.weight = nn.Parameter(init_ulist1[i-1])\n",
    "        model.diag1.weight = nn.Parameter(init_wlist1[i-1])\n",
    "        model.fc2.weight = nn.Parameter(init_vlist1[i-1])\n",
    "        \n",
    "    # Load optimiser \n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    # Model training/validation loops\n",
    "    model, train_losses, val_losses, save_weights_list = udv_frame_v_uwv(model = model, \n",
    "                                                                         optimizer = optimizer, \n",
    "                                                                         loss_fn = loss_fn, \n",
    "                                                                         train_loader = train_dataloader, \n",
    "                                                                         val_loader = val_dataloader, \n",
    "                                                                         num_epochs = num_epochs, \n",
    "                                                                         device = device,\n",
    "                                                                         constraints_u = constraints_u,\n",
    "                                                                         constraints_d = constraints_d,\n",
    "                                                                         constraints_v = constraints_v\n",
    "                                                                        )\n",
    "    # Record training/validation loss of all epochs \n",
    "    full_train_loss_model_0 = store_metrics(full_train_loss_model_0, train_losses)\n",
    "    full_val_loss_model_0 = store_metrics(full_val_loss_model_0, val_losses)\n",
    "    \n",
    "    # Save model weights\n",
    "    check_shapes(model_weight = model.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.diag1.weight, list_sample = save_weights_list[1])\n",
    "    check_shapes(model_weight = model.fc2.weight, list_sample = save_weights_list[2])\n",
    "    u_1_V_model_0.append(save_weights_list[0])\n",
    "    w_1_V_model_0.append(save_weights_list[1])\n",
    "    v_1_V_model_0.append(save_weights_list[2])\n",
    "    \n",
    "# Average training/validation loss from all seeds   \n",
    "avg_train_loss_model_0 = take_avg(full_train_loss_model_0)\n",
    "avg_val_loss_model_0 = take_avg(full_val_loss_model_0)\n",
    "\n",
    "# Timer end \n",
    "end_time = time.time()    \n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_ulist1, init_wlist1, init_vlist1, model, constraints_u, constraints_d, constraints_v, save_weights_list "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb4a5fd8",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model with uv constraints (Vector): Model_1_UDV-v2*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a0780be-de49-49f8-9823-0845cacd57ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "init_ulist1, init_wlist1, init_vlist1 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "full_train_loss_model_1 = []\n",
    "full_val_loss_model_1 = []\n",
    "\n",
    "u_1_V_model_1 = []\n",
    "w_1_V_model_1 = []\n",
    "v_1_V_model_1 = []\n",
    "\n",
    "constraints_u = Vector_left_U(normLim = vector_u_norm)\n",
    "constraints_v = Vector_right_V(normLim = vector_v_norm)\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range (1, num_seeds + 1):  \n",
    "    model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)\n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.fc1.weight, list_sample = init_ulist1[i-1])\n",
    "        check_shapes(model_weight = model.diag1.weight, list_sample = init_wlist1[i-1])\n",
    "        check_shapes(model_weight = model.fc2.weight, list_sample = init_vlist1[i-1])\n",
    "        model.fc1.weight = nn.Parameter(init_ulist1[i-1])\n",
    "        model.diag1.weight = nn.Parameter(init_wlist1[i-1])\n",
    "        model.fc2.weight = nn.Parameter(init_vlist1[i-1])\n",
    "\n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    model, train_losses, val_losses, save_weights_list = udv_frame_v_uv(model = model, \n",
    "                                                                        optimizer = optimizer, \n",
    "                                                                        loss_fn = loss_fn, \n",
    "                                                                        train_loader = train_dataloader, \n",
    "                                                                        val_loader = val_dataloader, \n",
    "                                                                        num_epochs = num_epochs, \n",
    "                                                                        device = device,\n",
    "                                                                        constraints_u = constraints_u,\n",
    "                                                                        constraints_v = constraints_v,\n",
    "                                                                       )  \n",
    "    full_train_loss_model_1 = store_metrics(full_train_loss_model_1, train_losses)\n",
    "    full_val_loss_model_1 = store_metrics(full_val_loss_model_1, val_losses)\n",
    "    \n",
    "    check_shapes(model_weight = model.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.diag1.weight, list_sample = save_weights_list[1])\n",
    "    check_shapes(model_weight = model.fc2.weight, list_sample = save_weights_list[2])\n",
    "    u_1_V_model_1.append(save_weights_list[0])\n",
    "    w_1_V_model_1.append(save_weights_list[1])\n",
    "    v_1_V_model_1.append(save_weights_list[2])\n",
    "\n",
    "avg_train_loss_model_1 = take_avg(full_train_loss_model_1)\n",
    "avg_val_loss_model_1 = take_avg(full_val_loss_model_1)\n",
    "\n",
    "end_time = time.time()    \n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_ulist1, init_wlist1, init_vlist1, model, constraints_u, constraints_v, save_weights_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5472d996",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model with uwv constraints (Matrix): Model_2_UDV*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e30e434",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "init_ulist1, init_wlist1, init_ulist2 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "full_train_loss_model_2 = []\n",
    "full_val_loss_model_2 = []\n",
    "\n",
    "u_1_M_model_2 = []\n",
    "w_1_M_model_2 = []\n",
    "u_2_M_model_2 = []\n",
    "\n",
    "constraints_uv = Matrix_bothside(normLim = matrix_uvnorm)\n",
    "constraints_d = UDV_Diag(threshold = d_threshold, boundTo = d_boundto)\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range (1, num_seeds + 1):\n",
    "    model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)\n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.fc1.weight, list_sample = init_ulist1[i-1])\n",
    "        check_shapes(model_weight = model.diag1.weight, list_sample = init_wlist1[i-1])\n",
    "        check_shapes(model_weight = model.fc2.weight, list_sample = init_ulist2[i-1])\n",
    "        model.fc1.weight = nn.Parameter(init_ulist1[i-1])\n",
    "        model.diag1.weight = nn.Parameter(init_wlist1[i-1])\n",
    "        model.fc2.weight = nn.Parameter(init_ulist2[i-1])\n",
    "    \n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    model, train_losses, val_losses, save_weights_list = udv_frame_m_uwv(model = model, \n",
    "                                                                         optimizer = optimizer, \n",
    "                                                                         loss_fn = loss_fn, \n",
    "                                                                         train_loader = train_dataloader, \n",
    "                                                                         val_loader = val_dataloader, \n",
    "                                                                         num_epochs = num_epochs, \n",
    "                                                                         device = device,\n",
    "                                                                         constraints_uv = constraints_uv,\n",
    "                                                                         constraints_d = constraints_d,\n",
    "                                                                        )  \n",
    "    full_train_loss_model_2 = store_metrics(full_train_loss_model_2, train_losses)\n",
    "    full_val_loss_model_2 = store_metrics(full_val_loss_model_2, val_losses)\n",
    "    \n",
    "    check_shapes(model_weight = model.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.diag1.weight, list_sample = save_weights_list[1])\n",
    "    check_shapes(model_weight = model.fc2.weight, list_sample = save_weights_list[2])\n",
    "    u_1_M_model_2.append(save_weights_list[0])\n",
    "    w_1_M_model_2.append(save_weights_list[1])\n",
    "    u_2_M_model_2.append(save_weights_list[2])\n",
    "\n",
    "avg_train_loss_model_2 = take_avg(full_train_loss_model_2)\n",
    "avg_val_loss_model_2 = take_avg(full_val_loss_model_2)\n",
    " \n",
    "end_time = time.time()    \n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_ulist1, init_wlist1, init_ulist2, model, constraints_uv, constraints_d, save_weights_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4fa2575",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model with uv constraints (Matrix): Model_3_UDV-s*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6a75acf",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "init_ulist1, init_wlist1, init_ulist2 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "full_train_loss_model_3 = []\n",
    "full_val_loss_model_3 = []\n",
    "\n",
    "u_1_M_model_3 = []\n",
    "w_1_M_model_3 = []\n",
    "u_2_M_model_3 = []\n",
    "\n",
    "constraints_uv = Matrix_bothside(normLim = matrix_uvnorm)\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range (1, num_seeds + 1):\n",
    "    model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)\n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.fc1.weight, list_sample = init_ulist1[i-1])\n",
    "        check_shapes(model_weight = model.diag1.weight, list_sample = init_wlist1[i-1])\n",
    "        check_shapes(model_weight = model.fc2.weight, list_sample = init_ulist2[i-1])\n",
    "        model.fc1.weight = nn.Parameter(init_ulist1[i-1])\n",
    "        model.diag1.weight = nn.Parameter(init_wlist1[i-1])\n",
    "        model.fc2.weight = nn.Parameter(init_ulist2[i-1])\n",
    "\n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    model, train_losses, val_losses, save_weights_list = udv_frame_m_uv(model = model, \n",
    "                                                                        optimizer = optimizer, \n",
    "                                                                        loss_fn = loss_fn, \n",
    "                                                                        train_loader = train_dataloader, \n",
    "                                                                        val_loader = val_dataloader, \n",
    "                                                                        num_epochs = num_epochs, \n",
    "                                                                        device = device,\n",
    "                                                                        constraints_uv = constraints_uv,\n",
    "                                                                       )  \n",
    "    full_train_loss_model_3 = store_metrics(full_train_loss_model_3, train_losses)\n",
    "    full_val_loss_model_3 = store_metrics(full_val_loss_model_3, val_losses)\n",
    "    \n",
    "    check_shapes(model_weight = model.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.diag1.weight, list_sample = save_weights_list[1])\n",
    "    check_shapes(model_weight = model.fc2.weight, list_sample = save_weights_list[2])\n",
    "    u_1_M_model_3.append(save_weights_list[0])\n",
    "    w_1_M_model_3.append(save_weights_list[1])\n",
    "    u_2_M_model_3.append(save_weights_list[2])\n",
    "\n",
    "avg_train_loss_model_3 = take_avg(full_train_loss_model_3)\n",
    "avg_val_loss_model_3 = take_avg(full_val_loss_model_3)\n",
    "\n",
    "end_time = time.time()    \n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_ulist1, init_wlist1, init_ulist2, model, constraints_uv, save_weights_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e630e64e",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model with ReLU: Model_4_UV_ReLU*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9ace0e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "init_llist1, init_wlist1, init_llist2 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "full_train_loss_model_4 = []\n",
    "full_val_loss_model_4 = []\n",
    "\n",
    "l_1_model_4 = []\n",
    "l_2_model_4 = []\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range (1, num_seeds + 1):\n",
    "    # Load relu model\n",
    "    model = relu_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)\n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.fc1.weight, list_sample = init_llist1[i-1])\n",
    "        check_shapes(model_weight = model.fc2.weight, list_sample = init_llist2[i-1])\n",
    "        model.fc1.weight = nn.Parameter(init_llist1[i-1])\n",
    "        model.fc2.weight = nn.Parameter(init_llist2[i-1])\n",
    "\n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    model, train_losses, val_losses, save_weights_list = udv_frame_relu(model = model, \n",
    "                                                                        optimizer = optimizer, \n",
    "                                                                        loss_fn = loss_fn, \n",
    "                                                                        train_loader = train_dataloader, \n",
    "                                                                        val_loader = val_dataloader, \n",
    "                                                                        num_epochs = num_epochs, \n",
    "                                                                        device = device\n",
    "                                                                       )  \n",
    "    full_train_loss_model_4 = store_metrics(full_train_loss_model_4, train_losses)\n",
    "    full_val_loss_model_4 = store_metrics(full_val_loss_model_4, val_losses)\n",
    "    \n",
    "    check_shapes(model_weight = model.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.fc2.weight, list_sample = save_weights_list[1])\n",
    "    l_1_model_4.append(save_weights_list[0])\n",
    "    l_2_model_4.append(save_weights_list[1])\n",
    "        \n",
    "avg_train_loss_model_4 = take_avg(full_train_loss_model_4)\n",
    "avg_val_loss_model_4 = take_avg(full_val_loss_model_4)\n",
    "\n",
    "end_time = time.time()    \n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_llist1, init_wlist1, init_llist2, model, save_weights_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7fb229a",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model is fully connected layers (NO activation): Model_5_UV*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43512d4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "init_llist1, init_wlist1, init_llist2 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "full_train_loss_model_5 = []\n",
    "full_val_loss_model_5 = []\n",
    "\n",
    "l_1_model_5 = []\n",
    "l_2_model_5 = []\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "if test_LinearAct:\n",
    "    for i in range (1, num_seeds + 1): \n",
    "\n",
    "        # Load fully connected model without activation (Linear activation)\n",
    "        model = fc_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)\n",
    "        with torch.no_grad():\n",
    "            check_shapes(model_weight = model.fc1.weight, list_sample = init_llist1[i-1])\n",
    "            check_shapes(model_weight = model.fc2.weight, list_sample = init_llist2[i-1])\n",
    "            model.fc1.weight = nn.Parameter(init_llist1[i-1])\n",
    "            model.fc2.weight = nn.Parameter(init_llist2[i-1])\n",
    "\n",
    "        optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "        #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "        model, train_losses, val_losses, save_weights_list = udv_frame_LAct(model = model, \n",
    "                                                                            optimizer = optimizer, \n",
    "                                                                            loss_fn = loss_fn, \n",
    "                                                                            train_loader = train_dataloader, \n",
    "                                                                            val_loader = val_dataloader, \n",
    "                                                                            num_epochs = num_epochs, \n",
    "                                                                            device = device\n",
    "                                                                           )  \n",
    "        full_train_loss_model_5 = store_metrics(full_train_loss_model_5, train_losses)\n",
    "        full_val_loss_model_5 = store_metrics(full_val_loss_model_5, val_losses)\n",
    "        \n",
    "        check_shapes(model_weight = model.fc1.weight, list_sample = save_weights_list[0])\n",
    "        check_shapes(model_weight = model.fc2.weight, list_sample = save_weights_list[1])\n",
    "        l_1_model_5.append(save_weights_list[0])\n",
    "        l_2_model_5.append(save_weights_list[1])\n",
    "        \n",
    "    avg_train_loss_model_5 = take_avg(full_train_loss_model_5)\n",
    "    avg_val_loss_model_5 = take_avg(full_val_loss_model_5)\n",
    "    \n",
    "    del model, save_weights_list\n",
    "else: # Duplicate ReLU result if model_5 is not trained to make saved data structure consistent\n",
    "    full_train_loss_model_5 = full_train_loss_model_4\n",
    "    full_val_loss_model_5 = full_val_loss_model_4\n",
    "    avg_train_loss_model_5 = avg_train_loss_model_4\n",
    "    avg_val_loss_model_5 = avg_val_loss_model_4\n",
    "    l_1_model_5 = l_1_model_4\n",
    "    l_2_model_5 = l_2_model_4\n",
    "\n",
    "end_time = time.time()    \n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_llist1, init_wlist1, init_llist2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "073ce221",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *Saving*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "967b4922-6e99-4795-ad8b-907223215636",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save variables for analysis\n",
    "\n",
    "store_file_path = '{0}/results.pkl'.format(result_path)\n",
    "\n",
    "variables = {\n",
    "    'device': device,\n",
    "    'data_path': data_path,\n",
    "    'public_seed': public_seed,\n",
    "    \n",
    "    'num_input': num_input,\n",
    "    'num_hidden_1': num_hidden_1,\n",
    "    'num_hidden_2': num_hidden_2,\n",
    "    'num_output': num_output,\n",
    "    \n",
    "    'batch_size': batch_size,\n",
    "    'num_epochs': num_epochs,\n",
    "    'num_seeds': num_seeds,\n",
    "    'optimiser_name': optimiser_name,\n",
    "    'learning_rate': learning_rate,\n",
    "    \n",
    "    'full_train_loss_model_0': full_train_loss_model_0,\n",
    "    'full_train_loss_model_1': full_train_loss_model_1,\n",
    "    'full_train_loss_model_2': full_train_loss_model_2,\n",
    "    'full_train_loss_model_3': full_train_loss_model_3,\n",
    "    'full_train_loss_model_4': full_train_loss_model_4,\n",
    "    'full_train_loss_model_5': full_train_loss_model_5,\n",
    "    'avg_train_loss_model_0': avg_train_loss_model_0,\n",
    "    'avg_train_loss_model_1': avg_train_loss_model_1,\n",
    "    'avg_train_loss_model_2': avg_train_loss_model_2,\n",
    "    'avg_train_loss_model_3': avg_train_loss_model_3,\n",
    "    'avg_train_loss_model_4': avg_train_loss_model_4,\n",
    "    'avg_train_loss_model_5': avg_train_loss_model_5,\n",
    "    'full_val_loss_model_0': full_val_loss_model_0,\n",
    "    'full_val_loss_model_1': full_val_loss_model_1,\n",
    "    'full_val_loss_model_2': full_val_loss_model_2,\n",
    "    'full_val_loss_model_3': full_val_loss_model_3,\n",
    "    'full_val_loss_model_4': full_val_loss_model_4,\n",
    "    'full_val_loss_model_5': full_val_loss_model_5,\n",
    "    'avg_val_loss_model_0': avg_val_loss_model_0,\n",
    "    'avg_val_loss_model_1': avg_val_loss_model_1,\n",
    "    'avg_val_loss_model_2': avg_val_loss_model_2,\n",
    "    'avg_val_loss_model_3': avg_val_loss_model_3,\n",
    "    'avg_val_loss_model_4': avg_val_loss_model_4,\n",
    "    'avg_val_loss_model_5': avg_val_loss_model_5,\n",
    "    \n",
    "    'u_1_V_model_0': u_1_V_model_0,\n",
    "    'w_1_V_model_0': w_1_V_model_0,\n",
    "    'v_1_V_model_0': v_1_V_model_0,\n",
    "    \n",
    "    'u_1_V_model_1': u_1_V_model_1,\n",
    "    'w_1_V_model_1': w_1_V_model_1,\n",
    "    'v_1_V_model_1': v_1_V_model_1,\n",
    "    \n",
    "    'u_1_M_model_2': u_1_M_model_2,\n",
    "    'w_1_M_model_2': w_1_M_model_2,\n",
    "    'u_2_M_model_2': u_2_M_model_2,\n",
    "\n",
    "    'u_1_M_model_3': u_1_M_model_3,\n",
    "    'w_1_M_model_3': w_1_M_model_3,\n",
    "    'u_2_M_model_3': u_2_M_model_3,\n",
    "\n",
    "    'l_1_model_4': l_1_model_4,\n",
    "    'l_2_model_4': l_2_model_4,\n",
    "    \n",
    "    'l_1_model_5': l_1_model_5,\n",
    "    'l_2_model_5': l_2_model_5,\n",
    "    \n",
    "    'time_record': time_record\n",
    "}\n",
    "\n",
    "with open(store_file_path, 'wb') as file:\n",
    "    pickle.dump(variables, file)\n",
    "\n",
    "print(\"Saving pickle file, done\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b9e9923-4a6d-45be-a238-e7ca5e4cc9c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from udvFunctions.udvOtherFunctions import avg_stable\n",
    "\n",
    "last_n = 5            # how many epochs (from end) are averaged\n",
    "final_train_loss_m0 = avg_stable(avg_train_loss_model_0, last_n)\n",
    "final_train_loss_m1 = avg_stable(avg_train_loss_model_1, last_n)\n",
    "final_train_loss_m2 = avg_stable(avg_train_loss_model_2, last_n)\n",
    "final_train_loss_m3 = avg_stable(avg_train_loss_model_3, last_n)\n",
    "final_train_loss_m4 = avg_stable(avg_train_loss_model_4, last_n)\n",
    "final_train_loss_m5 = avg_stable(avg_train_loss_model_5, last_n)\n",
    "\n",
    "final_val_loss_m0 = avg_stable(avg_val_loss_model_0, last_n)\n",
    "final_val_loss_m1 = avg_stable(avg_val_loss_model_1, last_n)\n",
    "final_val_loss_m2 = avg_stable(avg_val_loss_model_2, last_n)\n",
    "final_val_loss_m3 = avg_stable(avg_val_loss_model_3, last_n)\n",
    "final_val_loss_m4 = avg_stable(avg_val_loss_model_4, last_n)\n",
    "final_val_loss_m5 = avg_stable(avg_val_loss_model_5, last_n)\n",
    "\n",
    "print('Batch size: {0}\\n#Seeds: {1}\\n#Epochs: {2}\\nOptimier: {3}\\nLearning Rate: {4}\\n#Hidden_1: {5}\\n#Hidden_2: {6}'.format(batch_size, num_seeds, num_epochs, optimiser_name, learning_rate, num_hidden_1, num_hidden_2))\n",
    "print('Final results are averaged from the last {0} epochs\\n'.format(last_n))\n",
    "print('model_0: UDV-v1      (avg:-{0}): T-loss {1}; V-loss is {2}'.format(last_n, final_train_loss_m0, final_val_loss_m0))\n",
    "print('model_1: UDV-v2      (avg:-{0}): T-loss {1}; V-loss is {2}'.format(last_n, final_train_loss_m1, final_val_loss_m1))\n",
    "print('model_2: UDV         (avg:-{0}): T-loss {1}; V-loss is {2}'.format(last_n, final_train_loss_m2, final_val_loss_m2))\n",
    "print('model_3: UDV-s       (avg:-{0}): T-loss {1}; V-loss is {2}'.format(last_n, final_train_loss_m3, final_val_loss_m3))\n",
    "print('model_4: UV-ReLU     (avg:-{0}): T-loss {1}; V-loss is {2}'.format(last_n, final_train_loss_m4, final_val_loss_m4))\n",
    "print('model_5: UV.         (avg:-{0}): T-loss {1}; V-loss is {2}'.format(last_n, final_train_loss_m5, final_val_loss_m5))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
