{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82edbb2f-4436-44a6-b4e5-8edddb69c570",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28d437f3-18f8-4c9e-9c20-15166f24b893",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from src.dataset.vtt import VTTDataset, VTTSingleDataset\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6cd9f44-7b6b-4a6d-8afc-814bb3135816",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = VTTDataset(\n",
    "    meta_path=\"meta/vtt_integrated.jsonl\",\n",
    "    load_trans_frames=True,\n",
    "    transform_cfg=dict(resize=True, random_crop=True, random_hflip=True),\n",
    "    prefix_start=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ce7feaf-e74c-4a2a-9afd-abdf6603a2d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.arange(24).reshape(2,3,4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "064c43c3-7aa5-4c52-a5c0-21f093a5b0b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "z = torch.tensor([0,1,2,3,0,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2194683c-2943-41b3-b997-51dd10bfcfb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "x[range(2), range(3), z]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d395973-35b7-4df3-b4db-3e01823d9387",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = dataset[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ed34704-2816-446c-bd53-a5ef1eb3032a",
   "metadata": {},
   "outputs": [],
   "source": [
    "res[\"states_mask\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e967cf87-9d8b-4fc8-bc9a-afd10fff61b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "res[\"label_ids\"][range(12), torch.sum(res[\"label_mask\"], dim=-1)-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a776332-4142-4aea-bf5c-759886dcb444",
   "metadata": {},
   "outputs": [],
   "source": [
    "res[\"label_ids\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ab5b753-c558-40cd-8b56-70811d0a4cf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "res[\"label_ids\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8101b5c7-cd05-46d1-8200-0c9425bb50ea",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "res[\"label_mask\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba658e07-9fe5-4121-8afc-0ce0a8b98597",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.tokenizer.decode(res[\"label_ids\"][0].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee2197cb-ce23-4b62-8eda-4a8b9e54e806",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.transform.to_pil(res[\"states\"][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4087a4d8-53cc-4ae8-ba40-8f26ca89ac5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "res[\"states_mask\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f544072e-a4b0-4ef5-995b-a67e8ba93add",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.transform.to_pil(res[\"states\"][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bf7ba2a-eb2f-4c92-a710-ab2ae0afbb7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.transform.to_pil(res[\"trans\"][0][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23b520f1-56f0-430e-a55e-a1cc8eba2c17",
   "metadata": {},
   "outputs": [],
   "source": [
    "res[\"trans_mask\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d30647ca-25f1-4b0d-8587-f12d4cad5a24",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = VTTDataset(\n",
    "    transform_cfg=dict(resize=True),\n",
    "    prefix_start=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38cd3cb1-e97e-4afa-97ff-d3117ba3c4ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset[0]['states'][2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97d093e5-2d3e-4eaa-82aa-1558fff031b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = VTTSingleDataset(\n",
    "    transform_cfg=dict(resize=True),\n",
    "    prefix_start=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27f59b9e-b8a3-4a1b-b889-a80534123a24",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset[1]['states'][2]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f8fadc1-8bbf-4817-8138-65e0d32bab29",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Tokenizers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdae8522-be3a-412e-a528-aa1302719a34",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.dataset.text import SimpleTokenizer\n",
    "tokenizer = SimpleTokenizer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8579b908-540f-4b35-a10e-0311ed29cb99",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = tokenizer.encode(\"And everyone wants to help with the new addition to the house.\")\n",
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59b0eec0-cc4b-4edf-a4c9-52bf31eb266e",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.decode(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d30dda0-000b-455b-89ed-a99263894b51",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.encode(\"<|endoftext|>\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0caa21c0-5dfa-45c7-97b9-dc2f57d3d2eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import GPT2TokenizerFast\n",
    "tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9ee2434-81ab-48c5-9dc3-7ab18159e1ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03757d9f-d2c8-41f5-bcb2-3dabf14a760a",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.encode(\"<-start_of_text->\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "914f5455-98f2-4dba-9687-27918393e201",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.encode(\"And everyone wants to help with the new addition to the house.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18231551-cc3e-4cdb-bd5d-2a2030140731",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.dataset.vtt import VTTDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b853a8a1-6cbf-4421-9207-a90923146747",
   "metadata": {},
   "outputs": [],
   "source": [
    "vtt = VTTDataset(return_raw_text=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec72fae0-a8b8-4027-86f0-1272ee5e4de6",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = vtt[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8734f0c0-0093-492a-a2f8-74b939740c38",
   "metadata": {},
   "outputs": [],
   "source": [
    "x[\"text\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42bc0639-6853-4529-892f-9f289e75f918",
   "metadata": {},
   "outputs": [],
   "source": [
    "x[\"label_ids\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a79700dc-d630-4cf4-bfcf-652f852e571d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97b570d9-8690-4b2a-a2bc-f7281d35ca0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.model.components.text_encoder import TextCLIP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df2f314f-4f01-4a3c-8c70-d0471d9f4287",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_clip = TextCLIP(\"ViT-L/14\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b81578f4-5443-4dea-bcd2-04e5b41b4e35",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_clip(x[\"label_ids\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "715dac29-8a88-4ae6-a1d2-575d2195c94d",
   "metadata": {},
   "outputs": [],
   "source": [
    "x[\"label_mask\"].sum(dim=-1) > 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8ec5554-34e5-4f7b-afe5-2d5e215657d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.model.components.clip import clip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eee11315-9006-41a8-af43-4241a2d4afb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, _ = clip.load(\"ViT-L/14\", device=\"cpu\", download_root=\"/data/pretrain/clip\")\n",
    "text = clip.tokenize(x[\"text\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e865d9ca-32cb-4410-a4fa-d4f574f0f67d",
   "metadata": {},
   "outputs": [],
   "source": [
    "text.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25cd302e-3715-4e57-a7b8-a5d7aa2a6ba2",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_text_1 = text_clip(modified_text)\n",
    "res_text_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3dda062e-3527-4faa-b6da-3ec770afce76",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_text = model.encode_text(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "beed7fbd-f7b2-4851-8c9a-515c770cf2c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "611eeab9-a034-43bf-8e66-ab3e1164228a",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.mean(res_text_1 * res_text, dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ce6b700-17e9-4978-b75c-1792718f4f3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "modified_text = text + (text == 0) * 49407\n",
    "modified_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "595fa2a0-ccf0-4412-a86b-ee4bb419b36c",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cc6471b-410d-4fb6-b71d-c1b37c902b11",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_text.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "209e0956-7ecb-42a0-9eca-8c44aa57fa5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eee9ed17-0048-4aca-8f76-da6d5088695a",
   "metadata": {},
   "source": [
    "## Compute MAX WORDS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ef98c74-63c0-4317-8dff-58c6036d3c6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.dataset.text import SimpleTokenizer\n",
    "tokenizer = SimpleTokenizer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4abea5d0-53dd-45a0-9258-f6070eb9c286",
   "metadata": {},
   "outputs": [],
   "source": [
    "import src.utils.datatool as dtool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "629afd93-25f9-4c8a-b56b-67bd1a95a0b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = dtool.read_jsonlines(\"/data/vtt/meta/vtt_integrated.jsonl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ce547cb-ceac-43f3-86c3-24c053236a86",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "n_words = []\n",
    "for sample in tqdm(data):\n",
    "    n_words.extend([len(tokenizer.encode(x['label'])) for x in sample['annotation']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2999e3ac-7afe-4d3f-ab0e-32062f3c3b69",
   "metadata": {},
   "outputs": [],
   "source": [
    "max(n_words)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b96acbc-bad5-40a1-a446-c14b02b3abc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ac19bd4-b5a2-4b7d-9a40-4f34a98060f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "axis = plt.hist(n_words, bins=21)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9d9f34b-f852-47f3-9c30-73dff2155c95",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
