{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87cc4d2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import gudhi as gd\n",
    "import gudhi.wasserstein as wasserstein\n",
    "import gudhi.hera as hera\n",
    "\n",
    "from collections import defaultdict\n",
    "from tqdm import tqdm\n",
    "\n",
    "from itertools import combinations, combinations_with_replacement, product\n",
    "\n",
    "import ripserplusplus as rpp\n",
    "\n",
    "from scipy.spatial import distance_matrix\n",
    "from scipy.stats import pearsonr\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60c6cb38",
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = 'viridis'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e82d3dc",
   "metadata": {},
   "source": [
    "# Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf6d08f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'MNIST'\n",
    "version = 'd1'\n",
    "\n",
    "models = {\n",
    "     'PCA':'PCA',\n",
    "     'UMAP':'UMAP',\n",
    "    'Basic AutoEncoder':'AE',\n",
    "    'Topological AutoEncoder':'TopoAE (Moor et.al.)',\n",
    "    'RTD AutoEncoder H1':'RTD-AE',\n",
    "    'NSA AutoEncoder':'NSA-AE',\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "306043c0",
   "metadata": {},
   "source": [
    "## Calculate distance matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2912badd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def pdist_gpu(a, b, device = 'cuda'):\n",
    "    A = torch.tensor(a, dtype = torch.float64)\n",
    "    B = torch.tensor(b, dtype = torch.float64)\n",
    "\n",
    "    size = (A.shape[0] + B.shape[0]) * A.shape[1] / 1e9\n",
    "    max_size = 0.2\n",
    "\n",
    "    if size > max_size:\n",
    "        parts = int(size / max_size) + 1\n",
    "    else:\n",
    "        parts = 1\n",
    "\n",
    "    pdist = np.zeros((A.shape[0], B.shape[0]))\n",
    "    At = A.to(device)\n",
    "\n",
    "    for p in range(parts):\n",
    "        i1 = int(p * B.shape[0] / parts)\n",
    "        i2 = int((p + 1) * B.shape[0] / parts)\n",
    "        i2 = min(i2, B.shape[0])\n",
    "\n",
    "        Bt = B[i1:i2].to(device)\n",
    "        pt = torch.cdist(At, Bt)\n",
    "        pdist[:, i1:i2] = pt.cpu()\n",
    "\n",
    "        del Bt, pt\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "    del At\n",
    "\n",
    "    return pdist\n",
    "\n",
    "def zero_out_diagonal(distances):# make 0 on diagonal\n",
    "    return distances * (np.ones_like(distances) - np.eye(*distances.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe527b9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = np.load(f'data/{dataset_name}/prepared/train_data.npy')\n",
    "#data = np.load(f'data/{dataset_name}/prepared/data.npy')\n",
    "data = data.reshape(data.shape[0], -1)\n",
    "ids = np.random.choice(np.arange(len(data)), size=min(30000, len(data)), replace=False)\n",
    "data = data[ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb13b31d-cba0-41bc-832e-33e05fb45627",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af4dffca",
   "metadata": {},
   "outputs": [],
   "source": [
    "original_distances = pdist_gpu(data, data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58be8c03",
   "metadata": {},
   "outputs": [],
   "source": [
    "original_distances = zero_out_diagonal(original_distances)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20754cfe",
   "metadata": {},
   "source": [
    "## Pearson correlation for pairwise distances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c64d654a",
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    labels = np.load(f'data/{dataset_name}/prepared/train_labels.npy')\n",
    "    #labels = np.load(f'data/{dataset_name}/prepared/labels.npy')\n",
    "except FileNotFoundError:\n",
    "    labels = np.load(f'data/{dataset_name}/prepared/train_data.npy')\n",
    "# ids = np.random.choice(np.arange(0, len(labels)), size=min(6000, len(labels)), replace=False)\n",
    "\n",
    "def get_distances(data):\n",
    "    data = data.reshape(data.shape[0], -1)\n",
    "    distances = distance_matrix(data, data)\n",
    "    distances = distances[np.triu(np.ones_like(distances), k=1) > 0]\n",
    "    return distances\n",
    " # take only different "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "224a191b",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cf4d9fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "for model_name in models:\n",
    "    try:\n",
    "        latent = np.load(f'data/{dataset_name}/{model_name}_latent_output_{version}.npy')[ids]\n",
    "        print(latent.shape)\n",
    "        latent_distances = pdist_gpu(latent, latent)\n",
    "    except FileNotFoundError:\n",
    "        continue\n",
    "    results[model_name] = pearsonr(\n",
    "        latent_distances[np.triu(np.ones_like(original_distances), k=1) > 0], \n",
    "        original_distances[np.triu(np.ones_like(original_distances), k=1) > 0])[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c60cc2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8115d4e9",
   "metadata": {},
   "source": [
    "## Triplet accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33fe0026",
   "metadata": {},
   "outputs": [],
   "source": [
    "def triplet_accuracy(input_data, latent_data, triplets=None):\n",
    "    # calculate distance matricies\n",
    "    input_data = input_data.reshape(input_data.shape[0], -1)\n",
    "    input_distances = zero_out_diagonal(pdist_gpu(input_data, input_data))\n",
    "    latent_data = latent_data.reshape(latent_data.shape[0], -1)\n",
    "    latent_distances = zero_out_diagonal(pdist_gpu(latent_data, latent_data))\n",
    "    # generate triplets\n",
    "    if triplets is None:\n",
    "        triplets = np.asarray(list(combinations(range(len(input_data)), r=3)))\n",
    "    i_s = triplets[:, 0]\n",
    "    j_s = triplets[:, 1]\n",
    "    k_s = triplets[:, 2]\n",
    "    acc = (np.logical_xor(\n",
    "        input_distances[i_s, j_s] < input_distances[i_s, k_s], \n",
    "        latent_distances[i_s, j_s] < latent_distances[i_s, k_s]\n",
    "    ) == False)\n",
    "    acc = np.mean(acc.astype(np.int32))\n",
    "    return acc\n",
    "\n",
    "\n",
    "def avg_triplet_accuracy(input_data, latent_data, batch_size=128, n_runs=20):\n",
    "    # average over batches\n",
    "    accs = []\n",
    "    triplets = np.asarray(list(combinations(range(min(batch_size, len(input_data))), r=3)))\n",
    "    if batch_size > len(input_data):\n",
    "        accs.append(triplet_accuracy(input_data, latent_data, triplets=triplets))\n",
    "        return accs\n",
    "    for _ in range(n_runs):\n",
    "        ids = np.random.choice(np.arange(len(input_data)), size=batch_size, replace=False)\n",
    "        accs.append(triplet_accuracy(input_data[ids], latent_data[ids], triplets=triplets))\n",
    "    return accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab6716f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_data = np.load(f'data/{dataset_name}/prepared/train_data.npy')\n",
    "\n",
    "for model_name in models:\n",
    "    try:\n",
    "        latent_data = np.load(f'data/{dataset_name}/{model_name}_latent_output_{version}.npy')\n",
    "    except FileNotFoundError:\n",
    "        continue\n",
    "    accs = avg_triplet_accuracy(input_data, latent_data, batch_size=150, n_runs=10)\n",
    "    print(f\"Model: {model_name}, triplet acc: {np.mean(accs):.3f} $\\pm$ {np.std(accs):.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f49d68d1",
   "metadata": {},
   "source": [
    "# RTD"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f0dab6f",
   "metadata": {},
   "source": [
    "Switch to ripser++ from ArGentum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa50da87",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.loss import RTDLoss\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "180e3b39",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "n_runs = 10\n",
    "batch_size = 200\n",
    "\n",
    "loss = RTDLoss(dim=1, engine='ripser')\n",
    "\n",
    "data = np.load(f'data/{dataset_name}/prepared/train_data.npy')\n",
    "data = data.reshape(len(data), -1)\n",
    "\n",
    "if batch_size > len(data):\n",
    "    n_runs=1\n",
    "    \n",
    "max_dim = 1\n",
    "results = defaultdict(list)\n",
    "\n",
    "for i in tqdm(range(n_runs)):\n",
    "    ids = np.random.choice(np.arange(0, len(data)), size=min(batch_size, len(data)), replace=False)\n",
    "    \n",
    "    x = data[ids]\n",
    "    x_distances = distance_matrix(x, x)\n",
    "    x_distances = x_distances/np.percentile(x_distances.flatten(), 90)\n",
    "    \n",
    "    for model_name in models:\n",
    "        try:\n",
    "            z = np.load(f'data/{dataset_name}/{model_name}_latent_output_{version}.npy')\n",
    "        except FileNotFoundError:\n",
    "            try:\n",
    "                z = np.load(f'data/{dataset_name}/{model_name}_latent_output.npy')\n",
    "            except FileNotFoundError:\n",
    "                continue\n",
    "        z = z.reshape(len(z), -1)[ids]\n",
    "        z_distances = distance_matrix(z, z)\n",
    "        z_distances = z_distances/np.percentile(z_distances.flatten(), 90)\n",
    "        print(f'Calculating RTD for: {model_name}')\n",
    "        with torch.no_grad():\n",
    "            _, _, value = loss(torch.tensor(x_distances), torch.tensor(z_distances))\n",
    "        results[model_name].append(value.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c728dbd7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "names = [\n",
    "    'PCA', \n",
    "    'UMAP', \n",
    "    'Basic AutoEncoder', \n",
    "    'Topological AutoEncoder', \n",
    "    'RTD AutoEncoder H1',\n",
    "    'NSA AutoEncoder',\n",
    "]\n",
    "for model_name in names:\n",
    "    if model_name in results:\n",
    "        print(f\"{model_name}: {np.mean(results[model_name]):.2f} $\\pm$ {np.std(results[model_name]):.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a032104",
   "metadata": {},
   "source": [
    "# Tripet acc. between cluster centers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aafb5354",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cluster_distances(data, labels):\n",
    "    clusters = []\n",
    "    if len(data.shape) > 2:\n",
    "        data = data.reshape(data.shape[0], -1)\n",
    "    for l in np.sort(np.unique(labels)):\n",
    "        clusters.append(np.mean(data[labels == l], axis=0))\n",
    "    clusters = np.asarray(clusters)\n",
    "    return distance_matrix(clusters, clusters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fab3ed9",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = np.load(f'data/{dataset_name}/prepared/train_data.npy')\n",
    "labels = np.load(f'data/{dataset_name}/prepared/train_labels.npy')\n",
    "original_distances = get_cluster_distances(data, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dc59f54",
   "metadata": {},
   "outputs": [],
   "source": [
    "def triplet_accuracy_between_clusters(original_distances, latent_distances):\n",
    "    triplets = np.asarray(list(combinations(range(len(original_distances)), r=3)))\n",
    "    i_s = triplets[:, 0]\n",
    "    j_s = triplets[:, 1]\n",
    "    k_s = triplets[:, 2]\n",
    "    acc = (np.logical_xor(\n",
    "        original_distances[i_s, j_s] < original_distances[i_s, k_s], \n",
    "        latent_distances[i_s, j_s] < latent_distances[i_s, k_s]\n",
    "    ) == False)\n",
    "    return acc\n",
    "\n",
    "def triplet_accuracy_between_clusters_(original_distances, latent_distances):\n",
    "    ids = range(len(original_distances))\n",
    "    triplets = np.asarray(list(product(ids, ids, ids)))\n",
    "    i_s = triplets[:, 0]\n",
    "    j_s = triplets[:, 1]\n",
    "    k_s = triplets[:, 2]\n",
    "    acc = (np.logical_xor(\n",
    "        original_distances[i_s, j_s] < original_distances[i_s, k_s], \n",
    "        latent_distances[i_s, j_s] < latent_distances[i_s, k_s]\n",
    "    ) == False)\n",
    "    return acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f8d3598",
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name in models:\n",
    "    try:\n",
    "        latent_data = np.load(f'data/{dataset_name}/{model_name}_latent_output_{version}.npy')\n",
    "    except FileNotFoundError:\n",
    "        continue\n",
    "    latent_distances = get_cluster_distances(latent_data, labels)\n",
    "    accs = triplet_accuracy_between_clusters_(original_distances, latent_distances)\n",
    "    print(f\"Model: {model_name}, triplet acc: {np.mean(accs):.3f} $\\pm$ {np.std(accs):.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e3d6ff1-8aaf-4c2c-8e5f-2380724e7542",
   "metadata": {},
   "source": [
    "### MSE Between X and X hat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38e3f1f5-3e51-4c25-82ed-ca95f9f753da",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "criterion = nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60387a28-d26a-41bf-a085-880cb353a64d",
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_name in ['COIL-20','COIL-100']:\n",
    "    data = np.load(f'data/{dataset_name}/prepared/data.npy')\n",
    "    data = data.reshape(len(data), -1)\n",
    "    labels = np.load(f'data/{dataset_name}/prepared/labels.npy')\n",
    "else:\n",
    "    data = np.load(f'data/{dataset_name}/prepared/train_data.npy')\n",
    "    data = data.reshape(len(data), -1)\n",
    "    labels = np.load(f'data/{dataset_name}/prepared/train_labels.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d979de7-43cb-4f40-8401-e279f1c37bea",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler = {\n",
    "    \"Basic AutoEncoder\":255,\n",
    "    \"RTD AutoEncoder H1\":255,\n",
    "    \"NSA AutoEncoder\":255,\n",
    "    \"Topological AutoEncoder\":255,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7e5a416-0966-42c6-9992-1e94b90d0d3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name in models:\n",
    "    try:\n",
    "        output_data = np.load(f'data/{dataset_name}/{model_name}_final_output_{version}.npy')\n",
    "    except FileNotFoundError:\n",
    "        continue\n",
    "    accs = criterion(torch.tensor(data)/scaler[model_name], torch.tensor(output_data))\n",
    "    print(f\"Model: {model_name}, MSE value: {accs:e}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24e8ff8e-db18-4028-a5d0-3b9d6ce07656",
   "metadata": {},
   "source": [
    "#### Test MSE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6da7243a-8ff6-41cd-b1ec-5dbb9c4ab702",
   "metadata": {},
   "outputs": [],
   "source": [
    "#No Test MSE for COIL Datasets\n",
    "\n",
    "data = np.load(f'data/{dataset_name}/prepared/test_data.npy')\n",
    "data = data.reshape(len(data), -1)\n",
    "labels = np.load(f'data/{dataset_name}/prepared/test_labels.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9da1b24-717f-4edc-bc94-c1838f097082",
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name in models:\n",
    "    try:\n",
    "        output_data = np.load(f'data/{dataset_name}/{model_name}_final_output_{version}_test.npy')\n",
    "    except FileNotFoundError:\n",
    "        continue\n",
    "    accs = criterion(torch.tensor(data)/scaler[model_name], torch.tensor(output_data))\n",
    "    print(f\"Model: {model_name}, MSE value: {accs:e}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7de2a6f-2d06-4422-949b-cfb56edfcf55",
   "metadata": {},
   "source": [
    "### Calculate NSA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24c86f4e-d930-4070-bd11-fc24499dc6ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.loss import NSALoss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0fedf25-9313-43b9-8132-124a9eb05fb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = NSALoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bf9dd85-31c5-4ac6-9189-c302b5963e80",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_runs = 10\n",
    "batch_size = 2000\n",
    "\n",
    "loss = criterion\n",
    "\n",
    "if dataset_name in ['COIL-20','COIL-100']:\n",
    "    data = np.load(f'data/{dataset_name}/prepared/data.npy')\n",
    "    data = data.reshape(len(data), -1)\n",
    "else:\n",
    "    data = np.load(f'data/{dataset_name}/prepared/train_data.npy')\n",
    "    data = data.reshape(len(data), -1)\n",
    "data = data/255\n",
    "\n",
    "if batch_size > len(data):\n",
    "    n_runs=1\n",
    "    \n",
    "results = defaultdict(list)\n",
    "\n",
    "for i in tqdm(range(n_runs)):\n",
    "    ids = np.random.choice(np.arange(0, len(data)), size=min(batch_size, len(data)), replace=False)\n",
    "    \n",
    "    x = data[ids]\n",
    "    x = x/255\n",
    "    \n",
    "    for model_name in models:\n",
    "        try:\n",
    "            z = np.load(f'data/{dataset_name}/{model_name}_latent_output_{version}.npy')\n",
    "        except FileNotFoundError:\n",
    "            try:\n",
    "                z = np.load(f'data/{dataset_name}/{model_name}_latent_output.npy')\n",
    "            except FileNotFoundError:\n",
    "                continue\n",
    "        z = z.reshape(len(z), -1)[ids]\n",
    "        print(f'Calculating NSA for: {model_name}')\n",
    "        with torch.no_grad():\n",
    "            value = loss(torch.tensor(x), torch.tensor(z))\n",
    "        results[model_name].append(value.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e9b0a3f-968b-48a3-8d44-663df87bdcca",
   "metadata": {},
   "outputs": [],
   "source": [
    "names = [\n",
    "    'PCA', \n",
    "    'UMAP', \n",
    "    'Basic AutoEncoder', \n",
    "    'Topological AutoEncoder', \n",
    "    'RTD AutoEncoder H1',\n",
    "    'NSA AutoEncoder',\n",
    "]\n",
    "for model_name in names:\n",
    "    if model_name in results:\n",
    "        print(f\"{model_name}: {np.mean(results[model_name]):.2f} $\\pm$ {np.std(results[model_name]):.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e204aea-926d-4f0d-8f7e-ff0a1adf43ae",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
