{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0684aa57-3064-456a-aa85-342ae65ea669",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import random\n",
    "import warnings\n",
    "import time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.mixture import BayesianGaussianMixture\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.datasets import make_circles, load_iris\n",
    "from sklearn import preprocessing\n",
    "dir_ = os.path.abspath(os.path.join(os.getcwd(), \"..\", \"..\"))\n",
    "os.chdir(dir_)\n",
    "# Local modules\n",
    "import utils\n",
    "import prior\n",
    "import transformer\n",
    "import main\n",
    "from sklearn.metrics import rand_score, cluster, adjusted_mutual_info_score\n",
    "from sklearn.metrics.cluster import adjusted_rand_score\n",
    "# Settings\n",
    "matplotlib.use(\"TkAgg\")\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "%matplotlib inline\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c518bb13-c2a5-4089-a61f-5c3515df30a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "class Normalize(nn.Module):\n",
    "    def __init__(self, mean, std):\n",
    "        super().__init__()\n",
    "        self.mean = mean\n",
    "        self.std = std\n",
    "\n",
    "    def forward(self, x):\n",
    "        return (x - self.mean) / self.std\n",
    "\n",
    "\n",
    "class Transformer(nn.Transformer):\n",
    "    def __init__(self, d_model, nhead, nhid, nlayers, in_features=1, buckets_size=100):\n",
    "        super(Transformer, self).__init__(d_model=d_model, nhead=nhead, dim_feedforward=nhid, dropout=0,\n",
    "                                          num_encoder_layers=nlayers)\n",
    "        self.model_type = 'Transformer'\n",
    "        self.src_mask = None\n",
    "        self.d_model = d_model\n",
    "        self.nhead = nhead\n",
    "        self.nhid = nhid\n",
    "        self.nlayers = nlayers\n",
    "        self.embed_size = d_model\n",
    "        self.linear_x = nn.Linear(in_features, d_model)\n",
    "        self.decoder = nn.Linear(d_model, buckets_size)\n",
    "        self.linear_num_clusters = nn.Linear(in_features, d_model)\n",
    "        self.embedding = nn.Embedding(buckets_size + 1,self.embed_size)\n",
    "\n",
    "    def _generate_mask(self, size):\n",
    "        matrix = torch.zeros((size, size), dtype=torch.float32)\n",
    "        matrix[:size -1, :size - 1] = 1\n",
    "        matrix[size -1] = 1\n",
    "        matrix = matrix.masked_fill(matrix == 0, float('-inf')).masked_fill(matrix == 1, 0)\n",
    "        return matrix\n",
    "\n",
    "\n",
    "    def forward(self, X,num_clusters):\n",
    "        # convert features to higher dimension\n",
    "        train = (self.linear_x(X)) # Shape S + 1,B, d_model\n",
    "        cluster_input = torch.full((1,X.shape[1], X.shape[2]), -1, dtype=torch.float) # learns the cluster numbers\n",
    "        cluster_embedding = self.linear_num_clusters(cluster_input)\n",
    "        train = torch.cat((train, cluster_embedding) , dim=0)\n",
    "        cluster_conditional = self.embedding(num_clusters) # 1, B, E\n",
    "        train = train + cluster_conditional #broadcasting automatically done\n",
    "        src_mask = self._generate_mask(train.shape[0])\n",
    "        output = self.encoder(train, mask=src_mask) # S + 1, B, E\n",
    "        output = self.decoder(output)\n",
    "        return output[:-1], output[-1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98fa7f5a-95ed-4e4e-907d-c4b187411a1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import preprocessing\n",
    "import torch\n",
    "from scipy.stats import invwishart\n",
    "import random\n",
    "\n",
    "def sort( x, y,X_true, centers):\n",
    "    distances = np.linalg.norm(x, axis=1)\n",
    "    sorted_indices = np.argsort(distances)\n",
    "    sorted_x = x[sorted_indices]\n",
    "    sorted_y = y[sorted_indices]\n",
    "    sorted_X_true = X_true[sorted_indices]\n",
    "    mapping = {}\n",
    "    storage = set()\n",
    "    curr = 0\n",
    "    for i in range(len(sorted_y)):\n",
    "        if len(mapping) == centers:\n",
    "            break\n",
    "        if sorted_y[i] not in storage:\n",
    "            mapping[sorted_y[i]] = curr\n",
    "            storage.add(sorted_y[i])\n",
    "            curr += 1\n",
    "\n",
    "    y_mapped = np.array([mapping[number] for number in sorted_y])\n",
    "    indices = np.random.permutation(len(sorted_x))\n",
    "    shuffled_x = sorted_x[indices]\n",
    "    shuffled_y = y_mapped[indices]\n",
    "    shuffled_X_true = sorted_X_true[indices]\n",
    "    return shuffled_x, shuffled_y, shuffled_X_true\n",
    "\n",
    "\n",
    "def generate_bayesian_gmm_data(\n",
    "        batch_size=32,\n",
    "        start_seq_len = 100,\n",
    "        seq_len=500,\n",
    "        num_features=2,\n",
    "        min_classes=1,\n",
    "        num_classes=10,\n",
    "        weight_concentration_prior=1.0,  # Dirichlet prior for mixture weights\n",
    "        mean_prior=0.0,  # Mean of prior over cluster means\n",
    "        mean_precision_prior=0.01,  # Precision (confidence) over means\n",
    "        degrees_of_freedom_prior=None,  # Degrees of freedom for Wishart prior\n",
    "        covariance_prior=None,  # Scale matrix for Wishart prior\n",
    "        seed=None,\n",
    "        nan_frac = None,\n",
    "        **kwargs\n",
    "):\n",
    "    if seed is not None:\n",
    "        np.random.seed(seed)\n",
    "        random.seed(seed)\n",
    "        torch.manual_seed(seed)\n",
    "\n",
    "    #features = random.randint(2, num_features)\n",
    "    features = num_features\n",
    "    if covariance_prior is None:\n",
    "        covariance_prior = np.eye(features)\n",
    "\n",
    "    if degrees_of_freedom_prior is None:\n",
    "        degrees_of_freedom_prior = features\n",
    "\n",
    "    seq_len = random.randint(start_seq_len, seq_len)\n",
    "    clusters_x = np.zeros((batch_size, seq_len, num_features))\n",
    "    clusters_y = np.zeros((batch_size, seq_len))\n",
    "    clusters_x_true = np.zeros((batch_size, seq_len, num_features))\n",
    "    batch_classes = []\n",
    "    for i in range(batch_size):\n",
    "        n_components = np.random.randint(min_classes, num_classes + 1)\n",
    "\n",
    "        # Sample weights from Dirichlet prior\n",
    "        weights = np.random.dirichlet(np.ones(n_components) * weight_concentration_prior)\n",
    "\n",
    "        # Assign points to components proportionally\n",
    "        counts = np.random.multinomial(seq_len, weights)\n",
    "\n",
    "        # Sample cluster parameters\n",
    "        means = []\n",
    "        covariances = []\n",
    "        for _ in range(n_components):\n",
    "            Sigma = invwishart.rvs(df=degrees_of_freedom_prior, scale=covariance_prior)\n",
    "            mu = np.random.multivariate_normal(np.full(features, mean_prior), Sigma / mean_precision_prior)\n",
    "            means.append(mu)\n",
    "            covariances.append(Sigma)\n",
    "        # Sample points from GMM\n",
    "        X = []\n",
    "        y = []\n",
    "        for k in range(n_components):\n",
    "            n_k = counts[k]\n",
    "            X_k = np.random.multivariate_normal(means[k], covariances[k], size=n_k)\n",
    "            X.append(X_k)\n",
    "            y.append(np.full(n_k, k))\n",
    "\n",
    "        X = np.vstack(X)\n",
    "        y = np.concatenate(y)\n",
    "        X = np.pad(X, ((0, 0), (0, num_features - features)), mode='constant')\n",
    "\n",
    "        x = preprocessing.MinMaxScaler().fit_transform(X)\n",
    "        n_samples, n_features = seq_len, features\n",
    "\n",
    "        if nan_frac:\n",
    "            fraction_missing = np.random.uniform(0, nan_frac)\n",
    "            n_elements = n_samples * n_features\n",
    "            n_missing = int(fraction_missing * n_elements)\n",
    "\n",
    "            if n_missing > 0:\n",
    "                all_indices = np.arange(n_elements)\n",
    "                np.random.shuffle(all_indices)\n",
    "                chosen = all_indices[:n_missing]\n",
    "\n",
    "                rows, cols = np.unravel_index(chosen, (n_samples, n_features))\n",
    "\n",
    "                # Ensure no row is fully NaN\n",
    "                full_rows = [r for r in np.unique(rows) if np.sum(rows == r) == n_features]\n",
    "                to_keep = []\n",
    "                for r in full_rows:\n",
    "                    row_mask = rows == r\n",
    "                    row_positions_in_chosen = np.where(row_mask)[0]\n",
    "                    keep_pos = np.random.choice(row_positions_in_chosen)\n",
    "                    to_keep.append(keep_pos)\n",
    "                # Remove all \"to_keep\" indices from chosen at once\n",
    "                chosen = np.delete(chosen, to_keep)\n",
    "                rows, cols = np.unravel_index(chosen, (n_samples, n_features))\n",
    "                x[rows, cols] = -2\n",
    "\n",
    "        x, y, X_true = sort(x, y, X, n_components)\n",
    "        clusters_x[i] = x\n",
    "        clusters_y[i] = y\n",
    "        clusters_x_true[i] = X_true\n",
    "        batch_classes.append(n_components)\n",
    "\n",
    "    clusters_x_true = torch.tensor(clusters_x_true, dtype=torch.float32).permute(1, 0, 2)\n",
    "    clusters_x = torch.tensor(clusters_x, dtype=torch.float32).permute(1, 0, 2)\n",
    "    clusters_y = torch.tensor(clusters_y, dtype=torch.float32).permute(1, 0)\n",
    "    batch_classes = torch.tensor(batch_classes).unsqueeze(0)\n",
    "\n",
    "    return clusters_x, clusters_y, clusters_x_true, batch_classes\n",
    "\n",
    "\n",
    "def get_clusters_bayesian_gmm(X,batch_classes=None, n_components=10,weight_concentration_prior=1.0,\n",
    "                            mean_prior=None, mean_precision_prior=None,degrees_of_freedom_prior=None,\n",
    "        covariance_prior=None, random_state=42, n_init = None):\n",
    "    batch_size = X.shape[1]\n",
    "    clusters = [] \n",
    "    time_arr = []\n",
    "    labels = [] \n",
    "    probs = [] \n",
    "    for batch in range(batch_size):\n",
    "        time_tot = 0 \n",
    "        X_curr = X[:,batch,:]\n",
    "        mask = (X_curr != 0).any(axis=0)  # Boolean mask for columns that are not all zero\n",
    "        X_curr = X_curr[:, mask]\n",
    "        features = X_curr.shape[1] \n",
    "        curr_mean_prior = np.zeros(features)\n",
    "        curr_covariance_prior = np.eye(features)\n",
    "        curr_degrees_of_freedom_prior = features\n",
    "        best_elbo = float('-inf') \n",
    "        best_cluster = -1 \n",
    "        best_prob = [] \n",
    "        best_label = [] \n",
    "        comps = n_components  \n",
    "        for component in range(1, comps + 1):\n",
    "            model = BayesianGaussianMixture(n_components=component,\n",
    "                                weight_concentration_prior_type='dirichlet_distribution',\n",
    "                                weight_concentration_prior=weight_concentration_prior,\n",
    "                                mean_prior=curr_mean_prior, mean_precision_prior=mean_precision_prior,\n",
    "                                n_init = n_init, \n",
    "                                degrees_of_freedom_prior=curr_degrees_of_freedom_prior,max_iter=5000,\n",
    "                                covariance_prior=curr_covariance_prior,random_state=42) \n",
    "            start_time = time.time() \n",
    "            model.fit(X_curr) \n",
    "            end_time = time.time() \n",
    "            time_tot += end_time - start_time\n",
    "            lower_bound = model.lower_bound_\n",
    "            if lower_bound - np.log(math.factorial(component)) > best_elbo:\n",
    "                best_elbo = lower_bound - np.log(math.factorial(component))\n",
    "                best_cluster = component\n",
    "                best_label = model.predict(X_curr) \n",
    "                best_prob = model.predict_proba(X_curr) \n",
    "        time_arr.append(time_tot)\n",
    "        clusters.append(best_cluster)\n",
    "        labels.append(best_label) \n",
    "        probs.append(best_prob)\n",
    "    return clusters, time_arr, labels, probs\n",
    "\n",
    "def get_nll_transformer(model, X, true_labels,batch_classes, n_components=5, condition=False):\n",
    "    batch_size = true_labels.shape[1]\n",
    "    criterion = nn.NLLLoss()\n",
    "    losses = []\n",
    "    for batch_index in range(batch_size):\n",
    "        train_x = X[:, batch_index].unsqueeze(1)\n",
    "        train_y = true_labels[:, batch_index]\n",
    "        train_y = train_y.long()\n",
    "        batch = batch_classes[batch_index].unsqueeze(-1)\n",
    "        logits, cluster_output = model(train_x, torch.full(batch.shape,0, dtype=torch.long))\n",
    "        cluster_output = cluster_output.view(-1, cluster_output.shape[2])\n",
    "        predicted_index = cluster_output.argmax(dim=1).long().item()\n",
    "        logits, cluster_output = model(train_x, torch.full(batch.shape, predicted_index + 1, dtype=torch.long))\n",
    "        logits = logits.squeeze(1)\n",
    "\n",
    "        predicted_probs = F.log_softmax(logits, dim=-1)\n",
    "        num_true_classes = len(torch.unique(train_y))\n",
    "        best_loss = float(\"inf\")\n",
    "        all_perms = list(itertools.permutations(range(n_components), num_true_classes))\n",
    "        for perm in all_perms:\n",
    "            permuted_part = torch.as_tensor(predicted_probs[:, perm])\n",
    "            remaining_indices = [i for i in range(n_components) if i not in perm]\n",
    "            remaining_part = torch.as_tensor(predicted_probs[:, remaining_indices])\n",
    "            new_tensor = torch.cat([permuted_part, remaining_part], dim=1)\n",
    "            loss_val = criterion(new_tensor, train_y).item()\n",
    "            best_loss = min(best_loss, loss_val)\n",
    "        losses.append(best_loss)\n",
    "    return losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb48a123-4c2a-46b1-a10c-5e7ce1a95be5",
   "metadata": {},
   "outputs": [],
   "source": [
    "d_model, nhead, nhid, nlayers = 256, 4, 512, 4\n",
    "in_features = 2\n",
    "num_features = in_features\n",
    "num_outputs = 10\n",
    "num_classes =8 \n",
    "batch_size = 1\n",
    "mean_precision_prior = 0.1\n",
    "weight_concentration_prior = 1.0\n",
    "start_seq_len = 100\n",
    "seq_len = 500\n",
    "iterations = 20\n",
    "start_seed  = 0\n",
    "nr = 1\n",
    "n_init = 1\n",
    "seed = 0\n",
    "model_path = \"models/models_original/pfn_hard_2D.pt\"\n",
    "\n",
    "model = Transformer(d_model, nhead, nhid, nlayers, in_features=in_features,\n",
    "                                buckets_size=num_outputs)\n",
    "checkpoint = torch.load(model_path, weights_only=True, map_location=torch.device('cpu'))\n",
    "model.load_state_dict(checkpoint['model_state_dict'])\n",
    "model.eval()\n",
    "print('')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68d6210a-e0a5-4c4e-99a4-583736d71e5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import itertools \n",
    "\n",
    "def compute_correct_clusters(batch_size = 100, iterations = 1000,start_seq_len = 100, \n",
    "                             seq_len=500,num_features=2, min_classes=1, num_classes=10, weight_concentration_prior=1,\n",
    "                             mean_precision_prior = None, n_init = None, seed = 0):\n",
    "    bayes = [] \n",
    "    clusters_gmm = [] \n",
    "    ami_gmm = [] \n",
    "    ari_gmm = [] \n",
    "    purity_gmm = [] \n",
    "    r_model = [] \n",
    "    a_model = [] \n",
    "    p_model = [] \n",
    "    clusters_model = [] \n",
    "    nll_gmm = [] \n",
    "    nll_transformer = [] \n",
    "    for iteration in range(iterations):\n",
    "        train_X, train_Y, X_true, batch_classes = generate_bayesian_gmm_data(batch_size=batch_size,start_seq_len = start_seq_len,\n",
    "                                                                               seq_len=seq_len, num_features=num_features,min_classes=min_classes, num_classes=num_classes,\n",
    "                                                                               mean_precision_prior=mean_precision_prior, weight_concentration_prior=weight_concentration_prior,\n",
    "                                                                               seed = seed + iteration)\n",
    "        train_Y = train_Y.cpu() \n",
    "\n",
    "        # with torch.no_grad():\n",
    "        #     logits, cluster_output = model(train_X,torch.full(batch_classes.shape, 0, dtype=torch.long))\n",
    "        # cluster_input = torch.argmax(cluster_output, dim=-1) + 1\n",
    "        with torch.no_grad():\n",
    "            logits, cluster_output = model(train_X, batch_classes)\n",
    "        predictions = torch.argmax(logits, dim=-1).cpu().numpy()\n",
    "        predictions_cluster_count = (torch.argmax(cluster_output, dim=-1) + 1).squeeze(0) \n",
    "        \n",
    "        # Model metrics\n",
    "        rand_index_score = utils.compute_external_metrics(train_Y, predictions, 'rand_index')\n",
    "        ami_score = utils.compute_external_metrics(train_Y, predictions, 'ami')\n",
    "        purity_score = utils.compute_external_metrics(train_Y, predictions, 'purity')\n",
    "\n",
    "        r_model.extend(rand_index_score)\n",
    "        a_model.extend(ami_score)\n",
    "        p_model.extend(purity_score)\n",
    "        batch_classes = batch_classes.squeeze(0).detach()\n",
    "        bayes_cluster_prediction,time_arr,labels,probabilities = get_clusters_bayesian_gmm(X_true,batch_classes, mean_precision_prior=mean_precision_prior,\n",
    "                                                           weight_concentration_prior=weight_concentration_prior, n_init = n_init)\n",
    "        \n",
    "        nll_t = get_nll_transformer(model, train_X, train_Y,batch_classes=batch_classes, n_components = num_classes, condition=True)\n",
    "        nll_transformer.extend(nll_t) \n",
    "        true_labels = [train_Y[:, i].to(torch.int).tolist() for i in range(train_Y.shape[1])]\n",
    "        for i in range(len(bayes_cluster_prediction)):\n",
    "            curr_true_label = true_labels[i]\n",
    "            curr_true_cluster = batch_classes[i] \n",
    "            curr_prediction_cluster = bayes_cluster_prediction[i]\n",
    "            curr_prob_prediction = probabilities[i] \n",
    "            curr_label_prediction = labels[i]\n",
    "            clusters_gmm.append(curr_true_cluster == curr_prediction_cluster)\n",
    "            clusters_model.append(predictions_cluster_count[i] == curr_true_cluster)\n",
    "            ari_gmm.append(adjusted_rand_score(curr_true_label, curr_label_prediction))\n",
    "            ami_gmm.append(adjusted_mutual_info_score(curr_true_label, curr_label_prediction))\n",
    "            matrix = cluster.contingency_matrix(curr_true_label, curr_label_prediction)\n",
    "            purity_gmm.append(np.sum(np.amax(matrix, axis=0)) / np.sum(matrix))\n",
    "            target = torch.tensor(curr_true_label, dtype=torch.long)\n",
    "            criterion = nn.NLLLoss()\n",
    "            if curr_prediction_cluster < curr_true_cluster:\n",
    "                nll_gmm.append(float(\"inf\"))\n",
    "            else:\n",
    "                best_loss = float(\"inf\")\n",
    "                all_perms = list(itertools.permutations(range(curr_prediction_cluster), int(curr_true_cluster)))\n",
    "                for perm in all_perms:\n",
    "                    permuted = torch.as_tensor(curr_prob_prediction[:, perm])\n",
    "                    log_t = torch.log(permuted)\n",
    "                    loss = criterion(log_t, target).item()\n",
    "                    if loss < best_loss:\n",
    "                        best_loss = loss\n",
    "                nll_gmm.append(best_loss)\n",
    "    return {\n",
    "        'model': {\n",
    "            'clusters': np.array(clusters_model),\n",
    "            'rand_index': np.array(r_model),\n",
    "            'ami': np.array(a_model),\n",
    "            'purity': np.array(p_model),\n",
    "            'nll': np.array(nll_transformer)\n",
    "        },\n",
    "        'gmm': {\n",
    "            'clusters': np.array(clusters_gmm), \n",
    "            'rand_index': np.array(ari_gmm),\n",
    "            'ami': np.array(ami_gmm),\n",
    "            'purity': np.array(purity_gmm),\n",
    "            'nll' : np.array(nll_gmm)\n",
    "        }\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a363972e-d69e-4787-8b58-a8accae9ffa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = compute_correct_clusters(batch_size=batch_size,iterations=iterations,start_seq_len = start_seq_len, \n",
    "                                  seq_len=seq_len,num_features=num_features,num_classes = num_classes,\n",
    "                                  weight_concentration_prior = weight_concentration_prior,\n",
    "                                  mean_precision_prior=mean_precision_prior, n_init=n_init,seed = seed)\n",
    "\n",
    "# Cluster-PFN results\n",
    "clusters_pfn = results['model']['clusters']\n",
    "rand_index_pfn = results['model']['rand_index']\n",
    "ami_pfn = results['model']['ami']\n",
    "purity_pfn = results['model']['purity']\n",
    "nll_pfn = results['model']['nll']\n",
    "\n",
    "# GMM results\n",
    "clusters_gmm = results['gmm']['clusters']\n",
    "rand_index_gmm = results['gmm']['rand_index']\n",
    "ami_gmm = results['gmm']['ami']\n",
    "purity_gmm = results['gmm']['purity']\n",
    "nll_gmm = results['gmm']['nll']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ddba543-9897-40e9-95e9-4b9bac8bf4f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_violin_plots_external_metrics(ami_pfn, ami_bayes, ari_pfn, ari_bayes, purity_pfn, purity_bayes):\n",
    "    df = pd.DataFrame({\n",
    "        'AMI': np.concatenate([ami_pfn, ami_bayes]),\n",
    "        'ARI': np.concatenate([ari_pfn, ari_bayes]),\n",
    "        'Purity': np.concatenate([purity_pfn, purity_bayes]),\n",
    "        'Model': ['PFN'] * len(ami_pfn) + ['Bayesian GMM VI'] * len(ami_pfn)\n",
    "    })\n",
    "\n",
    "    # Melt into long form for Seaborn\n",
    "    df_melted = df.melt(id_vars='Model', var_name='Metric', value_name='Score')\n",
    "\n",
    "    # Summary (optional)\n",
    "    summary = df_melted.groupby(['Model', 'Metric'])['Score'].agg(['mean', 'median', 'std']).reset_index()\n",
    "    # print(summary)\n",
    "\n",
    "    # Plot violin plots\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    sns.violinplot(\n",
    "        data=df_melted, \n",
    "        x='Metric', \n",
    "        y='Score', \n",
    "        hue='Model', \n",
    "        split=True, \n",
    "        inner=None, \n",
    "        cut = 0,\n",
    "        palette = 'tab10'\n",
    "    )\n",
    "\n",
    "    plt.title(f\"Cluster-PFN vs Bayesian GMM VI external metrics\", fontsize=16)\n",
    "    plt.ylim(0.4, 1.01)\n",
    "    plt.xticks(fontsize=16)\n",
    "    plt.xlabel(\"Metric\", fontsize=16)\n",
    "    plt.ylabel(\"Score\", fontsize=16)\n",
    "    plt.legend(loc='lower right')\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fff4d3c-2d74-48a0-8b5a-f1736516a40e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"pfn cluster acccuracy: {np.sum(clusters_pfn) / len(clusters_pfn)}\")\n",
    "print(f\"VI cluster acccuracy: {np.sum(clusters_gmm) / len(clusters_pfn)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a80ad3b2-1440-409d-a263-65826c355e36",
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_violin_plots_external_metrics(ami_pfn, ami_gmm, rand_index_pfn, rand_index_gmm, purity_pfn, purity_gmm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10f75657-32d4-4a73-9a97-86de150e93f2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cuda_test2",
   "language": "python",
   "name": "cuda_test2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
