{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a042ff0c-c012-4bea-9e32-077019930810",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "# WILDS\n",
    "from wilds import get_dataset\n",
    "from wilds.common.data_loaders import get_train_loader, get_eval_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "495cc41e-f82d-4f75-a4be-3caae2adb6ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "num_epochs = 5\n",
    "batch_size = 32\n",
    "learning_rate = 1e-4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d60545ae-fe8a-4df7-a554-2288270521af",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the full dataset, and download it if necessary\n",
    "dataset = get_dataset(dataset=\"iwildcam\", download=True)\n",
    "\n",
    "# Get the training set\n",
    "train_data = dataset.get_subset(\n",
    "    \"train\",\n",
    "    transform=transforms.Compose(\n",
    "        [transforms.Resize((448, 448)), transforms.ToTensor()]\n",
    "    ),\n",
    ")\n",
    "\n",
    "# Prepare the standard data loader\n",
    "train_loader = get_train_loader(\"standard\", train_data, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "acaa8948-9519-470b-8d46-e121e1818236",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --------------------------\n",
    "# 3. Create model\n",
    "# --------------------------\n",
    "# We use a ResNet-152, pretrained on ImageNet\n",
    "model = models.resnet152(weights=None)\n",
    "# Replace final fully connected layer to match the number of classes for iWildCam\n",
    "# ResNet-152's final layer has in_features=2048\n",
    "model.fc = nn.Linear(2048, 182)\n",
    "model = model.to(device)\n",
    "\n",
    "# --------------------------\n",
    "# 4. Define Loss and Optimizer\n",
    "# --------------------------\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "812cfa6b-0e30-4d07-83d2-c93b6fcefbc8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training loss:  4.9739\n",
      "training loss:  4.9180\n",
      "training loss:  4.6620\n",
      "training loss:  4.2888\n",
      "training loss:  3.7062\n",
      "training loss:  4.1623\n",
      "training loss:  3.4184\n",
      "training loss:  3.7050\n",
      "training loss:  3.1832\n",
      "training loss:  3.0448\n",
      "training loss:  3.0124\n",
      "training loss:  3.3612\n",
      "training loss:  2.3886\n",
      "training loss:  2.8278\n",
      "training loss:  2.9781\n",
      "training loss:  3.3020\n",
      "training loss:  2.9077\n",
      "training loss:  3.2632\n",
      "training loss:  3.8324\n",
      "training loss:  3.0886\n",
      "training loss:  3.3191\n",
      "training loss:  2.3739\n",
      "training loss:  3.0675\n",
      "training loss:  2.8124\n",
      "training loss:  2.8324\n",
      "training loss:  2.8304\n",
      "training loss:  2.6286\n",
      "training loss:  2.5930\n",
      "training loss:  3.4455\n",
      "training loss:  2.1726\n",
      "training loss:  2.6566\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    for labeled_batch in train_loader:\n",
    "        # Unpack labeled data\n",
    "        x, y, _ = labeled_batch\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        # Forward pass on labeled data\n",
    "        outputs = model(x)\n",
    "        loss = criterion(outputs, y)\n",
    "        # Backprop and optimize\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        print(f'training loss: {loss.item() : .4f}')\n",
    "\n",
    "    print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f10e6ad5-bb96-43d5-8d56-015fe148a69f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the test set\n",
    "test_data = dataset.get_subset(\n",
    "    \"test\",\n",
    "    transform=transforms.Compose(\n",
    "        [transforms.Resize((448, 448)), transforms.ToTensor()]\n",
    "    ),\n",
    ")\n",
    "\n",
    "# Prepare the evaluation data loader\n",
    "test_loader = get_eval_loader(\"standard\", test_data, batch_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7df0359f-6bf8-4691-ba6e-f924919821c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_correct = 0\n",
    "# Get predictions for the full test set\n",
    "for x, y_true, _ in test_loader:\n",
    "    y_pred = model(x.to(device))\n",
    "    label_pred = torch.argmax(softmax(y_pred), dim=1)\n",
    "    num_correct += (label_pred == y_true).sum()\n",
    "    print('Acc: ' + str(num_correct/len(y_true)))\n",
    "    break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
