{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.preprocessing import LabelEncoder, StandardScaler, OneHotEncoder\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.compose import ColumnTransformer\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "from math import ceil, log2\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "import openml\n",
    "from simple_model import  ConformalPredictor, ConformalRankingPredictor, ClassifierModel, LabelRankingModel\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_models(dataset_id):\n",
    "    dataset = openml.datasets.get_dataset(dataset_id)\n",
    "    X, y, _, _ = dataset.get_data(\n",
    "        target=dataset.default_target_attribute, dataset_format=\"dataframe\"\n",
    "    )\n",
    "\n",
    "    # Automatically identify categorical and numerical columns\n",
    "    categorical_features = X.select_dtypes(\n",
    "        include=[\"object\", \"category\"]\n",
    "    ).columns.tolist()\n",
    "    numerical_features = X.select_dtypes(include=[\"int64\", \"float64\"]).columns.tolist()\n",
    "\n",
    "    num_classes = len(np.unique(y))\n",
    "\n",
    "    X_train, X_test, y_train, y_test = train_test_split(\n",
    "        X, y, test_size=0.2, random_state=1\n",
    "    )\n",
    "\n",
    "    # Encode labels\n",
    "    le = LabelEncoder()\n",
    "\n",
    "    y_train = le.fit_transform(y_train)\n",
    "    y_test = le.transform(y_test)\n",
    "\n",
    "    # Preprocessing for numerical data: Impute missing values, then scale\n",
    "    numerical_transformer = Pipeline(\n",
    "        steps=[\n",
    "            (\"imputer\", SimpleImputer(strategy=\"mean\")),\n",
    "            (\"scaler\", StandardScaler()),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    # Preprocessing for categorical data: Impute missing values, then one-hot encode\n",
    "    categorical_transformer = Pipeline(\n",
    "        steps=[\n",
    "            (\"imputer\", SimpleImputer(strategy=\"most_frequent\")),\n",
    "            (\"onehot\", OneHotEncoder(handle_unknown=\"ignore\")),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    # Combine preprocessing steps\n",
    "    preprocessor = ColumnTransformer(\n",
    "        transformers=[\n",
    "            (\"num\", numerical_transformer, numerical_features),\n",
    "            (\"cat\", categorical_transformer, categorical_features),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    X_train = preprocessor.fit_transform(X_train)\n",
    "    X_test = preprocessor.transform(X_test)\n",
    "\n",
    "    if not isinstance(X_train, np.ndarray):\n",
    "        X_train = X_train.toarray()\n",
    "    if not isinstance(y_train, np.ndarray):\n",
    "        y_train = y_train.toarray()\n",
    "\n",
    "    # clf = ClassifierModel(input_dim = X_train.shape[1], hidden_dim=16, output_dim=y.max()+1)\n",
    "    rf = RandomForestClassifier()\n",
    "\n",
    "    clf = ClassifierModel(input_dim=X_train.shape[1], hidden_dim=16, output_dim=num_classes)\n",
    "    rank = LabelRankingModel(input_dim=X_train.shape[1], hidden_dim=16, output_dim=num_classes)\n",
    "\n",
    "    num_classes = len(np.unique(y_train))\n",
    "\n",
    "    batch_size_clf = 32\n",
    "    val_frac = 0.2\n",
    "    cal_frac = 1/3\n",
    "    num_epochs = 35\n",
    "    learning_rate = 0.01\n",
    "\n",
    "    cp_net = ConformalPredictor(clf, alpha=0.05)\n",
    "    cp_rf = ConformalPredictor(rf, alpha=0.05)\n",
    "    cp_rank = ConformalPredictor(rank, alpha=0.05)\n",
    "\n",
    "    cp_net.fit(X_train, y_train, num_epochs=num_epochs, random_state=1, patience=num_epochs, batch_size=batch_size_clf, val_frac=val_frac, cal_size=cal_frac, learning_rate=learning_rate)\n",
    "    cp_rank.fit(X_train, y_train, num_epochs=num_epochs, random_state=1, patience=num_epochs, batch_size=batch_size_clf, val_frac=val_frac, cal_size=cal_frac, learning_rate=learning_rate)\n",
    "    cp_rf.fit(X_train, y_train, cal_size=cal_frac)\n",
    "\n",
    "    if not isinstance(X_test, np.ndarray):\n",
    "        X_test = X_test.toarray()\n",
    "    if not isinstance(y_test, np.ndarray):\n",
    "        y_test = y_test.toarray()\n",
    "\n",
    "    def evaluate_method(method):\n",
    "        pred_sets = method.predict_set(X_test)\n",
    "        y_test_model = method.model.predict(X_test)\n",
    "        coverage = np.mean([y_test[i] in pred_sets[i] for i in range(len(y_test))])\n",
    "        efficiency = np.mean([len(pred_sets[i]) for i in range(len(y_test))])\n",
    "        print(f\"Accuracy {accuracy_score(y_test_model, y_test)}\")\n",
    "        print(f\"Coverage {coverage} efficiency {efficiency}\")\n",
    "\n",
    "    # coverage_clf = np.mean([y_test[i] in pred_sets_clf[i] for i in range(len(y_test))])\n",
    "    # efficiency_clf = np.mean([len(pred_sets_clf[i]) for i in range(len(y_test))])\n",
    "\n",
    "    print(\"Random Forest\")\n",
    "    evaluate_method(cp_rf)\n",
    "    print(\"\\nClassifier Network\")\n",
    "    evaluate_method(cp_net)\n",
    "    print(\"\\nRanking Network\")\n",
    "    evaluate_method(cp_rank)\n",
    "for dataset_id in [61, 187, 15, 31, 4534, 1461]:\n",
    "    print(f\"\\n\\nDataset: {dataset_id}\")\n",
    "    evaluate_models(dataset_id)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from simple_model import LabelPairDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "X,y = load_iris(return_X_y=True)\n",
    "X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=1/3)\n",
    "num_classes = len(np.unique(y))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from simple_model import LabelRankingModel\n",
    "model = LabelRankingModel(input_dim=X_train.shape[1],hidden_dim=16, output_dim=3)\n",
    "model.fit(X_train,y_train,num_classes=3, random_state=1, batch_size=32, num_epochs=250,patience=250)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.arange(len(model.val_losses)), model.val_losses, label=\"validation loss\")\n",
    "plt.plot(np.arange(len(model.train_losses)), model.train_losses, label=\"train loss\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = model.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy_score(y_pred,y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torchvision import models\n",
    "\n",
    "# Set device (use GPU if available)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Step 1: Load the CIFAR-10 dataset\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize((224, 224)),  # ResNet expects 224x224 input\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]\n",
    "])\n",
    "\n",
    "batch_size = 64\n",
    "\n",
    "# Training and test data loaders\n",
    "trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "# # Step 2: Load the pre-trained ResNet18 model\n",
    "# model = models.resnet18(pretrained=True)\n",
    "# num_ftrs = model.fc.in_features\n",
    "\n",
    "# # Modify the final layer for 10 classes (CIFAR-10)\n",
    "# model.fc = nn.Linear(num_ftrs, 10)\n",
    "# model = model.to(device)\n",
    "\n",
    "# # Step 3: Define loss function and optimizer\n",
    "# criterion = nn.CrossEntropyLoss()\n",
    "# optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# # Step 4: Train the model\n",
    "# epochs = 5\n",
    "# for epoch in range(epochs):\n",
    "#     model.train()\n",
    "#     running_loss = 0.0\n",
    "#     for inputs, labels in trainloader:\n",
    "#         inputs, labels = inputs.to(device), labels.to(device)\n",
    "\n",
    "#         # Zero the parameter gradients\n",
    "#         optimizer.zero_grad()\n",
    "\n",
    "#         # Forward + backward + optimize\n",
    "#         outputs = model(inputs)\n",
    "#         loss = criterion(outputs, labels)\n",
    "#         loss.backward()\n",
    "#         optimizer.step()\n",
    "\n",
    "#         running_loss += loss.item()\n",
    "\n",
    "#     print(f'Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(trainloader):.4f}')\n",
    "\n",
    "# # Step 5: Evaluate the model on the test set\n",
    "# model.eval()\n",
    "# correct = 0\n",
    "# total = 0\n",
    "\n",
    "# with torch.no_grad():\n",
    "#     for inputs, labels in testloader:\n",
    "#         inputs, labels = inputs.to(device), labels.to(device)\n",
    "#         outputs = model(inputs)\n",
    "#         _, predicted = torch.max(outputs, 1)\n",
    "#         total += labels.size(0)\n",
    "#         correct += (predicted == labels).sum().item()\n",
    "\n",
    "# print(f'Test Accuracy: {100 * correct / total:.2f}%')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Step 2: Define a function to display images\n",
    "def imshow(img):\n",
    "    img = img / 2 + 0.5  # Unnormalize the image (reverse the normalization)\n",
    "    npimg = img.numpy()  # Convert to numpy array\n",
    "    plt.imshow(np.transpose(npimg, (1, 2, 0)))  # Transpose to (H, W, C)\n",
    "    plt.axis('off')  # Turn off axis\n",
    "    plt.show()\n",
    "\n",
    "# Step 3: Get a batch of images and display them\n",
    "dataiter = iter(trainloader)\n",
    "images, labels = next(dataiter)  # Get a batch of images and labels\n",
    "\n",
    "# Define CIFAR-10 class names\n",
    "classes = ('plane', 'car', 'bird', 'cat',\n",
    "           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
    "\n",
    "# Show images\n",
    "imshow(torchvision.utils.make_grid(images))  # Display the images in a grid\n",
    "\n",
    "# Print labels\n",
    "print(' '.join(f'{classes[labels[j]]}' for j in range(4)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "probs = np.load(\"data/cifar10h-probs.npy\")\n",
    "counts = np.load(\"data/cifar10h-counts.npy\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_pic_and_probs(k):\n",
    "    df = pd.DataFrame({'classes': classes, 'counts': counts[k], 'probabilities': probs[k]})\n",
    "    imshow(testset[k][0])\n",
    "    display(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "probs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epsilon = 1e-10\n",
    "entropy = -np.sum(probs * np.log2(probs + epsilon), axis=1)\n",
    "high_confusion = np.argsort(entropy)[::-1][:5]\n",
    "high_confusion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "from torchvision.datasets import CIFAR10\n",
    "from torchvision import transforms\n",
    "import numpy as np\n",
    "import itertools\n",
    "\n",
    "\n",
    "class CIFAR10SoftLabelComparisonDataset(Dataset):\n",
    "    def __init__(self, cifar_dataset, soft_labels, num_pairs=1000, in_instance_ratio=0.5):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            cifar_dataset (torchvision.datasets.CIFAR10): CIFAR-10 dataset object.\n",
    "            soft_labels (numpy.ndarray): Array of shape (N, num_classes) with soft labels.\n",
    "            num_pairs (int): Total number of pairs to generate.\n",
    "            in_instance_ratio (float): Ratio of in-instance comparisons (between 0 and 1).\n",
    "        \"\"\"\n",
    "        self.cifar_dataset = cifar_dataset\n",
    "        self.soft_labels = torch.tensor(soft_labels, dtype=torch.float32)  # Convert to torch tensor\n",
    "        self.num_samples = len(cifar_dataset)\n",
    "        self.num_classes = soft_labels.shape[1]\n",
    "        self.num_pairs = num_pairs\n",
    "        self.in_instance_ratio = in_instance_ratio\n",
    "\n",
    "        # Determine how many pairs are in-instance and cross-instance\n",
    "        self.num_in_instance = int(num_pairs * in_instance_ratio)\n",
    "        self.num_cross_instance = num_pairs - self.num_in_instance\n",
    "\n",
    "        # Precompute deterministic pairs\n",
    "        self.in_instance_pairs = self._generate_in_instance_pairs()\n",
    "        self.cross_instance_pairs = self._generate_cross_instance_pairs()\n",
    "\n",
    "    def _generate_in_instance_pairs(self):\n",
    "        \"\"\"\n",
    "        Generate in-instance pairs: (image, class1, class2) such that prob(class1|image) > prob(class2|image).\n",
    "        Returns:\n",
    "            list: List of (idx, class1, class2, prob1, prob2).\n",
    "        \"\"\"\n",
    "        in_instance_pairs = []\n",
    "\n",
    "        for idx in range(self.num_samples):\n",
    "            label_probs = self.soft_labels[idx]\n",
    "            for class1 in range(self.num_classes):\n",
    "                for class2 in range(self.num_classes):\n",
    "                    if class1 != class2 and label_probs[class1] > label_probs[class2]:\n",
    "                        in_instance_pairs.append(\n",
    "                            (idx, class1, class2, label_probs[class1].item(), label_probs[class2].item())\n",
    "                        )\n",
    "\n",
    "        # Shuffle and select a fixed number of in-instance pairs\n",
    "        np.random.shuffle(in_instance_pairs)\n",
    "        return in_instance_pairs[:self.num_in_instance]\n",
    "\n",
    "    def _generate_cross_instance_pairs(self):\n",
    "        \"\"\"\n",
    "        Generate cross-instance pairs: (image1, class1, image2, class2) where prob(class1|image1) > prob(class2|image2).\n",
    "        Returns:\n",
    "            list: List of (idx1, class1, idx2, class2, prob1, prob2).\n",
    "        \"\"\"\n",
    "        cross_instance_pairs = []\n",
    "\n",
    "        for _ in range(self.num_samples * 2):  # Generate more candidates than needed\n",
    "            idx1, idx2 = np.random.choice(self.num_samples, size=2, replace=False)\n",
    "            label_probs1 = self.soft_labels[idx1]\n",
    "            label_probs2 = self.soft_labels[idx2]\n",
    "\n",
    "            for class1 in range(self.num_classes):\n",
    "                for class2 in range(self.num_classes):\n",
    "                    if label_probs1[class1] > label_probs2[class2]:\n",
    "                        cross_instance_pairs.append(\n",
    "                            (idx1, class1, idx2, class2, label_probs1[class1].item(), label_probs2[class2].item())\n",
    "                        )\n",
    "\n",
    "        # Shuffle and select a fixed number of cross-instance pairs\n",
    "        np.random.shuffle(cross_instance_pairs)\n",
    "        return cross_instance_pairs[:self.num_cross_instance]\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.num_pairs\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        \"\"\"\n",
    "        Returns:\n",
    "            img1 (Tensor): First image in the comparison.\n",
    "            class1 (int): Class index for the first image.\n",
    "            prob1 (float): Probability of the class for the first image.\n",
    "            img2 (Tensor): Second image in the comparison.\n",
    "            class2 (int): Class index for the second image.\n",
    "            prob2 (float): Probability of the class for the second image.\n",
    "        \"\"\"\n",
    "        if idx < self.num_in_instance:\n",
    "            # Return an in-instance pair\n",
    "            idx, class1, class2, prob1, prob2 = self.in_instance_pairs[idx]\n",
    "            img1, _ = self.cifar_dataset[idx]\n",
    "            return img1, class1, prob1, img1, class2, prob2\n",
    "        else:\n",
    "            # Return a cross-instance pair\n",
    "            idx = idx - self.num_in_instance\n",
    "            idx1, class1, idx2, class2, prob1, prob2 = self.cross_instance_pairs[idx]\n",
    "            img1, _ = self.cifar_dataset[idx1]\n",
    "            img2, _ = self.cifar_dataset[idx2]\n",
    "            return img1, class1, prob1, img2, class2, prob2\n",
    "\n",
    "\n",
    "# Example usage\n",
    "transform = transforms.Compose([transforms.ToTensor()])\n",
    "cifar_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "soft_labels = probs\n",
    "\n",
    "# Create the comparison dataset with deterministic 60% in-instance, 40% cross-instance pairs\n",
    "comparison_dataset = CIFAR10SoftLabelComparisonDataset(\n",
    "    cifar_dataset, soft_labels, num_pairs=1000, in_instance_ratio=0.6\n",
    ")\n",
    "\n",
    "# Access a sample pair\n",
    "img1, class1, prob1, img2, class2, prob2 = comparison_dataset[0]\n",
    "print(f\"Image1 (Class {class1}, Prob {prob1:.2f}) vs Image2 (Class {class2}, Prob {prob2:.2f})\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_pair(img1, class1, prob1, img2, class2, prob2, class_names):\n",
    "    \"\"\"\n",
    "    Visualize a comparison pair, including the images, their labels, and probabilities.\n",
    "    \"\"\"\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(8, 4))\n",
    "    \n",
    "    # Image 1\n",
    "    axes[0].imshow(img1.permute(1, 2, 0))  # Convert from (C, H, W) to (H, W, C)\n",
    "    axes[0].set_title(f\"Class {class1} ({class_names[class1]})\\nProb: {prob1:.2f}\")\n",
    "    axes[0].axis('off')\n",
    "\n",
    "    # Image 2\n",
    "    axes[1].imshow(img2.permute(1, 2, 0))\n",
    "    axes[1].set_title(f\"Class {class2} ({class_names[class2]})\\nProb: {prob2:.2f}\")\n",
    "    axes[1].axis('off')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "class_names = ('plane', 'car', 'bird', 'cat',\n",
    "           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
    "\n",
    "batch_size = 1\n",
    "loader = iter(DataLoader(comparison_dataset, batch_size=batch_size, shuffle=True))\n",
    "\n",
    "\n",
    "for _ in range(5):\n",
    "    img1_b, class1_b, prob1_b, img2_b, class2_b, prob2_b = next(loader)\n",
    "    for i in range(batch_size):\n",
    "        img1, class1, prob1, img2, class2, prob2 = img1_b[i], class1_b[i], prob1_b[i], img2_b[i], class2_b[i], prob2_b[i]\n",
    "        visualize_pair(img1, class1, prob1, img2, class2, prob2, class_names)\n",
    "    # for (img1, class1, prob1, img2, class2, prob2) in batch:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_rowwise_pairs(probs, i_threshold=None, j_threshold=None):\n",
    "    result = []\n",
    "    for row_idx, row in enumerate(probs):\n",
    "        i_indices, j_indices = np.where(row[:, None] > row)\n",
    "        if i_threshold is not None:\n",
    "            mask_i = row[i_indices] > i_threshold\n",
    "            i_indices, j_indices = i_indices[mask_i], j_indices[mask_i]\n",
    "        if j_threshold is not None:\n",
    "            mask_j = row[j_indices] > j_threshold\n",
    "            i_indices, j_indices = i_indices[mask_j], j_indices[mask_j]\n",
    "        pairs = np.column_stack((np.full(i_indices.shape, row_idx), i_indices, j_indices))\n",
    "        result.append(pairs)\n",
    "    try:\n",
    "        result = np.vstack(result)\n",
    "    except:\n",
    "        result = np.array([])\n",
    "    return result\n",
    "\n",
    "def get_rowwise_pairs_with_max(matrix, j_threshold=None):\n",
    "    result = []\n",
    "    for row_idx, row in enumerate(matrix):\n",
    "        max_value = np.max(row)\n",
    "        max_ids = np.argwhere(row==max_value)\n",
    "        print(max_ids)\n",
    "        for max_idx in max_ids:\n",
    "            j_indices = np.where(row < max_value)[0]\n",
    "            if j_threshold is not None:\n",
    "                j_indices = j_indices[row[j_indices] > j_threshold]\n",
    "            \n",
    "            pairs = np.column_stack((np.full(j_indices.shape, row_idx), np.full(j_indices.shape, max_idx), j_indices))\n",
    "        result.append(pairs)\n",
    "    try:\n",
    "        result = np.vstack(result)\n",
    "    except:\n",
    "        result = np.array([])\n",
    "    return result\n",
    "\n",
    "def get_cross_row_pairs(matrix):\n",
    "    num_rows, num_cols = matrix.shape\n",
    "    result = []\n",
    "    # Iterate over all pairs of rows (k, l)\n",
    "    for k in range(num_rows):\n",
    "        for l in range(num_rows):\n",
    "            if k != l:\n",
    "                # Compare all pairs of elements from row k and row l\n",
    "                i_indices, j_indices = np.where(matrix[k][:, None] > matrix[l])\n",
    "                # Combine row indices (k, l) with column indices (i, j)\n",
    "                pairs = np.column_stack((np.full(i_indices.shape, k), i_indices, np.full(j_indices.shape, l), j_indices))\n",
    "                result.append(pairs)\n",
    "    try:\n",
    "        result = np.stack(result)\n",
    "    except:\n",
    "        result = np.array([])\n",
    "    return result\n",
    "\n",
    "\n",
    "def get_cross_row_pairs_with_max(matrix):\n",
    "    num_rows, num_cols = matrix.shape\n",
    "    result = []\n",
    "    # Iterate over all pairs of rows (k, l)\n",
    "    for k in range(num_rows):\n",
    "        max_k = np.max(matrix.k)\n",
    "        for l in range(num_rows):\n",
    "            max_l = np.max(matrix.k)\n",
    "            if k != l:\n",
    "                # Compare all pairs of elements from row k and row l\n",
    "                i_indices, j_indices = np.where(matrix[k][:, None] > matrix[l])\n",
    "                # Combine row indices (k, l) with column indices (i, j)\n",
    "                pairs = np.column_stack((np.full(i_indices.shape, k), i_indices, np.full(j_indices.shape, l), j_indices))\n",
    "                result.append(pairs)\n",
    "    try:\n",
    "        result = np.stack(result)\n",
    "    except:\n",
    "        result = np.array([])\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pairs = get_rowwise_pairs_with_max(probs[9244:9248])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.vstack(pairs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 216,
   "metadata": {},
   "outputs": [],
   "source": [
    "row_max = np.max(probs, axis=1)\n",
    "\n",
    "# Step 2: Check how many times the maximum value occurs in each row\n",
    "is_not_unique = np.sum(probs == row_max[:, None], axis=1) > 1\n",
    "\n",
    "# Step 3: Get the indices of rows where the maximum is not unique\n",
    "rows_with_non_unique_max = np.where(is_not_unique)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rows_with_non_unique_max"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def imshow(img):\n",
    "  img = img / 2 + 0.5   # unnormalize\n",
    "  npimg = img.numpy()   # convert from tensor\n",
    "  plt.imshow(np.transpose(npimg, (1, 2, 0))) \n",
    "  plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imshow(testset[32][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import resnet18\n",
    "model = resnet18(pretrained=True)\n",
    "\n",
    "num_features = model.fc.in_features\n",
    "model.fc = nn.Linear(num_features, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Subset\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.Resize(224),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
    "])\n",
    "\n",
    "test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform_test, download=True)\n",
    "\n",
    "subset = Subset(test_dataset, indices=range(0,5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "subset_loader = torch.utils.data.DataLoader(dataset=subset, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 50\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    running_loss = 0.0\n",
    "\n",
    "    for images, labels in subset_loader:\n",
    "        # images, labels = images.to(device), labels.to(device)\n",
    "        # Forward pass\n",
    "        outputs = model(images)\n",
    "        loss = criterion(outputs, labels)\n",
    "\n",
    "        # Backward pass and optimization\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        running_loss += loss.item()\n",
    "\n",
    "    print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(subset_loader):.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "correct = 0\n",
    "total = 0\n",
    "\n",
    "with torch.no_grad():\n",
    "    for images, labels in subset_loader:\n",
    "        # Forward pass\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "accuracy = 100 * correct / total\n",
    "print(f\"Test Accuracy: {accuracy:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "with torch.no_grad():\n",
    "\n",
    "    for image, label in subset_loader:\n",
    "        outputs = model(image)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        imshow(image)\n",
    "        print(softmax(outputs.cpu().numpy()))\n",
    "        print(\"true label:\", class_names[label])\n",
    "        print(\"predicted label:\", class_names[predicted])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for image, label in subset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "plnet",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
