{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.models as models\n",
    "from torch.utils.data import Subset\n",
    "from torch.utils.data import DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    "import numpy as np\n",
    "\n",
    "# Step 1: Prepare CIFAR-10 Dataset\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize((224, 224)),  # Resize to 224x224 (for ResNet)\n",
    "    transforms.RandomHorizontalFlip(),  # Data augmentation\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])  # CIFAR-10 normalization\n",
    "])\n",
    "\n",
    "\n",
    "# Step 1: Prepare CIFAR-10 Dataset\n",
    "transform_otf = transforms.Compose([\n",
    "    transforms.ToPILImage(),\n",
    "    transforms.Resize((224, 224)),  # Resize to 224x224 (for ResNet)\n",
    "    transforms.RandomHorizontalFlip(),  # Data augmentation\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])  # CIFAR-10 normalization\n",
    "])\n",
    "\n",
    "\n",
    "train_dataset_orig = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
    "train_dataset = Subset(train_dataset_orig, range(0,64))\n",
    "cal_dataset = Subset(train_dataset_orig, range(64,1024))\n",
    "\n",
    "test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)\n",
    "\n",
    "X = train_dataset_orig.data\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "X = train_dataset_orig.data\n",
    "y = train_dataset_orig.targets\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=150)\n",
    "\n",
    "\n",
    "# Step 2: Modify ResNet18 for CIFAR-10 (10 classes)\n",
    "model = models.resnet18(pretrained=False)  # Do not load pretrained weights\n",
    "model.fc = nn.Linear(model.fc.in_features, 10)  # Modify the final layer for CIFAR-10 (10 classes)\n",
    "\n",
    "# Step 3: Define Loss Function and Optimizer\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# Step 4: Train the Model\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model.to(device)\n",
    "\n",
    "num_epochs = 5  # You can adjust the number of epochs\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    running_loss = 0.0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for inputs, labels in train_loader:\n",
    "        inputs, labels = inputs.to(device), labels.to(device)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        running_loss += loss.item()\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%')\n",
    "\n",
    "# Step 5: Fine-tuned model is ready\n",
    "\n",
    "# Step 6: Wrap the model for MAPIE (similar to earlier)\n",
    "class TorchClassifierWrapper:\n",
    "    def __init__(self, model, transform=None, device=None):\n",
    "        self.model = model\n",
    "        self.transform = transform\n",
    "        self.device = device or (\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        self.model.to(self.device)\n",
    "        self.classes_ = np.arange(10)  # CIFAR-10 has 10 classes\n",
    "        # self.n_features_in_ = None  # Set in fit()\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        self.classes_ = np.unique(y)\n",
    "        # self.n_features_in_ = X.shape[1:]  # Image shape, e.g., (3, 224, 224)\n",
    "        return self\n",
    "\n",
    "    def predict_proba(self, X):\n",
    "        from sklearn.utils.validation import check_is_fitted\n",
    "        check_is_fitted(self, [\"classes_\"])\n",
    "        \n",
    "        self.model.eval()\n",
    "        \n",
    "        tensors = torch.stack([transform_otf(image) for image in X])\n",
    "\n",
    "        with torch.no_grad():\n",
    "            # Transform and move data to device\n",
    "            if isinstance(X, np.ndarray):\n",
    "                X = torch.tensor(X).float()\n",
    "            X_tensor = tensors.to(self.device)\n",
    "            logits = self.model(X_tensor)\n",
    "            probabilities = torch.nn.functional.softmax(logits, dim=1)\n",
    "        return probabilities.cpu().numpy()\n",
    "\n",
    "    def predict(self, X):\n",
    "        probabilities = self.predict_proba(X)\n",
    "        return np.argmax(probabilities, axis=1)\n",
    "\n",
    "    def __sklearn_is_fitted__(self):\n",
    "        return True\n",
    "\n",
    "    # def get_params(self, deep=True):\n",
    "    #     return {\"model\": self.model, \"transform\": self.transform, \"device\": self.device}\n",
    "\n",
    "    # def set_params(self, **params):\n",
    "    #     for param, value in params.items():\n",
    "    #         setattr(self, param, value)\n",
    "    #     return self\n",
    "\n",
    "# Wrap the model\n",
    "wrapped_model = TorchClassifierWrapper(model=model, transform=transform)\n",
    "\n",
    "# Step 7: Use MAPIE for conformal predictions\n",
    "from mapie.classification import MapieClassifier\n",
    "\n",
    "# Create a dummy fit call (MAPIE needs this to work)\n",
    "X_cal, y_cal = X_test, y_test\n",
    "\n",
    "wrapped_model.fit(X_cal, y_cal)  # Fit the model wrapper\n",
    "mapie = MapieClassifier(estimator=wrapped_model, method=\"aps\", cv=\"prefit\")\n",
    "mapie.fit(X_cal, y_cal)  # Fit MAPIE with dummy data\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 8: Make predictions on CIFAR-10 test set\n",
    "\n",
    "y_pred, y_pred_set = mapie.predict(X_train[:20], alpha=0.1)\n",
    "y_pred_set = [\n",
    "    np.where(prediction_set_row)[0].tolist()\n",
    "    for prediction_set_row in y_pred_set\n",
    "]\n",
    "print(\"True values:\", y_train[:5])\n",
    "print(\"Predicted labels:\", y_pred[:5])\n",
    "print(\"Prediction sets:\", y_pred_set[:5])\n",
    "    # break  # Remove to process the entire dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.models as models\n",
    "from torch.utils.data import Subset\n",
    "from torch.utils.data import DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# Step 1: Prepare CIFAR-10 Dataset\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize((224, 224)),  # Resize to 224x224 (for ResNet)\n",
    "    transforms.ToTensor(),\n",
    "    # transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])  # CIFAR-10 normalization\n",
    "])\n",
    "\n",
    "\n",
    "# Step 1: Prepare CIFAR-10 Dataset\n",
    "transform_otf = transforms.Compose([\n",
    "    transforms.ToPILImage(),\n",
    "    transforms.Resize((224, 224)),  # Resize to 224x224 (for ResNet)\n",
    "    transforms.ToTensor(),\n",
    "    # transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])  # CIFAR-10 normalization\n",
    "])\n",
    "\n",
    "train_dataset_orig = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
    "train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from mapie.classification import MapieClassifier\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.datasets import make_classification\n",
    "from sklearn.datasets import load_wine, load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "from models.classifier_model import ClassifierModel\n",
    "\n",
    "# Generate synthetic data\n",
    "X, y = make_classification(n_classes=10,n_samples=15000, n_features=15, random_state=42,n_informative=5)\n",
    "# X,y = load_iris(return_X_y=True)\n",
    "X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=1/4, shuffle=True)\n",
    "X_train, X_cal, y_train, y_cal = train_test_split(X_train,y_train,test_size=1/3, shuffle=True)\n",
    "\n",
    "# Fit MapieClassifiers\n",
    "clf = ClassifierModel(input_dim=X.shape[1],hidden_dim=16,output_dim=len(np.unique(y)))\n",
    "clf.fit(X_train, y_train)\n",
    "# clf = RandomForestClassifier()\n",
    "# clf.fit(X_train, y_train)\n",
    "mapie = MapieClassifier(estimator=clf, method=\"score\", cv=\"prefit\")\n",
    "mapie.fit(X_cal, y_cal)\n",
    "\n",
    "# Predict with alpha for prediction sets\n",
    "alpha = 0.05\n",
    "predictions, prediction_sets = mapie.predict(X_test, alpha=alpha)\n",
    "\n",
    "plausible_labels = [\n",
    "    np.where(prediction_set_row)[0].tolist()  # Extract indices where value is True\n",
    "    for prediction_set_row in prediction_sets\n",
    "]\n",
    "\n",
    "\n",
    "print(plausible_labels)\n",
    "\n",
    "# Coverage Calculation\n",
    "coverage = np.mean([\n",
    "    y_test[i] in plausible_labels[i]\n",
    "    for i in range(len(y_test))\n",
    "])\n",
    "\n",
    "# Efficiency Calculation\n",
    "efficiency = np.mean([\n",
    "    len(plausible_labels[i])\n",
    "    for i in range(len(plausible_labels))\n",
    "])\n",
    "\n",
    "print(f\"Coverage: {coverage:.2f}\")\n",
    "print(f\"Efficiency: {efficiency:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mapie.estimator.gradient_updates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "plnet",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
