{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2410bd58-ad01-46e0-b4c5-e599cd86a0c8",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:red\">DC_FW Project: Libraries</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d62d1060",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''''\n",
    "## Test setup for Neural Network's experiment \n",
    "#\n",
    "# [MHSY25] H. Maskan, Y.Hou, S.Sra, A. Yurtsever\n",
    "% \"Revisiting Frank-Wolfe for Structured Nonconvex Optimization\"\n",
    "% 39th Conference on Neural Information Processing Systems (NeurIPS 2025).\n",
    "% \n",
    "% contact information: https://github.com/hoomyhh\n",
    "'''''\n",
    "# This block imports libraries\n",
    "\n",
    "# Basic libraries:\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "# import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, Subset\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "#===========================================================\n",
    "# Classification needs:\n",
    "from torchvision import transforms\n",
    "from torchvision import models\n",
    "\n",
    "# For specific datasets \n",
    "# from torchvision.datasets import MNIST\n",
    "from torchvision.datasets import CIFAR10\n",
    "\n",
    "import pickle\n",
    "import time\n",
    "import datetime\n",
    "import random\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01f18f97-4adf-4248-af24-6c9e21e116f7",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\">GPU/CPU</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78c9549f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = \"cuda\"\n",
    "else:\n",
    "    device = \"cpu\"\n",
    "print(f\"Using {device} device\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9333a55d-96ac-4479-8e9d-b54dedd1192a",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\">Reproducible setting</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "744d3c6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def reset_seed(public_seed = 126):\n",
    "    random.seed(public_seed)\n",
    "    np.random.seed(public_seed)\n",
    "    torch.manual_seed(public_seed) \n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2553d652-0748-412d-af3e-4d16208dc81f",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\">Dataset pre-processing</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "341a6c33-2249-4195-b103-9d8997cb522b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# local dataset path \n",
    "data_path = \"./CIFAR10\" \n",
    "\n",
    "\n",
    "# Data set pre-processing\n",
    "# Warning: This is dataset-wise setting\n",
    "transform = transforms.Compose([transforms.ToTensor(),\n",
    "                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))# Only for CIFAR10 dataset\n",
    "                               ])\n",
    "\n",
    "# Load all or partial data\n",
    "Load_ALL = True\n",
    "\n",
    "if Load_ALL is True:\n",
    "    print(\"Load all data\")\n",
    "    train_dataset = CIFAR10(data_path, train = True, download = False, transform = transform)\n",
    "    val_dataset = CIFAR10(data_path, train = False, download = False, transform = transform)\n",
    "else:\n",
    "    print(\"Load partial data\")\n",
    "    num_train_samples = 1024 * 6\n",
    "    num_test_samples = 1024\n",
    "    CIFAR10_dataset = CIFAR10(root = data_path, train = True, download = False, transform = transform)\n",
    "    train_indices = random.sample(range(len(CIFAR10_dataset)), num_train_samples)\n",
    "    test_indices = random.sample(range(len(CIFAR10_dataset)), num_test_samples)\n",
    "\n",
    "    train_dataset = Subset(CIFAR10_dataset, train_indices)\n",
    "    val_dataset = Subset(CIFAR10_dataset, test_indices)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9960820f-72cf-4ca7-8549-387357447676",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\">Network structure</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f81ff050",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CIFAR10CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CIFAR10CNN, self).__init__()\n",
    "        self.conv_block1 = nn.Sequential(nn.Conv2d(3, 32, kernel_size = 3, stride = 1, padding = 1),\n",
    "                                         nn.ReLU(),\n",
    "                                         nn.BatchNorm2d(32),\n",
    "                                         nn.Conv2d(32, 32, kernel_size = 3, stride = 1, padding = 1),\n",
    "                                         nn.ReLU(),\n",
    "                                         nn.BatchNorm2d(32),\n",
    "                                         nn.MaxPool2d(kernel_size = 2, stride = 2, padding = 0),\n",
    "                                         nn.Dropout(0.01))\n",
    "        \n",
    "        self.conv_block2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size = 3, stride = 1, padding = 1),\n",
    "                                         nn.ReLU(),\n",
    "                                         nn.BatchNorm2d(64),\n",
    "                                         nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1),\n",
    "                                         nn.ReLU(),\n",
    "                                         nn.BatchNorm2d(64),\n",
    "                                         nn.MaxPool2d(kernel_size = 2, stride = 2, padding = 0),\n",
    "                                         nn.Dropout(0.01))\n",
    "        \n",
    "        self.fc_block = nn.Sequential(nn.Flatten(),\n",
    "                                      nn.Linear(64 * 8 * 8, 512), # 512 if following layers are applied\n",
    "                                      nn.ReLU(),\n",
    "                                      nn.BatchNorm1d(512),\n",
    "                                      nn.Dropout(0.01),\n",
    "                                      nn.Linear(512, 1024),\n",
    "                                      nn.ReLU(),\n",
    "                                      nn.BatchNorm1d(1024),\n",
    "                                      nn.Dropout(0.01),\n",
    "                                      nn.Linear(1024, 10))\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv_block1(x)\n",
    "        x = self.conv_block2(x)\n",
    "        x = self.fc_block(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acd105f7-da94-4f15-b0d7-82c1b97824d1",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:red\">The proposed DC-FW algorithm</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96a84689-725c-4170-a423-37955636b561",
   "metadata": {},
   "outputs": [],
   "source": [
    "# DC_FW for ICML adjustment\n",
    "\n",
    "class DC_FW_ICML(torch.optim.Optimizer):\n",
    "    def __init__(self, params, alpha, tuning_c, inner_upper):\n",
    "        assert alpha > 0.0, f\"Invalid alpha: {alpha}, it should be positive constant\"\n",
    "        assert tuning_c > 0.0, f\"Invalid tuning_c: {tuning_c}, it should be positive constant\"\n",
    "        \n",
    "        # Main variables\n",
    "        self.iteration_counter = 1             # Count the number of TOTAL iterations (=outer_loop_t). Reset until re-instantiating optimizer\n",
    "        self.inner_loop_upper = inner_upper    # Inner loop index k. should be very large\n",
    "   \n",
    "        self.gap = []                # For saving gap list for each parameter group\n",
    "        self.tolerance = []          # For saving tolerance list for each parameter group\n",
    "        self.para_t = []             # Update after end of the inner loop (The new [W_t])\n",
    "        \n",
    "        self.init_W_t = True         # The first loop in the first iteration requires the initialisation\n",
    "        \n",
    "        # Auxiliary variables\n",
    "        self.gap_warn = 0            # For dectecting negative gap\n",
    "         \n",
    "        # defaults (All hyper-parameters should be included in this \"defaults\")\n",
    "        defaults = dict(alpha = alpha, \n",
    "                        tuning_c = tuning_c)\n",
    "        super(DC_FW_ICML, self).__init__(params, defaults)\n",
    "        \n",
    "        \n",
    "    # For recovering optimizer state when loads a saved model (Not necessary in this case)\n",
    "    def __setstate__(self, state):\n",
    "        super(DC_FW_ICML, self).__setstate__(state)\n",
    "        \n",
    "        \n",
    "    # Step function (optimiser.step())\n",
    "    def step(self, closure = None):\n",
    "        # closure is used to re-calcualate loss in one step (Not necessary in this case)\n",
    "        loss = None\n",
    "        if closure is not None:\n",
    "            loss = closure()\n",
    "            \n",
    "        # group includes para and hyper-para\n",
    "        # We use global setting (not layer specific setting) which indicates only one group\n",
    "        # e.g., if we set different learning rate for different layers, we have multiple groups\n",
    "        # Hyper-parameters can be accessed here as normal variables\n",
    "        # ----------------------------------------------------\n",
    "        for group in self.param_groups: \n",
    "\n",
    "            # Extract and re-state hyper-parameters \n",
    "            alpha = group['alpha']\n",
    "            tuning_c = group['tuning_c']\n",
    "\n",
    "            # Initialise or update variables in the outer_loop\n",
    "            # Initialisation\n",
    "            if self.init_W_t is True: # Initialisation (only once per instantiation of optimiser)\n",
    "                for para_index, para_tk in enumerate(group['params']):\n",
    "                    self.para_t.append(para_tk.detach().clone())\n",
    "                    G_tk = para_tk.grad.data\n",
    "                    S_tk = -tuning_c * torch.sign(G_tk)\n",
    "                    D_tk = S_tk - para_tk.data\n",
    "                    \n",
    "                    # To initialise gap and tolerance, method 1:\n",
    "                    # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n",
    "                    # self.gap.append(torch.sum(-G_tk * D_tk))\n",
    "                    # self.tolerance.append(torch.sum(-G_tk * D_tk) * 0.9)\n",
    "                    # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n",
    "                    \n",
    "                    # To initialise gap and tolerance, method 2:\n",
    "                    # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n",
    "                    if G_tk.ndim == 2:\n",
    "                        self.gap.append(torch.linalg.matrix_norm(G_tk, ord = 1) * torch.linalg.matrix_norm(D_tk, ord = np.inf))\n",
    "                    else:\n",
    "                        self.gap.append(torch.linalg.vector_norm(G_tk, ord = 1) * torch.linalg.vector_norm(D_tk, ord = np.inf))\n",
    "                    self.tolerance.append(self.gap[para_index] * 0.9)\n",
    "                    # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n",
    "                del para_index, para_tk, G_tk, S_tk, D_tk\n",
    "                self.init_W_t = False # No more initialisations of the current optimiser\n",
    "                \n",
    "                print(f\"Parameters is divided to {len(self.para_t)} (= #layers * 2 when bias available) small groups and update independently\")\n",
    "                print(f\"Initialisation: self.gap {self.gap}\\nTol is {self.tolerance}\")\n",
    "                print(\"Initialisation for the loops has done (only once)\\n\\n\")\n",
    "                \n",
    "            # Update for outer_loop\n",
    "            else:\n",
    "                for para_index, para_tk in enumerate(group['params']):\n",
    "                    G_tk = para_tk.grad.data\n",
    "                    S_tk = -tuning_c * torch.sign(G_tk)\n",
    "                    D_tk = S_tk - para_tk.data # (S_tk - W_tk), same as (S_tk - W_t)\n",
    "                    \n",
    "                    # To calculate gap in the outer loop, method 1:\n",
    "                    # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n",
    "                    # Do: gap = - G_{t,k} dot D_{t_k} (Previously we use G_tk dot S_tk)\n",
    "                    # self.gap[para_index] = torch.sum(-G_tk * D_tk)\n",
    "                    # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n",
    "                    \n",
    "                    # To calculate gap in the outer loop, method 2:\n",
    "                    # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n",
    "                    # Do: gap = L1norm of G_tk * infnorm of D_tk \n",
    "                    if G_tk.ndim == 2:\n",
    "                        self.gap[para_index] = torch.linalg.matrix_norm(G_tk, ord = 1) * torch.linalg.matrix_norm(D_tk, ord = np.inf)\n",
    "                    else:\n",
    "                        self.gap[para_index] = torch.linalg.vector_norm(G_tk, ord = 1) * torch.linalg.vector_norm(D_tk, ord = np.inf)\n",
    "                    \n",
    "                    if self.gap[para_index] < self.tolerance[para_index]:\n",
    "                        self.tolerance[para_index] = self.gap[para_index] * 0.9\n",
    "                    # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n",
    "                del para_index, para_tk, G_tk, S_tk, D_tk\n",
    "                \n",
    "            \n",
    "            # Update the set of parameters (from shallow to deep, weights to bias)\n",
    "            # ----------------------------------------------------\n",
    "            for para_index, para_tk in enumerate(group['params']):\n",
    "                # Check gradient and skip the parameter without gradient\n",
    "                if para_tk.grad is None:\n",
    "                    continue\n",
    "                \n",
    "                # Algorithm in inner_loop. Can be breaked by the gap or finished by reaching upper limit\n",
    "                # ----------------------------------------------------\n",
    "                for inner_counter in range (self.inner_loop_upper):\n",
    "                    \n",
    "                    # '''''''''''''''''''''''''''''''''''''''''''''''''''''''''\n",
    "                    # Do: G_{t,k} = gradient_f(W_{t}) - alpha * (W_t - W_{t,k})\n",
    "                    G_tk = para_tk.grad.data - alpha * (self.para_t[para_index] - para_tk.data)\n",
    "                    # '''''''''''''''''''''''''''''''''''''''''''''''''''''''''\n",
    "                    \n",
    "                    # '''''''''''''''''''''''''''''''''''''''''''''''''''''''''\n",
    "                    # Do: S_{t,k} = lmo_D (G_{t,k}), Where lmo_D(G_t) = -c*sign(G_t)  [If norm(W,inf) <= c]\n",
    "                    S_tk = -tuning_c * torch.sign(G_tk)\n",
    "                    # '''''''''''''''''''''''''''''''''''''''''''''''''''''''''\n",
    "                    \n",
    "                    # '''''''''''''''''''''''''''''''''''''''''''''''''''''''''\n",
    "                    # Do: D_{t,k} = S_{t,k} - W{t,k}\n",
    "                    D_tk = S_tk - para_tk.data\n",
    "                    # '''''''''''''''''''''''''''''''''''''''''''''''''''''''''\n",
    "                    \n",
    "                    # '''''''''''''''''''''''''''''''''''''''''''''''''''''''''\n",
    "                    # To calculate gap in the inner loop, method 1:\n",
    "                    # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n",
    "                    # Do: gap = - G_{t,k} dot D_{t_k} (Previously we use G_tk dot S_tk)\n",
    "                    # self.gap[para_index] = torch.sum(-G_tk * D_tk)\n",
    "                    # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n",
    "                    \n",
    "                    # To calculate gap in the inner loop, method 2:\n",
    "                    # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n",
    "                    # Do: gap = L1_norm of G_tk * inf_norm of D_tk\n",
    "                    if G_tk.ndim == 2:\n",
    "                        self.gap[para_index] = torch.linalg.matrix_norm(G_tk, ord = 1) * torch.linalg.matrix_norm(D_tk, ord = np.inf)\n",
    "                        D_tk_l2norm = torch.linalg.matrix_norm(D_tk, ord = 'fro') ** 2\n",
    "                    else:\n",
    "                        self.gap[para_index] = torch.linalg.vector_norm(G_tk, ord = 1) * torch.linalg.vector_norm(D_tk, ord = np.inf)\n",
    "                        D_tk_l2norm = torch.linalg.vector_norm(D_tk, ord = 2) ** 2\n",
    "                    # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n",
    "                    # '''''''''''''''''''''''''''''''''''''''''''''''''''''''''\n",
    "                    \n",
    "                    # Warning check: negative gap\n",
    "                    if self.gap[para_index].item() < self.gap_warn:\n",
    "                        print(f\"Inner loop gap {self.gap[para_index]} (Threo: {self.gap_warn}) for para_index {para_index} in inner loop {inner_counter} at iteration {self.iteration_counter}\")         \n",
    "                    \n",
    "                    # '''''''''''''''''''''''''''''''''''''''''''''''''''''''''\n",
    "                    # Check gap condition\n",
    "                    if self.gap[para_index].item() <= (0.5 * self.tolerance[para_index].item()):\n",
    "                        # print(f\"gap:{self.gap[para_index]} <= 0.5 * {self.tolerance[para_index]} (tolerance), in inner loop {inner_counter} with para index {para_index}\")\n",
    "                        break\n",
    "                    else:\n",
    "                        # Do: eta_{t,k} = 2/(s+1)\n",
    "                        # eta_tk = 2.0 / (self.iteration_counter + 1) # eta_{t,k} = 2/(s+1)\n",
    "                        # eta_tk = 2.0 / (inner_counter + 1)        # eta_{t,k} = 2/(k+1)\n",
    "                        \n",
    "                        # If line search is applied here, eta_tk should be in [0,1]\n",
    "                        # eta_tk = np.minimum(np.maximum(self.gap[para_index].item() / (D_tk_l2norm * alpha), 0.0), 1.0)\n",
    "                        eta_tk = torch.minimum(torch.maximum(self.gap[para_index] / (D_tk_l2norm * alpha), torch.tensor(0.0, device = self.gap[para_index].device)), torch.tensor(1.0, device = self.gap[para_index].device))\n",
    "                        # Do: W_{t,k+1} = W_{t,k} + eta_{t,k} (S_{t,k} - W_{t,k})\n",
    "                        para_tk.data = para_tk.data + eta_tk * (S_tk - para_tk.data)\n",
    "                    # '''''''''''''''''''''''''''''''''''''''''''''''''''''''''\n",
    "                \n",
    "                # ====== Inner loop ends here ==============\n",
    "                # ====== below outside of this loop ======== \n",
    "                \n",
    "                # save the updated W_tk for each set of parameters after inner loop\n",
    "                # Do: W_{t+1} = W{t,k}\n",
    "                self.para_t[para_index] = para_tk.data.detach().clone()\n",
    "                \n",
    "                # Warning (Not necessary, for tuning): \n",
    "                # if inner_counter == (self.inner_loop_upper - 1):\n",
    "                #     print(\"Reach the inner loop upper bound\")\n",
    "                \n",
    "                \n",
    "            # ====== loop for parameters (from shallow to deep layer, weights to bias) ======\n",
    "            # ====== It loops (#layers * 2) times ===========================================\n",
    "            # ====== below outside of this loop =============================================\n",
    "\n",
    "            \n",
    "        # ====== loop for group of parameters and hyper-parameters ======\n",
    "        # ====== It loops (1) times since we use global setting    ======\n",
    "        # ====== below outside of this loop =============================\n",
    "        \n",
    "        self.iteration_counter += 1 # Reset until re-instantiating optimizer \n",
    "        return loss\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e3e0ad0-9131-47ce-905f-36d45df4c1e9",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:purple\">The classical FW algorithm</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7baca0f-16b4-43fe-8cb5-1050b06a16e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# FW for ICML adjustment\n",
    "\n",
    "class FW_ICML(torch.optim.Optimizer):\n",
    "    def __init__(self, params, tuning_c):\n",
    "        assert tuning_c > 0.0, f\"Invalid tuning_c: {tuning_c}, it should be positive constant\"\n",
    "        assert 'alpha' not in locals(), \"Normal FW should NOT have alpha\"\n",
    "        self.iteration_counter = 1 \n",
    "\n",
    "        defaults = dict(tuning_c = tuning_c)\n",
    "        super(FW_ICML, self).__init__(params, defaults)\n",
    "        \n",
    "        \n",
    "    def __setstate__(self, state):\n",
    "        super(FW_ICML, self).__setstate__(state)\n",
    "        \n",
    "        \n",
    "    def step(self, closure = None):\n",
    "        loss = None\n",
    "        if closure is not None:\n",
    "            loss = closure()\n",
    "            \n",
    "        for group in self.param_groups:    \n",
    "            tuning_c =  group['tuning_c']\n",
    "            \n",
    "            for para_index, para in enumerate(group['params']):\n",
    "                if para.grad is None:\n",
    "                    continue\n",
    "                    \n",
    "                # Do: G_{t} = gradient_f(W_{t})\n",
    "                G_t = para.grad.data\n",
    "                \n",
    "                # Do: S_{t} = lmo_D (G_{t}), Where lmo_D(G_t) = -c*sign(G_t)\n",
    "                S_t = -tuning_c * torch.sign(G_t)\n",
    "                \n",
    "                # Do: eta_{t} = 2/(s+1)\n",
    "                eta_t = 2.0 / (self.iteration_counter + 1)\n",
    "                \n",
    "                # Do: W_{t+1} = W_{t} + eta_{t} * (S_{t} - W_{t})\n",
    "                para.data = para.data + eta_t * (S_t - para.data)\n",
    "                \n",
    "        self.iteration_counter += 1\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05b6be21-8ca2-433e-acc6-7c848e79dd10",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\">Train / Validation framework</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7c3d8b6-e56f-455d-abf8-56910429eea0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def testFrame(model, optimizer, loss_fn, train_loader, val_loader, num_epochs, device):\n",
    "    model.to(device)\n",
    "    train_loss_epoch, val_loss_epoch = [], []\n",
    "    train_acc_epoch, val_acc_epoch = [], []\n",
    "\n",
    "    for epoch in range(1, num_epochs + 1):\n",
    "        model, train_loss, train_acc = training_loop(model,\n",
    "                                                     optimizer,\n",
    "                                                     loss_fn,\n",
    "                                                     train_loader,\n",
    "                                                     device)\n",
    "        train_loss_epoch.append(train_loss)\n",
    "        train_acc_epoch.append(train_acc)\n",
    "        \n",
    "        val_loss, val_acc = validation(model,\n",
    "                                       loss_fn,\n",
    "                                       val_loader,\n",
    "                                       device)\n",
    "        val_loss_epoch.append(val_loss)\n",
    "        val_acc_epoch.append(val_acc)\n",
    "\n",
    "        print(f\"epoch {epoch}/{num_epochs} training_loss: {train_loss:.12f} val_loss: {val_loss:.12f} train_acc: {train_acc:.5f} val_acc: {val_acc:.5f}\")\n",
    "\n",
    "    return model, train_loss_epoch, val_loss_epoch, train_acc_epoch, val_acc_epoch\n",
    "\n",
    "\n",
    "def training_loop(model, optimizer, loss_fn, train_loader, device):\n",
    "    model.train()\n",
    "    total_train_loss = 0\n",
    "    train_acc_batches = 0\n",
    "\n",
    "    for data, targets in train_loader:\n",
    "        data = data.to(device)\n",
    "        targets = targets.to(device)\n",
    "        outputs = model(data)\n",
    "        loss = loss_fn(outputs, targets)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total_train_loss += loss.item()\n",
    "        train_acc_batches += (outputs.argmax(1) == targets).sum().item()\n",
    "\n",
    "    avg_train_loss = total_train_loss / len(train_loader)\n",
    "    train_accuracy = train_acc_batches / len(train_loader.dataset)\n",
    "\n",
    "    return model, avg_train_loss, train_accuracy\n",
    "\n",
    "\n",
    "def validation(model, loss_fn, val_loader, device):\n",
    "    model.eval()\n",
    "    total_val_loss = 0\n",
    "    val_acc_batches = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for data, targets in val_loader:\n",
    "            data = data.to(device)\n",
    "            targets = targets.to(device)\n",
    "            outputs = model(data)\n",
    "            val_batch_loss = loss_fn(outputs, targets)\n",
    "\n",
    "            total_val_loss += val_batch_loss.item()\n",
    "            val_acc_batches += (outputs.argmax(1) == targets).sum().item()\n",
    "\n",
    "    avg_val_loss = total_val_loss / len(val_loader)\n",
    "    val_accuracy = val_acc_batches / len(val_loader.dataset)\n",
    "\n",
    "    return avg_val_loss, val_accuracy\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5f91db5-870e-4c11-9448-3b5e17db5a7f",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\">Helper Functions</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "356b66c5-7790-4a00-8617-1f8aa84bcd09",
   "metadata": {},
   "outputs": [],
   "source": [
    "# use to store loss or accuracy\n",
    "def store_metrics(a, b):\n",
    "    if len(a) == 0:\n",
    "        a = [b]\n",
    "    else:\n",
    "        a.append(b)\n",
    "    return a\n",
    "\n",
    "# Averaging results among various seeds\n",
    "def take_avg(list_in):\n",
    "    if len(list_in) == 0:\n",
    "        list_out = list_in\n",
    "    else:    \n",
    "        sum_list = list_in[0]\n",
    "        if len(list_in) > 1: \n",
    "            for i in range (1, len(list_in)):\n",
    "                sum_list = list(map(lambda x, y: x + y, sum_list, list_in[i]))\n",
    "        list_out = [x / len(list_in) for x in sum_list]\n",
    "    \n",
    "    return list_out\n",
    "\n",
    "# Averaging results of last few epochs\n",
    "def avg_stable(numbers, last_n):\n",
    "    if not numbers:\n",
    "        raise ValueError(\"Void list\")\n",
    "    avg_last = numbers[-last_n:]\n",
    "\n",
    "    return sum(avg_last) / len(avg_last)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8524ec9f-9902-4cf4-9720-cc90b4ad16ad",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:red\">Experiment setting</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbab915c-8155-44b0-8481-298c28e99b2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "reset_seed()\n",
    "\n",
    "batch_size = 256\n",
    "num_workers = 2\n",
    "if batch_size == 'full': # Test for GD \n",
    "    train_dataloader = DataLoader(train_dataset, batch_size = len(train_dataset), shuffle = True, num_workers = num_workers)\n",
    "    val_dataloader = DataLoader(val_dataset, batch_size = len(val_dataset), shuffle = False, num_workers = num_workers)\n",
    "else:\n",
    "    train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)\n",
    "    val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers)\n",
    "\n",
    "\n",
    "num_epochs = 100\n",
    "num_seeds = 1\n",
    "last_n = 5\n",
    "\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "time_record = []\n",
    "\n",
    "# DC_FW configuration and\n",
    "#----------------------\n",
    "alpha = 1             # DC_FW only  (1, 10, 100, 500, 1000)\n",
    "inner_upper = 10000   # DC_FW only (fixed)\n",
    "tuning_c = 1          # DC_FW and FW shared (1, 10, 100)\n",
    "#----------------------\n",
    "\n",
    "# Set result path\n",
    "result_path = './CIFAR10_CNN_Alpha_{0}_c_{1}_innerup_{2}_BS_{3}_E_{4}_S_{5}'.format(alpha, tuning_c, inner_upper, batch_size, num_epochs, num_seeds)\n",
    "\n",
    "if not os.path.exists(result_path):\n",
    "    # If the folder doesn't exist, create it\n",
    "    os.makedirs(result_path)\n",
    "    print(\"Result folder created successfully.\")       \n",
    "else:\n",
    "    print(\"Result Result folder already exists.\")\n",
    "    \n",
    "reset_seed()\n",
    "model_seed = random.sample(range(1, 1024), num_seeds) # \"Too many seeds\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae08cab5-837b-43a0-87bc-372c4a956915",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\">Model_0: DC_FW</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "522b6ca9-fb65-4f7a-b0c6-199a4b59839d",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_train_loss_DC_FW = []\n",
    "full_val_loss_DC_FW = []\n",
    "full_train_acc_DC_FW = []\n",
    "full_val_acc_DC_FW = []\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range(num_seeds): # Number of loops indicates how many seeds are averaged\n",
    "    reset_seed(public_seed = model_seed[i])\n",
    "    print(f\"Training model with DC_FW in {i + 1}th seed, model is initialised by the seed: {model_seed[i]}\\n\")\n",
    "    \n",
    "    model = CIFAR10CNN()\n",
    "    optimizer = DC_FW_ICML(model.parameters(), alpha = alpha, tuning_c = tuning_c, inner_upper = inner_upper)\n",
    "    model, train_losses, val_losses, train_acc_e, val_acc_e = testFrame(model = model, \n",
    "                                                                        optimizer = optimizer, \n",
    "                                                                        loss_fn = loss_fn, \n",
    "                                                                        train_loader = train_dataloader, \n",
    "                                                                        val_loader = val_dataloader, \n",
    "                                                                        num_epochs = num_epochs, \n",
    "                                                                        device = device)    \n",
    "    \n",
    "    full_train_loss_DC_FW = store_metrics(full_train_loss_DC_FW, train_losses)\n",
    "    full_val_loss_DC_FW = store_metrics(full_val_loss_DC_FW, val_losses)\n",
    "    full_train_acc_DC_FW = store_metrics(full_train_acc_DC_FW, train_acc_e)\n",
    "    full_val_acc_DC_FW = store_metrics(full_val_acc_DC_FW, val_acc_e)\n",
    "    \n",
    "avg_train_loss_DC_FW = take_avg(full_train_loss_DC_FW)\n",
    "avg_val_loss_DC_FW = take_avg(full_val_loss_DC_FW)\n",
    "avg_train_acc_DC_FW = take_avg(full_train_acc_DC_FW)\n",
    "avg_val_acc_DC_FW = take_avg(full_val_acc_DC_FW)\n",
    "\n",
    "final_train_loss_DC_FW = avg_stable(avg_train_loss_DC_FW, last_n = last_n)\n",
    "final_val_loss_DC_FW = avg_stable(avg_val_loss_DC_FW, last_n = last_n)\n",
    "final_train_acc_DC_FW = avg_stable(avg_train_acc_DC_FW, last_n = last_n)\n",
    "final_val_acc_DC_FW = avg_stable(avg_val_acc_DC_FW, last_n = last_n)\n",
    "\n",
    "end_time = time.time()    \n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del model # Can be commented "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bea0f1c-383e-4f55-b37b-3775ee211f3b",
   "metadata": {
    "tags": []
   },
   "source": [
    "### <span style=\"color:blue\">Model_1: FW</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5357f079-8970-437a-8e68-0f1f045d5c66",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_train_loss_FW = []\n",
    "full_val_loss_FW = []\n",
    "full_train_acc_FW = []\n",
    "full_val_acc_FW = []\n",
    "\n",
    "start_time = time.time()\n",
    "time_record.append(start_time)\n",
    "\n",
    "for i in range(num_seeds):\n",
    "    reset_seed(public_seed = model_seed[i])\n",
    "    print(f\"Training model with FW in {i + 1}th seed, model is initialised by the seed: {model_seed[i]}\\n\")\n",
    "    \n",
    "    model = CIFAR10CNN()\n",
    "    optimizer = FW_ICML(model.parameters(), tuning_c = tuning_c)\n",
    "    model, train_losses, val_losses, train_acc_e, val_acc_e = testFrame(model = model, \n",
    "                                                                        optimizer = optimizer, \n",
    "                                                                        loss_fn = loss_fn, \n",
    "                                                                        train_loader = train_dataloader, \n",
    "                                                                        val_loader = val_dataloader, \n",
    "                                                                        num_epochs = num_epochs, \n",
    "                                                                        device = device)    \n",
    "    full_train_loss_FW = store_metrics(full_train_loss_FW, train_losses)\n",
    "    full_val_loss_FW = store_metrics(full_val_loss_FW, val_losses)\n",
    "    full_train_acc_FW = store_metrics(full_train_acc_FW, train_acc_e)\n",
    "    full_val_acc_FW = store_metrics(full_val_acc_FW, val_acc_e)\n",
    "    \n",
    "avg_train_loss_FW = take_avg(full_train_loss_FW)\n",
    "avg_val_loss_FW = take_avg(full_val_loss_FW)\n",
    "avg_train_acc_FW = take_avg(full_train_acc_FW)\n",
    "avg_val_acc_FW = take_avg(full_val_acc_FW)\n",
    "\n",
    "final_train_loss_FW = avg_stable(avg_train_loss_FW, last_n = last_n)\n",
    "final_val_loss_FW = avg_stable(avg_val_loss_FW, last_n = last_n)\n",
    "final_train_acc_FW = avg_stable(avg_train_acc_FW, last_n = last_n)\n",
    "final_val_acc_FW = avg_stable(avg_val_acc_FW, last_n = last_n)\n",
    "\n",
    "end_time = time.time()    \n",
    "time_record.append(end_time)\n",
    "duration = end_time - start_time\n",
    "time_record.append(duration)\n",
    "\n",
    "del model # Can be commented "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ec16d58-de3f-427c-adbe-b07fed19662a",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"all done\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
