{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3190377",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb96f3a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "from src.autoencoder import AutoEncoder, NSAAutoEncoder\n",
    "from src.utils import *\n",
    "from src.loss import RTDLoss, NSALoss\n",
    "from src.top_ae import TopologicallyRegularizedAutoencoder\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.metrics import accuracy_score, confusion_matrix\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
    "\n",
    "from collections import defaultdict\n",
    "\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63bb080d",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"dataset_name\":\"MNIST\",\n",
    "    \"version\":\"d1\",\n",
    "    \"model_name\":\"default\",\n",
    "    \"max_epochs\":250,\n",
    "    \"gpus\":[0],\n",
    "    \"rtd_every_n_batches\":1,\n",
    "    \"rtd_start_epoch\":60,\n",
    "    \"nsa_l\":1.0, # rtd loss \n",
    "    \"nsa_every_n_batches\":1,\n",
    "    \"nsa_start_epoch\":60,\n",
    "    \"nsa_l\":1.0, # rtd loss\n",
    "    \"n_runs\":1, # number of runs for each model\n",
    "    \"card\":50, # number of points on the persistence diagram\n",
    "    \"n_threads\":1, # number of threads for parallel ripser computation of pers homology\n",
    "    \"latent_dim\":16, # latent dimension (2 or 3 for vizualization purposes)\n",
    "    \"input_dim\":28*28,\n",
    "    \"n_hidden_layers\":3,\n",
    "    \"hidden_dim\":512,\n",
    "    \"batch_size\":64,\n",
    "    \"engine\":\"ripser\",\n",
    "    \"is_sym\":True,\n",
    "    \"lr\":1e-4,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90495af7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model(input_dim, latent_dim=2, n_hidden_layers=2, m_type='encoder', **kwargs):\n",
    "    n = int(np.log2(input_dim))-1\n",
    "    layers = []\n",
    "    if m_type == 'encoder':\n",
    "        in_dim = input_dim\n",
    "        if input_dim  // 2 >= latent_dim:\n",
    "            out_dim = input_dim // 2\n",
    "        else:\n",
    "            out_dim = input_dim\n",
    "        for i in range(min(n, n_hidden_layers)):\n",
    "            layers.extend([nn.Linear(in_dim, out_dim), nn.ReLU()])\n",
    "            in_dim = out_dim\n",
    "            if in_dim  // 2 >= latent_dim:\n",
    "                out_dim = in_dim // 2\n",
    "            else:\n",
    "                out_dim = in_dim\n",
    "        layers.extend([nn.Linear(in_dim, latent_dim)])\n",
    "    elif m_type == 'decoder':\n",
    "        in_dim = latent_dim\n",
    "        out_dim = latent_dim * 2\n",
    "        for i in range(min(n, n_hidden_layers)):\n",
    "            layers.extend([nn.Linear(in_dim, out_dim), nn.ReLU()])\n",
    "            in_dim = out_dim\n",
    "            out_dim *= 2\n",
    "        layers.extend([nn.Linear(in_dim, input_dim)])\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "def get_list_of_models(**config):\n",
    "    # define a list of models\n",
    "    encoder = get_linear_model(\n",
    "        m_type='encoder',\n",
    "        **config\n",
    "    )\n",
    "    decoder = get_linear_model(\n",
    "        m_type='decoder',\n",
    "        **config\n",
    "    )\n",
    "    models = {\n",
    "        'Basic AutoEncoder':AutoEncoder(\n",
    "           encoder = encoder,\n",
    "            decoder = decoder,\n",
    "            MSELoss = nn.MSELoss(),\n",
    "            **config\n",
    "        ),\n",
    "        'Topological AutoEncoder':TopologicallyRegularizedAutoencoder(\n",
    "            encoder = encoder,\n",
    "            decoder = decoder,\n",
    "            MSELoss = nn.MSELoss(),\n",
    "            **config\n",
    "        ),\n",
    "        'RTD AutoEncoder H1':AutoEncoder(\n",
    "            encoder = encoder,\n",
    "            decoder = decoder,\n",
    "            RTDLoss = RTDLoss(dim=1, lp=1.0,  **config), # only H1\n",
    "            MSELoss = nn.MSELoss(),\n",
    "            **config\n",
    "        ),\n",
    "        'NSA AutoEncoder':NSAAutoEncoder(\n",
    "            encoder = encoder,\n",
    "            decoder = decoder,\n",
    "            NSALoss = NSALoss(), # only H1\n",
    "            MSELoss = nn.MSELoss(),\n",
    "            **config\n",
    "        ),\n",
    "    }\n",
    "    return models, encoder, decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81219803",
   "metadata": {},
   "outputs": [],
   "source": [
    "def collate_with_matrix(samples):\n",
    "    indicies, data, labels = zip(*samples)\n",
    "    data, labels = torch.tensor(np.asarray(data)), torch.tensor(np.asarray(labels))\n",
    "    if len(data.shape) > 2:\n",
    "        dist_data = torch.flatten(data, start_dim=1)\n",
    "    else:\n",
    "        dist_data = data\n",
    "    x_dist = torch.cdist(dist_data, dist_data, p=2) / np.sqrt(dist_data.shape[1])\n",
    "#     x_dist = (x_dist + x_dist.T) / 2.0 # make symmetrical (cdist is prone to computational errors)\n",
    "    return data, x_dist, labels\n",
    "\n",
    "def collate_with_matrix_geodesic(samples):\n",
    "    indicies, data, labels, dist_data = zip(*samples)\n",
    "    data, labels = torch.tensor(np.asarray(data)), torch.tensor(np.asarray(labels))\n",
    "    x_dist = torch.tensor(np.asarray(dist_data)[:, indicies])\n",
    "    return data, x_dist, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0fb1020",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = config['dataset_name']\n",
    "if dataset_name in ['COIL-20','COIL-100']:\n",
    "    train_data = np.load(f'data/{dataset_name}/prepared/data.npy').astype(np.float32)\n",
    "else:\n",
    "    train_data = np.load(f'data/{dataset_name}/prepared/train_data.npy').astype(np.float32)\n",
    "    \n",
    "try:        \n",
    "    test_data = np.load(f'data/{dataset_name}/prepared/test_data.npy').astype(np.float32)\n",
    "except FileNotFoundError:\n",
    "    ids = np.random.choice(np.arange(len(train_data)), size=int(0.2*len(train_data)), replace=False)\n",
    "    test_data = train_data[ids]\n",
    "\n",
    "try:\n",
    "    train_labels = np.load(f'data/{dataset_name}/prepared/train_labels.npy')\n",
    "    train_labels = np.load(f'data/{dataset_name}/prepared/labels.npy')\n",
    "\n",
    "except FileNotFoundError:\n",
    "    train_labels = None\n",
    "\n",
    "try:\n",
    "    test_labels = np.load(f'data/{dataset_name}/prepared/test_labels.npy')\n",
    "except FileNotFoundError:\n",
    "    if train_labels is None:\n",
    "        test_labels = None\n",
    "    else:\n",
    "        test_labels = train_labels[ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c76d71a2-2fab-436c-9047-a9722771b670",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "class CustomMinMaxScaler:\n",
    "    def __init__(self):\n",
    "        self.min_vals = train_data.min()\n",
    "        self.max_vals = train_data.max()\n",
    "        self.is_fitted = True\n",
    "        \n",
    "    def fit(self, data):\n",
    "        self.min_vals = np.min(data, axis=0)\n",
    "        self.max_vals = np.max(data, axis=0)\n",
    "        self.is_fitted = True\n",
    "        \n",
    "    def transform(self, data):\n",
    "        if not self.is_fitted:\n",
    "            raise NotFittedError\n",
    "        scaled_data = (data - self.min_vals) / (self.max_vals - self.min_vals)\n",
    "        return scaled_data\n",
    "    \n",
    "    def fit_transform(self, data):\n",
    "        self.fit(data)\n",
    "        return self.transform(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28cc76f2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "scaler = CustomMinMaxScaler()\n",
    "flatten = True\n",
    "geodesic = False\n",
    "\n",
    "train = FromNumpyDataset(\n",
    "    train_data, \n",
    "    train_labels, \n",
    "    geodesic=geodesic, \n",
    "    scaler=scaler, \n",
    "    flatten=flatten, \n",
    "    n_neighbors=2\n",
    ")\n",
    "print(\"Train done\")\n",
    "test = FromNumpyDataset(\n",
    "    test_data, \n",
    "    test_labels, \n",
    "    geodesic=geodesic, \n",
    "    scaler = train.scaler,    \n",
    "    flatten=flatten, \n",
    "    n_neighbors=2\n",
    ")\n",
    "\n",
    "train_loader = DataLoader(\n",
    "    train, \n",
    "    batch_size=config[\"batch_size\"], \n",
    "    num_workers=2, \n",
    "    collate_fn=collate_with_matrix_geodesic if geodesic else collate_with_matrix, \n",
    "    shuffle=True\n",
    ")\n",
    "\n",
    "val_loader = DataLoader(\n",
    "    test,\n",
    "    batch_size=config[\"batch_size\"],\n",
    "    num_workers=2,\n",
    "    collate_fn=collate_with_matrix_geodesic if geodesic else collate_with_matrix,\n",
    "    shuffle=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f1cdafc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_autoencoder(model, train_loader, val_loader=None, model_name='default', \n",
    "                      dataset_name='F-MNIST', gpus=[0], max_epochs=100, run=0, version=\"d1\"):\n",
    "    version = f\"{dataset_name}_{model_name}_{version}_{run}\"\n",
    "    logger = pl.loggers.TensorBoardLogger(save_dir=os.getcwd(), name='lightning_logs', version=version)\n",
    "    trainer = pl.Trainer(\n",
    "        logger=logger, \n",
    "        gpus=gpus, \n",
    "        max_epochs=max_epochs, \n",
    "        log_every_n_steps=1, \n",
    "        num_sanity_val_steps=0\n",
    "    )\n",
    "    trainer.fit(model, train_loader, val_loader)\n",
    "    return model\n",
    "\n",
    "def dump_figures(figures, dataset_name, version):\n",
    "    for model_name in figures:\n",
    "        figures[model_name].savefig(f'results/{dataset_name}/{model_name}_{version}.png')\n",
    "\n",
    "def train_models(train_loader, val_loader, dataset_name=\"\", max_epochs=1, gpus=[], n_neighbors=[1], n_runs=1, version='', **kwargs):\n",
    "    models, encoder, decoder = get_list_of_models(**kwargs)\n",
    "    \n",
    "    for model_name in tqdm(models, desc=f\"Training models\"):\n",
    "        if 'AutoEncoder' in model_name: # train an autoencoder\n",
    "            models[model_name] = train_autoencoder(\n",
    "                models[model_name], \n",
    "                train_loader, \n",
    "                val_loader, \n",
    "                model_name, \n",
    "                dataset_name,\n",
    "                gpus,\n",
    "                max_epochs,\n",
    "                0,\n",
    "                version\n",
    "            )\n",
    "        else: # umap / pca / t-sne (sklearn interface)\n",
    "            train_latent = models[model_name].fit_transform(train_loader.dataset.data)\n",
    "        # measure training time\n",
    "    return encoder, decoder, models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b98c06e8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "encoder, decoder, trained_models = train_models(train_loader, val_loader, **config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b4b2900-e1c7-4bf9-b5d6-118309296ffb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a91b08b8-a917-4462-a964-2fc4710ccbb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "version = config['version']\n",
    "train_loader = DataLoader(\n",
    "    train,\n",
    "    batch_size=config[\"batch_size\"],\n",
    "    num_workers=2,\n",
    "    collate_fn=collate_with_matrix_geodesic if geodesic else collate_with_matrix,\n",
    "    shuffle=False\n",
    ")\n",
    "\n",
    "for model_name in trained_models:\n",
    "    latent, labels = get_latent_representations(trained_models[model_name], train_loader)\n",
    "    print(latent.shape)\n",
    "    np.save(f'data/{dataset_name}/{model_name}_latent_output_{version}.npy', latent)\n",
    "    np.save(f'data/{dataset_name}/{model_name}_latent_labels_{version}.npy', labels)\n",
    "\n",
    "for model_name in trained_models:\n",
    "    latent, labels = get_output_representations(trained_models[model_name], train_loader)\n",
    "    print(latent.shape)\n",
    "    np.save(f'data/{dataset_name}/{model_name}_final_output_{version}.npy', latent)\n",
    "    np.save(f'data/{dataset_name}/{model_name}_final_labels_{version}.npy', labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd8a1433",
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name in trained_models:\n",
    "    latent, labels = get_latent_representations(trained_models[model_name], val_loader)\n",
    "    print(latent.shape)\n",
    "    np.save(f'data/{dataset_name}/{model_name}_latent_output_{version}_test.npy', latent)\n",
    "    np.save(f'data/{dataset_name}/{model_name}_latent_labels_{version}_test.npy', labels)\n",
    "\n",
    "for model_name in trained_models:\n",
    "    latent, labels = get_output_representations(trained_models[model_name], val_loader)\n",
    "    print(latent.shape)\n",
    "    np.save(f'data/{dataset_name}/{model_name}_final_output_{version}_test.npy', latent)\n",
    "    np.save(f'data/{dataset_name}/{model_name}_final_labels_{version}_test.npy', labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79754037-2ca5-4ba6-9ad1-ac0125f42ba0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
