{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from datasets import Dataset\n",
    "from typing import Dict\n",
    "from tqdm import tqdm\n",
    "from datasets.dataset_dict import DatasetDict\n",
    "import hydra\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from tqdm import tqdm\n",
    "from latent_invariances.aes.aes_model_ids import MODEL_DATASET_RUN_ID\n",
    "from functools import partial\n",
    "from modelzoo.pl_modules.aes.pl_autoencoder import LightningAutoencoder\n",
    "from modelzoo.data.vision.datamodule import collate_fn\n",
    "from omegaconf import OmegaConf\n",
    "import pandas as pd\n",
    "\n",
    "from modelzoo import MODELZOO_ROOT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Select models to consider"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAMES = [\"ae\", \"vae\", \"linearized_ae\", \"linearized_vae\"]\n",
    "LOCAL_MODELS_INDEX = PROJECT_ROOT / \"data\" / \"model_zoo\" / \"absolute_models.tsv\"\n",
    "LOCAL_MODELS_INDEX.parent.mkdir(exist_ok=True, parents=True)\n",
    "\n",
    "## Select models from MODELZOO_ROOT\n",
    "# model_index = MODELZOO_ROOT / 'index.csv'\n",
    "# df = pd.read_csv(model_index, sep='\\t')\n",
    "# df['model'] = df['tags'].map(lambda x: x.split(',')[1])\n",
    "# df = df[df.model.isin(MODEL_NAMES)]\n",
    "# df.to_csv(PROJECT_ROOT/'data'/'model_zoo' / \"absolute_models.tsv\", sep='\\t', index=False)\n",
    "\n",
    "# Load selected models\n",
    "df = pd.read_csv(LOCAL_MODELS_INDEX, sep=\"\\t\")\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set(df.dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Embed data for each dataset and each model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def embed_samples(loader, model) -> Dict:\n",
    "    model.cuda().eval()\n",
    "\n",
    "    embeddings = []\n",
    "    for batch in tqdm(loader, desc=f\"Embedding samples\"):\n",
    "        x = batch[\"x\"].to(\"cuda\")\n",
    "        latents = model.encode(x)[\"batch_latent\"]\n",
    "        embeddings.extend(latents.detach().cpu().numpy())\n",
    "\n",
    "    model.cpu()\n",
    "    return embeddings\n",
    "\n",
    "\n",
    "def extract_samples_id(loader) -> Dict:\n",
    "    ids = []\n",
    "    labels = []\n",
    "\n",
    "    for batch in tqdm(loader, desc=f\"Extract ids\"):\n",
    "        ids.extend(batch[\"id\"].cpu().numpy())\n",
    "        labels.extend(batch[\"y\"].cpu().numpy())\n",
    "\n",
    "    return labels, ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modelzoo.utils.io_model import load_local_ckpt\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "# must do one at a time or we go OOM\n",
    "STORING_MODELS = [\"ae\", \"vae\", \"linearized_ae\", \"linearized_vae\"]\n",
    "for STORING_MODEL in STORING_MODELS:\n",
    "    assert STORING_MODEL in set(df[\"model\"])\n",
    "\n",
    "    embeddings = {}\n",
    "    for idx, (run_id, dataset_name, model_name, seed_name) in df[\n",
    "        [\"wandb_id\", \"dataset\", \"model\", \"seed_index\"]\n",
    "    ].iterrows():\n",
    "        print(f\"Model: {model_name}, Dataset: {dataset_name}, Seed: {seed_name}, Run ID: {run_id}\")\n",
    "        if model_name != STORING_MODEL:\n",
    "            continue\n",
    "\n",
    "        # GET RUN DIR\n",
    "        filepath = MODELZOO_ROOT / \"checkpoints\" / f\"{run_id}.ckpt.zip\"\n",
    "\n",
    "        # # INSTANTIATE MODEL AND DATAMODULE\n",
    "        model, ckpt = load_local_ckpt(filepath, strict=False)\n",
    "        cfg = OmegaConf.create(ckpt[\"cfg\"])\n",
    "\n",
    "        datamodule = hydra.utils.instantiate(OmegaConf.to_container(cfg.nn.data), _recursive_=False)\n",
    "        datamodule.setup(stage=\"fit\")\n",
    "        train_dataset = datamodule.train_dataset\n",
    "        val_dataset = datamodule.val_datasets[0]\n",
    "\n",
    "        train_loader = DataLoader(\n",
    "            train_dataset,\n",
    "            batch_size=32,\n",
    "            pin_memory=True,\n",
    "            shuffle=False,\n",
    "            num_workers=4,\n",
    "            collate_fn=partial(\n",
    "                collate_fn, split=\"train\", metadata=datamodule.metadata, transform=datamodule.transform_batch\n",
    "            ),\n",
    "        )\n",
    "        val_loader = DataLoader(\n",
    "            val_dataset,\n",
    "            batch_size=32,\n",
    "            pin_memory=True,\n",
    "            shuffle=False,\n",
    "            num_workers=4,\n",
    "            collate_fn=partial(\n",
    "                collate_fn, split=\"val\", metadata=datamodule.metadata, transform=datamodule.transform_batch\n",
    "            ),\n",
    "        )\n",
    "\n",
    "        column_name = f\"{model_name}_{seed_name}\"\n",
    "        #\n",
    "        if dataset_name not in embeddings:\n",
    "            train_targets, train_ids = extract_samples_id(train_loader)\n",
    "            val_targets, val_ids = extract_samples_id(val_loader)\n",
    "            #\n",
    "            embeddings[dataset_name] = DatasetDict(\n",
    "                {\n",
    "                    \"train\": Dataset.from_dict({\"id\": train_ids, \"target\": train_targets}),\n",
    "                    \"test\": Dataset.from_dict({\"id\": val_ids, \"target\": val_targets}),\n",
    "                }\n",
    "            )\n",
    "        #\n",
    "        train_embeds = embed_samples(train_loader, model)\n",
    "        val_embeds = embed_samples(val_loader, model)\n",
    "        #\n",
    "        embeddings[dataset_name][\"train\"] = embeddings[dataset_name][\"train\"].map(\n",
    "            function=lambda _, i: {column_name: train_embeds[i]}, with_indices=True, desc=\"Embedding store train\"\n",
    "        )\n",
    "        embeddings[dataset_name][\"test\"] = embeddings[dataset_name][\"test\"].map(\n",
    "            function=lambda _, i: {column_name: val_embeds[i]}, with_indices=True, desc=\"Embedding store test\"\n",
    "        )\n",
    "\n",
    "    EMBEDDINGS_DIR = PROJECT_ROOT / \"data\" / \"model_zoo\" / \"embeddings\"\n",
    "    EMBEDDINGS_DIR.mkdir(exist_ok=True, parents=True)\n",
    "    for dataset_name, datset_embeds in embeddings.items():\n",
    "        dataset_embeds_dir = EMBEDDINGS_DIR / dataset_name / STORING_MODEL\n",
    "        dataset_embeds_dir.mkdir(exist_ok=True, parents=True)\n",
    "\n",
    "        datset_embeds.save_to_disk(str(dataset_embeds_dir))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Concatenate datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "STORING_MODELS = [\"ae\", \"vae\", \"linearized_ae\", \"linearized_vae\"]\n",
    "DATASETS = [\"mnist\", \"fashion_mnist\", \"cifar10\", \"cifar100\"]\n",
    "\n",
    "EMBEDDINGS_DIR = PROJECT_ROOT / \"data\" / \"model_zoo\" / \"embeddings\"\n",
    "\n",
    "\n",
    "datas = defaultdict(list)\n",
    "\n",
    "# Load datasets\n",
    "for model_name in MODEL_NAMES:\n",
    "    for dataset_name in DATASETS:\n",
    "        dataset_embeds_dir = EMBEDDINGS_DIR / dataset_name / model_name\n",
    "        datas[dataset_name].append(DatasetDict.load_from_disk(str(dataset_embeds_dir)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datas[\"mnist\"][0].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "# Create datasets dict\n",
    "datadicts = {}\n",
    "for dataset_name, dataset in datas.items():\n",
    "    datadicts[dataset_name] = DatasetDict(\n",
    "        {\n",
    "            split: datasets.concatenate_datasets(\n",
    "                [datas[dataset_name][0][split].select_columns([\"id\", \"target\"])]\n",
    "                + [x[split].remove_columns([\"id\", \"target\"]) for x in datas[dataset_name]],\n",
    "                axis=1,\n",
    "            )\n",
    "            for split in datas[dataset_name][0].keys()\n",
    "        }\n",
    "    )\n",
    "datadicts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Store embeddings\n",
    "EMBEDDINGS_DIR = PROJECT_ROOT / \"data\" / \"model_zoo\" / \"embeddings_all\"\n",
    "\n",
    "for dataset_name, datset_embeds in datadicts.items():\n",
    "    dataset_embeds_dir = EMBEDDINGS_DIR / dataset_name\n",
    "\n",
    "    dataset_embeds_dir.mkdir(exist_ok=True, parents=True)\n",
    "    datset_embeds.save_to_disk(str(dataset_embeds_dir))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "latent-invariances",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
