{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "icxjKpH8kKE7",
        "outputId": "28c28bce-b392-4d13-be5c-43e02b739eff"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/44.8 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.8/44.8 kB\u001b[0m \u001b[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "!pip install --quiet ftfy regex tqdm\n",
        "!pip install --quiet git+https://github.com/openai/CLIP.git\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# CLIP Knowledge Distillation: RN50 Teacher -> RN34 Student on Food-101\n",
        "# Food-101: 101K images, 101 food classes\n",
        "\n",
        "import os\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torchvision\n",
        "import torchvision.transforms as transforms\n",
        "from torch.utils.data import DataLoader, Dataset\n",
        "import clip\n",
        "import numpy as np\n",
        "from tqdm import tqdm\n",
        "\n",
        "# ================================================================\n",
        "# 1. SETUP AND INSTALLATION\n",
        "# ================================================================\n",
        "# Run these cells first in Colab:\n",
        "\n",
        "\n",
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "print(f\"Using device: {device}\")\n",
        "\n",
        "# ================================================================\n",
        "# 2. LOAD CLIP TEACHER MODEL (RN50)\n",
        "# ================================================================\n",
        "print(\"Loading CLIP RN50 teacher model...\")\n",
        "teacher, preprocess = clip.load(\"RN50\", device)\n",
        "teacher.eval()\n",
        "\n",
        "input_resolution = teacher.visual.input_resolution  # 224\n",
        "context_length = teacher.context_length             # 77\n",
        "vocab_size = teacher.vocab_size                     # 49408\n",
        "\n",
        "print(f\"Teacher model parameters: {np.sum([int(np.prod(p.shape)) for p in teacher.parameters()]):,}\")\n",
        "print(f\"Input resolution: {input_resolution}\")\n",
        "print(f\"Context length: {context_length}\")\n",
        "\n",
        "# ================================================================\n",
        "# 3. FOOD-101 CLASS NAMES AND TEMPLATES\n",
        "# ================================================================\n",
        "# Food-101 class names (cleaned up for CLIP templates)\n",
        "FOOD101_CLASSES = [\n",
        "    'apple pie', 'baby back ribs', 'baklava', 'beef carpaccio', 'beef tartare',\n",
        "    'beet salad', 'beignets', 'bibimbap', 'bread pudding', 'breakfast burrito',\n",
        "    'bruschetta', 'caesar salad', 'cannoli', 'caprese salad', 'carrot cake',\n",
        "    'ceviche', 'cheese plate', 'cheesecake', 'chicken curry', 'chicken quesadilla',\n",
        "    'chicken wings', 'chocolate cake', 'chocolate mousse', 'churros', 'clam chowder',\n",
        "    'club sandwich', 'crab cakes', 'creme brulee', 'croque madame', 'cup cakes',\n",
        "    'deviled eggs', 'donuts', 'dumplings', 'edamame', 'eggs benedict',\n",
        "    'escargots', 'falafel', 'filet mignon', 'fish and chips', 'foie gras',\n",
        "    'french fries', 'french onion soup', 'french toast', 'fried calamari', 'fried rice',\n",
        "    'frozen yogurt', 'garlic bread', 'gnocchi', 'greek salad', 'grilled cheese sandwich',\n",
        "    'grilled salmon', 'guacamole', 'gyoza', 'hamburger', 'hot and sour soup',\n",
        "    'hot dog', 'huevos rancheros', 'hummus', 'ice cream', 'lasagna',\n",
        "    'lobster bisque', 'lobster roll sandwich', 'macaroni and cheese', 'macarons', 'miso soup',\n",
        "    'mussels', 'nachos', 'omelette', 'onion rings', 'oysters',\n",
        "    'pad thai', 'paella', 'pancakes', 'panna cotta', 'peking duck',\n",
        "    'pho', 'pizza', 'pork chop', 'poutine', 'prime rib',\n",
        "    'pulled pork sandwich', 'ramen', 'ravioli', 'red velvet cake', 'risotto',\n",
        "    'samosa', 'sashimi', 'scallops', 'seaweed salad', 'shrimp and grits',\n",
        "    'spaghetti bolognese', 'spaghetti carbonara', 'spring rolls', 'steak', 'strawberry shortcake',\n",
        "    'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna tartare', 'waffles'\n",
        "]\n",
        "\n",
        "# CLIP templates for food classification\n",
        "FOOD_TEMPLATES = [\n",
        "    'a photo of {}.',\n",
        "    'a picture of {}.',\n",
        "    'an image of {}.',\n",
        "    'a delicious {}.',\n",
        "    'a tasty {}.',\n",
        "    'a fresh {}.',\n",
        "    'a homemade {}.',\n",
        "    'a restaurant-style {}.',\n",
        "    'a plate of {}.',\n",
        "    'a serving of {}.',\n",
        "    'a close-up of {}.',\n",
        "    'a gourmet {}.',\n",
        "]\n",
        "\n",
        "print(f\"Food-101 dataset: {len(FOOD101_CLASSES)} food classes\")\n",
        "print(f\"Using {len(FOOD_TEMPLATES)} text templates per class\")\n",
        "print(f\"Sample classes: {FOOD101_CLASSES[:5]}\")\n",
        "\n",
        "# ================================================================\n",
        "# 4. STUDENT MODEL DEFINITION (ResNet-34)\n",
        "# ================================================================\n",
        "class StudentImageEncoder(nn.Module):\n",
        "    def __init__(self, output_dim):\n",
        "        super(StudentImageEncoder, self).__init__()\n",
        "        self.encoder = torchvision.models.resnet34(weights='ResNet34_Weights.DEFAULT')\n",
        "        self.encoder.fc = nn.Linear(self.encoder.fc.in_features, output_dim)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.encoder(x)\n",
        "        x = x / x.norm(dim=-1, keepdim=True)  # Normalize\n",
        "        return x\n",
        "\n",
        "class StudentTextEncoder(nn.Module):\n",
        "    def __init__(self, vocab_size, context_length, output_dim):\n",
        "        super(StudentTextEncoder, self).__init__()\n",
        "        self.token_embedding = nn.Embedding(vocab_size, output_dim)\n",
        "        self.positional_embedding = nn.Parameter(torch.zeros(context_length, output_dim))\n",
        "        nn.init.normal_(self.positional_embedding, std=0.01)\n",
        "        encoder_layer = nn.TransformerEncoderLayer(d_model=output_dim, nhead=8)\n",
        "        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)\n",
        "        self.ln_final = nn.LayerNorm(output_dim)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.token_embedding(x) + self.positional_embedding\n",
        "        x = x.permute(1, 0, 2)  # (seq_len, batch_size, output_dim)\n",
        "        x = self.transformer(x)\n",
        "        x = x.permute(1, 0, 2)  # (batch_size, seq_len, output_dim)\n",
        "        x = self.ln_final(x)\n",
        "        x = x.mean(dim=1)  # Mean pooling\n",
        "        x = x / x.norm(dim=-1, keepdim=True)  # Normalize\n",
        "        return x\n",
        "\n",
        "# ================================================================\n",
        "# 5. FOOD-101 DATASET PREPARATION\n",
        "# ================================================================\n",
        "class Food101CLIPDataset(Dataset):\n",
        "    def __init__(self, food101_dataset, class_names, templates, context_length, split='train'):\n",
        "        self.dataset = food101_dataset\n",
        "        self.class_names = class_names\n",
        "        self.templates = templates\n",
        "        self.context_length = context_length\n",
        "        self.split = split\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.dataset)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        image, label = self.dataset[idx]\n",
        "\n",
        "        # Get class name\n",
        "        class_name = self.class_names[label]\n",
        "\n",
        "        if self.split == 'train':\n",
        "            # During training, use one random template per class\n",
        "            template = np.random.choice(self.templates)\n",
        "            text = template.format(class_name)\n",
        "            text_tokens = clip.tokenize([text], context_length=self.context_length)[0]\n",
        "        else:\n",
        "            # During evaluation, we'll handle multiple templates differently\n",
        "            text_tokens = label  # We'll handle this in evaluation\n",
        "\n",
        "        return image, text_tokens, label\n",
        "\n",
        "# Data transforms for Food-101\n",
        "transform_train = transforms.Compose([\n",
        "    transforms.RandomResizedCrop(input_resolution),\n",
        "    transforms.RandomHorizontalFlip(),\n",
        "    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),\n",
        "                        std=(0.26862954, 0.26130258, 0.27577711))\n",
        "])\n",
        "\n",
        "transform_val = transforms.Compose([\n",
        "    transforms.Resize(256),\n",
        "    transforms.CenterCrop(input_resolution),\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),\n",
        "                        std=(0.26862954, 0.26130258, 0.27577711))\n",
        "])\n",
        "\n",
        "# Load Food-101 dataset\n",
        "print(\"Downloading Food-101 dataset (this may take ~10 minutes for 5GB)...\")\n",
        "train_food101 = torchvision.datasets.Food101(\n",
        "    root='./data', split='train', download=True, transform=transform_train\n",
        ")\n",
        "val_food101 = torchvision.datasets.Food101(\n",
        "    root='./data', split='test', download=True, transform=transform_val\n",
        ")\n",
        "\n",
        "print(f\"✅ Food-101 loaded successfully!\")\n",
        "print(f\"   Training images: {len(train_food101):,}\")\n",
        "print(f\"   Validation images: {len(val_food101):,}\")\n",
        "\n",
        "# Create CLIP datasets\n",
        "train_dataset = Food101CLIPDataset(\n",
        "    train_food101, FOOD101_CLASSES, FOOD_TEMPLATES, context_length, 'train'\n",
        ")\n",
        "val_dataset = Food101CLIPDataset(\n",
        "    val_food101, FOOD101_CLASSES, FOOD_TEMPLATES, context_length, 'val'\n",
        ")\n",
        "\n",
        "# Data loaders\n",
        "train_loader = DataLoader(\n",
        "    train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True\n",
        ")\n",
        "val_loader = DataLoader(\n",
        "    val_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True\n",
        ")\n",
        "\n",
        "print(f\"   Training batches: {len(train_loader):,}\")\n",
        "print(f\"   Validation batches: {len(val_loader):,}\")\n"
      ],
      "metadata": {
        "id": "6e0b2wiJkK51"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "# ================================================================\n",
        "# 6. DISTILLATION LOSS FUNCTION\n",
        "# ================================================================\n",
        "def contrastive_loss_with_kl_l2_te(\n",
        "    student_image_features,\n",
        "    student_text_features,\n",
        "    teacher_image_features,\n",
        "    teacher_text_features,\n",
        "    temperature=0.07,\n",
        "    alpha=5.0,  # weight for KL term\n",
        "    beta=100,    # weight for L2 term\n",
        "    gamma=2.5   # weight for TE rewards\n",
        "):\n",
        "    eps = 1e-6\n",
        "\n",
        "    # Normalize features\n",
        "    student_image_features = student_image_features / student_image_features.norm(dim=-1, keepdim=True)\n",
        "    student_text_features = student_text_features / student_text_features.norm(dim=-1, keepdim=True)\n",
        "    teacher_image_features = teacher_image_features / teacher_image_features.norm(dim=-1, keepdim=True)\n",
        "    teacher_text_features = teacher_text_features / teacher_text_features.norm(dim=-1, keepdim=True)\n",
        "\n",
        "    # Compute student logits\n",
        "    logits_per_image_student = student_image_features @ student_text_features.t() / temperature\n",
        "    logits_per_text_student = logits_per_image_student.t()\n",
        "\n",
        "    # Compute teacher logits (no gradients)\n",
        "    with torch.no_grad():\n",
        "        logits_per_image_teacher = teacher_image_features @ teacher_text_features.t() / temperature\n",
        "        logits_per_text_teacher = logits_per_image_teacher.t()\n",
        "\n",
        "    # Contrastive loss (InfoNCE)\n",
        "    batch_size = student_image_features.size(0)\n",
        "    labels = torch.arange(batch_size, device=student_image_features.device)\n",
        "    loss_image = F.cross_entropy(logits_per_image_student, labels)\n",
        "    loss_text = F.cross_entropy(logits_per_text_student, labels)\n",
        "    contrastive_loss = (loss_image + loss_text) / 2\n",
        "\n",
        "    # Interactive Contrastive Learning (ICL)\n",
        "    icl_logits_I_to_T = student_image_features @ teacher_text_features.t() / temperature\n",
        "    icl_loss_I_to_T = F.cross_entropy(icl_logits_I_to_T, labels)\n",
        "    icl_logits_T_to_I = student_text_features @ teacher_image_features.t() / temperature\n",
        "    icl_loss_T_to_I = F.cross_entropy(icl_logits_T_to_I, labels)\n",
        "    icl_loss = (icl_loss_I_to_T + icl_loss_T_to_I) / 2\n",
        "\n",
        "    # KL divergence loss\n",
        "    student_img_log_probs = F.log_softmax(logits_per_image_student, dim=-1)\n",
        "    teacher_img_probs = F.softmax(logits_per_image_teacher, dim=-1)\n",
        "    kl_img = F.kl_div(student_img_log_probs, teacher_img_probs, reduction='batchmean')\n",
        "\n",
        "    student_txt_log_probs = F.log_softmax(logits_per_text_student, dim=-1)\n",
        "    teacher_txt_probs = F.softmax(logits_per_text_teacher, dim=-1)\n",
        "    kl_txt = F.kl_div(student_txt_log_probs, teacher_txt_probs, reduction='batchmean')\n",
        "    kl_loss = (kl_img + kl_txt) / 2\n",
        "\n",
        "    # L2 distance loss\n",
        "    l2_img = F.mse_loss(student_image_features, teacher_image_features)\n",
        "    l2_txt = F.mse_loss(student_text_features, teacher_text_features)\n",
        "    l2_loss = (l2_img + l2_txt) / 2\n",
        "\n",
        "    # Transfer Entropy (TE) approximation using cosine similarity\n",
        "    def cosine_te(student_features, teacher_features):\n",
        "        student_diff = student_features[1:] - student_features[:-1]\n",
        "        teacher_diff = teacher_features[1:] - teacher_features[:-1]\n",
        "        cos_sim = F.cosine_similarity(student_diff, teacher_diff, dim=-1, eps=eps)\n",
        "        return cos_sim.mean()\n",
        "\n",
        "    te_img = cosine_te(student_image_features, teacher_image_features)\n",
        "    te_txt = cosine_te(student_text_features, teacher_text_features)\n",
        "    te1 = (te_img + te_txt) / 2\n",
        "\n",
        "    # Concatenated TE\n",
        "    student_diff_img = student_image_features[1:] - student_image_features[:-1]\n",
        "    teacher_diff_img = teacher_image_features[1:] - teacher_image_features[:-1]\n",
        "    student_diff_txt = student_text_features[1:] - student_text_features[:-1]\n",
        "    teacher_diff_txt = teacher_text_features[1:] - teacher_text_features[:-1]\n",
        "\n",
        "    student_diff_cat = torch.cat([student_diff_img, student_diff_txt], dim=-1)\n",
        "    teacher_diff_cat = torch.cat([teacher_diff_img, teacher_diff_txt], dim=-1)\n",
        "    te2 = F.cosine_similarity(student_diff_cat, teacher_diff_cat, dim=-1, eps=eps).mean()\n",
        "\n",
        "    # Total loss: maximize TE (higher is better) by subtracting it\n",
        "    total_loss = (contrastive_loss + alpha * kl_loss + beta * l2_loss +\n",
        "                 icl_loss - gamma * te1 - gamma * te2)\n",
        "\n",
        "    return total_loss, kl_loss, l2_loss, icl_loss, te1, te2\n",
        "\n",
        "# ================================================================\n",
        "# 7. INITIALIZE STUDENT MODELS\n",
        "# ================================================================\n",
        "print(\"Initializing student models...\")\n",
        "student_image_encoder = StudentImageEncoder(output_dim=1024).to(device)\n",
        "student_text_encoder = StudentTextEncoder(vocab_size, context_length, output_dim=1024).to(device)\n",
        "\n",
        "# Count parameters\n",
        "student_params = sum(p.numel() for p in student_image_encoder.parameters()) + \\\n",
        "                sum(p.numel() for p in student_text_encoder.parameters())\n",
        "teacher_params = sum(p.numel() for p in teacher.parameters())\n",
        "\n",
        "print(f\"Student model parameters: {student_params:,}\")\n",
        "print(f\"Teacher model parameters: {teacher_params:,}\")\n",
        "print(f\"Compression ratio: {teacher_params/student_params:.1f}x\")\n",
        "\n",
        "# Optimizer\n",
        "optimizer = torch.optim.Adam(\n",
        "    list(student_image_encoder.parameters()) + list(student_text_encoder.parameters()),\n",
        "    lr=1e-4\n",
        ")\n",
        "\n",
        "print(\"Models initialized successfully!\")\n",
        "\n",
        "# ================================================================\n",
        "# 8. TRAINING FUNCTION\n",
        "# ================================================================\n",
        "def train_epoch(epoch, num_epochs):\n",
        "    student_image_encoder.train()\n",
        "    student_text_encoder.train()\n",
        "\n",
        "    total_loss = 0.0\n",
        "    total_kl_loss = 0.0\n",
        "    total_l2_loss = 0.0\n",
        "    total_icl_loss = 0.0\n",
        "    total_te1 = 0.0\n",
        "    total_te2 = 0.0\n",
        "\n",
        "    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')\n",
        "\n",
        "    for batch_idx, (images, texts, labels) in enumerate(progress_bar):\n",
        "        images = images.to(device)\n",
        "        texts = texts.to(device)\n",
        "\n",
        "        # Teacher outputs (no gradients)\n",
        "        with torch.no_grad():\n",
        "            teacher_image_features = teacher.encode_image(images)\n",
        "            teacher_text_features = teacher.encode_text(texts)\n",
        "\n",
        "        # Student outputs\n",
        "        student_image_features = student_image_encoder(images).to(teacher_image_features.dtype)\n",
        "        student_text_features = student_text_encoder(texts).to(teacher_text_features.dtype)\n",
        "\n",
        "        # Compute loss\n",
        "        loss, kl_loss, l2_loss, icl_loss, te1, te2 = contrastive_loss_with_kl_l2_te(\n",
        "            student_image_features, student_text_features,\n",
        "            teacher_image_features, teacher_text_features\n",
        "        )\n",
        "\n",
        "        # Backpropagation\n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        # Update metrics\n",
        "        total_loss += loss.item()\n",
        "        total_kl_loss += kl_loss.item()\n",
        "        total_l2_loss += l2_loss.item()\n",
        "        total_icl_loss += icl_loss.item()\n",
        "        total_te1 += te1.item()\n",
        "        total_te2 += te2.item()\n",
        "\n",
        "        # Update progress bar every 100 batches\n",
        "        if batch_idx % 100 == 0:\n",
        "            progress_bar.set_postfix({\n",
        "                'Loss': f'{loss.item():.3f}',\n",
        "                'KL': f'{kl_loss.item():.3f}',\n",
        "                'L2': f'{l2_loss.item():.5f}',\n",
        "                'ICL': f'{icl_loss.item():.3f}',\n",
        "                'TE1': f'{te1.item():.3f}',\n",
        "                'TE2': f'{te2.item():.3f}'\n",
        "            })\n",
        "\n",
        "    # Print epoch averages\n",
        "    num_batches = len(train_loader)\n",
        "    print(f\"\\n📊 Epoch [{epoch+1}/{num_epochs}] Summary:\")\n",
        "    print(f\"   Loss: {total_loss/num_batches:.4f} | \"\n",
        "          f\"KL: {total_kl_loss/num_batches:.4f} | \"\n",
        "          f\"L2: {total_l2_loss/num_batches:.6f}\")\n",
        "    print(f\"   ICL: {total_icl_loss/num_batches:.4f} | \"\n",
        "          f\"TE1: {total_te1/num_batches:.4f} | \"\n",
        "          f\"TE2: {total_te2/num_batches:.4f}\")\n",
        "\n",
        "# ================================================================\n",
        "# 9. EVALUATION FUNCTION\n",
        "# ================================================================\n",
        "def evaluate_food_classification():\n",
        "    \"\"\"Evaluate student model on Food-101 classification\"\"\"\n",
        "    print(\"🍕 Evaluating Food-101 classification performance...\")\n",
        "    student_image_encoder.eval()\n",
        "    student_text_encoder.eval()\n",
        "\n",
        "    # Pre-compute text features for all food classes\n",
        "    print(\"Computing text features for all 101 food classes...\")\n",
        "    all_text_features = []\n",
        "\n",
        "    with torch.no_grad():\n",
        "        for class_name in FOOD101_CLASSES:\n",
        "            class_texts = [template.format(class_name) for template in FOOD_TEMPLATES]\n",
        "            class_tokens = clip.tokenize(class_texts).to(device)\n",
        "            class_features = student_text_encoder(class_tokens)\n",
        "            class_features = class_features.mean(dim=0, keepdim=True)  # Average over templates\n",
        "            class_features = class_features / class_features.norm(dim=-1, keepdim=True)\n",
        "            all_text_features.append(class_features)\n",
        "\n",
        "    text_features = torch.cat(all_text_features, dim=0)  # (101, embed_dim)\n",
        "\n",
        "    # Evaluate on validation set\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    top5_correct = 0\n",
        "\n",
        "    with torch.no_grad():\n",
        "        for images, _, labels in tqdm(val_loader, desc=\"Evaluating\"):\n",
        "            images = images.to(device)\n",
        "            labels = labels.to(device)\n",
        "\n",
        "            # Get image features\n",
        "            image_features = student_image_encoder(images)\n",
        "            image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
        "\n",
        "            # Compute similarities\n",
        "            logits = image_features @ text_features.t()  # (batch_size, 101)\n",
        "\n",
        "            # Top-1 accuracy\n",
        "            _, predicted = logits.max(1)\n",
        "            correct += predicted.eq(labels).sum().item()\n",
        "\n",
        "            # Top-5 accuracy\n",
        "            _, top5_pred = logits.topk(5, 1, True, True)\n",
        "            top5_correct += top5_pred.eq(labels.view(-1, 1)).sum().item()\n",
        "\n",
        "            total += labels.size(0)\n",
        "\n",
        "    top1_acc = 100. * correct / total\n",
        "    top5_acc = 100. * top5_correct / total\n",
        "\n",
        "    print(f\"\\n🎯 Student Model Performance on Food-101:\")\n",
        "    print(f\"   Top-1 Accuracy: {top1_acc:.2f}%\")\n",
        "    print(f\"   Top-5 Accuracy: {top5_acc:.2f}%\")\n",
        "    print(f\"   Evaluated on {total:,} images\")\n",
        "\n",
        "    return top1_acc, top5_acc\n"
      ],
      "metadata": {
        "id": "cusnHjpbspjg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "# ================================================================\n",
        "# 10. MAIN TRAINING LOOP\n",
        "# ================================================================\n",
        "def main():\n",
        "    num_epochs = 10\n",
        "    print(f\"\\n🚀 Starting Food-101 CLIP distillation training for {num_epochs} epochs...\")\n",
        "    print(f\"📊 Dataset: {len(train_dataset):,} training images, {len(val_dataset):,} validation images\")\n",
        "    print(f\"🏫 Teacher: CLIP RN50, Student: ResNet-34\")\n",
        "    print(f\"💾 Using: {device}\")\n",
        "\n",
        "    best_accuracy = 0.0\n",
        "\n",
        "    for epoch in range(num_epochs):\n",
        "        print(f\"\\n{'='*60}\")\n",
        "        print(f\"🍳 EPOCH {epoch+1}/{num_epochs}\")\n",
        "        print(f\"{'='*60}\")\n",
        "\n",
        "        # Training\n",
        "        train_epoch(epoch, num_epochs)\n",
        "\n",
        "        # Evaluate every 2 epochs\n",
        "#        if (epoch + 1) % 2 == 0:\n",
        "#            top1_acc, top5_acc = evaluate_food_classification()\n",
        "\n",
        "            # Save best model\n",
        "#            if top1_acc > best_accuracy:\n",
        "#                best_accuracy = top1_acc\n",
        "#                print(f\"🎉 New best accuracy: {best_accuracy:.2f}%\")\n",
        "                # Optionally save model checkpoint here\n",
        "                # torch.save(student_image_encoder.state_dict(), 'best_student_image.pth')\n",
        "                # torch.save(student_text_encoder.state_dict(), 'best_student_text.pth')\n",
        "\n",
        "    print(f\"\\n🏆 Training completed!\")\n",
        "#    print(f\"📈 Best Top-1 Accuracy: {best_accuracy:.2f}%\")\n",
        "    print(f\"\\n🔥 Final evaluation:\")\n",
        "    evaluate_food_classification()\n",
        "\n",
        "# ================================================================\n",
        "# 11. RUN THE TRAINING\n",
        "# ================================================================\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n",
        "\n"
      ],
      "metadata": {
        "id": "ZnHdwhzYkLQm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "9qajJVzEkL_B"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "-aiC2_JnkMYe"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "fucZlOdokMzB"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}