{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b34eb5d1-c4b7-4984-b351-4256d8b92b6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import random\n",
    "import warnings\n",
    "import time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.mixture import BayesianGaussianMixture\n",
    "from sklearn.mixture import GaussianMixture\n",
    "\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.datasets import make_circles, load_iris\n",
    "from sklearn import preprocessing\n",
    "dir_ = os.path.abspath(os.path.join(os.getcwd(), \"..\", \"..\"))\n",
    "os.chdir(dir_)\n",
    "# Local modules\n",
    "import utils\n",
    "import prior\n",
    "import transformer\n",
    "import main\n",
    "from sklearn.metrics import rand_score, cluster, adjusted_mutual_info_score \n",
    "from sklearn.metrics.cluster import adjusted_rand_score\n",
    "\n",
    "# Settings\n",
    "matplotlib.use(\"TkAgg\")\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "%matplotlib inline\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca5ba37c-e948-4ef2-b38d-446540bebccc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "class Normalize(nn.Module):\n",
    "    def __init__(self, mean, std):\n",
    "        super().__init__()\n",
    "        self.mean = mean\n",
    "        self.std = std\n",
    "\n",
    "    def forward(self, x):\n",
    "        return (x - self.mean) / self.std\n",
    "\n",
    "\n",
    "class Transformer(nn.Transformer):\n",
    "    def __init__(self, d_model, nhead, nhid, nlayers, in_features=1, buckets_size=100):\n",
    "        super(Transformer, self).__init__(d_model=d_model, nhead=nhead, dim_feedforward=nhid, dropout=0,\n",
    "                                          num_encoder_layers=nlayers)\n",
    "        self.model_type = 'Transformer'\n",
    "        self.src_mask = None\n",
    "        self.d_model = d_model\n",
    "        self.nhead = nhead\n",
    "        self.nhid = nhid\n",
    "        self.nlayers = nlayers\n",
    "        self.embed_size = d_model\n",
    "        self.linear_x = nn.Linear(in_features, d_model)\n",
    "        self.decoder = nn.Linear(d_model, buckets_size)\n",
    "        self.linear_num_clusters = nn.Linear(in_features, d_model)\n",
    "        self.embedding = nn.Embedding(buckets_size + 1,self.embed_size)\n",
    "\n",
    "    def _generate_mask(self, size):\n",
    "        matrix = torch.zeros((size, size), dtype=torch.float32)\n",
    "        matrix[:size -1, :size - 1] = 1\n",
    "        matrix[size -1] = 1\n",
    "        matrix = matrix.masked_fill(matrix == 0, float('-inf')).masked_fill(matrix == 1, 0)\n",
    "        return matrix\n",
    "\n",
    "\n",
    "    def forward(self, X,num_clusters):\n",
    "        # convert features to higher dimension\n",
    "        train = (self.linear_x(X)) # Shape S + 1,B, d_model\n",
    "        cluster_input = torch.full((1,X.shape[1], X.shape[2]), -1, dtype=torch.float) # learns the cluster numbers\n",
    "        cluster_embedding = self.linear_num_clusters(cluster_input)\n",
    "        train = torch.cat((train, cluster_embedding) , dim=0)\n",
    "        cluster_conditional = self.embedding(num_clusters) # 1, B, E\n",
    "        train = train + cluster_conditional #broadcasting automatically done\n",
    "        src_mask = self._generate_mask(train.shape[0])\n",
    "        output = self.encoder(train, mask=src_mask) # S + 1, B, E\n",
    "        output = self.decoder(output)\n",
    "        return output[:-1], output[-1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c117fc27-35a1-4bb4-a856-e9d8fceefb32",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import preprocessing\n",
    "import torch\n",
    "from scipy.stats import invwishart\n",
    "import random\n",
    "\n",
    "def sort( x, y,X_true, centers):\n",
    "    distances = np.linalg.norm(x, axis=1)\n",
    "    sorted_indices = np.argsort(distances)\n",
    "    sorted_x = x[sorted_indices]\n",
    "    sorted_y = y[sorted_indices]\n",
    "    sorted_X_true = X_true[sorted_indices]\n",
    "    mapping = {}\n",
    "    storage = set()\n",
    "    curr = 0\n",
    "    for i in range(len(sorted_y)):\n",
    "        if len(mapping) == centers:\n",
    "            break\n",
    "        if sorted_y[i] not in storage:\n",
    "            mapping[sorted_y[i]] = curr\n",
    "            storage.add(sorted_y[i])\n",
    "            curr += 1\n",
    "\n",
    "    y_mapped = np.array([mapping[number] for number in sorted_y])\n",
    "    indices = np.random.permutation(len(sorted_x))\n",
    "    shuffled_x = sorted_x[indices]\n",
    "    shuffled_y = y_mapped[indices]\n",
    "    shuffled_X_true = sorted_X_true[indices]\n",
    "    return shuffled_x, shuffled_y, shuffled_X_true\n",
    "\n",
    "\n",
    "def generate_bayesian_gmm_data(\n",
    "        batch_size=32,\n",
    "        start_seq_len = 100,\n",
    "        seq_len=500,\n",
    "        num_features=2,\n",
    "        min_classes=1,\n",
    "        num_classes=10,\n",
    "        weight_concentration_prior=1.0,  # Dirichlet prior for mixture weights\n",
    "        mean_prior=0.0,  # Mean of prior over cluster means\n",
    "        mean_precision_prior=0.01,  # Precision (confidence) over means\n",
    "        degrees_of_freedom_prior=None,  # Degrees of freedom for Wishart prior\n",
    "        covariance_prior=None,  # Scale matrix for Wishart prior\n",
    "        seed=None,\n",
    "        nan_frac = None,\n",
    "        **kwargs\n",
    "):\n",
    "    if seed is not None:\n",
    "        np.random.seed(seed)\n",
    "        random.seed(seed)\n",
    "        torch.manual_seed(seed)\n",
    "\n",
    "    #features = random.randint(2, num_features)\n",
    "    features = num_features\n",
    "    if covariance_prior is None:\n",
    "        covariance_prior = np.eye(features)\n",
    "\n",
    "    if degrees_of_freedom_prior is None:\n",
    "        degrees_of_freedom_prior = features\n",
    "\n",
    "    seq_len = random.randint(start_seq_len, seq_len)\n",
    "    clusters_x = np.zeros((batch_size, seq_len, num_features))\n",
    "    clusters_y = np.zeros((batch_size, seq_len))\n",
    "    clusters_x_true = np.zeros((batch_size, seq_len, num_features))\n",
    "    batch_classes = []\n",
    "    for i in range(batch_size):\n",
    "        n_components = np.random.randint(min_classes, num_classes + 1)\n",
    "\n",
    "        # Sample weights from Dirichlet prior\n",
    "        weights = np.random.dirichlet(np.ones(n_components) * weight_concentration_prior)\n",
    "\n",
    "        # Assign points to components proportionally\n",
    "        counts = np.random.multinomial(seq_len, weights)\n",
    "\n",
    "        # Sample cluster parameters\n",
    "        means = []\n",
    "        covariances = []\n",
    "        for _ in range(n_components):\n",
    "            Sigma = invwishart.rvs(df=degrees_of_freedom_prior, scale=covariance_prior)\n",
    "            mu = np.random.multivariate_normal(np.full(features, mean_prior), Sigma / mean_precision_prior)\n",
    "            means.append(mu)\n",
    "            covariances.append(Sigma)\n",
    "        # Sample points from GMM\n",
    "        X = []\n",
    "        y = []\n",
    "        for k in range(n_components):\n",
    "            n_k = counts[k]\n",
    "            X_k = np.random.multivariate_normal(means[k], covariances[k], size=n_k)\n",
    "            X.append(X_k)\n",
    "            y.append(np.full(n_k, k))\n",
    "\n",
    "        X = np.vstack(X)\n",
    "        y = np.concatenate(y)\n",
    "        X = np.pad(X, ((0, 0), (0, num_features - features)), mode='constant')\n",
    "\n",
    "        x = preprocessing.MinMaxScaler().fit_transform(X)\n",
    "        n_samples, n_features = seq_len, features\n",
    "\n",
    "        if nan_frac:\n",
    "            fraction_missing = np.random.uniform(0, nan_frac)\n",
    "            n_elements = n_samples * n_features\n",
    "            n_missing = int(fraction_missing * n_elements)\n",
    "\n",
    "            if n_missing > 0:\n",
    "                all_indices = np.arange(n_elements)\n",
    "                np.random.shuffle(all_indices)\n",
    "                chosen = all_indices[:n_missing]\n",
    "\n",
    "                rows, cols = np.unravel_index(chosen, (n_samples, n_features))\n",
    "\n",
    "                # Ensure no row is fully NaN\n",
    "                full_rows = [r for r in np.unique(rows) if np.sum(rows == r) == n_features]\n",
    "                to_keep = []\n",
    "                for r in full_rows:\n",
    "                    row_mask = rows == r\n",
    "                    row_positions_in_chosen = np.where(row_mask)[0]\n",
    "                    keep_pos = np.random.choice(row_positions_in_chosen)\n",
    "                    to_keep.append(keep_pos)\n",
    "                # Remove all \"to_keep\" indices from chosen at once\n",
    "                chosen = np.delete(chosen, to_keep)\n",
    "                rows, cols = np.unravel_index(chosen, (n_samples, n_features))\n",
    "                x[rows, cols] = -2\n",
    "\n",
    "        x, y, X_true = sort(x, y, X, n_components)\n",
    "        clusters_x[i] = x\n",
    "        clusters_y[i] = y\n",
    "        clusters_x_true[i] = X_true\n",
    "        batch_classes.append(n_components)\n",
    "\n",
    "    clusters_x_true = torch.tensor(clusters_x_true, dtype=torch.float32).permute(1, 0, 2)\n",
    "    clusters_x = torch.tensor(clusters_x, dtype=torch.float32).permute(1, 0, 2)\n",
    "    clusters_y = torch.tensor(clusters_y, dtype=torch.float32).permute(1, 0)\n",
    "    batch_classes = torch.tensor(batch_classes).unsqueeze(0)\n",
    "\n",
    "    return clusters_x, clusters_y, clusters_x_true, batch_classes\n",
    "\n",
    "\n",
    "def get_clusters_bayesian_gmm(X,batch_classes=None, n_components=10,weight_concentration_prior=1.0,mean_prior=None,\n",
    "                              mean_precision_prior=None,degrees_of_freedom_prior=None,\n",
    "                            covariance_prior=None, random_state=42, n_init = None):\n",
    "    X_curr = X\n",
    "    features = X_curr.shape[1] \n",
    "    curr_mean_prior = np.zeros(features)\n",
    "    curr_covariance_prior = np.eye(features)\n",
    "    curr_degrees_of_freedom_prior = features\n",
    "    best_elbo = float('-inf') \n",
    "    best_cluster = -1 \n",
    "    best_prob = [] \n",
    "    best_label = [] \n",
    "    time_tot = 0 \n",
    "    comps = n_components  \n",
    "    for component in range(1, comps + 1):\n",
    "        model = BayesianGaussianMixture(n_components=component,\n",
    "                            weight_concentration_prior_type='dirichlet_distribution',\n",
    "                            weight_concentration_prior=weight_concentration_prior,\n",
    "                            mean_prior=curr_mean_prior, mean_precision_prior=mean_precision_prior,\n",
    "                            n_init = n_init, \n",
    "                            degrees_of_freedom_prior=curr_degrees_of_freedom_prior,max_iter=200,\n",
    "                            covariance_prior=curr_covariance_prior,random_state=42) \n",
    "        start_time = time.time() \n",
    "        model.fit(X_curr) \n",
    "        time_tot += time.time() - start_time\n",
    "        lower_bound = model.lower_bound_\n",
    "        if lower_bound - np.log(math.factorial(component)) > best_elbo:\n",
    "            best_elbo = lower_bound - np.log(math.factorial(component))\n",
    "            best_cluster = component\n",
    "            best_label = model.predict(X_curr) \n",
    "            best_prob = model.predict_proba(X_curr) \n",
    "    return best_cluster, best_label, best_prob, time_tot\n",
    "\n",
    "\n",
    "def match_mask(train_X, X_true):\n",
    "    # Copy to avoid modifying original\n",
    "    masked_X = X_true.clone()\n",
    "    \n",
    "    # Wherever train_X == -2, set X_true to NaN\n",
    "    masked_X[train_X == -2] = float('nan')\n",
    "    \n",
    "    return masked_X\n",
    "\n",
    "def apply_mask(train_X, miss_level, rng):\n",
    "    n_features = train_X.shape[1] \n",
    "    n_samples = train_X.shape[0]\n",
    "    corrupted = np.zeros_like(train_X, dtype=bool)\n",
    "    for f in range(n_features):\n",
    "        # Pick random fraction for this feature\n",
    "        n_mask = int(miss_level * n_samples)\n",
    "\n",
    "        # Randomly choose indices for this feature\n",
    "        idx = rng.choice(n_samples, size=n_mask, replace=False)\n",
    "        corrupted[idx, f] = True\n",
    "\n",
    "    fully_corrupted = np.where(np.all(corrupted, axis=1))[0]\n",
    "    for f in fully_corrupted:\n",
    "        # randomly \"unmask\" one feature so not all are -2\n",
    "        f_restore = rng.integers(0, n_features)\n",
    "        corrupted[f, f_restore] = False\n",
    "    train_X[corrupted] = -2\n",
    "\n",
    "    return train_X\n",
    "\n",
    "\n",
    "def generate_progressive_masks(n_samples, n_features, miss_levels, rng):\n",
    "    \"\"\"\n",
    "    Generates progressive masks where the *total* fraction of masked entries\n",
    "    matches each miss_level (not per feature).\n",
    "    \n",
    "    Returns a dict: {miss_level: boolean mask of shape (n_samples, n_features)}\n",
    "    \"\"\"\n",
    "    masks = {}\n",
    "    base_mask = np.zeros((n_samples, n_features), dtype=bool)  # start with no missing\n",
    "\n",
    "    total_entries = n_samples * n_features\n",
    "\n",
    "    for miss_level in miss_levels:\n",
    "        # total number of masked entries desired\n",
    "        n_to_mask_total = int(miss_level * total_entries)\n",
    "        \n",
    "        # start from previous mask\n",
    "        mask = base_mask.copy()\n",
    "        \n",
    "        # how many new entries we still need to add\n",
    "        current_masked = mask.sum()\n",
    "        n_new_to_mask = max(0, n_to_mask_total - current_masked)\n",
    "        \n",
    "        if n_new_to_mask > 0:\n",
    "            available_indices = np.argwhere(~mask)  # coordinates of available entries\n",
    "            selected = rng.choice(len(available_indices), size=n_new_to_mask, replace=False)\n",
    "            chosen_coords = available_indices[selected]\n",
    "            \n",
    "            for r, c in chosen_coords:\n",
    "                mask[r, c] = True\n",
    "\n",
    "        # ensure no row is fully masked\n",
    "        fully_masked_rows = np.where(np.all(mask, axis=1))[0]\n",
    "        for row in fully_masked_rows:\n",
    "            restore_feature = rng.integers(0, n_features)\n",
    "            mask[row, restore_feature] = False\n",
    "\n",
    "        masks[miss_level] = mask\n",
    "        base_mask = mask.copy()  # progressively build\n",
    "\n",
    "    return masks\n",
    "    \n",
    "\n",
    "def purity_score(y_true, y_pred):\n",
    "    \"\"\"Compute purity score given true and predicted labels.\"\"\"\n",
    "    matrix = cluster.contingency_matrix(y_true, y_pred)\n",
    "    return (np.sum(np.amax(matrix, axis=0)) / np.sum(matrix)) \n",
    "\n",
    "\n",
    "def plot_scores(scores, metric, title, ylabel):\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    score_means = {m: [] for m in scores.keys()}\n",
    "    score_stderrs = {m: [] for m in scores.keys()}\n",
    "\n",
    "    miss_levels_sorted = sorted({ml for v in scores.values() for ml in v[metric].keys()})\n",
    "\n",
    "    for m in scores.keys():\n",
    "        for miss_level in miss_levels_sorted:\n",
    "            vals = np.array(scores[m][metric].get(miss_level, []))\n",
    "            if len(vals) > 0:\n",
    "                score_means[m].append(vals.mean())\n",
    "                score_stderrs[m].append(vals.std(ddof=1) / np.sqrt(len(vals)))\n",
    "            else:\n",
    "                score_means[m].append(np.nan)\n",
    "                score_stderrs[m].append(np.nan)\n",
    "\n",
    "    for m in scores.keys():\n",
    "        plt.errorbar(\n",
    "            miss_levels_sorted,\n",
    "            score_means[m],\n",
    "            yerr=score_stderrs[m],\n",
    "            marker='o',\n",
    "            markersize=6,\n",
    "            capsize=5,\n",
    "            linewidth=2.2,\n",
    "            elinewidth=1.5,\n",
    "            capthick=1.5,\n",
    "            label=m\n",
    "        )\n",
    "\n",
    "    plt.xlabel(\"Missingness level\", fontsize=12)\n",
    "    plt.ylabel(ylabel, fontsize=12)\n",
    "    plt.title(title, fontsize=14)\n",
    "    plt.legend(fontsize=11)\n",
    "    plt.grid(True, alpha=0.4, linestyle=\"--\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "def apply_pca_5d(X):\n",
    "    \"\"\"Reduce dataset to 5 principal components.\"\"\"\n",
    "    pca = PCA(n_components=5)\n",
    "    return pca.fit_transform(X)\n",
    "\n",
    "def impute(X, strategy=\"mean\"):\n",
    "    X_imputed = X.clone()\n",
    "    \n",
    "    for col in range(X.shape[1]):\n",
    "        col_data = X[:, col]\n",
    "        mask = torch.isnan(col_data)\n",
    "        \n",
    "        if mask.any():\n",
    "            if strategy == \"mean\":\n",
    "                fill_value = torch.nanmean(col_data)\n",
    "            elif strategy == \"median\":\n",
    "                fill_value = torch.nanmedian(col_data)\n",
    "            else:\n",
    "                raise ValueError(\"strategy must be 'mean' or 'median'\")\n",
    "            \n",
    "            X_imputed[mask, col] = fill_value\n",
    "    \n",
    "    return X_imputed\n",
    "\n",
    "def get_labels_pfn(model, train_X):\n",
    "    start_time = time.time()\n",
    "    with torch.no_grad():\n",
    "        logits, cluster_output = model(train_X,torch.full((1,1), 0, dtype=torch.long))\n",
    "        cluster_input = torch.argmax(cluster_output, dim=-1) + 1\n",
    "        with torch.no_grad():\n",
    "            logits, cluster_output = model(train_X, torch.full((1,1), cluster_input.item(), dtype=torch.long))\n",
    "        predictions = torch.argmax(logits, dim=-1).cpu().numpy()\n",
    "        predictions_cluster_count = (torch.argmax(cluster_output, dim=-1) + 1).squeeze(0) \n",
    "        end_time = time.time() - start_time\n",
    "        return predictions, predictions_cluster_count, end_time\n",
    "\n",
    "\n",
    "def get_labels_gmm(X,batch_class,n_init = None, random_state = 42):\n",
    "    best_aic, best_bic = float('inf') , float('inf')\n",
    "    aic_labels, bic_labels = [], [] \n",
    "    X_curr = X\n",
    "    time_tot = 0 \n",
    "    for component in range(1, 11):\n",
    "        model = GaussianMixture(n_components = component,n_init=n_init, reg_covar=1e-04, random_state=random_state)\n",
    "        start_time = time.time() \n",
    "        model.fit(X_curr)\n",
    "        time_tot += time.time() - start_time\n",
    "        predictions = model.predict(X_curr)\n",
    "        current_bic  = model.bic(X_curr)\n",
    "        current_aic  = model.aic(X_curr)\n",
    "\n",
    "        if current_aic < best_aic:\n",
    "            best_aic = current_aic \n",
    "            aic_labels = predictions\n",
    "\n",
    "        if current_bic < best_bic:\n",
    "            best_bic = current_bic \n",
    "            bic_labels = predictions\n",
    "\n",
    "    return aic_labels, bic_labels, time_tot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a26b29d-398d-42a3-b429-324e17b22b9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "d_model, nhead, nhid, nlayers = 256, 4, 512, 4\n",
    "n_init = 10\n",
    "num_features = 5\n",
    "num_outputs = 10\n",
    "mean_precision_prior = 0.1 \n",
    "weight_concentration_prior = 1.0\n",
    "miss_levels=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,1] \n",
    "num_iterations = 20\n",
    "start_seed = 3000\n",
    "exp = f\"gls2_pfn_{start_seed}\"\n",
    "\n",
    "model_path_10 = \"models\\models_missingness\\pfn_hard_5D.pt\"\n",
    "model_path_20 = \"models\\models_missingness\\pfn_hard_5D.pt\"\n",
    "\n",
    "pfn_10 = Transformer(d_model, nhead, nhid, nlayers, in_features=num_features,buckets_size=num_outputs)\n",
    "pfn_20 = Transformer(d_model, nhead, nhid, nlayers, in_features=num_features,buckets_size=num_outputs)\n",
    "\n",
    "checkpoint_10 = torch.load(model_path_10, weights_only=True, map_location=torch.device('cpu'))\n",
    "checkpoint_20 = torch.load(model_path_20, weights_only=True, map_location=torch.device('cpu'))\n",
    "pfn_10.load_state_dict(checkpoint_10['model_state_dict'])\n",
    "pfn_20.load_state_dict(checkpoint_20['model_state_dict'])\n",
    "pfn_10.eval() \n",
    "pfn_20.eval()\n",
    "start_time = time.time() \n",
    "print('')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9f03e25-c4fb-42a3-97c9-a0cace5c601c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "##################################################### data RNA #####################################\n",
    "data_rna = (pd.read_csv(\"datasets/data_top_20_RNA.csv\")).to_numpy()\n",
    "data_rna = apply_pca_5d(data_rna)\n",
    "\n",
    "labels_rna = (pd.read_csv(\"datasets/labels_RNA.csv\")).to_numpy()\n",
    "labels_str = labels_rna[:, 1]  # extract the second column (labels)\n",
    "le = LabelEncoder()\n",
    "labels_rna = le.fit_transform(labels_str)\n",
    "\n",
    "\n",
    "scaler = MinMaxScaler()\n",
    "train_X = torch.from_numpy(scaler.fit_transform(data_rna)).float() \n",
    "X_true = torch.from_numpy(data_rna).float()\n",
    "batch_class = 5 \n",
    "train_Y = labels_rna\n",
    "\n",
    "\n",
    "######################################## Data GLS  #################################################\n",
    "# data_gls = pd.read_csv(\"datasets/GWAS_cl_data1.csv\")\n",
    "# data_gls = data_gls.drop(columns=['entropy', 'Unnamed: 0'])\n",
    "# le = LabelEncoder()\n",
    "# labels = le.fit_transform(data_gls['Cluster'].values)\n",
    "\n",
    "# X_data_gls  = (data_gls.drop(columns=['Cluster']).to_numpy())\n",
    "# train_Y = torch.tensor(labels, dtype=torch.long)\n",
    "\n",
    "\n",
    "# scaler = MinMaxScaler()\n",
    "# train_X = torch.tensor(scaler.fit_transform(X_data_gls), dtype = torch.float)\n",
    "# X_true = torch.tensor(X_data_gls)\n",
    "# batch_class= 3\n",
    "######################################## Data GLS 2 #################################################\n",
    "# data_gls2 = pd.read_csv(\"datasets/GWAS_cl_data2.csv\")\n",
    "# data_gls2 = data_gls2.drop(columns=['entropy', 'Unnamed: 0'])\n",
    "# le = LabelEncoder()\n",
    "# labels = le.fit_transform(data_gls2['Cluster'].values)\n",
    "\n",
    "# X_data_gls2  = (data_gls2.drop(columns=['Cluster']).to_numpy())\n",
    "# train_Y = torch.tensor(labels, dtype=torch.long)\n",
    "\n",
    "\n",
    "# scaler = MinMaxScaler()\n",
    "# train_X = torch.tensor(scaler.fit_transform(X_data_gls2), dtype = torch.float)\n",
    "# X_true = torch.tensor(X_data_gls2) \n",
    "# batch_class = 3 \n",
    "\n",
    "#################################################### Data GLS3 ####################################\n",
    "# data_gls3 = pd.read_csv(\"datasets/GWAS_cl_data3.csv\")\n",
    "# data_gls3 = data_gls3.drop(columns=['entropy', 'Unnamed: 0'])\n",
    "# le = LabelEncoder()\n",
    "# labels = le.fit_transform(data_gls3['Cluster'].values)\n",
    "\n",
    "# X_data_gls3  = (data_gls3.drop(columns=['Cluster']).to_numpy())\n",
    "# train_Y = torch.tensor(labels, dtype=torch.long)\n",
    "\n",
    "\n",
    "# scaler = MinMaxScaler()\n",
    "# train_X = torch.tensor(scaler.fit_transform(X_data_gls3), dtype = torch.float)\n",
    "# X_true = torch.tensor(X_data_gls3) \n",
    "# batch_class = 3 \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bde8d648-b16e-4565-be09-1e96abe52a12",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "methods = [\n",
    "    \"pfn_easy\", \"pfn_hard\", \n",
    "   \"bgmm_mean\", \"bgmm_median\", \n",
    "   \"aic_mean\", \"aic_median\", \n",
    "   \"bic_mean\", \"bic_median\",\n",
    "]\n",
    "\n",
    "methods2 = [\"pfn_easy\", \"pfn_hard\", \"bgmm\",\"gmm\"]\n",
    "scores = {m: {\"ari\": defaultdict(list),\n",
    "              \"ami\": defaultdict(list),\n",
    "              \"purity\": defaultdict(list)} \n",
    "          for m in methods}\n",
    "time_dict = {m: defaultdict(list) for m in methods2}\n",
    "\n",
    "for iteration in range(num_iterations): \n",
    "    rng_mask = np.random.default_rng(start_seed + iteration)\n",
    "    masks = generate_progressive_masks(X_true.shape[0],X_true.shape[1], miss_levels, rng_mask)\n",
    "    for miss_level in miss_levels: \n",
    "        train_X_original = train_X.clone() \n",
    "        X_true_original = X_true.clone()    \n",
    "        train_X_masked = train_X_original.clone()\n",
    "        train_X_masked[masks[miss_level]] = -2  # apply mask\n",
    "        X_true_masked = match_mask(train_X_masked.clone(), X_true_original.clone())\n",
    "        \n",
    "        # zeros_pad = torch.zeros((X_data_gls.shape[0], 2), dtype=train_X.dtype) # for gls1 only\n",
    "        # train_X_masked = torch.cat([train_X_masked, zeros_pad], dim=1)\n",
    "        \n",
    "        # imputations\n",
    "        X_true_mean = impute(X_true_masked, strategy='mean')\n",
    "        X_true_median = impute(X_true_masked, strategy='median')\n",
    "        X_true_np = X_true_masked.detach().cpu().numpy() \n",
    "\n",
    "\n",
    "        # clustering with masked + imputed data\n",
    "        labels_pfn_10, cluster_pfn_10, time_tot_pfn_easy = get_labels_pfn(pfn_10, train_X_masked.clone().unsqueeze(1)) \n",
    "        labels_pfn_10 = labels_pfn_10.squeeze()\n",
    "        labels_pfn_20, cluster_pfn_20, time_tot_pfn_hard = get_labels_pfn(pfn_20, train_X_masked.clone().unsqueeze(1))\n",
    "        labels_pfn_20 = labels_pfn_20.squeeze()\n",
    "    \n",
    "        cluster_bgmm_mean, labels_bgmm_mean, probs_bgmm_mean, time_tot_bgmm_mean = get_clusters_bayesian_gmm(\n",
    "            X=X_true_mean, batch_classes=None, n_init=n_init, \n",
    "            mean_precision_prior=mean_precision_prior, weight_concentration_prior=weight_concentration_prior)\n",
    "    \n",
    "        cluster_bgmm_median, labels_bgmm_median, probs_bgmm_median,time_tot_bgmm = get_clusters_bayesian_gmm(\n",
    "            X=X_true_median, batch_classes=None, n_init=n_init, \n",
    "            mean_precision_prior=mean_precision_prior, weight_concentration_prior=weight_concentration_prior)\n",
    "    \n",
    "        labels_aic_mean, labels_bic_mean, time_tot_gmm = get_labels_gmm(X_true_mean, batch_class, n_init=n_init)        \n",
    "        labels_aic_median, labels_bic_median,time_tot_bic_mean = get_labels_gmm(X_true_median, batch_class, n_init=n_init)  \n",
    "\n",
    "        time_dict['pfn_easy'][miss_level].append(time_tot_pfn_easy)\n",
    "        time_dict['pfn_hard'][miss_level].append(time_tot_pfn_hard)\n",
    "        time_dict['bgmm'][miss_level].append(time_tot_bgmm)\n",
    "        time_dict['gmm'][miss_level].append(time_tot_gmm)\n",
    "        # evaluation\n",
    "        for name, labels in {\n",
    "           \"pfn_easy\": labels_pfn_10,\n",
    "           \"pfn_hard\": labels_pfn_20,\n",
    "            \"bgmm_mean\": labels_bgmm_mean,\n",
    "            \"bgmm_median\": labels_bgmm_median,\n",
    "            \"aic_mean\": labels_aic_mean,\n",
    "            \"aic_median\": labels_aic_median,\n",
    "            \"bic_mean\": labels_bic_mean,\n",
    "            \"bic_median\": labels_bic_median,\n",
    "        }.items():\n",
    "            scores[name]['ari'][miss_level].append(adjusted_rand_score(train_Y, labels))\n",
    "            scores[name]['ami'][miss_level].append(adjusted_mutual_info_score(train_Y, labels, average_method=\"arithmetic\"))\n",
    "            scores[name]['purity'][miss_level].append(purity_score(train_Y, labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f2024c9-08d5-4d11-89b8-447c18e1e6d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_scores(scores, \"ari\", \"Clustering performance vs. Missingness (ARI)\", \"Adjusted Rand Index (ARI)\")\n",
    "plot_scores(scores, \"ami\", \"Clustering performance vs. Missingness (AMI)\", \"Adjusted Mutual Information (AMI)\")\n",
    "plot_scores(scores, \"purity\", \"Clustering performance vs. Missingness (Purity)\", \"Purity\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67295d95-5b80-4abe-999a-cdb7b8e54f35",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cuda_test2",
   "language": "python",
   "name": "cuda_test2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
