{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MNIST_MLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MNIST_MLP, self).__init__()\n",
    "        self.fc1 = nn.Linear(28*28*1, 512)\n",
    "        self.fc2 = nn.Linear(512, 256)\n",
    "        self.fc3 = nn.Linear(256, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, 28*28*1)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CIFAR10_MLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CIFAR10_MLP, self).__init__()\n",
    "        self.fc1 = nn.Linear(32*32*3, 512)\n",
    "        self.fc2 = nn.Linear(512, 256)\n",
    "        self.fc3 = nn.Linear(256, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, 32*32*3)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MNIST_CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MNIST_CNN, self).__init__()\n",
    "\n",
    "        self.conv1 = nn.Conv2d(1, 32, 5, padding=2)  \n",
    "        self.conv2 = nn.Conv2d(32, 64, 3, padding=1) \n",
    "        self.pool = nn.MaxPool2d(2, 2)           \n",
    "        self.fc1 = nn.Linear(64 * 7 * 7, 128)     \n",
    "        self.fc2 = nn.Linear(128, 64)\n",
    "        self.fc3 = nn.Linear(64, 10)      \n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x))) \n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 64 * 7 * 7)   \n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CIFAR10_CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CIFAR10_CNN, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3,   64,  5, padding=2)\n",
    "        self.conv2 = nn.Conv2d(64,  128, 3, padding=1)\n",
    "        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.fc1 = nn.Linear(256 * 4 * 4, 512)\n",
    "        self.fc2 = nn.Linear(512, 256)\n",
    "        self.fc3 = nn.Linear(256, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = self.pool(F.relu(self.conv3(x)))\n",
    "        x = x.view(-1, 256 * 4 * 4)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MNIST_VGG(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MNIST_VGG, self).__init__()\n",
    "\n",
    "        # First block: 2 conv layers\n",
    "        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1) \n",
    "        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)\n",
    "        self.pool1 = nn.MaxPool2d(2, 2)                 \n",
    "\n",
    "        # Second block: 2 conv layers\n",
    "        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)\n",
    "        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) \n",
    "        self.pool2 = nn.MaxPool2d(2, 2)                         \n",
    "\n",
    "        # Third block: 3 conv layers\n",
    "        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1) \n",
    "        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1)\n",
    "        self.conv7 = nn.Conv2d(256, 256, kernel_size=3, padding=1) \n",
    "        self.pool3 = nn.MaxPool2d(2, 2)                           \n",
    "\n",
    "        # Fully connected layers\n",
    "        self.fc1 = nn.Linear(256 * 3 * 3, 512)\n",
    "        self.fc2 = nn.Linear(512, 512)\n",
    "        self.fc3 = nn.Linear(512, 10)       \n",
    "\n",
    "    def forward(self, x):\n",
    "        # First block\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = self.pool1(x)\n",
    "        \n",
    "        # Second block\n",
    "        x = F.relu(self.conv3(x))\n",
    "        x = F.relu(self.conv4(x))\n",
    "        x = self.pool2(x)\n",
    "        \n",
    "        # Third block\n",
    "        x = F.relu(self.conv5(x))\n",
    "        x = F.relu(self.conv6(x))\n",
    "        x = F.relu(self.conv7(x))\n",
    "        x = self.pool3(x)\n",
    "\n",
    "        # Flatten for fully connected layers\n",
    "        x = x.view(-1, 256 * 3 * 3)\n",
    "        \n",
    "        # Fully connected layers\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)  \n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Batch Selection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ADAMBS\n",
    "def batch_selection_adambs(w_t, batch_size, gamma):\n",
    "    n = len(w_t)\n",
    "    weights_sum = np.sum(w_t)\n",
    "    p_t = (1.0-gamma) * w_t / weights_sum + gamma / n\n",
    "    indices = np.random.choice(n, size=batch_size, replace=True, p=p_t)\n",
    "    return indices, p_t\n",
    "\n",
    "# ADAMCB\n",
    "def distr_multiPlays(weights, K, gamma=0.0):\n",
    "    n = len(weights)\n",
    "    weights_sum = np.sum(weights)\n",
    "    probabilities = K * ((1.0 - gamma) * (weights / weights_sum) + (gamma / n))\n",
    "    return probabilities\n",
    "\n",
    "\n",
    "def getAlpha(temp, w_sorted):\n",
    "    # getAlpha calculates the alpha value for the sorted weight.\n",
    "    sum_weight = np.sum(w_sorted)\n",
    "    for i in range(len(w_sorted)):\n",
    "        alpha = (temp * sum_weight) / (1.0 - i * temp)\n",
    "        curr = w_sorted[i]\n",
    "        if alpha > curr:\n",
    "            alpha_exp = alpha\n",
    "            return alpha_exp\n",
    "        sum_weight = sum_weight - curr\n",
    "    raise Exception(\"alpha not found\")\n",
    "\n",
    "\n",
    "def find_indices(arr, condition):\n",
    "    # Function that returns the indices satisfying the condition function\n",
    "    return np.nonzero(condition(arr))[0]\n",
    "\n",
    "\n",
    "def DepRound(weights_p, k=1):\n",
    "    p = weights_p\n",
    "    n = len(p)\n",
    "    # Checks\n",
    "    assert k < n, f\"Error (DepRound): k = {k} should be < n = {n}.\"\n",
    "    if not np.isclose(np.sum(p), 1):\n",
    "        p = p / np.sum(p)\n",
    "    assert np.all(0 <= p) and np.all(p <= 1), f\"Error: the weights (p_1, ..., p_K) should all be 0 <= p_i <= 1 ...(={p})\"\n",
    "    assert np.isclose(np.sum(p), 1), f\"Error: the sum of weights p_1 + ... + p_K should be = 1 (= {np.sum(p)})\"\n",
    "    indices = np.random.choice(n, size=k, replace=False, p=p)\n",
    "    return indices\n",
    "\n",
    "\n",
    "def batch_selection_adamcb(weights, K, gamma):\n",
    "    n = len(weights)\n",
    "    # 1. modify the weights\n",
    "    theSum = np.sum(weights)\n",
    "    temp = (1.0 / K - gamma / n) * float(1.0 / (1.0 - gamma))\n",
    "    w_temp = weights.copy()\n",
    "    if np.max(weights) >= temp * theSum:\n",
    "        w_sorted = np.sort(weights)[::-1]\n",
    "        alpha_t = getAlpha(temp, w_sorted)\n",
    "        S_null = find_indices(w_temp, lambda e: e >= alpha_t)\n",
    "        for s in S_null:\n",
    "            w_temp[s] = alpha_t\n",
    "    else:\n",
    "        S_null = []\n",
    "    # 2. compute the probability\n",
    "    p_t = distr_multiPlays(w_temp, K, gamma=gamma)\n",
    "    assert False in np.isnan(np.array(p_t)\n",
    "        ), \"Error, probability must be a real number\"\n",
    "    # 3. sample K-distinct arms\n",
    "    indices = DepRound(p_t, k=K)\n",
    "    return indices, p_t, S_null"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compute Unbiased Gradient Estimate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_unbiased_gradient_estimate(model, loss_function, train_dataset, indices, p_t, device):\n",
    "    n = len(train_dataset)\n",
    "    K = len(indices)\n",
    "    params = list(model.parameters())\n",
    "\n",
    "    gradient_norms = []\n",
    "\n",
    "    # Compute unbiased gradient estimate\n",
    "    G_t_hat = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    for k in indices:\n",
    "        data, target = train_dataset[k]\n",
    "        data, target = data.to(device).unsqueeze(0), torch.tensor([target], device=device, dtype=torch.long)\n",
    "        model.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = loss_function(output, target)\n",
    "        loss.backward()\n",
    "        grad_norm = 0\n",
    "        with torch.no_grad():\n",
    "            for i, param in enumerate(model.parameters()):\n",
    "                G_t_hat[i] += param.grad.data / (n * p_t[k])\n",
    "                grad_norm += (param.grad.data ** 2).sum()\n",
    "            grad_norm = torch.sqrt(grad_norm)\n",
    "        gradient_norms.append(grad_norm.item())\n",
    "\n",
    "    G_t_hat = [grad / K for grad in G_t_hat]\n",
    "\n",
    "    return G_t_hat, gradient_norms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Update Model Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_model_parameters_adam(model, G_t_hat, m_t, v_t, v_hat, t, lr, betas, lambda_, eps, device):\n",
    "    beta1 = betas[0]\n",
    "    beta2 = betas[1]\n",
    "    alpha_t = lr / t**0.5\n",
    "    with torch.no_grad():\n",
    "        for i, param in enumerate(model.parameters()):\n",
    "            m_t[i] = beta1 * m_t[i] + (1 - beta1) * G_t_hat[i]\n",
    "            v_t[i] = beta2 * v_t[i] + (1 - beta2) * (G_t_hat[i] ** 2)\n",
    "            \n",
    "            m_hat_t = m_t[i] / (1- beta1 ** t)\n",
    "            v_hat_t = v_t[i] / (1 - beta2 ** t)\n",
    "            \n",
    "            update_step = alpha_t * m_hat_t / (torch.sqrt(v_hat_t) + eps)\n",
    "            param.data -= update_step.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_model_parameters_amsgrad(model, G_t_hat, m_t, v_t, v_hat, t, lr, betas, lambda_, eps, device):\n",
    "    beta1 = betas[0]\n",
    "    beta2 = betas[1]\n",
    "    alpha_t = lr / t**0.5\n",
    "    with torch.no_grad():\n",
    "        for i, param in enumerate(model.parameters()):\n",
    "            m_t[i] = beta1 * m_t[i] + (1 - beta1) * G_t_hat[i]\n",
    "            v_t[i] = beta2 * v_t[i] + (1 - beta2) * (G_t_hat[i] ** 2)\n",
    "            v_hat[i] = torch.max(v_hat[i], v_t[i])\n",
    "\n",
    "            m_hat_t = m_t[i] / (1- beta1 ** t)\n",
    "            v_hat_t = v_hat[i] / (1 - beta2 ** t)\n",
    "            \n",
    "            update_step = alpha_t * m_hat_t / (torch.sqrt(v_hat_t) + eps)\n",
    "            param.data -= update_step.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_model_parameters_adamx(model, G_t_hat, m_t, v_t, v_hat, t, lr, betas, lambda_, eps, device):\n",
    "    beta1 = betas[0] * lambda_ ** (t-1)\n",
    "    beta1_prev = betas[0] * lambda_ ** (t-2)\n",
    "    beta2 = betas[1]\n",
    "    alpha_t = lr / t**0.5\n",
    "    with torch.no_grad():\n",
    "        for i, param in enumerate(model.parameters()):\n",
    "            m_t[i] = beta1 * m_t[i] + (1 - beta1) * G_t_hat[i]\n",
    "            v_t[i] = beta2 * v_t[i] + (1 - beta2) * (G_t_hat[i] ** 2)\n",
    "            if t == 1:\n",
    "                v_hat[i] = v_t[i]\n",
    "            else:\n",
    "                v_hat[i] = torch.max(((1-beta1)**2/(1-beta1_prev)**2)*v_hat[i], v_t[i])\n",
    "            \n",
    "            m_hat_t = m_t[i] / (1- beta1 ** t)\n",
    "            v_hat_t = v_hat[i] / (1 - beta2 ** t)\n",
    "            \n",
    "            update_step = alpha_t * m_hat_t / (torch.sqrt(v_hat_t) + eps)\n",
    "            param.data -= update_step.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Update Sample Weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_sample_weights_BS(indices, w_t, p_t, sample_gradient_norms, gamma, alpha_p):\n",
    "    n = len(p_t)\n",
    "    K = len(indices)\n",
    "    p_min = gamma / n\n",
    "    h_hat = np.zeros(n)\n",
    "    for i, idx in enumerate(indices):\n",
    "        loss = - (sample_gradient_norms[i]**2 / p_t[idx]**2) + 1.0 / p_min**2\n",
    "        h_hat[idx] += loss / (K*p_t[idx])\n",
    "    w_t = w_t * np.exp(-alpha_p * h_hat)\n",
    "    return w_t\n",
    "\n",
    "\n",
    "def update_sample_weights_CB(indices, w_t, p_t, sample_gradient_norms, gamma, alpha_p, S_null):\n",
    "    n = len(p_t)\n",
    "    K = len(indices)\n",
    "    p_min = gamma / n\n",
    "    w_temp = w_t.copy()\n",
    "    h_hat = np.zeros(n)\n",
    "    for i, idx in enumerate(indices):\n",
    "        loss = - sample_gradient_norms[i]**2 / p_t[idx]**2 + 1.0 / p_min**2\n",
    "        h_hat[idx] = loss / p_t[idx]\n",
    "    w_t = w_temp * np.exp(-alpha_p * h_hat)\n",
    "    for i in S_null:\n",
    "        w_t[i] = w_temp[i]\n",
    "    return w_t\n",
    "\n",
    "\n",
    "def update_sample_distribution(indices, p_t, sample_gradient_norms, gamma, L, alpha_p):\n",
    "    n = len(p_t)\n",
    "    K = len(indices)\n",
    "    p_min = gamma / n\n",
    "    \n",
    "    h_hat = np.zeros(n)\n",
    "    for i, idx in enumerate(indices):\n",
    "        loss = - (sample_gradient_norms[i]**2 / p_t[idx]**2) + (L**2 / p_min**2)\n",
    "        h_hat[idx] += loss / (K*p_t[idx])\n",
    "\n",
    "    # Apply clipping to prevent overflow in exp()\n",
    "    scaled_h_hat = -alpha_p * h_hat\n",
    "    clipped_h_hat = np.clip(scaled_h_hat, a_min=-100, a_max=100)\n",
    "    w_t = p_t * np.exp(clipped_h_hat)\n",
    "    p_t = w_t / w_t.sum()\n",
    "    return p_t"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate1(model, loss_function, train_dataset, device):  \n",
    "    test_subset_size = 1000\n",
    "    test_subset_indices = range(1000)\n",
    "    with torch.no_grad():\n",
    "        test_loss = 0\n",
    "        test_correct = 0\n",
    "        for idx in test_subset_indices:\n",
    "            data, target = train_dataset[idx]\n",
    "            data, target = data.to(device).unsqueeze(0), torch.tensor([target], device=device, dtype=torch.long)\n",
    "            output = model(data)\n",
    "            test_loss += loss_function(output, target).item()\n",
    "            _, predicted = torch.max(output.data, 1)\n",
    "            test_correct += (predicted == target).sum().item()\n",
    "    avg_test_loss = test_loss / test_subset_size\n",
    "    test_accuracy = 100.0 * test_correct / test_subset_size\n",
    "    return avg_test_loss, test_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate2(model, loss_function, train_dataset, device):    \n",
    "    test_subset_size = 1000\n",
    "    test_subset_indices = np.random.choice(len(train_dataset), test_subset_size, replace=False)\n",
    "    with torch.no_grad():\n",
    "        test_loss = 0\n",
    "        test_correct = 0\n",
    "        for idx in test_subset_indices:\n",
    "            data, target = train_dataset[idx]\n",
    "            data, target = data.to(device).unsqueeze(0), torch.tensor([target], device=device, dtype=torch.long)\n",
    "            output = model(data)\n",
    "            test_loss += loss_function(output, target).item()\n",
    "            _, predicted = torch.max(output.data, 1)\n",
    "            test_correct += (predicted == target).sum().item()\n",
    "    avg_test_loss = test_loss / test_subset_size\n",
    "    test_accuracy = 100.0 * test_correct / test_subset_size\n",
    "    return avg_test_loss, test_accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Algorithms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ADAM(model, loss_function, train_dataset, test_dataset, K, max_iter, device, lr=0.001, betas=(0.9, 0.999), eps=1e-8, lambda_ = 1-1e-8):\n",
    "    n = len(train_dataset)\n",
    "    one_epoch = n // K\n",
    "\n",
    "    gamma = 1.0\n",
    "\n",
    "    # Initialize model parameters\n",
    "    params = list(model.parameters())\n",
    "    # Initialize first moment estimate and second moment estimate\n",
    "    m_t = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    v_t = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    v_hat = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    # Initialize sample weights\n",
    "    w_t = np.ones(n)\n",
    "\n",
    "    train_losses, train_accuracies, test_losses, test_accuracies, times = [], [], [], [], []\n",
    "    start_time = time.time()\n",
    "    for t in tqdm(range(1, max_iter + 1), desc=\"Training Progress\"):\n",
    "        model.train()\n",
    "        # Select a mini-batch I_t by sampling with replacement from p_t\n",
    "        indices, p_t = batch_selection_adambs(w_t, K, gamma)\n",
    "        G_t_hat, gradient_norms = compute_unbiased_gradient_estimate(model, loss_function, train_dataset, indices, p_t, device)\n",
    "        train_loss, train_accuracy = evaluate1(model, loss_function, train_dataset, device)\n",
    "        \n",
    "        update_model_parameters_adam(model, G_t_hat, m_t, v_t, v_hat, t, lr, betas, lambda_, eps, device)\n",
    "\n",
    "        if t == 1 or t % one_epoch == 0:\n",
    "            train_losses.append(train_loss)\n",
    "            train_accuracies.append(train_accuracy)\n",
    "            times.append(time.time() - start_time)\n",
    "            model.eval()\n",
    "            test_loss, test_accuracy = evaluate2(model, loss_function, test_dataset, device)\n",
    "            test_losses.append(test_loss)\n",
    "            test_accuracies.append(test_accuracy)\n",
    "            print(f\"Iteration: {t}, Time: {times[-1]:.2f}(s), Train Loss: {train_losses[-1]}, \"\n",
    "                f\"Test Loss: {test_losses[-1]}, Train Accuracy: {train_accuracies[-1]}, \"\n",
    "                f\"Test Accuracy: {test_accuracies[-1]}\")\n",
    "\n",
    "    return train_losses, test_losses, train_accuracies, test_accuracies, times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def AMSGrad(model, loss_function, train_dataset, test_dataset, K, max_iter, device, lr=0.001, betas=(0.9, 0.999), eps=1e-8, lambda_ = 1-1e-8):\n",
    "    n = len(train_dataset)\n",
    "    one_epoch = n // K\n",
    "\n",
    "    gamma = 1.0\n",
    "\n",
    "    # Initialize model parameters\n",
    "    params = list(model.parameters())\n",
    "    # Initialize first moment estimate and second moment estimate\n",
    "    m_t = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    v_t = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    v_hat = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    # Initialize sample weights\n",
    "    w_t = np.ones(n)\n",
    "    \n",
    "    train_losses, train_accuracies, test_losses, test_accuracies, times = [], [], [], [], []\n",
    "    start_time = time.time()\n",
    "    for t in tqdm(range(1, max_iter + 1), desc=\"Training Progress\"):\n",
    "        model.train()\n",
    "        # Select a mini-batch I_t by sampling with replacement from p_t\n",
    "        indices, p_t = batch_selection_adambs(w_t, K, gamma)\n",
    "        G_t_hat, gradient_norms = compute_unbiased_gradient_estimate(model, loss_function, train_dataset, indices, p_t, device)\n",
    "        train_loss, train_accuracy = evaluate1(model, loss_function, train_dataset, device)\n",
    "    \n",
    "        update_model_parameters_amsgrad(model, G_t_hat, m_t, v_t, v_hat, t, lr, betas, lambda_, eps, device)\n",
    "\n",
    "        if t == 1 or t % one_epoch == 0:\n",
    "            train_losses.append(train_loss)\n",
    "            train_accuracies.append(train_accuracy)\n",
    "            times.append(time.time() - start_time)\n",
    "            model.eval()\n",
    "            test_loss, test_accuracy = evaluate2(model, loss_function, test_dataset, device)\n",
    "            test_losses.append(test_loss)\n",
    "            test_accuracies.append(test_accuracy)\n",
    "            print(f\"Iteration: {t}, Time: {times[-1]:.2f}(s), Train Loss: {train_losses[-1]}, \"\n",
    "                f\"Test Loss: {test_losses[-1]}, Train Accuracy: {train_accuracies[-1]}, \"\n",
    "                f\"Test Accuracy: {test_accuracies[-1]}\")\n",
    "\n",
    "    return train_losses, test_losses, train_accuracies, test_accuracies, times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ADAMX(model, loss_function, train_dataset, test_dataset, K, max_iter, device, lr=0.001, betas=(0.9, 0.999), eps=1e-8, lambda_ = 1-1e-8):\n",
    "    n = len(train_dataset)\n",
    "    one_epoch = n // K\n",
    "    \n",
    "    gamma = 1.0\n",
    "\n",
    "    # Initialize model parameters\n",
    "    params = list(model.parameters())\n",
    "    # Initialize first moment estimate and second moment estimate\n",
    "    m_t = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    v_t = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    v_hat = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    # Initialize sample weights\n",
    "    w_t = np.ones(n)\n",
    "    \n",
    "    train_losses, train_accuracies, test_losses, test_accuracies, times = [], [], [], [], []\n",
    "    start_time = time.time()\n",
    "    for t in tqdm(range(1, max_iter + 1), desc=\"Training Progress\"):\n",
    "        model.train()\n",
    "        # Select a mini-batch I_t by sampling with replacement from p_t\n",
    "        indices, p_t = batch_selection_adambs(w_t, K, gamma)\n",
    "        G_t_hat, gradient_norms = compute_unbiased_gradient_estimate(model, loss_function, train_dataset, indices, p_t, device)\n",
    "        train_loss, train_accuracy = evaluate1(model, loss_function, train_dataset, device)\n",
    "        \n",
    "        update_model_parameters_adamx(model, G_t_hat, m_t, v_t, v_hat, t, lr, betas, lambda_, eps, device)\n",
    "        \n",
    "        if t == 1 or t % one_epoch == 0:\n",
    "            train_losses.append(train_loss)\n",
    "            train_accuracies.append(train_accuracy)\n",
    "            times.append(time.time() - start_time)\n",
    "            model.eval()\n",
    "            test_loss, test_accuracy = evaluate2(model, loss_function, test_dataset, device)\n",
    "            test_losses.append(test_loss)\n",
    "            test_accuracies.append(test_accuracy)\n",
    "            print(f\"Iteration: {t}, Time: {times[-1]:.2f}(s), Train Loss: {train_losses[-1]}, \"\n",
    "                f\"Test Loss: {test_losses[-1]}, Train Accuracy: {train_accuracies[-1]}, \"\n",
    "                f\"Test Accuracy: {test_accuracies[-1]}\")\n",
    "\n",
    "    return train_losses, test_losses, train_accuracies, test_accuracies, times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ADAMBS(model, loss_function, train_dataset, test_dataset, K, max_iter, device, lr=0.001, betas=(0.9, 0.999), eps=1e-8, lambda_ = 1-1e-8):\n",
    "    n = len(train_dataset)\n",
    "    one_epoch = n // K\n",
    "\n",
    "    gamma = 0.4\n",
    "    p_min = gamma / n\n",
    "    \n",
    "    # Initialize model parameters\n",
    "    params = list(model.parameters())\n",
    "    # Initialize first moment estimate and second moment estimate\n",
    "    m_t = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    v_t = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    v_hat = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    # Initialize sample weights\n",
    "    p_t = np.ones(n) / n\n",
    "\n",
    "    max_norm, min_norm = float('-inf'), float('inf')\n",
    "    alpha_p = 1e-18\n",
    "    train_losses, train_accuracies, test_losses, test_accuracies, times = [], [], [], [], []\n",
    "    start_time = time.time()\n",
    "    for t in tqdm(range(1, max_iter + 1), desc=\"Training Progress\"):\n",
    "        model.train()\n",
    "        # Select a mini-batch I_t by sampling with replacement from p_t\n",
    "        indices = np.random.choice(n, K, replace=True, p=p_t)\n",
    "        G_t_hat, gradient_norms = compute_unbiased_gradient_estimate(model, loss_function, train_dataset, indices, p_t, device)\n",
    "        train_loss, train_accuracy = evaluate1(model, loss_function, train_dataset, device)\n",
    "        if max(gradient_norms) > max_norm:\n",
    "            max_norm = max(gradient_norms)\n",
    "            \n",
    "        update_model_parameters_adam(model, G_t_hat, m_t, v_t, v_hat, t, lr, betas, lambda_, eps, device)\n",
    "        p_t = update_sample_distribution(indices, p_t, gradient_norms, gamma, max_norm, alpha_p)\n",
    "        \n",
    "        if t == 1 or t % one_epoch == 0:\n",
    "            train_losses.append(train_loss)\n",
    "            train_accuracies.append(train_accuracy)\n",
    "            times.append(time.time() - start_time)\n",
    "            model.eval()\n",
    "            test_loss, test_accuracy = evaluate2(model, loss_function, test_dataset, device)\n",
    "            test_losses.append(test_loss)\n",
    "            test_accuracies.append(test_accuracy)\n",
    "            print(f\"Iteration: {t}, Time: {times[-1]:.2f}(s), Train Loss: {train_losses[-1]}, \"\n",
    "                f\"Test Loss: {test_losses[-1]}, Train Accuracy: {train_accuracies[-1]}, \"\n",
    "                f\"Test Accuracy: {test_accuracies[-1]}\")\n",
    "\n",
    "    return train_losses, test_losses, train_accuracies, test_accuracies, times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ADAMBS_corrected(model, loss_function, train_dataset, test_dataset, K, max_iter, device, lr=0.001, betas=(0.9, 0.999), eps=1e-8, lambda_ = 1-1e-8):\n",
    "    n = len(train_dataset)\n",
    "    one_epoch = n // K\n",
    "    \n",
    "    gamma = 0.4\n",
    "    p_min = gamma / n\n",
    "    \n",
    "    # Initialize model parameters\n",
    "    params = list(model.parameters())\n",
    "    # Initialize first moment estimate and second moment estimate\n",
    "    m_t = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    v_t = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    v_hat = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    # Initialize sample weights\n",
    "    w_t = np.ones(n)\n",
    "    \n",
    "    max_norm, min_norm = float('-inf'), float('inf')\n",
    "    alpha_p = float(p_min**3)\n",
    "    train_losses, train_accuracies, test_losses, test_accuracies, times = [], [], [], [], []\n",
    "    start_time = time.time()\n",
    "    for t in tqdm(range(1, max_iter + 1), desc=\"Training Progress\"):\n",
    "        model.train()\n",
    "        # Select a mini-batch I_t by sampling with replacement from p_t\n",
    "        indices, p_t = batch_selection_adambs(w_t, K, gamma)\n",
    "        G_t_hat, gradient_norms = compute_unbiased_gradient_estimate(model, loss_function, train_dataset, indices, p_t, device)\n",
    "        train_loss, train_accuracy = evaluate1(model, loss_function, train_dataset, device)\n",
    "        if max(gradient_norms) > max_norm:\n",
    "            max_norm = max(gradient_norms)\n",
    "        min_norm = min(min_norm, min(gradient_norms))\n",
    "        sample_gradient_norms = [(x - min_norm) / (max_norm - min_norm) for x in gradient_norms]\n",
    "        update_model_parameters_adamx(model, G_t_hat, m_t, v_t, v_hat, t, lr, betas, lambda_, eps, device)\n",
    "        w_t = update_sample_weights_BS(indices, w_t, p_t, sample_gradient_norms, gamma, alpha_p)\n",
    "        if t == 1 or t % one_epoch == 0:\n",
    "            train_losses.append(train_loss)\n",
    "            train_accuracies.append(train_accuracy)\n",
    "            times.append(time.time() - start_time)\n",
    "            model.eval()\n",
    "            test_loss, test_accuracy = evaluate2(model, loss_function, test_dataset, device)\n",
    "            test_losses.append(test_loss)\n",
    "            test_accuracies.append(test_accuracy)\n",
    "            print(f\"Iteration: {t}, Time: {times[-1]:.2f}(s), Train Loss: {train_losses[-1]}, \"\n",
    "                f\"Test Loss: {test_losses[-1]}, Train Accuracy: {train_accuracies[-1]}, \"\n",
    "                f\"Test Accuracy: {test_accuracies[-1]}\")\n",
    "\n",
    "    return train_losses, test_losses, train_accuracies, test_accuracies, times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ADAMCB(model, loss_function, train_dataset, test_dataset, K, max_iter, device, lr=0.001, betas=(0.9, 0.999), eps=1e-8, lambda_ = 1-1e-8):\n",
    "    n = len(train_dataset)\n",
    "    one_epoch = n // K\n",
    "\n",
    "    gamma = 0.4\n",
    "    p_min = gamma / n\n",
    "\n",
    "    # Initialize model parameters\n",
    "    params = list(model.parameters())\n",
    "    # Initialize first moment estimate and second moment estimate\n",
    "    m_t = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    v_t = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    v_hat = [torch.zeros_like(p.data, device=device) for p in params]\n",
    "    # Initialize sample weights\n",
    "    w_t = np.ones(n)\n",
    "    \n",
    "    max_norm, min_norm = float('-inf'), float('inf')\n",
    "    alpha_p = float(p_min**3)\n",
    "    train_losses, train_accuracies, test_losses, test_accuracies, times = [], [], [], [], []\n",
    "    start_time = time.time()\n",
    "    for t in tqdm(range(1, max_iter + 1), desc=\"Training Progress\"):\n",
    "        model.train()\n",
    "        # Select a mini-batch I_t by sampling with replacement from p_t\n",
    "        indices, p_t, S_null = batch_selection_adamcb(w_t, K, gamma)\n",
    "        p_t /= np.sum(p_t)\n",
    "        G_t_hat, gradient_norms = compute_unbiased_gradient_estimate(model, loss_function, train_dataset, indices, p_t, device)\n",
    "        train_loss, train_accuracy = evaluate1(model, loss_function, train_dataset, device)\n",
    "        if max(gradient_norms) > max_norm:\n",
    "            max_norm = max(gradient_norms)\n",
    "        min_norm = min(min_norm, min(gradient_norms))\n",
    "        sample_gradient_norms = [(x - min_norm) / (max_norm - min_norm) for x in gradient_norms]\n",
    "        \n",
    "        update_model_parameters_adamx(model, G_t_hat, m_t, v_t, v_hat, t, lr, betas, lambda_, eps, device)\n",
    "        w_t = update_sample_weights_CB(indices, w_t, p_t, sample_gradient_norms, gamma, alpha_p, S_null)\n",
    "        \n",
    "        if t == 1 or t % one_epoch == 0:\n",
    "            train_losses.append(train_loss)\n",
    "            train_accuracies.append(train_accuracy)\n",
    "            times.append(time.time() - start_time)\n",
    "            model.eval()\n",
    "            test_loss, test_accuracy = evaluate2(model, loss_function, test_dataset, device)\n",
    "            test_losses.append(test_loss)\n",
    "            test_accuracies.append(test_accuracy)\n",
    "            print(f\"Iteration: {t}, Time: {times[-1]:.2f}(s), Train Loss: {train_losses[-1]}, \"\n",
    "                f\"Test Loss: {test_losses[-1]}, Train Accuracy: {train_accuracies[-1]}, \"\n",
    "                f\"Test Accuracy: {test_accuracies[-1]}\")\n",
    "\n",
    "    return train_losses, test_losses, train_accuracies, test_accuracies, times"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transformations for the MNIST dataset\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),  # Converts image to PyTorch tensor\n",
    "    transforms.Normalize((0.1307,), (0.3081,))  # Normalizes the dataset for mean and std deviation of MNIST\n",
    "])\n",
    "\n",
    "# Loading the MNIST dataset\n",
    "train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
    "test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "loss_function = nn.CrossEntropyLoss()\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "K = 128\n",
    "\n",
    "n = len(train_dataset)\n",
    "one_epoch = n // K\n",
    "\n",
    "max_iter = 10 * one_epoch\n",
    "seeds = [0, 1, 2, 3, 4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transformations for the FashionMNIST dataset\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),  # Converts image to PyTorch tensor\n",
    "    transforms.Normalize((0.1307,), (0.3081,))  # Normalizes the dataset for mean and std deviation of MNIST\n",
    "])\n",
    "\n",
    "# Loading the MNIST dataset\n",
    "train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)\n",
    "test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "loss_function = nn.CrossEntropyLoss()\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "K = 128\n",
    "\n",
    "n = len(train_dataset)\n",
    "one_epoch = n // K\n",
    "\n",
    "max_iter = 10 * one_epoch\n",
    "seeds = [0, 1, 2, 3, 4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "# Load and Normalize CIFAR-10 Dataset\n",
    "transform_train = transforms.Compose([\n",
    "transforms.RandomCrop(32, padding=4),\n",
    "transforms.RandomHorizontalFlip(),\n",
    "transforms.ToTensor(),\n",
    "transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "transform_test = transforms.Compose([\n",
    "transforms.ToTensor(),\n",
    "transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)\n",
    "test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)\n",
    "\n",
    "loss_function = nn.CrossEntropyLoss()\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "K = 128\n",
    "\n",
    "n = len(train_dataset)\n",
    "one_epoch = n // K\n",
    "max_iter = 10 * one_epoch\n",
    "\n",
    "seeds = [0, 1, 2, 3, 4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_losses_ADAM, test_losses_ADAM, train_accuracies_ADAM, test_accuracies_ADAM, times_ADAM = [], [], [], [], []\n",
    "train_losses_AMSGrad, test_losses_AMSGrad, train_accuracies_AMSGrad, test_accuracies_AMSGrad, times_AMSGrad = [], [], [], [], []\n",
    "train_losses_ADAMX, test_losses_ADAMX, train_accuracies_ADAMX, test_accuracies_ADAMX, times_ADAMX = [], [], [], [], []\n",
    "train_losses_ADAMBS, test_losses_ADAMBS, train_accuracies_ADAMBS, test_accuracies_ADAMBS, times_ADAMBS = [], [], [], [], []\n",
    "train_losses_ADAMCB, test_losses_ADAMCB, train_accuracies_ADAMCB, test_accuracies_ADAMCB, times_ADAMCB = [], [], [], [], []\n",
    "\n",
    "for i, j in enumerate(seeds):\n",
    "    print(\"Repetition: {} starts!\".format(i+1))\n",
    "    print(\"Seed: {}\".format(j))\n",
    "\n",
    "    print(\"Training by ADAM starts!\")\n",
    "    set_seed(j)\n",
    "    model = MNIST_CNN()\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "    train_loss_ADAM, test_loss_ADAM, train_accuracy_ADAM, test_accuracy_ADAM, time_ADAM = ADAM(model, loss_function, train_dataset, test_dataset, K, max_iter, device)\n",
    "    train_losses_ADAM.append(train_loss_ADAM)\n",
    "    test_losses_ADAM.append(test_loss_ADAM)\n",
    "    train_accuracies_ADAM.append(train_accuracy_ADAM)\n",
    "    test_accuracies_ADAM.append(test_accuracy_ADAM)\n",
    "    times_ADAM.append(time_ADAM)\n",
    "\n",
    "    print(\"Training by AMSGrad starts!\")\n",
    "    set_seed(j)\n",
    "    model = MNIST_CNN()\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "    train_loss_AMSGrad, test_loss_AMSGrad, train_accuracy_AMSGrad, test_accuracy_AMSGrad, time_AMSGrad = AMSGrad(model, loss_function, train_dataset, test_dataset, K, max_iter, device)\n",
    "    train_losses_AMSGrad.append(train_loss_AMSGrad)\n",
    "    test_losses_AMSGrad.append(test_loss_AMSGrad)\n",
    "    train_accuracies_AMSGrad.append(train_accuracy_AMSGrad)\n",
    "    test_accuracies_AMSGrad.append(test_accuracy_AMSGrad)\n",
    "    times_AMSGrad.append(time_AMSGrad)\n",
    "\n",
    "    print(\"Training by ADAMX starts!\")\n",
    "    set_seed(j)\n",
    "    model = MNIST_CNN()\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "    train_loss_ADAMX, test_loss_ADAMX, train_accuracy_ADAMX, test_accuracy_ADAMX, time_ADAMX = ADAMX(model, loss_function, train_dataset, test_dataset, K, max_iter, device)\n",
    "    train_losses_ADAMX.append(train_loss_ADAMX)\n",
    "    test_losses_ADAMX.append(test_loss_ADAMX)\n",
    "    train_accuracies_ADAMX.append(train_accuracy_ADAMX)\n",
    "    test_accuracies_ADAMX.append(test_accuracy_ADAMX)\n",
    "    times_ADAMX.append(time_ADAMX)\n",
    "\n",
    "    print(\"Training by ADAMBS starts!\")\n",
    "    set_seed(j)\n",
    "    model = MNIST_CNN()\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "    train_loss_ADAMBS, test_loss_ADAMBS, train_accuracy_ADAMBS, test_accuracy_ADAMBS, time_ADAMBS = ADAMBS(model, loss_function, train_dataset, test_dataset, K, max_iter, device)\n",
    "    train_losses_ADAMBS.append(train_loss_ADAMBS)\n",
    "    test_losses_ADAMBS.append(test_loss_ADAMBS)\n",
    "    train_accuracies_ADAMBS.append(train_accuracy_ADAMBS)\n",
    "    test_accuracies_ADAMBS.append(test_accuracy_ADAMBS)\n",
    "    times_ADAMBS.append(time_ADAMBS)\n",
    "    \n",
    "    print(\"Training by ADAMCB starts!\")\n",
    "    set_seed(j)\n",
    "    model = MNIST_CNN()\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "    train_loss_ADAMCB, test_loss_ADAMCB, train_accuracy_ADAMCB, test_accuracy_ADAMCB, time_ADAMCB = ADAMCB(model, loss_function, train_dataset, test_dataset, K, max_iter, device)\n",
    "    train_losses_ADAMCB.append(train_loss_ADAMCB)\n",
    "    test_losses_ADAMCB.append(test_loss_ADAMCB)\n",
    "    train_accuracies_ADAMCB.append(train_accuracy_ADAMCB)\n",
    "    test_accuracies_ADAMCB.append(test_accuracy_ADAMCB)\n",
    "    times_ADAMCB.append(time_ADAMCB)\n",
    "    \n",
    "    print(\"Repetition: {} finished!\".format(i+1))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
