{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44ab22ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import random\n",
    "import math\n",
    "import copy\n",
    "import io\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import OrderedDict\n",
    "\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import cuda\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader, TensorDataset, Sampler\n",
    "\n",
    "def set_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(seed)\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    g = torch.Generator()\n",
    "    g.manual_seed(seed)\n",
    "\n",
    "set_seed(0)\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = 'cuda'\n",
    "else:\n",
    "    device = 'cpu'\n",
    "\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "import sys\n",
    "sys.path.append('models_scratch/')\n",
    "sys.path.append('data/')\n",
    "from models_scratch import *\n",
    "from data_utils import *\n",
    "\n",
    "%matplotlib inline\n",
    "sns.set(style=\"whitegrid\")\n",
    "\n",
    "## Double precision or not?\n",
    "doublePrecision = False\n",
    "\n",
    "\n",
    "if doublePrecision:\n",
    "    torch.set_default_dtype(torch.float64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8821c45f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cifar10_train_loader(batch_size=256, size=32):\n",
    "    transform = transforms.Compose([\n",
    "        transforms.RandomCrop(32, padding=4),\n",
    "        transforms.Resize(size),\n",
    "        transforms.RandomHorizontalFlip(),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "    ])\n",
    "    trainset = torchvision.datasets.CIFAR10(root='data/', train=True, download=False, transform=transform)\n",
    "    return trainset\n",
    "def get_cifar10_test_loader(size=32):\n",
    "    transform = transforms.Compose([\n",
    "        transforms.Resize(size),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "    ])\n",
    "    testset = torchvision.datasets.CIFAR10(root='data/', train=False, download=False, transform=transform)\n",
    "    return testset \n",
    "\n",
    "trainset, testset = get_cifar10_train_loader(), get_cifar10_test_loader()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cfe7dba",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = torch.stack([img for img, _ in trainset])\n",
    "y_train = torch.tensor(trainset.targets)\n",
    "\n",
    "X_test = torch.stack([img for img, _ in testset])\n",
    "y_test = torch.tensor(testset.targets) \n",
    "\n",
    "def filter_cifar10(X, y, minority_class=1, majority_class=9, minority_fraction=0.1):\n",
    "    y = y.clone().detach()\n",
    "\n",
    "    minority_indices = torch.where(y == minority_class)[0]\n",
    "    majority_indices = torch.where(y == majority_class)[0]\n",
    "\n",
    "    n1 = len(majority_indices)\n",
    "    n0 = int((minority_fraction / (1 - minority_fraction)) * n1)\n",
    "\n",
    "    selected_minority_indices = minority_indices[torch.randperm(len(minority_indices))[:n0]]\n",
    "\n",
    "    final_indices = torch.cat([selected_minority_indices, majority_indices])\n",
    "    final_indices = final_indices[torch.randperm(len(final_indices))]\n",
    "\n",
    "    X_filtered = X[final_indices]\n",
    "    y_filtered = y[final_indices]\n",
    "\n",
    "    y_filtered = (y_filtered == majority_class).long()\n",
    "\n",
    "    S_filtered = y_filtered.clone()\n",
    "\n",
    "    return X_filtered, y_filtered, S_filtered\n",
    "\n",
    "X_train, y_train, S_train = filter_cifar10(X_train, y_train, minority_fraction=0.03) #0.03\n",
    "X_test, y_test, S_test = filter_cifar10(X_test, y_test, minority_fraction=0.03)\n",
    "\n",
    "\n",
    "class DeterministicReverseSampler(Sampler):\n",
    "    def __init__(self, base_sampler):\n",
    "        self.base_sampler = base_sampler\n",
    "        self._cached_indices = None\n",
    "\n",
    "    def __iter__(self):\n",
    "        if self._cached_indices is None:\n",
    "            self._cached_indices = list(self.base_sampler)\n",
    "        return iter(reversed(self._cached_indices))\n",
    "\n",
    "    def __len__(self):\n",
    "        if self._cached_indices is None:\n",
    "            self._cached_indices = list(self.base_sampler)\n",
    "        return len(self._cached_indices)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d6985cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_model(network, num_classes, input_channels, input_height, input_width, batch_norm = False, device='cuda'):\n",
    "    \n",
    "    if batch_norm:\n",
    "        norm_layer = nn.BatchNorm2d\n",
    "    else:\n",
    "        norm_layer = None\n",
    "\n",
    "    if network == \"vgg11\":\n",
    "        net = VGG(\"VGG11\", num_classes=num_classes, batch_norm=batch_norm)\n",
    "    elif network == \"vgg19\":\n",
    "        net = VGG(\"VGG19\", num_classes=num_classes, batch_norm=batch_norm)\n",
    "    elif network == \"resnet18\":\n",
    "        net = resnet18(norm_layer=norm_layer, num_classes=num_classes)\n",
    "    elif network == \"resnet34\":\n",
    "        net = resnet34(norm_layer=norm_layer, num_classes=num_classes)\n",
    "    elif network == \"resnet50\":\n",
    "        net = resnet50(norm_layer=norm_layer, num_classes=num_classes)\n",
    "    elif network == \"densenet121\":\n",
    "        net = densenet121(norm_layer=norm_layer, num_classes=num_classes,\n",
    "                          input_channels=input_channels, input_height=input_height, input_width=input_width)\n",
    "    elif network == \"mobilenet\":\n",
    "        net = MobileNet(num_classes=num_classes,\n",
    "                          input_channels=input_channels, input_height=input_height, input_width=input_width)\n",
    "    elif network == \"squeezenet\":\n",
    "        net = SqueezeNet(num_classes=num_classes,\n",
    "                          input_channels=input_channels, input_height=input_height, input_width=input_width)\n",
    "    elif network == \"lenet\":\n",
    "        net = LeNet5(num_classes=num_classes, input_channels=input_channels,\n",
    "                     input_height=input_height, input_width=input_width)\n",
    "    else:\n",
    "        raise ValueError(\"Invalid network name.\")\n",
    "\n",
    "    net = net.to(device)\n",
    "    \n",
    "    num_params = sum(p.numel() for p in net.parameters())\n",
    "    print(f\"Total number of parameters in {network}: {num_params:,}\")\n",
    "    \n",
    "    return net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27e00ab2",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(0)\n",
    "g = torch.Generator()\n",
    "g.manual_seed(0)\n",
    "train_dataset = TensorDataset(X_train, S_train, y_train)\n",
    "trainloader = DataLoader(train_dataset, batch_size=len(y_train), shuffle=True, generator=g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9819ec63",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = build_model(\"resnet18\", 2, 3, 32, 32, device)\n",
    "theta_init = copy.deepcopy(model.state_dict())\n",
    "learning_rate = 3e-2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85ad3119",
   "metadata": {},
   "source": [
    "### Step 1: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b88e89ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_S1_phase(model, dataloader, device, epochs=5):\n",
    "\n",
    "    metrics = {\"epoch\": [], \"loss\": [], \"acc\": []}\n",
    "    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0)\n",
    "    criterion = nn.CrossEntropyLoss(reduction='mean')\n",
    "    model.train()\n",
    "    best_acc = 0\n",
    "    \n",
    "    for epoch in range(epochs):\n",
    "        running_loss = 0.0\n",
    "        correct = 0\n",
    "        count = 0\n",
    "        for X_batch, S_batch, y_batch in dataloader:\n",
    "            mask = (S_batch == 1)\n",
    "            if mask.sum().item() == 0:\n",
    "                break\n",
    "            X_s1 = X_batch[mask].to(device)\n",
    "            y_s1 = y_batch[mask].to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(X_s1)\n",
    "            loss = criterion(outputs, y_s1)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            running_loss += loss.item() * X_s1.size(0)\n",
    "            count += X_s1.size(0)\n",
    "            _, preds = outputs.max(1)\n",
    "            correct += preds.eq(y_s1).sum().item()\n",
    "            \n",
    "        avg_loss = running_loss / count if count > 0 else 0\n",
    "        avg_acc = 100 * correct / count if count > 0 else 0\n",
    "        best_acc = max(best_acc, avg_acc)\n",
    "        metrics[\"epoch\"].append(epoch)\n",
    "        metrics[\"loss\"].append(avg_loss)\n",
    "        metrics[\"acc\"].append(best_acc)\n",
    "        \n",
    "        if (epoch + 1) % 10 == 0 or epoch == 0:\n",
    "            print(f\"Phase S=1, Epoch {epoch +1}, Loss: {avg_loss:.4f}, Acc: {avg_acc:.2f}%\")\n",
    "            \n",
    "    theta1 = copy.deepcopy(model.state_dict())\n",
    "    return metrics, theta1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9d782f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_phase1, theta1 = train_S1_phase(model, trainloader, device, epochs=150) ## 150"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cae7f264",
   "metadata": {},
   "source": [
    "### Step 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2358d96e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gradient_ascent_phase(model, dataloader, device, theta1, epochs=10):\n",
    "\n",
    "    metrics = {\n",
    "        \"epoch\": [],\n",
    "        \"loss_s0\": [],\n",
    "        \"loss_s1\": [],\n",
    "        \"loss_global\": [],\n",
    "        \"acc_s0\": [],\n",
    "        \"acc_s1\": [],\n",
    "        \"acc_global\": []\n",
    "    }\n",
    "\n",
    "    model.load_state_dict(theta1)\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0)\n",
    "    criterion = nn.CrossEntropyLoss(reduction='mean')\n",
    "    for name, param in model.named_parameters():\n",
    "        if name in theta1:\n",
    "            saved_param = theta1[name]\n",
    "            assert param.shape == saved_param.shape, f\"Shape mismatch for {name}: {param.shape} vs {saved_param.shape}\"\n",
    "            assert torch.allclose(param.cpu(), saved_param.cpu(), atol=1e-6), f\"Value mismatch for {name}\"\n",
    "        else:\n",
    "            raise ValueError(f\"{name} not found in theta_unlucky\")\n",
    "    model.train()\n",
    "\n",
    "    best_acc = 0\n",
    "    best_acc_s0 = 0\n",
    "    best_acc_s1 = 0\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "        total_loss, total_samples = 0.0, 0\n",
    "        loss_s0_sum, count_s0 = 0.0, 0\n",
    "        loss_s1_sum, count_s1 = 0.0, 0\n",
    "        correct_total, correct_s0, correct_s1 = 0, 0, 0\n",
    "\n",
    "        for X_batch, S_batch, y_batch in dataloader:\n",
    "            X_batch, S_batch, y_batch = X_batch.to(device), S_batch.to(device), y_batch.to(device)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(X_batch)\n",
    "            loss = criterion(outputs, y_batch)\n",
    "            loss.backward()\n",
    "\n",
    "            for param in model.parameters():\n",
    "                if param.grad is not None:\n",
    "                    param.grad.data.mul_(-1)\n",
    "            optimizer.step()\n",
    "\n",
    "            bsize = y_batch.size(0)\n",
    "            total_loss += loss.item() * bsize\n",
    "            total_samples += bsize\n",
    "\n",
    "            _, preds = outputs.max(1)\n",
    "            correct_total += preds.eq(y_batch).sum().item()\n",
    "\n",
    "            mask_s0 = (S_batch == 0)\n",
    "            if mask_s0.any():\n",
    "                n_s0 = mask_s0.sum().item()\n",
    "                loss_s0 = criterion(outputs[mask_s0], y_batch[mask_s0]).item()\n",
    "                loss_s0_sum += loss_s0 * n_s0\n",
    "                count_s0 += n_s0\n",
    "                correct_s0 += preds[mask_s0].eq(y_batch[mask_s0]).sum().item()\n",
    "\n",
    "            mask_s1 = (S_batch == 1)\n",
    "            if mask_s1.any():\n",
    "                n_s1 = mask_s1.sum().item()\n",
    "                loss_s1 = criterion(outputs[mask_s1], y_batch[mask_s1]).item()\n",
    "                loss_s1_sum += loss_s1 * n_s1\n",
    "                count_s1 += n_s1\n",
    "                correct_s1 += preds[mask_s1].eq(y_batch[mask_s1]).sum().item()\n",
    "\n",
    "        avg_loss    = total_loss / total_samples\n",
    "        avg_loss_s0 = loss_s0_sum / count_s0 if count_s0 > 0 else 0\n",
    "        avg_loss_s1 = loss_s1_sum / count_s1 if count_s1 > 0 else 0\n",
    "\n",
    "        acc_total = (correct_total / total_samples) * 100\n",
    "        acc_s0    = (correct_s0 / count_s0) * 100 if count_s0 > 0 else 0\n",
    "        acc_s1    = (correct_s1 / count_s1) * 100 if count_s1 > 0 else 0\n",
    "\n",
    "        best_acc_s0 = max(best_acc_s0, acc_s0)\n",
    "        best_acc_s1 = max(best_acc_s1, acc_s1)\n",
    "        best_acc    = max(best_acc, acc_total)\n",
    "\n",
    "        metrics[\"epoch\"].append(epoch)\n",
    "        metrics[\"loss_s0\"].append(avg_loss_s0)\n",
    "        metrics[\"loss_s1\"].append(avg_loss_s1)\n",
    "        metrics[\"loss_global\"].append(avg_loss)\n",
    "        metrics[\"acc_s0\"].append(best_acc_s0)\n",
    "        metrics[\"acc_s1\"].append(best_acc_s1)\n",
    "        metrics[\"acc_global\"].append(best_acc)\n",
    "\n",
    "        if (epoch + 1) % 10 == 0 or epoch == 0:\n",
    "            print(f\"Gradient Ascent, Epoch {epoch+1}, Loss S0: {avg_loss_s0:.4f}, Acc S0: {acc_s0:.2f}%, \"\n",
    "                  f\"Loss S1: {avg_loss_s1:.4f}, Acc S1: {acc_s1:.2f}%, Global Loss: {avg_loss:.4f}, Global Acc: {acc_total:.2f}%\")\n",
    "\n",
    "    theta_unlucky = copy.deepcopy(model.state_dict())\n",
    "    return metrics, theta_unlucky"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0b81cc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_phase2, theta_unlucky = gradient_ascent_phase(model, trainloader, device, theta1, epochs=100) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c759435",
   "metadata": {},
   "source": [
    "### Step 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17e84b2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_full_phase(model, dataloader, device, theta, epochs, kappa=99):\n",
    " \n",
    "    metrics = {\n",
    "        \"epoch\": [],\n",
    "        \"loss_s0\": [],\n",
    "        \"loss_s1\": [],\n",
    "        \"loss_global\": [],\n",
    "        \"acc_s0\": [],\n",
    "        \"acc_s1\": [],\n",
    "        \"acc_global\": []\n",
    "    }\n",
    "    model.load_state_dict(theta)\n",
    "    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0)\n",
    "    criterion = nn.CrossEntropyLoss(reduction='mean')\n",
    "    model.train()\n",
    "    \n",
    "    best_acc = 0\n",
    "    best_acc_s0 = 0\n",
    "    best_acc_s1 = 0\n",
    "    for epoch in range(epochs):\n",
    "        total_loss, total_samples = 0.0, 0\n",
    "        loss_s0_sum, count_s0 = 0.0, 0\n",
    "        loss_s1_sum, count_s1 = 0.0, 0\n",
    "        correct_total, correct_s0, correct_s1 = 0, 0, 0\n",
    "        total_loss, total_samples = 0.0, 0\n",
    "        for X_batch, S_batch, y_batch in trainloader:\n",
    "            X_batch, S_batch, y_batch = X_batch.to(device), S_batch.to(device), y_batch.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(X_batch)\n",
    "            loss = criterion(outputs, y_batch)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            bsize = y_batch.size(0)\n",
    "            total_loss += loss.item() * bsize\n",
    "            total_samples += bsize\n",
    "            \n",
    "            _, preds = outputs.max(1)\n",
    "            correct_total += preds.eq(y_batch).sum().item()\n",
    "            \n",
    "            mask_s0 = (S_batch == 0)\n",
    "            if mask_s0.any():\n",
    "                n_s0 = mask_s0.sum().item()\n",
    "                loss_s0 = criterion(outputs[mask_s0], y_batch[mask_s0]).item()\n",
    "                loss_s0_sum += loss_s0 * n_s0\n",
    "                count_s0 += n_s0\n",
    "                correct_s0 += preds[mask_s0].eq(y_batch[mask_s0]).sum().item()\n",
    "                \n",
    "            mask_s1 = (S_batch == 1)\n",
    "            if mask_s1.any():\n",
    "                n_s1 = mask_s1.sum().item()\n",
    "                loss_s1 = criterion(outputs[mask_s1], y_batch[mask_s1]).item()\n",
    "                loss_s1_sum += loss_s1 * n_s1\n",
    "                count_s1 += n_s1\n",
    "                correct_s1 += preds[mask_s1].eq(y_batch[mask_s1]).sum().item()\n",
    "        \n",
    "        avg_loss    = total_loss / total_samples\n",
    "        avg_loss_s0 = loss_s0_sum / count_s0 if count_s0 > 0 else 0\n",
    "        avg_loss_s1 = loss_s1_sum / count_s1 if count_s1 > 0 else 0\n",
    "        \n",
    "        acc_total = (correct_total / total_samples) * 100\n",
    "        acc_s0    = (correct_s0 / count_s0) * 100 if count_s0 > 0 else 0\n",
    "        acc_s1    = (correct_s1 / count_s1) * 100 if count_s1 > 0 else 0\n",
    "        \n",
    "        best_acc_s0 = max(best_acc_s0, acc_s0)\n",
    "        best_acc_s1 = max(best_acc_s1, acc_s1)\n",
    "        best_acc = max(best_acc, acc_total)\n",
    "        \n",
    "        metrics[\"epoch\"].append(epoch)\n",
    "        metrics[\"loss_s0\"].append(avg_loss_s0)\n",
    "        metrics[\"loss_s1\"].append(avg_loss_s1)\n",
    "        metrics[\"loss_global\"].append(avg_loss)\n",
    "        metrics[\"acc_s0\"].append(best_acc_s0)\n",
    "        metrics[\"acc_s1\"].append(best_acc_s1)\n",
    "        metrics[\"acc_global\"].append(best_acc_s1)\n",
    "        \n",
    "        if acc_s0 > kappa:\n",
    "            final_epoch = epoch + 1\n",
    "            break\n",
    "        \n",
    "        if (epoch + 1) % 10 == 0 or epoch == 0:\n",
    "            print(f\"Phase Full, Epoch {epoch+1}, Loss S0: {avg_loss_s0:.4f}, Acc S0: {acc_s0:.2f}%, \"\n",
    "              f\"Loss S1: {avg_loss_s1:.4f}, Acc S1: {acc_s1:.2f}%, Global Loss: {avg_loss:.4f}, Global Acc: {acc_total:.2f}%\")\n",
    "    return metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f18f08b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_phase3 = train_full_phase(model, trainloader, device, theta_unlucky, epochs=5000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6d10035",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = build_model(\"resnet18\", 2, 3, 32, 32, device)\n",
    "model.load_state_dict(theta_init)\n",
    "optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0)\n",
    "criterion = nn.CrossEntropyLoss(reduction='mean')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "527ac471",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_phase4 = train_full_phase(model, trainloader, device, theta_init, epochs=5000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d7a599f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_full_phase2(model, dataloader, device, theta, epochs, kappa=99):\n",
    "    metrics = {\n",
    "        \"epoch\": [],\n",
    "        \"loss_s0\": [],\n",
    "        \"loss_s1\": [],\n",
    "        \"loss_global\": [],\n",
    "        \"acc_s0\": [],\n",
    "        \"acc_s1\": [],\n",
    "        \"acc_global\": []\n",
    "    }\n",
    "    model.load_state_dict(theta)\n",
    "    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0)\n",
    "    criterion = nn.CrossEntropyLoss(reduction='mean')\n",
    "    model.train()\n",
    "    \n",
    "    best_acc = 0\n",
    "    best_acc_s0 = 0\n",
    "    best_acc_s1 = 0\n",
    "    for epoch in range(epochs):\n",
    "        total_loss, total_samples = 0.0, 0\n",
    "        loss_s0_sum, count_s0 = 0.0, 0\n",
    "        loss_s1_sum, count_s1 = 0.0, 0\n",
    "        correct_total, correct_s0, correct_s1 = 0, 0, 0\n",
    "        total_loss, total_samples = 0.0, 0\n",
    "        for X_batch, S_batch, y_batch in trainloader:\n",
    "            X_batch, S_batch, y_batch = X_batch.to(device), S_batch.to(device), y_batch.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(X_batch)\n",
    "            loss = criterion(outputs, y_batch)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            bsize = y_batch.size(0)\n",
    "            total_loss += loss.item() * bsize\n",
    "            total_samples += bsize\n",
    "            \n",
    "            _, preds = outputs.max(1)\n",
    "            correct_total += preds.eq(y_batch).sum().item()\n",
    "            \n",
    "            mask_s0 = (S_batch == 0)\n",
    "            if mask_s0.any():\n",
    "                n_s0 = mask_s0.sum().item()\n",
    "                loss_s0 = criterion(outputs[mask_s0], y_batch[mask_s0]).item()\n",
    "                loss_s0_sum += loss_s0 * n_s0\n",
    "                count_s0 += n_s0\n",
    "                correct_s0 += preds[mask_s0].eq(y_batch[mask_s0]).sum().item()\n",
    "                \n",
    "            mask_s1 = (S_batch == 1)\n",
    "            if mask_s1.any():\n",
    "                n_s1 = mask_s1.sum().item()\n",
    "                loss_s1 = criterion(outputs[mask_s1], y_batch[mask_s1]).item()\n",
    "                loss_s1_sum += loss_s1 * n_s1\n",
    "                count_s1 += n_s1\n",
    "                correct_s1 += preds[mask_s1].eq(y_batch[mask_s1]).sum().item()\n",
    "        \n",
    "        avg_loss    = total_loss / total_samples\n",
    "        avg_loss_s0 = loss_s0_sum / count_s0 if count_s0 > 0 else 0\n",
    "        avg_loss_s1 = loss_s1_sum / count_s1 if count_s1 > 0 else 0\n",
    "        \n",
    "        acc_total = (correct_total / total_samples) * 100\n",
    "        acc_s0    = (correct_s0 / count_s0) * 100 if count_s0 > 0 else 0\n",
    "        acc_s1    = (correct_s1 / count_s1) * 100 if count_s1 > 0 else 0\n",
    "        \n",
    "        best_acc_s0 = max(best_acc_s0, acc_s0)\n",
    "        best_acc_s1 = max(best_acc_s1, acc_s1)\n",
    "        best_acc = max(best_acc, acc_total)\n",
    "        \n",
    "        metrics[\"epoch\"].append(epoch)\n",
    "        metrics[\"loss_s0\"].append(avg_loss_s0)\n",
    "        metrics[\"loss_s1\"].append(avg_loss_s1)\n",
    "        metrics[\"loss_global\"].append(avg_loss)\n",
    "        metrics[\"acc_s0\"].append(best_acc_s0)\n",
    "        metrics[\"acc_s1\"].append(best_acc_s1)\n",
    "        metrics[\"acc_global\"].append(best_acc_s1)\n",
    "        \n",
    "        if acc_total > kappa:\n",
    "            final_epoch = epoch + 1\n",
    "            break\n",
    "        \n",
    "        if (epoch + 1) % 10 == 0 or epoch == 0:\n",
    "            print(f\"Phase Full, Epoch {epoch+1}, Loss S0: {avg_loss_s0:.4f}, Acc S0: {acc_s0:.2f}%, \"\n",
    "              f\"Loss S1: {avg_loss_s1:.4f}, Acc S1: {acc_s1:.2f}%, Global Loss: {avg_loss:.4f}, Global Acc: {acc_total:.2f}%\")\n",
    "    return metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cf3ce98",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = build_model(\"resnet18\", 2, 3, 32, 32, device)\n",
    "model.load_state_dict(theta_init)\n",
    "optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0)\n",
    "criterion = nn.CrossEntropyLoss(reduction='mean')\n",
    "metrics_phase5 = train_full_phase2(model, trainloader, device, theta_init, epochs=5000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b44e2801",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_metrics(metrics_phase3, metrics_phase4, metrics_phase5):\n",
    "    import numpy as np\n",
    "    import matplotlib.pyplot as plt\n",
    "    from collections import OrderedDict\n",
    "\n",
    "    color_dict = {\n",
    "        r\"$L$\": \"green\",\n",
    "        r\"$L_0$\": \"blue\",\n",
    "        r\"$L_1$\": \"orange\"\n",
    "    }\n",
    "\n",
    "    skip = 5\n",
    "    epochs_phase3 = np.arange(len(metrics_phase3[\"epoch\"]) - skip)\n",
    "    epochs_phase4 = np.arange(len(metrics_phase4[\"epoch\"]) - skip)\n",
    "    epochs_phase5 = np.arange(len(metrics_phase5[\"epoch\"]) - skip)\n",
    "    \n",
    "    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 3.75))\n",
    "\n",
    "    ax1.plot(epochs_phase3, metrics_phase3[\"loss_s1\"][skip:], label=r\"$L_1$\", color=color_dict[r\"$L_1$\"])\n",
    "    ax1.plot(epochs_phase3, metrics_phase3[\"loss_s0\"][skip:], label=r\"$L_0$\", color=color_dict[r\"$L_0$\"])\n",
    "    ax1.plot(epochs_phase3, metrics_phase3[\"loss_global\"][skip:], label=r\"$L$\", color=color_dict[r\"$L$\"])\n",
    "    ax1.set_xlabel(\"Epoch\")\n",
    "    ax1.set_ylabel(\"Loss\")\n",
    "    ax1.set_yscale('log')\n",
    "\n",
    "    ax2.plot(epochs_phase4, metrics_phase4[\"loss_s1\"][skip:], label=r\"$L_1$\", color=color_dict[r\"$L_1$\"])\n",
    "    ax2.plot(epochs_phase4, metrics_phase4[\"loss_s0\"][skip:], label=r\"$L_0$\", color=color_dict[r\"$L_0$\"])\n",
    "    ax2.plot(epochs_phase4, metrics_phase4[\"loss_global\"][skip:], label=r\"$L$\", color=color_dict[r\"$L$\"])\n",
    "    ax2.set_xlabel(\"Epoch\")\n",
    "    ax2.set_ylabel(\"\")\n",
    "    ax2.set_yscale('log')\n",
    "    \n",
    "    ax3.plot(epochs_phase5, metrics_phase5[\"loss_s1\"][skip:], label=r\"$L_1$\", color=color_dict[r\"$L_1$\"])\n",
    "    ax3.plot(epochs_phase5, metrics_phase5[\"loss_s0\"][skip:], label=r\"$L_0$\", color=color_dict[r\"$L_0$\"])\n",
    "    ax3.plot(epochs_phase5, metrics_phase5[\"loss_global\"][skip:], label=r\"$L$\", color=color_dict[r\"$L$\"])\n",
    "    ax3.set_xlabel(\"Epoch\")\n",
    "    ax3.set_ylabel(\"\")\n",
    "    ax3.set_yscale('log')\n",
    "\n",
    "    handles1, labels1 = ax1.get_legend_handles_labels()\n",
    "    handles2, labels2 = ax2.get_legend_handles_labels()\n",
    "    handles3, labels3 = ax3.get_legend_handles_labels()\n",
    "    all_handles = handles1 + handles2 + handles3\n",
    "    all_labels = labels1 + labels2 + labels3\n",
    "\n",
    "    unique = OrderedDict()\n",
    "    for h, l in zip(all_handles, all_labels):\n",
    "        if l not in unique:\n",
    "            unique[l] = h\n",
    "\n",
    "    fig.legend(unique.values(), unique.keys(), loc=\"lower center\",\n",
    "               ncol=len(unique), bbox_to_anchor=(0.5, -0.05))\n",
    "\n",
    "    plt.tight_layout(rect=[0, 0.07, 1, 1])\n",
    "    plt.savefig(\"results/CIFAR-2/trajectory.pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81997945",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metrics(metrics_phase3, metrics_phase4, metrics_phase5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
