{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# CIFAR-10"
      ],
      "metadata": {
        "id": "w88fKmIns1nx"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RmhtveC7qUqv"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "from torch.utils.data import DataLoader, Dataset, Subset\n",
        "from torchvision import datasets, transforms\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "from sklearn.metrics import classification_report, confusion_matrix\n",
        "import seaborn as sns\n",
        "import time\n",
        "import pandas as pd\n",
        "from torch.optim.lr_scheduler import CosineAnnealingLR\n",
        "from thop import profile  # For calculating FLOPs\n",
        "import math\n",
        "\n",
        "GLOBAL_TRAIN_BATCH_INDICES = None\n",
        "\n",
        "# GPU Configuration\n",
        "def configure_gpu():\n",
        "    \"\"\"\n",
        "    Configure GPU settings for PyTorch\n",
        "    \"\"\"\n",
        "    # Check for GPU availability\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    print(f\"Using device: {device}\")\n",
        "\n",
        "    # Print number of available GPUs\n",
        "    if torch.cuda.is_available():\n",
        "        print(f\"Num GPUs Available: {torch.cuda.device_count()}\")\n",
        "\n",
        "        # Configure GPU memory management\n",
        "        torch.cuda.empty_cache()\n",
        "        print(\"GPU memory cleared\")\n",
        "\n",
        "    return device\n",
        "\n",
        "def load_images(batch_size=128, val_size=5000):\n",
        "    \"\"\"\n",
        "    Loads CIFAR-10 with train/val/test split:\n",
        "    - Training: 45,000 samples\n",
        "    - Validation: 5,000 samples\n",
        "    - Test: 10,000 samples\n",
        "\n",
        "    Uses consistent batch composition with efficient sampling for training.\n",
        "    \"\"\"\n",
        "    # Define transformations\n",
        "    transform_train = transforms.Compose([\n",
        "        transforms.RandomCrop(32, padding=4),\n",
        "        transforms.RandomHorizontalFlip(),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))\n",
        "    ])\n",
        "\n",
        "    transform_val_test = transforms.Compose([\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))\n",
        "    ])\n",
        "\n",
        "    # Load full training dataset (50,000 samples)\n",
        "    # Use CIFAR100 instead of CIFAR10 to run it with CIFAR100. Change the normalization values and class numbers accordingly.\n",
        "    full_train_dataset = datasets.CIFAR10(\n",
        "        root='./data',\n",
        "        train=True,\n",
        "        download=True,\n",
        "        transform=transform_train\n",
        "    )\n",
        "\n",
        "    # Load validation dataset with test transforms (no augmentation)\n",
        "    full_train_dataset_for_val = datasets.CIFAR10(\n",
        "        root='./data',\n",
        "        train=True,\n",
        "        download=True,\n",
        "        transform=transform_val_test\n",
        "    )\n",
        "\n",
        "    # Load test dataset\n",
        "    test_dataset = datasets.CIFAR10(\n",
        "        root='./data',\n",
        "        train=False,\n",
        "        download=True,\n",
        "        transform=transform_val_test\n",
        "    )\n",
        "\n",
        "    class_names = ['plane', 'car', 'bird', 'cat', 'deer',\n",
        "                   'dog', 'frog', 'horse', 'ship', 'truck']\n",
        "\n",
        "    # --- Split training data into train (45k) and validation (5k) ---\n",
        "    num_total_train = len(full_train_dataset)  # 50,000\n",
        "    indices = np.arange(num_total_train)\n",
        "    np.random.shuffle(indices)\n",
        "\n",
        "    # Split indices\n",
        "    val_indices = indices[:val_size]  # First 5,000 for validation\n",
        "    train_indices = indices[val_size:]  # Remaining 45,000 for training\n",
        "\n",
        "    print(f\"\\nData Split:\")\n",
        "    print(f\"  Training samples: {len(train_indices)}\")\n",
        "    print(f\"  Validation samples: {len(val_indices)}\")\n",
        "    print(f\"  Test samples: {len(test_dataset)}\")\n",
        "\n",
        "    # Create validation subset (using non-augmented transforms)\n",
        "    val_dataset = Subset(full_train_dataset_for_val, val_indices)\n",
        "\n",
        "    # --- Efficient Consistent Batch Handling for Training ---\n",
        "    # Shuffle training indices and group into batches\n",
        "    np.random.shuffle(train_indices)\n",
        "\n",
        "    # Group into batches\n",
        "    batch_indices = [\n",
        "        train_indices[i:i + batch_size].tolist()\n",
        "        for i in range(0, len(train_indices), batch_size)\n",
        "    ]\n",
        "\n",
        "    # Save batch indices for later access in training loop\n",
        "    global GLOBAL_TRAIN_BATCH_INDICES\n",
        "    GLOBAL_TRAIN_BATCH_INDICES = batch_indices  # Make accessible for pruning logic\n",
        "\n",
        "    print(f\"\\nCreated consistent batch setup with {len(batch_indices)} training batches\")\n",
        "    print(f\"Each batch contains the same data samples across all epochs\")\n",
        "    print(f\"Intra-batch shuffling is enabled to prevent memorization\")\n",
        "\n",
        "    # Validation DataLoader\n",
        "    val_loader = DataLoader(\n",
        "        val_dataset,\n",
        "        batch_size=batch_size,\n",
        "        shuffle=False,\n",
        "        num_workers=4,\n",
        "        pin_memory=True\n",
        "    )\n",
        "\n",
        "    # Test DataLoader\n",
        "    test_loader = DataLoader(\n",
        "        test_dataset,\n",
        "        batch_size=batch_size,\n",
        "        shuffle=False,\n",
        "        num_workers=4,\n",
        "        pin_memory=True\n",
        "    )\n",
        "\n",
        "    return full_train_dataset, val_loader, test_loader, class_names\n",
        "\n",
        "\n",
        "def build_train_loader_for_epoch(full_train_dataset, active_batch_indices, batch_size, shuffle_within_batch=True):\n",
        "\n",
        "    global GLOBAL_TRAIN_BATCH_INDICES\n",
        "\n",
        "    # Collect (logical_batch_idx, sample_indices) for active batches only\n",
        "    active_batches = []\n",
        "    for logical_idx in active_batch_indices:\n",
        "        sample_indices = GLOBAL_TRAIN_BATCH_INDICES[logical_idx].copy()\n",
        "        if shuffle_within_batch:\n",
        "            np.random.shuffle(sample_indices)\n",
        "        active_batches.append((logical_idx, sample_indices))\n",
        "\n",
        "    # Flatten sample indices in order, and record which logical batch each sample belongs to\n",
        "    flat_indices = []\n",
        "    logical_index_order = []  # One entry per batch, in iteration order\n",
        "    for logical_idx, sample_indices in active_batches:\n",
        "        flat_indices.extend(sample_indices)\n",
        "        logical_index_order.append(logical_idx)\n",
        "\n",
        "    # Custom sampler that yields the flat indices in exact order\n",
        "    class OrderedIndexSampler:\n",
        "        def __init__(self, indices):\n",
        "            self.indices = indices\n",
        "\n",
        "        def __iter__(self):\n",
        "            return iter(self.indices)\n",
        "\n",
        "        def __len__(self):\n",
        "            return len(self.indices)\n",
        "\n",
        "    sampler = OrderedIndexSampler(flat_indices)\n",
        "\n",
        "    train_loader = DataLoader(\n",
        "        full_train_dataset,\n",
        "        batch_size=batch_size,\n",
        "        sampler=sampler,\n",
        "        num_workers=4,\n",
        "        shuffle=False,\n",
        "        pin_memory=True,\n",
        "        drop_last=False\n",
        "    )\n",
        "\n",
        "    return train_loader, logical_index_order\n",
        "\n",
        "\n",
        "# Updated ResNet model definitions from paste-2.txt\n",
        "class BasicBlock(nn.Module):\n",
        "    expansion = 1\n",
        "\n",
        "    def __init__(self, in_planes, planes, stride=1):\n",
        "        super(BasicBlock, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(\n",
        "            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
        "        self.bn1 = nn.BatchNorm2d(planes)\n",
        "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,\n",
        "                               stride=1, padding=1, bias=False)\n",
        "        self.bn2 = nn.BatchNorm2d(planes)\n",
        "\n",
        "        self.shortcut = nn.Sequential()\n",
        "        if stride != 1 or in_planes != self.expansion*planes:\n",
        "            self.shortcut = nn.Sequential(\n",
        "                nn.Conv2d(in_planes, self.expansion*planes,\n",
        "                          kernel_size=1, stride=stride, bias=False),\n",
        "                nn.BatchNorm2d(self.expansion*planes)\n",
        "            )\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = F.relu(self.bn1(self.conv1(x)))\n",
        "        out = self.bn2(self.conv2(out))\n",
        "        out += self.shortcut(x)\n",
        "        out = F.relu(out)\n",
        "        return out\n",
        "\n",
        "\n",
        "class Bottleneck(nn.Module):\n",
        "    expansion = 4\n",
        "\n",
        "    def __init__(self, in_planes, planes, stride=1):\n",
        "        super(Bottleneck, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)\n",
        "        self.bn1 = nn.BatchNorm2d(planes)\n",
        "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,\n",
        "                               stride=stride, padding=1, bias=False)\n",
        "        self.bn2 = nn.BatchNorm2d(planes)\n",
        "        self.conv3 = nn.Conv2d(planes, self.expansion *\n",
        "                               planes, kernel_size=1, bias=False)\n",
        "        self.bn3 = nn.BatchNorm2d(self.expansion*planes)\n",
        "\n",
        "        self.shortcut = nn.Sequential()\n",
        "        if stride != 1 or in_planes != self.expansion*planes:\n",
        "            self.shortcut = nn.Sequential(\n",
        "                nn.Conv2d(in_planes, self.expansion*planes,\n",
        "                          kernel_size=1, stride=stride, bias=False),\n",
        "                nn.BatchNorm2d(self.expansion*planes)\n",
        "            )\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = F.relu(self.bn1(self.conv1(x)))\n",
        "        out = F.relu(self.bn2(self.conv2(out)))\n",
        "        out = self.bn3(self.conv3(out))\n",
        "        out += self.shortcut(x)\n",
        "        out = F.relu(out)\n",
        "        return out\n",
        "\n",
        "\n",
        "class ResNet(nn.Module):\n",
        "    def __init__(self, block, num_blocks, num_classes=10):\n",
        "        super(ResNet, self).__init__()\n",
        "        self.in_planes = 64\n",
        "\n",
        "        # List to track layer names for activation tracking\n",
        "        self.layer_names = ['layer1', 'layer2', 'layer3', 'layer4']\n",
        "\n",
        "        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,\n",
        "                               stride=1, padding=1, bias=False)\n",
        "        self.bn1 = nn.BatchNorm2d(64)\n",
        "        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n",
        "        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n",
        "        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n",
        "        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n",
        "        self.avgpool = nn.AdaptiveAvgPool2d((1,1))\n",
        "        self.dropout = nn.Dropout(0.2)  # Adding dropout for regularization\n",
        "        self.linear = nn.Linear(512*block.expansion, num_classes)\n",
        "\n",
        "        # List of layers to track activations\n",
        "        self.conv_layers = [self.layer1, self.layer2, self.layer3, self.layer4]\n",
        "\n",
        "    def _make_layer(self, block, planes, num_blocks, stride):\n",
        "        strides = [stride] + [1]*(num_blocks-1)\n",
        "        layers = []\n",
        "        for stride in strides:\n",
        "            layers.append(block(self.in_planes, planes, stride))\n",
        "            self.in_planes = planes * block.expansion\n",
        "        return nn.Sequential(*layers)\n",
        "\n",
        "    def forward(self, x, return_activations=False):\n",
        "        # Initialize list to store activations if requested\n",
        "        activations = []\n",
        "\n",
        "        # Initial convolution\n",
        "        out = F.relu(self.bn1(self.conv1(x)))\n",
        "\n",
        "        # ResNet block 1\n",
        "        out = self.layer1(out)\n",
        "        if return_activations:\n",
        "            activations.append(out.detach())\n",
        "\n",
        "        # ResNet block 2\n",
        "        out = self.layer2(out)\n",
        "        if return_activations:\n",
        "            activations.append(out.detach())\n",
        "\n",
        "        # ResNet block 3\n",
        "        out = self.layer3(out)\n",
        "        if return_activations:\n",
        "            activations.append(out.detach())\n",
        "\n",
        "        # ResNet block 4\n",
        "        out = self.layer4(out)\n",
        "        if return_activations:\n",
        "            activations.append(out.detach())\n",
        "\n",
        "        # Global average pooling and final classifier\n",
        "        out = self.avgpool(out)\n",
        "        out = out.view(out.size(0), -1)\n",
        "        out = self.dropout(out)  # Apply dropout before final layer\n",
        "        out = self.linear(out)\n",
        "\n",
        "        if return_activations:\n",
        "            return out, activations\n",
        "        return out\n",
        "\n",
        "    def create_emb(self, x):\n",
        "        \"\"\"\n",
        "        Extract embeddings from the model\n",
        "        \"\"\"\n",
        "        out = F.relu(self.bn1(self.conv1(x)))\n",
        "        out = self.layer1(out)\n",
        "        out = self.layer2(out)\n",
        "        out = self.layer3(out)\n",
        "        out = self.layer4(out)\n",
        "        out = self.avgpool(out)\n",
        "        out = out.view(out.size(0), -1)\n",
        "        return out\n",
        "\n",
        "def ResNet18(num_classes=10):\n",
        "    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)\n",
        "\n",
        "def ResNet50(num_classes=10):\n",
        "    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)\n",
        "\n",
        "def calculate_model_parameters(model):\n",
        "    \"\"\"\n",
        "    Calculate the total number of parameters in a model\n",
        "    \"\"\"\n",
        "    total_params = sum(p.numel() for p in model.parameters())\n",
        "    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
        "\n",
        "    print(f\"\\nModel Parameters:\")\n",
        "    print(f\"Total parameters: {total_params:,}\")\n",
        "    print(f\"Trainable parameters: {trainable_params:,}\")\n",
        "\n",
        "    return total_params, trainable_params\n",
        "\n",
        "def calculate_flops(model, input_size=(3, 32, 32)):\n",
        "    \"\"\"\n",
        "    Calculate FLOPs for the model\n",
        "    \"\"\"\n",
        "    # Create a dummy input based on the specified input size\n",
        "    device = next(model.parameters()).device\n",
        "    dummy_input = torch.randn(1, *input_size).to(device)\n",
        "\n",
        "    # Calculate FLOPs\n",
        "    flops, params = profile(model, inputs=(dummy_input,), verbose=False)\n",
        "\n",
        "    # Print results\n",
        "    print(f\"\\nModel Computational Requirements:\")\n",
        "    print(f\"FLOPs: {flops:,} ({flops/1e9:.2f} GFLOPs)\")\n",
        "\n",
        "    return flops\n",
        "\n",
        "\n",
        "def train_epoch(model, device, train_loader, optimizer, epoch, logical_index_order, scheduler=None, threshold=0.0001):\n",
        "\n",
        "    model.train()\n",
        "    train_loss = 0\n",
        "    correct = 0\n",
        "    total = 0\n",
        "\n",
        "    # Initialize storage for standard deviations per layer\n",
        "    std_devs = {layer_name: [] for layer_name in model.layer_names}\n",
        "    batch_indices_list = []\n",
        "\n",
        "    # Track throughput\n",
        "    start_time = time.time()\n",
        "    total_images_processed = 0\n",
        "    batch_processing_times = []\n",
        "\n",
        "    num_active_batches = len(logical_index_order)\n",
        "    print(f\"Using {num_active_batches} batches this epoch\")\n",
        "\n",
        "    # Iterate over the loader; each batch aligns with logical_index_order\n",
        "    for loader_batch_idx, (data, target) in enumerate(train_loader):\n",
        "        # The loader_batch_idx-th batch from the loader corresponds to\n",
        "        # logical batch logical_index_order[loader_batch_idx]\n",
        "        logical_batch_idx = logical_index_order[loader_batch_idx]\n",
        "\n",
        "        batch_start_time = time.time()\n",
        "        batch_indices_list.append(logical_batch_idx)\n",
        "\n",
        "        # Move to device\n",
        "        data, target = data.to(device), target.to(device)\n",
        "\n",
        "        # Optional: Mixup augmentation\n",
        "        if np.random.random() > 0.5:\n",
        "            lam = np.random.beta(0.2, 0.2)\n",
        "            rand_idx = torch.randperm(data.size(0)).to(device)\n",
        "            mixed_data = lam * data + (1 - lam) * data[rand_idx]\n",
        "            target_a, target_b = target, target[rand_idx]\n",
        "\n",
        "            optimizer.zero_grad()\n",
        "            output, activations = model(mixed_data, return_activations=True)\n",
        "            loss = lam * F.cross_entropy(output, target_a) + (1 - lam) * F.cross_entropy(output, target_b)\n",
        "        else:\n",
        "            optimizer.zero_grad()\n",
        "            output, activations = model(data, return_activations=True)\n",
        "            loss = F.cross_entropy(output, target)\n",
        "\n",
        "        # Compute standard deviation of activations for each layer\n",
        "        for layer_idx, act in enumerate(activations):\n",
        "            std_dev = torch.std(act).item()\n",
        "            std_devs[model.layer_names[layer_idx]].append(std_dev)\n",
        "\n",
        "        # Backward pass\n",
        "        loss.backward()\n",
        "        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
        "        optimizer.step()\n",
        "\n",
        "        # Accumulate loss and accuracy\n",
        "        train_loss += loss.item()\n",
        "        _, predicted = output.max(1)\n",
        "        total += target.size(0)\n",
        "        correct += predicted.eq(target).sum().item()\n",
        "\n",
        "        # Throughput tracking\n",
        "        batch_end_time = time.time()\n",
        "        batch_time = batch_end_time - batch_start_time\n",
        "        batch_processing_times.append(batch_time)\n",
        "        total_images_processed += data.size(0)\n",
        "\n",
        "        # Print progress\n",
        "        if (loader_batch_idx + 1) % 100 == 0:\n",
        "            print(f'Train Epoch: {epoch} [{total_images_processed}/{num_active_batches * data.size(0)} '\n",
        "                  f'({100. * (loader_batch_idx + 1) / num_active_batches:.0f}%)]\\t'\n",
        "                  f'Loss: {loss.item():.6f}\\tAccuracy: {100. * correct / total:.2f}%')\n",
        "\n",
        "    # End timing\n",
        "    end_time = time.time()\n",
        "    training_time = end_time - start_time\n",
        "\n",
        "    # Build DataFrame of standard deviations\n",
        "    std_df = pd.DataFrame(std_devs)\n",
        "    std_df['batch_idx'] = batch_indices_list\n",
        "    std_df = std_df.set_index('batch_idx')\n",
        "    std_df['mean_std'] = std_df.mean(axis=1)\n",
        "\n",
        "    # Throughput metrics\n",
        "    avg_batch_time = np.mean(batch_processing_times) if batch_processing_times else 0\n",
        "    images_per_second = total_images_processed / training_time if training_time > 0 else 0\n",
        "\n",
        "    # Average loss and accuracy\n",
        "    epoch_loss = train_loss / len(batch_indices_list) if batch_indices_list else 0\n",
        "    epoch_acc = 100. * correct / total if total > 0 else 0\n",
        "\n",
        "    # Step scheduler\n",
        "    if scheduler is not None:\n",
        "        scheduler.step()\n",
        "\n",
        "    print(f\"\\nEpoch {epoch} completed in {training_time:.2f} seconds\")\n",
        "    print(f\"Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%\")\n",
        "    print(f\"Throughput: {images_per_second:.2f} images/second, Avg batch time: {avg_batch_time:.4f} s\")\n",
        "    print(f\"Current threshold: {threshold:.8f}\")\n",
        "    print(f\"Batch composition: CONSISTENT across epochs (same samples per batch)\")\n",
        "    print(f\"Intra-batch order: SHUFFLED to prevent memorization\")\n",
        "\n",
        "    throughput_metrics = {\n",
        "        'epoch_time': training_time,\n",
        "        'images_per_second': images_per_second,\n",
        "        'avg_batch_time': avg_batch_time,\n",
        "        'total_batches': len(batch_indices_list)\n",
        "    }\n",
        "\n",
        "    return epoch_loss, epoch_acc, std_df, batch_indices_list, throughput_metrics\n",
        "\n",
        "\n",
        "def validate(model, device, data_loader, dataset_name=\"Validation\"):\n",
        "    \"\"\"\n",
        "    Validates/tests the model on a given dataset\n",
        "    \"\"\"\n",
        "    model.eval()\n",
        "    val_loss = 0\n",
        "    correct = 0\n",
        "    all_preds = []\n",
        "    all_targets = []\n",
        "\n",
        "    with torch.no_grad():\n",
        "        for data, target in data_loader:\n",
        "            # Move data to device\n",
        "            data, target = data.to(device), target.to(device)\n",
        "\n",
        "            # Forward pass\n",
        "            output = model(data)\n",
        "\n",
        "            # Compute loss\n",
        "            val_loss += F.cross_entropy(output, target).item()\n",
        "\n",
        "            # Calculate accuracy\n",
        "            pred = output.argmax(dim=1)\n",
        "            correct += pred.eq(target).sum().item()\n",
        "\n",
        "            # Store predictions and targets for confusion matrix\n",
        "            all_preds.extend(pred.cpu().numpy())\n",
        "            all_targets.extend(target.cpu().numpy())\n",
        "\n",
        "    # Calculate average loss and accuracy\n",
        "    val_loss /= len(data_loader)\n",
        "    val_acc = 100. * correct / len(data_loader.dataset)\n",
        "\n",
        "    print(f\"{dataset_name} Loss: {val_loss:.4f}, Accuracy: {val_acc:.2f}%\")\n",
        "\n",
        "    return val_loss, val_acc, all_targets, all_preds\n",
        "\n",
        "def plot_training_metrics(epoch_times, train_losses, train_accs, val_losses, val_accs):\n",
        "    \"\"\"\n",
        "    Plots metrics related to training time and performance\n",
        "    \"\"\"\n",
        "    plt.figure(figsize=(18, 12))\n",
        "\n",
        "    # Plot 1: Training Time per Epoch\n",
        "    plt.subplot(2, 2, 1)\n",
        "    plt.plot(range(1, len(epoch_times) + 1), epoch_times, marker='o')\n",
        "    plt.title('Training Time per Epoch')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Time (seconds)')\n",
        "    plt.grid(True)\n",
        "\n",
        "    # Plot 2: Accuracy\n",
        "    plt.subplot(2, 2, 3)\n",
        "    plt.plot(train_accs, label='Train Accuracy')\n",
        "    plt.plot(val_accs, label='Validation Accuracy')\n",
        "    plt.title('Model Accuracy')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Accuracy')\n",
        "    plt.legend()\n",
        "    plt.grid(True)\n",
        "\n",
        "    # Plot 3: Loss\n",
        "    plt.subplot(2, 2, 4)\n",
        "    plt.plot(train_losses, label='Train Loss')\n",
        "    plt.plot(val_losses, label='Validation Loss')\n",
        "    plt.title('Model Loss')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Loss')\n",
        "    plt.legend()\n",
        "    plt.grid(True)\n",
        "\n",
        "    plt.tight_layout()\n",
        "    plt.savefig('training_metrics_improved_woes.pdf')\n",
        "    plt.show()\n",
        "\n",
        "def plot_confusion_matrix(all_targets, all_preds, class_names, title='Confusion Matrix'):\n",
        "    \"\"\"\n",
        "    Plots the confusion matrix\n",
        "    \"\"\"\n",
        "    cm = confusion_matrix(all_targets, all_preds)\n",
        "\n",
        "    plt.figure(figsize=(10, 8))\n",
        "    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',\n",
        "                xticklabels=class_names,\n",
        "                yticklabels=class_names)\n",
        "    plt.title(title)\n",
        "    plt.xlabel('Predicted Label')\n",
        "    plt.ylabel('True Label')\n",
        "    plt.tight_layout()\n",
        "    plt.savefig('confusion_matrix_improved_woes.pdf')\n",
        "    plt.show()\n",
        "\n",
        "def calculate_exponential_threshold(delta_start, delta_end, T, t):\n",
        "    \"\"\"\n",
        "    Calculate threshold using exponential function: δ(t) = δ_initial * e^(alpha * t)\n",
        "\n",
        "    Args:\n",
        "        delta_start (float): Initial threshold value\n",
        "        delta_end (float): Final threshold value\n",
        "        T (int): Total number of epochs\n",
        "        t (int): Current epoch number\n",
        "\n",
        "    Returns:\n",
        "        float: Threshold value for the current epoch\n",
        "    \"\"\"\n",
        "\n",
        "    # Compute alpha\n",
        "    alpha = math.log(delta_end / delta_start) / T\n",
        "    threshold = delta_start * math.exp(alpha * t)\n",
        "    return threshold\n",
        "\n",
        "\n",
        "\n",
        "def plot_batch_prune_analysis(threshold_history, remaining_batches_history,\n",
        "                            dropped_per_epoch_history, pruned_percentage_history):\n",
        "    \"\"\"\n",
        "    Create comprehensive visualizations for batch pruning analysis with exponential threshold\n",
        "    \"\"\"\n",
        "    plt.figure(figsize=(20, 12))\n",
        "\n",
        "    # Plot 1: Threshold evolution over epochs\n",
        "    plt.subplot(3, 2, 1)\n",
        "    plt.plot(range(1, len(threshold_history) + 1), threshold_history, marker='o', color='red')\n",
        "    plt.title('Exponential Threshold Evolution Over Epochs')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Threshold Value')\n",
        "    plt.yscale('log')  # Use log scale for better visualization\n",
        "    plt.grid(True, alpha=0.3)\n",
        "\n",
        "    # Plot 2: Remaining batches over epochs\n",
        "    plt.subplot(3, 2, 2)\n",
        "    plt.plot(range(1, len(remaining_batches_history) + 1), remaining_batches_history, marker='o', color='blue')\n",
        "    plt.title('Remaining Batches per Epoch')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Number of Batches')\n",
        "    plt.grid(True, alpha=0.3)\n",
        "\n",
        "    # Plot 3: Batches dropped per epoch\n",
        "    plt.subplot(3, 2, 3)\n",
        "    plt.bar(range(1, len(dropped_per_epoch_history) + 1), dropped_per_epoch_history, color='orange', alpha=0.7)\n",
        "    plt.title('Batches Dropped per Epoch')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Number of Batches Dropped')\n",
        "    plt.grid(True, alpha=0.3)\n",
        "\n",
        "    # Plot 4: Cumulative pruned percentage over epochs\n",
        "    plt.subplot(3, 2, 4)\n",
        "    plt.plot(range(1, len(pruned_percentage_history) + 1), pruned_percentage_history,\n",
        "             marker='d', color='green', linewidth=2)\n",
        "    plt.title('Cumulative Pruned Percentage Over Epochs')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Pruned Percentage (%)')\n",
        "    plt.grid(True, alpha=0.3)\n",
        "\n",
        "    # Plot 5: Threshold vs Remaining Batches (correlation analysis)\n",
        "    plt.subplot(3, 2, 5)\n",
        "    plt.scatter(threshold_history, remaining_batches_history, alpha=0.6, color='coral')\n",
        "    plt.title('Threshold vs Remaining Batches')\n",
        "    plt.xlabel('Threshold Value')\n",
        "    plt.ylabel('Remaining Batches')\n",
        "    plt.xscale('log')\n",
        "    plt.grid(True, alpha=0.3)\n",
        "\n",
        "    # Plot 6: Exponential threshold growth pattern\n",
        "    plt.subplot(3, 2, 6)\n",
        "    epochs = range(1, len(threshold_history) + 1)\n",
        "    plt.semilogy(epochs, threshold_history, marker='s', color='purple', linewidth=2)\n",
        "    plt.title('Exponential Threshold Growth Pattern')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Threshold Value (log scale)')\n",
        "    plt.grid(True, alpha=0.3)\n",
        "\n",
        "    plt.tight_layout()\n",
        "    plt.savefig('batch_prune_exponential_threshold_analysis.pdf', dpi=300, bbox_inches='tight')\n",
        "    plt.show()\n",
        "\n",
        "def run_cifar10_classification():\n",
        "    # Configure GPU\n",
        "    device = configure_gpu()\n",
        "\n",
        "    total_epochs = 200\n",
        "    batch_size = 128\n",
        "    learning_rate = 0.05\n",
        "    weight_decay = 5e-4\n",
        "    momentum = 0.9\n",
        "\n",
        "\n",
        "    # Exponential threshold parameters\n",
        "    delta_start = 0.000001   # starting δ\n",
        "    delta_end = 0.00005    # ending δ\n",
        "\n",
        "\n",
        "    min_batch_percentage = 0  # Minimum percentage of batches to keep before stopping\n",
        "\n",
        "    # Load data with train/val/test split (45k/5k/10k)\n",
        "    full_train_dataset, val_loader, test_loader, class_names = load_images(batch_size, val_size=5000)\n",
        "\n",
        "    # Calculate max batches per epoch\n",
        "    global GLOBAL_TRAIN_BATCH_INDICES\n",
        "    total_original_batches = len(GLOBAL_TRAIN_BATCH_INDICES)\n",
        "    max_epochs = total_epochs  # for DUI calculation\n",
        "\n",
        "    print(f\"Total original training batches per epoch: {total_original_batches}\")\n",
        "    print(f\"Maximum number of epochs: {max_epochs}\")\n",
        "\n",
        "    # Create improved ResNet18 model\n",
        "    model = ResNet18(num_classes=10).to(device)\n",
        "\n",
        "    # Print model summary\n",
        "    print(model)\n",
        "\n",
        "    # Calculate model parameters\n",
        "    total_params, trainable_params = calculate_model_parameters(model)\n",
        "\n",
        "    # Calculate FLOPs\n",
        "    try:\n",
        "        flops = calculate_flops(model)\n",
        "    except ImportError:\n",
        "        print(\"thop package not installed. Skipping FLOPs calculation.\")\n",
        "        flops = None\n",
        "\n",
        "    # Define improved optimizer and scheduler with OneCycleLR\n",
        "    optimizer = optim.SGD(model.parameters(),\n",
        "                         lr=learning_rate,\n",
        "                         momentum=momentum,\n",
        "                         weight_decay=weight_decay,\n",
        "                         nesterov=True)  # Enable Nesterov momentum\n",
        "\n",
        "    # CosineAnnealingLR schedule for better convergence\n",
        "    scheduler = CosineAnnealingLR(optimizer, T_max=total_epochs, eta_min=1e-5)\n",
        "\n",
        "    # Training and validation history\n",
        "    train_losses = []\n",
        "    train_accs = []\n",
        "    val_losses = []\n",
        "    val_accs = []\n",
        "\n",
        "    # Metrics to track\n",
        "    epoch_times = []\n",
        "    batch_counts = []\n",
        "\n",
        "    # Initialize tracking for batch dropping with exponential threshold\n",
        "    prev_epoch_std_df = None\n",
        "    batch_indices_to_use = list(range(total_original_batches))  # Start with all batches\n",
        "    total_dropped_batches = 0\n",
        "    total_remaining_batches = []  # Track remaining batches for each epoch\n",
        "    total_batches_dropped_per_epoch = []  # Track batches dropped per epoch\n",
        "\n",
        "    # New tracking for exponential threshold\n",
        "    threshold_history = []  # Track threshold values over epochs\n",
        "    pruned_percentage_history = []  # Track cumulative pruned percentage\n",
        "\n",
        "    # Visualizations for batch dropping analysis\n",
        "    batch_std_history = {}  # Dictionary to store std values by batch across epochs\n",
        "\n",
        "    # Track total training time\n",
        "    total_training_start = time.time()\n",
        "\n",
        "\n",
        "    best_val_acc = 0\n",
        "    best_model_state = None\n",
        "\n",
        "    # Train and evaluate\n",
        "    for epoch in range(1, total_epochs + 1):\n",
        "        print(f\"\\n{'='*50}\")\n",
        "        print(f\"EPOCH {epoch}/{total_epochs}\")\n",
        "        print(f\"{'='*50}\")\n",
        "\n",
        "        # Calculate threshold using exponential function\n",
        "        current_threshold = calculate_exponential_threshold(delta_start, delta_end, total_epochs, epoch)\n",
        "\n",
        "        # Record threshold history\n",
        "        threshold_history.append(current_threshold)\n",
        "\n",
        "        print(f\"Current threshold (δ({epoch})): {current_threshold:.8f}\")\n",
        "\n",
        "\n",
        "        train_loader, logical_index_order = build_train_loader_for_epoch(\n",
        "            full_train_dataset, batch_indices_to_use, batch_size, shuffle_within_batch=True\n",
        "        )\n",
        "\n",
        "        # Train for one epoch\n",
        "        train_loss, train_acc, current_std_df, used_batch_indices, throughput_metrics = train_epoch(\n",
        "            model, device, train_loader, optimizer, epoch,\n",
        "            logical_index_order=logical_index_order,\n",
        "            scheduler=scheduler,\n",
        "            threshold=current_threshold\n",
        "        )\n",
        "\n",
        "        # Record metrics\n",
        "        epoch_times.append(throughput_metrics['epoch_time'])\n",
        "        batch_counts.append(throughput_metrics['total_batches'])\n",
        "        train_losses.append(train_loss)\n",
        "        train_accs.append(train_acc)\n",
        "\n",
        "        # Track remaining batches for this epoch\n",
        "        current_remaining_batches = len(used_batch_indices) if used_batch_indices is not None else total_original_batches\n",
        "        total_remaining_batches.append(current_remaining_batches)\n",
        "\n",
        "        # Calculate and record pruned percentage\n",
        "        current_pruned_percentage = (1 - current_remaining_batches / total_original_batches) * 100\n",
        "        pruned_percentage_history.append(current_pruned_percentage)\n",
        "\n",
        "        # Evaluate on VALIDATION set (not test set)\n",
        "        val_loss, val_acc, _, _ = validate(model, device, val_loader, dataset_name=\"Validation\")\n",
        "        val_losses.append(val_loss)\n",
        "        val_accs.append(val_acc)\n",
        "\n",
        "        # Save best model so far (based on validation accuracy)\n",
        "        if val_acc > best_val_acc:\n",
        "            best_val_acc = val_acc\n",
        "            best_model_state = {key: value.cpu().clone() for key, value in model.state_dict().items()}\n",
        "            print(f\"New best validation accuracy: {best_val_acc:.2f}%\")\n",
        "            # Save model to disk\n",
        "            torch.save(model.state_dict(), 'best_resnet18_exponential_threshold.pth')\n",
        "\n",
        "        # Track batches dropped for this epoch\n",
        "        batches_dropped_this_epoch = 0\n",
        "\n",
        "        # After the first epoch, we start the batch dropping mechanism\n",
        "        if epoch > 1 and prev_epoch_std_df is not None:\n",
        "            # Find matching batch indices between previous and current epoch\n",
        "            common_indices = set(prev_epoch_std_df.index).intersection(set(current_std_df.index))\n",
        "\n",
        "            # Initialize list to track which batches to drop\n",
        "            batches_to_drop = []\n",
        "\n",
        "            # Compare mean standard deviations for each batch using exponential threshold\n",
        "            for idx in common_indices:\n",
        "                prev_mean_std = prev_epoch_std_df.loc[idx, 'mean_std']\n",
        "                curr_mean_std = current_std_df.loc[idx, 'mean_std']\n",
        "\n",
        "                # Update the batch std history for visualization\n",
        "                if idx not in batch_std_history:\n",
        "                    batch_std_history[idx] = []\n",
        "                batch_std_history[idx].append(curr_mean_std)\n",
        "\n",
        "                # If the absolute difference is below the exponential threshold, mark this batch for dropping\n",
        "                if abs(prev_mean_std - curr_mean_std) <= current_threshold:\n",
        "                    batches_to_drop.append(idx)\n",
        "\n",
        "            # Print information about dropped batches\n",
        "            if batches_to_drop:\n",
        "                print(f\"\\nDropping {len(batches_to_drop)} batches for the next epoch:\")\n",
        "                print(f\"Batch indices: {batches_to_drop[:10]}{'...' if len(batches_to_drop) > 10 else ''}\")\n",
        "                total_dropped_batches += len(batches_to_drop)\n",
        "                batches_dropped_this_epoch = len(batches_to_drop)\n",
        "            else:\n",
        "                print(\"\\nNo batches to drop for the next epoch\")\n",
        "\n",
        "            # Update batch indices for the next epoch (exclude the ones to drop)\n",
        "            batch_indices_to_use = [i for i in batch_indices_to_use if i not in set(batches_to_drop)]\n",
        "\n",
        "        # Save current dataframe for comparison with next epoch\n",
        "        prev_epoch_std_df = current_std_df\n",
        "\n",
        "        # Record batches dropped this epoch\n",
        "        total_batches_dropped_per_epoch.append(batches_dropped_this_epoch)\n",
        "\n",
        "        # Calculate percent of original batches remaining\n",
        "        percent_remaining = len(batch_indices_to_use) / total_original_batches * 100\n",
        "        print(f\"\\nBatches remaining: {len(batch_indices_to_use)}/{total_original_batches} ({percent_remaining:.2f}%)\")\n",
        "        print(f\"Total batches dropped so far: {total_dropped_batches}/{total_original_batches} ({total_dropped_batches/total_original_batches*100:.2f}%)\")\n",
        "\n",
        "        # Plot std distribution of remaining batches vs dropped batches for better analysis\n",
        "        if epoch % 10 == 0:  # Only plot every 10 epochs to avoid clutter\n",
        "            plot_batch_std_distribution(current_std_df, batch_indices_to_use, epoch)\n",
        "\n",
        "        # Early stopping if we've dropped too many batches\n",
        "        if len(batch_indices_to_use) <= total_original_batches * min_batch_percentage:\n",
        "            print(f\"\\nStopping early as fewer than {min_batch_percentage*100}% of original batches remain\")\n",
        "            break\n",
        "\n",
        "    # If we have a best model state, load it back\n",
        "    if best_model_state is not None:\n",
        "        model.load_state_dict(best_model_state)\n",
        "        print(f\"Loaded best model with validation accuracy: {best_val_acc:.2f}%\")\n",
        "\n",
        "    # Calculate total training time\n",
        "    total_training_end = time.time()\n",
        "    total_training_time = total_training_end - total_training_start\n",
        "    average_epoch_time = sum(epoch_times) / len(epoch_times) if epoch_times else 0\n",
        "\n",
        "    # Calculate Data Utilization Index (DUI)\n",
        "    total_epochs_completed = len(epoch_times)\n",
        "    sum_remaining_batches = sum(total_remaining_batches)\n",
        "    max_possible_batches = max_epochs * total_original_batches\n",
        "    data_utilization_index = sum_remaining_batches / max_possible_batches\n",
        "    data_savings_index = 1 - data_utilization_index\n",
        "    # Plot batch pruning analysis with exponential threshold\n",
        "    plot_batch_prune_analysis(threshold_history, total_remaining_batches,\n",
        "                            total_batches_dropped_per_epoch, pruned_percentage_history)\n",
        "\n",
        "    # Plot original batch dropping analysis for comparison\n",
        "    plot_batch_dropping_analysis(total_remaining_batches, total_batches_dropped_per_epoch, epoch_times, batch_std_history)\n",
        "\n",
        "    # Final stats on batch dropping\n",
        "    percent_dropped = total_dropped_batches / total_original_batches * 100\n",
        "    percent_remaining = 100 - percent_dropped\n",
        "    total_remaining_batch = total_original_batches - total_dropped_batches\n",
        "    print(f\"\\nData Utilization Index (DUI): {data_utilization_index:.4f}\")\n",
        "    print(f\"\\nData Savings Index (DSI): {data_savings_index:.4f}\")\n",
        "    print(f\"Sum of remaining batches across all epochs: {sum_remaining_batches}\")\n",
        "    print(f\"Maximum possible batches (max_epochs * total_batches): {max_possible_batches}\")\n",
        "\n",
        "    print(f\"\\nFinal Statistics:\")\n",
        "    print(f\"Total batches dropped: {total_dropped_batches}/{total_original_batches} ({percent_dropped:.2f}%)\")\n",
        "    print(f\"Total batches remaining: {total_remaining_batch}/{total_original_batches} ({percent_remaining:.2f}%)\")\n",
        "    print(f\"Final threshold: {threshold_history[-1]:.8f}\")\n",
        "\n",
        "    # Output model statistics and results\n",
        "    print(f\"\\nTraining Time Statistics:\")\n",
        "    print(f\"Total training time: {total_training_time:.2f} seconds ({total_training_time/60:.2f} minutes)\")\n",
        "    print(f\"Average epoch time: {average_epoch_time:.2f} seconds\")\n",
        "    print(f\"Number of epochs completed: {len(epoch_times)}\")\n",
        "\n",
        "    # Print model statistics\n",
        "    print(f\"\\nModel Statistics:\")\n",
        "    print(f\"Total parameters: {total_params:,}\")\n",
        "    if flops:\n",
        "        print(f\"FLOPs: {flops:,} ({flops/1e9:.2f} GFLOPs)\")\n",
        "    print(f\"Best validation accuracy: {best_val_acc:.2f}%\")\n",
        "\n",
        "    # Perform final validation on the VALIDATION set\n",
        "    print(f\"\\n{'='*50}\")\n",
        "    print(\"FINAL VALIDATION SET EVALUATION\")\n",
        "    print(f\"{'='*50}\")\n",
        "    final_val_loss, final_val_acc, final_val_targets, final_val_preds = validate(\n",
        "        model, device, val_loader, dataset_name=\"Final Validation\"\n",
        "    )\n",
        "\n",
        "    # Perform final evaluation on the TEST set\n",
        "    print(f\"\\n{'='*50}\")\n",
        "    print(\"FINAL TEST SET EVALUATION\")\n",
        "    print(f\"{'='*50}\")\n",
        "    final_test_loss, final_test_acc, final_test_targets, final_test_preds = validate(\n",
        "        model, device, test_loader, dataset_name=\"Final Test\"\n",
        "    )\n",
        "\n",
        "    print(f\"\\nFinal Results Summary:\")\n",
        "    print(f\"  Validation Loss: {final_val_loss:.4f}, Accuracy: {final_val_acc:.2f}%\")\n",
        "    print(f\"  Test Loss: {final_test_loss:.4f}, Accuracy: {final_test_acc:.2f}%\")\n",
        "\n",
        "    # Plot confusion matrix for TEST set evaluation\n",
        "    plot_confusion_matrix(final_test_targets, final_test_preds, class_names, title='Test Set Confusion Matrix')\n",
        "\n",
        "    # Plot training metrics\n",
        "    plot_training_metrics(epoch_times, train_losses, train_accs, val_losses, val_accs)\n",
        "\n",
        "    # Create and save final summary DataFrame with exponential threshold info\n",
        "    summary_data = {\n",
        "        'total_training_time': [total_training_time],\n",
        "        'average_epoch_time': [average_epoch_time],\n",
        "        'epochs_completed': [total_epochs_completed],\n",
        "        'best_validation_accuracy': [best_val_acc],\n",
        "        'final_validation_accuracy': [final_val_acc],\n",
        "        'final_test_accuracy': [final_test_acc],\n",
        "        'total_parameters': [total_params],\n",
        "        'data_utilization_index': [data_utilization_index],\n",
        "        'data_savings_index': [data_savings_index],\n",
        "        'total_batches_dropped': [total_dropped_batches],\n",
        "        'total_remaining_batches': [total_remaining_batch],\n",
        "        'batch_drop_percentage': [percent_dropped],\n",
        "        'final_threshold': [threshold_history[-1] if threshold_history else delta_start],\n",
        "        'learning_rate': [learning_rate],\n",
        "        'weight_decay': [weight_decay],\n",
        "        'train_samples': [45000],\n",
        "        'val_samples': [5000],\n",
        "        'test_samples': [10000]\n",
        "    }\n",
        "\n",
        "    summary_df = pd.DataFrame(summary_data)\n",
        "    summary_df.to_csv('exponential_threshold_batch_dropping_summary.csv', index=False)\n",
        "    print(\"\\nSummary saved to exponential_threshold_batch_dropping_summary.csv\")\n",
        "\n",
        "    # Save threshold history for further analysis\n",
        "    threshold_df = pd.DataFrame({\n",
        "        'epoch': range(1, len(threshold_history) + 1),\n",
        "        'threshold': threshold_history,\n",
        "        'remaining_batches': total_remaining_batches,\n",
        "        'batches_dropped': total_batches_dropped_per_epoch,\n",
        "        'pruned_percentage': pruned_percentage_history\n",
        "    })\n",
        "    threshold_df.to_csv('exponential_threshold_history.csv', index=False)\n",
        "    print(\"Threshold history saved to exponential_threshold_history.csv\")\n",
        "\n",
        "    return model, summary_df\n",
        "\n",
        "def plot_batch_std_distribution(std_df, remaining_batch_indices, epoch):\n",
        "    \"\"\"\n",
        "    Plots the distribution of standard deviations for remaining vs dropped batches\n",
        "    \"\"\"\n",
        "    plt.figure(figsize=(10, 6))\n",
        "\n",
        "    # Separate std values for remaining and dropped batches\n",
        "    remaining_stds = std_df.loc[std_df.index.isin(remaining_batch_indices), 'mean_std']\n",
        "    dropped_stds = std_df.loc[~std_df.index.isin(remaining_batch_indices), 'mean_std']\n",
        "\n",
        "    # Plot histograms\n",
        "    plt.hist(remaining_stds, bins=30, alpha=0.5, label='Remaining Batches')\n",
        "    if len(dropped_stds) > 0:\n",
        "        plt.hist(dropped_stds, bins=30, alpha=0.5, label='Dropped Batches')\n",
        "\n",
        "    plt.title(f'Batch Std Distribution - Epoch {epoch}')\n",
        "    plt.xlabel('Mean Standard Deviation')\n",
        "    plt.ylabel('Number of Batches')\n",
        "    plt.legend()\n",
        "    plt.grid(True, alpha=0.3)\n",
        "    plt.savefig(f'batch_std_distribution_epoch_{epoch}_woes.pdf')\n",
        "    plt.close()\n",
        "\n",
        "def plot_batch_dropping_analysis(remaining_batches, dropped_per_epoch, epoch_times, batch_std_history):\n",
        "    \"\"\"\n",
        "    Create comprehensive visualizations for batch dropping analysis\n",
        "    \"\"\"\n",
        "    plt.figure(figsize=(18, 12))\n",
        "\n",
        "    # Plot 1: Remaining batches over epochs\n",
        "    plt.subplot(2, 2, 1)\n",
        "    plt.plot(range(1, len(remaining_batches) + 1), remaining_batches, marker='o')\n",
        "    plt.title('Remaining Batches per Epoch')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Number of Batches')\n",
        "    plt.grid(True)\n",
        "\n",
        "    # Plot 2: Batches dropped per epoch\n",
        "    plt.subplot(2, 2, 2)\n",
        "    plt.bar(range(1, len(dropped_per_epoch) + 1), dropped_per_epoch)\n",
        "    plt.title('Batches Dropped per Epoch')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Number of Batches Dropped')\n",
        "    plt.grid(True)\n",
        "\n",
        "    # Plot 3: Training time vs remaining batches\n",
        "    plt.subplot(2, 2, 3)\n",
        "    plt.scatter(remaining_batches, epoch_times)\n",
        "    plt.title('Training Time vs Remaining Batches')\n",
        "    plt.xlabel('Number of Batches')\n",
        "    plt.ylabel('Epoch Time (seconds)')\n",
        "    plt.grid(True)\n",
        "\n",
        "    # Plot 4: Standard deviation trend for a sample of batches\n",
        "    plt.subplot(2, 2, 4)\n",
        "    # Sample up to 10 batches for visualization\n",
        "    sample_batches = list(batch_std_history.keys())[:10]\n",
        "    for batch_idx in sample_batches:\n",
        "        std_values = batch_std_history[batch_idx]\n",
        "        plt.plot(range(1, len(std_values) + 1), std_values, label=f'Batch {batch_idx}')\n",
        "\n",
        "    plt.title('Std Deviation Trend for Sample Batches')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Mean Std Deviation_woes')\n",
        "    plt.legend()\n",
        "    plt.grid(True)\n",
        "\n",
        "    plt.tight_layout()\n",
        "    plt.savefig('batch_dropping_analysis.pdf')\n",
        "    plt.close()\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    # Set random seeds for reproducibility\n",
        "    torch.manual_seed(42)\n",
        "    np.random.seed(42)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "    # Run the entire classification process\n",
        "    model, summary_df = run_cifar10_classification()"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# SVHN"
      ],
      "metadata": {
        "id": "75zumjwHs7wl"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "from torch.utils.data import DataLoader, Dataset, Subset\n",
        "from torchvision import datasets, transforms\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "from sklearn.metrics import classification_report, confusion_matrix\n",
        "import seaborn as sns\n",
        "import time\n",
        "import pandas as pd\n",
        "from torch.optim.lr_scheduler import CosineAnnealingLR\n",
        "from thop import profile  # For calculating FLOPs\n",
        "import math\n",
        "\n",
        "GLOBAL_TRAIN_BATCH_INDICES = None\n",
        "\n",
        "# GPU Configuration\n",
        "def configure_gpu():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    print(f\"Using device: {device}\")\n",
        "\n",
        "    if torch.cuda.is_available():\n",
        "        print(f\"Num GPUs Available: {torch.cuda.device_count()}\")\n",
        "        torch.cuda.empty_cache()\n",
        "        print(\"GPU memory cleared\")\n",
        "\n",
        "    return device\n",
        "\n",
        "def load_images(batch_size=128, val_size=5000):\n",
        "    \"\"\"\n",
        "    Loads SVHN with train/val/test split:\n",
        "    - Training: 68,257 samples (73,257 - 5,000)\n",
        "    - Validation: 5,000 samples\n",
        "    - Test: 26,032 samples\n",
        "\n",
        "    SVHN images are 32x32 RGB, with labels 0-9 (digit classes).\n",
        "    Uses consistent batch composition with efficient sampling for training.\n",
        "    \"\"\"\n",
        "    # SVHN normalization statistics\n",
        "    svhn_mean = (0.4377, 0.4438, 0.4728)\n",
        "    svhn_std = (0.1980, 0.2010, 0.1970)\n",
        "\n",
        "    # Define transformations (SVHN typically uses lighter augmentation than CIFAR)\n",
        "    transform_train = transforms.Compose([\n",
        "        transforms.RandomCrop(32, padding=4),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize(svhn_mean, svhn_std)\n",
        "    ])\n",
        "\n",
        "    transform_val_test = transforms.Compose([\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize(svhn_mean, svhn_std)\n",
        "    ])\n",
        "\n",
        "    # Load full training dataset (73,257 samples)\n",
        "    full_train_dataset = datasets.SVHN(\n",
        "        root='./data',\n",
        "        split='train',\n",
        "        download=True,\n",
        "        transform=transform_train\n",
        "    )\n",
        "\n",
        "    # Load validation dataset with test transforms (no augmentation)\n",
        "    full_train_dataset_for_val = datasets.SVHN(\n",
        "        root='./data',\n",
        "        split='train',\n",
        "        download=True,\n",
        "        transform=transform_val_test\n",
        "    )\n",
        "\n",
        "    # Load test dataset (26,032 samples)\n",
        "    test_dataset = datasets.SVHN(\n",
        "        root='./data',\n",
        "        split='test',\n",
        "        download=True,\n",
        "        transform=transform_val_test\n",
        "    )\n",
        "\n",
        "    # SVHN has 10 digit classes\n",
        "    class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']\n",
        "\n",
        "    # --- Split training data into train and validation ---\n",
        "    num_total_train = len(full_train_dataset)  # 73,257\n",
        "    indices = np.arange(num_total_train)\n",
        "    np.random.shuffle(indices)\n",
        "\n",
        "    val_indices = indices[:val_size]\n",
        "    train_indices = indices[val_size:]\n",
        "\n",
        "    print(f\"\\nData Split (SVHN):\")\n",
        "    print(f\"  Total training pool: {num_total_train}\")\n",
        "    print(f\"  Training samples: {len(train_indices)}\")\n",
        "    print(f\"  Validation samples: {len(val_indices)}\")\n",
        "    print(f\"  Test samples: {len(test_dataset)}\")\n",
        "    print(f\"  Number of classes: {len(class_names)}\")\n",
        "\n",
        "    # Create validation subset (using non-augmented transforms)\n",
        "    val_dataset = Subset(full_train_dataset_for_val, val_indices)\n",
        "\n",
        "    # --- Efficient Consistent Batch Handling for Training ---\n",
        "    np.random.shuffle(train_indices)\n",
        "\n",
        "    # Group into batches\n",
        "    batch_indices = [\n",
        "        train_indices[i:i + batch_size].tolist()\n",
        "        for i in range(0, len(train_indices), batch_size)\n",
        "    ]\n",
        "\n",
        "    global GLOBAL_TRAIN_BATCH_INDICES\n",
        "    GLOBAL_TRAIN_BATCH_INDICES = batch_indices\n",
        "\n",
        "    print(f\"\\nCreated consistent batch setup with {len(batch_indices)} training batches\")\n",
        "    print(f\"Each batch contains the same data samples across all epochs\")\n",
        "    print(f\"Intra-batch shuffling is enabled to prevent memorization\")\n",
        "\n",
        "    # Validation DataLoader\n",
        "    val_loader = DataLoader(\n",
        "        val_dataset,\n",
        "        batch_size=batch_size,\n",
        "        shuffle=False,\n",
        "        num_workers=4,\n",
        "        pin_memory=True\n",
        "    )\n",
        "\n",
        "    # Test DataLoader\n",
        "    test_loader = DataLoader(\n",
        "        test_dataset,\n",
        "        batch_size=batch_size,\n",
        "        shuffle=False,\n",
        "        num_workers=4,\n",
        "        pin_memory=True\n",
        "    )\n",
        "\n",
        "    return full_train_dataset, val_loader, test_loader, class_names\n",
        "\n",
        "\n",
        "def build_train_loader_for_epoch(full_train_dataset, active_batch_indices, batch_size, shuffle_within_batch=True):\n",
        "\n",
        "    global GLOBAL_TRAIN_BATCH_INDICES\n",
        "\n",
        "    active_batches = []\n",
        "    for logical_idx in active_batch_indices:\n",
        "        sample_indices = GLOBAL_TRAIN_BATCH_INDICES[logical_idx].copy()\n",
        "        if shuffle_within_batch:\n",
        "            np.random.shuffle(sample_indices)\n",
        "        active_batches.append((logical_idx, sample_indices))\n",
        "\n",
        "    flat_indices = []\n",
        "    logical_index_order = []\n",
        "    for logical_idx, sample_indices in active_batches:\n",
        "        flat_indices.extend(sample_indices)\n",
        "        logical_index_order.append(logical_idx)\n",
        "\n",
        "    class OrderedIndexSampler:\n",
        "        def __init__(self, indices):\n",
        "            self.indices = indices\n",
        "\n",
        "        def __iter__(self):\n",
        "            return iter(self.indices)\n",
        "\n",
        "        def __len__(self):\n",
        "            return len(self.indices)\n",
        "\n",
        "    sampler = OrderedIndexSampler(flat_indices)\n",
        "\n",
        "    train_loader = DataLoader(\n",
        "        full_train_dataset,\n",
        "        batch_size=batch_size,\n",
        "        sampler=sampler,\n",
        "        num_workers=4,\n",
        "        shuffle=False,\n",
        "        pin_memory=True,\n",
        "        drop_last=False\n",
        "    )\n",
        "\n",
        "    return train_loader, logical_index_order\n",
        "\n",
        "\n",
        "# ResNet model definitions\n",
        "class BasicBlock(nn.Module):\n",
        "    expansion = 1\n",
        "\n",
        "    def __init__(self, in_planes, planes, stride=1):\n",
        "        super(BasicBlock, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(\n",
        "            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
        "        self.bn1 = nn.BatchNorm2d(planes)\n",
        "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,\n",
        "                               stride=1, padding=1, bias=False)\n",
        "        self.bn2 = nn.BatchNorm2d(planes)\n",
        "\n",
        "        self.shortcut = nn.Sequential()\n",
        "        if stride != 1 or in_planes != self.expansion*planes:\n",
        "            self.shortcut = nn.Sequential(\n",
        "                nn.Conv2d(in_planes, self.expansion*planes,\n",
        "                          kernel_size=1, stride=stride, bias=False),\n",
        "                nn.BatchNorm2d(self.expansion*planes)\n",
        "            )\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = F.relu(self.bn1(self.conv1(x)))\n",
        "        out = self.bn2(self.conv2(out))\n",
        "        out += self.shortcut(x)\n",
        "        out = F.relu(out)\n",
        "        return out\n",
        "\n",
        "\n",
        "class Bottleneck(nn.Module):\n",
        "    expansion = 4\n",
        "\n",
        "    def __init__(self, in_planes, planes, stride=1):\n",
        "        super(Bottleneck, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)\n",
        "        self.bn1 = nn.BatchNorm2d(planes)\n",
        "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,\n",
        "                               stride=stride, padding=1, bias=False)\n",
        "        self.bn2 = nn.BatchNorm2d(planes)\n",
        "        self.conv3 = nn.Conv2d(planes, self.expansion *\n",
        "                               planes, kernel_size=1, bias=False)\n",
        "        self.bn3 = nn.BatchNorm2d(self.expansion*planes)\n",
        "\n",
        "        self.shortcut = nn.Sequential()\n",
        "        if stride != 1 or in_planes != self.expansion*planes:\n",
        "            self.shortcut = nn.Sequential(\n",
        "                nn.Conv2d(in_planes, self.expansion*planes,\n",
        "                          kernel_size=1, stride=stride, bias=False),\n",
        "                nn.BatchNorm2d(self.expansion*planes)\n",
        "            )\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = F.relu(self.bn1(self.conv1(x)))\n",
        "        out = F.relu(self.bn2(self.conv2(out)))\n",
        "        out = self.bn3(self.conv3(out))\n",
        "        out += self.shortcut(x)\n",
        "        out = F.relu(out)\n",
        "        return out\n",
        "\n",
        "\n",
        "class ResNet(nn.Module):\n",
        "    def __init__(self, block, num_blocks, num_classes=10):\n",
        "        super(ResNet, self).__init__()\n",
        "        self.in_planes = 64\n",
        "\n",
        "        self.layer_names = ['layer1', 'layer2', 'layer3', 'layer4']\n",
        "\n",
        "        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,\n",
        "                               stride=1, padding=1, bias=False)\n",
        "        self.bn1 = nn.BatchNorm2d(64)\n",
        "        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n",
        "        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n",
        "        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n",
        "        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n",
        "        self.avgpool = nn.AdaptiveAvgPool2d((1,1))\n",
        "        self.dropout = nn.Dropout(0.2)\n",
        "        self.linear = nn.Linear(512*block.expansion, num_classes)\n",
        "\n",
        "        self.conv_layers = [self.layer1, self.layer2, self.layer3, self.layer4]\n",
        "\n",
        "    def _make_layer(self, block, planes, num_blocks, stride):\n",
        "        strides = [stride] + [1]*(num_blocks-1)\n",
        "        layers = []\n",
        "        for stride in strides:\n",
        "            layers.append(block(self.in_planes, planes, stride))\n",
        "            self.in_planes = planes * block.expansion\n",
        "        return nn.Sequential(*layers)\n",
        "\n",
        "    def forward(self, x, return_activations=False):\n",
        "        activations = []\n",
        "\n",
        "        out = F.relu(self.bn1(self.conv1(x)))\n",
        "\n",
        "        out = self.layer1(out)\n",
        "        if return_activations:\n",
        "            activations.append(out.detach())\n",
        "\n",
        "        out = self.layer2(out)\n",
        "        if return_activations:\n",
        "            activations.append(out.detach())\n",
        "\n",
        "        out = self.layer3(out)\n",
        "        if return_activations:\n",
        "            activations.append(out.detach())\n",
        "\n",
        "        out = self.layer4(out)\n",
        "        if return_activations:\n",
        "            activations.append(out.detach())\n",
        "\n",
        "        out = self.avgpool(out)\n",
        "        out = out.view(out.size(0), -1)\n",
        "        out = self.dropout(out)\n",
        "        out = self.linear(out)\n",
        "\n",
        "        if return_activations:\n",
        "            return out, activations\n",
        "        return out\n",
        "\n",
        "    def create_emb(self, x):\n",
        "        out = F.relu(self.bn1(self.conv1(x)))\n",
        "        out = self.layer1(out)\n",
        "        out = self.layer2(out)\n",
        "        out = self.layer3(out)\n",
        "        out = self.layer4(out)\n",
        "        out = self.avgpool(out)\n",
        "        out = out.view(out.size(0), -1)\n",
        "        return out\n",
        "\n",
        "def ResNet18(num_classes=10):\n",
        "    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)\n",
        "\n",
        "def ResNet50(num_classes=10):\n",
        "    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)\n",
        "\n",
        "def calculate_model_parameters(model):\n",
        "    total_params = sum(p.numel() for p in model.parameters())\n",
        "    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
        "\n",
        "    print(f\"\\nModel Parameters:\")\n",
        "    print(f\"Total parameters: {total_params:,}\")\n",
        "    print(f\"Trainable parameters: {trainable_params:,}\")\n",
        "\n",
        "    return total_params, trainable_params\n",
        "\n",
        "def calculate_flops(model, input_size=(3, 32, 32)):\n",
        "    device = next(model.parameters()).device\n",
        "    dummy_input = torch.randn(1, *input_size).to(device)\n",
        "\n",
        "    flops, params = profile(model, inputs=(dummy_input,), verbose=False)\n",
        "\n",
        "    print(f\"\\nModel Computational Requirements:\")\n",
        "    print(f\"FLOPs: {flops:,} ({flops/1e9:.2f} GFLOPs)\")\n",
        "\n",
        "    return flops\n",
        "\n",
        "\n",
        "def train_epoch(model, device, train_loader, optimizer, epoch, logical_index_order, scheduler=None, threshold=0.0001):\n",
        "\n",
        "    model.train()\n",
        "    train_loss = 0\n",
        "    correct = 0\n",
        "    total = 0\n",
        "\n",
        "    std_devs = {layer_name: [] for layer_name in model.layer_names}\n",
        "    batch_indices_list = []\n",
        "\n",
        "    start_time = time.time()\n",
        "    total_images_processed = 0\n",
        "    batch_processing_times = []\n",
        "\n",
        "    num_active_batches = len(logical_index_order)\n",
        "    print(f\"Using {num_active_batches} batches this epoch\")\n",
        "\n",
        "    for loader_batch_idx, (data, target) in enumerate(train_loader):\n",
        "        logical_batch_idx = logical_index_order[loader_batch_idx]\n",
        "\n",
        "        batch_start_time = time.time()\n",
        "        batch_indices_list.append(logical_batch_idx)\n",
        "\n",
        "        data, target = data.to(device), target.to(device)\n",
        "\n",
        "        # Optional: Mixup augmentation\n",
        "        if np.random.random() > 0.5:\n",
        "            lam = np.random.beta(0.2, 0.2)\n",
        "            rand_idx = torch.randperm(data.size(0)).to(device)\n",
        "            mixed_data = lam * data + (1 - lam) * data[rand_idx]\n",
        "            target_a, target_b = target, target[rand_idx]\n",
        "\n",
        "            optimizer.zero_grad()\n",
        "            output, activations = model(mixed_data, return_activations=True)\n",
        "            loss = lam * F.cross_entropy(output, target_a) + (1 - lam) * F.cross_entropy(output, target_b)\n",
        "        else:\n",
        "            optimizer.zero_grad()\n",
        "            output, activations = model(data, return_activations=True)\n",
        "            loss = F.cross_entropy(output, target)\n",
        "\n",
        "        for layer_idx, act in enumerate(activations):\n",
        "            std_dev = torch.std(act).item()\n",
        "            std_devs[model.layer_names[layer_idx]].append(std_dev)\n",
        "\n",
        "        loss.backward()\n",
        "        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
        "        optimizer.step()\n",
        "\n",
        "        train_loss += loss.item()\n",
        "        _, predicted = output.max(1)\n",
        "        total += target.size(0)\n",
        "        correct += predicted.eq(target).sum().item()\n",
        "\n",
        "        batch_end_time = time.time()\n",
        "        batch_time = batch_end_time - batch_start_time\n",
        "        batch_processing_times.append(batch_time)\n",
        "        total_images_processed += data.size(0)\n",
        "\n",
        "        if (loader_batch_idx + 1) % 100 == 0:\n",
        "            print(f'Train Epoch: {epoch} [{total_images_processed}/{num_active_batches * data.size(0)} '\n",
        "                  f'({100. * (loader_batch_idx + 1) / num_active_batches:.0f}%)]\\t'\n",
        "                  f'Loss: {loss.item():.6f}\\tAccuracy: {100. * correct / total:.2f}%')\n",
        "\n",
        "    end_time = time.time()\n",
        "    training_time = end_time - start_time\n",
        "\n",
        "    std_df = pd.DataFrame(std_devs)\n",
        "    std_df['batch_idx'] = batch_indices_list\n",
        "    std_df = std_df.set_index('batch_idx')\n",
        "    std_df['mean_std'] = std_df.mean(axis=1)\n",
        "\n",
        "    avg_batch_time = np.mean(batch_processing_times) if batch_processing_times else 0\n",
        "    images_per_second = total_images_processed / training_time if training_time > 0 else 0\n",
        "\n",
        "    epoch_loss = train_loss / len(batch_indices_list) if batch_indices_list else 0\n",
        "    epoch_acc = 100. * correct / total if total > 0 else 0\n",
        "\n",
        "    if scheduler is not None:\n",
        "        scheduler.step()\n",
        "\n",
        "    print(f\"\\nEpoch {epoch} completed in {training_time:.2f} seconds\")\n",
        "    print(f\"Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%\")\n",
        "    print(f\"Throughput: {images_per_second:.2f} images/second, Avg batch time: {avg_batch_time:.4f} s\")\n",
        "    print(f\"Current threshold: {threshold:.8f}\")\n",
        "    print(f\"Batch composition: CONSISTENT across epochs (same samples per batch)\")\n",
        "    print(f\"Intra-batch order: SHUFFLED to prevent memorization\")\n",
        "\n",
        "    throughput_metrics = {\n",
        "        'epoch_time': training_time,\n",
        "        'images_per_second': images_per_second,\n",
        "        'avg_batch_time': avg_batch_time,\n",
        "        'total_batches': len(batch_indices_list)\n",
        "    }\n",
        "\n",
        "    return epoch_loss, epoch_acc, std_df, batch_indices_list, throughput_metrics\n",
        "\n",
        "\n",
        "def validate(model, device, data_loader, dataset_name=\"Validation\"):\n",
        "    model.eval()\n",
        "    val_loss = 0\n",
        "    correct = 0\n",
        "    all_preds = []\n",
        "    all_targets = []\n",
        "\n",
        "    with torch.no_grad():\n",
        "        for data, target in data_loader:\n",
        "            data, target = data.to(device), target.to(device)\n",
        "\n",
        "            output = model(data)\n",
        "\n",
        "            val_loss += F.cross_entropy(output, target).item()\n",
        "\n",
        "            pred = output.argmax(dim=1)\n",
        "            correct += pred.eq(target).sum().item()\n",
        "\n",
        "            all_preds.extend(pred.cpu().numpy())\n",
        "            all_targets.extend(target.cpu().numpy())\n",
        "\n",
        "    val_loss /= len(data_loader)\n",
        "    val_acc = 100. * correct / len(data_loader.dataset)\n",
        "\n",
        "    print(f\"{dataset_name} Loss: {val_loss:.4f}, Accuracy: {val_acc:.2f}%\")\n",
        "\n",
        "    return val_loss, val_acc, all_targets, all_preds\n",
        "\n",
        "def plot_training_metrics(epoch_times, train_losses, train_accs, val_losses, val_accs):\n",
        "    plt.figure(figsize=(18, 12))\n",
        "\n",
        "    plt.subplot(2, 2, 1)\n",
        "    plt.plot(range(1, len(epoch_times) + 1), epoch_times, marker='o')\n",
        "    plt.title('Training Time per Epoch')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Time (seconds)')\n",
        "    plt.grid(True)\n",
        "\n",
        "    plt.subplot(2, 2, 3)\n",
        "    plt.plot(train_accs, label='Train Accuracy')\n",
        "    plt.plot(val_accs, label='Validation Accuracy')\n",
        "    plt.title('Model Accuracy (SVHN)')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Accuracy')\n",
        "    plt.legend()\n",
        "    plt.grid(True)\n",
        "\n",
        "    plt.subplot(2, 2, 4)\n",
        "    plt.plot(train_losses, label='Train Loss')\n",
        "    plt.plot(val_losses, label='Validation Loss')\n",
        "    plt.title('Model Loss (SVHN)')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Loss')\n",
        "    plt.legend()\n",
        "    plt.grid(True)\n",
        "\n",
        "    plt.tight_layout()\n",
        "    plt.savefig('training_metrics_svhn.pdf')\n",
        "    plt.show()\n",
        "\n",
        "def plot_confusion_matrix(all_targets, all_preds, class_names, title='Confusion Matrix'):\n",
        "    cm = confusion_matrix(all_targets, all_preds)\n",
        "\n",
        "    plt.figure(figsize=(10, 8))\n",
        "    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',\n",
        "                xticklabels=class_names,\n",
        "                yticklabels=class_names)\n",
        "    plt.title(title)\n",
        "    plt.xlabel('Predicted Label')\n",
        "    plt.ylabel('True Label')\n",
        "    plt.tight_layout()\n",
        "    plt.savefig('confusion_matrix_svhn.pdf')\n",
        "    plt.show()\n",
        "\n",
        "def calculate_exponential_threshold(delta_start, delta_end, T, t):\n",
        "    \"\"\"\n",
        "    Calculate threshold using exponential function: δ(t) = δ_initial * e^(alpha * t)\n",
        "    \"\"\"\n",
        "    alpha = math.log(delta_end / delta_start) / T\n",
        "    threshold = delta_start * math.exp(alpha * t)\n",
        "    return threshold\n",
        "\n",
        "\n",
        "def plot_batch_prune_analysis(threshold_history, remaining_batches_history,\n",
        "                            dropped_per_epoch_history, pruned_percentage_history):\n",
        "    plt.figure(figsize=(20, 12))\n",
        "\n",
        "    plt.subplot(3, 2, 1)\n",
        "    plt.plot(range(1, len(threshold_history) + 1), threshold_history, marker='o', color='red')\n",
        "    plt.title('Exponential Threshold Evolution Over Epochs')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Threshold Value')\n",
        "    plt.yscale('log')\n",
        "    plt.grid(True, alpha=0.3)\n",
        "\n",
        "    plt.subplot(3, 2, 2)\n",
        "    plt.plot(range(1, len(remaining_batches_history) + 1), remaining_batches_history, marker='o', color='blue')\n",
        "    plt.title('Remaining Batches per Epoch')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Number of Batches')\n",
        "    plt.grid(True, alpha=0.3)\n",
        "\n",
        "    plt.subplot(3, 2, 3)\n",
        "    plt.bar(range(1, len(dropped_per_epoch_history) + 1), dropped_per_epoch_history, color='orange', alpha=0.7)\n",
        "    plt.title('Batches Dropped per Epoch')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Number of Batches Dropped')\n",
        "    plt.grid(True, alpha=0.3)\n",
        "\n",
        "    plt.subplot(3, 2, 4)\n",
        "    plt.plot(range(1, len(pruned_percentage_history) + 1), pruned_percentage_history,\n",
        "             marker='d', color='green', linewidth=2)\n",
        "    plt.title('Cumulative Pruned Percentage Over Epochs')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Pruned Percentage (%)')\n",
        "    plt.grid(True, alpha=0.3)\n",
        "\n",
        "    plt.subplot(3, 2, 5)\n",
        "    plt.scatter(threshold_history, remaining_batches_history, alpha=0.6, color='coral')\n",
        "    plt.title('Threshold vs Remaining Batches')\n",
        "    plt.xlabel('Threshold Value')\n",
        "    plt.ylabel('Remaining Batches')\n",
        "    plt.xscale('log')\n",
        "    plt.grid(True, alpha=0.3)\n",
        "\n",
        "    plt.subplot(3, 2, 6)\n",
        "    epochs = range(1, len(threshold_history) + 1)\n",
        "    plt.semilogy(epochs, threshold_history, marker='s', color='purple', linewidth=2)\n",
        "    plt.title('Exponential Threshold Growth Pattern')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Threshold Value (log scale)')\n",
        "    plt.grid(True, alpha=0.3)\n",
        "\n",
        "    plt.tight_layout()\n",
        "    plt.savefig('batch_prune_exponential_threshold_analysis_svhn.pdf', dpi=300, bbox_inches='tight')\n",
        "    plt.show()\n",
        "\n",
        "def run_svhn_classification():\n",
        "    # Configure GPU\n",
        "    device = configure_gpu()\n",
        "\n",
        "    # Hyperparameters\n",
        "    total_epochs = 200\n",
        "    batch_size = 128\n",
        "    learning_rate = 0.05\n",
        "    weight_decay = 5e-4\n",
        "    momentum = 0.9\n",
        "\n",
        "    # Exponential threshold parameters\n",
        "    delta_start = 0.000001   # starting δ\n",
        "    delta_end = 0.00005      # ending δ\n",
        "\n",
        "    min_batch_percentage = 0\n",
        "\n",
        "    # Load SVHN data\n",
        "    full_train_dataset, val_loader, test_loader, class_names = load_images(batch_size, val_size=5000)\n",
        "\n",
        "    global GLOBAL_TRAIN_BATCH_INDICES\n",
        "    total_original_batches = len(GLOBAL_TRAIN_BATCH_INDICES)\n",
        "    max_epochs = total_epochs\n",
        "\n",
        "    # Compute actual train size for summary\n",
        "    train_size = sum(len(b) for b in GLOBAL_TRAIN_BATCH_INDICES)\n",
        "\n",
        "    print(f\"Total original training batches per epoch: {total_original_batches}\")\n",
        "    print(f\"Maximum number of epochs: {max_epochs}\")\n",
        "\n",
        "    # Create ResNet18 model with 10 output classes\n",
        "    model = ResNet18(num_classes=10).to(device)\n",
        "\n",
        "    print(model)\n",
        "\n",
        "    total_params, trainable_params = calculate_model_parameters(model)\n",
        "\n",
        "    try:\n",
        "        flops = calculate_flops(model)\n",
        "    except ImportError:\n",
        "        print(\"thop package not installed. Skipping FLOPs calculation.\")\n",
        "        flops = None\n",
        "\n",
        "    optimizer = optim.SGD(model.parameters(),\n",
        "                         lr=learning_rate,\n",
        "                         momentum=momentum,\n",
        "                         weight_decay=weight_decay,\n",
        "                         nesterov=True)\n",
        "\n",
        "    scheduler = CosineAnnealingLR(optimizer, T_max=total_epochs, eta_min=1e-5)\n",
        "\n",
        "    # Training and validation history\n",
        "    train_losses = []\n",
        "    train_accs = []\n",
        "    val_losses = []\n",
        "    val_accs = []\n",
        "\n",
        "    epoch_times = []\n",
        "    batch_counts = []\n",
        "\n",
        "    # Initialize tracking for batch dropping\n",
        "    prev_epoch_std_df = None\n",
        "    batch_indices_to_use = list(range(total_original_batches))\n",
        "    total_dropped_batches = 0\n",
        "    total_remaining_batches = []\n",
        "    total_batches_dropped_per_epoch = []\n",
        "\n",
        "    threshold_history = []\n",
        "    pruned_percentage_history = []\n",
        "\n",
        "    batch_std_history = {}\n",
        "\n",
        "    total_training_start = time.time()\n",
        "\n",
        "    best_val_acc = 0\n",
        "    best_model_state = None\n",
        "\n",
        "    for epoch in range(1, total_epochs + 1):\n",
        "        print(f\"\\n{'='*50}\")\n",
        "        print(f\"EPOCH {epoch}/{total_epochs}\")\n",
        "        print(f\"{'='*50}\")\n",
        "\n",
        "        current_threshold = calculate_exponential_threshold(delta_start, delta_end, total_epochs, epoch)\n",
        "\n",
        "        threshold_history.append(current_threshold)\n",
        "\n",
        "        print(f\"Current threshold (δ({epoch})): {current_threshold:.8f}\")\n",
        "\n",
        "        train_loader, logical_index_order = build_train_loader_for_epoch(\n",
        "            full_train_dataset, batch_indices_to_use, batch_size, shuffle_within_batch=True\n",
        "        )\n",
        "\n",
        "        # Train for one epoch\n",
        "        train_loss, train_acc, current_std_df, used_batch_indices, throughput_metrics = train_epoch(\n",
        "            model, device, train_loader, optimizer, epoch,\n",
        "            logical_index_order=logical_index_order,\n",
        "            scheduler=scheduler,\n",
        "            threshold=current_threshold\n",
        "        )\n",
        "\n",
        "        epoch_times.append(throughput_metrics['epoch_time'])\n",
        "        batch_counts.append(throughput_metrics['total_batches'])\n",
        "        train_losses.append(train_loss)\n",
        "        train_accs.append(train_acc)\n",
        "\n",
        "        current_remaining_batches = len(used_batch_indices) if used_batch_indices is not None else total_original_batches\n",
        "        total_remaining_batches.append(current_remaining_batches)\n",
        "\n",
        "        current_pruned_percentage = (1 - current_remaining_batches / total_original_batches) * 100\n",
        "        pruned_percentage_history.append(current_pruned_percentage)\n",
        "\n",
        "        # Evaluate on VALIDATION set\n",
        "        val_loss, val_acc, _, _ = validate(model, device, val_loader, dataset_name=\"Validation\")\n",
        "        val_losses.append(val_loss)\n",
        "        val_accs.append(val_acc)\n",
        "\n",
        "        # Save best model\n",
        "        if val_acc > best_val_acc:\n",
        "            best_val_acc = val_acc\n",
        "            best_model_state = {key: value.cpu().clone() for key, value in model.state_dict().items()}\n",
        "            print(f\"New best validation accuracy: {best_val_acc:.2f}%\")\n",
        "            torch.save(model.state_dict(), 'best_resnet18_svhn_exponential_threshold.pth')\n",
        "\n",
        "        batches_dropped_this_epoch = 0\n",
        "\n",
        "        # After the first epoch, start the batch dropping mechanism\n",
        "        if epoch > 1 and prev_epoch_std_df is not None:\n",
        "            common_indices = set(prev_epoch_std_df.index).intersection(set(current_std_df.index))\n",
        "\n",
        "            batches_to_drop = []\n",
        "\n",
        "            for idx in common_indices:\n",
        "                prev_mean_std = prev_epoch_std_df.loc[idx, 'mean_std']\n",
        "                curr_mean_std = current_std_df.loc[idx, 'mean_std']\n",
        "\n",
        "                if idx not in batch_std_history:\n",
        "                    batch_std_history[idx] = []\n",
        "                batch_std_history[idx].append(curr_mean_std)\n",
        "\n",
        "                if abs(prev_mean_std - curr_mean_std) <= current_threshold:\n",
        "                    batches_to_drop.append(idx)\n",
        "\n",
        "            if batches_to_drop:\n",
        "                print(f\"\\nDropping {len(batches_to_drop)} batches for the next epoch:\")\n",
        "                print(f\"Batch indices: {batches_to_drop[:10]}{'...' if len(batches_to_drop) > 10 else ''}\")\n",
        "                total_dropped_batches += len(batches_to_drop)\n",
        "                batches_dropped_this_epoch = len(batches_to_drop)\n",
        "            else:\n",
        "                print(\"\\nNo batches to drop for the next epoch\")\n",
        "\n",
        "            batch_indices_to_use = [i for i in batch_indices_to_use if i not in set(batches_to_drop)]\n",
        "\n",
        "        prev_epoch_std_df = current_std_df\n",
        "\n",
        "        total_batches_dropped_per_epoch.append(batches_dropped_this_epoch)\n",
        "\n",
        "        percent_remaining = len(batch_indices_to_use) / total_original_batches * 100\n",
        "        print(f\"\\nBatches remaining: {len(batch_indices_to_use)}/{total_original_batches} ({percent_remaining:.2f}%)\")\n",
        "        print(f\"Total batches dropped so far: {total_dropped_batches}/{total_original_batches} ({total_dropped_batches/total_original_batches*100:.2f}%)\")\n",
        "\n",
        "        if epoch % 10 == 0:\n",
        "            plot_batch_std_distribution(current_std_df, batch_indices_to_use, epoch)\n",
        "\n",
        "        # Early stopping if we've dropped too many batches\n",
        "        if len(batch_indices_to_use) <= total_original_batches * min_batch_percentage:\n",
        "            print(f\"\\nStopping early as fewer than {min_batch_percentage*100}% of original batches remain\")\n",
        "            break\n",
        "\n",
        "    # Load best model\n",
        "    if best_model_state is not None:\n",
        "        model.load_state_dict(best_model_state)\n",
        "        print(f\"Loaded best model with validation accuracy: {best_val_acc:.2f}%\")\n",
        "\n",
        "    # Calculate total training time\n",
        "    total_training_end = time.time()\n",
        "    total_training_time = total_training_end - total_training_start\n",
        "    average_epoch_time = sum(epoch_times) / len(epoch_times) if epoch_times else 0\n",
        "\n",
        "    # Calculate Data Utilization Index (DUI)\n",
        "    total_epochs_completed = len(epoch_times)\n",
        "    sum_remaining_batches = sum(total_remaining_batches)\n",
        "    max_possible_batches = max_epochs * total_original_batches\n",
        "    data_utilization_index = sum_remaining_batches / max_possible_batches\n",
        "    data_savings_index = 1 - data_utilization_index\n",
        "\n",
        "    # Plot batch pruning analysis\n",
        "    plot_batch_prune_analysis(threshold_history, total_remaining_batches,\n",
        "                            total_batches_dropped_per_epoch, pruned_percentage_history)\n",
        "\n",
        "    plot_batch_dropping_analysis(total_remaining_batches, total_batches_dropped_per_epoch, epoch_times, batch_std_history)\n",
        "\n",
        "    # Final stats\n",
        "    percent_dropped = total_dropped_batches / total_original_batches * 100\n",
        "    percent_remaining = 100 - percent_dropped\n",
        "    total_remaining_batch = total_original_batches - total_dropped_batches\n",
        "    print(f\"\\nData Utilization Index (DUI): {data_utilization_index:.4f}\")\n",
        "    print(f\"\\nData Savings Index (DSI): {data_savings_index:.4f}\")\n",
        "    print(f\"Sum of remaining batches across all epochs: {sum_remaining_batches}\")\n",
        "    print(f\"Maximum possible batches (max_epochs * total_batches): {max_possible_batches}\")\n",
        "\n",
        "    print(f\"\\nFinal Statistics:\")\n",
        "    print(f\"Total batches dropped: {total_dropped_batches}/{total_original_batches} ({percent_dropped:.2f}%)\")\n",
        "    print(f\"Total batches remaining: {total_remaining_batch}/{total_original_batches} ({percent_remaining:.2f}%)\")\n",
        "    print(f\"Final threshold: {threshold_history[-1]:.8f}\")\n",
        "\n",
        "    print(f\"\\nTraining Time Statistics:\")\n",
        "    print(f\"Total training time: {total_training_time:.2f} seconds ({total_training_time/60:.2f} minutes)\")\n",
        "    print(f\"Average epoch time: {average_epoch_time:.2f} seconds\")\n",
        "    print(f\"Number of epochs completed: {len(epoch_times)}\")\n",
        "\n",
        "    print(f\"\\nModel Statistics:\")\n",
        "    print(f\"Total parameters: {total_params:,}\")\n",
        "    if flops:\n",
        "        print(f\"FLOPs: {flops:,} ({flops/1e9:.2f} GFLOPs)\")\n",
        "    print(f\"Best validation accuracy: {best_val_acc:.2f}%\")\n",
        "\n",
        "    # Final validation\n",
        "    print(f\"\\n{'='*50}\")\n",
        "    print(\"FINAL VALIDATION SET EVALUATION\")\n",
        "    print(f\"{'='*50}\")\n",
        "    final_val_loss, final_val_acc, final_val_targets, final_val_preds = validate(\n",
        "        model, device, val_loader, dataset_name=\"Final Validation\"\n",
        "    )\n",
        "\n",
        "    # Final test\n",
        "    print(f\"\\n{'='*50}\")\n",
        "    print(\"FINAL TEST SET EVALUATION\")\n",
        "    print(f\"{'='*50}\")\n",
        "    final_test_loss, final_test_acc, final_test_targets, final_test_preds = validate(\n",
        "        model, device, test_loader, dataset_name=\"Final Test\"\n",
        "    )\n",
        "\n",
        "    print(f\"\\nFinal Results Summary:\")\n",
        "    print(f\"  Validation Loss: {final_val_loss:.4f}, Accuracy: {final_val_acc:.2f}%\")\n",
        "    print(f\"  Test Loss: {final_test_loss:.4f}, Accuracy: {final_test_acc:.2f}%\")\n",
        "\n",
        "    # Plot confusion matrix for TEST set\n",
        "    plot_confusion_matrix(final_test_targets, final_test_preds, class_names, title='SVHN Test Set Confusion Matrix')\n",
        "\n",
        "    # Plot training metrics\n",
        "    plot_training_metrics(epoch_times, train_losses, train_accs, val_losses, val_accs)\n",
        "\n",
        "    # Save summary\n",
        "    summary_data = {\n",
        "        'dataset': ['SVHN'],\n",
        "        'num_classes': [10],\n",
        "        'total_training_time': [total_training_time],\n",
        "        'average_epoch_time': [average_epoch_time],\n",
        "        'epochs_completed': [total_epochs_completed],\n",
        "        'best_validation_accuracy': [best_val_acc],\n",
        "        'final_validation_accuracy': [final_val_acc],\n",
        "        'final_test_accuracy': [final_test_acc],\n",
        "        'total_parameters': [total_params],\n",
        "        'data_utilization_index': [data_utilization_index],\n",
        "        'data_savings_index': [data_savings_index],\n",
        "        'total_batches_dropped': [total_dropped_batches],\n",
        "        'total_remaining_batches': [total_remaining_batch],\n",
        "        'batch_drop_percentage': [percent_dropped],\n",
        "        'final_threshold': [threshold_history[-1] if threshold_history else delta_start],\n",
        "        'learning_rate': [learning_rate],\n",
        "        'weight_decay': [weight_decay],\n",
        "        'train_samples': [train_size],\n",
        "        'val_samples': [5000],\n",
        "        'test_samples': [len(test_loader.dataset)]\n",
        "    }\n",
        "\n",
        "    summary_df = pd.DataFrame(summary_data)\n",
        "    summary_df.to_csv('exponential_threshold_batch_dropping_summary_svhn.csv', index=False)\n",
        "    print(\"\\nSummary saved to exponential_threshold_batch_dropping_summary_svhn.csv\")\n",
        "\n",
        "    # Save threshold history\n",
        "    threshold_df = pd.DataFrame({\n",
        "        'epoch': range(1, len(threshold_history) + 1),\n",
        "        'threshold': threshold_history,\n",
        "        'remaining_batches': total_remaining_batches,\n",
        "        'batches_dropped': total_batches_dropped_per_epoch,\n",
        "        'pruned_percentage': pruned_percentage_history\n",
        "    })\n",
        "    threshold_df.to_csv('exponential_threshold_history_svhn.csv', index=False)\n",
        "    print(\"Threshold history saved to exponential_threshold_history_svhn.csv\")\n",
        "\n",
        "    return model, summary_df\n",
        "\n",
        "def plot_batch_std_distribution(std_df, remaining_batch_indices, epoch):\n",
        "    plt.figure(figsize=(10, 6))\n",
        "\n",
        "    remaining_stds = std_df.loc[std_df.index.isin(remaining_batch_indices), 'mean_std']\n",
        "    dropped_stds = std_df.loc[~std_df.index.isin(remaining_batch_indices), 'mean_std']\n",
        "\n",
        "    plt.hist(remaining_stds, bins=30, alpha=0.5, label='Remaining Batches')\n",
        "    if len(dropped_stds) > 0:\n",
        "        plt.hist(dropped_stds, bins=30, alpha=0.5, label='Dropped Batches')\n",
        "\n",
        "    plt.title(f'Batch Std Distribution - Epoch {epoch} (SVHN)')\n",
        "    plt.xlabel('Mean Standard Deviation')\n",
        "    plt.ylabel('Number of Batches')\n",
        "    plt.legend()\n",
        "    plt.grid(True, alpha=0.3)\n",
        "    plt.savefig(f'batch_std_distribution_epoch_{epoch}_svhn.pdf')\n",
        "    plt.close()\n",
        "\n",
        "def plot_batch_dropping_analysis(remaining_batches, dropped_per_epoch, epoch_times, batch_std_history):\n",
        "    plt.figure(figsize=(18, 12))\n",
        "\n",
        "    plt.subplot(2, 2, 1)\n",
        "    plt.plot(range(1, len(remaining_batches) + 1), remaining_batches, marker='o')\n",
        "    plt.title('Remaining Batches per Epoch')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Number of Batches')\n",
        "    plt.grid(True)\n",
        "\n",
        "    plt.subplot(2, 2, 2)\n",
        "    plt.bar(range(1, len(dropped_per_epoch) + 1), dropped_per_epoch)\n",
        "    plt.title('Batches Dropped per Epoch')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Number of Batches Dropped')\n",
        "    plt.grid(True)\n",
        "\n",
        "    plt.subplot(2, 2, 3)\n",
        "    plt.scatter(remaining_batches, epoch_times)\n",
        "    plt.title('Training Time vs Remaining Batches')\n",
        "    plt.xlabel('Number of Batches')\n",
        "    plt.ylabel('Epoch Time (seconds)')\n",
        "    plt.grid(True)\n",
        "\n",
        "    plt.subplot(2, 2, 4)\n",
        "    sample_batches = list(batch_std_history.keys())[:10]\n",
        "    for batch_idx in sample_batches:\n",
        "        std_values = batch_std_history[batch_idx]\n",
        "        plt.plot(range(1, len(std_values) + 1), std_values, label=f'Batch {batch_idx}')\n",
        "\n",
        "    plt.title('Std Deviation Trend for Sample Batches')\n",
        "    plt.xlabel('Epoch')\n",
        "    plt.ylabel('Mean Std Deviation')\n",
        "    plt.legend()\n",
        "    plt.grid(True)\n",
        "\n",
        "    plt.tight_layout()\n",
        "    plt.savefig('batch_dropping_analysis_svhn.pdf')\n",
        "    plt.close()\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    torch.manual_seed(42)\n",
        "    np.random.seed(42)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "    model, summary_df = run_svhn_classification()"
      ],
      "metadata": {
        "id": "vAQYU9-Rs9jb"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}