{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PTrFma-Wh8Dj"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torchvision\n",
        "import torchvision.transforms as transforms\n",
        "from torch.utils.data import DataLoader, SubsetRandomSampler\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "from sklearn.metrics import confusion_matrix, classification_report\n",
        "import seaborn as sns\n",
        "from torchvision.models import resnet18\n",
        "\n",
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "print(f\"Using device: {device}\")\n",
        "\n",
        "torch.manual_seed(42)\n",
        "np.random.seed(42)\n",
        "\n",
        "class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n",
        "               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']\n",
        "\n",
        "transform = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize((0.5,), (0.5,)),\n",
        "    transforms.Lambda(lambda x: x.repeat(3, 1, 1))\n",
        "])\n",
        "\n",
        "train_dataset = torchvision.datasets.FashionMNIST(\n",
        "    root='./data',\n",
        "    train=True,\n",
        "    download=True,\n",
        "    transform=transform\n",
        ")\n",
        "\n",
        "test_dataset = torchvision.datasets.FashionMNIST(\n",
        "    root='./data',\n",
        "    train=False,\n",
        "    download=True,\n",
        "    transform=transform\n",
        ")\n",
        "\n",
        "train_size = len(train_dataset)\n",
        "indices = list(range(train_size))\n",
        "np.random.shuffle(indices)\n",
        "\n",
        "split = int(np.floor(0.2 * train_size))\n",
        "train_indices, val_indices = indices[split:], indices[:split]\n",
        "\n",
        "train_sampler = SubsetRandomSampler(train_indices)\n",
        "val_sampler = SubsetRandomSampler(val_indices)\n",
        "\n",
        "batch_size = 128\n",
        "\n",
        "train_loader = DataLoader(\n",
        "    train_dataset,\n",
        "    batch_size=batch_size,\n",
        "    sampler=train_sampler,\n",
        "    num_workers=2\n",
        ")\n",
        "\n",
        "val_loader = DataLoader(\n",
        "    train_dataset,\n",
        "    batch_size=batch_size,\n",
        "    sampler=val_sampler,\n",
        "    num_workers=2\n",
        ")\n",
        "\n",
        "test_loader = DataLoader(\n",
        "    test_dataset,\n",
        "    batch_size=batch_size,\n",
        "    shuffle=False,\n",
        "    num_workers=2\n",
        ")\n",
        "\n",
        "from torchvision.models import ResNet18_Weights\n",
        "model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)\n",
        "model.fc = nn.Linear(model.fc.in_features, 10)\n",
        "model = model.to(device)\n",
        "\n",
        "criterion = nn.CrossEntropyLoss()\n",
        "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
        "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)\n",
        "\n",
        "def train_epoch(model, train_loader, optimizer, criterion, device):\n",
        "    model.train()\n",
        "    running_loss = 0.0\n",
        "    correct = 0\n",
        "    total = 0\n",
        "\n",
        "    for inputs, labels in train_loader:\n",
        "        inputs, labels = inputs.to(device), labels.to(device)\n",
        "\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        outputs = model(inputs)\n",
        "        loss = criterion(outputs, labels)\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        running_loss += loss.item() * inputs.size(0)\n",
        "        _, predicted = torch.max(outputs, 1)\n",
        "        total += labels.size(0)\n",
        "        correct += (predicted == labels).sum().item()\n",
        "\n",
        "    epoch_loss = running_loss / total\n",
        "    epoch_acc = correct / total\n",
        "    return epoch_loss, epoch_acc\n",
        "\n",
        "def evaluate(model, data_loader, criterion, device):\n",
        "    model.eval()\n",
        "    running_loss = 0.0\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    all_preds = []\n",
        "    all_labels = []\n",
        "\n",
        "    with torch.no_grad():\n",
        "        for inputs, labels in data_loader:\n",
        "            inputs, labels = inputs.to(device), labels.to(device)\n",
        "\n",
        "            outputs = model(inputs)\n",
        "            loss = criterion(outputs, labels)\n",
        "\n",
        "            running_loss += loss.item() * inputs.size(0)\n",
        "            _, predicted = torch.max(outputs, 1)\n",
        "            total += labels.size(0)\n",
        "            correct += (predicted == labels).sum().item()\n",
        "\n",
        "            all_preds.extend(predicted.cpu().numpy())\n",
        "            all_labels.extend(labels.cpu().numpy())\n",
        "\n",
        "    epoch_loss = running_loss / total\n",
        "    epoch_acc = correct / total\n",
        "    return epoch_loss, epoch_acc, all_preds, all_labels\n",
        "\n",
        "num_epochs = 100\n",
        "best_val_loss = float('inf')\n",
        "train_losses = []\n",
        "train_accs = []\n",
        "val_losses = []\n",
        "val_accs = []\n",
        "\n",
        "for epoch in range(num_epochs):\n",
        "    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)\n",
        "    train_losses.append(train_loss)\n",
        "    train_accs.append(train_acc)\n",
        "\n",
        "    val_loss, val_acc, _, _ = evaluate(model, val_loader, criterion, device)\n",
        "    val_losses.append(val_loss)\n",
        "    val_accs.append(val_acc)\n",
        "\n",
        "    scheduler.step(val_loss)\n",
        "\n",
        "    if val_loss < best_val_loss:\n",
        "        best_val_loss = val_loss\n",
        "        torch.save(model.state_dict(), 'best_model_10class.pth')\n",
        "\n",
        "    print(f'Epoch {epoch+1}/{num_epochs} | '\n",
        "          f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | '\n",
        "          f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}')\n",
        "\n",
        "plt.figure(figsize=(12, 5))\n",
        "\n",
        "plt.subplot(1, 2, 1)\n",
        "plt.plot(train_losses, label='Train')\n",
        "plt.plot(val_losses, label='Validation')\n",
        "plt.title('Loss')\n",
        "plt.xlabel('Epoch')\n",
        "plt.legend()\n",
        "\n",
        "plt.subplot(1, 2, 2)\n",
        "plt.plot(train_accs, label='Train')\n",
        "plt.plot(val_accs, label='Validation')\n",
        "plt.title('Accuracy')\n",
        "plt.xlabel('Epoch')\n",
        "plt.legend()\n",
        "\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "model.load_state_dict(torch.load('best_model_10class.pth'))\n",
        "\n",
        "test_loss, test_acc, test_preds, test_labels = evaluate(model, test_loader, criterion, device)\n",
        "print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}')\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [],
      "metadata": {
        "id": "kpTQEW9Gh-yE"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def binary_accuracy_from_multiclass(model, data_loader, device):\n",
        "    model.eval()\n",
        "    binary_correct = 0\n",
        "    binary_total = 0\n",
        "\n",
        "    # treat the following classes as positive Shirt(6), Dress(3), Coat(4), Pullover(2), T-shirt/top(0)\n",
        "    positive_classes = [0, 2, 3, 4, 6]\n",
        "\n",
        "    with torch.no_grad():\n",
        "        for inputs, labels in data_loader:\n",
        "            inputs, labels = inputs.to(device), labels.to(device)\n",
        "            outputs = model(inputs)\n",
        "            class_scores = torch.softmax(outputs, dim=1)\n",
        "\n",
        "            binary_true_labels = torch.tensor([1 if label.item() in positive_classes else 0\n",
        "                                           for label in labels], device=device)\n",
        "\n",
        "            positive_scores = torch.zeros(class_scores.shape[0], device=device)\n",
        "            for cls in positive_classes:\n",
        "                positive_scores += class_scores[:, cls]\n",
        "\n",
        "            negative_scores = torch.zeros(class_scores.shape[0], device=device)\n",
        "            for cls in range(10):\n",
        "                if cls not in positive_classes:\n",
        "                    negative_scores += class_scores[:, cls]\n",
        "\n",
        "            binary_preds = torch.where(positive_scores >= negative_scores,\n",
        "                                      torch.ones_like(positive_scores),\n",
        "                                      torch.zeros_like(positive_scores))\n",
        "\n",
        "            binary_correct += (binary_preds == binary_true_labels).sum().item()\n",
        "            binary_total += labels.size(0)\n",
        "\n",
        "    binary_accuracy = binary_correct / binary_total if binary_total > 0 else 0\n",
        "\n",
        "    return binary_accuracy\n",
        "\n",
        "binary_acc  = binary_accuracy_from_multiclass(model, test_loader, device)"
      ],
      "metadata": {
        "id": "OpQ3oJjS3pej"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}