{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e7325f48-2cd4-4cc2-a149-7eef1ed29f01",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "# <span style=\"color:blue\"> *UDV project*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d62d1060",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports libraries\n",
    "\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.preprocessing import MinMaxScaler, LabelEncoder\n",
    "from sklearn.impute import SimpleImputer\n",
    "\n",
    "import pickle\n",
    "import time\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3fbd59a-b3fe-4c1e-99af-f1e41904c67c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the global seed for reproducibility\n",
    "\n",
    "public_seed = 529\n",
    "torch.manual_seed(public_seed)\n",
    "print(\"Current Seed is {0}\".format(public_seed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64521320-8885-459a-b69d-af11048f4b7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Obtain device\n",
    "\n",
    "from udvFunctions.udvDevice import get_device\n",
    "device = get_device()\n",
    "device = \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f3e66bc-2cf3-4763-8d96-eecf36485b88",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load dataset\n",
    "\n",
    "from udvFunctions.udvDatasetPreprocessing import HousingPriceDataset\n",
    "data_path = \"./HP_Orig.csv\"\n",
    "dataset = HousingPriceDataset(data_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07cee382",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare UDV constraints\n",
    "\n",
    "from udvFunctions.udvConstraints import Matrix_bothside, UDV_Diag\n",
    "from udvFunctions.udvConstraints import Vector_left_U, Vector_right_V"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86ab3012",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare the single-connection layer that carry weight matrix 'w'\n",
    "\n",
    "from udvFunctions.udvDiagonalLayer import D_singleConnection   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14bfcc20-2be0-4187-8e02-b2090461410e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare networks\n",
    "\n",
    "from udvFunctions.udvRegNetworks import UDV_net_1, UDV_relu_1\n",
    "from udvFunctions.udvRegNetworks import relu_net_1, fc_net_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6012e40-8ae9-4449-a1ce-6f662ef830be",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare loss function (optional)\n",
    "\n",
    "from udvFunctions.udvLoss import UDV_Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d97a9e38-9328-420a-a540-2212c8ad67f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare training framework\n",
    "\n",
    "# Model_0: Matrix_uwv\n",
    "from udvFunctions.udvRegMatrixUWV import udv_frame_m_uwv\n",
    "\n",
    "# Model_1: Matrix_uv\n",
    "from udvFunctions.udvRegMatrixUV import udv_frame_m_uv\n",
    "\n",
    "# Model_4-5: ReLU,UV_ReLU\n",
    "from udvFunctions.udvRegReLU import udv_frame_relu, uvrelu_frame_m_uv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81c08633-a1dc-46d6-b137-1dd3f3bc41ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Identical initialisation\n",
    "\n",
    "from udvFunctions.udvSameInit import seedList1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd86ae55-4c1e-4fe5-8fc6-c3d9ba51536e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Other customised functions\n",
    "\n",
    "from udvFunctions.udvOtherFunctions import store_metrics, take_avg, check_shapes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0cd1610",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "# <span style=\"color:blue\"> *Setting Experiments:*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c48d712b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "# Dataset split\n",
    "training_ratio = 0.8\n",
    "train_size = int(training_ratio * len(dataset))\n",
    "validation_size = len(dataset) - train_size\n",
    "train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, validation_size])\n",
    "\n",
    "# Set number of epochs and how many different seeds for training \n",
    "num_epochs = 100         # Number of epochs; Min:5; 200\n",
    "num_seeds = 50         # Number of seeds;  Min:1; 1000\n",
    "\n",
    "num_input = len(dataset.features[0,:])\n",
    "num_output = 1                                                                                  # Number of output feature\n",
    "num_hidden_1 = round((((num_output+2)*num_input)**0.5) + (2*((num_input/(num_output+2))**0.5))) # Number of hidde neurons in fully connected layer 1 (constrained layer)\n",
    "num_hidden_2 = round(num_output*((num_input/(num_output+2))**0.5))                              # Discarded (Expandable design)\n",
    "\n",
    "# Load data\n",
    "batch_size = 128\n",
    "train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)\n",
    "val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False)\n",
    "\n",
    "# Load data (Full batch)\n",
    "# batch_size = \"full\"\n",
    "# train_dataloader = DataLoader(train_dataset, batch_size = len(train_dataset), shuffle = True)\n",
    "# val_dataloader = DataLoader(val_dataset, batch_size = len(val_dataset), shuffle = False)\n",
    "\n",
    "# Set optimiser (later) and loss function\n",
    "optimiser_name = 'Adam'             # 'Adam', 'NAdam', 'SGD', 'SGDM'\n",
    "SGD_M = 0.9                         # Only available when 'SGDM' is specified in optimiser_name\n",
    "learning_rate = 1e-3\n",
    "\n",
    "# loss_fn = UDV_Loss()    # MSE/2\n",
    "loss_fn = nn.MSELoss()  # MSE\n",
    "\n",
    "# Constraints setting\n",
    "vector_u_norm = 1\n",
    "vector_v_norm = 1\n",
    "matrix_uvnorm = 1\n",
    "d_threshold = 0\n",
    "d_boundto = 0\n",
    "\n",
    "# Fully Connect network without any activation leads many NaN with high learning rate\n",
    "if learning_rate >= 1:\n",
    "    test_LinearAct = False\n",
    "else:\n",
    "    test_LinearAct = True\n",
    "\n",
    "# time record \n",
    "time_record = []\n",
    "\n",
    "# Set result path\n",
    "result_path = './{0}_{1}_H1_{2}_H2_{3}_BS{4}_E{5}_S{6}/SingleLayer'.format(optimiser_name, learning_rate, num_hidden_1, num_hidden_2, batch_size, num_epochs, num_seeds)\n",
    "if not os.path.exists(result_path):\n",
    "    os.makedirs(result_path)  \n",
    "\n",
    "# Re-producible setting\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5472d996",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model with uwv constraints (Matrix): Model_0_UDV*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e30e434",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "init_ulist1, init_wlist1, init_ulist2 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "full_train_loss_model_0 = []\n",
    "full_val_loss_model_0 = []\n",
    "\n",
    "u_1_M_model_0 = []\n",
    "w_1_M_model_0 = []\n",
    "u_2_M_model_0 = []\n",
    "\n",
    "constraints_uv = Matrix_bothside(normLim = matrix_uvnorm)\n",
    "constraints_d = UDV_Diag(threshold = d_threshold, boundTo = d_boundto)\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range (1, num_seeds + 1):\n",
    "    model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)\n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.fc1.weight, list_sample = init_ulist1[i-1])\n",
    "        check_shapes(model_weight = model.diag1.weight, list_sample = init_wlist1[i-1])\n",
    "        check_shapes(model_weight = model.fc2.weight, list_sample = init_ulist2[i-1])\n",
    "        model.fc1.weight = nn.Parameter(init_ulist1[i-1])\n",
    "        model.diag1.weight = nn.Parameter(init_wlist1[i-1])\n",
    "        model.fc2.weight = nn.Parameter(init_ulist2[i-1])\n",
    "    \n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    model, train_losses, val_losses, save_weights_list = udv_frame_m_uwv(model = model, \n",
    "                                                                         optimizer = optimizer, \n",
    "                                                                         loss_fn = loss_fn, \n",
    "                                                                         train_loader = train_dataloader, \n",
    "                                                                         val_loader = val_dataloader, \n",
    "                                                                         num_epochs = num_epochs, \n",
    "                                                                         device = device,\n",
    "                                                                         constraints_uv = constraints_uv,\n",
    "                                                                         constraints_d = constraints_d,\n",
    "                                                                        )  \n",
    "    full_train_loss_model_0 = store_metrics(full_train_loss_model_0, train_losses)\n",
    "    full_val_loss_model_0 = store_metrics(full_val_loss_model_0, val_losses)\n",
    "    \n",
    "    check_shapes(model_weight = model.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.diag1.weight, list_sample = save_weights_list[1])\n",
    "    check_shapes(model_weight = model.fc2.weight, list_sample = save_weights_list[2])\n",
    "    u_1_M_model_0.append(save_weights_list[0])\n",
    "    w_1_M_model_0.append(save_weights_list[1])\n",
    "    u_2_M_model_0.append(save_weights_list[2])\n",
    "\n",
    "avg_train_loss_model_0 = take_avg(full_train_loss_model_0)\n",
    "avg_val_loss_model_0 = take_avg(full_val_loss_model_0)\n",
    " \n",
    "end_time = time.time()    \n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_ulist1, init_wlist1, init_ulist2, model, constraints_uv, constraints_d, save_weights_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4fa2575",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model with uv constraints (Matrix): Model_1_UDV-s*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6a75acf",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "init_ulist1, init_wlist1, init_ulist2 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "full_train_loss_model_1 = []\n",
    "full_val_loss_model_1 = []\n",
    "\n",
    "u_1_M_model_1 = []\n",
    "w_1_M_model_1 = []\n",
    "u_2_M_model_1 = []\n",
    "\n",
    "constraints_uv = Matrix_bothside(normLim = matrix_uvnorm)\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range (1, num_seeds + 1):\n",
    "    model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)\n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.fc1.weight, list_sample = init_ulist1[i-1])\n",
    "        check_shapes(model_weight = model.diag1.weight, list_sample = init_wlist1[i-1])\n",
    "        check_shapes(model_weight = model.fc2.weight, list_sample = init_ulist2[i-1])\n",
    "        model.fc1.weight = nn.Parameter(init_ulist1[i-1])\n",
    "        model.diag1.weight = nn.Parameter(init_wlist1[i-1])\n",
    "        model.fc2.weight = nn.Parameter(init_ulist2[i-1])\n",
    "\n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    model, train_losses, val_losses, save_weights_list = udv_frame_m_uv(model = model, \n",
    "                                                                        optimizer = optimizer, \n",
    "                                                                        loss_fn = loss_fn, \n",
    "                                                                        train_loader = train_dataloader, \n",
    "                                                                        val_loader = val_dataloader, \n",
    "                                                                        num_epochs = num_epochs, \n",
    "                                                                        device = device,\n",
    "                                                                        constraints_uv = constraints_uv,\n",
    "                                                                       )  \n",
    "    full_train_loss_model_1 = store_metrics(full_train_loss_model_1, train_losses)\n",
    "    full_val_loss_model_1 = store_metrics(full_val_loss_model_1, val_losses)\n",
    "    \n",
    "    check_shapes(model_weight = model.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.diag1.weight, list_sample = save_weights_list[1])\n",
    "    check_shapes(model_weight = model.fc2.weight, list_sample = save_weights_list[2])\n",
    "    u_1_M_model_1.append(save_weights_list[0])\n",
    "    w_1_M_model_1.append(save_weights_list[1])\n",
    "    u_2_M_model_1.append(save_weights_list[2])\n",
    "\n",
    "avg_train_loss_model_1 = take_avg(full_train_loss_model_1)\n",
    "avg_val_loss_model_1 = take_avg(full_val_loss_model_1)\n",
    "\n",
    "end_time = time.time()    \n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_ulist1, init_wlist1, init_ulist2, model, constraints_uv, save_weights_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14bc0f7e-372a-4e92-862b-dc5cb3b797bc",
   "metadata": {},
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model with ReLU_Diag_uwv constraints (Matrix): Model_2_UDV_ReLU*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86e33f9a-f1cb-4d20-998e-ffaad165a607",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "init_ulist1, init_wlist1, init_ulist2 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "full_train_loss_model_2 = []\n",
    "full_val_loss_model_2 = []\n",
    "\n",
    "u_1_M_model_2 = []\n",
    "w_1_M_model_2 = []\n",
    "u_2_M_model_2 = []\n",
    "\n",
    "constraints_uv = Matrix_bothside(normLim = matrix_uvnorm)\n",
    "constraints_d = UDV_Diag(threshold = d_threshold, boundTo = d_boundto)\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range (1, num_seeds + 1):\n",
    "    model = UDV_relu_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)\n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.fc1.weight, list_sample = init_ulist1[i-1])\n",
    "        check_shapes(model_weight = model.diag1.weight, list_sample = init_wlist1[i-1])\n",
    "        check_shapes(model_weight = model.fc2.weight, list_sample = init_ulist2[i-1])\n",
    "        model.fc1.weight = nn.Parameter(init_ulist1[i-1])\n",
    "        model.diag1.weight = nn.Parameter(init_wlist1[i-1])\n",
    "        model.fc2.weight = nn.Parameter(init_ulist2[i-1])\n",
    "    \n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    model, train_losses, val_losses, save_weights_list = udv_frame_m_uwv(model = model, \n",
    "                                                                         optimizer = optimizer, \n",
    "                                                                         loss_fn = loss_fn, \n",
    "                                                                         train_loader = train_dataloader, \n",
    "                                                                         val_loader = val_dataloader, \n",
    "                                                                         num_epochs = num_epochs, \n",
    "                                                                         device = device,\n",
    "                                                                         constraints_uv = constraints_uv,\n",
    "                                                                         constraints_d = constraints_d,\n",
    "                                                                        )  \n",
    "    full_train_loss_model_2 = store_metrics(full_train_loss_model_2, train_losses)\n",
    "    full_val_loss_model_2 = store_metrics(full_val_loss_model_2, val_losses)\n",
    "    \n",
    "    check_shapes(model_weight = model.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.diag1.weight, list_sample = save_weights_list[1])\n",
    "    check_shapes(model_weight = model.fc2.weight, list_sample = save_weights_list[2])\n",
    "    u_1_M_model_2.append(save_weights_list[0])\n",
    "    w_1_M_model_2.append(save_weights_list[1])\n",
    "    u_2_M_model_2.append(save_weights_list[2])\n",
    "\n",
    "avg_train_loss_model_2 = take_avg(full_train_loss_model_2)\n",
    "avg_val_loss_model_2 = take_avg(full_val_loss_model_2)\n",
    " \n",
    "end_time = time.time()    \n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_ulist1, init_wlist1, init_ulist2, model, constraints_uv, constraints_d, save_weights_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "edc23c64-1faa-4ecd-9fe7-8cfe81e0954e",
   "metadata": {},
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model with ReLU_Diag_uv constraints (Matrix): Model_3_UDV_ReLU-s*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fc71d8b-3b08-4861-9319-1f582c6744d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "init_ulist1, init_wlist1, init_ulist2 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "full_train_loss_model_3 = []\n",
    "full_val_loss_model_3 = []\n",
    "\n",
    "u_1_M_model_3 = []\n",
    "w_1_M_model_3 = []\n",
    "u_2_M_model_3 = []\n",
    "\n",
    "constraints_uv = Matrix_bothside(normLim = matrix_uvnorm)\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range (1, num_seeds + 1):\n",
    "    model = UDV_relu_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)\n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.fc1.weight, list_sample = init_ulist1[i-1])\n",
    "        check_shapes(model_weight = model.diag1.weight, list_sample = init_wlist1[i-1])\n",
    "        check_shapes(model_weight = model.fc2.weight, list_sample = init_ulist2[i-1])\n",
    "        model.fc1.weight = nn.Parameter(init_ulist1[i-1])\n",
    "        model.diag1.weight = nn.Parameter(init_wlist1[i-1])\n",
    "        model.fc2.weight = nn.Parameter(init_ulist2[i-1])\n",
    "\n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    model, train_losses, val_losses, save_weights_list = udv_frame_m_uv(model = model, \n",
    "                                                                        optimizer = optimizer, \n",
    "                                                                        loss_fn = loss_fn, \n",
    "                                                                        train_loader = train_dataloader, \n",
    "                                                                        val_loader = val_dataloader, \n",
    "                                                                        num_epochs = num_epochs, \n",
    "                                                                        device = device,\n",
    "                                                                        constraints_uv = constraints_uv,\n",
    "                                                                       )  \n",
    "    full_train_loss_model_3 = store_metrics(full_train_loss_model_3, train_losses)\n",
    "    full_val_loss_model_3 = store_metrics(full_val_loss_model_3, val_losses)\n",
    "    \n",
    "    check_shapes(model_weight = model.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.diag1.weight, list_sample = save_weights_list[1])\n",
    "    check_shapes(model_weight = model.fc2.weight, list_sample = save_weights_list[2])\n",
    "    u_1_M_model_3.append(save_weights_list[0])\n",
    "    w_1_M_model_3.append(save_weights_list[1])\n",
    "    u_2_M_model_3.append(save_weights_list[2])\n",
    "\n",
    "avg_train_loss_model_3 = take_avg(full_train_loss_model_3)\n",
    "avg_val_loss_model_3 = take_avg(full_val_loss_model_3)\n",
    "\n",
    "end_time = time.time()    \n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_ulist1, init_wlist1, init_ulist2, model, constraints_uv, save_weights_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e630e64e",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model with ReLU: Model_4_UV_ReLU*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9ace0e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "init_llist1, init_wlist1, init_llist2 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "full_train_loss_model_4 = []\n",
    "full_val_loss_model_4 = []\n",
    "\n",
    "l_1_model_4 = []\n",
    "l_2_model_4 = []\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range (1, num_seeds + 1):\n",
    "    # Load relu model\n",
    "    model = relu_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)\n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.fc1.weight, list_sample = init_llist1[i-1])\n",
    "        check_shapes(model_weight = model.fc2.weight, list_sample = init_llist2[i-1])\n",
    "        model.fc1.weight = nn.Parameter(init_llist1[i-1])\n",
    "        model.fc2.weight = nn.Parameter(init_llist2[i-1])\n",
    "\n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    model, train_losses, val_losses, save_weights_list = udv_frame_relu(model = model, \n",
    "                                                                        optimizer = optimizer, \n",
    "                                                                        loss_fn = loss_fn, \n",
    "                                                                        train_loader = train_dataloader, \n",
    "                                                                        val_loader = val_dataloader, \n",
    "                                                                        num_epochs = num_epochs, \n",
    "                                                                        device = device\n",
    "                                                                       )  \n",
    "    full_train_loss_model_4 = store_metrics(full_train_loss_model_4, train_losses)\n",
    "    full_val_loss_model_4 = store_metrics(full_val_loss_model_4, val_losses)\n",
    "    \n",
    "    check_shapes(model_weight = model.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.fc2.weight, list_sample = save_weights_list[1])\n",
    "    l_1_model_4.append(save_weights_list[0])\n",
    "    l_2_model_4.append(save_weights_list[1])\n",
    "        \n",
    "avg_train_loss_model_4 = take_avg(full_train_loss_model_4)\n",
    "avg_val_loss_model_4 = take_avg(full_val_loss_model_4)\n",
    "\n",
    "end_time = time.time()    \n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_llist1, init_wlist1, init_llist2, model, save_weights_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7fb229a",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *The following model with ReLU_uv constraints (Matrix): Model_5_UV_ReLU(constrained)*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43512d4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(public_seed)\n",
    "\n",
    "init_llist1, init_wlist1, init_llist2 = seedList1(num_seeds = num_seeds, \n",
    "                                                  public_seed = public_seed,\n",
    "                                                  num_input = num_input,\n",
    "                                                  num_hidden_1 = num_hidden_1, \n",
    "                                                  num_output = num_output\n",
    "                                                 )\n",
    "\n",
    "full_train_loss_model_5 = []\n",
    "full_val_loss_model_5 = []\n",
    "\n",
    "l_1_model_5 = []\n",
    "l_2_model_5 = []\n",
    "\n",
    "constraints_uv = Matrix_bothside(normLim = matrix_uvnorm)\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range (1, num_seeds + 1):\n",
    "    model = relu_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)\n",
    "    with torch.no_grad():\n",
    "        check_shapes(model_weight = model.fc1.weight, list_sample = init_llist1[i-1])\n",
    "        check_shapes(model_weight = model.fc2.weight, list_sample = init_llist2[i-1])\n",
    "        model.fc1.weight = nn.Parameter(init_llist1[i-1])\n",
    "        model.fc2.weight = nn.Parameter(init_llist2[i-1])\n",
    "\n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate) # Or replace 'Adam' by 'NAdam' or 'SGD'\n",
    "    #optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = SGD_M) # Only available when 'SGDM' is specified in optimiser_name\n",
    "    \n",
    "    model, train_losses, val_losses, save_weights_list = uvrelu_frame_m_uv(model = model, \n",
    "                                                                           optimizer = optimizer, \n",
    "                                                                           loss_fn = loss_fn, \n",
    "                                                                           train_loader = train_dataloader, \n",
    "                                                                           val_loader = val_dataloader, \n",
    "                                                                           num_epochs = num_epochs, \n",
    "                                                                           device = device,\n",
    "                                                                           constraints_uv = constraints_uv,\n",
    "                                                                          )  \n",
    "    full_train_loss_model_5 = store_metrics(full_train_loss_model_5, train_losses)\n",
    "    full_val_loss_model_5 = store_metrics(full_val_loss_model_5, val_losses)\n",
    "    \n",
    "    check_shapes(model_weight = model.fc1.weight, list_sample = save_weights_list[0])\n",
    "    check_shapes(model_weight = model.fc2.weight, list_sample = save_weights_list[1])\n",
    "    l_1_model_5.append(save_weights_list[0])\n",
    "    l_2_model_5.append(save_weights_list[1])\n",
    "\n",
    "avg_train_loss_model_5 = take_avg(full_train_loss_model_5)\n",
    "avg_val_loss_model_5 = take_avg(full_val_loss_model_5)\n",
    "\n",
    "end_time = time.time()    \n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del init_llist1, init_wlist1, init_llist2, model, constraints_uv, save_weights_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "073ce221",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\"> *---------------------------------------*</span>\n",
    "## <span style=\"color:blue\"> *Saving*</span>\n",
    "### <span style=\"color:blue\"> *---------------------------------------*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "634fb413-ee1a-4150-90d3-d9dfc8b26ed6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save variables for analysis\n",
    "\n",
    "store_file_path = '{0}/results.pkl'.format(result_path)\n",
    "\n",
    "variables = {\n",
    "    'device': device,\n",
    "    'data_path': data_path,\n",
    "    'public_seed': public_seed,\n",
    "    \n",
    "    'num_input': num_input,\n",
    "    'num_hidden_1': num_hidden_1,\n",
    "    'num_hidden_2': num_hidden_2,\n",
    "    'num_output': num_output,\n",
    "    \n",
    "    'batch_size': batch_size,\n",
    "    'num_epochs': num_epochs,\n",
    "    'num_seeds': num_seeds,\n",
    "    'optimiser_name': optimiser_name,\n",
    "    'learning_rate': learning_rate,\n",
    "    \n",
    "    'full_train_loss_model_0': full_train_loss_model_0,\n",
    "    'full_train_loss_model_1': full_train_loss_model_1,\n",
    "    'full_train_loss_model_2': full_train_loss_model_2,\n",
    "    'full_train_loss_model_3': full_train_loss_model_3,\n",
    "    'full_train_loss_model_4': full_train_loss_model_4,\n",
    "    'full_train_loss_model_5': full_train_loss_model_5,\n",
    "    'avg_train_loss_model_0': avg_train_loss_model_0,\n",
    "    'avg_train_loss_model_1': avg_train_loss_model_1,\n",
    "    'avg_train_loss_model_2': avg_train_loss_model_2,\n",
    "    'avg_train_loss_model_3': avg_train_loss_model_3,\n",
    "    'avg_train_loss_model_4': avg_train_loss_model_4,\n",
    "    'avg_train_loss_model_5': avg_train_loss_model_5,\n",
    "    'full_val_loss_model_0': full_val_loss_model_0,\n",
    "    'full_val_loss_model_1': full_val_loss_model_1,\n",
    "    'full_val_loss_model_2': full_val_loss_model_2,\n",
    "    'full_val_loss_model_3': full_val_loss_model_3,\n",
    "    'full_val_loss_model_4': full_val_loss_model_4,\n",
    "    'full_val_loss_model_5': full_val_loss_model_5,\n",
    "    'avg_val_loss_model_0': avg_val_loss_model_0,\n",
    "    'avg_val_loss_model_1': avg_val_loss_model_1,\n",
    "    'avg_val_loss_model_2': avg_val_loss_model_2,\n",
    "    'avg_val_loss_model_3': avg_val_loss_model_3,\n",
    "    'avg_val_loss_model_4': avg_val_loss_model_4,\n",
    "    'avg_val_loss_model_5': avg_val_loss_model_5,\n",
    "    \n",
    "    'u_1_M_model_0': u_1_M_model_0,\n",
    "    'w_1_M_model_0': w_1_M_model_0,\n",
    "    'u_2_M_model_0': u_2_M_model_0,\n",
    "    \n",
    "    'u_1_M_model_1': u_1_M_model_1,\n",
    "    'w_1_M_model_1': w_1_M_model_1,\n",
    "    'u_2_M_model_1': u_2_M_model_1,\n",
    "    \n",
    "    'u_1_M_model_2': u_1_M_model_2,\n",
    "    'w_1_M_model_2': w_1_M_model_2,\n",
    "    'u_2_M_model_2': u_2_M_model_2,\n",
    "\n",
    "    'u_1_M_model_3': u_1_M_model_3,\n",
    "    'w_1_M_model_3': w_1_M_model_3,\n",
    "    'u_2_M_model_3': u_2_M_model_3,\n",
    "\n",
    "    'l_1_model_4': l_1_model_4,\n",
    "    'l_2_model_4': l_2_model_4,\n",
    "    \n",
    "    'l_1_model_5': l_1_model_5,\n",
    "    'l_2_model_5': l_2_model_5,\n",
    "    \n",
    "    'time_record': time_record\n",
    "}\n",
    "\n",
    "with open(store_file_path, 'wb') as file:\n",
    "    pickle.dump(variables, file)\n",
    "\n",
    "print(\"Saving pickle file, done\")\n",
    "\n",
    "from udvFunctions.udvOtherFunctions import avg_stable\n",
    "\n",
    "last_n = 20           # how many epochs (from end) are averaged\n",
    "final_train_loss_m0 = avg_stable(avg_train_loss_model_0, last_n)\n",
    "final_train_loss_m1 = avg_stable(avg_train_loss_model_1, last_n)\n",
    "final_train_loss_m2 = avg_stable(avg_train_loss_model_2, last_n)\n",
    "final_train_loss_m3 = avg_stable(avg_train_loss_model_3, last_n)\n",
    "final_train_loss_m4 = avg_stable(avg_train_loss_model_4, last_n)\n",
    "final_train_loss_m5 = avg_stable(avg_train_loss_model_5, last_n)\n",
    "\n",
    "final_val_loss_m0 = avg_stable(avg_val_loss_model_0, last_n)\n",
    "final_val_loss_m1 = avg_stable(avg_val_loss_model_1, last_n)\n",
    "final_val_loss_m2 = avg_stable(avg_val_loss_model_2, last_n)\n",
    "final_val_loss_m3 = avg_stable(avg_val_loss_model_3, last_n)\n",
    "final_val_loss_m4 = avg_stable(avg_val_loss_model_4, last_n)\n",
    "final_val_loss_m5 = avg_stable(avg_val_loss_model_5, last_n)\n",
    "\n",
    "print('Batch size: {0}\\n#Seeds: {1}\\n#Epochs: {2}\\nOptimier: {3}\\nLearning Rate: {4}\\n#Hidden_1: {5}\\n#Hidden_2: {6}'.format(batch_size, num_seeds, num_epochs, optimiser_name, learning_rate, num_hidden_1, num_hidden_2))\n",
    "print('Final results are averaged from the last {0} epochs\\n'.format(last_n))\n",
    "print('model_0: UDV                      (avg:-{0}): T-loss {1}; V-loss is {2}'.format(last_n, final_train_loss_m0, final_val_loss_m0))\n",
    "print('model_1: UDV-s                    (avg:-{0}): T-loss {1}; V-loss is {2}'.format(last_n, final_train_loss_m1, final_val_loss_m1))\n",
    "print('model_2: UDV-ReLU                 (avg:-{0}): T-loss {1}; V-loss is {2}'.format(last_n, final_train_loss_m2, final_val_loss_m2))\n",
    "print('model_3: UDV-ReLU-s               (avg:-{0}): T-loss {1}; V-loss is {2}'.format(last_n, final_train_loss_m3, final_val_loss_m3))\n",
    "print('model_4: UV-ReLU                  (avg:-{0}): T-loss {1}; V-loss is {2}'.format(last_n, final_train_loss_m4, final_val_loss_m4))\n",
    "print('model_5: UV-ReLU(Constrained)     (avg:-{0}): T-loss {1}; V-loss is {2}'.format(last_n, final_train_loss_m5, final_val_loss_m5))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
