{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from sklearn.base import BaseEstimator, ClassifierMixin\n",
    "from torch.distributions.multivariate_normal import MultivariateNormal\n",
    "from torch.distributions import Normal\n",
    "import os, sys\n",
    "sys.path.append(os.path.join(\"/home/ra43rid/torch_plnet\"))\n",
    "torch.set_default_device(\"cuda\")\n",
    "\n",
    "class TorchMultivariateGaussianClassifier(BaseEstimator, ClassifierMixin):\n",
    "    def __init__(self, class_params=None, device=\"cpu\"):\n",
    "        \"\"\"\n",
    "        Initialize the classifier with class parameters.\n",
    "\n",
    "        Parameters:\n",
    "        class_params: dict\n",
    "            A dictionary where keys are class labels and values are dictionaries with\n",
    "            'mean' (vector), 'cov' (matrix), and 'prior' for each class.\n",
    "        device: str\n",
    "            The device to use for computations ('cpu' or 'cuda').\n",
    "        \"\"\"\n",
    "        self.class_params = class_params if class_params is not None else {}\n",
    "        self.device = torch.device(device)\n",
    "        self.classes_ = torch.tensor(list(self.class_params.keys()), device=self.device)\n",
    "        self.n_features = len(self.class_params[0][\"cov\"])\n",
    "        self.n_classes = len(self.classes_)\n",
    "    \n",
    "    def fit(self, X, y=None):\n",
    "        \"\"\"\n",
    "        Fit method for compatibility. This classifier doesn't require fitting.\n",
    "        \"\"\"\n",
    "        self.classes_ = torch.tensor(list(self.class_params.keys()), device=self.device)\n",
    "        return self\n",
    "    \n",
    "    def predict_proba(self, X):\n",
    "        \"\"\"\n",
    "        Predict the probability of each class for the given input data X.\n",
    "\n",
    "        Parameters:\n",
    "        X: torch.Tensor or array-like of shape (n_samples, n_features)\n",
    "            Input features.\n",
    "\n",
    "        Returns:\n",
    "        probs: torch.Tensor of shape (n_samples, n_classes)\n",
    "            Predicted probabilities for each class.\n",
    "        \"\"\"\n",
    "        if not isinstance(X, torch.Tensor):\n",
    "            X = torch.tensor(X, device=self.device, dtype=torch.float32)\n",
    "        \n",
    "        probs = torch.zeros((X.shape[0], len(self.classes_)), device=self.device)\n",
    "        for i, c in enumerate(self.classes_):\n",
    "            mean = torch.tensor(self.class_params[int(c)][\"mean\"], device=self.device, dtype=torch.float32)\n",
    "            cov = torch.tensor(self.class_params[int(c)][\"cov\"], device=self.device, dtype=torch.float32)\n",
    "            prior = self.class_params[int(c)][\"prior\"]\n",
    "            \n",
    "            # Multivariate normal distribution\n",
    "            mvn_dist = MultivariateNormal(mean, covariance_matrix=cov)\n",
    "            px_given_y = torch.exp(mvn_dist.log_prob(X))\n",
    "            \n",
    "            # Combine with prior\n",
    "            probs[:, i] = px_given_y * prior\n",
    "        \n",
    "        # Normalize to get P(y=c|x)\n",
    "        probs /= probs.sum(dim=1, keepdim=True)\n",
    "        return probs\n",
    "    \n",
    "    def predict(self, X):\n",
    "        \"\"\"\n",
    "        Predict the class label for each sample in X.\n",
    "\n",
    "        Parameters:\n",
    "        X: torch.Tensor or array-like of shape (n_samples, n_features)\n",
    "            Input features.\n",
    "\n",
    "        Returns:\n",
    "        predictions: torch.Tensor of shape (n_samples,)\n",
    "            Predicted class labels.\n",
    "        \"\"\"\n",
    "        probs = self.predict_proba(X)\n",
    "        return self.classes_[torch.argmax(probs, dim=1)]\n",
    "    \n",
    "    def generate_data(self, n_samples=100):\n",
    "        \"\"\"\n",
    "        Generate synthetic data using the predefined class parameters.\n",
    "\n",
    "        Parameters:\n",
    "        n_samples: int\n",
    "            Number of samples to generate.\n",
    "\n",
    "        Returns:\n",
    "        X: torch.Tensor of shape (n_samples, n_features)\n",
    "            Generated features.\n",
    "        y: torch.Tensor of shape (n_samples,)\n",
    "            Generated labels.\n",
    "        \"\"\"\n",
    "        X = []\n",
    "        y = []\n",
    "        for _ in range(n_samples):\n",
    "            # Sample class based on priors\n",
    "            sampled_class = torch.multinomial(\n",
    "                torch.tensor([self.class_params[int(c)][\"prior\"] for c in self.classes_], device=self.device),\n",
    "                num_samples=1\n",
    "            ).item()\n",
    "            mean = torch.tensor(self.class_params[int(self.classes_[sampled_class])][\"mean\"], device=self.device, dtype=torch.float32)\n",
    "            cov = torch.tensor(self.class_params[int(self.classes_[sampled_class])][\"cov\"], device=self.device, dtype=torch.float32)\n",
    "            \n",
    "            # Sample feature vector from the corresponding multivariate Gaussian\n",
    "            mvn_dist = MultivariateNormal(mean, covariance_matrix=cov)\n",
    "            sampled_x = mvn_dist.sample()\n",
    "            X.append(sampled_x)\n",
    "            y.append(self.classes_[sampled_class].item())\n",
    "        \n",
    "        return torch.stack(X), torch.tensor(y, device=self.device)\n",
    "\n",
    "class TorchGaussianSyntheticClassifier(BaseEstimator, ClassifierMixin):\n",
    "    def __init__(self, class_params=None, device=\"cpu\"):\n",
    "        \"\"\"\n",
    "        Initialize the classifier with class parameters.\n",
    "\n",
    "        Parameters:\n",
    "        class_params: dict\n",
    "            A dictionary where keys are class labels and values are dictionaries with\n",
    "            'mean', 'std', and 'prior' for each class.\n",
    "        device: str\n",
    "            The device to use for computations ('cpu' or 'cuda').\n",
    "        \"\"\"\n",
    "        self.class_params = class_params if class_params is not None else {}\n",
    "        self.device = torch.device(device)\n",
    "        self.classes_ = torch.tensor(list(self.class_params.keys()), device=self.device)\n",
    "        self.n_features = 1\n",
    "        self.n_classes = len(self.classes_)\n",
    "    \n",
    "    def fit(self, X, y=None):\n",
    "        \"\"\"\n",
    "        Fit method for compatibility. This classifier doesn't require fitting.\n",
    "        \"\"\"\n",
    "        self.classes_ = torch.tensor(list(self.class_params.keys()), device=self.device)\n",
    "        return self\n",
    "    \n",
    "    def predict_proba(self, X):\n",
    "        \"\"\"\n",
    "        Predict the probability of each class for the given input data X.\n",
    "\n",
    "        Parameters:\n",
    "        X: torch.Tensor or array-like of shape (n_samples,)\n",
    "            Input features.\n",
    "\n",
    "        Returns:\n",
    "        probs: torch.Tensor of shape (n_samples, n_classes)\n",
    "            Predicted probabilities for each class.\n",
    "        \"\"\"\n",
    "        if not isinstance(X, torch.Tensor):\n",
    "            X = torch.tensor(X, device=self.device, dtype=torch.float32)\n",
    "        \n",
    "        probs = torch.zeros((len(X), len(self.classes_)), device=self.device)\n",
    "        for i, c in enumerate(self.classes_):\n",
    "            mean = self.class_params[int(c)][\"mean\"]\n",
    "            std = self.class_params[int(c)][\"std\"]\n",
    "            prior = self.class_params[int(c)][\"prior\"]\n",
    "            \n",
    "            # Calculate Gaussian PDF: P(x|y=c)\n",
    "            normal_dist = Normal(loc=mean, scale=std)\n",
    "            px_given_y = torch.exp(normal_dist.log_prob(X))\n",
    "            \n",
    "            # Combine with prior: P(x|y=c) * P(y=c)\n",
    "            probs[:, i] = px_given_y * prior\n",
    "        \n",
    "        # Normalize to get P(y=c|x)\n",
    "        probs /= probs.sum(dim=1, keepdim=True)\n",
    "        return probs\n",
    "    \n",
    "    def predict(self, X):\n",
    "        \"\"\"\n",
    "        Predict the class label for each sample in X.\n",
    "\n",
    "        Parameters:\n",
    "        X: torch.Tensor or array-like of shape (n_samples,)\n",
    "            Input features.\n",
    "\n",
    "        Returns:\n",
    "        predictions: torch.Tensor of shape (n_samples,)\n",
    "            Predicted class labels.\n",
    "        \"\"\"\n",
    "        probs = self.predict_proba(X)\n",
    "        return self.classes_[torch.argmax(probs, dim=1)]\n",
    "\n",
    "    def generate_data(self, n_samples=100):\n",
    "        \"\"\"\n",
    "        Generate synthetic data using the predefined class parameters.\n",
    "\n",
    "        Parameters:\n",
    "        n_samples: int\n",
    "            Number of samples to generate.\n",
    "\n",
    "        Returns:\n",
    "        X: torch.Tensor of shape (n_samples,)\n",
    "            Generated features.\n",
    "        y: torch.Tensor of shape (n_samples,)\n",
    "            Generated labels.\n",
    "        \"\"\"\n",
    "        X = []\n",
    "        y = []\n",
    "        for _ in range(n_samples):\n",
    "            # Sample class based on priors\n",
    "            sampled_class = torch.multinomial(\n",
    "                torch.tensor([self.class_params[int(c)][\"prior\"] for c in self.classes_], device=self.device),\n",
    "                num_samples=1\n",
    "            ).item()\n",
    "            # Sample feature value from the corresponding Gaussian\n",
    "            mean = self.class_params[int(self.classes_[sampled_class])][\"mean\"]\n",
    "            std = self.class_params[int(self.classes_[sampled_class])][\"std\"]\n",
    "            normal_dist = Normal(loc=mean, scale=std)\n",
    "            sampled_x = normal_dist.sample().item()\n",
    "            X.append(sampled_x)\n",
    "            y.append(self.classes_[sampled_class].item())\n",
    "        \n",
    "        return torch.tensor(X, device=self.device), torch.tensor(y, device=self.device)\n",
    "\n",
    "def goodman_kruskal_gamma(x, y):\n",
    "    \"\"\"\n",
    "    Compute Goodman and Kruskal's Gamma for two ordinal variables.\n",
    "    \n",
    "    Parameters:\n",
    "    x, y: Lists or arrays of ordinal data (same length)\n",
    "    \n",
    "    Returns:\n",
    "    gamma: Goodman and Kruskal's Gamma\n",
    "    \"\"\"\n",
    "    if len(x) != len(y):\n",
    "        raise ValueError(\"Both variables must have the same length.\")\n",
    "    \n",
    "    concordant = 0\n",
    "    discordant = 0\n",
    "    \n",
    "    n = len(x)\n",
    "    for i in range(n - 1):\n",
    "        for j in range(i + 1, n):\n",
    "            # Determine concordance or discordance\n",
    "            if (x[i] > x[j] and y[i] > y[j]) or (x[i] < x[j] and y[i] < y[j]):\n",
    "                concordant += 1\n",
    "            elif (x[i] > x[j] and y[i] < y[j]) or (x[i] < x[j] and y[i] > y[j]):\n",
    "                discordant += 1\n",
    "    \n",
    "    # Compute Gamma\n",
    "    if concordant + discordant == 0:\n",
    "        return 0  # Avoid division by zero\n",
    "    gamma = (concordant - discordant) / (concordant + discordant)\n",
    "    return gamma\n",
    "\n",
    "from torchcp.classification.score import APS, THR, SAPS\n",
    "aps = APS(score_type=\"identity\", randomized=False)\n",
    "rand_aps = APS(score_type=\"identity\", randomized=True)\n",
    "lac = THR(score_type=\"identity\",)\n",
    "saps = SAPS(score_type=\"identity\",randomized=False)\n",
    "\n",
    "class_params_1f_2c = {\n",
    "    0: {\"mean\": 1, \"std\": 1, \"prior\": 0.3},\n",
    "    1: {\"mean\": 3, \"std\": 1, \"prior\": 0.4},\n",
    "    # 2: {\"mean\": 4, \"std\": 2.2, \"prior\": 0.3},\n",
    "}   \n",
    "\n",
    "class_params_1f_3c = {\n",
    "    0: {\"mean\": 1, \"std\": 1, \"prior\": 0.3},\n",
    "    1: {\"mean\": 3, \"std\": 1, \"prior\": 0.4},\n",
    "    2: {\"mean\": 4, \"std\": 2.2, \"prior\": 0.3},\n",
    "}   \n",
    "\n",
    "# Initialize and fit the generator\n",
    "class_params_2d_3c = {\n",
    "    0: {\n",
    "        \"mean\": [3.0, 2.0],  # Mean vector for class 0\n",
    "        \"cov\": [\n",
    "            [1.0, 0.5],\n",
    "            [0.5, 1.2],\n",
    "            # [0.3, 0.4, 0.8]\n",
    "        ],  # Covariance matrix for class 0\n",
    "        \"prior\": 0.3  # Prior probability for class 0\n",
    "    },\n",
    "    1: {\n",
    "        \"mean\": [3.0, 4.0],  # Mean vector for class 1\n",
    "        \"cov\": [\n",
    "            [1.5, 0.3],\n",
    "            [0.3, 1.1],\n",
    "            # [0.2, 0.1, 0.9]\n",
    "        ],  # Covariance matrix for class 1\n",
    "        \"prior\": 0.4  # Prior probability for class 1\n",
    "    },\n",
    "    2: {\n",
    "        \"mean\": [1.0, 2.0],  # Mean vector for class 2\n",
    "        \"cov\": [\n",
    "            [1.2, 0.4],\n",
    "            [0.4, 1.3],\n",
    "            # [0.3, 0.5, 1.4]\n",
    "        ],  # Covariance matrix for class 2\n",
    "        \"prior\": 0.3  # Prior probability for class 2\n",
    "    },\n",
    "}\n",
    "\n",
    "# Initialize and fit the generator\n",
    "class_params_3d_3c = {\n",
    "    0: {\n",
    "        \"mean\": [3.0, 2.0, 4.0],  # Mean vector for class 0\n",
    "        \"cov\": [\n",
    "            [1.0, 0.5, 0.3],\n",
    "            [0.5, 1.2, 0.4],\n",
    "            [0.3, 0.4, 0.8]\n",
    "        ],  # Covariance matrix for class 0\n",
    "        \"prior\": 0.3  # Prior probability for class 0\n",
    "    },\n",
    "    1: {\n",
    "        \"mean\": [3.0, 4.0, 1.0],  # Mean vector for class 1\n",
    "        \"cov\": [\n",
    "            [1.5, 0.3, 0.2],\n",
    "            [0.3, 1.1, 0.1],\n",
    "            [0.2, 0.1, 0.9]\n",
    "        ],  # Covariance matrix for class 1\n",
    "        \"prior\": 0.4  # Prior probability for class 1\n",
    "    },\n",
    "    2: {\n",
    "        \"mean\": [1.0, 2.0, 2.0],  # Mean vector for class 2\n",
    "        \"cov\": [\n",
    "            [1.2, 0.4, 0.3],\n",
    "            [0.4, 1.3, 0.5],\n",
    "            [0.3, 0.5, 1.4]\n",
    "        ],  # Covariance matrix for class 2\n",
    "        \"prior\": 0.3  # Prior probability for class 2\n",
    "    },\n",
    "}\n",
    "\n",
    "\n",
    "# Initialize and fit the generator\n",
    "class_params_2d_2c = {\n",
    "    0: {\n",
    "        \"mean\": [3.0, 2.0],  # Mean vector for class 0\n",
    "        \"cov\": [\n",
    "            [1.0, 0.5],\n",
    "            [0.5, 1.2],\n",
    "        ],  # Covariance matrix for class 0\n",
    "        \"prior\": 0.3  # Prior probability for class 0\n",
    "    },\n",
    "    1: {\n",
    "        \"mean\": [2.0, 3.0],  # Mean vector for class 1\n",
    "        \"cov\": [\n",
    "            [1.5, 0.3],\n",
    "            [0.3, 1.1],\n",
    "        ],  # Covariance matrix for class 1\n",
    "        \"prior\": 0.4  # Prior probability for class 1\n",
    "    },\n",
    "    # 2: {\n",
    "    #     \"mean\": [1.0, 2.0, 2.0],  # Mean vector for class 2\n",
    "    #     \"cov\": [\n",
    "    #         [1.2, 0.4],\n",
    "    #         [0.4, 1.3],\n",
    "    #         # [0.3, 0.5, 1.4]\n",
    "    #     ],  # Covariance matrix for class 2\n",
    "    #     \"prior\": 0.3  # Prior probability for class 2\n",
    "    # },\n",
    "}\n",
    "\n",
    "clf_1d_2c = TorchGaussianSyntheticClassifier(class_params=class_params_1f_2c, device=\"cuda\")\n",
    "clf_1d_3c = TorchGaussianSyntheticClassifier(class_params=class_params_1f_3c, device=\"cuda\")\n",
    "clf_3d_3c = TorchMultivariateGaussianClassifier(class_params=class_params_3d_3c, device=\"cuda\")\n",
    "clf_2d_2c = TorchMultivariateGaussianClassifier(class_params=class_params_2d_2c, device=\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class OracleAnnotator:\n",
    "    def __init__(self,score, generator):\n",
    "        self.score = score\n",
    "        self.classes_ = generator.classes_\n",
    "        self.generator = generator\n",
    "\n",
    "    # we assume y is already label encoded\n",
    "    def get_conformity(self, X, y):\n",
    "        y_pred_proba = self.generator.predict_proba(X)\n",
    "        scores = self.score(y_pred_proba, y)\n",
    "        return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from zmq import device\n",
    "from util.ranking_datasets import LabelPairDataset\n",
    "from models.ranking_models import LabelRankingModel\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "from scipy.stats import kendalltau\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from scipy.stats import kendalltau\n",
    "from joblib import Parallel, delayed\n",
    "torch.set_default_device(\"cuda\")\n",
    "\n",
    "def fit_model_with_all_pairs(X_train, y_train, oracle_annotator, generator, learning_rate = 0.01, num_epochs=200, X_val = None, y_val = None):\n",
    "    conformities = oracle_annotator.get_conformity(torch.tensor(X_train, device=\"cuda\"),torch.tensor(y_train,device=\"cuda\")).detach().cpu().numpy()\n",
    "\n",
    "    sort_idx = (-conformities).argsort(axis=0).flatten()\n",
    "    X_sorted = X_train[sort_idx]\n",
    "    X_sorted = X_sorted.detach().cpu().numpy()\n",
    "    y_sorted = y_train[sort_idx]\n",
    "    conformities_sorted = conformities[sort_idx]\n",
    "    \n",
    "    X_pairs = np.array([(X_sorted[i], X_sorted[j]) for i in range(len(X_sorted)) for j in range(i + 1, len(X_sorted))])\n",
    "    y_pairs = np.array([(y_sorted[i], y_sorted[j]) for i in range(len(y_sorted)) for j in range(i + 1, len(y_sorted))])\n",
    "    conformity_pairs = np.array([(conformities_sorted[i], conformities_sorted[j]) for i in range(len(conformities_sorted)) for j in range(i + 1, len(conformities_sorted))])\n",
    "    conformity_pairs = conformity_pairs.round(6)\n",
    "    mask = conformity_pairs[:,0] == conformity_pairs[:,1]\n",
    "    \n",
    "    X_pairs_distinct = X_pairs[~mask]\n",
    "    y_pairs_distinct = y_pairs[~mask]\n",
    "    X_pairs_nondistinct = X_pairs[mask]\n",
    "    y_pairs_nondistinct = y_pairs[mask]\n",
    "    X_pairs_nondistinct_swp = X_pairs_nondistinct[:,::-1]\n",
    "    y_pairs_nondistinct_swp = y_pairs_nondistinct[:,::-1]\n",
    "\n",
    "    X_pairs_augmented = np.vstack((X_pairs_distinct, X_pairs_nondistinct, X_pairs_nondistinct_swp))\n",
    "    y_pairs_augmented = np.vstack((y_pairs_distinct, y_pairs_nondistinct, y_pairs_nondistinct_swp))\n",
    "\n",
    "    y_pairs_augmented = np.expand_dims(y_pairs_augmented,axis=-1)\n",
    "\n",
    "    ds = LabelPairDataset()\n",
    "    ds.create_from_numpy_pairs(X_pairs_augmented, y_pairs_augmented)\n",
    "\n",
    "\n",
    "    # validation data\n",
    "    if X_val is not None:\n",
    "        conformities = oracle_annotator.get_conformity(torch.tensor(X_train, device=\"cuda\"),torch.tensor(y_train,device=\"cuda\")).detach().cpu().numpy()\n",
    "\n",
    "        sort_idx = (-conformities).argsort(axis=0).flatten()\n",
    "        X_sorted = X_train[sort_idx]\n",
    "        X_sorted = X_sorted.detach().cpu().numpy()\n",
    "        y_sorted = y_train[sort_idx]\n",
    "        conformities_sorted = conformities[sort_idx]\n",
    "        \n",
    "        X_pairs = np.array([(X_sorted[i], X_sorted[j]) for i in range(len(X_sorted)) for j in range(i + 1, len(X_sorted))])\n",
    "        y_pairs = np.array([(y_sorted[i], y_sorted[j]) for i in range(len(y_sorted)) for j in range(i + 1, len(y_sorted))])\n",
    "        conformity_pairs = np.array([(conformities_sorted[i], conformities_sorted[j]) for i in range(len(conformities_sorted)) for j in range(i + 1, len(conformities_sorted))])\n",
    "        conformity_pairs = conformity_pairs.round(6)\n",
    "        mask = conformity_pairs[:,0] == conformity_pairs[:,1]\n",
    "        \n",
    "        X_pairs_distinct = X_pairs[~mask]\n",
    "        y_pairs_distinct = y_pairs[~mask]\n",
    "        X_pairs_nondistinct = X_pairs[mask]\n",
    "        y_pairs_nondistinct = y_pairs[mask]\n",
    "        X_pairs_nondistinct_swp = X_pairs_nondistinct[:,::-1]\n",
    "        y_pairs_nondistinct_swp = y_pairs_nondistinct[:,::-1]\n",
    "\n",
    "        X_pairs_augmented = np.vstack((X_pairs_distinct, X_pairs_nondistinct, X_pairs_nondistinct_swp))\n",
    "        y_pairs_augmented = np.vstack((y_pairs_distinct, y_pairs_nondistinct, y_pairs_nondistinct_swp))\n",
    "\n",
    "        y_pairs_augmented = np.expand_dims(y_pairs_augmented,axis=-1)\n",
    "\n",
    "        ds_val = LabelPairDataset()\n",
    "        ds_val.create_from_numpy_pairs(X_pairs_augmented, y_pairs_augmented)\n",
    "        val_loader = DataLoader(ds_val, batch_size=64)\n",
    "    else:\n",
    "        val_loader = DataLoader(ds, batch_size=64)\n",
    "    pair_loader = DataLoader(ds, batch_size=64)\n",
    "\n",
    "    model = LabelRankingModel(input_dim=generator.n_features, hidden_dims=2*[5*generator.n_features], activations=[torch.nn.Sigmoid(), torch.nn.Sigmoid()], output_dim=len(generator.classes_))\n",
    "    model.num_classes = generator.n_classes\n",
    "    model.cuda()\n",
    "    model._fit(pair_loader, val_loader=val_loader, num_epochs=num_epochs, learning_rate=learning_rate, patience=20, verbose=True)\n",
    "    return model\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.multiprocessing as mp\n",
    "from concurrent.futures import ThreadPoolExecutor\n",
    "from scipy.stats import kendalltau\n",
    "def train_model(X_train, y_train, oracle_annotator, generator, learning_rate, num_epochs, X_val=None, y_val=None):\n",
    "    \"\"\"Trains a model with the given dataset and oracle.\"\"\"\n",
    "    return fit_model_with_all_pairs(X_train, y_train, oracle_annotator, generator, learning_rate=learning_rate, num_epochs=num_epochs, X_val=X_val, y_val=y_val)\n",
    "\n",
    "def evaluate_model(model, oracle, name, X_test, y_test, taus, gammas):\n",
    "    \"\"\"Evaluates the trained model and computes correlation scores.\"\"\"\n",
    "    skills = np.take_along_axis(\n",
    "        model.predict_class_skills(X_test),\n",
    "        y_test[:, np.newaxis].detach().cpu().numpy(),\n",
    "        axis=1\n",
    "    )\n",
    "    conformities = oracle.get_conformity(torch.tensor(X_test,device=\"cuda\"), torch.tensor(y_test,device=\"cuda\")).detach().cpu().numpy()\n",
    "    \n",
    "    tau_corr, _ = kendalltau(skills.detach().cpu().numpy(), conformities.detach().cpu().numpy())\n",
    "    gamma_corr = goodman_kruskal_gamma(skills.detach().cpu().numpy(), conformities.detach().cpu().numpy())\n",
    "\n",
    "    taus[name].append(tau_corr)\n",
    "    gammas[name].append(gamma_corr)\n",
    "\n",
    "\n",
    "def conduct_oracle_experiment(num_instances_to_check, generator, learning_rate=0.01, num_epochs=250):\n",
    "    \"\"\"Conducts an oracle experiment with parallelized training and evaluation.\"\"\"\n",
    "    \n",
    "    oracle_annotator_aps = OracleAnnotator(generator=generator, score=aps)\n",
    "    oracle_annotator_lac = OracleAnnotator(generator=generator, score=lac)\n",
    "    oracle_annotator_rand_aps = OracleAnnotator(generator=generator, score=rand_aps)\n",
    "\n",
    "    X_test, y_test = generator.generate_data(n_samples=100)\n",
    "    X_test, y_test = X_test.to(\"cuda\"), y_test.to(\"cuda\")\n",
    "    taus = {\"lac\": [], \"aps\": [], \"own_aps\": [], \"rand_aps\": [], \"own_rand_aps\": []}\n",
    "    gammas = {\"lac\": [], \"aps\": [], \"own_aps\": [], \"rand_aps\": [], \"own_rand_aps\": []}\n",
    "\n",
    "    for num_instances in num_instances_to_check:\n",
    "        X_gen, _ = generator.generate_data(n_samples=num_instances)\n",
    "        X_train = X_gen.repeat_interleave(generator.n_classes, dim=0)\n",
    "        y_train = np.tile(generator.classes_.detach().cpu().numpy(), num_instances)\n",
    "        X_gen, _ = generator.generate_data(n_samples=num_instances)\n",
    "        X_val = X_gen.repeat_interleave(generator.n_classes, dim=0)\n",
    "        y_val = np.tile(generator.classes_.detach().cpu().numpy(), num_instances)\n",
    "\n",
    "        # --- Parallel Model Training ---\n",
    "        models = Parallel(n_jobs=1, backend=\"loky\")(\n",
    "            delayed(train_model)(X_train, y_train, oracle, generator, learning_rate, num_epochs, X_val, y_val)\n",
    "            for oracle in [oracle_annotator_lac, oracle_annotator_aps, oracle_annotator_rand_aps]\n",
    "        )\n",
    "        model_lac, model_aps, model_rand_aps = models\n",
    "\n",
    "        models = [model_lac, model_aps, model_rand_aps]\n",
    "        oracles = [oracle_annotator_lac, oracle_annotator_aps, oracle_annotator_rand_aps]\n",
    "        names = [\"lac\", \"aps\", \"rand_aps\"]\n",
    "\n",
    "        # --- Parallel Model Evaluation ---\n",
    "        with ThreadPoolExecutor() as executor:\n",
    "            futures = [\n",
    "                executor.submit(evaluate_model, model, oracle, name, X_test, y_test, taus, gammas)\n",
    "                for model, oracle, name in zip(models, oracles, names)\n",
    "            ]\n",
    "            for future in futures:\n",
    "                future.result()  # Ensure completion\n",
    "\n",
    "        skills_from_model = model_lac(X_test)\n",
    "        own_lac = torch.take_along_dim(skills_from_model, y_test.unsqueeze(-1), dim=1).detach().cpu().numpy()\n",
    "        y_test = torch.tensor(y_test, device=\"cuda\")\n",
    "        skills_from_model = -skills_from_model\n",
    "        skills_from_model = (skills_from_model - skills_from_model.min()) / (skills_from_model.max() - skills_from_model.min()) \n",
    "        own_aps = aps._calculate_single_label(torch.tensor(skills_from_model), y_test).detach().cpu().numpy()\n",
    "        aps_scores = oracle_annotator_aps.get_conformity(X_test, y_test).detach().cpu().numpy()\n",
    "\n",
    "        tau_corr, p_value = kendalltau(own_aps, aps_scores)\n",
    "        gamma_corr = goodman_kruskal_gamma(own_aps,aps_scores)\n",
    "        taus[\"own_aps\"].append(tau_corr)\n",
    "        gammas[\"own_aps\"].append(gamma_corr)\n",
    "        # randomized APS reconstructed\n",
    "        own_rand_aps = rand_aps._calculate_single_label(-torch.tensor(skills_from_model), y_test).detach().cpu().numpy()\n",
    "        # own_aps = np.take_along_axis(own_aps, y_test.detach().numpy()[:,np.newaxis], axis=1)\n",
    "        rand_aps_scores = oracle_annotator_rand_aps.get_conformity(X_test, y_test).detach().cpu().numpy()\n",
    "        tau_corr, p_value = kendalltau(own_rand_aps, rand_aps_scores)\n",
    "        gamma_corr = goodman_kruskal_gamma(own_aps,aps_scores)\n",
    "        taus[\"own_rand_aps\"].append(tau_corr)\n",
    "        gammas[\"own_rand_aps\"].append(gamma_corr)\n",
    "\n",
    "\n",
    "    return taus, gammas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# X_train, y_train = make_classification(\n",
    "#     n_samples=100, n_features=3, n_classes=3, n_informative=3, n_redundant=0, n_repeated=0, n_clusters_per_class=1, random_state=42\n",
    "# )\n",
    "\n",
    "clf = clf_3d_3c\n",
    "\n",
    "clf.fit(None, None)\n",
    "num_instances_to_check = np.linspace(10,50,5).astype(int)\n",
    "# num_instances_to_check = [40]\n",
    "\n",
    "taus, gammas = conduct_oracle_experiment(num_instances_to_check=num_instances_to_check,generator=clf,learning_rate=0.01, num_epochs=10)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set(font_scale=1.5,rc={'text.usetex' : True})\n",
    "sns.set_style(\"whitegrid\")\n",
    "plt.rc('font', **{'family': 'serif'})\n",
    "plt.rcParams[\"figure.figsize\"] = (7, 3)\n",
    "\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "ax.set_title(\"\")\n",
    "ax.set_ylabel(r\"Kendall's $\\tau$-b\")\n",
    "ax.set_xlabel(r\"No. Instances\")\n",
    "# ax.set_ylim([0.0,1])\n",
    "sns.lineplot(x=num_instances_to_check, y=taus[\"lac\"], ax = ax, marker=\"o\",label=\"LAC\", legend=False)\n",
    "sns.lineplot(x=num_instances_to_check, y=taus[\"aps\"], ax = ax, marker=\"o\",label=\"APS\", legend=False)\n",
    "sns.lineplot(x=num_instances_to_check, y=taus[\"own_aps\"], ax = ax, marker=\"o\",label=\"Reconstructed APS\", legend=False, linestyle=\"--\")\n",
    "# sns.lineplot(x=num_instances_to_check, y=taus[\"rand_aps\"], ax = ax, marker=\"o\",label=\"APS\", legend=False)\n",
    "# sns.lineplot(x=num_instances_to_check, y=taus[\"own_rand_aps\"], ax = ax, marker=\"o\",label=\"APS\", legend=False)\n",
    "# sns.lineplot(x=num_instances_to_check, y=tau_corrs_own_aps, ax = ax, marker=\"o\",label=\"Reconstructed APS\", linestyle=\"--\", legend=False)\n",
    "# sns.lineplot(x=num_instances_to_check, y=tau_corrs_SAPS, ax = ax, marker=\"o\", label=\"SAPS\", legend=False)\n",
    "lgd = fig.legend(loc='upper center', ncol=3, bbox_to_anchor=(0.5, 0.08), frameon=False)\n",
    "\n",
    "fig.tight_layout() \n",
    "plt.savefig(\"replicating.pdf\",bbox_extra_artists=(lgd,), bbox_inches='tight')\n",
    "# axes[1].set_title(\"APS\")\n",
    "# axes[1].set_ylabel(r\"Kendalls $\\tau$\")\n",
    "# axes[1].set_xlabel(r\"No. Pairs\")\n",
    "# sns.lineplot(x=num_pairs_to_check, y=tau_corrs_APS, ax = axes[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "skills_from_model = model_lac(X_test)\n",
    "y_test = torch.tensor(y_test, device=\"cuda\")\n",
    "skills_from_model = skills_from_model - skills_from_model.min()\n",
    "skills_from_model = skills_from_model / (skills_from_model.max() - skills_from_model.min())\n",
    "own_aps = aps._calculate_single_label(-torch.tensor(skills_from_model), y_test).detach().cpu().numpy()\n",
    "y_test = torch.tensor(y_test, device=\"cuda\")\n",
    "own_aps = aps._calculate_single_label(-torch.tensor(skills_from_model), y_test).detach().cpu().numpy()\n",
    "# own_aps = np.take_along_axis(own_aps, y_test.detach().numpy()[:,np.newaxis], axis=1)\n",
    "aps_scores = oracle_annotator_aps.get_conformity(X_test, y_test).detach().cpu().numpy()\n",
    "tau_corr, p_value = kendalltau(own_aps, aps_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
