{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CBdVtd0iTWIX"
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function, division\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms, utils\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "print(\"PyTorch Version:\", torch.__version__)\n",
    "print(\"Torchvision Version:\", torchvision.__version__)\n",
    "print(\"GPU is available?\", torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pJn7fK-_TwkB"
   },
   "outputs": [],
   "source": [
    "from sklearn.datasets import make_regression, load_diabetes\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "afSpv3-S09n5"
   },
   "outputs": [],
   "source": [
    "class BlockCoordinateNN(nn.Module):\n",
    "    def __init__(self, input_dim = 10, r = 30, L = 5):\n",
    "        super(BlockCoordinateNN, self).__init__()\n",
    "        self.layers = [nn.Linear(input_dim, r)]\n",
    "        for _ in range(L-2):\n",
    "          self.layers.append(nn.Linear(r, r))\n",
    "        self.layers.append(nn.Linear(r, 1))\n",
    "        self.activation = nn.ReLU()\n",
    "        self.loss = [[] for _ in range(2*L)]\n",
    "        self.test_loss = []\n",
    "        self.L = L\n",
    "\n",
    "    def forward(self, x):\n",
    "        Vs = []\n",
    "        V = x\n",
    "        for idx, layer in enumerate(self.layers[:-1]):\n",
    "          if idx == 0:\n",
    "            Vnext = self.activation(layer(V))\n",
    "          else:\n",
    "            Vnext = self.activation(layer(V)) + V\n",
    "          Vs.append(Vnext)\n",
    "          V = Vnext\n",
    "        output = self.layers[-1](Vnext)\n",
    "        return output, Vs\n",
    "\n",
    "    def apply_singular_value_bounding(self):\n",
    "        with torch.no_grad():\n",
    "            for name, param in self.named_parameters():\n",
    "                if 'weight' in name:\n",
    "                    param.copy_(singular_value_bounding(param))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8OJdeVwYlobs"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Set the number of features and data size\n",
    "input_dim = 600\n",
    "train_size = 500\n",
    "test_size = 1000\n",
    "\n",
    "# Generate random training and test data\n",
    "X_train = torch.randn(train_size, input_dim)\n",
    "X_test = torch.randn(test_size, input_dim)\n",
    "\n",
    "# Create an instance of the neural network model\n",
    "model = BlockCoordinateNN(input_dim=input_dim, L= 2)\n",
    "\n",
    "# Generate labels from training and test data using the model\n",
    "with torch.no_grad():\n",
    "    y_train, _ = model(X_train)  # Labels for training data\n",
    "    y_test, _ = model(X_test)    # Labels for test data\n",
    "\n",
    "# Create datasets\n",
    "train_dataset = TensorDataset(X_train, y_train)\n",
    "test_dataset = TensorDataset(X_test, y_test)\n",
    "\n",
    "# Create DataLoaders (batch size is arbitrary, e.g., 32)\n",
    "batch_size = 32\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "# Check the DataLoaders\n",
    "len(train_loader), len(test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uq75y245dEoc"
   },
   "outputs": [],
   "source": [
    "scaler_X = StandardScaler()\n",
    "scaler_y = StandardScaler()\n",
    "X_train = scaler_X.fit_transform(X_train)\n",
    "X_test = scaler_X.transform(X_test)\n",
    "y_train = scaler_y.fit_transform(y_train.reshape(-1, 1)).flatten()\n",
    "y_test = scaler_y.transform(y_test.reshape(-1, 1)).flatten()\n",
    "\n",
    "X_train_tensor = torch.tensor(X_train, dtype=torch.float32)\n",
    "X_test_tensor = torch.tensor(X_test, dtype=torch.float32)\n",
    "y_train_tensor = torch.tensor(y_train, dtype=torch.float32)\n",
    "y_test_tensor = torch.tensor(y_test, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mbjnWaRYWO4L"
   },
   "outputs": [],
   "source": [
    "def singular_value_bounding(W):\n",
    "    # Perform Singular Value Decomposition (SVD)\n",
    "    U, Sigma, Vt = np.linalg.svd(W, full_matrices=False)\n",
    "\n",
    "    # Apply bounding to singular values\n",
    "    bounded_Sigma = [max(-1/2, min(1/2, s)) for s in Sigma]\n",
    "\n",
    "    # Construct the diagonal matrix of bounded singular values\n",
    "    Sigma_bounded = np.diag(bounded_Sigma)\n",
    "\n",
    "    # Reconstruct the matrix using the bounded singular values\n",
    "    W_bounded = U @ Sigma_bounded @ Vt\n",
    "\n",
    "    return W_bounded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Gz7-5kqX7EQB"
   },
   "outputs": [],
   "source": [
    "def compute_loss(model, Vs, batch_y, criterion, layer_index, Vnext=None):\n",
    "    if layer_index == model.L-2:\n",
    "        return criterion(model.layers[layer_index + 1](Vs[layer_index]).view(-1), batch_y)\n",
    "    else:\n",
    "        return criterion(model.activation(model.layers[layer_index + 1](Vs[layer_index]))+Vs[layer_index], Vnext)\n",
    "\n",
    "def update_weights(model, optimizer, loss, retain_graph=False):\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward(retain_graph=retain_graph)\n",
    "    optimizer.step()\n",
    "\n",
    "def train_vj(model, criterion, Vj, batch_y, Vnext, layer_index, eta, num_inner_iters=1000):\n",
    "    if layer_index == model.L-2:\n",
    "        for _ in range(10):\n",
    "            loss_Vj = criterion(model.layers[layer_index + 1](Vj).view(-1), batch_y)\n",
    "            grad = torch.autograd.grad(loss_Vj, Vj, retain_graph=True)[0]\n",
    "            Vj = Vj - eta * grad\n",
    "            Vj = nn.ReLU()(Vj)\n",
    "    else:\n",
    "        for _ in range(num_inner_iters):\n",
    "            loss_Vj = criterion(model.activation(model.layers[layer_index + 1](Vj))+Vj, Vnext)\n",
    "            grad = torch.autograd.grad(loss_Vj, Vj, retain_graph=True)[0]\n",
    "            Vj = Vj - eta * grad\n",
    "            Vj = nn.ReLU()(Vj)\n",
    "    return Vj.detach(), loss_Vj\n",
    "\n",
    "def print_loss(model, epoch, num_epochs):\n",
    "    print(f'Epoch [{epoch + 1}/{num_epochs}], W{model.L} Loss: {model.loss[-1][-1]:.4f}')\n",
    "    for j in range(model.L - 1, 0, -1):\n",
    "        print(f'Epoch [{epoch + 1}/{num_epochs}], V{j} Loss: {model.loss[2 * j][-1]:.4f}')\n",
    "        print(f'Epoch [{epoch + 1}/{num_epochs}], W{j} Loss: {model.loss[2 * j - 1][-1]:.4f}')\n",
    "    print(f'Epoch [{epoch + 1}/{num_epochs}], Total Loss: {model.loss[0][-1]:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SYwqpESZ6RAS"
   },
   "outputs": [],
   "source": [
    "def train_bcd(model, X_train_tensor, y_train_tensor, num_epochs=100, batch_size=None, eta=0.1, num_inner_iters=100):\n",
    "    \"\"\"\n",
    "    Train the BCD model using the given training data.\n",
    "\n",
    "    Parameters:\n",
    "    - model: The model to be trained.\n",
    "    - X_train_tensor: Tensor of training features.\n",
    "    - y_train_tensor: Tensor of training labels.\n",
    "    - num_epochs: Number of epochs to train (default: 100).\n",
    "    - batch_size: Size of each training batch (default: None, meaning use full dataset).\n",
    "    - eta: Learning rate (default: 0.1).\n",
    "    - num_inner_iters: Number of inner iterations (default: 1).\n",
    "\n",
    "    Returns:\n",
    "    - None\n",
    "    \"\"\"\n",
    "\n",
    "    if batch_size is None:\n",
    "        batch_size = len(X_train_tensor)\n",
    "\n",
    "    criterion = nn.MSELoss()\n",
    "    optimizer_list = [optim.SGD(layer.parameters(), lr=eta) for layer in model.layers]\n",
    "\n",
    "    batch_X, batch_y = X_train_tensor, y_train_tensor\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        outputs, Vs_init = model(batch_X)\n",
    "        loss = criterion(outputs.view(-1), batch_y)\n",
    "\n",
    "        if epoch == 0:\n",
    "            Vs = Vs_init\n",
    "\n",
    "        new_Vs = []\n",
    "\n",
    "        Vnext = None\n",
    "\n",
    "        for j in range(model.L-2, -1, -1):\n",
    "            model.zero_grad()\n",
    "            for param in model.layers[j + 1].parameters():\n",
    "                param.requires_grad = True\n",
    "\n",
    "            loss_Wj = compute_loss(model, Vs, batch_y, criterion, j, Vnext)\n",
    "\n",
    "            if j < model.L-2:\n",
    "              update_weights(model, optimizer_list[j + 1], loss_Wj, retain_graph=True)\n",
    "\n",
    "            model.loss[2 * (j + 1) + 1].append(loss_Wj.item())\n",
    "\n",
    "            for param in model.layers[j + 1].parameters():\n",
    "                param.requires_grad = False\n",
    "\n",
    "            Vj = Vs[j].clone().detach().requires_grad_(True)\n",
    "            Vj, loss_Vj = train_vj(model, criterion, Vj, batch_y, Vnext, j, eta, num_inner_iters)\n",
    "\n",
    "            Vnext = Vj.clone().detach().requires_grad_(False)\n",
    "            new_Vs.insert(0, Vj)\n",
    "            model.loss[2 * (j + 1)].append(loss_Vj.item())\n",
    "\n",
    "        for param in model.layers[0].parameters():\n",
    "            param.requires_grad = True\n",
    "\n",
    "        for _ in range(num_inner_iters):\n",
    "            loss_W0 = criterion(model.layers[0](batch_X), Vnext)\n",
    "            update_weights(model, optimizer_list[0], loss_W0)\n",
    "\n",
    "        model.loss[1].append(loss_W0.item())\n",
    "\n",
    "        for param in model.layers[0].parameters():\n",
    "            param.requires_grad = False\n",
    "\n",
    "        Vs = new_Vs\n",
    "        model.loss[0].append(loss.item())\n",
    "\n",
    "        if (epoch + 1) % 10 == 0:\n",
    "          print_loss(model, epoch, num_epochs)\n",
    "\n",
    "          outputs, _ = model(X_test_tensor)\n",
    "          test_loss = criterion(outputs.view(-1), y_test_tensor)\n",
    "          model.test_loss.append(test_loss.item())\n",
    "          print(f\"test_loss:{test_loss.item()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SzviIT4IlFw6"
   },
   "outputs": [],
   "source": [
    "input_dim = 600\n",
    "train_size = 500\n",
    "test_size = 1000\n",
    "\n",
    "model = BlockCoordinateNN(input_dim=input_dim, L = 5)\n",
    "model.apply_singular_value_bounding()\n",
    "train_bcd(model, num_epochs = 5000, eta = 1., X_train_tensor = X_train_tensor, y_train_tensor = y_train_tensor, num_inner_iters=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KSklpGkXT6Vp"
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(model.loss[0][:4000], label = \"training loss\",color=\"blue\", lw=2.5)\n",
    "plt.plot(model_mlp.loss[0][:4000],color=\"red\", lw = 2.5, label = \"without skip connection\")\n",
    "\n",
    "plt.xlabel(\"epoch\", fontsize=18)\n",
    "plt.ylabel(\"loss\", fontsize=18)\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(True)\n",
    "ax = plt.gca()\n",
    "ax.yaxis.set_minor_locator(plt.NullLocator())\n",
    "plt.tick_params(labelsize=18)\n",
    "plt.legend(fontsize=18)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "TPU",
  "colab": {
   "authorship_tag": "ABX9TyPIx6EWc+hD1H13XZlqqLki",
   "gpuType": "V28",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
