{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.base import BaseEstimator, ClassifierMixin\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.utils.validation import check_is_fitted\n",
    "import torch\n",
    "\n",
    "class MultinomialSyntheticDataGenerator(BaseEstimator, ClassifierMixin):\n",
    "    def __init__(self, random_state=None):\n",
    "        \"\"\"\n",
    "        A custom estimator for generating synthetic data using multinomial logistic regression,\n",
    "        with the feature distribution inferred from the training data.\n",
    "        \n",
    "        Parameters:\n",
    "        - n_samples (int): Number of synthetic samples to generate.\n",
    "        - random_state (int): Seed for reproducibility.\n",
    "        \"\"\"\n",
    "        self.random_state = random_state\n",
    "        np.random.seed(self.random_state)\n",
    "\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        \"\"\"\n",
    "        Fits a multinomial logistic regression model to the data and estimates the feature distribution.\n",
    "        \n",
    "        Parameters:\n",
    "        - X (ndarray): Feature matrix of shape (n_samples, n_features).\n",
    "        - y (ndarray): Target labels of shape (n_samples,).\n",
    "        \n",
    "        Returns:\n",
    "        - self: The fitted instance.\n",
    "        \"\"\"\n",
    "        # Store mean and covariance of features\n",
    "        self.classes_ = np.unique(y)\n",
    "        self.feature_mean_ = np.mean(X, axis=0)\n",
    "        self.feature_cov_ = np.cov(X, rowvar=False)\n",
    "        \n",
    "        # Fit a logistic regression model\n",
    "        self.model_ = LogisticRegression(multi_class=\"multinomial\", solver=\"lbfgs\", random_state=self.random_state)\n",
    "        self.model_.fit(X, y)\n",
    "        \n",
    "        # Store the number of classes and features\n",
    "        self.n_classes_ = len(np.unique(y))\n",
    "        self.n_features_ = X.shape[1]\n",
    "        \n",
    "        return self\n",
    "\n",
    "    def predict_proba(self, X):\n",
    "        \"\"\"\n",
    "        Predicts class probabilities for the given feature matrix.\n",
    "        \n",
    "        Parameters:\n",
    "        - X (ndarray): Feature matrix of shape (n_samples, n_features).\n",
    "        \n",
    "        Returns:\n",
    "        - probabilities (ndarray): Predicted probabilities of shape (n_samples, n_classes).\n",
    "        \"\"\"\n",
    "        check_is_fitted(self, \"model_\")\n",
    "        return self.model_.predict_proba(X)\n",
    "    \n",
    "\n",
    "    def predict(self, X):\n",
    "        \"\"\"\n",
    "        Predicts class probabilities for the given feature matrix.\n",
    "        \n",
    "        Parameters:\n",
    "        - X (ndarray): Feature matrix of shape (n_samples, n_features).\n",
    "        \n",
    "        Returns:\n",
    "        - probabilities (ndarray): Predicted probabilities of shape (n_samples, n_classes).\n",
    "        \"\"\"\n",
    "        check_is_fitted(self, \"model_\")\n",
    "        return self.model_.predict(X)\n",
    "\n",
    "    def generate(self, n):\n",
    "        \"\"\"\n",
    "        Generates synthetic data and labels based on the learned model and feature distribution.\n",
    "        \n",
    "        Returns:\n",
    "        - X_synthetic (ndarray): Generated feature matrix of shape (n_samples, n_features).\n",
    "        - y_synthetic (ndarray): Generated labels of shape (n_samples,).\n",
    "        \"\"\"\n",
    "        check_is_fitted(self, [\"model_\", \"feature_mean_\", \"feature_cov_\"])\n",
    "        \n",
    "        # Generate synthetic features based on the inferred distribution\n",
    "        if self.n_features_ == 1:\n",
    "            X_synthetic = np.random.normal(self.feature_mean_, self.feature_cov_, n).reshape(-1, 1)\n",
    "        else:\n",
    "            X_synthetic = np.random.multivariate_normal(self.feature_mean_, self.feature_cov_, n)\n",
    "        \n",
    "        # Compute class probabilities\n",
    "        P_Y_given_X = self.predict_proba(X_synthetic)\n",
    "        \n",
    "        # Sample synthetic labels\n",
    "        y_synthetic = np.array([np.random.choice(self.n_classes_, p=probs) for probs in P_Y_given_X])\n",
    "        \n",
    "        return X_synthetic, y_synthetic\n",
    "\n",
    "\n",
    "    def generate_instances(self, n):\n",
    "        \"\"\"\n",
    "        Generates synthetic data and labels based on the learned model and feature distribution.\n",
    "        \n",
    "        Returns:\n",
    "        - X_synthetic (ndarray): Generated feature matrix of shape (n_samples, n_features).\n",
    "        - y_synthetic (ndarray): Generated labels of shape (n_samples,).\n",
    "        \"\"\"\n",
    "        check_is_fitted(self, [\"model_\", \"feature_mean_\", \"feature_cov_\"])\n",
    "        \n",
    "        # Generate synthetic features based on the inferred distribution\n",
    "        \n",
    "        if self.n_features_ == 1:\n",
    "            X = np.random.normal(self.feature_mean_, self.feature_cov_, n).reshape(-1, 1)\n",
    "        else:\n",
    "            X = np.random.multivariate_normal(self.feature_mean_, self.feature_cov_, n)\n",
    "        return X\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class OracleAnnotator:\n",
    "    def __init__(self,mapie_clf, generator):\n",
    "        self.mapie_clf = mapie_clf\n",
    "        self.classes_ = mapie_clf.classes_\n",
    "        self.generator = generator\n",
    "\n",
    "    def generate_pairs_in_instance(self, n):\n",
    "        \"\"\"\n",
    "        Generates synthetic data and labels based on the learned model and feature distribution.\n",
    "        \n",
    "        Returns:\n",
    "        - X_synthetic (ndarray): Generated feature matrix of shape (n_samples, n_features).\n",
    "        - y_synthetic (ndarray): Generated labels of shape (n_samples,).\n",
    "        \"\"\"        \n",
    "        # Generate synthetic features based on the inferred distribution\n",
    "        X = self.generator.generate_instances(n)\n",
    "        X = np.repeat(X, repeats=2, axis=0)\n",
    "\n",
    "        y = np.hstack([np.random.choice(self.classes_, size=2, replace=False) for _ in range(n)])\n",
    "\n",
    "        conformities = self.get_conformity(X,y)\n",
    "\n",
    "        X_rs = X.reshape(n,2,self.generator.n_features_)\n",
    "        y_rs = y.reshape(n,2)\n",
    "        conformities_n_rs = - conformities.reshape(n,2)\n",
    "        sort_idx = conformities_n_rs.argsort(axis=1)\n",
    "        X_rs[sort_idx]\n",
    "        y_rs[sort_idx,:]\n",
    "        X_pairs = np.take_along_axis(X_rs, sort_idx[:, :, np.newaxis], axis=1)\n",
    "        y_pairs = np.expand_dims(np.take_along_axis(y_rs, sort_idx, axis=1),axis=-1)\n",
    "        return X_pairs, y_pairs\n",
    "\n",
    "\n",
    "    def generate_pairs_cross_instance(self, n):\n",
    "        \"\"\"\n",
    "        Generates synthetic data and labels based on the learned model and feature distribution.\n",
    "        \n",
    "        Returns:\n",
    "        - X_synthetic (ndarray): Generated feature matrix of shape (n_samples, n_features).\n",
    "        - y_synthetic (ndarray): Generated labels of shape (n_samples,).\n",
    "        \"\"\"        \n",
    "        # Generate synthetic features based on the inferred distribution\n",
    "        \n",
    "        X = self.generator.generate_instances(2*n)\n",
    "        y = np.random.choice(self.classes_, size=2*n, replace=True)\n",
    "        conformities = self.get_conformity(X,y)\n",
    "\n",
    "        X_rs = X.reshape(n,2,self.generator.n_features_)\n",
    "        y_rs = y.reshape(n,2)\n",
    "        conformities_n_rs = - conformities.reshape(n,2)\n",
    "        sort_idx = conformities_n_rs.argsort(axis=1)\n",
    "        X_rs[sort_idx]\n",
    "        y_rs[sort_idx,:]\n",
    "        X_pairs = np.take_along_axis(X_rs, sort_idx[:, :, np.newaxis], axis=1)\n",
    "        y_pairs = np.expand_dims(np.take_along_axis(y_rs, sort_idx, axis=1),axis=-1)\n",
    "\n",
    "        return X_pairs, y_pairs\n",
    "    \n",
    "    # def create_pairs_for_classification_data(self, X):\n",
    "    #     \"\"\"\n",
    "    #     Generates synthetic data and labels based on the learned model and feature distribution.\n",
    "        \n",
    "    #     Returns:\n",
    "    #     - X_synthetic (ndarray): Generated feature matrix of shape (n_samples, n_features).\n",
    "    #     - y_synthetic (ndarray): Generated labels of shape (n_samples,).\n",
    "    #     \"\"\"\n",
    "    #     # Generate synthetic features based on the inferred distribution\n",
    "        \n",
    "    #     X = self.generator.generate_instances(2*n)\n",
    "    #     y = np.random.choice(self.classes_, size=2*n, replace=True)\n",
    "    #     conformities = self.get_conformity(X,y)\n",
    "\n",
    "    #     X_rs = X.reshape(n,2,self.generator.n_features_)\n",
    "    #     y_rs = y.reshape(n,2)\n",
    "    #     conformities_n_rs = - conformities.reshape(n,2)\n",
    "    #     sort_idx = conformities_n_rs.argsort(axis=1)\n",
    "    #     X_rs[sort_idx]\n",
    "    #     y_rs[sort_idx,:]\n",
    "    #     X_pairs = np.take_along_axis(X_rs, sort_idx[:, :, np.newaxis], axis=1)\n",
    "    #     y_pairs = np.expand_dims(np.take_along_axis(y_rs, sort_idx, axis=1),axis=-1)\n",
    "\n",
    "        return X_pairs, y_pairs\n",
    "\n",
    "    # we assume y is already label encoded\n",
    "    def get_conformity(self, X, y):\n",
    "        y_pred_proba = self.mapie_clf.estimator.predict_proba(X)\n",
    "        scores = self.mapie_clf.conformity_score_function_.get_conformity_scores(\n",
    "                        y, y_pred_proba, y_enc=y\n",
    "                    )\n",
    "        return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def goodman_kruskal_gamma(x, y):\n",
    "    \"\"\"\n",
    "    Compute Goodman and Kruskal's Gamma for two ordinal variables.\n",
    "    \n",
    "    Parameters:\n",
    "    x, y: Lists or arrays of ordinal data (same length)\n",
    "    \n",
    "    Returns:\n",
    "    gamma: Goodman and Kruskal's Gamma\n",
    "    \"\"\"\n",
    "    if len(x) != len(y):\n",
    "        raise ValueError(\"Both variables must have the same length.\")\n",
    "    \n",
    "    concordant = 0\n",
    "    discordant = 0\n",
    "    \n",
    "    n = len(x)\n",
    "    for i in range(n - 1):\n",
    "        for j in range(i + 1, n):\n",
    "            # Determine concordance or discordance\n",
    "            if (x[i] > x[j] and y[i] > y[j]) or (x[i] < x[j] and y[i] < y[j]):\n",
    "                concordant += 1\n",
    "            elif (x[i] > x[j] and y[i] < y[j]) or (x[i] < x[j] and y[i] > y[j]):\n",
    "                discordant += 1\n",
    "    \n",
    "    # Compute Gamma\n",
    "    if concordant + discordant == 0:\n",
    "        return 0  # Avoid division by zero\n",
    "    gamma = (concordant - discordant) / (concordant + discordant)\n",
    "    return gamma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mapie.classification import MapieClassifier\n",
    "from mapie.conformity_scores.sets import APSConformityScore, LACConformityScore, NaiveConformityScore, TopKConformityScore\n",
    "from util.ranking_datasets import LabelPairDataset\n",
    "from models.ranking_models import LabelRankingModel\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "from sklearn.datasets import make_classification\n",
    "from scipy.stats import kendalltau\n",
    "import matplotlib.pyplot as plt\n",
    "from models.ranking_models import SortLayer\n",
    "import torch\n",
    "\n",
    "def conduct_oracle_experiment(conformity_score, num_instances_to_check, generator, X_cal, y_cal):\n",
    "    tau_corrs = []\n",
    "    gamma_corrs = []\n",
    "    # Generate a small dataset\n",
    "\n",
    "    mapie_clf = MapieClassifier(estimator=generator, cv=\"prefit\", conformity_score=conformity_score)\n",
    "    # create mapie classifier for conformity scores\n",
    "    mapie_clf.fit(X_cal, y_cal)\n",
    "    # create \n",
    "    oracle_annotator = OracleAnnotator(mapie_clf, generator)\n",
    "    models = []\n",
    "\n",
    "    for num_instances in num_instances_to_check:\n",
    "        X_train = generator.generate_instances(num_instances).repeat(3, axis=0)\n",
    "        y_train = np.tile(generator.classes_, num_instances)\n",
    "        conformities = oracle_annotator.get_conformity(X_train,y_train)\n",
    "        sort_idx = (-conformities).argsort(axis=0).flatten()\n",
    "\n",
    "        X_sorted = X_train[sort_idx]\n",
    "        y_sorted = y_train[sort_idx]\n",
    "\n",
    "        X_pairs = np.array([(X_sorted[i], X_sorted[j]) for i in range(len(X_sorted)) for j in range(i + 1, len(X_sorted))])\n",
    "        y_pairs = np.array([(y_sorted[i], y_sorted[j]) for i in range(len(y_sorted)) for j in range(i + 1, len(y_sorted))])\n",
    "        y_pairs = np.expand_dims(y_pairs,axis=-1)\n",
    "\n",
    "\n",
    "        ds = LabelPairDataset()\n",
    "        ds.create_from_numpy_pairs(X_pairs, y_pairs)\n",
    "        pair_loader = DataLoader(ds, batch_size=64)\n",
    "        ds_val = LabelPairDataset()\n",
    "        ds_val.create_from_numpy_pairs(X_pairs, y_pairs)\n",
    "        val_loader = DataLoader(ds_val, batch_size=64)\n",
    "        print(len(ds))\n",
    "        model = LabelRankingModel(input_dim=generator.n_features_, hidden_dims=3*[generator.n_features_], activations=[torch.nn.Sigmoid(), torch.nn.Sigmoid(), torch.nn.Sigmoid()], output_dim=len(generator.classes_))\n",
    "        model.num_classes = generator.n_classes_\n",
    "        device = next(model.parameters()).device\n",
    "        print(f\"Model is on: {device}\")\n",
    "        model._fit(pair_loader, val_loader=pair_loader, num_epochs=300, learning_rate=0.01, patience=100, verbose=True)\n",
    "\n",
    "        # generate data from data generating process and check whether the learned non-conformity relation sorts them correctly\n",
    "        X_test = generator.generate_instances(100).repeat(3, axis=0)\n",
    "        y_test = np.tile(generator.classes_, 100)\n",
    "        skills_from_model = np.take_along_axis(model.predict_class_skills(X_test), y_test[:,np.newaxis], axis=1)\n",
    "        conformity_scores = oracle_annotator.get_conformity(X_test, y_test)\n",
    "        tau_corr, p_value = kendalltau(skills_from_model, conformity_scores)\n",
    "        tau_corrs.append(tau_corr)\n",
    "        gamma_corr = goodman_kruskal_gamma(skills_from_model,conformity_scores)\n",
    "        gamma_corrs.append(gamma_corr)\n",
    "        models.append(models)\n",
    "    return tau_corrs, gamma_corrs, skills_from_model, conformity_scores, models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_instances_to_check = np.linspace(10,100,5).astype(int)\n",
    "X_train, y_train = make_classification(\n",
    "    n_samples=100, n_features=3, n_classes=3, n_informative=3, n_redundant=0, n_repeated=0, n_clusters_per_class=1, random_state=42\n",
    ")\n",
    "\n",
    "# Initialize and fit the generator\n",
    "generator = MultinomialSyntheticDataGenerator(random_state=42)\n",
    "generator.fit(X_train, y_train)\n",
    "X_cal, y_cal = generator.generate(n=100)\n",
    "tau_corrs_LAC, gamma_corrs_LAC, skills_LAC, conformities_LAC, models_LAC = conduct_oracle_experiment(LACConformityScore(), num_instances_to_check, generator, X_cal, y_cal)\n",
    "tau_corrs_APS, gamma_corrs_APS, skills_APC, conformities_APC, models_APC = conduct_oracle_experiment(APSConformityScore(), num_instances_to_check, generator, X_cal, y_cal)\n",
    "# tau_corrs_TopK, skills_TopK, conformities_TopK, models_TopK = conduct_oracle_experiment(TopKConformityScore(), num_instances_to_check, generator, X_cal, y_cal)\n",
    "# tau_corrs_Naive, skills_Naive, conformities_Naive = conduct_oracle_experiment(NaiveConformityScore(), num_pairs_to_check, generator, X_cal, y_cal)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set(font_scale=1.5,rc={'text.usetex' : True})\n",
    "sns.set_style(\"whitegrid\")\n",
    "plt.rc('font', **{'family': 'serif'})\n",
    "plt.rcParams[\"figure.figsize\"] = (7, 3)\n",
    "\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "ax.set_title(\"\")\n",
    "ax.set_ylabel(r\"Gamma\")\n",
    "ax.set_xlabel(r\"No. Instances\")\n",
    "# ax.set_ylim([0.0,1])\n",
    "sns.lineplot(x=num_instances_to_check, y=gamma_corrs_LAC, ax = ax, marker=\"o\",label=\"LAC\", legend=False)\n",
    "sns.lineplot(x=num_instances_to_check, y=gamma_corrs_APS, ax = ax, marker=\"o\",label=\"APS\", legend=False)\n",
    "# sns.lineplot(x=num_instances_to_check, y=tau_corrs_TopK, ax = ax, marker=\"o\", label=\"TopK\", legend=False)\n",
    "lgd = fig.legend(loc='upper center', ncol=3, bbox_to_anchor=(0.5, 0.08), frameon=False)\n",
    "\n",
    "fig.tight_layout() \n",
    "plt.savefig(\"replicating.pdf\",bbox_extra_artists=(lgd,), bbox_inches='tight')\n",
    "# axes[1].set_title(\"APS\")\n",
    "# axes[1].set_ylabel(r\"Kendalls $\\tau$\")\n",
    "# axes[1].set_xlabel(r\"No. Pairs\")\n",
    "# sns.lineplot(x=num_pairs_to_check, y=tau_corrs_APS, ax = axes[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mapie_clf = MapieClassifier(estimator=generator, cv=\"prefit\", conformity_score=APSConformityScore())\n",
    "# create mapie classifier for conformity scores\n",
    "mapie_clf.fit(X_cal, y_cal)\n",
    "# create \n",
    "oracle_annotator = OracleAnnotator(mapie_clf, generator)\n",
    "models = []\n",
    "\n",
    "X_train = generator.generate_instances(30).repeat(3, axis=0)\n",
    "y_train = np.tile(generator.classes_, 30)\n",
    "conformities = oracle_annotator.get_conformity(X_train,y_train)\n",
    "sort_idx = (-conformities).argsort(axis=0).flatten()\n",
    "\n",
    "X_sorted = X_train[sort_idx]\n",
    "y_sorted = y_train[sort_idx]\n",
    "\n",
    "X_pairs = np.array([(X_sorted[i], X_sorted[j]) for i in range(len(X_sorted)) for j in range(i + 1, len(X_sorted))])\n",
    "y_pairs = np.array([(y_sorted[i], y_sorted[j]) for i in range(len(y_sorted)) for j in range(i + 1, len(y_sorted))])\n",
    "y_pairs = np.expand_dims(y_pairs,axis=-1)\n",
    "\n",
    "model = models_APS[-1]\n",
    "\n",
    "ds = LabelPairDataset()\n",
    "ds.create_from_numpy_pairs(X_pairs, y_pairs)\n",
    "pair_loader = DataLoader(ds, batch_size=64)\n",
    "ds_val = LabelPairDataset()\n",
    "ds_val.create_from_numpy_pairs(X_pairs, y_pairs)\n",
    "val_loader = DataLoader(ds_val, batch_size=64)\n",
    "print(len(ds))\n",
    "model = LabelRankingModel(input_dim=generator.n_features_, hidden_dims=3*[generator.n_features_], activations=[torch.nn.Sigmoid(), torch.nn.Sigmoid(), torch.nn.Sigmoid()], output_dim=len(generator.classes_))\n",
    "model.num_classes = generator.n_classes_\n",
    "device = next(model.parameters()).device\n",
    "print(f\"Model is on: {device}\")\n",
    "# model._fit(pair_loader, val_loader=pair_loader, num_epochs=300, learning_rate=0.01, patience=100, verbose=True)\n",
    "\n",
    "# generate data from data generating process and check whether the learned non-conformity relation sorts them correctly\n",
    "X_test = generator.generate_instances(100).repeat(3, axis=0)\n",
    "y_test = np.tile(generator.classes_, 100)      \n",
    "skills_from_model = np.take_along_axis(model.predict_class_skills(X_test), y_test[:,np.newaxis], axis=1)\n",
    "conformity_scores = oracle_annotator.get_conformity(X_test, y_test)\n",
    "fig, ax = plt.subplots()\n",
    "ax.scatter(skills_from_model, conformity_scores)\n",
    "ax.set_xlabel(\"learn utility\")\n",
    "ax.set_ylabel(\"conformity score\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_seed, y_seed = make_classification(n_samples=1000, n_features=3, n_classes=3, n_informative=3, n_redundant=0, n_repeated=0, n_clusters_per_class=1, random_state=42)\n",
    "conformity_score = APSConformityScore()\n",
    "generator = MultinomialSyntheticDataGenerator(random_state=42)\n",
    "generator.fit(X_seed, y_seed)\n",
    "X_cal, y_cal = generator.generate(n=100)\n",
    "mapie_clf = MapieClassifier(estimator=generator, cv=\"prefit\", conformity_score=conformity_score)\n",
    "# create mapie classifier for conformity scores\n",
    "mapie_clf.fit(X_cal, y_cal)\n",
    "# create \n",
    "oracle_annotator = OracleAnnotator(mapie_clf, generator)\n",
    "\n",
    "# generate all possible pairs for a couple of instances\n",
    "n_instances = 30\n",
    "n_classes = len(generator.classes_)\n",
    "n_obs = n_instances * n_classes\n",
    "X_train = generator.generate_instances(n_instances).repeat(n_classes, axis=0)\n",
    "y_train = np.tile(generator.classes_, n_instances)\n",
    "conformities = oracle_annotator.get_conformity(X_train,y_train)\n",
    "sort_idx = (-conformities).argsort(axis=0).flatten()\n",
    "\n",
    "X_sorted = X_train[sort_idx]\n",
    "y_sorted = y_train[sort_idx]\n",
    "conformities_sorted = conformities[sort_idx]\n",
    "\n",
    "X_pairs = np.array([(X_sorted[i], X_sorted[j]) for i in range(len(X_sorted)) for j in range(i + 1, len(X_sorted))])\n",
    "y_pairs = np.array([(y_sorted[i], y_sorted[j]) for i in range(len(y_sorted)) for j in range(i + 1, len(y_sorted))])\n",
    "y_pairs = np.expand_dims(y_pairs, axis=-1)\n",
    "\n",
    "\n",
    "ds = LabelPairDataset()\n",
    "ds.create_from_numpy_pairs(X_pairs, y_pairs)\n",
    "for x,y in ds:\n",
    "    print(x,y)\n",
    "print(len(ds))\n",
    "model = LabelRankingModel(input_dim=generator.n_features_, hidden_dims=3*[generator.n_features_], activations=[torch.nn.Sigmoid(),torch.nn.Sigmoid(),torch.nn.Sigmoid()], output_dim=len(generator.classes_))\n",
    "\n",
    "pair_loader = DataLoader(ds, batch_size=32)\n",
    "val_loader = DataLoader(ds, batch_size=32)\n",
    "model.num_classes = generator.n_classes_\n",
    "model._fit(pair_loader, val_loader=val_loader, num_epochs=250, patience=250, learning_rate=0.01, verbose=True)\n",
    "\n",
    "\n",
    "# # generate data from data generating process and check whether the learned non-conformity relation sorts them correctly\n",
    "# X_test, y_test = X_train, y_train\n",
    "# conformity_scores = oracle_annotator.get_conformity(X_test, y_test)\n",
    "# skills_from_model = np.take_along_axis(model.predict_class_skills(X_test), y_test[:,np.newaxis], axis=1)\n",
    "# tau_corr, p_value = kendalltau(skills_from_model, conformity_scores)\n",
    "# print(\"in-sample: \", tau_corr)\n",
    "# X_test, y_test = generator.generate(n=10)\n",
    "# conformity_scores = oracle_annotator.get_conformity(X_test, y_test)\n",
    "# skills_from_model = np.take_along_axis(model.predict_class_skills(X_test), y_test[:,np.newaxis], axis=1)\n",
    "# tau_corr, p_value = kendalltau(skills_from_model, conformity_scores)\n",
    "# print(\"out-of-sample: \", tau_corr)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [],
   "source": [
    "# X_test, y_test = X_train, y_train\n",
    "X_test = generator.generate_instances(100).repeat(3, axis=0)\n",
    "y_test = np.tile(generator.classes_, 100)   \n",
    "\n",
    "\n",
    "conformity_scores = oracle_annotator.get_conformity(X_test, y_test)\n",
    "skills_from_model = np.take_along_axis(model.predict_class_skills(X_test), y_test[:,np.newaxis], axis=1)\n",
    "tau_corr, p_value = kendalltau(skills_from_model, conformity_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams[\"figure.figsize\"] = (7, 7)\n",
    "\n",
    "plt.scatter(conformity_scores, skills_from_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "goodman_kruskal_gamma(conformity_scores, skills_from_model)\n",
    "# kendalltau(conformity_scores, skills_from_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_instances = 10\n",
    "n_classes = len(generator.classes_)\n",
    "n_obs = n_instances * n_classes\n",
    "X_train = generator.generate_instances(n_instances).repeat(n_classes, axis=0)\n",
    "y_train = np.tile(generator.classes_, n_instances)\n",
    "conformities = oracle_annotator.get_conformity(X_train,y_train)\n",
    "sort_idx = (-conformities).argsort(axis=0).flatten()\n",
    "\n",
    "X_sorted = X_train[sort_idx]\n",
    "y_sorted = y_train[sort_idx]\n",
    "conformities_sorted = conformities[sort_idx]\n",
    "\n",
    "X_pairs = np.array([(X_sorted[i], X_sorted[j]) for i in range(len(X_sorted)) for j in range(i + 1, len(X_sorted))])\n",
    "y_pairs = np.array([(y_sorted[i], y_sorted[j]) for i in range(len(y_sorted)) for j in range(i + 1, len(y_sorted))])\n",
    "y_pairs = np.expand_dims(y_pairs, axis=-1)\n",
    "ds = LabelPairDataset()\n",
    "ds.create_from_numpy_pairs(X_pairs, y_pairs)\n",
    "for x,y in ds:\n",
    "    print(x,y)\n",
    "pair_loader = DataLoader(ds, batch_size=32)\n",
    "# ds_val = LabelPairDataset()\n",
    "# ds_val.create_from_numpy_pairs(X_pairs, y_pairs)\n",
    "# val_loader = DataLoader(ds_val, batch_size=32, num_workers=6)\n",
    "# print(len(ds))\n",
    "model = LabelRankingModel(input_dim=generator.n_features_, hidden_dims=3*[generator.n_features_], activations=[torch.nn.Sigmoid(), SortLayer(),torch.nn.Identity()], output_dim=len(generator.classes_))\n",
    "model.num_classes = generator.n_classes_\n",
    "print(len(pair_loader))\n",
    "# device = next(model.parameters()).device\n",
    "# print(f\"Model is on: {device}\")\n",
    "model._fit(pair_loader, val_loader=pair_loader, num_epochs=250, learning_rate=0.001, verbose=True)\n",
    "\n",
    "\n",
    "# # generate data from data generating process and check whether the learned non-conformity relation sorts them correctly\n",
    "# X_test, y_test = generator.generate(n=100)\n",
    "# skills_from_model = np.take_along_axis(model.predict_class_skills(X_test), y_test[:,np.newaxis], axis=1)\n",
    "# conformity_scores = oracle_annotator.get_conformity(X_test, y_test)\n",
    "# tau_corr, p_value = kendalltau(skills_from_model, conformity_scores)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mapie.classification import MapieClassifier\n",
    "mapie_clf = MapieClassifier(estimator=generator, cv=\"prefit\", conformity_score=TopKConformityScore())\n",
    "# create mapie classifier for conformity scores\n",
    "mapie_clf.fit(X_train, y_train)\n",
    "# create \n",
    "oracle_annotator = OracleAnnotator(mapie_clf, generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mapie.classification import MapieClassifier\n",
    "from mapie.conformity_scores.sets import APSConformityScore, LACConformityScore, NaiveConformityScore, TopKConformityScore\n",
    "from util.ranking_datasets import LabelPairDataset\n",
    "from models.ranking_models import LabelRankingModel\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "from sklearn.datasets import make_classification\n",
    "from scipy.stats import kendalltau\n",
    "import matplotlib.pyplot as plt\n",
    "from models.ranking_models import SortLayer\n",
    "import torch\n",
    "X_seed, y_seed = make_classification(n_samples=1000, n_features=1, n_classes=3, n_informative=1, n_redundant=0, n_repeated=0, n_clusters_per_class=1, random_state=42)\n",
    "conformity_score = APSConformityScore()\n",
    "generator = MultinomialSyntheticDataGenerator(random_state=42)\n",
    "generator.fit(X_seed, y_seed)\n",
    "X_cal, y_cal = generator.generate(n=100)\n",
    "mapie_clf = MapieClassifier(estimator=generator, cv=\"prefit\", conformity_score=conformity_score)\n",
    "# create mapie classifier for conformity scores\n",
    "mapie_clf.fit(X_cal, y_cal)\n",
    "# create \n",
    "oracle_annotator = OracleAnnotator(mapie_clf, generator)\n",
    "\n",
    "# generate all possible pairs for a couple of instances\n",
    "n_instances = 10\n",
    "n_classes = len(generator.classes_)\n",
    "n_obs = n_instances * n_classes\n",
    "X_train = generator.generate_instances(n_instances).repeat(n_classes, axis=0)\n",
    "y_train = np.tile(generator.classes_, n_instances)\n",
    "conformities = oracle_annotator.get_conformity(X_train,y_train)\n",
    "sort_idx = (-conformities).argsort(axis=0).flatten()\n",
    "\n",
    "X_sorted = X_train[sort_idx]\n",
    "y_sorted = y_train[sort_idx]\n",
    "conformities_sorted = conformities[sort_idx]\n",
    "\n",
    "X_pairs = np.array([(X_sorted[i], X_sorted[j]) for i in range(len(X_sorted)) for j in range(i + 1, len(X_sorted))])\n",
    "y_pairs = np.array([(y_sorted[i], y_sorted[j]) for i in range(len(y_sorted)) for j in range(i + 1, len(y_sorted))])\n",
    "y_pairs = np.expand_dims(y_pairs, axis=-1)\n",
    "\n",
    "\n",
    "ds = LabelPairDataset()\n",
    "ds.create_from_numpy_pairs(X_pairs, y_pairs)\n",
    "for x,y in ds:\n",
    "    print(x,y)\n",
    "print(len(ds))\n",
    "model = LabelRankingModel(input_dim=generator.n_features_, hidden_dims=3*[generator.n_features_], activations=[torch.nn.Sigmoid(), SortLayer(),torch.nn.Identity()], output_dim=len(generator.classes_))\n",
    "\n",
    "pair_loader = DataLoader(ds, batch_size=32)\n",
    "model.num_classes = generator.n_classes_\n",
    "model._fit(pair_loader, val_loader=pair_loader, num_epochs=250, learning_rate=0.001, verbose=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_seed, y_seed = make_classification(100,n_informative=1,n_classes=3)\n",
    "generator = MultinomialSyntheticDataGenerator(random_state=42)\n",
    "generator.fit(X_seed, y_seed)\n",
    "X_cal, y_cal = generator.generate(n=100)\n",
    "\n",
    "num_instances = 30\n",
    "\n",
    "mapie_clf = MapieClassifier(estimator=generator, cv=\"prefit\", conformity_score=APSConformityScore())\n",
    "# create mapie classifier for conformity scores\n",
    "mapie_clf.fit(X_cal, y_cal)\n",
    "# create \n",
    "oracle_annotator = OracleAnnotator(mapie_clf, generator)\n",
    "\n",
    "X_train = generator.generate_instances(num_instances).repeat(3, axis=0)\n",
    "y_train = np.tile(generator.classes_, num_instances)\n",
    "conformities = oracle_annotator.get_conformity(X_train,y_train)\n",
    "sort_idx = (-conformities).argsort(axis=0).flatten()\n",
    "\n",
    "X_sorted = X_train[sort_idx]\n",
    "y_sorted = y_train[sort_idx]\n",
    "\n",
    "X_pairs = np.array([(X_sorted[i], X_sorted[j]) for i in range(len(X_sorted)) for j in range(i + 1, len(X_sorted))])\n",
    "y_pairs = np.array([(y_sorted[i], y_sorted[j]) for i in range(len(y_sorted)) for j in range(i + 1, len(y_sorted))])\n",
    "y_pairs = np.expand_dims(y_pairs,axis=-1)\n",
    "\n",
    "\n",
    "ds = LabelPairDataset()\n",
    "ds.create_from_numpy_pairs(X_pairs, y_pairs)\n",
    "pair_loader = DataLoader(ds, batch_size=64)\n",
    "ds_val = LabelPairDataset()\n",
    "ds_val.create_from_numpy_pairs(X_pairs, y_pairs)\n",
    "val_loader = DataLoader(ds_val, batch_size=64)\n",
    "print(len(ds))\n",
    "model = LabelRankingModel(input_dim=generator.n_features_, hidden_dims=3*[generator.n_features_], activations=[torch.nn.Sigmoid(), torch.nn.Sigmoid(), torch.nn.Sigmoid()], output_dim=len(generator.classes_))\n",
    "model.num_classes = generator.n_classes_\n",
    "device = next(model.parameters()).device\n",
    "print(f\"Model is on: {device}\")\n",
    "model._fit(pair_loader, val_loader=pair_loader, num_epochs=300, learning_rate=0.01, patience=100, verbose=True)\n",
    "\n",
    "# generate data from data generating process and check whether the learned non-conformity relation sorts them correctly\n",
    "X_test = generator.generate_instances(100).repeat(3, axis=0)\n",
    "y_test = np.tile(generator.classes_, 100)      \n",
    "skills_from_model = np.take_along_axis(model.predict_class_skills(X_test), y_test[:,np.newaxis], axis=1)\n",
    "conformity_scores = oracle_annotator.get_conformity(X_test, y_test)\n",
    "tau_corr, p_value = kendalltau(skills_from_model, conformity_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "class_1 = np.random.normal(loc=1, scale=0.5, size=100)\n",
    "class_2 = np.random.normal(loc=3, scale=0.5, size=100)\n",
    "class_3 = np.random.normal(loc=5, scale=0.5, size=100)\n",
    "\n",
    "# Combine into one dataset\n",
    "X_seed = np.concatenate([class_1, class_2]).reshape(-1,1)\n",
    "y_seed = np.concatenate([np.zeros(100), np.ones(100)])\n",
    "\n",
    "\n",
    "generator = MultinomialSyntheticDataGenerator(random_state=42)\n",
    "generator.fit(X_seed, y_seed)\n",
    "X_cal, y_cal = generator.generate(n=100)\n",
    "\n",
    "num_instances = 30\n",
    "\n",
    "mapie_clf = MapieClassifier(estimator=generator, cv=\"prefit\", conformity_score=APSConformityScore())\n",
    "# create mapie classifier for conformity scores\n",
    "mapie_clf.fit(X_cal, y_cal)\n",
    "# create \n",
    "oracle_annotator = OracleAnnotator(mapie_clf, generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "X,y = generator.generate(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(X.flatten(), y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "def _sort_sum(probs):\n",
    "        \"\"\"\n",
    "        Sort probabilities and calculate cumulative sum.\n",
    "\n",
    "        Args:\n",
    "            probs (torch.Tensor): The prediction probabilities.\n",
    "\n",
    "        Returns:\n",
    "            tuple: A tuple containing:\n",
    "                - indices (torch.Tensor): The rank of ordered probabilities in descending order.\n",
    "                - ordered (torch.Tensor): The ordered probabilities in descending order.\n",
    "                - cumsum (torch.Tensor): The accumulation of sorted probabilities.\n",
    "        \"\"\"\n",
    "        ordered, indices = torch.sort(probs, dim=-1, descending=True)\n",
    "        cumsum = torch.cumsum(ordered, dim=-1)\n",
    "        return indices, ordered, cumsum\n",
    "\n",
    "def _calculate_single_label(probs, label):\n",
    "    \"\"\"\n",
    "    Calculate non-conformity score for a single label.\n",
    "\n",
    "    Args:\n",
    "        probs (torch.Tensor): The prediction probabilities.\n",
    "        label (torch.Tensor): The ground truth label.\n",
    "\n",
    "    Returns:\n",
    "        torch.Tensor: The non-conformity score for the given label.\n",
    "    \"\"\"\n",
    "    indices, ordered, cumsum = _sort_sum(probs)\n",
    "    U = torch.zeros(indices.shape[0], device=probs.device)\n",
    "\n",
    "    idx = torch.where(indices == label.view(-1, 1))\n",
    "    scores = cumsum[idx] - U * ordered[idx]\n",
    "    return scores\n",
    "\n",
    "def get_conformity(clf, X, y):\n",
    "    y_pred_proba = clf.predict_proba(X)\n",
    "    \n",
    "    scores = _calculate_single_label(torch.tensor(y_pred_proba),torch.tensor(y))\n",
    "    return scores\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set(font_scale=1.5,rc={'text.usetex' : True})\n",
    "sns.set_style(\"whitegrid\")\n",
    "plt.rc('font', **{'family': 'serif'})\n",
    "plt.rcParams[\"figure.figsize\"] = (7, 2)\n",
    "\n",
    "# Step 1: Define class parameters\n",
    "class_params = {\n",
    "    0: {\"mean\": 2, \"std\": 1, \"prior\": 0.3},\n",
    "    1: {\"mean\": 5, \"std\": 0.8, \"prior\": 0.4},\n",
    "    2: {\"mean\": 8, \"std\": 1.2, \"prior\": 0.3},\n",
    "}\n",
    "\n",
    "# Step 2: Create an instance of the classifier\n",
    "clf = GaussianSyntheticClassifier(class_params=class_params)\n",
    "\n",
    "# Step 3: Generate synthetic data\n",
    "X, y = clf.generate_data(n_samples=1000)\n",
    "\n",
    "# Step 4: Fit the classifier (it's effectively a no-op, but we call it for sklearn compatibility)\n",
    "clf.fit(X, y)\n",
    "\n",
    "# Step 5: Predict probabilities for a range of x values (for plotting)\n",
    "x_range = np.linspace(X.min() - 1, X.max() + 1, 1000)\n",
    "probs = clf.predict_proba(x_range)\n",
    "\n",
    "# Step 6: Plot the data and probabilities\n",
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "# Scatter plot of the synthetic data\n",
    "# plt.scatter(X, y + np.random.uniform(-0.1, 0.1, len(y)), c=y, cmap='viridis', alpha=0.6, edgecolor='k', label='Data')\n",
    "\n",
    "# Plot the probabilities for each class\n",
    "for i, c in enumerate(clf.classes_):\n",
    "    plt.plot(x_range, probs[:, i], label=rf'$P(y={c} \\mid x)$', linewidth=2)\n",
    "\n",
    "X_nc = np.linspace(-2,12,300).reshape(-1,1)\n",
    "y_0  = np.full((300),0)\n",
    "y_1  = np.full((300),1)\n",
    "y_2  = np.full((300),2)\n",
    "y_0_nc = get_conformity(clf,X_nc,y_0).detach().cpu().numpy()\n",
    "y_1_nc = get_conformity(clf,X_nc,y_1).detach().cpu().numpy()\n",
    "y_2_nc = get_conformity(clf,X_nc,y_2).detach().cpu().numpy()\n",
    "\n",
    "\n",
    "plt.plot(X_nc.flatten(), y_0_nc, linestyle=\"--\", label=\"APS for y=0\")\n",
    "plt.plot(X_nc.flatten(), y_1_nc, linestyle=\"--\", label=\"APS for y=1\")\n",
    "plt.plot(X_nc.flatten(), y_2_nc, linestyle=\"--\", label=\"APS for y=2\")\n",
    "\n",
    "\n",
    "# Labels and title\n",
    "plt.title(\"Conformity scores for synthetic data\")\n",
    "plt.xlabel(\"Feature Value (x)\")\n",
    "plt.ylabel(\"Class / Probability\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cp_rank",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
