{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_VmGblraEvLf"
      },
      "source": [
        "## INSTALL NECESSARY LIBRARIES"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rXthPiNH-l9z",
        "outputId": "b3b823e2-a441-481c-f513-d4ffe51209e1"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Wed Sep 24 21:43:07 2025       \n",
            "+-----------------------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |\n",
            "|-----------------------------------------+------------------------+----------------------+\n",
            "| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |\n",
            "|                                         |                        |               MIG M. |\n",
            "|=========================================+========================+======================|\n",
            "|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |\n",
            "| N/A   31C    P0             51W /  400W |       0MiB /  81920MiB |      0%      Default |\n",
            "|                                         |                        |             Disabled |\n",
            "+-----------------------------------------+------------------------+----------------------+\n",
            "                                                                                         \n",
            "+-----------------------------------------------------------------------------------------+\n",
            "| Processes:                                                                              |\n",
            "|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |\n",
            "|        ID   ID                                                               Usage      |\n",
            "|=========================================================================================|\n",
            "|  No running processes found                                                             |\n",
            "+-----------------------------------------------------------------------------------------+\n"
          ]
        }
      ],
      "source": [
        "!nvidia-smi\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "aCrb5QHa2MBk",
        "outputId": "f6d093e8-3504-425f-d374-c1daf7905cef"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "tpfVl1zypzGa",
        "outputId": "9a59f02a-074b-4a49-cd65-a934a4dc3503"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (2.8.0+cu126)\n",
            "Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (0.23.0+cu126)\n",
            "Requirement already satisfied: timm in /usr/local/lib/python3.12/dist-packages (1.0.19)\n",
            "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.12/dist-packages (1.6.1)\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.12/dist-packages (3.10.0)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch) (3.19.1)\n",
            "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch) (4.15.0)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch) (75.2.0)\n",
            "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch) (1.13.3)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch) (3.5)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch) (3.1.6)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch) (2025.3.0)\n",
            "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.77)\n",
            "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.77)\n",
            "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.80)\n",
            "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch) (9.10.2.21)\n",
            "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.4.1)\n",
            "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch) (11.3.0.4)\n",
            "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch) (10.3.7.77)\n",
            "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch) (11.7.1.2)\n",
            "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch) (12.5.4.2)\n",
            "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch) (0.7.1)\n",
            "Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch) (2.27.3)\n",
            "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.77)\n",
            "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.85)\n",
            "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch) (1.11.1.6)\n",
            "Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch) (3.4.0)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from torchvision) (2.0.2)\n",
            "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.12/dist-packages (from torchvision) (11.3.0)\n",
            "Requirement already satisfied: pyyaml in /usr/local/lib/python3.12/dist-packages (from timm) (6.0.2)\n",
            "Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.12/dist-packages (from timm) (0.35.0)\n",
            "Requirement already satisfied: safetensors in /usr/local/lib/python3.12/dist-packages (from timm) (0.6.2)\n",
            "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (1.16.2)\n",
            "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (1.5.2)\n",
            "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (3.6.0)\n",
            "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (1.3.3)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (0.12.1)\n",
            "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (4.60.0)\n",
            "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (1.4.9)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (25.0)\n",
            "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (3.2.4)\n",
            "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (2.9.0.post0)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.7->matplotlib) (1.17.0)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch) (1.3.0)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from huggingface_hub->timm) (2.32.4)\n",
            "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub->timm) (4.67.1)\n",
            "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub->timm) (1.1.10)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch) (3.0.2)\n",
            "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub->timm) (3.4.3)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub->timm) (3.10)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub->timm) (2.5.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub->timm) (2025.8.3)\n"
          ]
        }
      ],
      "source": [
        "!pip install torch torchvision timm scikit-learn matplotlib\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qs3z0qPeLNQM"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import timm\n",
        "from torch.utils.data import DataLoader, random_split, Dataset\n",
        "import torchvision.transforms as transforms\n",
        "from torchvision.datasets import CIFAR100\n",
        "import numpy as np\n",
        "from transformers import SwinForImageClassification, SwinConfig\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "from sklearn.manifold import TSNE\n",
        "from sklearn.mixture import GaussianMixture\n",
        "from scipy import linalg\n",
        "import torch.nn.functional as F\n",
        "from PIL import Image\n",
        "import time\n",
        "\n",
        "import itertools\n",
        "from sklearn.metrics import precision_recall_fscore_support\n",
        "import pickle\n",
        "import random\n",
        "import json\n",
        "import pickle\n",
        "import torchvision.transforms.v2 as v2\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Reproducibility Settings\n",
        "def set_seed(seed: int = 42):\n",
        "    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "set_seed(42)\n",
        "\n",
        "def seed_worker(worker_id: int):\n",
        "    # Deterministic per worker\n",
        "    worker_seed = torch.initial_seed() % 2**32\n",
        "    np.random.seed(worker_seed)\n",
        "    random.seed(worker_seed)\n",
        "\n",
        "g = torch.Generator()\n",
        "g.manual_seed(42)\n",
        "\n",
        "# Selecet cuda\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
      ],
      "metadata": {
        "id": "Xf24_efYUpjH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Check seleceted machine\n",
        "if torch.cuda.is_available():\n",
        "    print(\"GPU is available!\")\n",
        "    print(\"GPU name:\", torch.cuda.get_device_name(0))\n",
        "else:\n",
        "    print(\"GPU is NOT available.\")\n",
        "\n",
        "print(\"Device being used:\", device)\n",
        "print(\"Using CUDA:\", torch.cuda.is_available())\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "E4Or6Bpld_Cr",
        "outputId": "78257874-7251-4d75-80bc-390dc84526df"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "GPU is available!\n",
            "GPU name: NVIDIA A100-SXM4-80GB\n",
            "Device being used: cuda\n",
            "Using CUDA: True\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# global perf toggles for faster running time\n",
        "import torch\n",
        "torch.backends.cuda.matmul.allow_tf32 = True\n",
        "torch.backends.cudnn.allow_tf32 = True\n",
        "torch.set_float32_matmul_precision(\"high\")  # PyTorch 2.x\n",
        "\n",
        "AMP_DTYPE = torch.bfloat16  # prefer BF16 on A100\n",
        "USE_SCALER = False          # GradScaler only for FP16\n"
      ],
      "metadata": {
        "id": "70mUN_36gXvY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Resume Checkpoint\n",
        "SAVE_DIR = '/content/drive/MyDrive/neurosymbolic_epistemic_ai'\n",
        "os.makedirs(SAVE_DIR, exist_ok=True)\n"
      ],
      "metadata": {
        "id": "4XoQcjPdh7bb"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "_Y8qCP3Ko749"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "##CIFAR - DATA LOAD AND AUGMENTATION"
      ],
      "metadata": {
        "id": "omVz9s15o9ka"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### COMPUTE MEAN AND STD FOR NORMALIZATION"
      ],
      "metadata": {
        "id": "0kG67rJjfcm5"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Define the transform to convert the images to tensor\n",
        "transform = transforms.Compose([transforms.ToTensor()])\n",
        "\n",
        "# Load CIFAR-100 dataset\n",
        "tr_data = CIFAR100(root='./data', train=True, download=True, transform=transform)\n",
        "te_data = CIFAR100(root='./data', train=False, download=True, transform=transform)\n"
      ],
      "metadata": {
        "id": "OvsJoBI6c3Fk"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Define the transform to convert the images to tensor\n",
        "transform = transforms.Compose([transforms.ToTensor()])\n",
        "\n",
        "# Create a DataLoader to load the dataset in batches\n",
        "train_l = torch.utils.data.DataLoader(tr_data, batch_size=50000, shuffle=False)\n",
        "# Create a DataLoader to load the dataset in batches\n",
        "test_l = torch.utils.data.DataLoader(te_data, batch_size=50000, shuffle=False)\n",
        "\n",
        "# Get all the images in a single batch\n",
        "data = next(iter(train_l))\n",
        "images, labels = data\n",
        "\n",
        "# Compute the mean and standard deviation\n",
        "mean = torch.mean(images, dim=[0, 2, 3])\n",
        "std = torch.std(images, dim=[0, 2, 3])\n",
        "\n",
        "print(f'Train Mean: {mean}')\n",
        "print(f'Train Standard Deviation: {std}')\n",
        "\n",
        "# Get all the images in a single batch\n",
        "data = next(iter(test_l))\n",
        "images, labels = data\n",
        "\n",
        "# Compute the mean and standard deviation\n",
        "mean = torch.mean(images, dim=[0, 2, 3])\n",
        "std = torch.std(images, dim=[0, 2, 3])\n",
        "\n",
        "print(f'Test Mean: {mean}')\n",
        "print(f'Test Standard Deviation: {std}')\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "8b03e3cf-13ca-4aae-fe29-1d81b6a5546f",
        "id": "i6vDS4Z3mloo"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Train Mean: tensor([0.5071, 0.4866, 0.4409])\n",
            "Train Standard Deviation: tensor([0.2673, 0.2564, 0.2762])\n",
            "Test Mean: tensor([0.5088, 0.4874, 0.4419])\n",
            "Test Standard Deviation: tensor([0.2683, 0.2574, 0.2771])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N8Dh4JfmE2B8"
      },
      "source": [
        "### LOAD AND PREPROCESS DATA"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "def get_cifar100_fine_and_coarse_labels_from_pickle(root='./data'):\n",
        "    \"\"\"\n",
        "    Load CIFAR-100 raw data (fine/coarse labels and names)\n",
        "\n",
        "    Returns:\n",
        "        fine_label_names: List[str]\n",
        "        coarse_label_names: List[str]\n",
        "        fine_to_coarse: Dict[int, int]\n",
        "    \"\"\"\n",
        "    # CIFAR-100 format folder\n",
        "    folder = os.path.join(root, 'cifar-100-python')\n",
        "\n",
        "    # Download CIFAR-100 if needed using torchvision\n",
        "    from torchvision.datasets import CIFAR100\n",
        "    _ = CIFAR100(root=root, train=True, download=True)\n",
        "\n",
        "    # Load meta file\n",
        "    with open(os.path.join(folder, 'meta'), 'rb') as f:\n",
        "        meta = pickle.load(f, encoding='latin1')\n",
        "    fine_label_names = meta['fine_label_names']\n",
        "    coarse_label_names = meta['coarse_label_names']\n",
        "\n",
        "    # Load training labels to map fine → coarse\n",
        "    with open(os.path.join(folder, 'train'), 'rb') as f:\n",
        "        train_data = pickle.load(f, encoding='latin1')\n",
        "\n",
        "    fine_labels = train_data['fine_labels']\n",
        "    coarse_labels = train_data['coarse_labels']\n",
        "\n",
        "    fine_to_coarse = {}\n",
        "    for fine, coarse in zip(fine_labels, coarse_labels):\n",
        "        if fine not in fine_to_coarse:\n",
        "            fine_to_coarse[fine] = coarse  # only take the first occurrence\n",
        "\n",
        "    return fine_label_names, coarse_label_names, fine_to_coarse\n"
      ],
      "metadata": {
        "id": "tgTFODzmRF-C"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Get fine labels, coarse labels and fine_to_coarse mapping\n",
        "fine_classes, coarse_classes, fine_to_coarse = get_cifar100_fine_and_coarse_labels_from_pickle()\n",
        "\n",
        "print(\"fine_classes\", fine_classes)\n",
        "print(\"coarse_classes\", coarse_classes)\n",
        "print(\"fine_to_coarse\", fine_to_coarse)\n",
        "\n",
        "print(\"Size of fine_classes\", len(fine_classes))\n",
        "print(\"Size of coarse_classes\", len(coarse_classes))\n",
        "print(\"Size of fine_to_coarse\", len(fine_to_coarse))\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "UsrOgTzwRtYO",
        "outputId": "d5ecec57-7822-4378-b1a2-35e8cc20527b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 169M/169M [00:05<00:00, 29.7MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "fine_classes ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm']\n",
            "coarse_classes ['aquatic_mammals', 'fish', 'flowers', 'food_containers', 'fruit_and_vegetables', 'household_electrical_devices', 'household_furniture', 'insects', 'large_carnivores', 'large_man-made_outdoor_things', 'large_natural_outdoor_scenes', 'large_omnivores_and_herbivores', 'medium_mammals', 'non-insect_invertebrates', 'people', 'reptiles', 'small_mammals', 'trees', 'vehicles_1', 'vehicles_2']\n",
            "fine_to_coarse {19: 11, 29: 15, 0: 4, 11: 14, 1: 1, 86: 5, 90: 18, 28: 3, 23: 10, 31: 11, 39: 5, 96: 17, 82: 2, 17: 9, 71: 10, 8: 18, 97: 8, 80: 16, 74: 16, 59: 17, 70: 2, 87: 5, 84: 6, 64: 12, 52: 17, 42: 8, 47: 17, 65: 16, 21: 11, 22: 5, 81: 19, 24: 7, 78: 15, 45: 13, 49: 10, 56: 17, 76: 9, 89: 19, 73: 1, 14: 7, 9: 3, 6: 7, 20: 6, 98: 14, 36: 16, 55: 0, 72: 0, 43: 8, 51: 4, 35: 14, 83: 4, 33: 10, 27: 15, 53: 4, 92: 2, 50: 16, 15: 11, 18: 7, 46: 14, 75: 12, 38: 11, 66: 12, 77: 13, 69: 19, 95: 0, 99: 13, 93: 15, 4: 0, 61: 3, 94: 6, 68: 9, 34: 12, 32: 1, 88: 8, 67: 1, 30: 0, 62: 2, 63: 12, 40: 5, 26: 13, 48: 18, 79: 13, 85: 19, 54: 2, 44: 15, 7: 7, 12: 9, 2: 14, 41: 19, 37: 9, 13: 18, 25: 6, 10: 3, 57: 4, 5: 6, 60: 10, 91: 1, 3: 8, 58: 18, 16: 3}\n",
            "Size of fine_classes 100\n",
            "Size of coarse_classes 20\n",
            "Size of fine_to_coarse 100\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q_4KpLJ1dHAh"
      },
      "outputs": [],
      "source": [
        "# Custom Dataset to include fine and coarse labels (integer labels)\n",
        "class CIFAR100WithCoarse(Dataset):\n",
        "    def __init__(self, fine_labels, coarse_labels, dataset, transform=None):\n",
        "        self.fine_labels = fine_labels\n",
        "        self.coarse_labels = coarse_labels\n",
        "        self.dataset = dataset\n",
        "        self.transform = transform  # Add transform parameter\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.fine_labels)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        image, _ = self.dataset[idx]  # Get image\n",
        "        fine_label = self.fine_labels[idx]  # Integer fine label\n",
        "        coarse_label = self.coarse_labels[idx]  # Integer coarse label\n",
        "\n",
        "        # Apply transform if available\n",
        "        if self.transform:\n",
        "            image = self.transform(image)\n",
        "\n",
        "        return image, fine_label, coarse_label\n",
        "\n",
        "# Custom Dataset to return one-hot encoded labels\n",
        "class CIFAR100WithCoarseOneHot(Dataset):\n",
        "    def __init__(self, fine_labels, coarse_labels, dataset, num_fine_labels, num_coarse_labels, transform=None):\n",
        "        self.fine_labels = fine_labels\n",
        "        self.coarse_labels = coarse_labels\n",
        "        self.num_fine_labels = num_fine_labels\n",
        "        self.num_coarse_labels = num_coarse_labels\n",
        "        self.dataset = dataset\n",
        "        self.transform = transform\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.fine_labels)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        image, _ = self.dataset[idx]  # Get image\n",
        "        fine_label = F.one_hot(torch.tensor(self.fine_labels[idx]), num_classes=self.num_fine_labels).float()\n",
        "        coarse_label = F.one_hot(torch.tensor(self.coarse_labels[idx]), num_classes=self.num_coarse_labels).float()\n",
        "\n",
        "        # Apply transform if available\n",
        "        if self.transform:\n",
        "            image = self.transform(image)\n",
        "\n",
        "        return image, fine_label, coarse_label\n",
        "\n",
        "class CIFAR100WithBeliefEncoding(Dataset):\n",
        "    def __init__(self, fine_beliefs, coarse_beliefs, dataset, transform=None,\n",
        "                 fine_true_idx=None, coarse_true_idx=None):\n",
        "        self.fine_beliefs = fine_beliefs      # multi-hot over focal sets\n",
        "        self.coarse_beliefs = coarse_beliefs  # multi-hot over focal sets\n",
        "        self.dataset = dataset\n",
        "        self.transform = transform\n",
        "        self.fine_true_idx = fine_true_idx\n",
        "        self.coarse_true_idx = coarse_true_idx\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.fine_beliefs)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        image, _ = self.dataset[idx]\n",
        "        if self.transform:\n",
        "            image = self.transform(image)\n",
        "        fine_label   = torch.tensor(self.fine_beliefs[idx]).float()\n",
        "        coarse_label = torch.tensor(self.coarse_beliefs[idx]).float()\n",
        "\n",
        "        # Train: return 3-tuple. Val/Test: return 5-tuple (with true ids)\n",
        "        if self.fine_true_idx is None:\n",
        "            return image, fine_label, coarse_label\n",
        "        return (image, fine_label, coarse_label,\n",
        "                torch.tensor(self.fine_true_idx[idx]),\n",
        "                torch.tensor(self.coarse_true_idx[idx]))\n",
        "\n",
        "train_transform = v2.Compose([\n",
        "    v2.ToImage(),\n",
        "    v2.RandomResizedCrop(224, scale=(0.8, 1.0), antialias=True),\n",
        "    v2.RandomHorizontalFlip(),\n",
        "    v2.RandomRotation(10),\n",
        "    v2.ColorJitter(0.2,0.2,0.2,0.1),\n",
        "    v2.AutoAugment(v2.AutoAugmentPolicy.CIFAR10),\n",
        "    v2.ToDtype(torch.float32, scale=True),\n",
        "    v2.Normalize((0.5071,0.4866,0.4409), (0.2673,0.2564,0.2762)),\n",
        "    v2.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3)),\n",
        "])\n",
        "\n",
        "val_test_transform = v2.Compose([\n",
        "    v2.ToImage(),\n",
        "    v2.Resize((224,224), antialias=True),\n",
        "    v2.ToDtype(torch.float32, scale=True),\n",
        "    v2.Normalize((0.5088,0.4874,0.4419), (0.2683,0.2574,0.2771)),\n",
        "])\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Load CIFAR-100 dataset\n",
        "train_dataset_original = CIFAR100(root='./data', train=True, download=True)\n",
        "test_dataset_original = CIFAR100(root='./data', train=False, download=True)\n",
        "\n",
        "# Split the dataset: 40,000 for training and 10,000 for validation\n",
        "train_size = 40000\n",
        "val_size = 10000\n",
        "\n",
        "generator = torch.Generator().manual_seed(42)\n",
        "train_dataset_part, val_dataset_part = random_split(\n",
        "    train_dataset_original, [train_size, val_size], generator=generator\n",
        ")\n",
        "\n",
        "# Create fine and coarse labels for the training dataset\n",
        "fine_labels_train = np.array([train_dataset_part.dataset.targets[idx] for idx in train_dataset_part.indices])\n",
        "coarse_labels_train = np.array([fine_to_coarse[label] for label in fine_labels_train])\n",
        "\n",
        "# Create fine and coarse labels for the validation dataset\n",
        "fine_labels_val = np.array([val_dataset_part.dataset.targets[idx] for idx in val_dataset_part.indices])\n",
        "coarse_labels_val = np.array([fine_to_coarse[label] for label in fine_labels_val])\n",
        "\n",
        "# Create fine and coarse labels for the test dataset\n",
        "fine_labels_test = np.array(test_dataset_original.targets)\n",
        "coarse_labels_test = np.array([fine_to_coarse[label] for label in fine_labels_test])\n"
      ],
      "metadata": {
        "id": "Q8b42HPkTQfS"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VOcHifBVZQaR"
      },
      "outputs": [],
      "source": [
        "# Number of fine and coarse labels for one-hot encoding\n",
        "num_fine_labels = len(fine_classes)\n",
        "num_coarse_labels = len(coarse_classes)\n",
        "\n",
        "# Create the custom datasets (integer labels)\n",
        "train_dataset_int = CIFAR100WithCoarse(fine_labels_train, coarse_labels_train, train_dataset_part, transform=train_transform)\n",
        "val_dataset_int = CIFAR100WithCoarse(fine_labels_val, coarse_labels_val, val_dataset_part, transform=val_test_transform)\n",
        "test_dataset_int = CIFAR100WithCoarse(fine_labels_test, coarse_labels_test, test_dataset_original, transform=val_test_transform)\n",
        "\n",
        "# Create the custom datasets (one-hot encoded labels)\n",
        "train_dataset_one_hot = CIFAR100WithCoarseOneHot(fine_labels_train, coarse_labels_train, train_dataset_part, num_fine_labels, num_coarse_labels, transform=train_transform)\n",
        "val_dataset_one_hot = CIFAR100WithCoarseOneHot(fine_labels_val, coarse_labels_val, val_dataset_part, num_fine_labels, num_coarse_labels, transform=val_test_transform)\n",
        "test_dataset_one_hot = CIFAR100WithCoarseOneHot(fine_labels_test, coarse_labels_test, test_dataset_original, num_fine_labels, num_coarse_labels, transform=val_test_transform)\n",
        "\n",
        "# dataloaders\n",
        "loader_kwargs_train = dict(\n",
        "    batch_size=1024,\n",
        "    shuffle=True,\n",
        "    num_workers=32,\n",
        "    pin_memory=True,\n",
        "    persistent_workers=True,\n",
        "    prefetch_factor=8,\n",
        "    worker_init_fn=seed_worker,\n",
        "    generator=g,\n",
        "    drop_last=True,\n",
        ")\n",
        "\n",
        "loader_kwargs_eval  = dict(\n",
        "    batch_size=1024,\n",
        "    shuffle=False,\n",
        "    num_workers=32,\n",
        "    pin_memory=True,\n",
        "    persistent_workers=True,\n",
        "    worker_init_fn = seed_worker,\n",
        "    generator = g)\n",
        "\n",
        "# Integer-labeled DataLoaders\n",
        "train_loader_int = DataLoader(train_dataset_int, **loader_kwargs_train)\n",
        "val_loader_int   = DataLoader(val_dataset_int, **loader_kwargs_eval)\n",
        "test_loader_int  = DataLoader(test_dataset_int, **loader_kwargs_eval)\n",
        "\n",
        "# One-hot-labeled DataLoaders\n",
        "train_loader_one_hot = DataLoader(train_dataset_one_hot, **loader_kwargs_train)\n",
        "val_loader_one_hot   = DataLoader(val_dataset_one_hot, **loader_kwargs_eval)\n",
        "test_loader_one_hot  = DataLoader(test_dataset_one_hot, **loader_kwargs_eval)\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## INATURALIST _ DATA LOAD AND AUGMENTATION (DO NOT RUN WITH CIFAR)"
      ],
      "metadata": {
        "id": "3RPl8dCFgFGX"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "# Path - RUN HERE JUST ONCE TO DOWNLOAD THE DATA\n",
        "root = \"/content/drive/MyDrive/inat_data\"\n",
        "\n",
        "transform = transforms.ToTensor()\n",
        "\n",
        "# Load dataset with genus + family labels\n",
        "dataset = INaturalist(\n",
        "    root=root,\n",
        "    version=\"2021_valid\",\n",
        "    target_type=[\"genus\", \"family\", \"class\"],\n",
        "    transform=transforms.ToTensor(),\n",
        "    download=True\n",
        ")\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "GeS2bRFKgJIV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from torch.utils.data import DataLoader\n",
        "import torch\n",
        "from tqdm import tqdm\n",
        "import torchvision.transforms.v2 as v2\n",
        "\n",
        "class TempDataset(Dataset):\n",
        "    def __init__(self, subset, transform):\n",
        "        self.subset = subset\n",
        "        self.transform = transform\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.subset)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        img, target = self.subset[idx]\n",
        "        if self.transform:\n",
        "            img = self.transform(img)\n",
        "        return img, target\n",
        "\n",
        "\n",
        "def compute_mean_std(dataset, batch_size=256, num_workers=4):\n",
        "    loader = DataLoader(dataset, batch_size=batch_size,\n",
        "                        num_workers=num_workers, shuffle=False)\n",
        "    n_pixels = 0\n",
        "    mean = 0.\n",
        "    var = 0.\n",
        "\n",
        "    for imgs, _ in loader:\n",
        "        imgs = imgs.view(imgs.size(0), imgs.size(1), -1)  # [B, C, H*W]\n",
        "        n_pixels += imgs.size(0) * imgs.size(2)\n",
        "        mean += imgs.sum(dim=[0,2])\n",
        "        var += (imgs ** 2).sum(dim=[0,2])\n",
        "\n",
        "    mean /= n_pixels\n",
        "    var /= n_pixels\n",
        "    std = (var - mean**2).sqrt()\n",
        "    return mean, std\n",
        "\n",
        "\n",
        "tmp_transform = v2.Compose([\n",
        "    v2.Resize((224, 224)),\n",
        "    v2.ToImage(),\n",
        "    v2.ToDtype(torch.float32, scale=True)\n",
        "])\n",
        "\n",
        "train_mean, train_std = compute_mean_std(TempDataset(train_d, tmp_transform))\n",
        "val_mean, val_std     = compute_mean_std(TempDataset(val_d, tmp_transform))\n",
        "test_mean, test_std   = compute_mean_std(TempDataset(test_d, tmp_transform))\n",
        "\n",
        "print(\"Train mean/std:\", train_mean, train_std)\n",
        "print(\"Val   mean/std:\", val_mean, val_std)\n",
        "print(\"Test  mean/std:\", test_mean, test_std)\n"
      ],
      "metadata": {
        "id": "qqzJE7rERFKj",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "79e35eea-98a3-44e9-8371-e6b24d0b1d5b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Train mean/std: tensor([0.4649, 0.4814, 0.3771]) tensor([0.2337, 0.2246, 0.2436])\n",
            "Val   mean/std: tensor([0.4628, 0.4811, 0.3768]) tensor([0.2335, 0.2245, 0.2434])\n",
            "Test  mean/std: tensor([0.4646, 0.4815, 0.3768]) tensor([0.2321, 0.2236, 0.2429])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "91c218ff"
      },
      "source": [
        "### LOAD DATA, PREPROCESSING AND FINE TO COARSE MAPPING"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!mkdir -p /content/inat_data/2021_valid\n"
      ],
      "metadata": {
        "id": "HswB-qO10vJx"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!tar -xzf /content/drive/MyDrive/inat_data/2021_valid.tgz -C /content/inat_data/2021_valid\n",
        "\n"
      ],
      "metadata": {
        "id": "7I9cSc1j0xP6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!ls /content/inat_data/2021_valid | head\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Vc59H_eW03Ij",
        "outputId": "fc83aa2f-1d18-430a-e37c-5cafcebafe8a"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "val\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!ls /content/inat_data/2021_valid/val | head\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dZzqoc8w1Iv8",
        "outputId": "6e89652d-70cf-4584-f65a-aa0e4623ab55"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "00000_Animalia_Annelida_Clitellata_Haplotaxida_Lumbricidae_Lumbricus_terrestris\n",
            "00001_Animalia_Annelida_Polychaeta_Sabellida_Sabellidae_Sabella_spallanzanii\n",
            "00002_Animalia_Annelida_Polychaeta_Sabellida_Serpulidae_Serpula_columbiana\n",
            "00003_Animalia_Annelida_Polychaeta_Sabellida_Serpulidae_Spirobranchus_cariniferus\n",
            "00004_Animalia_Arthropoda_Arachnida_Araneae_Agelenidae_Eratigena_duellica\n",
            "00005_Animalia_Arthropoda_Arachnida_Araneae_Antrodiaetidae_Atypoides_riversi\n",
            "00006_Animalia_Arthropoda_Arachnida_Araneae_Araneidae_Aculepeira_ceropegia\n",
            "00007_Animalia_Arthropoda_Arachnida_Araneae_Araneidae_Agalenatea_redii\n",
            "00008_Animalia_Arthropoda_Arachnida_Araneae_Araneidae_Araneus_bicentenarius\n",
            "00009_Animalia_Arthropoda_Arachnida_Araneae_Araneidae_Araneus_diadematus\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!mv /content/inat_data/2021_valid/val/* /content/inat_data/2021_valid/\n",
        "!rm -rf /content/inat_data/2021_valid/val\n"
      ],
      "metadata": {
        "id": "jfEo111I12ZL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ----------------- Load dataset -----------------\n",
        "root = \"/content/inat_data\"\n",
        "dataset = INaturalist(\n",
        "    root=root,\n",
        "    version=\"2021_valid\",\n",
        "    target_type=[\"family\", \"class\"],\n",
        "    transform=None,\n",
        "    download=False\n",
        ")\n",
        "\n",
        "print(\"Samples:\", len(dataset))\n",
        "\n",
        "\n",
        "# ----------------- Choose classes -----------------\n",
        "chosen_class_names = [\n",
        "    \"Mammalia\", \"Aves\", \"Reptilia\", \"Amphibia\", \"Actinopterygii\",\n",
        "    \"Insecta\", \"Arachnida\", \"Chilopoda\", \"Diplopoda\", \"Malacostraca\",\n",
        "    \"Bivalvia\", \"Cephalopoda\", \"Gastropoda\",\n",
        "    \"Agaricomycetes\", \"Lecanoromycetes\",\n",
        "    \"Magnoliopsida\", \"Liliopsida\", \"Pinopsida\", \"Polypodiopsida\",\n",
        "    \"Bryopsida\"\n",
        "]\n",
        "\n",
        "class_name_to_id = dataset.categories_index[\"class\"]\n",
        "chosen_class_ids = [class_name_to_id[name] for name in chosen_class_names]\n",
        "\n"
      ],
      "metadata": {
        "id": "CnPcjQABzX4C",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bec390f8-f40e-4f00-91c7-ef432a524c9f"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Samples: 100000\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from collections import Counter\n",
        "\n",
        "# ----------------- Count families -----------------\n",
        "family_counts = Counter()\n",
        "for full_id, _ in dataset.index:\n",
        "    if dataset.categories_map[full_id][\"class\"] in chosen_class_ids:\n",
        "        fam_id = dataset.categories_map[full_id][\"family\"]\n",
        "        fam_name = dataset.category_name(\"family\", fam_id)\n",
        "        family_counts[fam_name] += 1\n",
        "\n",
        "print(\"Total families in chosen classes:\", len(family_counts))\n",
        "print(\"Top 50 families:\", family_counts.most_common(50))\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zKVjZs_QGMoO",
        "outputId": "210e32e1-1753-4d91-ee42-ddc60617bc21"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Total families in chosen classes: 1008\n",
            "Top 50 families: [('Asteraceae', 4950), ('Fabaceae', 2370), ('Nymphalidae', 2340), ('Erebidae', 1790), ('Geometridae', 1770), ('Noctuidae', 1770), ('Rosaceae', 1480), ('Poaceae', 1240), ('Libellulidae', 1220), ('Orchidaceae', 1130), ('Lamiaceae', 1070), ('Lycaenidae', 1060), ('Ericaceae', 990), ('Ranunculaceae', 950), ('Crambidae', 930), ('Anatidae', 900), ('Boraginaceae', 860), ('Hesperiidae', 830), ('Colubridae', 830), ('Cactaceae', 800), ('Accipitridae', 740), ('Plantaginaceae', 740), ('Asparagaceae', 730), ('Cyperaceae', 730), ('Brassicaceae', 700), ('Apiaceae', 680), ('Coenagrionidae', 630), ('Pieridae', 620), ('Orobanchaceae', 610), ('Liliaceae', 590), ('Apocynaceae', 590), ('Euphorbiaceae', 590), ('Malvaceae', 580), ('Sphingidae', 550), ('Scolopacidae', 530), ('Polygonaceae', 530), ('Onagraceae', 510), ('Pinaceae', 510), ('Fagaceae', 500), ('Laridae', 490), ('Caryophyllaceae', 490), ('Acrididae', 470), ('Parulidae', 470), ('Rubiaceae', 470), ('Papilionidae', 460), ('Araneidae', 450), ('Cerambycidae', 450), ('Tyrannidae', 450), ('Picidae', 430), ('Apidae', 420)]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# ----------------- Pick top-50 families -----------------\n",
        "top_families = {fam for fam, _ in family_counts.most_common(50)}\n",
        "\n",
        "keep_indices = [\n",
        "    i for i, (full_id, _) in enumerate(dataset.index)\n",
        "    if dataset.categories_map[full_id][\"class\"] in chosen_class_ids\n",
        "    and dataset.category_name(\"family\", dataset.categories_map[full_id][\"family\"]) in top_families\n",
        "]\n",
        "\n",
        "print(\"Samples after filtering:\", len(keep_indices))\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "tJZtgd43Hv2c",
        "outputId": "c7cc0731-2772-463e-e286-cbd912645754"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Samples after filtering: 45990\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "class NamedSubset(Dataset):\n",
        "    def __init__(self, base, indices):\n",
        "        self.base = base\n",
        "        self.indices = indices\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.indices)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        real_idx = self.indices[idx]\n",
        "        img, target = self.base[real_idx]   # target is a tuple/list\n",
        "        # target may look like: (genus_id, family_id, class_id)\n",
        "        if len(target) == 3:\n",
        "            genus_id, family_id, class_id = target\n",
        "        else:\n",
        "            # fallback: handle cases where only one id is returned\n",
        "            family_id, class_id = target\n",
        "            genus_id = -1\n",
        "        return img, {\n",
        "            \"family_id\": family_id,\n",
        "            \"class_id\": class_id,\n",
        "            \"family_name\": self.base.category_name(\"family\", family_id),\n",
        "            \"class_name\": self.base.category_name(\"class\", class_id),\n",
        "        }\n",
        "\n",
        "\n",
        "subset = NamedSubset(dataset, keep_indices)\n"
      ],
      "metadata": {
        "id": "HKR2E3ayQt7d"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ----------------- Train/Val/Test split -----------------\n",
        "n_total = len(subset)\n",
        "n_train, n_val = int(0.7 * n_total), int(0.15 * n_total)\n",
        "n_test = n_total - n_train - n_val\n",
        "g = torch.Generator().manual_seed(42)\n",
        "train_d, val_d, test_d = random_split(subset, [n_train, n_val, n_test], generator=g)\n",
        "\n",
        "# ----------------- Label mapping -----------------\n",
        "def build_label_indexers(split):\n",
        "    fams, clss = [], []\n",
        "    for _, t in split:\n",
        "        fams.append(t[\"family_name\"])\n",
        "        clss.append(t[\"class_name\"])\n",
        "    fam2idx = {f: i for i, f in enumerate(sorted(set(fams)))}\n",
        "    cls2idx = {c: i for i, c in enumerate(sorted(set(clss)))}\n",
        "    return fam2idx, cls2idx\n",
        "\n",
        "fam2idx, cls2idx = build_label_indexers(train_d)\n",
        "num_fine, num_coarse = len(fam2idx), len(cls2idx)\n",
        "print(f\"Fine labels: {num_fine}, Coarse labels: {num_coarse}\")\n"
      ],
      "metadata": {
        "id": "NeSYRpESVJBt",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bf236d03-5944-48d5-9ac0-d6711007b265"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Fine labels: 50, Coarse labels: 7\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"Total samples:\", len(subset))"
      ],
      "metadata": {
        "id": "ELitgepea-Gc",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "51a71aa4-278e-4f6a-c1be-d792e3fa675c"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Total samples: 45990\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# ----------------- Dataset wrappers -----------------\n",
        "class INatWithCoarse(Dataset):\n",
        "    def __init__(self, base, fam2idx, cls2idx, transform=None):\n",
        "        self.base, self.f2i, self.c2i, self.transform = base, fam2idx, cls2idx, transform\n",
        "    def __len__(self): return len(self.base)\n",
        "    def __getitem__(self, idx):\n",
        "        img, t = self.base[idx]\n",
        "        if self.transform: img = self.transform(img)\n",
        "        return img, self.f2i[t[\"family_name\"]], self.c2i[t[\"class_name\"]]\n",
        "\n",
        "class INatWithCoarseOneHot(Dataset):\n",
        "    def __init__(self, base, fam2idx, cls2idx, num_fam, num_cls, transform=None):\n",
        "        self.base = INatWithCoarse(base, fam2idx, cls2idx, transform)\n",
        "        self.num_fam, self.num_cls = num_fam, num_cls\n",
        "    def __len__(self): return len(self.base)\n",
        "    def __getitem__(self, idx):\n",
        "        img, f, c = self.base[idx]\n",
        "        return img, F.one_hot(torch.tensor(f), self.num_fam).float(), \\\n",
        "                   F.one_hot(torch.tensor(c), self.num_cls).float()\n",
        "\n",
        "# ----------------- Belief-encoding wrapper -----------------\n",
        "class INatWithBeliefEncoding(Dataset):\n",
        "    def __init__(self, fine_beliefs, coarse_beliefs, dataset,\n",
        "                 transform=None, fine_true_idx=None, coarse_true_idx=None):\n",
        "        \"\"\"\n",
        "        fine_beliefs:   numpy array or tensor [N, num_fine_labels] (multi-hot or soft)\n",
        "        coarse_beliefs: numpy array or tensor [N, num_coarse_labels]\n",
        "        dataset:        base dataset (aligned with beliefs)\n",
        "        transform:      optional torchvision transform\n",
        "        fine_true_idx:  optional [N] array of true fine indices\n",
        "        coarse_true_idx:optional [N] array of true coarse indices\n",
        "        \"\"\"\n",
        "        assert len(fine_beliefs) == len(coarse_beliefs) == len(dataset), \\\n",
        "\n",
        "        self.fine_beliefs = fine_beliefs\n",
        "        self.coarse_beliefs = coarse_beliefs\n",
        "        self.dataset = dataset\n",
        "        self.transform = transform\n",
        "        self.fine_true_idx = fine_true_idx\n",
        "        self.coarse_true_idx = coarse_true_idx\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.dataset)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        img, _ = self.dataset[idx]   # dataset returns (img, labels) but we ignore labels here\n",
        "\n",
        "        if self.transform:\n",
        "            img = self.transform(img)\n",
        "\n",
        "        fine_label   = torch.as_tensor(self.fine_beliefs[idx]).float()\n",
        "        coarse_label = torch.as_tensor(self.coarse_beliefs[idx]).float()\n",
        "\n",
        "        # If no true indices → training mode (just beliefs)\n",
        "        if self.fine_true_idx is None or self.coarse_true_idx is None:\n",
        "            return img, fine_label, coarse_label\n",
        "\n",
        "        # If true indices provided → val/test mode (return both beliefs + true ids)\n",
        "        return (\n",
        "            img,\n",
        "            fine_label,\n",
        "            coarse_label,\n",
        "            torch.tensor(int(self.fine_true_idx[idx])).long(),\n",
        "            torch.tensor(int(self.coarse_true_idx[idx])).long()\n",
        "        )\n"
      ],
      "metadata": {
        "id": "vBBrrPbpd0o_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "train_mean =([0.4648, 0.4816, 0.3769])\n",
        "train_std  =([0.2334, 0.2243, 0.2434])\n",
        "val_mean   =([0.4638, 0.4809, 0.3774])\n",
        "val_std    =([0.2334, 0.2247, 0.2433])\n",
        "test_mean  =([0.4645, 0.4807, 0.3770])\n",
        "test_std   =([0.2337, 0.2248, 0.2438])\n",
        "\n",
        "# ----------------- Transforms -----------------\n",
        "train_transform = v2.Compose([\n",
        "    v2.ToImage(),\n",
        "    v2.RandomResizedCrop(224, scale=(0.8, 1.0), antialias=True),\n",
        "    v2.RandomHorizontalFlip(),\n",
        "    v2.RandomRotation(10),\n",
        "    v2.ColorJitter(0.2,0.2,0.2,0.1),\n",
        "    v2.ToDtype(torch.float32, scale=True),\n",
        "    v2.Normalize(mean=train_mean, std=train_std),\n",
        "])\n",
        "\n",
        "val_transform = v2.Compose([\n",
        "    v2.ToImage(),\n",
        "    v2.Resize((224,224), antialias=True),\n",
        "    v2.ToDtype(torch.float32, scale=True),\n",
        "    v2.Normalize(mean=val_mean, std=val_std),\n",
        "])\n",
        "\n",
        "test_transform = v2.Compose([\n",
        "    v2.ToImage(),\n",
        "    v2.Resize((224,224), antialias=True),\n",
        "    v2.ToDtype(torch.float32, scale=True),\n",
        "    v2.Normalize(mean=test_mean, std=test_std),\n",
        "])\n"
      ],
      "metadata": {
        "id": "8KtCg0nzRqMv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ----------------- Final datasets -----------------\n",
        "train_int   = INatWithCoarse(train_d, fam2idx, cls2idx, transform=train_transform)\n",
        "val_int     = INatWithCoarse(val_d,   fam2idx, cls2idx, transform=val_transform)\n",
        "test_int    = INatWithCoarse(test_d,  fam2idx, cls2idx, transform=test_transform)\n",
        "\n",
        "train_onehot = INatWithCoarseOneHot(train_d, fam2idx, cls2idx, num_fine, num_coarse, transform=train_transform)\n",
        "val_onehot   = INatWithCoarseOneHot(val_d,   fam2idx, cls2idx, num_fine, num_coarse, transform=val_transform)\n",
        "test_onehot  = INatWithCoarseOneHot(test_d,  fam2idx, cls2idx, num_fine, num_coarse, transform=test_transform)\n",
        "\n",
        "# ----------------- DataLoaders -----------------\n",
        "# dataloaders\n",
        "loader_kwargs_train = dict(\n",
        "    batch_size=512,\n",
        "    shuffle=True,\n",
        "    num_workers=12,\n",
        "    pin_memory=True,\n",
        "    persistent_workers=True,\n",
        "    prefetch_factor=4,\n",
        "    worker_init_fn=seed_worker,\n",
        "    generator=g,\n",
        "    drop_last=True,\n",
        ")\n",
        "\n",
        "loader_kwargs_eval  = dict(\n",
        "    batch_size=512,\n",
        "    shuffle=False,\n",
        "    num_workers=12,\n",
        "    pin_memory=True,\n",
        "    persistent_workers=True,\n",
        "    worker_init_fn = seed_worker,\n",
        "    generator = g)\n",
        "\n",
        "train_loader_int = DataLoader(train_int, **loader_kwargs_train)\n",
        "val_loader_int   = DataLoader(val_int,   **loader_kwargs_eval)\n",
        "test_loader_int  = DataLoader(test_int,  **loader_kwargs_eval)\n",
        "\n",
        "train_loader_onehot = DataLoader(train_onehot, **loader_kwargs_train)\n",
        "val_loader_onehot   = DataLoader(val_onehot,   **loader_kwargs_eval)\n",
        "test_loader_onehot  = DataLoader(test_onehot,  **loader_kwargs_eval)\n"
      ],
      "metadata": {
        "id": "mqGNOlUhRu0Z"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Build fine->coarse mapping without touching images\n",
        "# fine (family) and coarse (class) labels for the training dataset\n",
        "fine_labels_train = []\n",
        "coarse_labels_train = []\n",
        "fine_to_coarse = {}\n",
        "for _, t in train_d:   # train_d is a Subset of NamedSubset\n",
        "    fam_id = fam2idx[t[\"family_name\"]]\n",
        "    cls_id = cls2idx[t[\"class_name\"]]\n",
        "    fine_labels_train.append(fam_id)\n",
        "    coarse_labels_train.append(cls_id)\n",
        "    if fam_id not in fine_to_coarse:\n",
        "        fine_to_coarse[fam_id] = cls_id\n",
        "\n",
        "fine_labels_train   = np.array(fine_labels_train)\n",
        "coarse_labels_train = np.array(coarse_labels_train)\n",
        "\n",
        "print(\"fine_labels_train shape:\", fine_labels_train.shape)\n",
        "print(\"coarse_labels_train shape:\", coarse_labels_train.shape)\n"
      ],
      "metadata": {
        "id": "QIWfq-0NiC1c",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "c6fad9f0-215a-44c4-b8ed-f8b0fe30b1ed"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "fine_labels_train shape: (32192,)\n",
            "coarse_labels_train shape: (32192,)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Build fine->coarse mapping without touching images\n",
        "# fine (family) and coarse (class) labels for the training dataset\n",
        "fine_labels_val = []\n",
        "coarse_labels_val = []\n",
        "for _, t in val_d:   # val_d is a Subset of NamedSubset\n",
        "    fam_id = fam2idx[t[\"family_name\"]]\n",
        "    cls_id = cls2idx[t[\"class_name\"]]\n",
        "    fine_labels_val.append(fam_id)\n",
        "    coarse_labels_val.append(cls_id)\n",
        "\n",
        "fine_labels_val   = np.array(fine_labels_val)\n",
        "coarse_labels_val = np.array(coarse_labels_val)\n",
        "\n",
        "print(\"fine_labels_val shape:\", fine_labels_val.shape)\n",
        "print(\"coarse_labels_val shape:\", coarse_labels_val.shape)\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JJ4mkPsCCNd0",
        "outputId": "789325d8-d58b-4f75-8bbc-6088be252905"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "fine_labels_val shape: (6898,)\n",
            "coarse_labels_val shape: (6898,)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Build fine->coarse mapping without touching images\n",
        "# fine (family) and coarse (class) labels for the training dataset\n",
        "fine_labels_test = []\n",
        "coarse_labels_test = []\n",
        "for _, t in test_d:   # test_d is a Subset of NamedSubset\n",
        "    fam_id = fam2idx[t[\"family_name\"]]\n",
        "    cls_id = cls2idx[t[\"class_name\"]]\n",
        "    fine_labels_test.append(fam_id)\n",
        "    coarse_labels_test.append(cls_id)\n",
        "\n",
        "fine_labels_test   = np.array(fine_labels_test)\n",
        "coarse_labels_test = np.array(coarse_labels_test)\n",
        "\n",
        "print(\"fine_labels_test shape:\", fine_labels_test.shape)\n",
        "print(\"coarse_labels_test shape:\", coarse_labels_test.shape)\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "daz9bWErCaK6",
        "outputId": "a00ff34d-7126-4507-9285-136b225ef3f0"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "fine_labels_test shape: (6900,)\n",
            "coarse_labels_test shape: (6900,)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"fine_to_coarse : \", fine_to_coarse)"
      ],
      "metadata": {
        "id": "527501WylTk7",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "79cd8161-f98a-4e4e-9fcb-51e7685a7db5"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "fine_to_coarse :  {45: 4, 23: 2, 39: 2, 44: 4, 32: 2, 8: 4, 17: 3, 22: 4, 6: 0, 2: 1, 21: 4, 40: 5, 33: 4, 29: 2, 31: 2, 16: 2, 49: 1, 11: 4, 27: 2, 26: 1, 5: 4, 24: 2, 25: 4, 34: 3, 3: 4, 28: 3, 9: 4, 12: 4, 13: 2, 10: 4, 15: 6, 7: 3, 18: 2, 30: 4, 14: 2, 20: 4, 19: 4, 47: 1, 0: 1, 35: 4, 38: 1, 37: 1, 1: 2, 42: 3, 41: 4, 4: 2, 46: 4, 36: 2, 43: 4, 48: 2}\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "idx2fam = {v: k for k, v in fam2idx.items()}\n",
        "fine_classes = [idx2fam[i] for i in range(num_fine)]\n",
        "\n",
        "idx2cls = {v: k for k, v in cls2idx.items()}\n",
        "coarse_classes = [idx2cls[i] for i in range(num_coarse)]\n",
        "\n",
        "print(\"fine_classes\", fine_classes)\n",
        "print(\"coarse_classes\", coarse_classes)\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "X3gM_VlhSE78",
        "outputId": "65ac737d-c268-4dbe-b306-6b367cbd46fa"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "fine_classes ['Accipitridae', 'Acrididae', 'Anatidae', 'Apiaceae', 'Apidae', 'Apocynaceae', 'Araneidae', 'Asparagaceae', 'Asteraceae', 'Boraginaceae', 'Brassicaceae', 'Cactaceae', 'Caryophyllaceae', 'Cerambycidae', 'Coenagrionidae', 'Colubridae', 'Crambidae', 'Cyperaceae', 'Erebidae', 'Ericaceae', 'Euphorbiaceae', 'Fabaceae', 'Fagaceae', 'Geometridae', 'Hesperiidae', 'Lamiaceae', 'Laridae', 'Libellulidae', 'Liliaceae', 'Lycaenidae', 'Malvaceae', 'Noctuidae', 'Nymphalidae', 'Onagraceae', 'Orchidaceae', 'Orobanchaceae', 'Papilionidae', 'Parulidae', 'Picidae', 'Pieridae', 'Pinaceae', 'Plantaginaceae', 'Poaceae', 'Polygonaceae', 'Ranunculaceae', 'Rosaceae', 'Rubiaceae', 'Scolopacidae', 'Sphingidae', 'Tyrannidae']\n",
            "coarse_classes ['Arachnida', 'Aves', 'Insecta', 'Liliopsida', 'Magnoliopsida', 'Pinopsida', 'Reptilia']\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## MODEL"
      ],
      "metadata": {
        "id": "iWS7FQeNcKJX"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "# Swin as the base with dual head with direct logits\n",
        "class SwinMultiTask(nn.Module):\n",
        "    def __init__(self, num_fine_labels, num_coarse_labels):\n",
        "        super(SwinMultiTask, self).__init__()\n",
        "\n",
        "        # Load the pre-trained Swin Transformer model from timm\n",
        "        self.base_model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=0)\n",
        "\n",
        "        # Freeze layers\n",
        "        for param in self.base_model.parameters():\n",
        "            param.requires_grad = False #True for Unfreeze\n",
        "\n",
        "        # Dropout for regularisation\n",
        "        self.dropout = nn.Dropout(0.3)\n",
        "\n",
        "        # Batch normalisation\n",
        "        self.bn1 = nn.BatchNorm1d(1024)\n",
        "        self.bn2 = nn.BatchNorm1d(512)\n",
        "\n",
        "        # Add dense layers after flattening\n",
        "        self.fc1 = nn.Linear(self.base_model.num_features, 1024)  # First dense layer with 1024 units\n",
        "        self.fc2 = nn.Linear(1024, 512)  # Second dense layer with 512 units\n",
        "\n",
        "        # Fine classification head\n",
        "        self.fine_classifier = nn.Linear(512, num_fine_labels)\n",
        "\n",
        "        # Coarse classification head\n",
        "        self.coarse_classifier = nn.Linear(512, num_coarse_labels)\n",
        "\n",
        "        # Activation\n",
        "        self.relu = nn.ReLU()\n",
        "\n",
        "    def forward(self, x, return_features=False):\n",
        "\n",
        "        features = self.base_model(x)  # Get the features from the Swin Transformer\n",
        "\n",
        "         # First dense block\n",
        "        x = self.fc1(features) # Dense layer with 1024 units\n",
        "        x = self.bn1(x)  # Batch normalisation\n",
        "        x = self.relu(x)  # ReLU activation\n",
        "        x = self.dropout(x)  # Dropout\n",
        "\n",
        "        if return_features:  # Return features from the penultimate layer (after 1024 dense layer)\n",
        "            return x\n",
        "\n",
        "        # Second dense block\n",
        "        x = self.fc2(x)\n",
        "        x = self.bn2(x)\n",
        "        x = self.relu(x)\n",
        "        x = self.dropout(x)\n",
        "\n",
        "        # Fine and coarse label predictions (no softmax)\n",
        "        fine_logits = self.fine_classifier(x)  # Fine classification head\n",
        "        coarse_logits = self.coarse_classifier(x)  # Coarse classification head\n",
        "\n",
        "        # Return both fine and coarse logits\n",
        "        return fine_logits, coarse_logits\n"
      ],
      "metadata": {
        "id": "2NaEV7bCcMBw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## COMPUTE FOCAL SETS"
      ],
      "metadata": {
        "id": "FSSkMTIib_ko"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### GET EMBEDDINGS"
      ],
      "metadata": {
        "id": "zVHYJp--cCy2"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Initialize model\n",
        "model_aux = SwinMultiTask(num_fine_labels, num_coarse_labels).to(device)\n",
        "model_aux.eval()\n",
        "\n",
        "# Extract features from penultimate layer\n",
        "def extract_features_from_loader(model, dataloader, device):\n",
        "    model.eval()\n",
        "    features, fine_labels, coarse_labels = [], [], []\n",
        "\n",
        "    with torch.no_grad():\n",
        "        for inputs, fine, coarse in dataloader:  # integer dataloader\n",
        "            inputs = inputs.to(device)\n",
        "            feats = model(inputs, return_features=True)\n",
        "            features.append(feats.cpu())\n",
        "            fine_labels.extend(fine.numpy().tolist())\n",
        "            coarse_labels.extend(coarse.numpy().tolist())\n",
        "\n",
        "    features = torch.cat(features, dim=0).numpy()\n",
        "    fine_labels = np.array(fine_labels)\n",
        "    coarse_labels = np.array(coarse_labels)\n",
        "    return features, fine_labels, coarse_labels\n",
        "\n",
        "# return data features with fine an coarse labels\n",
        "train_features, fine_train_labels, coarse_train_labels = extract_features_from_loader(\n",
        "    model_aux, train_loader_int, device\n",
        ")\n",
        "\n",
        "print(train_features.shape)\n",
        "print(fine_train_labels.shape)\n",
        "print(coarse_train_labels.shape)\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hRXXO5yBcGG4",
        "outputId": "2daafa23-0171-433e-f5b6-bfffdd6b91c7"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "(40000, 1024)\n",
            "(40000,)\n",
            "(40000,)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### FOCAL SET CALCULATION"
      ],
      "metadata": {
        "id": "T3ekFKKRo8YY"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "# Fit Gaussian Mixture Models (GMM) for each class\n",
        "def fit_gmm(classes, train_embedded_tsne, y_train):\n",
        "    individual_gms = []\n",
        "    for i in range(len(classes)):\n",
        "        gmm = GaussianMixture(n_components=1, random_state=7)\n",
        "        gmm.fit(train_embedded_tsne[y_train == i])\n",
        "        individual_gms.append(gmm)\n",
        "\n",
        "    return individual_gms\n",
        "\n",
        "# Create ellipses based on GMM\n",
        "def ellipse(individual_gms, num_classes):\n",
        "    means = []\n",
        "    eigen_vecs = []\n",
        "    stds = []\n",
        "    feature_space = 3\n",
        "\n",
        "    for gmm in individual_gms:\n",
        "        means.append(gmm.means_[0])\n",
        "        v, w = linalg.eigh(gmm.covariances_[0])\n",
        "        v = 2.0 * np.sqrt(7.815) * np.sqrt(v)\n",
        "        stds.append(v)\n",
        "        eigen_vecs.append(w)\n",
        "\n",
        "    means = np.array(means)\n",
        "    eigen_vecs = np.array(eigen_vecs)\n",
        "    stds = np.array(stds)\n",
        "\n",
        "    max_std = np.max(stds)\n",
        "    max_len = int(max_std) + 2\n",
        "    reg_shape = (max_len,) * feature_space\n",
        "    center = np.array(reg_shape) // 2\n",
        "\n",
        "    indices = np.indices(reg_shape)\n",
        "    indices = np.transpose(indices, list(np.arange(1, len(reg_shape) + 1)) + [0])\n",
        "    indices = indices.reshape((np.prod(reg_shape), feature_space))\n",
        "\n",
        "    regions = []\n",
        "    vecs = indices - center\n",
        "    vec_norms = np.linalg.norm(vecs, axis=-1) + 1e-31\n",
        "\n",
        "    for i in range(num_classes):\n",
        "        ell = np.sum(vecs[:, None, :] * eigen_vecs[i][None, :, :], axis=-1)\n",
        "        ell = np.abs(ell / (vec_norms[:, None] * np.linalg.norm(eigen_vecs[i], axis=-1)[None, :]))\n",
        "        ell = np.linalg.norm(np.sum((ell * (stds[i][None, :] / 2))[:, :, None] * eigen_vecs[i][None, :, :], axis=1), axis=-1) + 1e-25\n",
        "        ell = (vec_norms <= ell).reshape(reg_shape).astype(np.float32)\n",
        "\n",
        "        regions.append(ell)\n",
        "\n",
        "    return regions, means, max_len\n",
        "\n",
        "# Calculate overlaps and choose the top-k sets with highest overlaps\n",
        "def overlaps(k, classes, num_clusters, classes_dict, regions, means, max_len):\n",
        "    clusters = classes\n",
        "    overlaps = {}\n",
        "    top_sets = [set([c]) for c in clusters]\n",
        "\n",
        "    for cardinality in range(2, num_clusters + 1):\n",
        "        for ts in top_sets:\n",
        "            for clus in clusters:\n",
        "                s = ts.copy()\n",
        "                s.add(clus)\n",
        "                s = sorted(s)\n",
        "                if len(s) == cardinality and \",\".join(s) not in overlaps:\n",
        "                    region = np.zeros_like(regions[0])\n",
        "                    smallest_region = np.inf\n",
        "                    for num, name in enumerate(s):\n",
        "                        c = classes_dict[name]\n",
        "                        if num == 0:\n",
        "                            region += regions[c]\n",
        "                            reg_cen = means[c]\n",
        "                        else:\n",
        "                            top_corner = means[c] - reg_cen\n",
        "                            if any(top_corner < -max_len) or any(top_corner > max_len):\n",
        "                                pass\n",
        "                            else:\n",
        "                                limits = []\n",
        "                                start_points = []\n",
        "                                for val in top_corner:\n",
        "                                    if val < 0:\n",
        "                                        limits.append((int(abs(val)), max_len))\n",
        "                                        start_points.append(0)\n",
        "                                    else:\n",
        "                                        limits.append((0, max_len - int(val)))\n",
        "                                        start_points.append(int(val))\n",
        "\n",
        "                                eval_s = []\n",
        "                                for n1 in range(len(limits)):\n",
        "                                    eval_s.append(f\"{limits[n1][0]}:{limits[n1][1]}\")\n",
        "                                eval_s = \",\".join(eval_s)\n",
        "                                cutout = eval(f\"regions[{c}][{eval_s}]\")\n",
        "\n",
        "                                eval_s = []\n",
        "                                for n1 in range(len(start_points)):\n",
        "                                    eval_s.append(f\"{start_points[n1]}:{start_points[n1]}+{cutout.shape[n1]}\")\n",
        "                                eval_s = \",\".join(eval_s)\n",
        "                                exec(f\"region[{eval_s}] += cutout\")\n",
        "\n",
        "                        if np.sum(regions[c]) < smallest_region:\n",
        "                            smallest_region = np.sum(regions[c])\n",
        "\n",
        "                    intersection = np.sum([region == len(s)])\n",
        "                    op = intersection / np.sum(region != 0)\n",
        "                    overlaps[\",\".join(s)] = op\n",
        "\n",
        "        keys = np.array(list(overlaps.keys()))\n",
        "        values = np.array(list(overlaps.values()))\n",
        "        arg_sorted = np.argsort(values)[::-1]\n",
        "        top_sets = [set([num for num in cl.split(\",\")]) for cl in keys[arg_sorted[:k]] if len(set([num for num in cl.split(\",\")])) == cardinality]\n",
        "\n",
        "    keys = list(overlaps.keys())\n",
        "    keys = np.array([set([num for num in cl.split(\",\")]) for cl in keys])\n",
        "    values = np.array(list(overlaps.values()))\n",
        "    arg_sorted = np.argsort(values)[::-1]\n",
        "    new_k = min(k, np.sum(values[arg_sorted[:k]] != 0))\n",
        "    new_classes = [set([c]) for c in classes] + list(keys[arg_sorted[:new_k]])\n",
        "\n",
        "    return new_classes\n"
      ],
      "metadata": {
        "id": "iP4-Do3xddjK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Apply t-SNE\n",
        "train_embedded_tsne = TSNE(\n",
        "    n_components=3, init='random', perplexity=30, random_state=42, n_jobs=-1\n",
        ").fit_transform(train_features)\n",
        "\n",
        "# Fit GMMs\n",
        "individual_gms_fine = fit_gmm(fine_classes, train_embedded_tsne, fine_train_labels)\n",
        "individual_gms_coarse = fit_gmm(coarse_classes, train_embedded_tsne, coarse_train_labels)\n",
        "\n",
        "# Calculate ellipses for Fine Labels\n",
        "regions_fine, means_fine, max_len_fine = ellipse(individual_gms_fine, len(fine_classes))\n",
        "\n",
        "# Calculate ellipses for Coarse Labels\n",
        "regions_coarse, means_coarse, max_len_coarse = ellipse(individual_gms_coarse, len(coarse_classes))\n"
      ],
      "metadata": {
        "id": "97I_Wn0OhXWd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Fine Label Overlap and Class Selection\n",
        "classes_dict_fine = {cls: idx for idx, cls in enumerate(fine_classes)}\n",
        "num_clusters_fine = len(fine_classes)\n",
        "k_fine = 20\n",
        "new_classes_fine = overlaps(k_fine, fine_classes, num_clusters_fine, classes_dict_fine, regions_fine, means_fine, max_len_fine)\n",
        "\n",
        "# Coarse Label Overlap and Class Selection\n",
        "classes_dict_coarse = {cls: idx for idx, cls in enumerate(coarse_classes)}\n",
        "num_clusters_coarse = len(coarse_classes)\n",
        "k_coarse = 4\n",
        "new_classes_coarse = overlaps(k_coarse, coarse_classes, num_clusters_coarse, classes_dict_coarse, regions_coarse, means_coarse, max_len_coarse)\n"
      ],
      "metadata": {
        "id": "BHWDLl_Hhqaq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# save new classes\n",
        "FOCAL_SETS_FINE = os.path.join(SAVE_DIR, 'new_classes_fine.pth')\n",
        "torch.save(new_classes_fine, FOCAL_SETS_FINE)\n",
        "\n",
        "FOCAL_SETS_COARSE = os.path.join(SAVE_DIR, 'new_classes_coarse.pth')\n",
        "torch.save(new_classes_coarse, FOCAL_SETS_COARSE)\n",
        "\n",
        "# Print the new classes\n",
        "print(\"New Fine Classes:\", new_classes_fine)\n",
        "print(\"New Coarse Classes:\", new_classes_coarse)\n",
        "\n",
        "# Print the new classes\n",
        "print(\"New Fine Classes Size:\", len(new_classes_fine))\n",
        "print(\"New Coarse Classes Size:\", len(new_classes_coarse))\n",
        "\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ZXy9yRnJhtsR",
        "outputId": "1772f45d-99c7-43d0-8cc3-9f4ec60e5c81"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "New Fine Classes: [{'apple'}, {'aquarium_fish'}, {'baby'}, {'bear'}, {'beaver'}, {'bed'}, {'bee'}, {'beetle'}, {'bicycle'}, {'bottle'}, {'bowl'}, {'boy'}, {'bridge'}, {'bus'}, {'butterfly'}, {'camel'}, {'can'}, {'castle'}, {'caterpillar'}, {'cattle'}, {'chair'}, {'chimpanzee'}, {'clock'}, {'cloud'}, {'cockroach'}, {'couch'}, {'crab'}, {'crocodile'}, {'cup'}, {'dinosaur'}, {'dolphin'}, {'elephant'}, {'flatfish'}, {'forest'}, {'fox'}, {'girl'}, {'hamster'}, {'house'}, {'kangaroo'}, {'keyboard'}, {'lamp'}, {'lawn_mower'}, {'leopard'}, {'lion'}, {'lizard'}, {'lobster'}, {'man'}, {'maple_tree'}, {'motorcycle'}, {'mountain'}, {'mouse'}, {'mushroom'}, {'oak_tree'}, {'orange'}, {'orchid'}, {'otter'}, {'palm_tree'}, {'pear'}, {'pickup_truck'}, {'pine_tree'}, {'plain'}, {'plate'}, {'poppy'}, {'porcupine'}, {'possum'}, {'rabbit'}, {'raccoon'}, {'ray'}, {'road'}, {'rocket'}, {'rose'}, {'sea'}, {'seal'}, {'shark'}, {'shrew'}, {'skunk'}, {'skyscraper'}, {'snail'}, {'snake'}, {'spider'}, {'squirrel'}, {'streetcar'}, {'sunflower'}, {'sweet_pepper'}, {'table'}, {'tank'}, {'telephone'}, {'television'}, {'tiger'}, {'tractor'}, {'train'}, {'trout'}, {'tulip'}, {'turtle'}, {'wardrobe'}, {'whale'}, {'willow_tree'}, {'wolf'}, {'woman'}, {'worm'}, {'pine_tree', 'willow_tree'}, {'seal', 'otter'}, {'crocodile', 'beaver'}, {'trout', 'lizard'}, {'pine_tree', 'maple_tree'}, {'crocodile', 'otter'}, {'woman', 'man'}, {'maple_tree', 'willow_tree'}, {'squirrel', 'porcupine'}, {'shrew', 'porcupine'}, {'squirrel', 'raccoon'}, {'beaver', 'otter'}, {'castle', 'house'}, {'tiger', 'leopard'}, {'lizard', 'turtle'}, {'trout', 'turtle'}, {'crocodile', 'seal'}, {'lobster', 'lizard'}, {'whale', 'dolphin'}, {'bear', 'elephant'}]\n",
            "New Coarse Classes: [{'aquatic_mammals'}, {'fish'}, {'flowers'}, {'food_containers'}, {'fruit_and_vegetables'}, {'household_electrical_devices'}, {'household_furniture'}, {'insects'}, {'large_carnivores'}, {'large_man-made_outdoor_things'}, {'large_natural_outdoor_scenes'}, {'large_omnivores_and_herbivores'}, {'medium_mammals'}, {'non-insect_invertebrates'}, {'people'}, {'reptiles'}, {'small_mammals'}, {'trees'}, {'vehicles_1'}, {'vehicles_2'}, {'vehicles_1', 'vehicles_2'}, {'large_carnivores', 'medium_mammals'}, {'reptiles', 'aquatic_mammals'}, {'large_carnivores', 'large_omnivores_and_herbivores'}]\n",
            "New Fine Classes Size: 120\n",
            "New Coarse Classes Size: 24\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Paths\n",
        "FOCAL_SETS_FINE = os.path.join(SAVE_DIR, 'new_classes_fine.pth')\n",
        "FOCAL_SETS_COARSE = os.path.join(SAVE_DIR, 'new_classes_coarse.pth')\n",
        "\n",
        "# Load\n",
        "new_classes_fine = torch.load(FOCAL_SETS_FINE)\n",
        "new_classes_coarse = torch.load(FOCAL_SETS_COARSE)\n",
        "\n",
        "# Print the new classes\n",
        "print(\"New Fine Classes:\", new_classes_fine)\n",
        "print(\"New Coarse Classes:\", new_classes_coarse)\n",
        "\n",
        "# Print the new classes\n",
        "print(\"New Fine Classes Size:\", len(new_classes_fine))\n",
        "print(\"New Coarse Classes Size:\", len(new_classes_coarse))\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "cpJRHBqtpYQR",
        "outputId": "9e153f33-24e4-44c4-d592-6afab075aa0b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "New Fine Classes: [{'apple'}, {'aquarium_fish'}, {'baby'}, {'bear'}, {'beaver'}, {'bed'}, {'bee'}, {'beetle'}, {'bicycle'}, {'bottle'}, {'bowl'}, {'boy'}, {'bridge'}, {'bus'}, {'butterfly'}, {'camel'}, {'can'}, {'castle'}, {'caterpillar'}, {'cattle'}, {'chair'}, {'chimpanzee'}, {'clock'}, {'cloud'}, {'cockroach'}, {'couch'}, {'crab'}, {'crocodile'}, {'cup'}, {'dinosaur'}, {'dolphin'}, {'elephant'}, {'flatfish'}, {'forest'}, {'fox'}, {'girl'}, {'hamster'}, {'house'}, {'kangaroo'}, {'keyboard'}, {'lamp'}, {'lawn_mower'}, {'leopard'}, {'lion'}, {'lizard'}, {'lobster'}, {'man'}, {'maple_tree'}, {'motorcycle'}, {'mountain'}, {'mouse'}, {'mushroom'}, {'oak_tree'}, {'orange'}, {'orchid'}, {'otter'}, {'palm_tree'}, {'pear'}, {'pickup_truck'}, {'pine_tree'}, {'plain'}, {'plate'}, {'poppy'}, {'porcupine'}, {'possum'}, {'rabbit'}, {'raccoon'}, {'ray'}, {'road'}, {'rocket'}, {'rose'}, {'sea'}, {'seal'}, {'shark'}, {'shrew'}, {'skunk'}, {'skyscraper'}, {'snail'}, {'snake'}, {'spider'}, {'squirrel'}, {'streetcar'}, {'sunflower'}, {'sweet_pepper'}, {'table'}, {'tank'}, {'telephone'}, {'television'}, {'tiger'}, {'tractor'}, {'train'}, {'trout'}, {'tulip'}, {'turtle'}, {'wardrobe'}, {'whale'}, {'willow_tree'}, {'wolf'}, {'woman'}, {'worm'}, {'willow_tree', 'pine_tree'}, {'seal', 'otter'}, {'crocodile', 'beaver'}, {'trout', 'lizard'}, {'maple_tree', 'pine_tree'}, {'crocodile', 'otter'}, {'woman', 'man'}, {'willow_tree', 'maple_tree'}, {'porcupine', 'squirrel'}, {'shrew', 'porcupine'}, {'raccoon', 'squirrel'}, {'beaver', 'otter'}, {'house', 'castle'}, {'tiger', 'leopard'}, {'lizard', 'turtle'}, {'trout', 'turtle'}, {'crocodile', 'seal'}, {'lizard', 'lobster'}, {'whale', 'dolphin'}, {'elephant', 'bear'}]\n",
            "New Coarse Classes: [{'aquatic_mammals'}, {'fish'}, {'flowers'}, {'food_containers'}, {'fruit_and_vegetables'}, {'household_electrical_devices'}, {'household_furniture'}, {'insects'}, {'large_carnivores'}, {'large_man-made_outdoor_things'}, {'large_natural_outdoor_scenes'}, {'large_omnivores_and_herbivores'}, {'medium_mammals'}, {'non-insect_invertebrates'}, {'people'}, {'reptiles'}, {'small_mammals'}, {'trees'}, {'vehicles_1'}, {'vehicles_2'}, {'vehicles_1', 'vehicles_2'}, {'large_carnivores', 'medium_mammals'}, {'reptiles', 'aquatic_mammals'}, {'large_omnivores_and_herbivores', 'large_carnivores'}]\n",
            "New Fine Classes Size: 120\n",
            "New Coarse Classes Size: 24\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## METRICS"
      ],
      "metadata": {
        "id": "GxGVUknPdXrX"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Function to calculate logical consistency\n",
        "def calculate_logical_consistency_base(fine_preds, coarse_preds):\n",
        "    correct_logical = 0\n",
        "    for fine_pred, coarse_pred in zip(fine_preds, coarse_preds):\n",
        "        if fine_to_coarse[fine_pred.item()] == coarse_pred.item():\n",
        "            correct_logical += 1\n",
        "    return correct_logical\n"
      ],
      "metadata": {
        "id": "SK2e0up6UARJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eAMq7a90-8Sz"
      },
      "outputs": [],
      "source": [
        "# Calculate logical consistency for multihot labels\n",
        "def calculate_logical_consistency(pred_fine, pred_coarse):\n",
        "    mapped = torch.tensor([fine_to_coarse[int(i)] for i in pred_fine.tolist()])\n",
        "    return (mapped == pred_coarse).float().mean().item()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a_0V0tb7Oo8E"
      },
      "outputs": [],
      "source": [
        "\n",
        "# Computes the Expected Calibration Error (ECE) for a given set of predictions and true labels\n",
        "def compute_ece_pytorch(probabilities, labels, num_bins=10):\n",
        "\n",
        "    # Get the max confidence and the predicted class from the probabilities\n",
        "    confidences, predictions = torch.max(probabilities, 1)\n",
        "\n",
        "    # Initialize bin boundaries\n",
        "    bins = torch.linspace(0, 1, num_bins + 1, device=probabilities.device)  # Bins from 0 to 1\n",
        "    bin_lowers = bins[:-1]  # Lower boundaries of the bins\n",
        "    bin_uppers = bins[1:]   # Upper boundaries of the bins\n",
        "\n",
        "    # Initialize accumulators for accuracy, confidence, and bin counts\n",
        "    bin_accuracy = torch.zeros(num_bins, device=probabilities.device)\n",
        "    bin_confidence = torch.zeros(num_bins, device=probabilities.device)\n",
        "    bin_counts = torch.zeros(num_bins, device=probabilities.device)\n",
        "\n",
        "    # Compute accuracy (1 if correct, 0 if incorrect)\n",
        "    accuracies = predictions.eq(labels).float()\n",
        "\n",
        "    # Populate bin-wise accuracy, confidence, and counts\n",
        "    for i in range(num_bins):\n",
        "        if i == 0:\n",
        "            bin_mask = (confidences >= bin_lowers[i]) & (confidences <= bin_uppers[i])\n",
        "        else:\n",
        "            bin_mask = (confidences > bin_lowers[i]) & (confidences <= bin_uppers[i])\n",
        "\n",
        "        bin_count = bin_mask.sum().float()\n",
        "\n",
        "        if bin_count > 0:\n",
        "            bin_accuracy[i] = accuracies[bin_mask].mean()  # Mean accuracy in the bin\n",
        "            bin_confidence[i] = confidences[bin_mask].mean()  # Mean confidence in the bin\n",
        "            bin_counts[i] = bin_count\n",
        "\n",
        "    # Normalize bin counts\n",
        "    total_count = confidences.size(0)\n",
        "    weights = bin_counts / total_count\n",
        "\n",
        "    # Compute ECE as the weighted sum of the absolute difference between accuracy and confidence\n",
        "    ece = torch.sum(torch.abs(bin_confidence - bin_accuracy) * weights)\n",
        "\n",
        "    return ece.item()\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def compute_entropy(probs: torch.Tensor) -> float:\n",
        "    \"\"\"\n",
        "    Average Shannon entropy in nats,\n",
        "    probs: [N, C]\n",
        "    \"\"\"\n",
        "    if probs.numel() == 0:\n",
        "        return float(\"nan\")\n",
        "\n",
        "    # Clean & (re)normalize per row\n",
        "    probs = torch.nan_to_num(probs, nan=0.0, posinf=1.0, neginf=0.0)\n",
        "    row_sum = probs.sum(dim=1, keepdim=True).clamp_min(1e-12)\n",
        "    probs = probs / row_sum\n",
        "\n",
        "    # Standard entropy with clamp\n",
        "    p = probs.clamp(min=1e-12, max=1.0)\n",
        "    ent = -torch.sum(p * torch.log(p), dim=1)  # [N]\n",
        "    # If a row were degenerate, ent could still carry NaN; clean once more\n",
        "    ent = torch.nan_to_num(ent, nan=0.0, posinf=0.0, neginf=0.0)\n",
        "    return ent.mean().item()\n"
      ],
      "metadata": {
        "id": "BGAI_KEihEk1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T3y3C_NMqCys"
      },
      "outputs": [],
      "source": [
        "from sklearn.metrics import precision_score, recall_score, f1_score\n",
        "# Compute recall, precision and recall for multilabels\n",
        "def compute_multilabel_metrics(preds, labels, threshold=0.5):\n",
        "    preds = preds.detach().cpu()\n",
        "    labels = labels.detach().cpu()\n",
        "\n",
        "    binary_preds = (preds > threshold).float().numpy()\n",
        "    binary_labels = labels.float().numpy()\n",
        "\n",
        "    precision = precision_score(binary_labels, binary_preds, average='macro', zero_division=0)\n",
        "    recall = recall_score(binary_labels, binary_preds, average='macro', zero_division=0)\n",
        "    f1 = f1_score(binary_labels, binary_preds, average='macro', zero_division=0)\n",
        "\n",
        "    return precision, recall, f1\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n1tUU9lJq9Ev"
      },
      "outputs": [],
      "source": [
        "# Compute recall, precision and recall for single label\n",
        "def compute_singlelabel_metrics(y_true, y_pred):\n",
        "    precision = precision_score(y_true, y_pred, average='macro', zero_division=0)\n",
        "    recall = recall_score(y_true, y_pred, average='macro', zero_division=0)\n",
        "    f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)\n",
        "    return precision, recall, f1\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## HELPER FUNCTIONS"
      ],
      "metadata": {
        "id": "6UaJEO0OYRGw"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# load trained configs\n",
        "def load_trained_configs(results_list):\n",
        "    \"\"\"Collect {config_key, best_model_path, best_val_loss} entries that exist on disk.\"\"\"\n",
        "    configs = []\n",
        "    for r in results_list:\n",
        "        path = r.get('best_model_path')\n",
        "        key  = r.get('config_key', 'unknown_config')\n",
        "        bvl  = r.get('best_val_loss', float('nan'))\n",
        "        if isinstance(path, str) and os.path.exists(path):\n",
        "            configs.append({'config_key': key, 'best_model_path': path, 'best_val_loss': bvl})\n",
        "        else:\n",
        "            print(f\"[WARN] Skipping {key}: best_model_path missing or not found -> {path}\")\n",
        "    return configs\n",
        "\n",
        "# load state dictionary of a saved model\n",
        "def load_state_dict_robust(model, ckpt_path, device):\n",
        "    \"\"\"Load either a checkpoint dict with 'model_state_dict' or a raw state_dict.\"\"\"\n",
        "    obj = torch.load(ckpt_path, map_location=device)\n",
        "    if isinstance(obj, dict) and 'model_state_dict' in obj:\n",
        "        model.load_state_dict(obj['model_state_dict'])\n",
        "    else:\n",
        "        # raw state_dict\n",
        "        model.load_state_dict(obj)\n",
        "    return obj  # return for possible\n",
        "\n",
        "def safe_div(a, b):\n",
        "    return (a / b) if b else 0.0\n",
        "\n",
        "# Helper functions for per-config paths\n",
        "def get_partial_checkpoint_path(config_key):\n",
        "    return os.path.join(FOCAL_DIR, f\"partial_checkpoint_{config_key}.pth\")\n",
        "\n",
        "def get_best_model_path(config_key):\n",
        "    return os.path.join(FOCAL_DIR, f\"best_model_{config_key}.pth\")\n"
      ],
      "metadata": {
        "id": "6M2-0RaeYP4z"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Subdir for this experiment\n",
        "FOCAL_DIR = os.path.join(SAVE_DIR, \"focal\")\n",
        "os.makedirs(FOCAL_DIR, exist_ok=True)\n",
        "\n",
        "# File locations\n",
        "RESULTS_FILE        = os.path.join(FOCAL_DIR, \"all_training_results.pkl\")\n",
        "BEST_MODEL_METADATA = os.path.join(FOCAL_DIR, \"best_model_metadata.pkl\")\n",
        "\n",
        "\n",
        "def get_checkpoint_file(config_key):\n",
        "    \"\"\"Generate a unique checkpoint path for each config.\"\"\"\n",
        "    return os.path.join(FOCAL_DIR, f\"checkpoint_{config_key}.pt\")\n",
        "\n",
        "# Save the traininhg and validation results\n",
        "def save_checkpoint(epoch, model, criterion, optimizer, scheduler, scaler,\n",
        "                    results, config_key, best_val_loss, early_stopping_counter):\n",
        "    state = {\n",
        "        \"epoch\": epoch,\n",
        "        \"model_state_dict\": model.state_dict(),\n",
        "        \"criterion_state_dict\": criterion.state_dict(),\n",
        "        \"optimizer_state_dict\": optimizer.state_dict(),\n",
        "        \"scheduler_state_dict\": scheduler.state_dict(),\n",
        "        \"scaler_state_dict\": scaler.state_dict(),\n",
        "        \"results\": results,\n",
        "        \"config_key\": config_key,\n",
        "        \"best_val_loss\": best_val_loss,\n",
        "        \"early_stopping_counter\": early_stopping_counter\n",
        "    }\n",
        "    ckpt_file = get_checkpoint_file(config_key)\n",
        "    torch.save(state, ckpt_file)\n",
        "    print(f\"[{config_key}] Checkpoint saved at epoch {epoch+1}, val_loss={best_val_loss:.4f}\")\n",
        "\n",
        "\n",
        "def load_checkpoint(model, criterion, optimizer, scheduler, scaler, device, config_key):\n",
        "    ckpt_file = get_checkpoint_file(config_key)\n",
        "    if os.path.exists(ckpt_file):\n",
        "        ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)\n",
        "        model.load_state_dict(ckpt[\"model_state_dict\"])\n",
        "        criterion.load_state_dict(ckpt[\"criterion_state_dict\"])\n",
        "        optimizer.load_state_dict(ckpt[\"optimizer_state_dict\"])\n",
        "        scheduler.load_state_dict(ckpt[\"scheduler_state_dict\"])\n",
        "        scaler.load_state_dict(ckpt[\"scaler_state_dict\"])\n",
        "        print(f\"[{config_key}] Resumed from epoch {ckpt['epoch']} with best_val_loss={ckpt['best_val_loss']:.4f}\")\n",
        "        return (ckpt[\"epoch\"] + 1,\n",
        "                ckpt[\"results\"],\n",
        "                ckpt[\"best_val_loss\"],\n",
        "                ckpt[\"config_key\"],\n",
        "                ckpt.get(\"early_stopping_counter\", 0))\n",
        "    else:\n",
        "        return 0, {\"train\": [], \"val\": []}, float(\"inf\"), None, 0\n",
        "\n"
      ],
      "metadata": {
        "id": "pHZEKtnXxN6e"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## ENCODE GROUND TRUTHS WITH NEW CLASSES AND CREATE NEW DATALOADERS"
      ],
      "metadata": {
        "id": "CiIz3FD3pJHG"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Modifying the ground truth with belief encoding\n",
        "def groundtruthmod(y, classes, new_classes, dict):\n",
        "    y_encoded = np.zeros((len(y), len(new_classes)), dtype=int)\n",
        "    for i, label in enumerate(y):\n",
        "        for j, class_ in enumerate(new_classes):\n",
        "            if class_.issubset(set(classes)) and dict[label] in class_:\n",
        "                y_encoded[i, j] = 1\n",
        "    return y_encoded\n"
      ],
      "metadata": {
        "id": "i6G9TvBepPPG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Create a mapping from label index to class names for fine and coarse\n",
        "fine_dict_inverse = {i: fine_classes[i] for i in range(len(fine_classes))}\n",
        "coarse_dict_inverse = {i: coarse_classes[i] for i in range(len(coarse_classes))}\n"
      ],
      "metadata": {
        "id": "iBaYBjURpRxD"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Encode Fine Labels for Training, Validation, and Test\n",
        "fine_labels_train_encoded = groundtruthmod(fine_labels_train, fine_classes, new_classes_fine, fine_dict_inverse)\n",
        "fine_labels_val_encoded   = groundtruthmod(fine_labels_val, fine_classes, new_classes_fine, fine_dict_inverse)\n",
        "fine_labels_test_encoded  = groundtruthmod(fine_labels_test, fine_classes, new_classes_fine, fine_dict_inverse)\n",
        "\n",
        "# Debugging prints to check the shape of encoded fine labels\n",
        "print(f\"Fine labels train encoded shape: {fine_labels_train_encoded.shape}\")\n",
        "print(f\"Fine labels val encoded shape: {fine_labels_val_encoded.shape}\")\n",
        "print(f\"Fine labels test encoded shape: {fine_labels_test_encoded.shape}\")\n",
        "\n",
        "# Encode Coarse Labels for Training, Validation, and Test\n",
        "coarse_labels_train_encoded = groundtruthmod(coarse_labels_train, coarse_classes, new_classes_coarse, coarse_dict_inverse)\n",
        "coarse_labels_val_encoded   = groundtruthmod(coarse_labels_val, coarse_classes, new_classes_coarse, coarse_dict_inverse)\n",
        "coarse_labels_test_encoded  = groundtruthmod(coarse_labels_test, coarse_classes, new_classes_coarse, coarse_dict_inverse)\n",
        "\n",
        "# Debugging prints to check the shape of encoded coarse labels\n",
        "print(f\"Coarse labels train encoded shape: {coarse_labels_train_encoded.shape}\")\n",
        "print(f\"Coarse labels val encoded shape: {coarse_labels_val_encoded.shape}\")\n",
        "print(f\"Coarse labels test encoded shape: {coarse_labels_test_encoded.shape}\")\n"
      ],
      "metadata": {
        "id": "wSidSLv3pU5N",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "0b7d48a0-b94a-4f6e-f30e-17d3ab2334cb"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Fine labels train encoded shape: (40000, 120)\n",
            "Fine labels val encoded shape: (10000, 120)\n",
            "Fine labels test encoded shape: (10000, 120)\n",
            "Coarse labels train encoded shape: (40000, 24)\n",
            "Coarse labels val encoded shape: (10000, 24)\n",
            "Coarse labels test encoded shape: (10000, 24)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Datasets with belief encoding\n",
        "train_dataset_new = CIFAR100WithBeliefEncoding(\n",
        "    fine_labels_train_encoded, coarse_labels_train_encoded, train_dataset_part,\n",
        "    transform=train_transform\n",
        ")\n",
        "val_dataset_new = CIFAR100WithBeliefEncoding(\n",
        "    fine_labels_val_encoded, coarse_labels_val_encoded, val_dataset_part,\n",
        "    transform=val_test_transform,\n",
        "    fine_true_idx=fine_labels_val, coarse_true_idx=coarse_labels_val\n",
        ")\n",
        "test_dataset_new = CIFAR100WithBeliefEncoding(\n",
        "    fine_labels_test_encoded, coarse_labels_test_encoded, test_dataset_original,\n",
        "    transform=val_test_transform,\n",
        "    fine_true_idx=fine_labels_test, coarse_true_idx=coarse_labels_test\n",
        ")\n",
        "\n",
        "# Loaders with reproducibility settings\n",
        "train_loader_new = DataLoader(train_dataset_new,  **loader_kwargs_train)\n",
        "val_loader_new   = DataLoader(val_dataset_new, **loader_kwargs_eval)\n",
        "test_loader_new  = DataLoader(test_dataset_new, **loader_kwargs_eval)\n"
      ],
      "metadata": {
        "id": "TX0bcBwLbbTB"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xBiaQZEBPb6J"
      },
      "source": [
        "## BELIEF - MASS - PIGNISTIC PROBABILITY"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kxrAn3KYh309"
      },
      "outputs": [],
      "source": [
        "# Generate mass coefficient matrix\n",
        "def mass_coeff(new_classes):\n",
        "    mass_co = np.zeros((len(new_classes), len(new_classes)))\n",
        "\n",
        "    for i, A in enumerate(new_classes):\n",
        "        for j, B in enumerate(new_classes):\n",
        "            leng = 0\n",
        "            if set(B).issubset(set(A)):\n",
        "                leng = (-1) ** (len(A) - len(B))\n",
        "            mass_co[j][i] = leng\n",
        "    return mass_co\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def build_mass_coeff_tensor(new_classes, device):\n",
        "    mat = mass_coeff(new_classes)\n",
        "    return torch.tensor(mat, dtype=torch.float32, device=device)\n",
        "\n",
        "def build_betp_matrix(classes, new_classes, device):\n",
        "    # new_classes **without** Ω here; we append Ω as the last row/col logically\n",
        "    all_sets = list(new_classes) + [set(classes)]  # append Ω\n",
        "    M = torch.zeros((len(all_sets), len(classes)), dtype=torch.float32, device=device)\n",
        "    for i, c in enumerate(classes):\n",
        "        sc = {c}\n",
        "        for j, A in enumerate(all_sets):\n",
        "            if sc.issubset(A):\n",
        "                M[j, i] = 1.0 / len(A)\n",
        "    return M\n",
        "\n",
        "# Belief to mass conversion\n",
        "def belief_to_mass(pred_bel, mass_coeff_tensor):\n",
        "    # Ensure same device & dtype as predictions\n",
        "    M = mass_coeff_tensor.to(device=pred_bel.device, dtype=pred_bel.dtype)\n",
        "    m = pred_bel @ M\n",
        "    m = torch.clamp(m, min=0.0)\n",
        "    rest = (1.0 - m.sum(dim=-1)).clamp_min(0.0)\n",
        "    m = torch.cat([m, rest.unsqueeze(-1)], dim=-1)\n",
        "    m = m / m.sum(dim=-1, keepdim=True)\n",
        "    return m\n",
        "\n",
        "def final_betp(mass_with_omega, betp_matrix):\n",
        "    P = betp_matrix.to(device=mass_with_omega.device, dtype=mass_with_omega.dtype)\n",
        "    return mass_with_omega @ P\n",
        "\n"
      ],
      "metadata": {
        "id": "oq7lpyo77yYD"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w5nvvJW2Qb5Y"
      },
      "source": [
        "## CALCULATE MEMBERSHIP VALUE WITH T-NORMS"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Convert string-based fine_classes/coarse_classes to index-based\n",
        "fine_class_name_to_index = {name: idx for idx, name in enumerate(fine_classes)}\n",
        "coarse_class_name_to_index = {name: idx for idx, name in enumerate(coarse_classes)}\n",
        "\n",
        "# Convert string-based new_classes_fine to index-based\n",
        "new_classes_fine_idx = [set(fine_class_name_to_index[name] for name in s) for s in new_classes_fine]\n",
        "\n",
        "# new_classes_coarse to index\n",
        "new_classes_coarse_idx = [set(coarse_class_name_to_index[name] for name in s) for s in new_classes_coarse]\n"
      ],
      "metadata": {
        "id": "vNT3_XjpeM48"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PwjrhJAsJfXo"
      },
      "outputs": [],
      "source": [
        "# Triangular membership function\n",
        "def triangular_membership(x, a=0.5, b=1.0, sigma=None):\n",
        "    mid = (a + b) * 0.5\n",
        "    half = (b - a) * 0.5 + 1e-8\n",
        "    return (1 - ((x - mid).abs() / half)).clamp_min(0.0)\n",
        "\n",
        "# Trapezoidal membership function\n",
        "def trapezoidal_membership(x, a=0.0, b=0.3, c=0.7, d=1.0, sigma=None):\n",
        "    # rising, plateau, falling\n",
        "    rise   = (x - a) / (b - a + 1e-8)\n",
        "    fall   = (d - x) / (d - c + 1e-8)\n",
        "    plateau= x.new_ones(())\n",
        "    return torch.clamp(torch.min(torch.min(rise, plateau), fall), min=0.0)\n",
        "\n",
        "# Gaussian membership function for fuzzy logic\n",
        "def gaussian_membership(x, mean=1.0, sigma=1.0):\n",
        "    return torch.exp(-((x - mean) ** 2) / (2 * sigma ** 2))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ViHktSpN_ovg"
      },
      "outputs": [],
      "source": [
        "# T-norms for logical conjunction\n",
        "def product_t_norm(x, y):\n",
        "    return x * y\n",
        "\n",
        "def godel_t_norm(x, y):\n",
        "    return torch.min(x, y)\n",
        "\n",
        "def lukasiewicz_t_norm(x, y):\n",
        "    return (x + y - 1).clamp_min(0.0)\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Build structures for logical consistency calculation\n",
        "def build_consistency_structures(new_classes_fine_idx, new_classes_coarse_idx, fine_to_coarse, device):\n",
        "    F, C = len(new_classes_fine_idx), len(new_classes_coarse_idx)\n",
        "    M_fc   = torch.zeros((F, C), dtype=torch.float32, device=device)\n",
        "    Kappa  = torch.zeros((F, C), dtype=torch.float32, device=device)\n",
        "    w_f    = torch.tensor([1.0/len(s) for s in new_classes_fine_idx],   dtype=torch.float32, device=device)\n",
        "    w_c    = torch.tensor([1.0/len(s) for s in new_classes_coarse_idx], dtype=torch.float32, device=device)\n",
        "\n",
        "    # normalize weights to mean≈1\n",
        "    w_f /= (w_f.mean() + 1e-12)\n",
        "    w_c /= (w_c.mean() + 1e-12)\n",
        "\n",
        "    # fine set A -> projected coarse set C(A)\n",
        "    projected_list = []\n",
        "    for A in new_classes_fine_idx:\n",
        "        projected = { fine_to_coarse[i] for i in A }  # set of coarse indices\n",
        "        projected_list.append(projected)\n",
        "\n",
        "    # fill M_fc and Kappa\n",
        "    for i, CA in enumerate(projected_list):\n",
        "        sizeA = max(1, len(CA))\n",
        "        for j, B in enumerate(new_classes_coarse_idx):\n",
        "            inter = len(CA & B)\n",
        "            if inter == 0:\n",
        "                continue\n",
        "            M_fc[i, j] = 1.0\n",
        "            # softness: “how much of CA is inside B”\n",
        "            Kappa[i, j] = inter / sizeA   # ∈ (0,1], equals 1 if CA⊆B\n",
        "\n",
        "    return M_fc, Kappa, w_f, w_c\n"
      ],
      "metadata": {
        "id": "kI_oI7W2uQRH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# consistency term\n",
        "def belief_consistency_weighted(\n",
        "    mass_fine,        # [B, F] masses on fine focal sets (no Ω)\n",
        "    mass_coarse,      # [B, C] masses on coarse focal sets (no Ω)\n",
        "    M_fc,             # [F, C] 0/1 feasibility mask\n",
        "    Kappa,            # [F, C] compatibility in [0,1]\n",
        "    w_f,              # [F] fine specificity weights (mean≈1)\n",
        "    w_c,              # [C] coarse precision weights (mean≈1)\n",
        "    t_norm,           # e.g. product/gödel/Łukasiewicz\n",
        "    membership_fn,    # fuzzifier for coarse masses, signature: fn(x, sigma=?)\n",
        "    sigma=1.0,\n",
        "):\n",
        "    # Clamp only for fuzzy logic (do NOT clamp for the penalties)\n",
        "    mF = mass_fine.clamp(0.0, 1.0)      # [B, F]\n",
        "    mC = mass_coarse.clamp(0.0, 1.0)    # [B, C]\n",
        "\n",
        "    mu_coarse = membership_fn(mC, sigma=sigma)  # [B, C]\n",
        "\n",
        "    # Broadcast to [B, F, C]\n",
        "    mF  = mF.unsqueeze(2)               # [B, F, 1]\n",
        "    muC = mu_coarse.unsqueeze(1)        # [B, 1, C]\n",
        "\n",
        "    MC  = M_fc.unsqueeze(0)             # [1, F, C]\n",
        "    KAP = Kappa.unsqueeze(0)            # [1, F, C]\n",
        "\n",
        "    combined = t_norm(mF, muC) * MC * KAP\n",
        "\n",
        "    WF = w_f.view(1, -1, 1)\n",
        "    WC = w_c.view(1, 1, -1)\n",
        "    combined = combined * WF * WC\n",
        "\n",
        "    valid_pairs = (M_fc * Kappa).sum().clamp_min(1.0)  # scalar normalizer\n",
        "    per_sample  = combined.sum(dim=(1, 2)) / valid_pairs\n",
        "    return per_sample.mean()   # scalar consistency in [0,1]\n"
      ],
      "metadata": {
        "id": "egeFEc3jpO-U"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## LOSS FUNCTION"
      ],
      "metadata": {
        "id": "679TAqPhKAD2"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class LearnableBeliefLoss_bce(nn.Module):\n",
        "    def __init__(self, mass_coeff_matrix_fine, mass_coeff_matrix_coarse,\n",
        "                 M_fc, Kappa, w_f, w_c, t_norm_function, membership_fn,\n",
        "                 sigma=1.0, epsilon=1e-8):\n",
        "        super().__init__()\n",
        "        self.log_alpha = nn.Parameter(torch.tensor(-2.3))\n",
        "        self.log_beta  = nn.Parameter(torch.tensor(-2.3))\n",
        "        self.log_gamma = nn.Parameter(torch.tensor(-2.3))\n",
        "\n",
        "        self.mass_coeff_matrix_fine   = mass_coeff_matrix_fine\n",
        "        self.mass_coeff_matrix_coarse = mass_coeff_matrix_coarse\n",
        "        self.t_norm_function = t_norm_function\n",
        "        self.membership_fn   = membership_fn\n",
        "        self.sigma = sigma\n",
        "        self.epsilon = epsilon\n",
        "\n",
        "        self.register_buffer(\"M_fc\", M_fc)\n",
        "        self.register_buffer(\"Kappa\", Kappa)\n",
        "        self.register_buffer(\"w_f\", w_f)\n",
        "        self.register_buffer(\"w_c\", w_c)\n",
        "\n",
        "    def forward(self, fine_logits, fine_labels, coarse_logits, coarse_labels):\n",
        "        # 1) BCE on logits\n",
        "        bce_fine   = F.binary_cross_entropy_with_logits(fine_logits,   fine_labels.float())\n",
        "        bce_coarse = F.binary_cross_entropy_with_logits(coarse_logits, coarse_labels.float())\n",
        "\n",
        "        # 2) Möbius inverse → focal-set masses (NO Ω here)\n",
        "        fine_probs   = torch.sigmoid(fine_logits)\n",
        "        coarse_probs = torch.sigmoid(coarse_logits)\n",
        "\n",
        "        M_f = self.mass_coeff_matrix_fine.to(device=fine_probs.device, dtype=fine_probs.dtype)\n",
        "        M_c = self.mass_coeff_matrix_coarse.to(device=coarse_probs.device, dtype=coarse_probs.dtype)\n",
        "\n",
        "        mass_fine   = fine_probs   @ M_f     # [B, F]\n",
        "        mass_coarse = coarse_probs @ M_c     # [B, C]\n",
        "\n",
        "        #  negatives penalty\n",
        "        mass_reg_fine   = F.relu(-mass_fine).mean()\n",
        "        mass_reg_coarse = F.relu(-mass_coarse).mean()\n",
        "        #   sum<=1 (per-sample, then mean)\n",
        "        mass_sum_fine   = F.relu(mass_fine.sum(dim=-1)   - 1).mean()\n",
        "        mass_sum_coarse = F.relu(mass_coarse.sum(dim=-1) - 1).mean()\n",
        "\n",
        "        # 4) Fuzzy logical consistency (internally clamps for fuzzy logic only)\n",
        "        consistency = belief_consistency_weighted(\n",
        "            mass_fine, mass_coarse,\n",
        "            self.M_fc, self.Kappa, self.w_f, self.w_c,\n",
        "            self.t_norm_function, self.membership_fn, self.sigma\n",
        "        )\n",
        "\n",
        "        # 5) Learned positive weights\n",
        "        alpha = torch.exp(self.log_alpha).clamp(1e-5, 1.0)\n",
        "        beta  = torch.exp(self.log_beta ).clamp(1e-5, 1.0)\n",
        "        gamma = torch.exp(self.log_gamma).clamp(1e-5, 10.0)\n",
        "\n",
        "        total_loss = (\n",
        "            bce_fine + bce_coarse\n",
        "            + alpha * (mass_reg_fine + mass_reg_coarse)\n",
        "            + beta  * (mass_sum_fine + mass_sum_coarse)\n",
        "            + gamma * (1 - consistency)\n",
        "        )\n",
        "\n",
        "        aux = {\n",
        "            \"bce_fine\": bce_fine.item(), \"bce_coarse\": bce_coarse.item(),\n",
        "            \"mass_reg\": (mass_reg_fine + mass_reg_coarse).item(),\n",
        "            \"mass_sum\": (mass_sum_fine + mass_sum_coarse).item(),\n",
        "            \"consistency_penalty\": (gamma * (1 - consistency)).item(),\n",
        "            \"alpha\": alpha.item(), \"beta\": beta.item(), \"gamma\": gamma.item()\n",
        "        }\n",
        "        return total_loss, aux\n"
      ],
      "metadata": {
        "id": "8ISgIFAuJ_Ry"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## TRAIN AND VALIDATION - NEUROSYMBOLIC EPISTEMIC AI"
      ],
      "metadata": {
        "id": "0p-vqwlqtLv6"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Convert mass_coeff_matrix for fine and coarse to PyTorch tensors\n",
        "mass_coeff_matrix_fine = torch.tensor(mass_coeff(new_classes_fine), dtype=torch.float32)\n",
        "mass_coeff_matrix_coarse = torch.tensor(mass_coeff(new_classes_coarse), dtype=torch.float32)\n",
        "\n",
        "mass_coeff_matrix_fine = mass_coeff_matrix_fine.to(device)\n",
        "mass_coeff_matrix_coarse = mass_coeff_matrix_coarse.to(device)\n"
      ],
      "metadata": {
        "id": "iPH_2qvrNnrd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "M_fine = build_mass_coeff_tensor(new_classes_fine_idx, device)\n",
        "P_fine = build_betp_matrix(range(len(fine_classes)), new_classes_fine_idx, device)\n",
        "\n",
        "M_coarse = build_mass_coeff_tensor(new_classes_coarse_idx, device)\n",
        "P_coarse = build_betp_matrix(range(len(coarse_classes)), new_classes_coarse_idx, device)\n"
      ],
      "metadata": {
        "id": "_h-cCqgv96Gh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### TRANING AND VALIDATION STEPS"
      ],
      "metadata": {
        "id": "v44ThzrmvaYz"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "scaler = torch.cuda.amp.GradScaler()\n",
        "\n",
        "def train_one_epoch(model, criterion, optimizer, loader, device, epoch, grad_clip=None):\n",
        "    model.train()\n",
        "    running_loss, n_batches = 0.0, 0\n",
        "\n",
        "    for inputs, fine_labels, coarse_labels in loader:\n",
        "        inputs      = inputs.to(device, non_blocking=True)\n",
        "        fine_labels = fine_labels.to(device, non_blocking=True)\n",
        "        coarse_labels = coarse_labels.to(device, non_blocking=True)\n",
        "\n",
        "        optimizer.zero_grad(set_to_none=True)\n",
        "        with torch.amp.autocast('cuda', enabled=(device.type == \"cuda\")):\n",
        "            fine_logits, coarse_logits = model(inputs)\n",
        "            loss, _ = criterion(fine_logits, fine_labels, coarse_logits, coarse_labels)\n",
        "\n",
        "        scaler.scale(loss).backward()\n",
        "\n",
        "        # optional gradient clipping\n",
        "        if grad_clip is not None:\n",
        "            scaler.unscale_(optimizer)\n",
        "            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)\n",
        "\n",
        "        scaler.step(optimizer)\n",
        "        scaler.update()\n",
        "\n",
        "        running_loss += float(loss.detach())\n",
        "        n_batches += 1\n",
        "\n",
        "    return running_loss / max(1, n_batches)\n"
      ],
      "metadata": {
        "id": "OEeHCIxkvYVr",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "9364f1c0-8be8-4b8a-a5c0-ed1fbbe1ead2"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/tmp/ipython-input-3620454188.py:1: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n",
            "  scaler = torch.cuda.amp.GradScaler()\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "@torch.no_grad()\n",
        "def validate_one_epoch(model, criterion, loader, device,\n",
        "                       M_fine, P_fine, M_coarse, P_coarse):\n",
        "    \"\"\"\n",
        "    Expects val loader to yield 5-tuples:\n",
        "      (image, fine_belief, coarse_belief, fine_true_idx, coarse_true_idx)\n",
        "    \"\"\"\n",
        "    model.eval()\n",
        "    losses=[]\n",
        "    f_logits=[]; c_logits=[]\n",
        "    fine_true_idx_list=[]; coarse_true_idx_list=[]\n",
        "\n",
        "    for batch in loader:\n",
        "        # Unpack either 5-tuple or fallback 3-tuple\n",
        "        if len(batch) == 5:\n",
        "            inputs, fine_labels, coarse_labels, fine_true_idx, coarse_true_idx = batch\n",
        "            fine_true_idx_list.append(fine_true_idx.cpu())\n",
        "            coarse_true_idx_list.append(coarse_true_idx.cpu())\n",
        "        else:\n",
        "            # Fallback\n",
        "            inputs, fine_labels, coarse_labels = batch\n",
        "            raise RuntimeError(\"Val loader must provide true class indices (set fine_true_idx/coarse_true_idx in dataset).\")\n",
        "\n",
        "        inputs = inputs.to(device)\n",
        "        fine_labels = fine_labels.to(device)\n",
        "        coarse_labels = coarse_labels.to(device)\n",
        "\n",
        "        with torch.amp.autocast('cuda', enabled=(device.type == \"cuda\")):\n",
        "            fine_log, coarse_log = model(inputs)\n",
        "            loss, _ = criterion(fine_log, fine_labels, coarse_log, coarse_labels)\n",
        "\n",
        "        losses.append(loss.item())\n",
        "        f_logits.append(fine_log.detach().cpu())\n",
        "        c_logits.append(coarse_log.detach().cpu())\n",
        "\n",
        "    # Stack\n",
        "    fine_logits = torch.cat(f_logits)           # [N, F]\n",
        "    coarse_logits = torch.cat(c_logits)         # [N, C]\n",
        "    ytrue_fine = torch.cat(fine_true_idx_list)  # [N]\n",
        "    ytrue_coarse = torch.cat(coarse_true_idx_list)  # [N]\n",
        "\n",
        "    # Probabilities over focal sets\n",
        "    fine_probs = torch.sigmoid(fine_logits).to(device)\n",
        "    coarse_probs = torch.sigmoid(coarse_logits).to(device)\n",
        "\n",
        "    # Masses (+Ω) and BetP over SINGLETON classes\n",
        "    fine_mass   = belief_to_mass(fine_probs,   M_fine)\n",
        "    coarse_mass = belief_to_mass(coarse_probs, M_coarse)\n",
        "    fine_betp   = final_betp(fine_mass,   P_fine)     # [N, |Θ_fine|]\n",
        "    coarse_betp = final_betp(coarse_mass, P_coarse)   # [N, |Θ_coarse|]\n",
        "\n",
        "    # Predictions (singleton classes)\n",
        "    pred_fine   = fine_betp.argmax(dim=1).cpu()\n",
        "    pred_coarse = coarse_betp.argmax(dim=1).cpu()\n",
        "\n",
        "    # Metrics\n",
        "    fine_acc   = (pred_fine   == ytrue_fine).float().mean().item()\n",
        "    coarse_acc = (pred_coarse == ytrue_coarse).float().mean().item()\n",
        "\n",
        "    # Macro Precision/Recall/F1 (singleton class level)\n",
        "    fine_prec, fine_rec, fine_f1 = compute_singlelabel_metrics(ytrue_fine.numpy(),   pred_fine.numpy())\n",
        "    coarse_prec, coarse_rec, coarse_f1 = compute_singlelabel_metrics(ytrue_coarse.numpy(), pred_coarse.numpy())\n",
        "\n",
        "    # Entropy (normalized) on BetP; and \"softmax\" entropy on focal-set logits (diagnostic)\n",
        "    fine_entropy_betp   = compute_entropy(fine_betp)\n",
        "    coarse_entropy_betp = compute_entropy(coarse_betp)\n",
        "    fine_entropy_soft   = compute_entropy(torch.softmax(fine_logits,   dim=1))\n",
        "    coarse_entropy_soft = compute_entropy(torch.softmax(coarse_logits, dim=1))\n",
        "\n",
        "    # ECE (BetP vs true class), and ECE over focal-set-softmax (diagnostic)\n",
        "    fine_ece_betp   = compute_ece_pytorch(fine_betp.cpu(),   ytrue_fine)\n",
        "    coarse_ece_betp = compute_ece_pytorch(coarse_betp.cpu(), ytrue_coarse)\n",
        "\n",
        "    # Ω stats\n",
        "    fine_omega_mean   = fine_mass[:, -1].mean().item()\n",
        "    coarse_omega_mean = coarse_mass[:, -1].mean().item()\n",
        "\n",
        "    # Logical consistency (singleton predictions)\n",
        "    logical_cons = calculate_logical_consistency(pred_fine, pred_coarse)\n",
        "\n",
        "    metrics = {\n",
        "        \"loss\": float(np.mean(losses)),\n",
        "        \"fine_acc\": fine_acc, \"coarse_acc\": coarse_acc,\n",
        "        \"fine_precision\": float(fine_prec), \"fine_recall\": float(fine_rec), \"fine_f1\": float(fine_f1),\n",
        "        \"coarse_precision\": float(coarse_prec), \"coarse_recall\": float(coarse_rec), \"coarse_f1\": float(coarse_f1),\n",
        "        \"fine_entropy_betp\": fine_entropy_betp, \"coarse_entropy_betp\": coarse_entropy_betp,\n",
        "        \"fine_entropy_softmax\": fine_entropy_soft, \"coarse_entropy_softmax\": coarse_entropy_soft,\n",
        "        \"fine_ece_betp\": fine_ece_betp, \"coarse_ece_betp\": coarse_ece_betp,\n",
        "        \"fine_omega_mean\": fine_omega_mean, \"coarse_omega_mean\": coarse_omega_mean,\n",
        "        \"logical_consistency\": logical_cons\n",
        "    }\n",
        "    return metrics\n"
      ],
      "metadata": {
        "id": "EXPv29BIvWh5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### TRAIN AND VALIDATION LOOP"
      ],
      "metadata": {
        "id": "zlZ2WmHrvsZ7"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Train and validation loop function\n",
        "@torch.no_grad()\n",
        "def get_alpha_beta_gamma(criterion):\n",
        "    alpha = torch.exp(criterion.log_alpha).clamp(1e-5, 1.0).item()\n",
        "    beta  = torch.exp(criterion.log_beta ).clamp(1e-5, 1.0).item()\n",
        "    gamma = torch.exp(criterion.log_gamma).clamp(1e-5,10.0).item()\n",
        "    return alpha, beta, gamma\n",
        "\n",
        "def train_model(config_key, model, criterion, optimizer, scheduler,\n",
        "                train_loader, val_loader, num_epochs, device,\n",
        "                early_stopping_patience=5,\n",
        "                M_fine=None, P_fine=None, M_coarse=None, P_coarse=None):\n",
        "\n",
        "    # Load checkpoint if available (specific to this config)\n",
        "    start_epoch, results, best_val_loss, prev_config, early_stopping_counter = load_checkpoint(\n",
        "        model, criterion, optimizer, scheduler, scaler, device, config_key\n",
        "    )\n",
        "\n",
        "    for epoch in range(start_epoch, num_epochs):\n",
        "        print(f\"[{config_key}] Epoch {epoch+1}/{num_epochs}\")\n",
        "\n",
        "        # Train\n",
        "        train_loss = train_one_epoch(\n",
        "            model, criterion, optimizer, train_loader, device, epoch\n",
        "        )\n",
        "\n",
        "        # Validate\n",
        "        val_log = validate_one_epoch(\n",
        "            model, criterion, val_loader, device,\n",
        "            M_fine, P_fine, M_coarse, P_coarse\n",
        "        )\n",
        "\n",
        "        results[\"train\"].append(train_loss)\n",
        "        results[\"val\"].append(val_log)\n",
        "\n",
        "        # Check improvement\n",
        "        if val_log[\"loss\"] < best_val_loss:\n",
        "            best_val_loss = val_log[\"loss\"]\n",
        "            alpha, beta, gamma = get_alpha_beta_gamma(criterion)\n",
        "            val_log.update({\"alpha\": alpha, \"beta\": beta, \"gamma\": gamma})\n",
        "\n",
        "            # Save best weights only\n",
        "            best_model_path = os.path.join(FOCAL_DIR, f\"best_model_{config_key}.pth\")\n",
        "            torch.save(model.state_dict(), best_model_path)\n",
        "\n",
        "            # save checkpoint if a new best model is detected\n",
        "            save_checkpoint(epoch, model, criterion, optimizer, scheduler, scaler,\n",
        "                        results, config_key, best_val_loss, early_stopping_counter)\n",
        "\n",
        "            print(f\"[{config_key}] New BEST: val_loss={best_val_loss:.4f} \"\n",
        "                  f\"(fine_acc={val_log['fine_acc']:.3f}, coarse_acc={val_log['coarse_acc']:.3f})\")\n",
        "            early_stopping_counter = 0\n",
        "        else:\n",
        "            early_stopping_counter += 1\n",
        "            print(f\"[{config_key}] No improvement ({early_stopping_counter}/{early_stopping_patience}) \"\n",
        "                  f\"| val_loss={val_log['loss']:.4f} best={best_val_loss:.4f}\")\n",
        "\n",
        "        if early_stopping_counter >= early_stopping_patience:\n",
        "            print(f\"[{config_key}] Early stopping triggered\")\n",
        "            break\n",
        "\n",
        "        scheduler.step(val_log[\"loss\"])\n",
        "\n",
        "    return results"
      ],
      "metadata": {
        "id": "Eq-HkbNptLCJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Hyperparameter grid\n",
        "t_norms = {\"lukasiewicz\": lukasiewicz_t_norm}\n",
        "membership_functions = {\n",
        "    'trapezoidal': trapezoidal_membership,\n",
        "    'gaussian': gaussian_membership,\n",
        "    'triangular': triangular_membership,\n",
        "}\n",
        "\n",
        "num_epochs = 300\n",
        "early_stopping_patience = 5\n",
        "train_total_samples = len(train_loader_new.dataset)\n",
        "\n",
        "\n",
        "# Resume checkpoint if exists\n",
        "if os.path.exists(RESULTS_FILE):\n",
        "    with open(RESULTS_FILE, \"rb\") as f:\n",
        "        all_results = pickle.load(f)\n",
        "    completed_configs = {r[\"config_key\"] for r in all_results}\n",
        "    print(f\"Resuming training. {len(completed_configs)} configs already completed.\")\n",
        "    print(sorted(list(completed_configs))[:5], \"...\")\n",
        "else:\n",
        "    all_results = []\n",
        "    completed_configs = set()\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    all_args = list(itertools.product(t_norms.items(), membership_functions.items()))\n",
        "    formatted_args = [(t[0], t[1], m[0], m[1]) for t, m in all_args]\n",
        "\n",
        "    for args in formatted_args:\n",
        "        t_norm_type, t_norm_fn, membership_type, membership_fn = args\n",
        "        config_key = f\"learnable_{t_norm_type}_{membership_type}\"\n",
        "\n",
        "        if config_key in completed_configs:\n",
        "            print(f\"Skipping already completed config: {config_key}\")\n",
        "            continue\n",
        "\n",
        "        # Build model and loss\n",
        "        model = SwinMultiTask(len(new_classes_fine), len(new_classes_coarse)).to(device)\n",
        "\n",
        "        # Build consistency structures\n",
        "        M_fc, Kappa, w_f, w_c = build_consistency_structures(\n",
        "            new_classes_fine_idx, new_classes_coarse_idx, fine_to_coarse, device\n",
        "        )\n",
        "\n",
        "        criterion = LearnableBeliefLoss_bce(\n",
        "            mass_coeff_matrix_fine,\n",
        "            mass_coeff_matrix_coarse,\n",
        "            M_fc, Kappa, w_f, w_c,\n",
        "            t_norm_fn,\n",
        "            membership_fn\n",
        "        ).to(device)\n",
        "\n",
        "        optimizer = optim.AdamW(\n",
        "            list(model.parameters()) + list(criterion.parameters()), lr=2e-4\n",
        "        )\n",
        "        scheduler = optim.lr_scheduler.ReduceLROnPlateau(\n",
        "            optimizer, 'min', factor=0.1, patience=3\n",
        "        )\n",
        "\n",
        "        # Train and Validation\n",
        "        result = train_model(\n",
        "            config_key=config_key,\n",
        "            model=model,\n",
        "            criterion=criterion,\n",
        "            optimizer=optimizer,\n",
        "            scheduler=scheduler,\n",
        "            train_loader=train_loader_new,\n",
        "            val_loader=val_loader_new,\n",
        "            num_epochs=num_epochs,\n",
        "            device=device,\n",
        "            early_stopping_patience=early_stopping_patience,\n",
        "            M_fine=M_fine, P_fine=P_fine, M_coarse=M_coarse, P_coarse=P_coarse\n",
        "        )\n",
        "\n",
        "        # Add metadata to result\n",
        "        result[\"config_key\"] = config_key\n",
        "        result[\"best_model_path\"] = f\"best_model_{config_key}.pth\"\n",
        "        result[\"best_val_loss\"] = min([m[\"loss\"] for m in result[\"val\"]])\n",
        "\n",
        "        all_results.append(result)\n",
        "\n",
        "        # Save model checkpoint\n",
        "        torch.save(\n",
        "            {\n",
        "                \"model_state_dict\": model.state_dict(),\n",
        "                \"criterion_state_dict\": criterion.state_dict(),\n",
        "                \"optimizer_state_dict\": optimizer.state_dict(),\n",
        "                \"config_key\": config_key,\n",
        "            },\n",
        "            os.path.join(FOCAL_DIR, result[\"best_model_path\"])\n",
        "        )\n",
        "        print(f\"Saved model for {config_key} -> {result['best_model_path']}\")\n",
        "\n",
        "    # Save best result\n",
        "    best_result = min(all_results, key=lambda x: x['best_val_loss'])\n",
        "    with open(BEST_MODEL_METADATA, 'wb') as f:\n",
        "        pickle.dump({\n",
        "            'best_model_path': best_result['best_model_path'],\n",
        "            'best_config_key': best_result['config_key'],\n",
        "        }, f)\n",
        "\n",
        "    with open(RESULTS_FILE, 'wb') as f:\n",
        "        pickle.dump(all_results, f)\n",
        "\n",
        "    print(f\"Saved all results to {RESULTS_FILE}\")\n",
        "    print(f\"Best model path: {best_result['best_model_path']}\")\n"
      ],
      "metadata": {
        "id": "TGBG2aVzsUpQ",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "7086ccebb8de40389de97f4bfd29ae0e",
            "9539b0594a2744b3aa999ba7e348ede3",
            "0274c22b63d941b3ad3fd545f79e058c",
            "d62f3f4a0edb4759b9ea932db60dc147",
            "6086d0e734a74e5788f8035ade29b570",
            "1584e13fbad841ab877b0306a9bb9bc6",
            "781a32bbb5244b14a8525af16e9590cf",
            "71a375e09243424a913ce6f5d1f583e8",
            "93493f3755db4beeb32b9a4cbc14f17e",
            "25852e7661ac4b62966e3d945d002033",
            "06c772e5438b440faa9adf7984a068f1"
          ]
        },
        "outputId": "089d5e28-ee90-410a-88e6-0b1367808e80"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n",
            "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
            "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
            "You will be able to reuse this secret in all of your notebooks.\n",
            "Please note that authentication is recommended but still optional to access public models or datasets.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "7086ccebb8de40389de97f4bfd29ae0e"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[learnable_lukasiewicz_trapezoidal] Resumed from epoch 299 with best_val_loss=0.0496\n",
            "Saved model for learnable_lukasiewicz_trapezoidal -> best_model_learnable_lukasiewicz_trapezoidal.pth\n",
            "[learnable_lukasiewicz_gaussian] Resumed from epoch 299 with best_val_loss=0.0499\n",
            "Saved model for learnable_lukasiewicz_gaussian -> best_model_learnable_lukasiewicz_gaussian.pth\n",
            "[learnable_lukasiewicz_triangular] Resumed from epoch 194 with best_val_loss=0.0639\n",
            "[learnable_lukasiewicz_triangular] Epoch 196/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 196, val_loss=0.0638\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0638 (fine_acc=0.855, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 197/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 197, val_loss=0.0637\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0637 (fine_acc=0.854, coarse_acc=0.924)\n",
            "[learnable_lukasiewicz_triangular] Epoch 198/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 198, val_loss=0.0637\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0637 (fine_acc=0.853, coarse_acc=0.922)\n",
            "[learnable_lukasiewicz_triangular] Epoch 199/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 199, val_loss=0.0633\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0633 (fine_acc=0.855, coarse_acc=0.924)\n",
            "[learnable_lukasiewicz_triangular] Epoch 200/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 200, val_loss=0.0633\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0633 (fine_acc=0.855, coarse_acc=0.924)\n",
            "[learnable_lukasiewicz_triangular] Epoch 201/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 201, val_loss=0.0631\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0631 (fine_acc=0.853, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 202/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 202, val_loss=0.0630\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0630 (fine_acc=0.855, coarse_acc=0.922)\n",
            "[learnable_lukasiewicz_triangular] Epoch 203/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 203, val_loss=0.0626\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0626 (fine_acc=0.854, coarse_acc=0.922)\n",
            "[learnable_lukasiewicz_triangular] Epoch 204/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 204, val_loss=0.0626\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0626 (fine_acc=0.857, coarse_acc=0.922)\n",
            "[learnable_lukasiewicz_triangular] Epoch 205/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 205, val_loss=0.0622\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0622 (fine_acc=0.856, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 206/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0624 best=0.0622\n",
            "[learnable_lukasiewicz_triangular] Epoch 207/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (2/5) | val_loss=0.0623 best=0.0622\n",
            "[learnable_lukasiewicz_triangular] Epoch 208/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 208, val_loss=0.0619\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0619 (fine_acc=0.857, coarse_acc=0.925)\n",
            "[learnable_lukasiewicz_triangular] Epoch 209/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0619 best=0.0619\n",
            "[learnable_lukasiewicz_triangular] Epoch 210/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 210, val_loss=0.0615\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0615 (fine_acc=0.855, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 211/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 211, val_loss=0.0614\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0614 (fine_acc=0.856, coarse_acc=0.922)\n",
            "[learnable_lukasiewicz_triangular] Epoch 212/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 212, val_loss=0.0614\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0614 (fine_acc=0.856, coarse_acc=0.924)\n",
            "[learnable_lukasiewicz_triangular] Epoch 213/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 213, val_loss=0.0610\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0610 (fine_acc=0.857, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 214/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 214, val_loss=0.0608\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0608 (fine_acc=0.856, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 215/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 215, val_loss=0.0604\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0604 (fine_acc=0.857, coarse_acc=0.925)\n",
            "[learnable_lukasiewicz_triangular] Epoch 216/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0605 best=0.0604\n",
            "[learnable_lukasiewicz_triangular] Epoch 217/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (2/5) | val_loss=0.0606 best=0.0604\n",
            "[learnable_lukasiewicz_triangular] Epoch 218/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 218, val_loss=0.0604\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0604 (fine_acc=0.855, coarse_acc=0.924)\n",
            "[learnable_lukasiewicz_triangular] Epoch 219/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 219, val_loss=0.0603\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0603 (fine_acc=0.854, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 220/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 220, val_loss=0.0602\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0602 (fine_acc=0.856, coarse_acc=0.922)\n",
            "[learnable_lukasiewicz_triangular] Epoch 221/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 221, val_loss=0.0596\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0596 (fine_acc=0.854, coarse_acc=0.924)\n",
            "[learnable_lukasiewicz_triangular] Epoch 222/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0599 best=0.0596\n",
            "[learnable_lukasiewicz_triangular] Epoch 223/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 223, val_loss=0.0594\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0594 (fine_acc=0.856, coarse_acc=0.925)\n",
            "[learnable_lukasiewicz_triangular] Epoch 224/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 224, val_loss=0.0589\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0589 (fine_acc=0.858, coarse_acc=0.926)\n",
            "[learnable_lukasiewicz_triangular] Epoch 225/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0592 best=0.0589\n",
            "[learnable_lukasiewicz_triangular] Epoch 226/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 226, val_loss=0.0588\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0588 (fine_acc=0.857, coarse_acc=0.926)\n",
            "[learnable_lukasiewicz_triangular] Epoch 227/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0588 best=0.0588\n",
            "[learnable_lukasiewicz_triangular] Epoch 228/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (2/5) | val_loss=0.0588 best=0.0588\n",
            "[learnable_lukasiewicz_triangular] Epoch 229/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 229, val_loss=0.0584\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0584 (fine_acc=0.857, coarse_acc=0.925)\n",
            "[learnable_lukasiewicz_triangular] Epoch 230/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 230, val_loss=0.0582\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0582 (fine_acc=0.858, coarse_acc=0.924)\n",
            "[learnable_lukasiewicz_triangular] Epoch 231/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 231, val_loss=0.0581\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0581 (fine_acc=0.859, coarse_acc=0.924)\n",
            "[learnable_lukasiewicz_triangular] Epoch 232/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 232, val_loss=0.0579\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0579 (fine_acc=0.857, coarse_acc=0.924)\n",
            "[learnable_lukasiewicz_triangular] Epoch 233/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0580 best=0.0579\n",
            "[learnable_lukasiewicz_triangular] Epoch 234/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 234, val_loss=0.0578\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0578 (fine_acc=0.858, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 235/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 235, val_loss=0.0578\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0578 (fine_acc=0.855, coarse_acc=0.924)\n",
            "[learnable_lukasiewicz_triangular] Epoch 236/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 236, val_loss=0.0576\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0576 (fine_acc=0.857, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 237/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 237, val_loss=0.0572\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0572 (fine_acc=0.856, coarse_acc=0.924)\n",
            "[learnable_lukasiewicz_triangular] Epoch 238/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0572 best=0.0572\n",
            "[learnable_lukasiewicz_triangular] Epoch 239/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 239, val_loss=0.0568\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0568 (fine_acc=0.857, coarse_acc=0.924)\n",
            "[learnable_lukasiewicz_triangular] Epoch 240/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 240, val_loss=0.0565\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0565 (fine_acc=0.857, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 241/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0565 best=0.0565\n",
            "[learnable_lukasiewicz_triangular] Epoch 242/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 242, val_loss=0.0564\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0564 (fine_acc=0.856, coarse_acc=0.926)\n",
            "[learnable_lukasiewicz_triangular] Epoch 243/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0565 best=0.0564\n",
            "[learnable_lukasiewicz_triangular] Epoch 244/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 244, val_loss=0.0563\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0563 (fine_acc=0.855, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 245/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 245, val_loss=0.0563\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0563 (fine_acc=0.857, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 246/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0565 best=0.0563\n",
            "[learnable_lukasiewicz_triangular] Epoch 247/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 247, val_loss=0.0559\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0559 (fine_acc=0.856, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 248/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0565 best=0.0559\n",
            "[learnable_lukasiewicz_triangular] Epoch 249/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (2/5) | val_loss=0.0560 best=0.0559\n",
            "[learnable_lukasiewicz_triangular] Epoch 250/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 250, val_loss=0.0557\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0557 (fine_acc=0.855, coarse_acc=0.922)\n",
            "[learnable_lukasiewicz_triangular] Epoch 251/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 251, val_loss=0.0554\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0554 (fine_acc=0.854, coarse_acc=0.922)\n",
            "[learnable_lukasiewicz_triangular] Epoch 252/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 252, val_loss=0.0554\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0554 (fine_acc=0.856, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 253/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 253, val_loss=0.0553\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0553 (fine_acc=0.857, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 254/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 254, val_loss=0.0552\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0552 (fine_acc=0.855, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 255/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 255, val_loss=0.0552\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0552 (fine_acc=0.857, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 256/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 256, val_loss=0.0547\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0547 (fine_acc=0.858, coarse_acc=0.922)\n",
            "[learnable_lukasiewicz_triangular] Epoch 257/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 257, val_loss=0.0547\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0547 (fine_acc=0.857, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 258/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 258, val_loss=0.0545\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0545 (fine_acc=0.857, coarse_acc=0.924)\n",
            "[learnable_lukasiewicz_triangular] Epoch 259/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0547 best=0.0545\n",
            "[learnable_lukasiewicz_triangular] Epoch 260/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (2/5) | val_loss=0.0545 best=0.0545\n",
            "[learnable_lukasiewicz_triangular] Epoch 261/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 261, val_loss=0.0544\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0544 (fine_acc=0.857, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 262/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 262, val_loss=0.0543\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0543 (fine_acc=0.857, coarse_acc=0.921)\n",
            "[learnable_lukasiewicz_triangular] Epoch 263/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0544 best=0.0543\n",
            "[learnable_lukasiewicz_triangular] Epoch 264/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 264, val_loss=0.0540\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0540 (fine_acc=0.856, coarse_acc=0.922)\n",
            "[learnable_lukasiewicz_triangular] Epoch 265/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 265, val_loss=0.0538\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0538 (fine_acc=0.856, coarse_acc=0.924)\n",
            "[learnable_lukasiewicz_triangular] Epoch 266/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0539 best=0.0538\n",
            "[learnable_lukasiewicz_triangular] Epoch 267/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (2/5) | val_loss=0.0540 best=0.0538\n",
            "[learnable_lukasiewicz_triangular] Epoch 268/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 268, val_loss=0.0536\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0536 (fine_acc=0.857, coarse_acc=0.923)\n",
            "[learnable_lukasiewicz_triangular] Epoch 269/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 269, val_loss=0.0530\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0530 (fine_acc=0.857, coarse_acc=0.922)\n",
            "[learnable_lukasiewicz_triangular] Epoch 270/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0530 best=0.0530\n",
            "[learnable_lukasiewicz_triangular] Epoch 271/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (2/5) | val_loss=0.0534 best=0.0530\n",
            "[learnable_lukasiewicz_triangular] Epoch 272/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 272, val_loss=0.0528\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0528 (fine_acc=0.857, coarse_acc=0.925)\n",
            "[learnable_lukasiewicz_triangular] Epoch 273/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 273, val_loss=0.0526\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0526 (fine_acc=0.856, coarse_acc=0.924)\n",
            "[learnable_lukasiewicz_triangular] Epoch 274/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 274, val_loss=0.0526\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0526 (fine_acc=0.858, coarse_acc=0.924)\n",
            "[learnable_lukasiewicz_triangular] Epoch 275/300\n",
            "[learnable_lukasiewicz_triangular] Checkpoint saved at epoch 275, val_loss=0.0521\n",
            "[learnable_lukasiewicz_triangular] New BEST: val_loss=0.0521 (fine_acc=0.858, coarse_acc=0.925)\n",
            "[learnable_lukasiewicz_triangular] Epoch 276/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (1/5) | val_loss=0.0523 best=0.0521\n",
            "[learnable_lukasiewicz_triangular] Epoch 277/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (2/5) | val_loss=0.0525 best=0.0521\n",
            "[learnable_lukasiewicz_triangular] Epoch 278/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (3/5) | val_loss=0.0527 best=0.0521\n",
            "[learnable_lukasiewicz_triangular] Epoch 279/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (4/5) | val_loss=0.0526 best=0.0521\n",
            "[learnable_lukasiewicz_triangular] Epoch 280/300\n",
            "[learnable_lukasiewicz_triangular] No improvement (5/5) | val_loss=0.0521 best=0.0521\n",
            "[learnable_lukasiewicz_triangular] Early stopping triggered\n",
            "Saved model for learnable_lukasiewicz_triangular -> best_model_learnable_lukasiewicz_triangular.pth\n",
            "Saved all results to /content/drive/MyDrive/NO_WARMUP/focal/all_training_results.pkl\n",
            "Best model path: best_model_learnable_lukasiewicz_trapezoidal.pth\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(f\"Loaded {len(all_results)} configs from {RESULTS_FILE}\\n\")\n",
        "\n",
        "for res in all_results:\n",
        "    # pick best epoch based on validation loss\n",
        "    best_epoch = int(np.argmin([m[\"loss\"] for m in res[\"val\"]]))\n",
        "    best_val = res[\"val\"][best_epoch]\n",
        "    best_train = res[\"train\"][best_epoch]\n",
        "\n",
        "    print(\"=\"*60)\n",
        "    print(f\"Config: {res['config_key']}\")\n",
        "    print(f\" Best epoch: {best_epoch+1}\")\n",
        "    print(f\" Train Loss: {best_train:.4f}\")\n",
        "    print(f\" Val Loss: {best_val['loss']:.4f}\")\n",
        "    print(f\" Val Acc (Fine/Coarse): {best_val['fine_acc']:.4f} / {best_val['coarse_acc']:.4f}\")\n",
        "    print(f\" Val P/R/F1 Fine: {best_val['fine_precision']:.4f} / {best_val['fine_recall']:.4f} / {best_val['fine_f1']:.4f}\")\n",
        "    print(f\" Val P/R/F1 Coarse: {best_val['coarse_precision']:.4f} / {best_val['coarse_recall']:.4f} / {best_val['coarse_f1']:.4f}\")\n",
        "    print(f\" Val Entropy (Fine/Coarse BETP): {best_val['fine_entropy_betp']:.4f} / {best_val['coarse_entropy_betp']:.4f}\")\n",
        "    print(f\" Val Entropy (Fine/Coarse Softmax): {best_val['fine_entropy_softmax']:.4f} / {best_val['coarse_entropy_softmax']:.4f}\")\n",
        "    print(f\" Val ECE (Fine/Coarse BETP): {best_val['fine_ece_betp']:.4f} / {best_val['coarse_ece_betp']:.4f}\")\n",
        "    print(f\" Val Ω Mean (Fine/Coarse): {best_val['fine_omega_mean']:.4f} / {best_val['coarse_omega_mean']:.4f}\")\n",
        "    print(f\" Val Logical Consistency: {best_val['logical_consistency']:.4f}\")\n",
        "    print(f\" Best model path: {res['best_model_path']}\")\n"
      ],
      "metadata": {
        "id": "9_w49zulEBoC",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "d914b4b1-e3e7-4b3b-edd2-dfda7358f41f"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loaded 3 configs from /content/drive/MyDrive/NO_WARMUP/focal/all_training_results.pkl\n",
            "\n",
            "============================================================\n",
            "Config: learnable_lukasiewicz_trapezoidal\n",
            " Best epoch: 298\n",
            " Train Loss: 0.0629\n",
            " Val Loss: 0.0496\n",
            " Val Acc (Fine/Coarse): 0.8573 / 0.9230\n",
            " Val P/R/F1 Fine: 0.8592 / 0.8569 / 0.8558\n",
            " Val P/R/F1 Coarse: 0.9225 / 0.9225 / 0.9221\n",
            " Val Entropy (Fine/Coarse BETP): 0.8804 / 0.2788\n",
            " Val Entropy (Fine/Coarse Softmax): 0.4648 / 0.1877\n",
            " Val ECE (Fine/Coarse BETP): 0.0633 / 0.0080\n",
            " Val Ω Mean (Fine/Coarse): 0.0548 / 0.0196\n",
            " Val Logical Consistency: 0.9650\n",
            " Best model path: best_model_learnable_lukasiewicz_trapezoidal.pth\n",
            "============================================================\n",
            "Config: learnable_lukasiewicz_gaussian\n",
            " Best epoch: 299\n",
            " Train Loss: 0.0626\n",
            " Val Loss: 0.0499\n",
            " Val Acc (Fine/Coarse): 0.8550 / 0.9241\n",
            " Val P/R/F1 Fine: 0.8563 / 0.8545 / 0.8532\n",
            " Val P/R/F1 Coarse: 0.9233 / 0.9234 / 0.9231\n",
            " Val Entropy (Fine/Coarse BETP): 0.8813 / 0.2778\n",
            " Val Entropy (Fine/Coarse Softmax): 0.4634 / 0.1864\n",
            " Val ECE (Fine/Coarse BETP): 0.0608 / 0.0057\n",
            " Val Ω Mean (Fine/Coarse): 0.0549 / 0.0180\n",
            " Val Logical Consistency: 0.9669\n",
            " Best model path: best_model_learnable_lukasiewicz_gaussian.pth\n",
            "============================================================\n",
            "Config: learnable_lukasiewicz_triangular\n",
            " Best epoch: 275\n",
            " Train Loss: 0.0666\n",
            " Val Loss: 0.0521\n",
            " Val Acc (Fine/Coarse): 0.8576 / 0.9254\n",
            " Val P/R/F1 Fine: 0.8589 / 0.8568 / 0.8556\n",
            " Val P/R/F1 Coarse: 0.9247 / 0.9246 / 0.9244\n",
            " Val Entropy (Fine/Coarse BETP): 0.9165 / 0.2839\n",
            " Val Entropy (Fine/Coarse Softmax): 0.4873 / 0.1920\n",
            " Val ECE (Fine/Coarse BETP): 0.0690 / 0.0071\n",
            " Val Ω Mean (Fine/Coarse): 0.0555 / 0.0196\n",
            " Val Logical Consistency: 0.9668\n",
            " Best model path: best_model_learnable_lukasiewicz_triangular.pth\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## TEST NEUROSYMBOLIC EPISTEMIC AI"
      ],
      "metadata": {
        "id": "OxVBjyoqT4Lb"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Force fine level to be aligned with the coarse label when coarse label is confident\n",
        "def constrained_output(fine_probs, coarse_probs,\n",
        "                       M_fine, P_fine, M_coarse, P_coarse,\n",
        "                       fine_to_coarse,\n",
        "                       fine_threshold=0.5, coarse_threshold=0.5,\n",
        "                       device=\"cpu\"):\n",
        "    device = fine_probs.device\n",
        "\n",
        "    fine_mass  = belief_to_mass(fine_probs, M_fine)\n",
        "    fine_focal = final_betp(fine_mass, P_fine)\n",
        "    fine_pred  = torch.argmax(fine_focal, dim=-1)\n",
        "\n",
        "    constrained_coarse_preds = []\n",
        "    for i in range(fine_pred.size(0)):\n",
        "        if fine_focal[i, fine_pred[i]] >= fine_threshold:\n",
        "            expected_coarse = int(fine_to_coarse[int(fine_pred[i].item())])\n",
        "\n",
        "            coarse_mass  = belief_to_mass(coarse_probs[i].unsqueeze(0), M_coarse)\n",
        "            coarse_focal = final_betp(coarse_mass, P_coarse)\n",
        "            coarse_pred  = int(torch.argmax(coarse_focal, dim=-1).item())\n",
        "\n",
        "            if coarse_focal[0, expected_coarse] < coarse_threshold:\n",
        "                constrained_coarse_preds.append(expected_coarse)\n",
        "            else:\n",
        "                constrained_coarse_preds.append(coarse_pred)\n",
        "        else:\n",
        "            coarse_pred = int(torch.argmax(coarse_probs[i]).item())\n",
        "            constrained_coarse_preds.append(coarse_pred)\n",
        "\n",
        "    return fine_pred.to(device), torch.as_tensor(constrained_coarse_preds,\n",
        "                                                 device=device, dtype=torch.long)\n"
      ],
      "metadata": {
        "id": "AUdmb-Kwr00V"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "@torch.no_grad()\n",
        "def evaluate_config(model_path, config_key, M_fine, P_fine, M_coarse, P_coarse,\n",
        "                    fine_threshold=0.5, coarse_threshold=0.5):\n",
        "\n",
        "    # pick a local device\n",
        "    dev = next(iter(M_fine)).device if hasattr(M_fine, \"__iter__\") else (M_fine.device if torch.is_tensor(M_fine) else None)\n",
        "    if dev is None:\n",
        "        dev = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    M_fine   = M_fine.to(dev)\n",
        "    P_fine   = P_fine.to(dev)\n",
        "    M_coarse = M_coarse.to(dev)\n",
        "    P_coarse = P_coarse.to(dev)\n",
        "\n",
        "    # build & load model\n",
        "    model = SwinMultiTask(len(new_classes_fine), len(new_classes_coarse)).to(dev)\n",
        "    state = torch.load(model_path, map_location=dev)\n",
        "    model.load_state_dict(state.get(\"model_state_dict\", state))\n",
        "    model.eval()\n",
        "\n",
        "    # accumulators\n",
        "    N = 0\n",
        "    correct_fine = 0\n",
        "    correct_coarse = 0\n",
        "    logical_cons_sum = 0.0\n",
        "\n",
        "    correct_fine_cov = 0\n",
        "    correct_coarse_cov = 0\n",
        "\n",
        "    corr_fine_betp = 0\n",
        "    corr_coarse_betp = 0\n",
        "\n",
        "    total_fine_ece_betp = total_coarse_ece_betp = 0.0\n",
        "    total_fine_ece_soft  = total_coarse_ece_soft  = 0.0\n",
        "    total_ent_fine_betp  = total_ent_coarse_betp  = 0.0\n",
        "    total_ent_fine_soft  = total_ent_coarse_soft  = 0.0\n",
        "\n",
        "    y_true_fine, y_pred_fine = [], []\n",
        "    y_true_coarse, y_pred_coarse = [], []\n",
        "    y_true_fine_ml, y_pred_fine_ml = [], []\n",
        "    y_true_coarse_ml, y_pred_coarse_ml = [], []\n",
        "\n",
        "    for batch in test_loader_new:\n",
        "        # expect 5-tuple like val\n",
        "        inputs, fine_bel, coarse_bel, ytrue_fine_ids, ytrue_coarse_ids = batch\n",
        "        inputs = inputs.to(device)\n",
        "        fine_bel = fine_bel.to(device); coarse_bel = coarse_bel.to(device)\n",
        "        ytrue_fine_ids = ytrue_fine_ids.to(device); ytrue_coarse_ids = ytrue_coarse_ids.to(device)\n",
        "        bs = inputs.size(0)\n",
        "\n",
        "        fine_logits, coarse_logits = model(inputs)\n",
        "        fine_probs   = torch.sigmoid(fine_logits)\n",
        "        coarse_probs = torch.sigmoid(coarse_logits)\n",
        "\n",
        "        # masses + BetP for both (singleton class space)\n",
        "        fine_mass   = belief_to_mass(fine_probs,   M_fine)\n",
        "        coarse_mass = belief_to_mass(coarse_probs, M_coarse)\n",
        "\n",
        "        # masks (keep everything on device)\n",
        "        fine_focal_mask   = (fine_mass > 0).float()\n",
        "        coarse_focal_mask = (coarse_mass > 0).float()\n",
        "\n",
        "        fine_label_mask   = (fine_focal_mask @ (P_fine > 0).to(fine_focal_mask.device).float()) > 0\n",
        "        coarse_label_mask = (coarse_focal_mask @ (P_coarse > 0).to(coarse_focal_mask.device).float()) > 0\n",
        "\n",
        "        # Coverage: true label inside predicted set?\n",
        "        idx = torch.arange(bs, device=fine_label_mask.device)\n",
        "        cov_fine_batch   = fine_label_mask[idx,   ytrue_fine_ids].float().sum().item()\n",
        "        cov_coarse_batch = coarse_label_mask[idx, ytrue_coarse_ids].float().sum().item()\n",
        "\n",
        "        correct_fine_cov   += cov_fine_batch\n",
        "        correct_coarse_cov += cov_coarse_batch\n",
        "\n",
        "        # BetP in class space\n",
        "        fine_betp   = final_betp(fine_mass,   P_fine)\n",
        "        coarse_betp = final_betp(coarse_mass, P_coarse)\n",
        "\n",
        "        # softmax (diagnostic; same device)\n",
        "        fine_soft   = torch.softmax(fine_logits, dim=1)\n",
        "        coarse_soft = torch.softmax(coarse_logits, dim=1)\n",
        "\n",
        "        # pignistic plain argmax\n",
        "        pig_pred_fine   = fine_betp.argmax(dim=1)\n",
        "        pig_pred_coarse = coarse_betp.argmax(dim=1)\n",
        "\n",
        "        # constrained predictions\n",
        "        final_fine, final_coarse = constrained_output(\n",
        "            fine_probs, coarse_probs,\n",
        "            M_fine, P_fine, M_coarse, P_coarse,\n",
        "            fine_to_coarse,\n",
        "            fine_threshold=fine_threshold,\n",
        "            coarse_threshold=coarse_threshold,\n",
        "            device=device\n",
        "        )\n",
        "\n",
        "        # counts\n",
        "        N += bs\n",
        "        corr_fine_betp   += (pig_pred_fine   == ytrue_fine_ids).sum().item()\n",
        "        corr_coarse_betp += (pig_pred_coarse == ytrue_coarse_ids).sum().item()\n",
        "        correct_fine     += (final_fine      == ytrue_fine_ids).sum().item()\n",
        "        correct_coarse   += (final_coarse    == ytrue_coarse_ids).sum().item()\n",
        "        logical_cons_sum += calculate_logical_consistency_base(final_fine, final_coarse)\n",
        "\n",
        "        # batch-weighted ECE & entropy\n",
        "        total_fine_ece_betp   += compute_ece_pytorch(fine_betp,   ytrue_fine_ids)   * bs\n",
        "        total_coarse_ece_betp += compute_ece_pytorch(coarse_betp, ytrue_coarse_ids) * bs\n",
        "        total_fine_ece_soft   += compute_ece_pytorch(fine_soft,   ytrue_fine_ids)   * bs\n",
        "        total_coarse_ece_soft += compute_ece_pytorch(coarse_soft, ytrue_coarse_ids) * bs\n",
        "\n",
        "        total_ent_fine_betp   += compute_entropy(fine_betp)   * bs\n",
        "        total_ent_coarse_betp += compute_entropy(coarse_betp) * bs\n",
        "        total_ent_fine_soft   += compute_entropy(fine_soft)   * bs\n",
        "        total_ent_coarse_soft += compute_entropy(coarse_soft) * bs\n",
        "\n",
        "        # PRF collectors\n",
        "        y_true_fine.append(ytrue_fine_ids.cpu());       y_pred_fine.append(final_fine.cpu())\n",
        "        y_true_coarse.append(ytrue_coarse_ids.cpu());   y_pred_coarse.append(final_coarse.cpu())\n",
        "\n",
        "        # multilabel over focal sets\n",
        "        y_true_fine_ml.append(fine_bel.cpu());     y_pred_fine_ml.append(fine_probs.cpu())\n",
        "        y_true_coarse_ml.append(coarse_bel.cpu()); y_pred_coarse_ml.append(coarse_probs.cpu())\n",
        "\n",
        "    # aggregate\n",
        "    avg_fine_pig_acc   = corr_fine_betp   / N\n",
        "    avg_coarse_pig_acc = corr_coarse_betp / N\n",
        "    avg_fine_acc       = correct_fine     / N\n",
        "    avg_coarse_acc     = correct_coarse   / N\n",
        "    avg_logical_cons   = logical_cons_sum / N\n",
        "\n",
        "    avg_fine_ece_betp   = total_fine_ece_betp   / N\n",
        "    avg_coarse_ece_betp = total_coarse_ece_betp / N\n",
        "    avg_fine_ece_soft   = total_fine_ece_soft   / N\n",
        "    avg_coarse_ece_soft = total_coarse_ece_soft / N\n",
        "\n",
        "    avg_ent_fine_betp   = total_ent_fine_betp   / N\n",
        "    avg_ent_coarse_betp = total_ent_coarse_betp / N\n",
        "    avg_ent_fine_soft   = total_ent_fine_soft   / N\n",
        "    avg_ent_coarse_soft = total_ent_coarse_soft / N\n",
        "\n",
        "    # PRF (single-label)\n",
        "    yt_fine   = torch.cat(y_true_fine).cpu().numpy()\n",
        "    yp_fine   = torch.cat(y_pred_fine).cpu().numpy()\n",
        "    yt_coarse = torch.cat(y_true_coarse).cpu().numpy()\n",
        "    yp_coarse = torch.cat(y_pred_coarse).cpu().numpy()\n",
        "\n",
        "    prec_fine, rec_fine, f1_fine = compute_singlelabel_metrics(yt_fine, yp_fine)\n",
        "    prec_coarse, rec_coarse, f1_coarse = compute_singlelabel_metrics(yt_coarse, yp_coarse)\n",
        "\n",
        "    # PRF (multilabel on focal-set space)\n",
        "    yt_fine_ml   = torch.cat(y_true_fine_ml)\n",
        "    yp_fine_ml   = torch.cat(y_pred_fine_ml)\n",
        "    yt_coarse_ml = torch.cat(y_true_coarse_ml)\n",
        "    yp_coarse_ml = torch.cat(y_pred_coarse_ml)\n",
        "    ml_prec_fine, ml_rec_fine, ml_f1_fine = compute_multilabel_metrics(yp_fine_ml, yt_fine_ml)\n",
        "    ml_prec_coarse, ml_rec_coarse, ml_f1_coarse = compute_multilabel_metrics(yp_coarse_ml, yt_coarse_ml)\n",
        "\n",
        "    return {\n",
        "        'config_key': config_key,\n",
        "        'best_model_path': model_path,\n",
        "        'test_accuracy_fine': avg_fine_acc,\n",
        "        'test_accuracy_coarse': avg_coarse_acc,\n",
        "        'test_pignistic_accuracy_fine': avg_fine_pig_acc,\n",
        "        'test_pignistic_accuracy_coarse': avg_coarse_pig_acc,\n",
        "        'test_logical_consistency': avg_logical_cons,\n",
        "        # ECE + entropy\n",
        "        'test_ece_fine_betp': avg_fine_ece_betp, 'test_ece_coarse_betp': avg_coarse_ece_betp,\n",
        "        'test_ece_fine_softmax': avg_fine_ece_soft, 'test_ece_coarse_softmax': avg_coarse_ece_soft,\n",
        "        'test_entropy_fine_betp': avg_ent_fine_betp, 'test_entropy_coarse_betp': avg_ent_coarse_betp,\n",
        "        'test_entropy_fine_softmax': avg_ent_fine_soft, 'test_entropy_coarse_softmax': avg_ent_coarse_soft,\n",
        "        # PRF\n",
        "        'test_precision_fine': prec_fine, 'test_recall_fine': rec_fine, 'test_f1_fine': f1_fine,\n",
        "        'test_precision_coarse': prec_coarse, 'test_recall_coarse': rec_coarse, 'test_f1_coarse': f1_coarse,\n",
        "        'test_multilabel_precision_fine': ml_prec_fine, 'test_multilabel_recall_fine': ml_rec_fine, 'test_multilabel_f1_fine': ml_f1_fine,\n",
        "        'test_multilabel_precision_coarse': ml_prec_coarse, 'test_multilabel_recall_coarse': ml_rec_coarse, 'test_multilabel_f1_coarse': ml_f1_coarse,\n",
        "        # Coverage\n",
        "        'test_fine_coverage': correct_fine_cov / N,\n",
        "        'test_coarse_coverage': correct_coarse_cov / N\n",
        "    }\n"
      ],
      "metadata": {
        "id": "t9iVoV-p9I91"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Test trained configs\n",
        "import glob, csv, math\n",
        "\n",
        "TEST_RESULTS_PKL = os.path.join(SAVE_DIR, 'test_results_all.pkl')\n",
        "TEST_RESULTS_CSV = os.path.join(SAVE_DIR, 'focal', 'test_results_all.csv')\n",
        "\n",
        "# Load the results\n",
        "with open(RESULTS_FILE, \"rb\") as f:\n",
        "    all_results = pickle.load(f)\n",
        "\n",
        "print(f\"Loaded {len(all_results)} configs from {RESULTS_FILE}\\n\")\n",
        "\n",
        "# best_model_path\n",
        "for cfg in all_results:\n",
        "    if not os.path.isabs(cfg['best_model_path']):\n",
        "        cfg['best_model_path'] = os.path.join(FOCAL_DIR, cfg['best_model_path'])\n",
        "\n",
        "# test all train models\n",
        "configs = all_results\n",
        "\n",
        "fine_thresholds   = [0.4, 0.5, 0.6]\n",
        "coarse_thresholds = [0.4, 0.5, 0.6]\n",
        "\n",
        "all_test_results = []\n",
        "\n",
        "\n",
        "# move helper matrices once here (saves tiny overhead)\n",
        "M_fine_d   = M_fine.to(device)\n",
        "P_fine_d   = P_fine.to(device)\n",
        "M_coarse_d = M_coarse.to(device)\n",
        "P_coarse_d = P_coarse.to(device)\n",
        "\n",
        "for i, cfg in enumerate(configs, 1):\n",
        "    key = cfg['config_key']\n",
        "    path = cfg['best_model_path']\n",
        "    bvl = cfg.get('best_val_loss', float('nan'))\n",
        "\n",
        "    for ft in fine_thresholds:\n",
        "        for ct in coarse_thresholds:\n",
        "            print(f\"[{i}/{len(configs)}] Testing {key} (val_loss={bvl:.6f}) \"\n",
        "                  f\"with fine_th={ft}, coarse_th={ct}: {path}\")\n",
        "            try:\n",
        "                res = evaluate_config(path, key,\n",
        "                                      M_fine_d, P_fine_d,\n",
        "                                      M_coarse_d, P_coarse_d,\n",
        "                                      fine_threshold=ft, coarse_threshold=ct)\n",
        "                res['best_val_loss'] = bvl\n",
        "                res['fine_threshold'] = ft\n",
        "                res['coarse_threshold'] = ct\n",
        "                res['best_val_loss'] = bvl\n",
        "                all_test_results.append(res)\n",
        "\n",
        "                print(f\" -> Done: FineAcc={res['test_accuracy_fine']:.4f} | \"\n",
        "                      f\"CoarseAcc={res['test_accuracy_coarse']:.4f} | \"\n",
        "                      f\"FineF1={res['test_f1_fine']:.4f} | \"\n",
        "                      f\"CoarseF1={res['test_f1_coarse']:.4f}\\n\")\n",
        "            except Exception as e:\n",
        "                print(f\" !! Failed on {key} with ft={ft}, ct={ct}: {e}\\n\")\n",
        "\n",
        "# Save pkl\n",
        "with open(TEST_RESULTS_PKL, 'wb') as f:\n",
        "    pickle.dump(all_test_results, f)\n",
        "\n",
        "# Print a sorted summary\n",
        "def sort_key(r):\n",
        "    bvl = r.get('best_val_loss', float('nan'))\n",
        "    if not (isinstance(bvl, float) and not math.isnan(bvl)):\n",
        "        # fallback: sort by test_f1_fine desc\n",
        "        return (1, -r.get('test_f1_fine', 0.0))\n",
        "    return (0, bvl)\n",
        "\n",
        "all_test_results_sorted = sorted(all_test_results, key=sort_key)\n",
        "print(\"Test Summary: \\n\")\n",
        "for r in all_test_results_sorted:\n",
        "    print(f\"{r['config_key']:<35} \"\n",
        "          f\"val_loss={r.get('best_val_loss', float('nan')):>8.5f}\\n\"\n",
        "          f\"  Accuracies: Fine={r['test_accuracy_fine']:.4f}, Coarse={r['test_accuracy_coarse']:.4f}\\n\"\n",
        "          f\"  Pignistic Acc: Fine={r['test_pignistic_accuracy_fine']:.4f}, Coarse={r['test_pignistic_accuracy_coarse']:.4f}\\n\"\n",
        "          f\"  Logical Consistency={r['test_logical_consistency']:.4f}\\n\"\n",
        "          f\"  Single-label PRF: Fine(P={r['test_precision_fine']:.4f}, R={r['test_recall_fine']:.4f}, F1={r['test_f1_fine']:.4f}) | \"\n",
        "          f\"Coarse(P={r['test_precision_coarse']:.4f}, R={r['test_recall_coarse']:.4f}, F1={r['test_f1_coarse']:.4f})\\n\"\n",
        "          f\"  Multi-label PRF: Fine(P={r['test_multilabel_precision_fine']:.4f}, R={r['test_multilabel_recall_fine']:.4f}, F1={r['test_multilabel_f1_fine']:.4f}) | \"\n",
        "          f\"Coarse(P={r['test_multilabel_precision_coarse']:.4f}, R={r['test_multilabel_recall_coarse']:.4f}, F1={r['test_multilabel_f1_coarse']:.4f})\\n\"\n",
        "          f\"  ECE (BetP): Fine={r['test_ece_fine_betp']:.4f}, Coarse={r['test_ece_coarse_betp']:.4f}\\n\"\n",
        "          f\"  ECE (Softmax): Fine={r['test_ece_fine_softmax']:.4f}, Coarse={r['test_ece_coarse_softmax']:.4f}\\n\"\n",
        "          f\"  Entropy (BetP): Fine={r['test_entropy_fine_betp']:.4f}, Coarse={r['test_entropy_coarse_betp']:.4f}\\n\"\n",
        "          f\"  Entropy (Softmax): Fine={r['test_entropy_fine_softmax']:.4f}, Coarse={r['test_entropy_coarse_softmax']:.4f}\\n\"\n",
        "    )\n",
        "\n",
        "print(f\"\\nSaved per-config test results to:\\n- {TEST_RESULTS_PKL}\")\n"
      ],
      "metadata": {
        "id": "GNwWCrsYeIUe",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "db2a1ff4-7af4-4849-a2c7-4d4645ad26ec"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loaded 3 configs from /content/drive/MyDrive/NO_WARMUP/focal/all_training_results.pkl\n",
            "\n",
            "[1/3] Testing learnable_lukasiewicz_trapezoidal (val_loss=0.049613) with fine_th=0.4, coarse_th=0.4: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_trapezoidal.pth\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 32 worker processes in total. Our suggested max number of worker in current system is 12, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            " -> Done: FineAcc=0.8568 | CoarseAcc=0.9136 | FineF1=0.8566 | CoarseF1=0.7721\n",
            "\n",
            "[1/3] Testing learnable_lukasiewicz_trapezoidal (val_loss=0.049613) with fine_th=0.4, coarse_th=0.5: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_trapezoidal.pth\n",
            " -> Done: FineAcc=0.8568 | CoarseAcc=0.9130 | FineF1=0.8566 | CoarseF1=0.7716\n",
            "\n",
            "[1/3] Testing learnable_lukasiewicz_trapezoidal (val_loss=0.049613) with fine_th=0.4, coarse_th=0.6: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_trapezoidal.pth\n",
            " -> Done: FineAcc=0.8568 | CoarseAcc=0.9130 | FineF1=0.8566 | CoarseF1=0.7716\n",
            "\n",
            "[1/3] Testing learnable_lukasiewicz_trapezoidal (val_loss=0.049613) with fine_th=0.5, coarse_th=0.4: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_trapezoidal.pth\n",
            " -> Done: FineAcc=0.8568 | CoarseAcc=0.9094 | FineF1=0.8566 | CoarseF1=0.7709\n",
            "\n",
            "[1/3] Testing learnable_lukasiewicz_trapezoidal (val_loss=0.049613) with fine_th=0.5, coarse_th=0.5: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_trapezoidal.pth\n",
            " -> Done: FineAcc=0.8568 | CoarseAcc=0.9093 | FineF1=0.8566 | CoarseF1=0.7709\n",
            "\n",
            "[1/3] Testing learnable_lukasiewicz_trapezoidal (val_loss=0.049613) with fine_th=0.5, coarse_th=0.6: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_trapezoidal.pth\n",
            " -> Done: FineAcc=0.8568 | CoarseAcc=0.9093 | FineF1=0.8566 | CoarseF1=0.7709\n",
            "\n",
            "[1/3] Testing learnable_lukasiewicz_trapezoidal (val_loss=0.049613) with fine_th=0.6, coarse_th=0.4: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_trapezoidal.pth\n",
            " -> Done: FineAcc=0.8568 | CoarseAcc=0.9033 | FineF1=0.8566 | CoarseF1=0.7690\n",
            "\n",
            "[1/3] Testing learnable_lukasiewicz_trapezoidal (val_loss=0.049613) with fine_th=0.6, coarse_th=0.5: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_trapezoidal.pth\n",
            " -> Done: FineAcc=0.8568 | CoarseAcc=0.9032 | FineF1=0.8566 | CoarseF1=0.7689\n",
            "\n",
            "[1/3] Testing learnable_lukasiewicz_trapezoidal (val_loss=0.049613) with fine_th=0.6, coarse_th=0.6: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_trapezoidal.pth\n",
            " -> Done: FineAcc=0.8568 | CoarseAcc=0.9032 | FineF1=0.8566 | CoarseF1=0.7689\n",
            "\n",
            "[2/3] Testing learnable_lukasiewicz_gaussian (val_loss=0.049939) with fine_th=0.4, coarse_th=0.4: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_gaussian.pth\n",
            " -> Done: FineAcc=0.8571 | CoarseAcc=0.9100 | FineF1=0.8571 | CoarseF1=0.7686\n",
            "\n",
            "[2/3] Testing learnable_lukasiewicz_gaussian (val_loss=0.049939) with fine_th=0.4, coarse_th=0.5: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_gaussian.pth\n",
            " -> Done: FineAcc=0.8571 | CoarseAcc=0.9108 | FineF1=0.8571 | CoarseF1=0.7692\n",
            "\n",
            "[2/3] Testing learnable_lukasiewicz_gaussian (val_loss=0.049939) with fine_th=0.4, coarse_th=0.6: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_gaussian.pth\n",
            " -> Done: FineAcc=0.8571 | CoarseAcc=0.9108 | FineF1=0.8571 | CoarseF1=0.7692\n",
            "\n",
            "[2/3] Testing learnable_lukasiewicz_gaussian (val_loss=0.049939) with fine_th=0.5, coarse_th=0.4: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_gaussian.pth\n",
            " -> Done: FineAcc=0.8571 | CoarseAcc=0.9062 | FineF1=0.8571 | CoarseF1=0.7675\n",
            "\n",
            "[2/3] Testing learnable_lukasiewicz_gaussian (val_loss=0.049939) with fine_th=0.5, coarse_th=0.5: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_gaussian.pth\n",
            " -> Done: FineAcc=0.8571 | CoarseAcc=0.9068 | FineF1=0.8571 | CoarseF1=0.7680\n",
            "\n",
            "[2/3] Testing learnable_lukasiewicz_gaussian (val_loss=0.049939) with fine_th=0.5, coarse_th=0.6: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_gaussian.pth\n",
            " -> Done: FineAcc=0.8571 | CoarseAcc=0.9068 | FineF1=0.8571 | CoarseF1=0.7680\n",
            "\n",
            "[2/3] Testing learnable_lukasiewicz_gaussian (val_loss=0.049939) with fine_th=0.6, coarse_th=0.4: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_gaussian.pth\n",
            " -> Done: FineAcc=0.8571 | CoarseAcc=0.9005 | FineF1=0.8571 | CoarseF1=0.7657\n",
            "\n",
            "[2/3] Testing learnable_lukasiewicz_gaussian (val_loss=0.049939) with fine_th=0.6, coarse_th=0.5: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_gaussian.pth\n",
            " -> Done: FineAcc=0.8571 | CoarseAcc=0.9010 | FineF1=0.8571 | CoarseF1=0.7661\n",
            "\n",
            "[2/3] Testing learnable_lukasiewicz_gaussian (val_loss=0.049939) with fine_th=0.6, coarse_th=0.6: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_gaussian.pth\n",
            " -> Done: FineAcc=0.8571 | CoarseAcc=0.9010 | FineF1=0.8571 | CoarseF1=0.7661\n",
            "\n",
            "[3/3] Testing learnable_lukasiewicz_triangular (val_loss=0.052090) with fine_th=0.4, coarse_th=0.4: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_triangular.pth\n",
            " -> Done: FineAcc=0.8542 | CoarseAcc=0.9147 | FineF1=0.8542 | CoarseF1=0.7713\n",
            "\n",
            "[3/3] Testing learnable_lukasiewicz_triangular (val_loss=0.052090) with fine_th=0.4, coarse_th=0.5: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_triangular.pth\n",
            " -> Done: FineAcc=0.8542 | CoarseAcc=0.9147 | FineF1=0.8542 | CoarseF1=0.7713\n",
            "\n",
            "[3/3] Testing learnable_lukasiewicz_triangular (val_loss=0.052090) with fine_th=0.4, coarse_th=0.6: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_triangular.pth\n",
            " -> Done: FineAcc=0.8542 | CoarseAcc=0.9147 | FineF1=0.8542 | CoarseF1=0.7713\n",
            "\n",
            "[3/3] Testing learnable_lukasiewicz_triangular (val_loss=0.052090) with fine_th=0.5, coarse_th=0.4: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_triangular.pth\n",
            " -> Done: FineAcc=0.8542 | CoarseAcc=0.9081 | FineF1=0.8542 | CoarseF1=0.7692\n",
            "\n",
            "[3/3] Testing learnable_lukasiewicz_triangular (val_loss=0.052090) with fine_th=0.5, coarse_th=0.5: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_triangular.pth\n",
            " -> Done: FineAcc=0.8542 | CoarseAcc=0.9074 | FineF1=0.8542 | CoarseF1=0.7686\n",
            "\n",
            "[3/3] Testing learnable_lukasiewicz_triangular (val_loss=0.052090) with fine_th=0.5, coarse_th=0.6: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_triangular.pth\n",
            " -> Done: FineAcc=0.8542 | CoarseAcc=0.9074 | FineF1=0.8542 | CoarseF1=0.7686\n",
            "\n",
            "[3/3] Testing learnable_lukasiewicz_triangular (val_loss=0.052090) with fine_th=0.6, coarse_th=0.4: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_triangular.pth\n",
            " -> Done: FineAcc=0.8542 | CoarseAcc=0.9020 | FineF1=0.8542 | CoarseF1=0.7670\n",
            "\n",
            "[3/3] Testing learnable_lukasiewicz_triangular (val_loss=0.052090) with fine_th=0.6, coarse_th=0.5: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_triangular.pth\n",
            " -> Done: FineAcc=0.8542 | CoarseAcc=0.9016 | FineF1=0.8542 | CoarseF1=0.7667\n",
            "\n",
            "[3/3] Testing learnable_lukasiewicz_triangular (val_loss=0.052090) with fine_th=0.6, coarse_th=0.6: /content/drive/MyDrive/NO_WARMUP/focal/best_model_learnable_lukasiewicz_triangular.pth\n",
            " -> Done: FineAcc=0.8542 | CoarseAcc=0.9016 | FineF1=0.8542 | CoarseF1=0.7667\n",
            "\n",
            "Test Summary: \n",
            "\n",
            "learnable_lukasiewicz_trapezoidal   val_loss= 0.04961\n",
            "  Accuracies: Fine=0.8568, Coarse=0.9136\n",
            "  Pignistic Acc: Fine=0.8568, Coarse=0.9249\n",
            "  Logical Consistency=0.9537\n",
            "  Single-label PRF: Fine(P=0.8613, R=0.8568, F1=0.8566) | Coarse(P=0.7835, R=0.7613, F1=0.7721)\n",
            "  Multi-label PRF: Fine(P=0.9202, R=0.7781, F1=0.8381) | Coarse(P=0.9462, R=0.9102, F1=0.9274)\n",
            "  ECE (BetP): Fine=0.0625, Coarse=0.0161\n",
            "  ECE (Softmax): Fine=0.0407, Coarse=0.0738\n",
            "  Entropy (BetP): Fine=0.9138, Coarse=0.2805\n",
            "  Entropy (Softmax): Fine=0.5874, Coarse=0.3376\n",
            "\n",
            "learnable_lukasiewicz_trapezoidal   val_loss= 0.04961\n",
            "  Accuracies: Fine=0.8568, Coarse=0.9130\n",
            "  Pignistic Acc: Fine=0.8568, Coarse=0.9249\n",
            "  Logical Consistency=0.9555\n",
            "  Single-label PRF: Fine(P=0.8613, R=0.8568, F1=0.8566) | Coarse(P=0.7830, R=0.7608, F1=0.7716)\n",
            "  Multi-label PRF: Fine(P=0.9202, R=0.7781, F1=0.8381) | Coarse(P=0.9462, R=0.9102, F1=0.9274)\n",
            "  ECE (BetP): Fine=0.0625, Coarse=0.0161\n",
            "  ECE (Softmax): Fine=0.0407, Coarse=0.0738\n",
            "  Entropy (BetP): Fine=0.9138, Coarse=0.2805\n",
            "  Entropy (Softmax): Fine=0.5874, Coarse=0.3376\n",
            "\n",
            "learnable_lukasiewicz_trapezoidal   val_loss= 0.04961\n",
            "  Accuracies: Fine=0.8568, Coarse=0.9130\n",
            "  Pignistic Acc: Fine=0.8568, Coarse=0.9249\n",
            "  Logical Consistency=0.9555\n",
            "  Single-label PRF: Fine(P=0.8613, R=0.8568, F1=0.8566) | Coarse(P=0.7830, R=0.7608, F1=0.7716)\n",
            "  Multi-label PRF: Fine(P=0.9202, R=0.7781, F1=0.8381) | Coarse(P=0.9462, R=0.9102, F1=0.9274)\n",
            "  ECE (BetP): Fine=0.0625, Coarse=0.0161\n",
            "  ECE (Softmax): Fine=0.0407, Coarse=0.0738\n",
            "  Entropy (BetP): Fine=0.9138, Coarse=0.2805\n",
            "  Entropy (Softmax): Fine=0.5874, Coarse=0.3376\n",
            "\n",
            "learnable_lukasiewicz_trapezoidal   val_loss= 0.04961\n",
            "  Accuracies: Fine=0.8568, Coarse=0.9094\n",
            "  Pignistic Acc: Fine=0.8568, Coarse=0.9249\n",
            "  Logical Consistency=0.9462\n",
            "  Single-label PRF: Fine(P=0.8613, R=0.8568, F1=0.8566) | Coarse(P=0.7850, R=0.7578, F1=0.7709)\n",
            "  Multi-label PRF: Fine(P=0.9202, R=0.7781, F1=0.8381) | Coarse(P=0.9462, R=0.9102, F1=0.9274)\n",
            "  ECE (BetP): Fine=0.0625, Coarse=0.0161\n",
            "  ECE (Softmax): Fine=0.0407, Coarse=0.0738\n",
            "  Entropy (BetP): Fine=0.9138, Coarse=0.2805\n",
            "  Entropy (Softmax): Fine=0.5874, Coarse=0.3376\n",
            "\n",
            "learnable_lukasiewicz_trapezoidal   val_loss= 0.04961\n",
            "  Accuracies: Fine=0.8568, Coarse=0.9093\n",
            "  Pignistic Acc: Fine=0.8568, Coarse=0.9249\n",
            "  Logical Consistency=0.9475\n",
            "  Single-label PRF: Fine(P=0.8613, R=0.8568, F1=0.8566) | Coarse(P=0.7849, R=0.7578, F1=0.7709)\n",
            "  Multi-label PRF: Fine(P=0.9202, R=0.7781, F1=0.8381) | Coarse(P=0.9462, R=0.9102, F1=0.9274)\n",
            "  ECE (BetP): Fine=0.0625, Coarse=0.0161\n",
            "  ECE (Softmax): Fine=0.0407, Coarse=0.0738\n",
            "  Entropy (BetP): Fine=0.9138, Coarse=0.2805\n",
            "  Entropy (Softmax): Fine=0.5874, Coarse=0.3376\n",
            "\n",
            "learnable_lukasiewicz_trapezoidal   val_loss= 0.04961\n",
            "  Accuracies: Fine=0.8568, Coarse=0.9093\n",
            "  Pignistic Acc: Fine=0.8568, Coarse=0.9249\n",
            "  Logical Consistency=0.9475\n",
            "  Single-label PRF: Fine(P=0.8613, R=0.8568, F1=0.8566) | Coarse(P=0.7849, R=0.7578, F1=0.7709)\n",
            "  Multi-label PRF: Fine(P=0.9202, R=0.7781, F1=0.8381) | Coarse(P=0.9462, R=0.9102, F1=0.9274)\n",
            "  ECE (BetP): Fine=0.0625, Coarse=0.0161\n",
            "  ECE (Softmax): Fine=0.0407, Coarse=0.0738\n",
            "  Entropy (BetP): Fine=0.9138, Coarse=0.2805\n",
            "  Entropy (Softmax): Fine=0.5874, Coarse=0.3376\n",
            "\n",
            "learnable_lukasiewicz_trapezoidal   val_loss= 0.04961\n",
            "  Accuracies: Fine=0.8568, Coarse=0.9033\n",
            "  Pignistic Acc: Fine=0.8568, Coarse=0.9249\n",
            "  Logical Consistency=0.9369\n",
            "  Single-label PRF: Fine(P=0.8613, R=0.8568, F1=0.8566) | Coarse(P=0.7867, R=0.7527, F1=0.7690)\n",
            "  Multi-label PRF: Fine(P=0.9202, R=0.7781, F1=0.8381) | Coarse(P=0.9462, R=0.9102, F1=0.9274)\n",
            "  ECE (BetP): Fine=0.0625, Coarse=0.0161\n",
            "  ECE (Softmax): Fine=0.0407, Coarse=0.0738\n",
            "  Entropy (BetP): Fine=0.9138, Coarse=0.2805\n",
            "  Entropy (Softmax): Fine=0.5874, Coarse=0.3376\n",
            "\n",
            "learnable_lukasiewicz_trapezoidal   val_loss= 0.04961\n",
            "  Accuracies: Fine=0.8568, Coarse=0.9032\n",
            "  Pignistic Acc: Fine=0.8568, Coarse=0.9249\n",
            "  Logical Consistency=0.9374\n",
            "  Single-label PRF: Fine(P=0.8613, R=0.8568, F1=0.8566) | Coarse(P=0.7867, R=0.7527, F1=0.7689)\n",
            "  Multi-label PRF: Fine(P=0.9202, R=0.7781, F1=0.8381) | Coarse(P=0.9462, R=0.9102, F1=0.9274)\n",
            "  ECE (BetP): Fine=0.0625, Coarse=0.0161\n",
            "  ECE (Softmax): Fine=0.0407, Coarse=0.0738\n",
            "  Entropy (BetP): Fine=0.9138, Coarse=0.2805\n",
            "  Entropy (Softmax): Fine=0.5874, Coarse=0.3376\n",
            "\n",
            "learnable_lukasiewicz_trapezoidal   val_loss= 0.04961\n",
            "  Accuracies: Fine=0.8568, Coarse=0.9032\n",
            "  Pignistic Acc: Fine=0.8568, Coarse=0.9249\n",
            "  Logical Consistency=0.9374\n",
            "  Single-label PRF: Fine(P=0.8613, R=0.8568, F1=0.8566) | Coarse(P=0.7867, R=0.7527, F1=0.7689)\n",
            "  Multi-label PRF: Fine(P=0.9202, R=0.7781, F1=0.8381) | Coarse(P=0.9462, R=0.9102, F1=0.9274)\n",
            "  ECE (BetP): Fine=0.0625, Coarse=0.0161\n",
            "  ECE (Softmax): Fine=0.0407, Coarse=0.0738\n",
            "  Entropy (BetP): Fine=0.9138, Coarse=0.2805\n",
            "  Entropy (Softmax): Fine=0.5874, Coarse=0.3376\n",
            "\n",
            "learnable_lukasiewicz_gaussian      val_loss= 0.04994\n",
            "  Accuracies: Fine=0.8571, Coarse=0.9100\n",
            "  Pignistic Acc: Fine=0.8571, Coarse=0.9241\n",
            "  Logical Consistency=0.9502\n",
            "  Single-label PRF: Fine(P=0.8617, R=0.8571, F1=0.8571) | Coarse(P=0.7797, R=0.7583, F1=0.7686)\n",
            "  Multi-label PRF: Fine(P=0.9191, R=0.7781, F1=0.8383) | Coarse(P=0.9444, R=0.9095, F1=0.9260)\n",
            "  ECE (BetP): Fine=0.0647, Coarse=0.0170\n",
            "  ECE (Softmax): Fine=0.0428, Coarse=0.0649\n",
            "  Entropy (BetP): Fine=0.9181, Coarse=0.2818\n",
            "  Entropy (Softmax): Fine=0.5912, Coarse=0.3439\n",
            "\n",
            "learnable_lukasiewicz_gaussian      val_loss= 0.04994\n",
            "  Accuracies: Fine=0.8571, Coarse=0.9108\n",
            "  Pignistic Acc: Fine=0.8571, Coarse=0.9241\n",
            "  Logical Consistency=0.9527\n",
            "  Single-label PRF: Fine(P=0.8617, R=0.8571, F1=0.8571) | Coarse(P=0.7804, R=0.7590, F1=0.7692)\n",
            "  Multi-label PRF: Fine(P=0.9191, R=0.7781, F1=0.8383) | Coarse(P=0.9444, R=0.9095, F1=0.9260)\n",
            "  ECE (BetP): Fine=0.0647, Coarse=0.0170\n",
            "  ECE (Softmax): Fine=0.0428, Coarse=0.0649\n",
            "  Entropy (BetP): Fine=0.9181, Coarse=0.2818\n",
            "  Entropy (Softmax): Fine=0.5912, Coarse=0.3439\n",
            "\n",
            "learnable_lukasiewicz_gaussian      val_loss= 0.04994\n",
            "  Accuracies: Fine=0.8571, Coarse=0.9108\n",
            "  Pignistic Acc: Fine=0.8571, Coarse=0.9241\n",
            "  Logical Consistency=0.9527\n",
            "  Single-label PRF: Fine(P=0.8617, R=0.8571, F1=0.8571) | Coarse(P=0.7804, R=0.7590, F1=0.7692)\n",
            "  Multi-label PRF: Fine(P=0.9191, R=0.7781, F1=0.8383) | Coarse(P=0.9444, R=0.9095, F1=0.9260)\n",
            "  ECE (BetP): Fine=0.0647, Coarse=0.0170\n",
            "  ECE (Softmax): Fine=0.0428, Coarse=0.0649\n",
            "  Entropy (BetP): Fine=0.9181, Coarse=0.2818\n",
            "  Entropy (Softmax): Fine=0.5912, Coarse=0.3439\n",
            "\n",
            "learnable_lukasiewicz_gaussian      val_loss= 0.04994\n",
            "  Accuracies: Fine=0.8571, Coarse=0.9062\n",
            "  Pignistic Acc: Fine=0.8571, Coarse=0.9241\n",
            "  Logical Consistency=0.9429\n",
            "  Single-label PRF: Fine(P=0.8617, R=0.8571, F1=0.8571) | Coarse(P=0.7810, R=0.7552, F1=0.7675)\n",
            "  Multi-label PRF: Fine(P=0.9191, R=0.7781, F1=0.8383) | Coarse(P=0.9444, R=0.9095, F1=0.9260)\n",
            "  ECE (BetP): Fine=0.0647, Coarse=0.0170\n",
            "  ECE (Softmax): Fine=0.0428, Coarse=0.0649\n",
            "  Entropy (BetP): Fine=0.9181, Coarse=0.2818\n",
            "  Entropy (Softmax): Fine=0.5912, Coarse=0.3439\n",
            "\n",
            "learnable_lukasiewicz_gaussian      val_loss= 0.04994\n",
            "  Accuracies: Fine=0.8571, Coarse=0.9068\n",
            "  Pignistic Acc: Fine=0.8571, Coarse=0.9241\n",
            "  Logical Consistency=0.9448\n",
            "  Single-label PRF: Fine(P=0.8617, R=0.8571, F1=0.8571) | Coarse(P=0.7815, R=0.7557, F1=0.7680)\n",
            "  Multi-label PRF: Fine(P=0.9191, R=0.7781, F1=0.8383) | Coarse(P=0.9444, R=0.9095, F1=0.9260)\n",
            "  ECE (BetP): Fine=0.0647, Coarse=0.0170\n",
            "  ECE (Softmax): Fine=0.0428, Coarse=0.0649\n",
            "  Entropy (BetP): Fine=0.9181, Coarse=0.2818\n",
            "  Entropy (Softmax): Fine=0.5912, Coarse=0.3439\n",
            "\n",
            "learnable_lukasiewicz_gaussian      val_loss= 0.04994\n",
            "  Accuracies: Fine=0.8571, Coarse=0.9068\n",
            "  Pignistic Acc: Fine=0.8571, Coarse=0.9241\n",
            "  Logical Consistency=0.9448\n",
            "  Single-label PRF: Fine(P=0.8617, R=0.8571, F1=0.8571) | Coarse(P=0.7815, R=0.7557, F1=0.7680)\n",
            "  Multi-label PRF: Fine(P=0.9191, R=0.7781, F1=0.8383) | Coarse(P=0.9444, R=0.9095, F1=0.9260)\n",
            "  ECE (BetP): Fine=0.0647, Coarse=0.0170\n",
            "  ECE (Softmax): Fine=0.0428, Coarse=0.0649\n",
            "  Entropy (BetP): Fine=0.9181, Coarse=0.2818\n",
            "  Entropy (Softmax): Fine=0.5912, Coarse=0.3439\n",
            "\n",
            "learnable_lukasiewicz_gaussian      val_loss= 0.04994\n",
            "  Accuracies: Fine=0.8571, Coarse=0.9005\n",
            "  Pignistic Acc: Fine=0.8571, Coarse=0.9241\n",
            "  Logical Consistency=0.9347\n",
            "  Single-label PRF: Fine(P=0.8617, R=0.8571, F1=0.8571) | Coarse(P=0.7824, R=0.7504, F1=0.7657)\n",
            "  Multi-label PRF: Fine(P=0.9191, R=0.7781, F1=0.8383) | Coarse(P=0.9444, R=0.9095, F1=0.9260)\n",
            "  ECE (BetP): Fine=0.0647, Coarse=0.0170\n",
            "  ECE (Softmax): Fine=0.0428, Coarse=0.0649\n",
            "  Entropy (BetP): Fine=0.9181, Coarse=0.2818\n",
            "  Entropy (Softmax): Fine=0.5912, Coarse=0.3439\n",
            "\n",
            "learnable_lukasiewicz_gaussian      val_loss= 0.04994\n",
            "  Accuracies: Fine=0.8571, Coarse=0.9010\n",
            "  Pignistic Acc: Fine=0.8571, Coarse=0.9241\n",
            "  Logical Consistency=0.9357\n",
            "  Single-label PRF: Fine(P=0.8617, R=0.8571, F1=0.8571) | Coarse(P=0.7828, R=0.7508, F1=0.7661)\n",
            "  Multi-label PRF: Fine(P=0.9191, R=0.7781, F1=0.8383) | Coarse(P=0.9444, R=0.9095, F1=0.9260)\n",
            "  ECE (BetP): Fine=0.0647, Coarse=0.0170\n",
            "  ECE (Softmax): Fine=0.0428, Coarse=0.0649\n",
            "  Entropy (BetP): Fine=0.9181, Coarse=0.2818\n",
            "  Entropy (Softmax): Fine=0.5912, Coarse=0.3439\n",
            "\n",
            "learnable_lukasiewicz_gaussian      val_loss= 0.04994\n",
            "  Accuracies: Fine=0.8571, Coarse=0.9010\n",
            "  Pignistic Acc: Fine=0.8571, Coarse=0.9241\n",
            "  Logical Consistency=0.9357\n",
            "  Single-label PRF: Fine(P=0.8617, R=0.8571, F1=0.8571) | Coarse(P=0.7828, R=0.7508, F1=0.7661)\n",
            "  Multi-label PRF: Fine(P=0.9191, R=0.7781, F1=0.8383) | Coarse(P=0.9444, R=0.9095, F1=0.9260)\n",
            "  ECE (BetP): Fine=0.0647, Coarse=0.0170\n",
            "  ECE (Softmax): Fine=0.0428, Coarse=0.0649\n",
            "  Entropy (BetP): Fine=0.9181, Coarse=0.2818\n",
            "  Entropy (Softmax): Fine=0.5912, Coarse=0.3439\n",
            "\n",
            "learnable_lukasiewicz_triangular    val_loss= 0.05209\n",
            "  Accuracies: Fine=0.8542, Coarse=0.9147\n",
            "  Pignistic Acc: Fine=0.8542, Coarse=0.9253\n",
            "  Logical Consistency=0.9566\n",
            "  Single-label PRF: Fine(P=0.8586, R=0.8542, F1=0.8542) | Coarse(P=0.7810, R=0.7622, F1=0.7713)\n",
            "  Multi-label PRF: Fine(P=0.9173, R=0.7758, F1=0.8365) | Coarse(P=0.9460, R=0.9097, F1=0.9270)\n",
            "  ECE (BetP): Fine=0.0642, Coarse=0.0154\n",
            "  ECE (Softmax): Fine=0.0405, Coarse=0.0652\n",
            "  Entropy (BetP): Fine=0.9317, Coarse=0.2829\n",
            "  Entropy (Softmax): Fine=0.5962, Coarse=0.3366\n",
            "\n",
            "learnable_lukasiewicz_triangular    val_loss= 0.05209\n",
            "  Accuracies: Fine=0.8542, Coarse=0.9147\n",
            "  Pignistic Acc: Fine=0.8542, Coarse=0.9253\n",
            "  Logical Consistency=0.9593\n",
            "  Single-label PRF: Fine(P=0.8586, R=0.8542, F1=0.8542) | Coarse(P=0.7810, R=0.7622, F1=0.7713)\n",
            "  Multi-label PRF: Fine(P=0.9173, R=0.7758, F1=0.8365) | Coarse(P=0.9460, R=0.9097, F1=0.9270)\n",
            "  ECE (BetP): Fine=0.0642, Coarse=0.0154\n",
            "  ECE (Softmax): Fine=0.0405, Coarse=0.0652\n",
            "  Entropy (BetP): Fine=0.9317, Coarse=0.2829\n",
            "  Entropy (Softmax): Fine=0.5962, Coarse=0.3366\n",
            "\n",
            "learnable_lukasiewicz_triangular    val_loss= 0.05209\n",
            "  Accuracies: Fine=0.8542, Coarse=0.9147\n",
            "  Pignistic Acc: Fine=0.8542, Coarse=0.9253\n",
            "  Logical Consistency=0.9593\n",
            "  Single-label PRF: Fine(P=0.8586, R=0.8542, F1=0.8542) | Coarse(P=0.7810, R=0.7622, F1=0.7713)\n",
            "  Multi-label PRF: Fine(P=0.9173, R=0.7758, F1=0.8365) | Coarse(P=0.9460, R=0.9097, F1=0.9270)\n",
            "  ECE (BetP): Fine=0.0642, Coarse=0.0154\n",
            "  ECE (Softmax): Fine=0.0405, Coarse=0.0652\n",
            "  Entropy (BetP): Fine=0.9317, Coarse=0.2829\n",
            "  Entropy (Softmax): Fine=0.5962, Coarse=0.3366\n",
            "\n",
            "learnable_lukasiewicz_triangular    val_loss= 0.05209\n",
            "  Accuracies: Fine=0.8542, Coarse=0.9081\n",
            "  Pignistic Acc: Fine=0.8542, Coarse=0.9253\n",
            "  Logical Consistency=0.9460\n",
            "  Single-label PRF: Fine(P=0.8586, R=0.8542, F1=0.8542) | Coarse(P=0.7827, R=0.7567, F1=0.7692)\n",
            "  Multi-label PRF: Fine(P=0.9173, R=0.7758, F1=0.8365) | Coarse(P=0.9460, R=0.9097, F1=0.9270)\n",
            "  ECE (BetP): Fine=0.0642, Coarse=0.0154\n",
            "  ECE (Softmax): Fine=0.0405, Coarse=0.0652\n",
            "  Entropy (BetP): Fine=0.9317, Coarse=0.2829\n",
            "  Entropy (Softmax): Fine=0.5962, Coarse=0.3366\n",
            "\n",
            "learnable_lukasiewicz_triangular    val_loss= 0.05209\n",
            "  Accuracies: Fine=0.8542, Coarse=0.9074\n",
            "  Pignistic Acc: Fine=0.8542, Coarse=0.9253\n",
            "  Logical Consistency=0.9477\n",
            "  Single-label PRF: Fine(P=0.8586, R=0.8542, F1=0.8542) | Coarse(P=0.7821, R=0.7562, F1=0.7686)\n",
            "  Multi-label PRF: Fine(P=0.9173, R=0.7758, F1=0.8365) | Coarse(P=0.9460, R=0.9097, F1=0.9270)\n",
            "  ECE (BetP): Fine=0.0642, Coarse=0.0154\n",
            "  ECE (Softmax): Fine=0.0405, Coarse=0.0652\n",
            "  Entropy (BetP): Fine=0.9317, Coarse=0.2829\n",
            "  Entropy (Softmax): Fine=0.5962, Coarse=0.3366\n",
            "\n",
            "learnable_lukasiewicz_triangular    val_loss= 0.05209\n",
            "  Accuracies: Fine=0.8542, Coarse=0.9074\n",
            "  Pignistic Acc: Fine=0.8542, Coarse=0.9253\n",
            "  Logical Consistency=0.9477\n",
            "  Single-label PRF: Fine(P=0.8586, R=0.8542, F1=0.8542) | Coarse(P=0.7821, R=0.7562, F1=0.7686)\n",
            "  Multi-label PRF: Fine(P=0.9173, R=0.7758, F1=0.8365) | Coarse(P=0.9460, R=0.9097, F1=0.9270)\n",
            "  ECE (BetP): Fine=0.0642, Coarse=0.0154\n",
            "  ECE (Softmax): Fine=0.0405, Coarse=0.0652\n",
            "  Entropy (BetP): Fine=0.9317, Coarse=0.2829\n",
            "  Entropy (Softmax): Fine=0.5962, Coarse=0.3366\n",
            "\n",
            "learnable_lukasiewicz_triangular    val_loss= 0.05209\n",
            "  Accuracies: Fine=0.8542, Coarse=0.9020\n",
            "  Pignistic Acc: Fine=0.8542, Coarse=0.9253\n",
            "  Logical Consistency=0.9373\n",
            "  Single-label PRF: Fine(P=0.8586, R=0.8542, F1=0.8542) | Coarse(P=0.7841, R=0.7517, F1=0.7670)\n",
            "  Multi-label PRF: Fine(P=0.9173, R=0.7758, F1=0.8365) | Coarse(P=0.9460, R=0.9097, F1=0.9270)\n",
            "  ECE (BetP): Fine=0.0642, Coarse=0.0154\n",
            "  ECE (Softmax): Fine=0.0405, Coarse=0.0652\n",
            "  Entropy (BetP): Fine=0.9317, Coarse=0.2829\n",
            "  Entropy (Softmax): Fine=0.5962, Coarse=0.3366\n",
            "\n",
            "learnable_lukasiewicz_triangular    val_loss= 0.05209\n",
            "  Accuracies: Fine=0.8542, Coarse=0.9016\n",
            "  Pignistic Acc: Fine=0.8542, Coarse=0.9253\n",
            "  Logical Consistency=0.9381\n",
            "  Single-label PRF: Fine(P=0.8586, R=0.8542, F1=0.8542) | Coarse(P=0.7838, R=0.7513, F1=0.7667)\n",
            "  Multi-label PRF: Fine(P=0.9173, R=0.7758, F1=0.8365) | Coarse(P=0.9460, R=0.9097, F1=0.9270)\n",
            "  ECE (BetP): Fine=0.0642, Coarse=0.0154\n",
            "  ECE (Softmax): Fine=0.0405, Coarse=0.0652\n",
            "  Entropy (BetP): Fine=0.9317, Coarse=0.2829\n",
            "  Entropy (Softmax): Fine=0.5962, Coarse=0.3366\n",
            "\n",
            "learnable_lukasiewicz_triangular    val_loss= 0.05209\n",
            "  Accuracies: Fine=0.8542, Coarse=0.9016\n",
            "  Pignistic Acc: Fine=0.8542, Coarse=0.9253\n",
            "  Logical Consistency=0.9381\n",
            "  Single-label PRF: Fine(P=0.8586, R=0.8542, F1=0.8542) | Coarse(P=0.7838, R=0.7513, F1=0.7667)\n",
            "  Multi-label PRF: Fine(P=0.9173, R=0.7758, F1=0.8365) | Coarse(P=0.9460, R=0.9097, F1=0.9270)\n",
            "  ECE (BetP): Fine=0.0642, Coarse=0.0154\n",
            "  ECE (Softmax): Fine=0.0405, Coarse=0.0652\n",
            "  Entropy (BetP): Fine=0.9317, Coarse=0.2829\n",
            "  Entropy (Softmax): Fine=0.5962, Coarse=0.3366\n",
            "\n",
            "\n",
            "Saved per-config test results to:\n",
            "- /content/drive/MyDrive/NO_WARMUP/test_results_all.pkl\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import csv\n",
        "\n",
        "# Path to save CSV\n",
        "TEST_RESULTS_CSV = os.path.join(SAVE_DIR, 'focal', 'test_results_cifar_warmup_godel_all.csv')\n",
        "\n",
        "# Use the keys from the first result dict as CSV headers\n",
        "if all_test_results:\n",
        "    fieldnames = list(all_test_results[0].keys())\n",
        "\n",
        "    with open(TEST_RESULTS_CSV, mode='w', newline='') as f:\n",
        "        writer = csv.DictWriter(f, fieldnames=fieldnames)\n",
        "        writer.writeheader()\n",
        "        for row in all_test_results:\n",
        "            writer.writerow(row)\n",
        "\n",
        "    print(f\"Saved all test results to CSV:\\n- {TEST_RESULTS_CSV}\")\n",
        "else:\n",
        "    print(\"No results to save.\")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7aq7o_s5UiYP",
        "outputId": "509f81c5-1faf-4970-8b0c-f6e9b02c3949"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Saved all test results to CSV:\n",
            "- /content/drive/MyDrive/NO_WARMUP/focal/test_results_cifar_warmup_godel_all.csv\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## BASE"
      ],
      "metadata": {
        "id": "FAv3bsxgtYcB"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# set scaler for multi process\n",
        "scaler = GradScaler()\n",
        "\n",
        "# Results file\n",
        "BASE_DIR = SAVE_DIR + \"/base\"\n",
        "os.makedirs(BASE_DIR, exist_ok=True)\n",
        "RESULTS_FILE = os.path.join(BASE_DIR, \"/base_all_results.pkl\")\n",
        "\n",
        "# Load previous results if available\n",
        "if os.path.exists(RESULTS_FILE):\n",
        "    with open(RESULTS_FILE, \"rb\") as f:\n",
        "        results = pickle.load(f)\n",
        "else:\n",
        "    results = {}\n",
        "\n",
        "# Base model\n",
        "model_base = SwinMultiTask(\n",
        "    num_fine_labels=num_fine,\n",
        "    num_coarse_labels=num_coarse\n",
        ").to(device)\n",
        "\n",
        "\n",
        "base_results = {\n",
        "    \"train_loss\": [], \"val_loss\": [],\n",
        "    \"val_accuracy_fine\": [],\n",
        "    \"val_accuracy_coarse\": [],\n",
        "    \"val_ece_fine\": [],\n",
        "    \"val_ece_coarse\": [],\n",
        "    \"val_entropy_fine\": [],\n",
        "    \"val_entropy_coarse\": [],\n",
        "    \"val_consistency\": [],\n",
        "    \"val_precision_fine\": [], \"val_recall_fine\": [], \"val_f1_fine\": [],\n",
        "    \"val_precision_coarse\": [], \"val_recall_coarse\": [], \"val_f1_coarse\": [],\n",
        "}\n",
        "\n",
        "\n",
        "# Define loss functions\n",
        "criterion_fine   = nn.CrossEntropyLoss()\n",
        "criterion_coarse = nn.CrossEntropyLoss()\n",
        "\n",
        "# Define optimizer\n",
        "optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model_base.parameters()), lr=2e-4)  # Only train unfrozen layers\n",
        "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)\n",
        "\n",
        "# Parameters for early stopping\n",
        "num_epochs = 300\n",
        "early_stopping_patience = 5\n",
        "best_val_loss = float('inf')\n",
        "early_stopping_counter = 0\n",
        "\n",
        "\n",
        "# Training and validation loop\n",
        "for epoch in range(num_epochs):\n",
        "\n",
        "    # Training\n",
        "    model_base.train()  # Set model to training mode\n",
        "    total_train_loss = 0.0\n",
        "\n",
        "    for inputs, fine_labels, coarse_labels in train_loader_int:\n",
        "        inputs = inputs.to(device, non_blocking=True)\n",
        "        fine_labels, coarse_labels = fine_labels.to(device), coarse_labels.to(device)\n",
        "\n",
        "        # Zero the gradients\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        # Forward + loss under autocast\n",
        "        with autocast():\n",
        "            fine_logits, coarse_logits = model_base(inputs)\n",
        "            fine_loss = criterion_fine(fine_logits, fine_labels)\n",
        "            coarse_loss = criterion_coarse(coarse_logits, coarse_labels)\n",
        "            loss = fine_loss + coarse_loss\n",
        "\n",
        "        # Backpropagation with scaling\n",
        "        scaler.scale(loss).backward()\n",
        "        scaler.step(optimizer)\n",
        "        scaler.update()\n",
        "\n",
        "        # Accumulate metrics\n",
        "        total_train_loss += loss.item()\n",
        "\n",
        "    # Aggregate train\n",
        "    avg_train_loss       = total_train_loss / len(train_loader_int)\n",
        "\n",
        "    # Save traning metrics\n",
        "    base_results[\"train_loss\"].append(avg_train_loss)\n",
        "\n",
        "    # Validation\n",
        "    model_base.eval()  # Set model to evaluation mode\n",
        "    total_val_loss = 0.0\n",
        "    correct_fine_val = 0\n",
        "    correct_coarse_val = 0\n",
        "    total_samples_val = 0\n",
        "\n",
        "    # Initialize storage for additional validation metrics\n",
        "    all_fine_probs_val = []\n",
        "    all_coarse_probs_val = []\n",
        "    all_fine_preds_val = []\n",
        "    all_coarse_preds_val = []\n",
        "    all_fine_labels_val = []\n",
        "    all_coarse_labels_val = []\n",
        "\n",
        "    with torch.no_grad():  # Disable gradient calculation for validation\n",
        "        for inputs, fine_labels, coarse_labels in val_loader_int:\n",
        "            inputs, fine_labels, coarse_labels = inputs.to(device), fine_labels.to(device), coarse_labels.to(device)\n",
        "\n",
        "            with autocast():\n",
        "                fine_logits, coarse_logits = model_base(inputs)\n",
        "                fine_loss = criterion_fine(fine_logits, fine_labels)\n",
        "                coarse_loss = criterion_coarse(coarse_logits, coarse_labels)\n",
        "                loss = fine_loss + coarse_loss\n",
        "\n",
        "            # Accumulate metrics\n",
        "            total_val_loss += loss.item()\n",
        "            _, fine_preds = torch.max(fine_logits, 1)\n",
        "            _, coarse_preds = torch.max(coarse_logits, 1)\n",
        "            correct_fine_val += (fine_preds == fine_labels).sum().item()\n",
        "            correct_coarse_val += (coarse_preds == coarse_labels).sum().item()\n",
        "            total_samples_val += fine_labels.size(0)\n",
        "\n",
        "            # Store data for additional metrics\n",
        "            all_fine_probs_val.append(torch.softmax(fine_logits, dim=1).cpu())\n",
        "            all_coarse_probs_val.append(torch.softmax(coarse_logits, dim=1).cpu())\n",
        "            all_fine_preds_val.append(fine_preds.cpu())\n",
        "            all_coarse_preds_val.append(coarse_preds.cpu())\n",
        "            all_fine_labels_val.append(fine_labels.cpu())\n",
        "            all_coarse_labels_val.append(coarse_labels.cpu())\n",
        "\n",
        "    avg_val_loss        = total_val_loss / len(val_loader_int)\n",
        "    fine_accuracy_val   = correct_fine_val / total_samples_val\n",
        "    coarse_accuracy_val = correct_coarse_val / total_samples_val\n",
        "\n",
        "    all_fine_probs_val   = torch.cat(all_fine_probs_val)\n",
        "    all_coarse_probs_val = torch.cat(all_coarse_probs_val)\n",
        "    all_fine_preds_val   = torch.cat(all_fine_preds_val)\n",
        "    all_coarse_preds_val = torch.cat(all_coarse_preds_val)\n",
        "    all_fine_labels_val  = torch.cat(all_fine_labels_val)\n",
        "    all_coarse_labels_val= torch.cat(all_coarse_labels_val)\n",
        "\n",
        "    fine_ece_val   = compute_ece_pytorch(all_fine_probs_val, all_fine_labels_val)\n",
        "    coarse_ece_val = compute_ece_pytorch(all_coarse_probs_val, all_coarse_labels_val)\n",
        "\n",
        "    fine_entropy_val   = compute_entropy(all_fine_probs_val)\n",
        "    coarse_entropy_val = compute_entropy(all_coarse_probs_val)\n",
        "\n",
        "    logical_consistency_val = calculate_logical_consistency_base(all_fine_preds_val, all_coarse_preds_val) / total_samples_val\n",
        "\n",
        "    P_f_m_va, R_f_m_va, F1_f_m_va = compute_singlelabel_metrics(all_fine_labels_val.numpy(), all_fine_preds_val.numpy())\n",
        "    P_c_m_va, R_c_m_va, F1_c_m_va = compute_singlelabel_metrics(all_coarse_labels_val.numpy(), all_coarse_preds_val.numpy())\n",
        "\n",
        "    # Logging the results\n",
        "    print(f\"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}\")\n",
        "    print(f\"Val Loss: {avg_val_loss:.4f} | Acc(f) {fine_accuracy_val:.4f} Acc(c) {coarse_accuracy_val:.4f}\")\n",
        "    print(f\"Val Logical Consistency: {logical_consistency_val:.4f}\")\n",
        "    print(f\"Val ECE:     Fine {fine_ece_val:.4f} | Coarse {coarse_ece_val:.4f}\")\n",
        "    print(f\"Val Entropy: Fine {fine_entropy_val:.3f} | Coarse {coarse_entropy_val:.3f}\")\n",
        "    print(f\"Val   PR/F1 (fine)   -> Macro:  P {P_f_m_va:.3f} R {R_f_m_va:.3f} F1 {F1_f_m_va:.3f} \")\n",
        "    print(f\"Val   PR/F1 (coarse) -> Macro:  P {P_c_m_va:.3f} R {R_c_m_va:.3f} F1 {F1_c_m_va:.3f} \")\n",
        "\n",
        "    # Save validation metrics\n",
        "    base_results[\"val_loss\"].append(avg_val_loss)\n",
        "    base_results[\"val_accuracy_fine\"].append(fine_accuracy_val)\n",
        "    base_results[\"val_accuracy_coarse\"].append(coarse_accuracy_val)\n",
        "    base_results[\"val_ece_fine\"].append(fine_ece_val)\n",
        "    base_results[\"val_ece_coarse\"].append(coarse_ece_val)\n",
        "    base_results[\"val_entropy_fine\"].append(fine_entropy_val)\n",
        "    base_results[\"val_entropy_coarse\"].append(coarse_entropy_val)\n",
        "    base_results[\"val_consistency\"].append(logical_consistency_val)\n",
        "    base_results[\"val_precision_fine\"].append(P_f_m_va)\n",
        "    base_results[\"val_recall_fine\"].append(R_f_m_va)\n",
        "    base_results[\"val_f1_fine\"].append(F1_f_m_va)\n",
        "    base_results[\"val_precision_coarse\"].append(P_c_m_va)\n",
        "    base_results[\"val_recall_coarse\"].append(R_c_m_va)\n",
        "    base_results[\"val_f1_coarse\"].append(F1_c_m_va)\n",
        "\n",
        "    # Early stopping check\n",
        "    if avg_val_loss < best_val_loss:\n",
        "        best_val_loss = avg_val_loss\n",
        "        early_stopping_counter = 0\n",
        "        model_path = os.path.join(BASE_DIR, \"/best_model.pth\")\n",
        "        torch.save(model_base.state_dict(), model_path)\n",
        "\n",
        "        with open(BASE_DIR + \"/base_results.pkl\", \"wb\") as f:\n",
        "            pickle.dump(base_results, f)\n",
        "        print(f\"Best model saved to {BASE_DIR}\")\n",
        "    else:\n",
        "        early_stopping_counter += 1\n",
        "        if early_stopping_counter >= early_stopping_patience:\n",
        "            print(\"Early stopping triggered\")\n",
        "            break\n",
        "\n",
        "    # Adjust learning rate\n",
        "    scheduler.step(avg_val_loss)\n",
        "\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0ziVM7intXqU",
        "outputId": "115faa7f-1e4e-4ea3-8ca0-789da8f64ddd"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch [1/200], Train Loss: 3.1819\n",
            "Val Loss: 1.9694 | Acc(f) 0.6158 Acc(c) 0.9332\n",
            "Val Logical Consistency: 0.9452\n",
            "Val ECE:     Fine 0.2598 | Coarse 0.1077\n",
            "Val Entropy: Fine 2.645 | Coarse 0.620\n",
            "Val   PR/F1 (fine)   -> Macro:  P 0.669 R 0.536 F1 0.534 \n",
            "Val   PR/F1 (coarse) -> Macro:  P 0.923 R 0.840 F1 0.870 \n",
            "Best model saved to /content/drive/MyDrive/imagenet-1k/base\n",
            "Epoch [2/200], Train Loss: 1.7950\n",
            "Val Loss: 1.3871 | Acc(f) 0.6980 Acc(c) 0.9482\n",
            "Val Logical Consistency: 0.9640\n",
            "Val ECE:     Fine 0.1717 | Coarse 0.0523\n",
            "Val Entropy: Fine 1.898 | Coarse 0.379\n",
            "Val   PR/F1 (fine)   -> Macro:  P 0.722 R 0.659 F1 0.659 \n",
            "Val   PR/F1 (coarse) -> Macro:  P 0.934 R 0.912 F1 0.921 \n",
            "Best model saved to /content/drive/MyDrive/imagenet-1k/base\n",
            "Epoch [3/200], Train Loss: 1.4520\n",
            "Val Loss: 1.1659 | Acc(f) 0.7325 Acc(c) 0.9513\n",
            "Val Logical Consistency: 0.9703\n",
            "Val ECE:     Fine 0.1261 | Coarse 0.0329\n",
            "Val Entropy: Fine 1.546 | Coarse 0.295\n",
            "Val   PR/F1 (fine)   -> Macro:  P 0.755 R 0.709 F1 0.712 \n",
            "Val   PR/F1 (coarse) -> Macro:  P 0.938 R 0.927 F1 0.931 \n",
            "Best model saved to /content/drive/MyDrive/imagenet-1k/base\n",
            "Epoch [4/200], Train Loss: 1.2800\n",
            "Val Loss: 1.0491 | Acc(f) 0.7496 Acc(c) 0.9553\n",
            "Val Logical Consistency: 0.9684\n",
            "Val ECE:     Fine 0.0962 | Coarse 0.0238\n",
            "Val Entropy: Fine 1.339 | Coarse 0.244\n",
            "Val   PR/F1 (fine)   -> Macro:  P 0.761 R 0.736 F1 0.738 \n",
            "Val   PR/F1 (coarse) -> Macro:  P 0.942 R 0.928 F1 0.934 \n",
            "Best model saved to /content/drive/MyDrive/imagenet-1k/base\n",
            "Epoch [5/200], Train Loss: 1.1864\n",
            "Val Loss: 0.9782 | Acc(f) 0.7588 Acc(c) 0.9581\n",
            "Val Logical Consistency: 0.9716\n",
            "Val ECE:     Fine 0.0712 | Coarse 0.0179\n",
            "Val Entropy: Fine 1.184 | Coarse 0.205\n",
            "Val   PR/F1 (fine)   -> Macro:  P 0.770 R 0.744 F1 0.748 \n",
            "Val   PR/F1 (coarse) -> Macro:  P 0.946 R 0.925 F1 0.935 \n",
            "Best model saved to /content/drive/MyDrive/imagenet-1k/base\n",
            "Epoch [6/200], Train Loss: 1.1118\n",
            "Val Loss: 0.9430 | Acc(f) 0.7665 Acc(c) 0.9575\n",
            "Val Logical Consistency: 0.9710\n",
            "Val ECE:     Fine 0.0625 | Coarse 0.0174\n",
            "Val Entropy: Fine 1.111 | Coarse 0.204\n",
            "Val   PR/F1 (fine)   -> Macro:  P 0.772 R 0.757 F1 0.760 \n",
            "Val   PR/F1 (coarse) -> Macro:  P 0.941 R 0.937 F1 0.938 \n",
            "Best model saved to /content/drive/MyDrive/imagenet-1k/base\n",
            "Epoch [7/200], Train Loss: 1.0584\n",
            "Val Loss: 0.9066 | Acc(f) 0.7696 Acc(c) 0.9596\n",
            "Val Logical Consistency: 0.9756\n",
            "Val ECE:     Fine 0.0471 | Coarse 0.0123\n",
            "Val Entropy: Fine 1.027 | Coarse 0.177\n",
            "Val   PR/F1 (fine)   -> Macro:  P 0.777 R 0.760 F1 0.764 \n",
            "Val   PR/F1 (coarse) -> Macro:  P 0.948 R 0.935 F1 0.941 \n",
            "Best model saved to /content/drive/MyDrive/imagenet-1k/base\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# load model\n",
        "model = SwinMultiTask(num_fine_labels=num_fine, num_coarse_labels=num_coarse)\n",
        "model.load_state_dict(torch.load(BASE_DIR + '/best_model.pth', map_location=device))\n",
        "model.to(device)\n",
        "model.eval()\n",
        "\n",
        "# losses (same as train/val)\n",
        "criterion_fine = nn.NLLLoss()\n",
        "criterion_coarse = nn.NLLLoss()\n",
        "\n",
        "total_test_loss = 0.0\n",
        "correct_fine_test = 0\n",
        "correct_coarse_test = 0\n",
        "correct_logical_test = 0\n",
        "total_samples_test = 0\n",
        "\n",
        "# storage for entropy/ECE\n",
        "all_fine_probs, all_coarse_probs = [], []\n",
        "all_fine_targets, all_coarse_targets = [], []\n",
        "all_fine_preds, all_coarse_preds = [], []\n",
        "\n",
        "with torch.no_grad():\n",
        "    for inputs, fine_labels, coarse_labels in test_loader_int:\n",
        "        inputs = inputs.to(device)\n",
        "        fine_targets = fine_labels.to(device)\n",
        "        coarse_targets = coarse_labels.to(device)\n",
        "        bs = inputs.size(0)\n",
        "\n",
        "        fine_logits, coarse_logits = model(inputs)\n",
        "\n",
        "        fine_log_probs   = F.log_softmax(fine_logits, dim=1)\n",
        "        coarse_log_probs = F.log_softmax(coarse_logits, dim=1)\n",
        "\n",
        "        fine_loss   = criterion_fine(fine_log_probs, fine_targets)\n",
        "        coarse_loss = criterion_coarse(coarse_log_probs, coarse_targets)\n",
        "        loss = fine_loss + coarse_loss\n",
        "        total_test_loss += loss.item()\n",
        "\n",
        "        fine_preds   = fine_logits.argmax(dim=1)\n",
        "        coarse_preds = coarse_logits.argmax(dim=1)\n",
        "\n",
        "        correct_fine_test   += (fine_preds == fine_targets).sum().item()\n",
        "        correct_coarse_test += (coarse_preds == coarse_targets).sum().item()\n",
        "        total_samples_test  += bs\n",
        "\n",
        "        # logical consistency\n",
        "        correct_logical_test += calculate_logical_consistency_base(fine_preds, coarse_preds)\n",
        "\n",
        "        # store for entropy/ECE + PR/RC/F1\n",
        "        fine_probs   = torch.softmax(fine_logits, dim=1).cpu()\n",
        "        coarse_probs = torch.softmax(coarse_logits, dim=1).cpu()\n",
        "\n",
        "        all_fine_probs.append(fine_probs)\n",
        "        all_coarse_probs.append(coarse_probs)\n",
        "        all_fine_targets.append(fine_targets.cpu())\n",
        "        all_coarse_targets.append(coarse_targets.cpu())\n",
        "        all_fine_preds.append(fine_preds.cpu())\n",
        "        all_coarse_preds.append(coarse_preds.cpu())\n",
        "\n",
        "\n",
        "# aggregate\n",
        "avg_test_loss        = total_test_loss / max(1, len(test_loader_int))\n",
        "fine_accuracy_test   = correct_fine_test / max(1, total_samples_test)\n",
        "coarse_accuracy_test = correct_coarse_test / max(1, total_samples_test)\n",
        "logical_consistency  = correct_logical_test / max(1, total_samples_test)\n",
        "\n",
        "# concat all\n",
        "all_fine_probs   = torch.cat(all_fine_probs)\n",
        "all_coarse_probs = torch.cat(all_coarse_probs)\n",
        "all_fine_targets = torch.cat(all_fine_targets)\n",
        "all_coarse_targets = torch.cat(all_coarse_targets)\n",
        "all_fine_preds   = torch.cat(all_fine_preds)\n",
        "all_coarse_preds = torch.cat(all_coarse_preds)\n",
        "\n",
        "# entropy\n",
        "avg_fine_entropy   = compute_entropy(all_fine_probs)\n",
        "avg_coarse_entropy = compute_entropy(all_coarse_probs)\n",
        "\n",
        "# ECE\n",
        "avg_fine_ece   = compute_ece_pytorch(all_fine_probs, all_fine_targets, num_bins=10)\n",
        "avg_coarse_ece = compute_ece_pytorch(all_coarse_probs, all_coarse_targets, num_bins=10)\n",
        "\n",
        "# PR/RC/F1 (macro)\n",
        "P_f_m, R_f_m, F1_f_m, _ = precision_recall_fscore_support(\n",
        "    all_fine_targets.numpy(), all_fine_preds.numpy(), average='macro', zero_division=0\n",
        ")\n",
        "P_c_m, R_c_m, F1_c_m, _ = precision_recall_fscore_support(\n",
        "    all_coarse_targets.numpy(), all_coarse_preds.numpy(), average='macro', zero_division=0\n",
        ")\n",
        "\n",
        "print(f\"Test Loss: {avg_test_loss:.4f} | Acc(f) {fine_accuracy_test:.4f} Acc(c) {coarse_accuracy_test:.4f} \"\n",
        "      f\"| Logical Consistency {logical_consistency:.4f}\")\n",
        "print(f\"Entropy: Fine {avg_fine_entropy:.3f} | Coarse {avg_coarse_entropy:.3f}\")\n",
        "print(f\"ECE:     Fine {avg_fine_ece:.4f} | Coarse {avg_coarse_ece:.4f}\")\n",
        "print(f\"(fine)   ->   Precision {P_f_m:.3f} Recall {R_f_m:.3f} F1 {F1_f_m:.3f}\")\n",
        "print(f\"(coarse) ->  Precision {P_c_m:.3f} Recall {R_c_m:.3f} F1 {F1_c_m:.3f}\")\n",
        "\n",
        "test_results = {\n",
        "    \"test_loss\": avg_test_loss,\n",
        "    \"fine_accuracy_test\": fine_accuracy_test,\n",
        "    \"coarse_accuracy_test\": coarse_accuracy_test,\n",
        "    \"logical_consistency_test\": logical_consistency,\n",
        "    \"avg_fine_entropy\": avg_fine_entropy,\n",
        "    \"avg_coarse_entropy\": avg_coarse_entropy,\n",
        "    \"avg_fine_ece\": avg_fine_ece,\n",
        "    \"avg_coarse_ece\": avg_coarse_ece,\n",
        "    \"precision_fine\": P_f_m,\n",
        "    \"recall_fine\": R_f_m,\n",
        "    \"f1_fine\": F1_f_m,\n",
        "    \"precision_coarse\": P_c_m,\n",
        "    \"recall_coarse\": R_c_m,\n",
        "    \"f1_coarse\": F1_c_m,\n",
        "}\n",
        "\n",
        "# Merge everything (train/val already in base_results)\n",
        "all_base_results = {\n",
        "    \"test_results\": test_results,\n",
        "}\n",
        "\n",
        "# Save as pickle\n",
        "with open(os.path.join(SAVE_DIR, \"base_test_results.pkl\"), \"wb\") as f:\n",
        "    pickle.dump(all_base_test_results, f)\n",
        "\n",
        "# Also save as JSON\n",
        "with open(os.path.join(SAVE_DIR, \"base_test_results.json\"), \"w\") as f:\n",
        "    json.dump(all_base_test_results, f, indent=4)\n"
      ],
      "metadata": {
        "id": "Sd9q0RY1yJ9e"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Path  (adjust folder path)\n",
        "results_path = BASE_DIR + \"/base_results.pkl\"\n",
        "\n",
        "# Load pickle\n",
        "with open(results_path, \"rb\") as f:\n",
        "    results = pickle.load(f)\n",
        "\n",
        "#  save ase_results\n",
        "if \"train_val_results\" in results:\n",
        "    base_results = results[\"train_val_results\"]\n",
        "else:\n",
        "    base_results = results\n",
        "\n",
        "# Find best epoch\n",
        "best_idx = int(np.argmin(base_results[\"val_loss\"]))\n",
        "\n",
        "print(\"\\n==== Best Results (Lowest Val Loss) ====\")\n",
        "print(f\"Best Epoch: {best_idx+1}\")\n",
        "print(f\"Train Loss: {base_results['train_loss'][best_idx]:.4f}\")\n",
        "print(f\"Val Loss:   {base_results['val_loss'][best_idx]:.4f}\")\n",
        "\n",
        "\n",
        "print(f\"Val Fine Accuracy:    {base_results['val_accuracy_fine'][best_idx]:.4f}\")\n",
        "\n",
        "print(f\"Val Coarse Accuracy:  {base_results['val_accuracy_coarse'][best_idx]:.4f}\")\n",
        "print(f\"Val Fine F1:     {base_results['val_f1_fine'][best_idx]:.4f}\")\n",
        "print(f\"Val Coarse F1:   {base_results['val_f1_coarse'][best_idx]:.4f}\")\n",
        "print(f\"Val Fine ECE:     {base_results['val_ece_fine'][best_idx]:.4f}\")\n",
        "print(f\"Val Coarse ECE:   {base_results['val_ece_coarse'][best_idx]:.4f}\")\n",
        "print(f\"Val Fine Entropy:     {base_results['val_entropy_fine'][best_idx]:.4f}\")\n",
        "print(f\"Val Coarse Entropy:   {base_results['val_entropy_coarse'][best_idx]:.4f}\")\n",
        "print(f\"Val Consistency:   {base_results['val_consistency'][best_idx]:.4f}\")"
      ],
      "metadata": {
        "id": "DoaUe6x6LOeV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "##CLCO"
      ],
      "metadata": {
        "id": "SO94QjW0vu_x"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### LOSS FUNCTION"
      ],
      "metadata": {
        "id": "rRZtbg2uxDPn"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class FuzzyMembershipLoss(nn.Module):\n",
        "    def __init__(self,\n",
        "                 init_gamma=0.1,\n",
        "                 t_norm=product_t_norm,\n",
        "                 membership_fn=triangular_membership,\n",
        "                 sigma=1.0):\n",
        "        super().__init__()\n",
        "        # raw parameter in log-space for stability\n",
        "        self.log_gamma = nn.Parameter(torch.log(torch.tensor(init_gamma, dtype=torch.float32)))\n",
        "\n",
        "        self.t_norm = t_norm\n",
        "        self.membership_fn = membership_fn\n",
        "        self.sigma = sigma\n",
        "        self.ce_fine = nn.CrossEntropyLoss()\n",
        "        self.ce_coarse = nn.CrossEntropyLoss()\n",
        "\n",
        "        # Build once\n",
        "        if isinstance(fine_to_coarse, dict):\n",
        "            self.register_buffer(\n",
        "                \"mapping_tensor\",\n",
        "                torch.tensor([fine_to_coarse[i] for i in range(len(fine_to_coarse))], dtype=torch.long)\n",
        "            )\n",
        "        else:\n",
        "            self.register_buffer(\"mapping_tensor\", fine_to_coarse.clone())\n",
        "\n",
        "    @property\n",
        "    def gamma(self):\n",
        "        # map log_gamma → positive γ, clamped to [1e-5, 10.0]\n",
        "        return torch.exp(self.log_gamma).clamp(1e-5, 10.0)\n",
        "\n",
        "    def forward(self, fine_logits, fine_labels, coarse_logits, coarse_labels):\n",
        "        # standard CE losses\n",
        "        loss_fine = self.ce_fine(fine_logits, fine_labels)\n",
        "        loss_coarse = self.ce_coarse(coarse_logits, coarse_labels)\n",
        "\n",
        "        fine_preds = torch.argmax(fine_logits, dim=1)\n",
        "        coarse_probs = torch.softmax(coarse_logits, dim=-1)\n",
        "\n",
        "        # map fine → expected coarse class\n",
        "        expected_coarse = self.mapping_tensor[fine_preds]\n",
        "\n",
        "        # gather coarse probs at expected_coarse positions\n",
        "        expected_probs = coarse_probs.gather(1, expected_coarse.unsqueeze(1)).squeeze(1)\n",
        "\n",
        "        # compute membership in batch\n",
        "        membership_values = self.membership_fn(expected_probs, sigma=self.sigma)\n",
        "\n",
        "        # mean consistency score\n",
        "        consistency_score = membership_values.mean()\n",
        "\n",
        "        # constraint term\n",
        "        L_constraints = self.gamma * (1 - consistency_score)\n",
        "\n",
        "        # Total loss\n",
        "        total_loss = loss_fine + loss_coarse + L_constraints\n",
        "        return total_loss, {\n",
        "            \"loss_fine\": loss_fine.item(),\n",
        "            \"loss_coarse\": loss_coarse.item(),\n",
        "            \"consistency_penalty\": L_constraints.item(),\n",
        "            \"gamma\": self.gamma.item()\n",
        "        }\n"
      ],
      "metadata": {
        "id": "DZvXBT1Dxj1f"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### TRAIN AND VALIDATION"
      ],
      "metadata": {
        "id": "snCyn8g3vyK0"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#  Config\n",
        "t_norms = {\n",
        "    'product':      product_t_norm,\n",
        "    'godel':        godel_t_norm,\n",
        "    'lukasiewicz':  lukasiewicz_t_norm,\n",
        "}\n",
        "\n",
        "membership_function = {\n",
        "    'trapezoidal': trapezoidal_membership,\n",
        "    'gaussian': gaussian_membership,\n",
        "    'triangular': triangular_membership}\n",
        "\n",
        "\n",
        "RESULTS_FILE     = SAVE_DIR + '/neuro/neuro_all_results.pkl'\n",
        "CHECKPOINT_FILE  = SAVE_DIR + '/neuro/neuro_training_checkpoint.pkl'\n",
        "BEST_META_FILE   = SAVE_DIR + '/neuro/neuro_best_model_metadata.pkl'\n",
        "NEURO_DIR        = SAVE_DIR + '/neuro'\n",
        "\n",
        "num_epochs = 300\n",
        "early_stopping_patience = 5\n",
        "\n",
        "os.makedirs(NEURO_DIR, exist_ok=True)\n",
        "\n",
        "#  Resume checkpoint list of results\n",
        "if os.path.exists(CHECKPOINT_FILE):\n",
        "    with open(CHECKPOINT_FILE, 'rb') as f:\n",
        "        all_results = pickle.load(f)\n",
        "    completed_configs = {r['config_key'] for r in all_results}\n",
        "    print(f\"Resuming training. {len(completed_configs)} configs already completed.\")\n",
        "    print(sorted(list(completed_configs))[:5], \"...\")\n",
        "else:\n",
        "    all_results = []\n",
        "    completed_configs = set()\n",
        "\n",
        "#  Train config function\n",
        "def train_single_config(t_norm_key, t_norm_func, membership_key, membership_func):\n",
        "    config_key = f\"neuro_{t_norm_key}_{membership_key}\"\n",
        "    print(f\"\\nTraining with {config_key}...\\n\")\n",
        "\n",
        "    partial_path = f\"{NEURO_DIR}/partial_checkpoint_{config_key}.pth\"\n",
        "    start_epoch = 0\n",
        "\n",
        "    results = {\n",
        "        'config_key': config_key,\n",
        "        'train_loss':[],\n",
        "        'val_loss': [],\n",
        "        'val_accuracy_fine': [],\n",
        "        'val_accuracy_coarse': [],\n",
        "        'val_ece_fine': [],\n",
        "        'val_ece_coarse': [],\n",
        "        'val_entropy_fine': [],\n",
        "        'val_entropy_coarse': [],\n",
        "        'val_logical_consistency': [],\n",
        "        'val_precision_fine': [], 'val_recall_fine': [], 'val_f1_fine': [],\n",
        "        'val_precision_coarse': [], 'val_recall_coarse': [], 'val_f1_coarse': [],\n",
        "        'best_val_loss': float('inf'),\n",
        "        'best_model_path': f'{NEURO_DIR}/best_model_{config_key}.pth',\n",
        "    }\n",
        "\n",
        "    # model\n",
        "    model_neuro = SwinMultiTask(\n",
        "        num_fine_labels=num_fine,\n",
        "        num_coarse_labels=num_coarse\n",
        "    ).to(device)\n",
        "\n",
        "    # loss / optimizer/ schedular\n",
        "\n",
        "    loss_fn = FuzzyMembershipLoss(\n",
        "      init_gamma=0.1,\n",
        "      t_norm=t_norm_func,\n",
        "      membership_fn=membership_func,\n",
        "      sigma=1.0).to(device)\n",
        "\n",
        "    optimizer = optim.AdamW(list(model_neuro.parameters()) + list(loss_fn.parameters()), lr=1e-4)\n",
        "\n",
        "    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)\n",
        "    early_stopping_counter = 0\n",
        "\n",
        "    # Resume partial checkpoint\n",
        "    if os.path.exists(partial_path):\n",
        "        checkpoint = torch.load(partial_path, map_location=device)\n",
        "        model_neuro.load_state_dict(checkpoint['model_state_dict'])\n",
        "        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
        "        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n",
        "        loss_fn.load_state_dict(checkpoint['loss_fn_state_dict'])\n",
        "        results = checkpoint['results']\n",
        "        start_epoch = checkpoint['epoch'] + 1\n",
        "        print(f\"Resuming {config_key} from epoch {start_epoch}\")\n",
        "    ...\n",
        "    for epoch in range(start_epoch, num_epochs):\n",
        "        print(f\"Epoch {epoch+1}/{num_epochs}\")\n",
        "\n",
        "        # TRAIN\n",
        "        model_neuro.train()\n",
        "        run_loss = 0.0\n",
        "\n",
        "        for inputs, fine_labels, coarse_labels in train_loader_int:\n",
        "            inputs = inputs.to(device)\n",
        "            fine_labels = fine_labels.to(device)\n",
        "            coarse_labels = coarse_labels.to(device)\n",
        "\n",
        "            optimizer.zero_grad()\n",
        "\n",
        "            # Forward + loss under autocast\n",
        "            with autocast():\n",
        "                fine_logits, coarse_logits = model_neuro(inputs)\n",
        "                loss, _ = loss_fn(fine_logits, fine_labels, coarse_logits, coarse_labels)\n",
        "                run_loss += loss.item()\n",
        "\n",
        "            # Backpropagation with scaling\n",
        "            scaler.scale(loss).backward()\n",
        "            scaler.step(optimizer)\n",
        "            scaler.update()\n",
        "\n",
        "        train_loss = run_loss / len(train_loader_int)\n",
        "\n",
        "        # store results\n",
        "        results['train_loss'].append(train_loss)\n",
        "\n",
        "        # VALIDATION\n",
        "        model_neuro.eval()\n",
        "        val_loss_accum = 0.0\n",
        "        total_val = 0\n",
        "        correct_fine_val = 0\n",
        "        correct_coarse_val = 0\n",
        "        logical_val = 0\n",
        "\n",
        "        all_fine_probs, all_coarse_probs = [], []\n",
        "        all_fine_preds, all_coarse_preds = [], []\n",
        "        all_fine_true, all_coarse_true = [], []\n",
        "\n",
        "        with torch.no_grad():\n",
        "            for inputs, fine_labels, coarse_labels in val_loader_int:\n",
        "                inputs = inputs.to(device)\n",
        "                fine_labels = fine_labels.to(device)\n",
        "                coarse_labels = coarse_labels.to(device)\n",
        "\n",
        "                # Forward + loss under autocast\n",
        "                with autocast():\n",
        "                    fine_logits, coarse_logits = model_neuro(inputs)\n",
        "                    loss, _ = loss_fn(fine_logits, fine_labels, coarse_logits, coarse_labels)\n",
        "                val_loss_accum += loss.item()\n",
        "\n",
        "                fine_probs = torch.softmax(fine_logits, dim=1)\n",
        "                coarse_probs = torch.softmax(coarse_logits, dim=1)\n",
        "                fine_pred = fine_probs.argmax(dim=1)\n",
        "                coarse_pred = coarse_probs.argmax(dim=1)\n",
        "\n",
        "                bs = fine_labels.size(0)\n",
        "                total_val += bs\n",
        "                correct_fine_val += (fine_pred == fine_labels).sum().item()\n",
        "                correct_coarse_val += (coarse_pred == coarse_labels).sum().item()\n",
        "                logical_val += calculate_logical_consistency_base(fine_pred, coarse_pred)\n",
        "\n",
        "                # collect\n",
        "                all_fine_probs.append(fine_probs.cpu())\n",
        "                all_coarse_probs.append(coarse_probs.cpu())\n",
        "                all_fine_preds.append(fine_pred.cpu())\n",
        "                all_coarse_preds.append(coarse_pred.cpu())\n",
        "                all_fine_true.append(fine_labels.cpu())\n",
        "                all_coarse_true.append(coarse_labels.cpu())\n",
        "\n",
        "        # AFTER loop → metrics\n",
        "        all_fine_probs = torch.cat(all_fine_probs)\n",
        "        all_coarse_probs = torch.cat(all_coarse_probs)\n",
        "        all_fine_preds = torch.cat(all_fine_preds)\n",
        "        all_coarse_preds = torch.cat(all_coarse_preds)\n",
        "        all_fine_true  = torch.cat(all_fine_true)\n",
        "        all_coarse_true= torch.cat(all_coarse_true)\n",
        "\n",
        "        val_loss = val_loss_accum / len(val_loader_int)\n",
        "        val_acc_fine = correct_fine_val / total_val\n",
        "        val_acc_coarse = correct_coarse_val / total_val\n",
        "        val_consistency = logical_val / total_val\n",
        "\n",
        "        val_ece_fine   = float(compute_ece_pytorch(all_fine_probs, all_fine_true))\n",
        "        val_ece_coarse = float(compute_ece_pytorch(all_coarse_probs, all_coarse_true))\n",
        "        val_ent_fine   = float(compute_entropy(all_fine_probs))\n",
        "        val_ent_coarse = float(compute_entropy(all_coarse_probs))\n",
        "\n",
        "        P_f_va, R_f_va, F1_f_va, _ = precision_recall_fscore_support(\n",
        "            all_fine_true.numpy(), all_fine_preds.numpy(), average='macro', zero_division=0\n",
        "        )\n",
        "        P_c_va, R_c_va, F1_c_va, _ = precision_recall_fscore_support(\n",
        "            all_coarse_true.numpy(), all_coarse_preds.numpy(), average='macro', zero_division=0\n",
        "        )\n",
        "\n",
        "        results['val_loss'].append(val_loss)\n",
        "        results['val_accuracy_fine'].append(val_acc_fine)\n",
        "        results['val_accuracy_coarse'].append(val_acc_coarse)\n",
        "        results['val_ece_fine'].append(val_ece_fine)\n",
        "        results['val_ece_coarse'].append(val_ece_coarse)\n",
        "        results['val_entropy_fine'].append(val_ent_fine)\n",
        "        results['val_entropy_coarse'].append(val_ent_coarse)\n",
        "        results['val_logical_consistency'].append(val_consistency)\n",
        "        results['val_precision_fine'].append(float(P_f_va))\n",
        "        results['val_recall_fine'].append(float(R_f_va))\n",
        "        results['val_f1_fine'].append(float(F1_f_va))\n",
        "        results['val_precision_coarse'].append(float(P_c_va))\n",
        "        results['val_recall_coarse'].append(float(R_c_va))\n",
        "        results['val_f1_coarse'].append(float(F1_c_va))\n",
        "\n",
        "        #  Save partial checkpoint (epoch-level)\n",
        "        torch.save({\n",
        "            'epoch': epoch,\n",
        "            'model_state_dict': model_neuro.state_dict(),\n",
        "            'optimizer_state_dict': optimizer.state_dict(),\n",
        "            'scheduler_state_dict': scheduler.state_dict(),\n",
        "            'loss_fn_state_dict': loss_fn.state_dict(),\n",
        "            'results': results\n",
        "        }, partial_path)\n",
        "\n",
        "        #  Track best model by val loss\n",
        "        if val_loss < results['best_val_loss']:\n",
        "            results['best_val_loss'] = val_loss\n",
        "            torch.save({\n",
        "                'epoch': epoch,\n",
        "                'model_state_dict': model_neuro.state_dict(),\n",
        "                'optimizer_state_dict': optimizer.state_dict(),\n",
        "                'scheduler_state_dict': scheduler.state_dict(),\n",
        "                'loss_fn_state_dict': loss_fn.state_dict(),\n",
        "            }, results['best_model_path'])\n",
        "            early_stopping_counter = 0\n",
        "            print(f\"[{config_key}] New best val loss: {val_loss:.4f} at epoch {epoch+1}\")\n",
        "        else:\n",
        "            early_stopping_counter += 1\n",
        "            if early_stopping_counter >= early_stopping_patience:\n",
        "                print(\"Early stopping triggered.\")\n",
        "                break\n",
        "\n",
        "        scheduler.step(val_loss)\n",
        "\n",
        "    if os.path.exists(partial_path):\n",
        "        os.remove(partial_path)\n",
        "\n",
        "    return results\n",
        "\n",
        "# Main loop\n",
        "if __name__ == \"__main__\":\n",
        "    all_args = list(itertools.product(t_norms.items(), membership_function.items()))\n",
        "    formatted_args = [(t[0], t[1], m[0], m[1]) for t, m in all_args]\n",
        "    for args in formatted_args:\n",
        "        t_norm_key, t_norm_func, membership_key, membership_fn = args\n",
        "        cfg_key = f\"neuro_{t_norm_key}_{membership_key}\"\n",
        "\n",
        "        if cfg_key in completed_configs:\n",
        "            print(f\"Skipping already completed config: {cfg_key}\")\n",
        "            continue\n",
        "        result = train_single_config(t_norm_key, t_norm_func, membership_key, membership_fn)\n",
        "        all_results.append(result)\n",
        "\n",
        "        # Summary print\n",
        "        print(f\"\\n=== Finished training {result['config_key']} ===\")\n",
        "        print(f\"Train Loss: {result['train_loss'][-1]:.4f} | Val Loss: {result['val_loss'][-1]:.4f}\")\n",
        "        print(f\"Val   Acc (Fine/Coarse): {result['val_accuracy_fine'][-1]:.4f}/{result['val_accuracy_coarse'][-1]:.4f}\")\n",
        "        print(f\"Val ECE (F/C): {result['val_ece_fine'][-1]:.4f}/{result['val_ece_coarse'][-1]:.4f}\")\n",
        "        print(f\"Val Entropy (F/C): {result['val_entropy_fine'][-1]:.4f}/{result['val_entropy_coarse'][-1]:.4f}\")\n",
        "        print(f\"Val Consistency: {result['val_logical_consistency'][-1]:.4f}\")\n",
        "        print(f\"Val  P/R/F1 Fine: {result['val_precision_fine'][-1]:.4f}/{result['val_recall_fine'][-1]:.4f}/{result['val_f1_fine'][-1]:.4f}\")\n",
        "        print(f\"Val  P/R/F1 Coarse: {result['val_precision_coarse'][-1]:.4f}/{result['val_recall_coarse'][-1]:.4f}/{result['val_f1_coarse'][-1]:.4f}\")\n",
        "        print(f\"Saved model: {result['best_model_path']}\")\n",
        "\n",
        "        # Save checkpoint list after each config (godel-style)\n",
        "        with open(CHECKPOINT_FILE, 'wb') as f:\n",
        "            pickle.dump(all_results, f)\n",
        "\n",
        "    # Finalize: choose best and save results list & metadata\n",
        "    best_result = min(all_results, key=lambda x: x['best_val_loss'])\n",
        "    print(f\"\\n=== Best Model: {best_result['config_key']} with val loss {best_result['best_val_loss']:.4f} ===\")\n",
        "\n",
        "    with open(BEST_META_FILE, 'wb') as f:\n",
        "        pickle.dump({\n",
        "            'best_model_path': best_result['best_model_path'],\n",
        "            'best_config_key': best_result['config_key']\n",
        "        }, f)\n",
        "\n",
        "    with open(RESULTS_FILE, 'wb') as f:\n",
        "        pickle.dump(all_results, f)\n",
        "\n",
        "    print(f\"Saved all results to {RESULTS_FILE}\")\n",
        "    print(f\"Best model path: {best_result['best_model_path']}\")\n",
        "\n",
        "    if os.path.exists(CHECKPOINT_FILE):\n",
        "        os.rename(CHECKPOINT_FILE, CHECKPOINT_FILE + \".bak\")\n",
        "        print(f\"Backup saved to {CHECKPOINT_FILE}.bak\")\n"
      ],
      "metadata": {
        "id": "kSQpghq17a0J"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import pickle\n",
        "import numpy as np\n",
        "\n",
        "RESULTS_FILE = SAVE_DIR + '/neuro/neuro_all_results.pkl'\n",
        "\n",
        "with open(RESULTS_FILE, \"rb\") as f:\n",
        "    all_results = pickle.load(f)\n",
        "\n",
        "print(f\"Loaded {len(all_results)} configs from {RESULTS_FILE}\\n\")\n",
        "\n",
        "for res in all_results:\n",
        "    config_key = res[\"config_key\"]\n",
        "    val_losses = res[\"val_loss\"]\n",
        "    train_losses = res[\"train_loss\"]\n",
        "\n",
        "    if not val_losses:  # skip if empty\n",
        "        continue\n",
        "\n",
        "    best_epoch = int(np.argmin(val_losses))\n",
        "\n",
        "    print(\"\\n\")\n",
        "    print(f\"Config: {config_key}\")\n",
        "    print(f\" Best epoch: {best_epoch+1}\")\n",
        "    print(f\" Train Loss: {train_losses[best_epoch]:.4f}\")\n",
        "    print(f\" Val Loss: {val_losses[best_epoch]:.4f}\")\n",
        "    print(f\" Val Acc (Fine/Coarse): {res['val_accuracy_fine'][best_epoch]:.4f} / {res['val_accuracy_coarse'][best_epoch]:.4f}\")\n",
        "    print(f\" Val ECE   (Fine/Coarse): {res['val_ece_fine'][best_epoch]:.4f} / {res['val_ece_coarse'][best_epoch]:.4f}\")\n",
        "    print(f\" Val Entropy (F/C): {res['val_entropy_fine'][best_epoch]:.4f} / {res['val_entropy_coarse'][best_epoch]:.4f}\")\n",
        "    print(f\" Val Consistency: {res['val_logical_consistency'][best_epoch]:.4f}\")\n",
        "    print(f\" Val P/R/F1 Fine: {res['val_precision_fine'][best_epoch]:.4f} / \"\n",
        "          f\"{res['val_recall_fine'][best_epoch]:.4f} / {res['val_f1_fine'][best_epoch]:.4f}\")\n",
        "    print(f\" Val P/R/F1 Coarse: {res['val_precision_coarse'][best_epoch]:.4f} / \"\n",
        "          f\"{res['val_recall_coarse'][best_epoch]:.4f} / {res['val_f1_coarse'][best_epoch]:.4f}\")\n",
        "    print(f\" Best model path: {res['best_model_path']}\")\n"
      ],
      "metadata": {
        "id": "fJNb6hLHCBjC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### TEST"
      ],
      "metadata": {
        "id": "iGhlyevQv4vO"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def constrained_output_clco(\n",
        "    fine_probs, coarse_probs,\n",
        "    fine_to_coarse,\n",
        "    device,\n",
        "    fine_threshold=0.5, coarse_threshold=0.5\n",
        "):\n",
        "    \"\"\"\n",
        "    Constrained prediction for fine + coarse tasks when only singleton labels exist.\n",
        "    Uses softmax probabilities directly.\n",
        "    \"\"\"\n",
        "\n",
        "    # Fine prediction (singleton softmax)\n",
        "    fine_pred = torch.argmax(fine_probs, dim=-1)\n",
        "\n",
        "    constrained_coarse_preds = []\n",
        "\n",
        "    for i in range(fine_pred.size(0)):\n",
        "        fine_conf = fine_probs[i, fine_pred[i]].item()  # confidence for fine\n",
        "\n",
        "        if fine_conf >= fine_threshold:\n",
        "            expected_coarse = fine_to_coarse[fine_pred[i].item()]\n",
        "\n",
        "            # Coarse prediction (singleton softmax)\n",
        "            coarse_pred = torch.argmax(coarse_probs[i]).item()\n",
        "            coarse_conf = coarse_probs[i, expected_coarse].item()\n",
        "\n",
        "            if coarse_conf < coarse_threshold:\n",
        "                constrained_coarse_preds.append(expected_coarse)\n",
        "            else:\n",
        "                constrained_coarse_preds.append(coarse_pred)\n",
        "        else:\n",
        "            # fallback: just trust coarse classifier\n",
        "            coarse_pred = torch.argmax(coarse_probs[i]).item()\n",
        "            constrained_coarse_preds.append(coarse_pred)\n",
        "\n",
        "    return fine_pred, torch.tensor(constrained_coarse_preds, device=device)\n"
      ],
      "metadata": {
        "id": "Xa1M38_exBN1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import math\n",
        "# Paths (neuro)\n",
        "TEST_RESULTS_PKL = os.path.join(SAVE_DIR, 'neuro_test_results_all.pkl')\n",
        "\n",
        "os.makedirs(SAVE_DIR, exist_ok=True)\n",
        "\n",
        "with open(RESULTS_FILE, 'rb') as f:\n",
        "    all_results = pickle.load(f)\n",
        "\n",
        "\n",
        "def evaluate_config(model_path, config_key,\n",
        "                    fine_threshold=0.5, coarse_threshold=0.5):\n",
        "\n",
        "    # Rebuild full SwinMultiTask model\n",
        "    model = SwinMultiTask(\n",
        "        num_fine_labels=num_fine_labels,\n",
        "        num_coarse_labels=num_coarse_labels\n",
        "    ).to(device)\n",
        "\n",
        "    # Load trained weights\n",
        "    _ = load_state_dict_robust(model, model_path, device)\n",
        "\n",
        "    # Accumulators\n",
        "    total = 0\n",
        "    correct_fine = 0\n",
        "    correct_coarse = 0\n",
        "    correct_logical = 0\n",
        "    test_loss_accum = 0.0\n",
        "\n",
        "    # For ECE / entropy / PRF\n",
        "    batch_fine_probs, batch_coarse_probs = [], []\n",
        "    batch_fine_preds, batch_coarse_preds = [], []\n",
        "    batch_fine_true,  batch_coarse_true  = [], []\n",
        "\n",
        "    for inputs, fine_labels, coarse_labels in test_loader_int:\n",
        "        inputs = inputs.to(device)\n",
        "        fine_labels = fine_labels.to(device)\n",
        "        coarse_labels = coarse_labels.to(device)\n",
        "\n",
        "        fine_logits, coarse_logits = model(inputs)\n",
        "\n",
        "        fine_probs = torch.softmax(fine_logits, dim=1)\n",
        "        coarse_probs = torch.softmax(coarse_logits, dim=1)\n",
        "\n",
        "        fine_pred, coarse_pred = constrained_output_clco(\n",
        "            fine_probs, coarse_probs,\n",
        "            fine_to_coarse=fine_to_coarse,\n",
        "            device=device,\n",
        "            fine_threshold=fine_threshold,\n",
        "            coarse_threshold=coarse_threshold\n",
        "        )\n",
        "\n",
        "        bs = fine_labels.size(0)\n",
        "        total += bs\n",
        "\n",
        "        # labels are already integers\n",
        "        fine_labels_idx   = fine_labels\n",
        "        coarse_labels_idx = coarse_labels\n",
        "\n",
        "\n",
        "        correct_fine   += (fine_pred   == fine_labels_idx).sum().item()\n",
        "        correct_coarse += (coarse_pred == coarse_labels_idx).sum().item()\n",
        "        correct_logical += calculate_logical_consistency_base(fine_pred, coarse_pred)\n",
        "\n",
        "        batch_fine_probs.append(fine_probs.cpu())\n",
        "        batch_coarse_probs.append(coarse_probs.cpu())\n",
        "        batch_fine_preds.append(fine_pred.cpu())\n",
        "        batch_coarse_preds.append(coarse_pred.cpu())\n",
        "        batch_fine_true.append(fine_labels_idx.cpu())\n",
        "        batch_coarse_true.append(coarse_labels_idx.cpu())\n",
        "\n",
        "    # Aggregates (fractions)\n",
        "    test_loss = test_loss_accum / max(1, len(test_loader_int))\n",
        "    fine_accuracy   = safe_div(correct_fine, total)\n",
        "    coarse_accuracy = safe_div(correct_coarse, total)\n",
        "    logical_consistency = safe_div(correct_logical, total)\n",
        "\n",
        "    all_fprob = torch.cat(batch_fine_probs)\n",
        "    all_cprob = torch.cat(batch_coarse_probs)\n",
        "    all_ftrue = torch.cat(batch_fine_true)\n",
        "    all_ctrue = torch.cat(batch_coarse_true)\n",
        "    all_fpred = torch.cat(batch_fine_preds)\n",
        "    all_cpred = torch.cat(batch_coarse_preds)\n",
        "\n",
        "    # ECE & entropy (original probs)\n",
        "    test_ece_fine   = float(compute_ece_pytorch(all_fprob, all_ftrue))\n",
        "    test_ece_coarse = float(compute_ece_pytorch(all_cprob, all_ctrue))\n",
        "    test_ent_fine   = float(compute_entropy(all_fprob))\n",
        "    test_ent_coarse = float(compute_entropy(all_cprob))\n",
        "\n",
        "    # Macro PR/Recall/F1\n",
        "    ytf = all_ftrue.numpy(); ypf = all_fpred.numpy()\n",
        "    ytc = all_ctrue.numpy(); ypc = all_cpred.numpy()\n",
        "\n",
        "    P_f, R_f, F1_f, _ = precision_recall_fscore_support(ytf, ypf, average='macro', zero_division=0)\n",
        "    P_c, R_c, F1_c, _ = precision_recall_fscore_support(ytc, ypc, average='macro', zero_division=0)\n",
        "\n",
        "    out = {\n",
        "        'config_key': config_key,\n",
        "        'best_model_path': model_path,\n",
        "        'fine_threshold': fine_threshold,\n",
        "        'coarse_threshold': coarse_threshold,\n",
        "        'test_loss': float(test_loss),\n",
        "        'test_accuracy_fine': float(fine_accuracy),\n",
        "        'test_accuracy_coarse': float(coarse_accuracy),\n",
        "        'test_logical_consistency': float(logical_consistency),\n",
        "        'test_ece_fine': float(test_ece_fine),\n",
        "        'test_ece_coarse': float(test_ece_coarse),\n",
        "        'test_entropy_fine': float(test_ent_fine),\n",
        "        'test_entropy_coarse': float(test_ent_coarse),\n",
        "        'test_precision_fine': float(P_f),\n",
        "        'test_recall_fine': float(R_f),\n",
        "        'test_f1_fine': float(F1_f),\n",
        "        'test_precision_coarse': float(P_c),\n",
        "        'test_recall_coarse': float(R_c),\n",
        "        'test_f1_coarse': float(F1_c),\n",
        "    }\n",
        "    return out\n",
        "\n",
        "\n",
        "# Run over every trained config + thresholds\n",
        "fine_thresholds   = [0.4, 0.5, 0.6, 0.7]\n",
        "coarse_thresholds = [0.3, 0.5, 0.7]\n",
        "\n",
        "configs = load_trained_configs(all_results)\n",
        "print(f\"Found {len(configs)} trained configs to test.\\n\")\n",
        "\n",
        "all_test_results = []\n",
        "for i, cfg in enumerate(configs, 1):\n",
        "    key  = cfg['config_key']\n",
        "    path = cfg['best_model_path']\n",
        "    bvl  = cfg.get('best_val_loss', float('nan'))\n",
        "\n",
        "    for ft in fine_thresholds:\n",
        "        for ct in coarse_thresholds:\n",
        "            print(f\"[{i}/{len(configs)}] Testing {key} \"\n",
        "                  f\"(val_loss={bvl:.6f}, fine_th={ft}, coarse_th={ct}): {path}\")\n",
        "            try:\n",
        "                res = evaluate_config(path, key,\n",
        "                                      fine_threshold=ft,\n",
        "                                      coarse_threshold=ct)\n",
        "                res['best_val_loss'] = bvl  # carry forward for sorting/report\n",
        "                all_test_results.append(res)\n",
        "\n",
        "                print(f\" -> Done: FineAcc={res['test_accuracy_fine']:.4f} | \"\n",
        "                      f\"CoarseAcc={res['test_accuracy_coarse']:.4f} | \"\n",
        "                      f\"FineF1={res['test_f1_fine']:.4f} | \"\n",
        "                      f\"CoarseF1={res['test_f1_coarse']:.4f}\\n\")\n",
        "            except Exception as e:\n",
        "                print(f\" !! Failed on {key}, thresholds {ft}/{ct}: {e}\\n\")\n",
        "\n",
        "\n",
        "# Save PKL\n",
        "with open(TEST_RESULTS_PKL, 'wb') as f:\n",
        "    pickle.dump(all_test_results, f)\n",
        "\n",
        "fieldnames = [\n",
        "    'config_key', 'best_val_loss',\n",
        "    'fine_threshold', 'coarse_threshold',\n",
        "    'test_loss',\n",
        "    'test_accuracy_fine', 'test_accuracy_coarse',\n",
        "    'test_precision_fine', 'test_recall_fine', 'test_f1_fine',\n",
        "    'test_precision_coarse', 'test_recall_coarse', 'test_f1_coarse',\n",
        "    'test_logical_consistency',\n",
        "    'test_ece_fine', 'test_ece_coarse',\n",
        "    'test_entropy_fine', 'test_entropy_coarse',\n",
        "    'best_model_path',\n",
        "]\n",
        "\n",
        "\n",
        "# Summary\n",
        "def sort_key(r):\n",
        "    bvl = r.get('best_val_loss', float('nan'))\n",
        "    if not (isinstance(bvl, float) and not math.isnan(bvl)):\n",
        "        return (1, -r.get('test_f1_fine', 0.0))  # fallback\n",
        "    return (0, bvl)\n",
        "\n",
        "all_test_results_sorted = sorted(all_test_results, key=sort_key)\n",
        "print(\"\\nTest Summary:\\n\")\n",
        "for r in all_test_results_sorted:\n",
        "    print(\n",
        "        f\"{r['config_key']:<30} \"\n",
        "        f\"(ft={r['fine_threshold']:.2f}, ct={r['coarse_threshold']:.2f}) \"\n",
        "        f\"val_loss={r.get('best_val_loss', float('nan')):>8.5f} | \"\n",
        "        f\"FineAcc={r['test_accuracy_fine']:.4f} | \"\n",
        "        f\"CoarseAcc={r['test_accuracy_coarse']:.4f} | \"\n",
        "        f\"FineF1={r['test_f1_fine']:.4f} | \"\n",
        "        f\"CoarseF1={r['test_f1_coarse']:.4f} | \"\n",
        "        f\"FineP/R={r['test_precision_fine']:.4f}/{r['test_recall_fine']:.4f} | \"\n",
        "        f\"CoarseP/R={r['test_precision_coarse']:.4f}/{r['test_recall_coarse']:.4f} | \"\n",
        "        f\"ECE(F/C)={r['test_ece_fine']:.4f}/{r['test_ece_coarse']:.4f} | \"\n",
        "        f\"Entropy(F/C)={r['test_entropy_fine']:.4f}/{r['test_entropy_coarse']:.4f}\"\n",
        "    )\n",
        "\n",
        "print(f\"\\nSaved per-config test results to:\\n- {TEST_RESULTS_PKL}\")\n"
      ],
      "metadata": {
        "id": "_mkoUbxWYB3y"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "machine_shape": "hm",
      "provenance": [],
      "collapsed_sections": [
        "0kG67rJjfcm5",
        "iWS7FQeNcKJX",
        "v44ThzrmvaYz",
        "zlZ2WmHrvsZ7"
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "7086ccebb8de40389de97f4bfd29ae0e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_9539b0594a2744b3aa999ba7e348ede3",
              "IPY_MODEL_0274c22b63d941b3ad3fd545f79e058c",
              "IPY_MODEL_d62f3f4a0edb4759b9ea932db60dc147"
            ],
            "layout": "IPY_MODEL_6086d0e734a74e5788f8035ade29b570"
          }
        },
        "9539b0594a2744b3aa999ba7e348ede3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_1584e13fbad841ab877b0306a9bb9bc6",
            "placeholder": "​",
            "style": "IPY_MODEL_781a32bbb5244b14a8525af16e9590cf",
            "value": "model.safetensors: 100%"
          }
        },
        "0274c22b63d941b3ad3fd545f79e058c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_71a375e09243424a913ce6f5d1f583e8",
            "max": 352685652,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_93493f3755db4beeb32b9a4cbc14f17e",
            "value": 352685652
          }
        },
        "d62f3f4a0edb4759b9ea932db60dc147": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_25852e7661ac4b62966e3d945d002033",
            "placeholder": "​",
            "style": "IPY_MODEL_06c772e5438b440faa9adf7984a068f1",
            "value": " 353M/353M [00:04&lt;00:00, 88.3MB/s]"
          }
        },
        "6086d0e734a74e5788f8035ade29b570": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1584e13fbad841ab877b0306a9bb9bc6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "781a32bbb5244b14a8525af16e9590cf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "71a375e09243424a913ce6f5d1f583e8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "93493f3755db4beeb32b9a4cbc14f17e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "25852e7661ac4b62966e3d945d002033": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "06c772e5438b440faa9adf7984a068f1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}