{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from src.model.components.clip_text_autoencoder import CLIPAutoEncoder\n",
    "from src.dataset.text import SimpleTokenizer\n",
    "from pathlib import Path\n",
    "import random\n",
    "from src.utils.datatool import read_jsonlines\n",
    "import torch\n",
    "\n",
    "from lmdb_embeddings.reader import LmdbEmbeddingsReader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "descriptions = set()\n",
    "samples = read_jsonlines(\"/data/vtt/meta/vtt_all.jsonl\")\n",
    "for sample in samples:\n",
    "    for label in sample[\"annotation\"]:\n",
    "        descriptions.add(label[\"label\"])\n",
    "descriptions = list(sorted(list(descriptions)))\n",
    "print(f\"Total descriptions: {len(descriptions)}\")\n",
    "print(descriptions[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# vitb16_4\n",
    "ckpt_path_list = {\n",
    "    # \"vitb16_4_50e\": Path(\"/log/exp/vtt/WikiHowDataModule.CLIPAutoEncoder.GenerationLoss.2023-01-14_00-11-16/checkpoints/last.ckpt\"),\n",
    "    # \"vitb16_4_300e\": Path(\"/log/exp/vtt/WikiHowDataModule.CLIPAutoEncoder.GenerationLoss.2023-01-20_12-35-24/checkpoints/last.ckpt\"),\n",
    "    \"vitb16_4\": Path(\"/log/exp/vtt/WikiHowDataModule.CLIPAutoEncoder.GenerationLoss.2023-01-20_12-35-27/checkpoints/last.ckpt\"),\n",
    "    \"vitb32_4\": Path(\"/log/exp/vtt/WikiHowDataModule.CLIPAutoEncoder.GenerationLoss.2023-01-20_12-35-26/checkpoints/last.ckpt\"),\n",
    "    \"vitl14_4\": Path(\"/log/exp/vtt/WikiHowDataModule.CLIPAutoEncoder.GenerationLoss.2023-01-20_12-35-28/checkpoints/last.ckpt\")\n",
    "}\n",
    "models = {}\n",
    "for name, ckpt_path in ckpt_path_list.items():\n",
    "    models[name] = CLIPAutoEncoder(from_decoder_ckpt=ckpt_path)\n",
    "    models[name].eval()\n",
    "    models[name].cuda()\n",
    "tokenizer = SimpleTokenizer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# random.seed(0)\n",
    "for i in range(5):\n",
    "    text = random.choice(descriptions)\n",
    "    print(f\"Text: {text}\")\n",
    "    print(f\"De-text: \")\n",
    "    for name, model in models.items():\n",
    "        embedding = model.encode_raw(text)\n",
    "        de_text = model.decode_raw(embedding)\n",
    "        print(f\"  {name}: {de_text}\")\n",
    "    print()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
