{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.base import BaseEstimator, ClassifierMixin\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.utils.validation import check_is_fitted\n",
    "from mapie.classification import MapieClassifier\n",
    "from util.ranking_datasets import DyadOneHotPairDataset\n",
    "from models.ranking_models import DyadRankingModel, SortLayer\n",
    "from torch.utils.data import DataLoader\n",
    "from scipy.stats import kendalltau\n",
    "from sklearn.datasets import make_classification\n",
    "from mapie.conformity_scores import LACConformityScore, APSConformityScore, TopKConformityScore\n",
    "import torch\n",
    "\n",
    "class MultinomialSyntheticDataGenerator(BaseEstimator, ClassifierMixin):\n",
    "    def __init__(self, random_state=None):\n",
    "        \"\"\"\n",
    "        A custom estimator for generating synthetic data using multinomial logistic regression,\n",
    "        with the feature distribution inferred from the training data.\n",
    "        \n",
    "        Parameters:\n",
    "        - n_samples (int): Number of synthetic samples to generate.\n",
    "        - random_state (int): Seed for reproducibility.\n",
    "        \"\"\"\n",
    "        self.random_state = random_state\n",
    "        np.random.seed(self.random_state)\n",
    "\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        \"\"\"\n",
    "        Fits a multinomial logistic regression model to the data and estimates the feature distribution.\n",
    "        \n",
    "        Parameters:\n",
    "        - X (ndarray): Feature matrix of shape (n_samples, n_features).\n",
    "        - y (ndarray): Target labels of shape (n_samples,).\n",
    "        \n",
    "        Returns:\n",
    "        - self: The fitted instance.\n",
    "        \"\"\"\n",
    "        # Store mean and covariance of features\n",
    "        self.classes_ = np.unique(y)\n",
    "        self.feature_mean_ = np.mean(X, axis=0)\n",
    "        self.feature_cov_ = np.cov(X, rowvar=False)\n",
    "        \n",
    "        # Fit a logistic regression model\n",
    "        self.model_ = LogisticRegression(multi_class=\"multinomial\", solver=\"lbfgs\", random_state=self.random_state)\n",
    "        self.model_.fit(X, y)\n",
    "        \n",
    "        # Store the number of classes and features\n",
    "        self.n_classes_ = len(np.unique(y))\n",
    "        self.n_features_ = X.shape[1]\n",
    "        \n",
    "        return self\n",
    "\n",
    "    def predict_proba(self, X):\n",
    "        \"\"\"\n",
    "        Predicts class probabilities for the given feature matrix.\n",
    "        \n",
    "        Parameters:\n",
    "        - X (ndarray): Feature matrix of shape (n_samples, n_features).\n",
    "        \n",
    "        Returns:\n",
    "        - probabilities (ndarray): Predicted probabilities of shape (n_samples, n_classes).\n",
    "        \"\"\"\n",
    "        check_is_fitted(self, \"model_\")\n",
    "        return self.model_.predict_proba(X)\n",
    "    \n",
    "\n",
    "    def predict(self, X):\n",
    "        \"\"\"\n",
    "        Predicts class probabilities for the given feature matrix.\n",
    "        \n",
    "        Parameters:\n",
    "        - X (ndarray): Feature matrix of shape (n_samples, n_features).\n",
    "        \n",
    "        Returns:\n",
    "        - probabilities (ndarray): Predicted probabilities of shape (n_samples, n_classes).\n",
    "        \"\"\"\n",
    "        check_is_fitted(self, \"model_\")\n",
    "        return self.model_.predict(X)\n",
    "\n",
    "    def generate(self, n):\n",
    "        \"\"\"\n",
    "        Generates synthetic data and labels based on the learned model and feature distribution.\n",
    "        \n",
    "        Returns:\n",
    "        - X_synthetic (ndarray): Generated feature matrix of shape (n_samples, n_features).\n",
    "        - y_synthetic (ndarray): Generated labels of shape (n_samples,).\n",
    "        \"\"\"\n",
    "        check_is_fitted(self, [\"model_\", \"feature_mean_\", \"feature_cov_\"])\n",
    "        \n",
    "        # Generate synthetic features based on the inferred distribution\n",
    "        X_synthetic = np.random.multivariate_normal(self.feature_mean_, self.feature_cov_, n)\n",
    "        \n",
    "        # Compute class probabilities\n",
    "        P_Y_given_X = self.predict_proba(X_synthetic)\n",
    "        \n",
    "        # Sample synthetic labels\n",
    "        y_synthetic = np.array([np.random.choice(self.n_classes_, p=probs) for probs in P_Y_given_X])\n",
    "        \n",
    "        return X_synthetic, y_synthetic\n",
    "\n",
    "\n",
    "    def generate_instances(self, n):\n",
    "        \"\"\"\n",
    "        Generates synthetic data and labels based on the learned model and feature distribution.\n",
    "        \n",
    "        Returns:\n",
    "        - X_synthetic (ndarray): Generated feature matrix of shape (n_samples, n_features).\n",
    "        - y_synthetic (ndarray): Generated labels of shape (n_samples,).\n",
    "        \"\"\"\n",
    "        check_is_fitted(self, [\"model_\", \"feature_mean_\", \"feature_cov_\"])\n",
    "        \n",
    "        # Generate synthetic features based on the inferred distribution\n",
    "        \n",
    "        X = np.random.multivariate_normal(self.feature_mean_, self.feature_cov_, n)\n",
    "        return X\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class OracleAnnotator:\n",
    "    def __init__(self,mapie_clf, generator):\n",
    "        self.mapie_clf = mapie_clf\n",
    "        self.classes_ = mapie_clf.classes_\n",
    "        self.generator = generator\n",
    "\n",
    "    def generate_pairs_in_instance(self, n):\n",
    "        \"\"\"\n",
    "        Generates synthetic data and labels based on the learned model and feature distribution.\n",
    "        \n",
    "        Returns:\n",
    "        - X_synthetic (ndarray): Generated feature matrix of shape (n_samples, n_features).\n",
    "        - y_synthetic (ndarray): Generated labels of shape (n_samples,).\n",
    "        \"\"\"        \n",
    "        # Generate synthetic features based on the inferred distribution\n",
    "        X = self.generator.generate_instances(n)\n",
    "        X = np.repeat(X, repeats=2, axis=0)\n",
    "\n",
    "        y = np.hstack([np.random.choice(self.classes_, size=2, replace=False) for _ in range(n)])\n",
    "\n",
    "        conformities = self.get_conformity(X,y)\n",
    "\n",
    "        X_rs = X.reshape(n,2,self.generator.n_features_)\n",
    "        y_rs = y.reshape(n,2)\n",
    "        conformities_n_rs = - conformities.reshape(n,2)\n",
    "        sort_idx = conformities_n_rs.argsort(axis=1)\n",
    "        X_rs[sort_idx]\n",
    "        y_rs[sort_idx,:]\n",
    "        X_pairs = np.take_along_axis(X_rs, sort_idx[:, :, np.newaxis], axis=1)\n",
    "        y_pairs = np.expand_dims(np.take_along_axis(y_rs, sort_idx, axis=1),axis=-1)\n",
    "        return X_pairs, y_pairs\n",
    "\n",
    "\n",
    "    def generate_pairs_cross_instance(self, n):\n",
    "        \"\"\"\n",
    "        Generates synthetic data and labels based on the learned model and feature distribution.\n",
    "        \n",
    "        Returns:\n",
    "        - X_synthetic (ndarray): Generated feature matrix of shape (n_samples, n_features).\n",
    "        - y_synthetic (ndarray): Generated labels of shape (n_samples,).\n",
    "        \"\"\"        \n",
    "        # Generate synthetic features based on the inferred distribution\n",
    "        \n",
    "        X = self.generator.generate_instances(2*n)\n",
    "        y = np.random.choice(self.classes_, size=2*n, replace=True)\n",
    "        conformities = self.get_conformity(X,y)\n",
    "\n",
    "        X_rs = X.reshape(n,2,self.generator.n_features_)\n",
    "        y_rs = y.reshape(n,2)\n",
    "        conformities_n_rs = - conformities.reshape(n,2)\n",
    "        sort_idx = conformities_n_rs.argsort(axis=1)\n",
    "        X_rs[sort_idx]\n",
    "        y_rs[sort_idx,:]\n",
    "        X_pairs = np.take_along_axis(X_rs, sort_idx[:, :, np.newaxis], axis=1)\n",
    "        y_pairs = np.expand_dims(np.take_along_axis(y_rs, sort_idx, axis=1),axis=-1)\n",
    "\n",
    "        return X_pairs, y_pairs\n",
    "    \n",
    "    def create_pairs_for_classification_data(self, X):\n",
    "        \"\"\"\n",
    "        Generates synthetic data and labels based on the learned model and feature distribution.\n",
    "        \n",
    "        Returns:\n",
    "        - X_synthetic (ndarray): Generated feature matrix of shape (n_samples, n_features).\n",
    "        - y_synthetic (ndarray): Generated labels of shape (n_samples,).\n",
    "        \"\"\"\n",
    "        # Generate synthetic features based on the inferred distribution\n",
    "        \n",
    "        X = self.generator.generate_instances(2*n)\n",
    "        y = np.random.choice(self.classes_, size=2*n, replace=True)\n",
    "        conformities = self.get_conformity(X,y)\n",
    "\n",
    "        X_rs = X.reshape(n,2,self.generator.n_features_)\n",
    "        y_rs = y.reshape(n,2)\n",
    "        conformities_n_rs = - conformities.reshape(n,2)\n",
    "        sort_idx = conformities_n_rs.argsort(axis=1)\n",
    "        X_rs[sort_idx]\n",
    "        y_rs[sort_idx,:]\n",
    "        X_pairs = np.take_along_axis(X_rs, sort_idx[:, :, np.newaxis], axis=1)\n",
    "        y_pairs = np.expand_dims(np.take_along_axis(y_rs, sort_idx, axis=1),axis=-1)\n",
    "\n",
    "        return X_pairs, y_pairs\n",
    "\n",
    "    # we assume y is already label encoded\n",
    "    def get_conformity(self, X, y):\n",
    "        y_pred_proba = self.mapie_clf.estimator.predict_proba(X)\n",
    "        scores = self.mapie_clf.conformity_score_function_.get_conformity_scores(\n",
    "                        y, y_pred_proba, y_enc=y\n",
    "                    )\n",
    "        return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from venv import create\n",
    "# from util.ranking_datasets import DyadOneHotPairDataset\n",
    "# from mapie.classification import MapieClassifier\n",
    "# from mapie.conformity_scores.sets import APSConformityScore, LACConformityScore, NaiveConformityScore, TopKConformityScore\n",
    "# from torch.utils.data.dataloader import DataLoader\n",
    "# from sklearn.datasets import make_classification\n",
    "# from scipy.stats import kendalltau\n",
    "# import matplotlib.pyplot as plt\n",
    "\n",
    "# def create_dyads(X,y, n_classes):\n",
    "#     y_1h = np.eye(n_classes)[y.reshape(-1)].reshape(*y.shape, n_classes)\n",
    "#     dyads = np.concatenate((X, y_1h.squeeze()), axis=1)\n",
    "#     return dyads\n",
    "\n",
    "# X_seed, y_seed = make_classification(n_samples=1000, n_features=3, n_classes=3, n_informative=3, n_redundant=0, n_repeated=0, n_clusters_per_class=1, random_state=42)\n",
    "# conformity_score = APSConformityScore()\n",
    "# generator = MultinomialSyntheticDataGenerator(random_state=42)\n",
    "# generator.fit(X_seed, y_seed)\n",
    "# X_cal, y_cal = generator.generate(n=100)\n",
    "# mapie_clf = MapieClassifier(estimator=generator, cv=\"prefit\", conformity_score=conformity_score)\n",
    "# # create mapie classifier for conformity scores\n",
    "# mapie_clf.fit(X_cal, y_cal)\n",
    "# # create \n",
    "# oracle_annotator = OracleAnnotator(mapie_clf, generator)\n",
    "\n",
    "# # generate all possible pairs for a couple of instances\n",
    "\n",
    "# def create_training_data(n_instances):\n",
    "#     n_classes = len(generator.classes_)\n",
    "#     X_train = generator.generate_instances(n_instances).repeat(n_classes, axis=0)\n",
    "#     y_train = np.tile(generator.classes_, n_instances)\n",
    "#     conformities = oracle_annotator.get_conformity(X_train,y_train)\n",
    "#     sort_idx = (-conformities).argsort(axis=0).flatten()\n",
    "\n",
    "#     X_sorted = X_train[sort_idx]\n",
    "#     y_sorted = y_train[sort_idx]\n",
    "\n",
    "#     X_pairs = np.array([(X_sorted[i], X_sorted[j]) for i in range(len(X_sorted)) for j in range(i + 1, len(X_sorted))])\n",
    "#     y_pairs = np.array([(y_sorted[i], y_sorted[j]) for i in range(len(y_sorted)) for j in range(i + 1, len(y_sorted))])\n",
    "#     y_pairs = np.expand_dims(y_pairs, axis=-1)\n",
    "#     y_pairs_1h = np.eye(n_classes)[y_pairs.reshape(-1)].reshape(*y_pairs.shape, n_classes)\n",
    "#     dyads = np.concatenate((X_pairs, y_pairs_1h.squeeze()), axis=2)\n",
    "#     ds_1h = DyadOneHotPairDataset()\n",
    "#     ds_1h.create_from_numpy_dyad_pairs(dyads)\n",
    "#     return ds_1h, X_train, y_train\n",
    "\n",
    "# from models.ranking_models import DyadRankingModel, SortLayer\n",
    "# import torch\n",
    "\n",
    "# model = DyadRankingModel(input_dim=6,hidden_dims=[6,6,6],activations=[torch.nn.Sigmoid(), SortLayer(), torch.nn.Identity()])\n",
    "\n",
    "\n",
    "# train_data, X_train, y_train = create_training_data(100)\n",
    "# val_data, X_val, y_val = create_training_data(50)\n",
    "\n",
    "# train_loader = DataLoader(train_data, 64)\n",
    "# val_loader = DataLoader(val_data, 64)\n",
    "\n",
    "\n",
    "# model._fit(train_loader,val_loader=val_loader, num_epochs=200, patience=1000, learning_rate=0.01, verbose=True)\n",
    "\n",
    "# def create_dyads(X,y, n_classes):\n",
    "#     y_1h = np.eye(n_classes)[y.reshape(-1)].reshape(*y.shape, n_classes)\n",
    "#     dyads = np.concatenate((X, y_1h.squeeze()), axis=1)\n",
    "#     return dyads\n",
    "\n",
    "# X_test, y_test = generator.generate(10)\n",
    "\n",
    "# conformities = oracle_annotator.get_conformity(X_test, y_test)\n",
    "\n",
    "# dyads_test= create_dyads(X_test, y_test, 3)\n",
    "# dyads_tensor = torch.tensor(dyads_test, dtype=torch.float32)\n",
    "# skills = model(dyads_tensor).detach().cpu().numpy()\n",
    "# print(\"out of sample\", kendalltau(skills, conformities))\n",
    "\n",
    "# X_test, y_test = X_train, y_train\n",
    "\n",
    "# conformities = oracle_annotator.get_conformity(X_test, y_test)\n",
    "\n",
    "# dyads_test= create_dyads(X_test, y_test, 3)\n",
    "# dyads_tensor = torch.tensor(dyads_test, dtype=torch.float32)\n",
    "# skills = model(dyads_tensor).detach().cpu().numpy()\n",
    "# print(\"in of sample\", kendalltau(skills, conformities))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conduct_oracle_experiment(conformity_score, num_instances_to_check, generator, X_cal, y_cal):\n",
    "    tau_corrs = []\n",
    "    # Generate a small dataset\n",
    "\n",
    "    mapie_clf = MapieClassifier(estimator=generator, cv=\"prefit\", conformity_score=conformity_score)\n",
    "    # create mapie classifier for conformity scores\n",
    "    mapie_clf.fit(X_cal, y_cal)\n",
    "    # create \n",
    "    oracle_annotator = OracleAnnotator(mapie_clf, generator)\n",
    "    models = []\n",
    "\n",
    "    for num_instances in num_instances_to_check:\n",
    "\n",
    "        X_train = generator.generate_instances(num_instances).repeat(3, axis=0)\n",
    "        y_train = np.tile(generator.classes_, num_instances)\n",
    "        conformities = oracle_annotator.get_conformity(X_train,y_train)\n",
    "        sort_idx = (-conformities).argsort(axis=0).flatten()\n",
    "\n",
    "        X_sorted = X_train[sort_idx]\n",
    "        y_sorted = y_train[sort_idx]\n",
    "\n",
    "        X_pairs = np.array([(X_sorted[i], X_sorted[j]) for i in range(len(X_sorted)) for j in range(i + 1, len(X_sorted))])\n",
    "        y_pairs = np.array([(y_sorted[i], y_sorted[j]) for i in range(len(y_sorted)) for j in range(i + 1, len(y_sorted))])\n",
    "\n",
    "        y_pairs_1h = np.eye(3)[y_pairs.reshape(-1)].reshape(*y_pairs.shape, 3)\n",
    "        dyads = np.concatenate((X_pairs, y_pairs_1h.squeeze()), axis=2)\n",
    "        ds_1h = DyadOneHotPairDataset()\n",
    "        ds_1h.create_from_numpy_dyad_pairs(dyads)\n",
    "        pair_loader = DataLoader(ds_1h, batch_size=64)\n",
    "        model = DyadRankingModel(input_dim=6,hidden_dims=[6,6,6],activations=[torch.nn.Sigmoid(), SortLayer(), torch.nn.ReLU()])\n",
    "\n",
    "        model.num_classes = generator.n_classes_\n",
    "        # model.to(\"cuda\")\n",
    "        device = next(model.parameters()).device\n",
    "        print(f\"Model is on: {device}\")\n",
    "        model._fit(pair_loader, val_loader=pair_loader, num_epochs=250, learning_rate=0.001, patience=100, verbose=True)\n",
    "\n",
    "        # generate data from data generating process and check whether the learned non-conformity relation sorts them correctly\n",
    "        X_test = generator.generate_instances(100).repeat(3, axis=0)\n",
    "        y_test = np.tile(generator.classes_, 100)        \n",
    "        skills_from_model = np.take_along_axis(model.predict_class_skills(X_test), y_test[:,np.newaxis], axis=1)\n",
    "        conformity_scores = oracle_annotator.get_conformity(X_test, y_test)\n",
    "        tau_corr, p_value = kendalltau(skills_from_model, conformity_scores)\n",
    "        tau_corrs.append(tau_corr)\n",
    "        models.append(models)\n",
    "        torch.cuda.empty_cache()\n",
    "    return tau_corrs, skills_from_model, conformity_scores, models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_pairs_to_check = np.linspace(10,100,3).astype(int)\n",
    "X_train, y_train = make_classification(\n",
    "    n_samples=100, n_features=3, n_classes=3, n_informative=3, n_redundant=0, n_repeated=0, n_clusters_per_class=1, random_state=42\n",
    ")\n",
    "\n",
    "# Initialize and fit the generator\n",
    "generator = MultinomialSyntheticDataGenerator(random_state=42)\n",
    "generator.fit(X_train, y_train)\n",
    "X_cal, y_cal = generator.generate(n=100)\n",
    "tau_corrs_LAC, skills_LAC, conformities_LAC, models_LAC = conduct_oracle_experiment(LACConformityScore(), num_pairs_to_check, generator, X_cal, y_cal)\n",
    "# tau_corrs_Naive, skills_Naive, conformities_Naive = conduct_oracle_experiment(NaiveConformityScore(), num_pairs_to_check, generator, X_cal, y_cal)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tau_corrs_APS, skills_APC, conformities_APC, models_APC = conduct_oracle_experiment(APSConformityScore(), num_pairs_to_check, generator, X_cal, y_cal)\n",
    "tau_corrs_TopK, skills_TopK, conformities_TopK, models_TopK = conduct_oracle_experiment(TopKConformityScore(), num_pairs_to_check, generator, X_cal, y_cal)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set(font_scale=1.5,rc={'text.usetex' : True})\n",
    "sns.set_style(\"whitegrid\")\n",
    "plt.rc('font', **{'family': 'serif'})\n",
    "plt.rcParams[\"figure.figsize\"] = (12, 3)\n",
    "\n",
    "\n",
    "\n",
    "fig, axes = plt.subplots(1,3, sharey=True)\n",
    "axes = axes.ravel()\n",
    "\n",
    "axes[0].set_title(\"LAC\")\n",
    "axes[0].set_ylabel(r\"Kendalls $\\tau$\")\n",
    "axes[0].set_xlabel(r\"No. Pairs\")\n",
    "axes[0].set_ylim([-0.1,1])\n",
    "sns.lineplot(x=num_pairs_to_check, y=tau_corrs_LAC, ax = axes[0], marker=\"o\")\n",
    "axes[1].set_title(\"APS\")\n",
    "axes[1].set_ylabel(r\"Kendalls $\\tau$\")\n",
    "axes[1].set_xlabel(r\"No. Pairs\")\n",
    "sns.lineplot(x=num_pairs_to_check, y=tau_corrs_APS, ax = axes[1], marker=\"o\")\n",
    "axes[2].set_title(\"TopK\")\n",
    "axes[2].set_ylabel(r\"Kendalls $\\tau$\")\n",
    "axes[2].set_xlabel(r\"No. Pairs\")\n",
    "sns.lineplot(x=num_pairs_to_check, y=tau_corrs_TopK, ax = axes[2], marker=\"o\")\n",
    "fig.tight_layout()\n",
    "plt.savefig(\"replicating.pdf\")\n",
    "# axes[1].set_title(\"APS\")\n",
    "# axes[1].set_ylabel(r\"Kendalls $\\tau$\")\n",
    "# axes[1].set_xlabel(r\"No. Pairs\")\n",
    "# sns.lineplot(x=num_pairs_to_check, y=tau_corrs_APS, ax = axes[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set(font_scale=1.5,rc={'text.usetex' : True})\n",
    "sns.set_style(\"whitegrid\")\n",
    "plt.rc('font', **{'family': 'serif'})\n",
    "plt.rcParams[\"figure.figsize\"] = (7, 3)\n",
    "\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "ax.set_title(\"\")\n",
    "ax.set_ylabel(r\"Kendalls $\\tau$\")\n",
    "ax.set_xlabel(r\"No. Pairs\")\n",
    "ax.set_ylim([-0.2,1])\n",
    "sns.lineplot(x=num_pairs_to_check, y=tau_corrs_LAC, ax = ax, marker=\"o\",label=\"LAC\", legend=False)\n",
    "sns.lineplot(x=num_pairs_to_check, y=tau_corrs_APS, ax = ax, marker=\"o\",label=\"APS\", legend=False)\n",
    "sns.lineplot(x=num_pairs_to_check, y=tau_corrs_TopK, ax = ax, marker=\"o\", label=\"TopK\", legend=False)\n",
    "lgd = fig.legend(loc='upper center', ncol=3, bbox_to_anchor=(0.5, 0.08), frameon=False)\n",
    "\n",
    "fig.tight_layout() \n",
    "plt.savefig(\"replicating.pdf\",bbox_extra_artists=(lgd,), bbox_inches='tight')\n",
    "# axes[1].set_title(\"APS\")\n",
    "# axes[1].set_ylabel(r\"Kendalls $\\tau$\")\n",
    "# axes[1].set_xlabel(r\"No. Pairs\")\n",
    "# sns.lineplot(x=num_pairs_to_check, y=tau_corrs_APS, ax = axes[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "plnet",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
