{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.preprocessing import LabelEncoder, StandardScaler, OneHotEncoder\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.compose import ColumnTransformer\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "from math import ceil, log2\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "import openml\n",
    "from conformal.conformal import ConformalPredictor, ConformalRankingPredictor\n",
    "from models.ranking_resnet import LabelRankingResnet\n",
    "from models.classifier_resnet import ClassifierResnet\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_pair(img1, class1, img2, class2, class_names, prob1=None, prob2=None):\n",
    "    \"\"\"\n",
    "    Visualize a comparison pair, including the images, their labels, and probabilities.\n",
    "    \"\"\"\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(8, 4))\n",
    "    # Image 1\n",
    "    axes[0].imshow(img1.permute(1, 2, 0))  # Convert from (C, H, W) to (H, W, C)\n",
    "    if prob1:\n",
    "        axes[0].set_title(f\"Class {class1} ({class_names[class1]})\\nProb: {prob1:.2f}\")\n",
    "    else:\n",
    "        axes[0].set_title(f\"Class {class1} ({class_names[class1]})\")\n",
    "    axes[0].axis('off')\n",
    "\n",
    "    # Image 2\n",
    "    axes[1].imshow(img2.permute(1, 2, 0))\n",
    "    if prob2:\n",
    "        axes[1].set_title(f\"Class {class2} ({class_names[class2]})\\nProb: {prob2:.2f}\")\n",
    "    else:\n",
    "        axes[1].set_title(f\"Class {class2} ({class_names[class2]})\")\n",
    "    axes[1].axis('off')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_rowwise_pairs(probs, i_threshold=None, j_threshold=None):\n",
    "    result = []\n",
    "    for row_idx, row in enumerate(probs):\n",
    "        i_indices, j_indices = np.where(row[:, None] > row)\n",
    "        if i_threshold is not None:\n",
    "            mask_i = row[i_indices] > i_threshold\n",
    "            i_indices, j_indices = i_indices[mask_i], j_indices[mask_i]\n",
    "        if j_threshold is not None:\n",
    "            mask_j = row[j_indices] > j_threshold\n",
    "            i_indices, j_indices = i_indices[mask_j], j_indices[mask_j]\n",
    "        pairs = np.column_stack((np.full(i_indices.shape, row_idx), i_indices, j_indices))\n",
    "        result.append(pairs)\n",
    "    try:\n",
    "        result = np.vstack(result)\n",
    "    except:\n",
    "        result = np.array([])   \n",
    "    return result\n",
    "\n",
    "def get_rowwise_pairs_with_max(matrix, j_threshold=None):\n",
    "    result = []\n",
    "    for row_idx, row in enumerate(matrix):\n",
    "        max_value = np.max(row)\n",
    "        max_ids = np.argwhere(row==max_value)\n",
    "        for max_idx in max_ids:\n",
    "            j_indices = np.where(row < max_value)[0]\n",
    "            if j_threshold is not None:\n",
    "                j_indices = j_indices[row[j_indices] > j_threshold]\n",
    "            \n",
    "            pairs = np.column_stack((np.full(j_indices.shape, row_idx),np.full(j_indices.shape, row_idx), np.full(j_indices.shape, max_idx), j_indices))\n",
    "        result.append(pairs)\n",
    "    try:\n",
    "        result = np.vstack(result)\n",
    "    except:\n",
    "        result = np.array([])\n",
    "    return result\n",
    "\n",
    "def get_cross_row_pairs(matrix):\n",
    "    num_rows, num_cols = matrix.shape\n",
    "    result = []\n",
    "    # Iterate over all pairs of rows (k, l)\n",
    "    for k in range(num_rows):\n",
    "        for l in range(num_rows):\n",
    "            if k != l:\n",
    "                # Compare all pairs of elements from row k and row l\n",
    "                i_indices, j_indices = np.where(matrix[k][:, None] > matrix[l])\n",
    "                # Combine row indices (k, l) with column indices (i, j)\n",
    "                pairs = np.column_stack((np.full(i_indices.shape, k), i_indices, np.full(j_indices.shape, l), j_indices))\n",
    "                result.append(pairs)\n",
    "    try:\n",
    "        result = np.stack(result)\n",
    "    except:\n",
    "        result = np.array([])\n",
    "    return result\n",
    "\n",
    "\n",
    "def get_cross_row_pairs_with_max(matrix):\n",
    "    \"\"\" Generates pairs between argmax classes across instances\n",
    "\n",
    "    :param matrix: _description_\n",
    "    :return: _description_\n",
    "    \"\"\"\n",
    "    result = []\n",
    "\n",
    "    max_indices = [np.where(row == row.max())[0] for row in matrix]\n",
    "\n",
    "    pairs = []\n",
    "\n",
    "    for i in range(len(matrix)):\n",
    "        for j in range(len(matrix)):\n",
    "            if i != j:\n",
    "                # Compare all combinations of maxima indices between row i and row j\n",
    "                for col_i in max_indices[i]:\n",
    "                    for col_j in max_indices[j]:\n",
    "                        if matrix[i, col_i] > matrix[j, col_j]:\n",
    "                            pairs.append((i, j, col_i, col_j))\n",
    "    try:\n",
    "        result = np.stack(np.array(pairs))\n",
    "    except:\n",
    "        result = np.array([])\n",
    "\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torchvision import models\n",
    "from torch.utils.data import Subset\n",
    "\n",
    "class PairwiseCIFAR10H(Dataset):\n",
    "\n",
    "    def sample_rows(array, sample):\n",
    "        if isinstance(sample, float):  # Fraction of rows\n",
    "            num_rows = int(sample * array.shape[0])\n",
    "        elif isinstance(sample, int):  # Number of rows\n",
    "            num_rows = sample\n",
    "        else:\n",
    "            raise ValueError(\"Sample must be a float (fraction) or int (number).\")\n",
    "        \n",
    "        sampled_indices = np.random.choice(array.shape[0], size=num_rows, replace=False)\n",
    "        return array[sampled_indices]\n",
    "\n",
    "    def __init__(self, dataset, probs, in_instance_pairs=1.0, cross_instance_pairs=1.0):\n",
    "\n",
    "        self.dataset = dataset\n",
    "        self.probs = probs\n",
    "        print(\"Generating in-instance pairs:\")\n",
    "        in_instance_pairs = get_rowwise_pairs_with_max(self.probs)\n",
    "        print(\"Generating cross-instance pairs:\")\n",
    "        cross_instance_pairs = get_cross_row_pairs_with_max(self.probs)\n",
    "        self.pair_indices = np.vstack([in_instance_pairs, cross_instance_pairs])\n",
    "\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.pair_indices)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        img_a_idx, img_b_idx, label_a, label_b = self.pair_indices[idx]\n",
    "        img_a, ground_truth_a = self.dataset[img_a_idx]\n",
    "        img_b, ground_truth_b = self.dataset[img_b_idx]\n",
    "        return img_a, label_a, img_b, label_b, img_a_idx, img_b_idx, label_a, label_b\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.Compose([\n",
    "    transforms.Resize((224, 224)),  # ResNet expects 224x224 input\n",
    "    transforms.ToTensor(),\n",
    "    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "])\n",
    "dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "class_names = ('plane', 'car', 'bird', 'cat',\n",
    "           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
    "probs = np.load(\"data/cifar10h-probs.npy\")\n",
    "\n",
    "subset = Subset(dataset, range(0,5))\n",
    "subset_probs = probs[0:5]\n",
    "\n",
    "pair_data = PairwiseCIFAR10H(subset, subset_probs)\n",
    "\n",
    "for x1,l1,x2,l2, img_a_idx, img_b_idx, label_a, label_b in pair_data:\n",
    "    prob1 = subset_probs[img_a_idx][label_a]\n",
    "    prob2 = subset_probs[img_b_idx][label_b]\n",
    "    visualize_pair(x1, l1, x2, l2, class_names, prob1, prob2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import ceil\n",
    "\n",
    "# Set device (use GPU if available)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Step 1: Load the CIFAR-10 dataset\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize((224, 224)),  # ResNet expects 224x224 input\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]\n",
    "])\n",
    "\n",
    "batch_size = 64\n",
    "\n",
    "# Training and test data loaders\n",
    "trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
    "\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
    "subset = Subset(trainset, indices=range(0,1000))\n",
    "\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)\n",
    "subset_loader = torch.utils.data.DataLoader(dataset=subset, batch_size=batch_size, shuffle=True, generator=torch.Generator(device='cuda'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_clf = models.resnet18(pretrained=True)\n",
    "num_ftrs = model_clf.fc.in_features\n",
    "\n",
    "model_clf.fc = nn.Linear(num_ftrs, 10)\n",
    "model_clf = model_clf.to(device)\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model_clf.parameters(), lr=0.001)\n",
    "\n",
    "epochs = 30\n",
    "for epoch in range(epochs):\n",
    "    model_clf.train()\n",
    "    running_loss = 0.0\n",
    "    for inputs, labels in testloader:\n",
    "        inputs, labels = inputs.to(device), labels.to(device)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        outputs = model_clf(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        running_loss += loss.item()\n",
    "\n",
    "    print(f'Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(trainloader):.4f}')\n",
    "save_path = f\"./finetuned_models/clf_cifar10h.pth\"\n",
    "\n",
    "torch.save(model_clf.state_dict(), save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_clf.eval()\n",
    "correct = 0\n",
    "total = 0\n",
    "\n",
    "with torch.no_grad():\n",
    "    for inputs, labels in subset_loader:\n",
    "        inputs, labels = inputs.to(device), labels.to(device)\n",
    "        outputs = model_clf(inputs)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "print(f'Test Accuracy: {100 * correct / total:.2f}%')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.set_default_device(\"cuda\")\n",
    "# pairset = PairwiseCIFAR10H(testset,probs)\n",
    "# print(len(pairset))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size_scale = len(pairset) / len(testset)\n",
    "from torchvision.models import resnet18\n",
    "from torch.optim import Adam\n",
    "\n",
    "batch_size = 64 \n",
    "# batch_size = int(batch_size_scale*batch_size)\n",
    "pairset_loader = torch.utils.data.DataLoader(dataset=pairset, batch_size=batch_size, shuffle=True, generator=torch.Generator(device='cuda'))\n",
    "\n",
    "rnk_model = resnet18(pretrained=True)\n",
    "num_ftrs = rnk_model.fc.in_features\n",
    "rnk_model.cuda()\n",
    "\n",
    "optimizer = Adam(rnk_model.parameters(), lr=0.001)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "rnk_model.eval()\n",
    "correct = 0\n",
    "total = 0\n",
    "device = \"cuda\"\n",
    "with torch.no_grad():\n",
    "    for inputs, labels in subset_loader:\n",
    "        inputs, labels = inputs.to(device), labels.to(device)\n",
    "        outputs = rnk_model(inputs)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "print(f'Test Accuracy: {100 * correct / total:.2f}%')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(pairset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "plnet",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
