{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98fdad4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from urllib.request import urlopen\n",
    "from PIL import Image\n",
    "import timm\n",
    "import torch\n",
    "\n",
    "img = Image.open(urlopen(\n",
    "    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'\n",
    "))\n",
    "\n",
    "model = timm.create_model('convnext_small.in12k_ft_in1k', pretrained=True)\n",
    "model = model.eval()\n",
    "\n",
    "# get model specific transforms (normalization, resize)\n",
    "data_config = timm.data.resolve_model_data_config(model)\n",
    "transforms = timm.data.create_transform(**data_config, is_training=False)\n",
    "\n",
    "output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1\n",
    "\n",
    "top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "936c3d73",
   "metadata": {},
   "outputs": [],
   "source": [
    "top5_class_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff81f7bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import timm\n",
    "\n",
    "class MLJETConvNext(nn.Module):\n",
    "    def __init__(self, backbone='convnext_small.in12k_ft_in1k', pretrained=True):\n",
    "        super().__init__()\n",
    "        self.backbone = timm.create_model(backbone, pretrained=pretrained, in_chans=1, num_classes=0)\n",
    "        \n",
    "        backbone_features = self.backbone.num_features\n",
    "        \n",
    "        # Multi-head classifier outputs\n",
    "        self.energy_loss_head = nn.Sequential(\n",
    "            nn.Linear(backbone_features, 1),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "        self.alpha_head = nn.Sequential(\n",
    "            nn.Linear(backbone_features, 3),\n",
    "            nn.Softmax(dim=1)\n",
    "        )\n",
    "        self.q0_head = nn.Sequential(\n",
    "            nn.Linear(backbone_features, 4),\n",
    "            nn.Softmax(dim=1)\n",
    "        )\n",
    "        \n",
    "    def forward(self, x):\n",
    "        features = self.backbone(x)\n",
    "        energy_loss_output = self.energy_loss_head(features)\n",
    "        alpha_output = self.alpha_head(features)\n",
    "        q0_output = self.q0_head(features)\n",
    "        \n",
    "        return {\n",
    "            'energy_loss_output': energy_loss_output,\n",
    "            'alpha_output': alpha_output,\n",
    "            'q0_output': q0_output\n",
    "        }\n",
    "\n",
    "# Example usage\n",
    "model = MLJETConvNext()\n",
    "\n",
    "# Example input\n",
    "x = torch.randn((1, 1, 32, 32))  # Updated input shape for single-channel input\n",
    "outputs = model(x)\n",
    "\n",
    "print(outputs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
