{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.base import BaseEstimator, ClassifierMixin\n",
    "from scipy.stats import norm\n",
    "import torch\n",
    "import numpy as np\n",
    "from sklearn.base import BaseEstimator, ClassifierMixin\n",
    "from torch.distributions import Normal\n",
    "\n",
    "class TorchGaussianSyntheticClassifier(BaseEstimator, ClassifierMixin):\n",
    "    def __init__(self, class_params=None, device=\"cpu\"):\n",
    "        \"\"\"\n",
    "        Initialize the classifier with class parameters.\n",
    "\n",
    "        Parameters:\n",
    "        class_params: dict\n",
    "            A dictionary where keys are class labels and values are dictionaries with\n",
    "            'mean', 'std', and 'prior' for each class.\n",
    "        device: str\n",
    "            The device to use for computations ('cpu' or 'cuda').\n",
    "        \"\"\"\n",
    "        self.class_params = class_params if class_params is not None else {}\n",
    "        self.device = torch.device(device)\n",
    "        self.classes_ = torch.tensor(list(self.class_params.keys()), device=self.device)\n",
    "        self.n_features_ = 1\n",
    "        self.n_classes_ = 3\n",
    "    \n",
    "    def fit(self, X, y=None):\n",
    "        \"\"\"\n",
    "        Fit method for compatibility. This classifier doesn't require fitting.\n",
    "        \"\"\"\n",
    "        self.classes_ = torch.tensor(list(self.class_params.keys()), device=self.device)\n",
    "        return self\n",
    "    \n",
    "    def predict_proba(self, X):\n",
    "        \"\"\"\n",
    "        Predict the probability of each class for the given input data X.\n",
    "\n",
    "        Parameters:\n",
    "        X: torch.Tensor or array-like of shape (n_samples,)\n",
    "            Input features.\n",
    "\n",
    "        Returns:\n",
    "        probs: torch.Tensor of shape (n_samples, n_classes)\n",
    "            Predicted probabilities for each class.\n",
    "        \"\"\"\n",
    "        if not isinstance(X, torch.Tensor):\n",
    "            X = torch.tensor(X, device=self.device, dtype=torch.float32)\n",
    "        \n",
    "        probs = torch.zeros((len(X), len(self.classes_)), device=self.device)\n",
    "        for i, c in enumerate(self.classes_):\n",
    "            mean = self.class_params[int(c)][\"mean\"]\n",
    "            std = self.class_params[int(c)][\"std\"]\n",
    "            prior = self.class_params[int(c)][\"prior\"]\n",
    "            \n",
    "            # Calculate Gaussian PDF: P(x|y=c)\n",
    "            normal_dist = Normal(loc=mean, scale=std)\n",
    "            px_given_y = torch.exp(normal_dist.log_prob(X))\n",
    "            \n",
    "            # Combine with prior: P(x|y=c) * P(y=c)\n",
    "            probs[:, i] = px_given_y * prior\n",
    "        \n",
    "        # Normalize to get P(y=c|x)\n",
    "        probs /= probs.sum(dim=1, keepdim=True)\n",
    "        return probs\n",
    "    \n",
    "    def predict(self, X):\n",
    "        \"\"\"\n",
    "        Predict the class label for each sample in X.\n",
    "\n",
    "        Parameters:\n",
    "        X: torch.Tensor or array-like of shape (n_samples,)\n",
    "            Input features.\n",
    "\n",
    "        Returns:\n",
    "        predictions: torch.Tensor of shape (n_samples,)\n",
    "            Predicted class labels.\n",
    "        \"\"\"\n",
    "        probs = self.predict_proba(X)\n",
    "        return self.classes_[torch.argmax(probs, dim=1)]\n",
    "\n",
    "    def generate_data(self, n_samples=100):\n",
    "        \"\"\"\n",
    "        Generate synthetic data using the predefined class parameters.\n",
    "\n",
    "        Parameters:\n",
    "        n_samples: int\n",
    "            Number of samples to generate.\n",
    "\n",
    "        Returns:\n",
    "        X: torch.Tensor of shape (n_samples,)\n",
    "            Generated features.\n",
    "        y: torch.Tensor of shape (n_samples,)\n",
    "            Generated labels.\n",
    "        \"\"\"\n",
    "        X = []\n",
    "        y = []\n",
    "        for _ in range(n_samples):\n",
    "            # Sample class based on priors\n",
    "            sampled_class = torch.multinomial(\n",
    "                torch.tensor([self.class_params[int(c)][\"prior\"] for c in self.classes_], device=self.device),\n",
    "                num_samples=1\n",
    "            ).item()\n",
    "            # Sample feature value from the corresponding Gaussian\n",
    "            mean = self.class_params[int(self.classes_[sampled_class])][\"mean\"]\n",
    "            std = self.class_params[int(self.classes_[sampled_class])][\"std\"]\n",
    "            normal_dist = Normal(loc=mean, scale=std)\n",
    "            sampled_x = normal_dist.sample().item()\n",
    "            X.append(sampled_x)\n",
    "            y.append(self.classes_[sampled_class].item())\n",
    "        \n",
    "        return torch.tensor(X, device=self.device), torch.tensor(y, device=self.device)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import rand\n",
    "from torchcp.classification.score import APS, THR, RAPS, SAPS\n",
    "aps = APS(score_type=\"identity\", randomized=False)\n",
    "aps_randomized = APS(score_type=\"identity\", randomized=True)\n",
    "lac = THR(score_type=\"identity\",)\n",
    "raps = RAPS(score_type=\"identity\", randomized=True, penalty=2,kreg=1)\n",
    "saps = SAPS(score_type=\"identity\", randomized=True)\n",
    "score = aps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from turtle import color\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set(font_scale=1.5,rc={'text.usetex' : True})\n",
    "sns.set_style(\"whitegrid\")\n",
    "plt.rc('font', **{'family': 'serif'})\n",
    "plt.rcParams[\"figure.figsize\"] = (14, 3)\n",
    "\n",
    "# Step 1: Define class parameters\n",
    "class_params = {\n",
    "    0: {\"mean\": 1, \"std\": 1, \"prior\": 0.3},\n",
    "    1: {\"mean\": 3, \"std\": 1, \"prior\": 0.4},\n",
    "    # 2: {\"mean\": 4, \"std\": 2.2, \"prior\": 0.3},\n",
    "}   \n",
    "\n",
    "# Create an instance of the Torch-based classifier\n",
    "clf = TorchGaussianSyntheticClassifier(class_params=class_params, device=\"cpu\")\n",
    "\n",
    "# Generate synthetic data\n",
    "X, y = clf.generate_data(n_samples=1000)\n",
    "\n",
    "# Step 4: Fit the classifier (it's effectively a no-op, but we call it for sklearn compatibility)\n",
    "clf.fit(X, y)\n",
    "\n",
    "# Step 5: Predict probabilities for a range of x values (for plotting)\n",
    "x_range = torch.linspace(X.min() - 1, X.max() + 1, 1000)\n",
    "probs = clf.predict_proba(x_range)\n",
    "\n",
    "# Step 6: Plot the data and probabilities\n",
    "fig, ax = plt.subplots(1,3, sharey= True)\n",
    "ax = ax.ravel()\n",
    "\n",
    "\n",
    "\n",
    "# Scatter plot of the synthetic data\n",
    "# plt.scatter(X, y + np.random.uniform(-0.1, 0.1, len(y)), c=y, cmap='viridis', alpha=0.6, edgecolor='k', label='Data')\n",
    "\n",
    "# Plot the probabilities for each class\n",
    "scores = [lac, aps, aps_randomized]\n",
    "for i, score in enumerate(scores):\n",
    "\n",
    "    colors = [\"blue\", \"red\"]\n",
    "    for j, c in enumerate(clf.classes_):\n",
    "        ax[i].plot(x_range, probs[:, j], label=rf'$p(y={c} \\mid x)$', linewidth=2, alpha=0.25, color=colors[j])\n",
    "\n",
    "    x_min, x_max = -2, 6\n",
    "\n",
    "    X_nc = torch.linspace(x_min, x_max,300)\n",
    "    y_0  = torch.full(size=(300,),fill_value=0)\n",
    "    y_1  = torch.full(size=(300,),fill_value=1)\n",
    "    # y_2  = torch.full(size=(300,),fill_value=2)\n",
    "\n",
    "\n",
    "    y_0_nc = score(clf.predict_proba(X_nc),y_0).detach().cpu().numpy()\n",
    "    y_1_nc = score(clf.predict_proba(X_nc),y_1).detach().cpu().numpy()\n",
    "    # y_2_nc = score(clf.predict_proba(X_nc),y_2).detach().cpu().numpy()\n",
    "\n",
    "\n",
    "\n",
    "    ax[i].plot(X_nc.flatten(), y_0_nc, linestyle=\"--\", label=rf\"$s(x,y=0)$\", color=\"darkblue\")\n",
    "    ax[i].plot(X_nc.flatten(), y_1_nc, linestyle=\"--\", label=rf\"$s(x,y=1)$\", color=\"darkred\")\n",
    "    # ax[i].plot(X_nc.flatten(), y_2_nc, linestyle=\"--\", label=\"APS for y=2\")\n",
    "\n",
    "    # Labels and title\n",
    "    # ax[i].set_title(\"Conformity scores for synthetic data\")\n",
    "    ax[i].set_xlabel(\"Feature Value ($x$)\")\n",
    "    ax[i].set_xticks([])\n",
    "    ax[i].set_yticks([0,1])\n",
    "    ax[i].set_xlim([x_min+1, x_max-1])\n",
    "\n",
    "\n",
    "# plt.legend()\n",
    "ax[0].set_title(\"LAC\")\n",
    "ax[0].set_ylabel(\"Class Probability\")\n",
    "ax[0].get_legend()\n",
    "ax[1].set_title(\"APS\")\n",
    "ax[2].set_title(\"Randomized APS\")\n",
    "\n",
    "handles, labels = ax[0].get_legend_handles_labels()\n",
    "lgd = fig.legend(handles,labels,loc='upper center', ncol=4, bbox_to_anchor=(0.5, 0.08), frameon=False)\n",
    "\n",
    "fig.tight_layout() \n",
    "plt.savefig(\"lac_aps.pdf\",bbox_extra_artists=(lgd,), bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "plnet",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
