{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":517,"status":"ok","timestamp":1727849368648,"user":{"displayName":"秋山俊太","userId":"03114718366633972315"},"user_tz":-540},"id":"CBdVtd0iTWIX","outputId":"f982c732-ae71-4a21-c400-18485227d7a7"},"outputs":[],"source":["from __future__ import print_function, division\n","import numpy as np\n","import pandas as pd\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","from torch.autograd import Variable\n","import torch.optim as optim\n","import torchvision\n","from torchvision import datasets, models, transforms, utils\n","from torch.utils.data import Dataset, DataLoader\n","import matplotlib.pyplot as pltatas\n","import time\n","import os\n","import copy\n","\n","print(\"PyTorch Version:\", torch.__version__)\n","print(\"Torchvision Version:\", torchvision.__version__)\n","print(\"GPU is available?\", torch.cuda.is_available())"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"mbjnWaRYWO4L"},"outputs":[],"source":["def singular_value_bounding(W):\n","    # Perform Singular Value Decomposition (SVD)\n","    U, Sigma, Vt = np.linalg.svd(W, full_matrices=False)\n","\n","    # Apply bounding to singular values\n","    bounded_Sigma = [max(3/4, min(5/4, s)) for s in Sigma]\n","\n","    # Construct the diagonal matrix of bounded singular values\n","    Sigma_bounded = np.diag(bounded_Sigma)\n","\n","    # Reconstruct the matrix using the bounded singular values\n","    W_bounded = U @ Sigma_bounded @ Vt\n","\n","    return W_bounded"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"afSpv3-S09n5"},"outputs":[],"source":["class BlockCoordinateNN(nn.Module):\n","    def __init__(self, input_dim = 10, r = 30, L = 5):\n","        super(BlockCoordinateNN, self).__init__()\n","        self.layers = [nn.Linear(input_dim, r)]\n","        for _ in range(L-2):\n","          self.layers.append(nn.Linear(r, r))\n","        self.layers.append(nn.Linear(r, 1))\n","        self.activation = nn.ReLU()\n","        self.loss = [[] for _ in range(2*L)]\n","        self.L = L\n","\n","    def forward(self, x):\n","        Vs = []\n","        V = x\n","        for layer in self.layers[:-1]:\n","          Vnext = self.activation(layer(V))\n","          Vs.append(Vnext)\n","          V = Vnext\n","        output = self.layers[-1](Vnext)\n","        return output, Vs\n","\n","    def apply_singular_value_bounding(self):\n","        with torch.no_grad():\n","            for name, param in self.named_parameters():\n","                if 'weight' in name:\n","                    param.copy_(singular_value_bounding(param))"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Gz7-5kqX7EQB"},"outputs":[],"source":["def compute_loss(model, Vs, batch_y, criterion, layer_index, Vnext=None):\n","    if layer_index == model.L-2:\n","        return criterion(model.layers[layer_index + 1](Vs[layer_index]).view(-1), batch_y)\n","    else:\n","        return criterion(model.activation(model.layers[layer_index + 1](Vs[layer_index])), Vnext)\n","\n","def update_weights(model, optimizer, loss, retain_graph=False):\n","    optimizer.zero_grad()\n","    loss.backward(retain_graph=retain_graph)\n","    optimizer.step()\n","\n","def train_vj(model, criterion, Vj, batch_y, Vnext, layer_index, eta, num_inner_iters=1000):\n","    if layer_index == model.L-2:\n","        for _ in range(1):\n","            loss_Vj = criterion(model.layers[layer_index + 1](Vj).view(-1), batch_y)\n","            grad = torch.autograd.grad(loss_Vj, Vj, retain_graph=True)[0]\n","            Vj = Vj - eta * grad\n","    else:\n","        for _ in range(num_inner_iters):\n","            loss_Vj = criterion(model.activation(model.layers[layer_index + 1](Vj)), Vnext)\n","            grad = torch.autograd.grad(loss_Vj, Vj, retain_graph=True)[0]\n","            Vj = Vj - eta * grad\n","    return Vj.detach(), loss_Vj\n","\n","def print_loss(model, epoch, num_epochs):\n","    print(f'Epoch [{epoch + 1}/{num_epochs}], W{model.L} Loss: {model.loss[-1][-1]:.4f}')\n","    for j in range(model.L - 1, 0, -1):\n","        print(f'Epoch [{epoch + 1}/{num_epochs}], V{j} Loss: {model.loss[2 * j][-1]:.4f}')\n","        print(f'Epoch [{epoch + 1}/{num_epochs}], W{j} Loss: {model.loss[2 * j - 1][-1]:.4f}')\n","    print(f'Epoch [{epoch + 1}/{num_epochs}], Total Loss: {model.loss[0][-1]:.4f}')"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"SYwqpESZ6RAS"},"outputs":[],"source":["def train_bcd(model, X_train_tensor, y_train_tensor, num_epochs=100, batch_size=None, eta=0.1, num_inner_iters=100):\n","    \"\"\"\n","    Train the BCD model using the given training data.\n","\n","    Parameters:\n","    - model: The model to be trained.\n","    - X_train_tensor: Tensor of training features.\n","    - y_train_tensor: Tensor of training labels.\n","    - num_epochs: Number of epochs to train (default: 100).\n","    - batch_size: Size of each training batch (default: None, meaning use full dataset).\n","    - eta: Learning rate (default: 0.1).\n","    - num_inner_iters: Number of inner iterations (default: 1).\n","\n","    Returns:\n","    - None\n","    \"\"\"\n","\n","    if batch_size is None:\n","        batch_size = len(X_train_tensor)\n","\n","    criterion = nn.MSELoss()\n","    optimizer_list = [optim.SGD(layer.parameters(), lr=eta) for layer in model.layers]\n","\n","    batch_X, batch_y = X_train_tensor, y_train_tensor\n","\n","    for epoch in range(num_epochs):\n","        outputs, Vs_init = model(batch_X)\n","        loss = criterion(outputs.view(-1), batch_y)\n","\n","        if epoch == 0:\n","            Vs = Vs_init\n","\n","        new_Vs = []\n","\n","        Vnext = None  # Initialization\n","\n","        for j in range(model.L-2, -1, -1):\n","            model.zero_grad()\n","            for param in model.layers[j + 1].parameters():\n","                param.requires_grad = True\n","\n","            loss_Wj = compute_loss(model, Vs, batch_y, criterion, j, Vnext)\n","            update_weights(model, optimizer_list[j + 1], loss_Wj, retain_graph=True)\n","\n","            model.loss[2 * (j + 1) + 1].append(loss_Wj.item())\n","\n","            for param in model.layers[j + 1].parameters():\n","                param.requires_grad = False\n","\n","            Vj = Vs[j].clone().detach().requires_grad_(True)\n","            Vj, loss_Vj = train_vj(model, criterion, Vj, batch_y, Vnext, j, eta, num_inner_iters)\n","\n","            Vnext = Vj.clone().detach().requires_grad_(False)\n","            new_Vs.insert(0, Vj)\n","            model.loss[2 * (j + 1)].append(loss_Vj.item())\n","\n","        for param in model.layers[0].parameters():\n","            param.requires_grad = True\n","\n","        for _ in range(num_inner_iters):\n","            loss_W0 = criterion(model.activation(model.layers[0](batch_X)), Vnext)\n","            update_weights(model, optimizer_list[0], loss_W0)\n","\n","        model.loss[1].append(loss_W0.item())\n","\n","        for param in model.layers[0].parameters():\n","            param.requires_grad = False\n","\n","        Vs = new_Vs\n","        model.loss[0].append(loss.item())\n","\n","        if (epoch + 1) % 10 == 0:\n","          print_loss(model, epoch, num_epochs)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":1096815,"status":"ok","timestamp":1727850969473,"user":{"displayName":"秋山俊太","userId":"03114718366633972315"},"user_tz":-540},"id":"SzviIT4IlFw6","outputId":"80215780-69d4-47cd-ec84-a2f091801dbe"},"outputs":[],"source":["model = BlockCoordinateNN(input_dim =600,r=30)\n","model.apply_singular_value_bounding()\n","train_bcd(model, num_epochs = 5000, eta = 1, X_train_tensor = X_train_tensor, y_train_tensor = y_train_tensor, num_inner_iters=100)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"fsBEFWzgB19E"},"outputs":[],"source":[]}],"metadata":{"accelerator":"GPU","colab":{"authorship_tag":"ABX9TyNhfkZeMfi+fk7EEDv4dzPb","gpuType":"T4","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}
