{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a89e314d-6ad7-40f3-a3dc-c34fff9c8c49",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7525fa3-eed3-4335-be3a-0ea0e99c68e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jsonlines\n",
    "from collections import defaultdict\n",
    "import json\n",
    "from pathlib import Path\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.font_manager\n",
    "import matplotlib.image as mpimg\n",
    "from tqdm.notebook import tqdm\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "478d5b54-fac2-4751-b59e-0330435081b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.font_manager\n",
    "import matplotlib.image as mpimg\n",
    "print(f\"available fonts: {sorted([f.name for f in matplotlib.font_manager.fontManager.ttflist])}\")\n",
    "\n",
    "plt.style.use('seaborn-muted')\n",
    "\n",
    "plt.rcParams[\"figure.dpi\"] = 150\n",
    "plt.rcParams[\"savefig.dpi\"] = 300\n",
    "plt.rcParams[\"savefig.format\"] = \"pdf\"\n",
    "plt.rcParams[\"savefig.bbox\"] = \"tight\"\n",
    "plt.rcParams[\"savefig.pad_inches\"] = 0.1\n",
    "\n",
    "plt.rcParams['figure.titlesize'] = 18\n",
    "plt.rcParams['axes.titlesize'] = 18\n",
    "plt.rcParams['font.family'] = 'Helvetica'\n",
    "plt.rcParams['font.size'] = 18\n",
    "\n",
    "plt.rcParams[\"lines.linewidth\"] = 2\n",
    "plt.rcParams['axes.labelsize'] = 16\n",
    "plt.rcParams['axes.labelweight'] = 'bold'\n",
    "plt.rcParams['xtick.labelsize'] = 16\n",
    "plt.rcParams['ytick.labelsize'] = 16\n",
    "plt.rcParams['legend.fontsize'] = 16\n",
    "plt.rcParams['axes.linewidth'] = 2\n",
    "plt.rcParams['axes.titlepad'] = 6\n",
    "\n",
    "plt.rcParams['mathtext.fontset'] = 'dejavuserif'\n",
    "plt.rcParams['mathtext.it'] = 'serif:italic'\n",
    "plt.rcParams['lines.marker'] = \"\"\n",
    "plt.rcParams['legend.frameon'] = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e469ea93-e031-4883-8b3d-b7345bc7a1fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def list2count(_list):\n",
    "    count = defaultdict(int)\n",
    "    for x in _list:\n",
    "        count[x] += 1\n",
    "    count = {key: val for key, val in sorted(count.items())}\n",
    "    return count\n",
    "\n",
    "def show_statistics(data):\n",
    "    print(f\"Total samples: {len(data)}\")\n",
    "    \n",
    "    source_count = defaultdict(int)\n",
    "    for sample in data:\n",
    "        source_count[sample[\"ori\"]] += 1\n",
    "    print(f\"Source count: {json.dumps(source_count, indent=2)}\")\n",
    "    \n",
    "    steps = []\n",
    "    durations = []\n",
    "    clip_durations = []\n",
    "    for sample in data:\n",
    "        n_step = len(sample[\"annotation\"])\n",
    "        steps.append(n_step)\n",
    "        durations.append(sample[\"duration\"])\n",
    "        clip_duration = sum([x['segment'][1] - x['segment'][0] for x in sample[\"annotation\"]])\n",
    "        clip_durations.append(clip_duration)\n",
    "    steps_count = list2count(steps)\n",
    "    # print(f\"Steps count: {json.dumps(steps_count, indent=2)}\")\n",
    "    print(f\"Total steps: {sum(steps)}, total states: {sum(steps) + len(steps)}\")\n",
    "    \n",
    "    cross_steps_count = list2count([steps[i] for i, sample in enumerate(data) if sample[\"ori\"] == \"cross\"])\n",
    "    # var_steps_count = list2count([steps[i] for i, sample in enumerate(data) if sample[\"ori\"] == \"var\"])\n",
    "    coin_steps_count = list2count([steps[i] for i, sample in enumerate(data) if sample[\"ori\"] == \"coin\"])\n",
    "\n",
    "    \n",
    "    # ax1.plot(range(len(steps)), steps, linewidth=1)\n",
    "    # ax1.set_title(\"Step Overview\",  fontstyle='italic')\n",
    "    # ax1.set_xlabel(\"samples\")\n",
    "    # ax1.set_ylabel(\"steps\")\n",
    "    \n",
    "    def plot_step_count(ax):\n",
    "        ax.plot(steps_count.keys(), steps_count.values(), \"-\", label=\"total\")\n",
    "        ax.plot(cross_steps_count.keys(), cross_steps_count.values(), \"--\", label=\"cross\")\n",
    "        # ax2.plot(var_steps_count.keys(), var_steps_count.values(), \":\", label=\"var\")\n",
    "        ax.plot(coin_steps_count.keys(), coin_steps_count.values(), \"-.\", label=\"coin\")\n",
    "\n",
    "        ax.set_title(\"Step Count\", fontstyle=\"italic\")\n",
    "        ax.set_xlabel(\"steps\")\n",
    "        ax.set_ylabel(\"number of samples\")\n",
    "        ax.legend(loc='best', numpoints=1, fancybox=False)\n",
    "        \n",
    "    def plot_step_dist(ax):\n",
    "        # bins = np.arange(14) - 0.5\n",
    "        index = np.arange(len(steps_count.keys()))\n",
    "        ax.bar(index, steps_count.values(), 0.9)\n",
    "        # ax.hist(steps)\n",
    "        ax.set_xticks(index, steps_count.keys())\n",
    "\n",
    "        ax.set_title(\"Step Distribution\", fontstyle=\"italic\")\n",
    "        ax.set_xlabel(\"steps\")\n",
    "        ax.set_ylabel(\"number of samples\")\n",
    "\n",
    "\n",
    "    def plot_duration_dist(ax):\n",
    "        ax.hist(clip_durations, bins=100)\n",
    "\n",
    "        ax.set_title(\"Duration Distribution\", fontstyle=\"italic\")\n",
    "        ax.set_xlabel(\"duration\")\n",
    "        ax.set_ylabel(\"number of samples\")\n",
    "    \n",
    "        \n",
    "    width, height = plt.figaspect(0.3)\n",
    "    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(width, height))\n",
    "    # plot_step_count(ax1)\n",
    "    plot_step_dist(ax1)\n",
    "    plot_duration_dist(ax2)\n",
    "    \n",
    "    # plt.subplots_adjust(hspace=0.25, wspace=0.40)\n",
    "    plt.subplots_adjust(top=0.8, hspace=None, wspace=None) \n",
    "\n",
    "    fig.suptitle(\"Statistics of VTT dataset\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f84fc155-6f33-421b-b9ca-a36e664090d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "with jsonlines.open(\"/data/vtt/meta/vtt.jsonl\") as reader:\n",
    "    data = list(reader)\n",
    "show_statistics(data)\n",
    "plt.savefig(\"statistics-all.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f288da3-5cac-4f50-a722-e642ae3d58a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "with jsonlines.open(\"/data/vtt/meta/vtt.jsonl\") as reader:\n",
    "    data = list(reader)\n",
    "steps = []\n",
    "durations = []\n",
    "clip_durations = []\n",
    "for sample in data:\n",
    "    n_step = len(sample[\"annotation\"])\n",
    "    steps.append(n_step)\n",
    "    durations.append(sample[\"duration\"])\n",
    "    clip_duration = sum([x['segment'][1] - x['segment'][0] for x in sample[\"annotation\"]])\n",
    "    clip_durations.append(clip_duration)\n",
    "steps_count = list2count(steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5daae9d-0c22-4461-bcb5-ebc9aa30bb1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "font_size = 16\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"savefig.dpi\"] = 300\n",
    "plt.rcParams['axes.labelsize'] = font_size + 2\n",
    "plt.rcParams['axes.labelweight'] = 'normal'\n",
    "plt.rcParams['legend.fontsize'] = font_size\n",
    "plt.rcParams['xtick.labelsize'] = font_size\n",
    "plt.rcParams['ytick.labelsize'] = font_size\n",
    "plt.rcParams['axes.linewidth'] = 1.5\n",
    "\n",
    "plt.bar(steps_count.keys(), steps_count.values())\n",
    "plt.xticks(list(steps_count.keys()))\n",
    "plt.xlabel(\"transformations\")\n",
    "plt.ylabel(\"#Samples\")\n",
    "\n",
    "plt.savefig(\"steps_dist.pdf\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97a982f1-2751-4726-b8d1-1c98682dcf0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import scipy.stats\n",
    "font_size = 16\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"savefig.dpi\"] = 300\n",
    "plt.rcParams['axes.labelsize'] = font_size + 2\n",
    "plt.rcParams['axes.labelweight'] = 'normal'\n",
    "plt.rcParams['legend.fontsize'] = font_size\n",
    "plt.rcParams['xtick.labelsize'] = font_size\n",
    "plt.rcParams['ytick.labelsize'] = font_size\n",
    "plt.rcParams['axes.linewidth'] = 1.5\n",
    "\n",
    "_, bins, _ = plt.hist(durations, bins=100, density=False, edgecolor='w', linewidth=0.5)\n",
    "\n",
    "# mu, sigma = scipy.stats.norm.fit(durations)\n",
    "# best_fit_line = scipy.stats.norm.pdf(bins, mu, sigma)\n",
    "# plt.plot(bins, best_fit_line, 'k--', alpha=0.5)\n",
    "\n",
    "plt.xlabel(\"duration\")\n",
    "plt.ylabel(\"#Samples\")\n",
    "\n",
    "plt.savefig(\"duration_dist.pdf\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0018cbd8-c77e-47e6-955e-272d8854ede6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import scipy.stats\n",
    "font_size = 16\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"savefig.dpi\"] = 300\n",
    "plt.rcParams['axes.labelsize'] = font_size + 2\n",
    "plt.rcParams['axes.labelweight'] = 'normal'\n",
    "plt.rcParams['legend.fontsize'] = font_size\n",
    "plt.rcParams['xtick.labelsize'] = font_size\n",
    "plt.rcParams['ytick.labelsize'] = font_size\n",
    "plt.rcParams['axes.linewidth'] = 1.5\n",
    "\n",
    "_, bins, _ = plt.hist(clip_durations, bins=100, density=False, edgecolor='w', linewidth=0.5)\n",
    "\n",
    "# mu, sigma = scipy.stats.norm.fit(durations)\n",
    "# best_fit_line = scipy.stats.norm.pdf(bins, mu, sigma)\n",
    "# plt.plot(bins, best_fit_line, 'k--', alpha=0.5)\n",
    "\n",
    "plt.xlabel(\"segment duration\")\n",
    "plt.ylabel(\"#Samples\")\n",
    "\n",
    "plt.savefig(\"segment_duration_dist.pdf\", dpi=300)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac061eec-7edd-4791-977a-8a5e81f6807f",
   "metadata": {
    "tags": []
   },
   "source": [
    "## sentence statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9d25a15-ae75-4636-8ba7-45501514a10b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import spacy\n",
    "\n",
    "# python -m spacy download en_core_web_sm\n",
    "nlp = spacy.load(\"en_core_web_sm\")\n",
    "lemmatizer = nlp.get_pipe(\"lemmatizer\")\n",
    "\n",
    "sentences = defaultdict(list)\n",
    "words = defaultdict(lambda: defaultdict(int))\n",
    "words_all = []\n",
    "for sample in tqdm(data):\n",
    "    for step in sample[\"annotation\"]:\n",
    "        sentences[sample[\"ori\"]].append(len(step['label'].split()))\n",
    "        doc = nlp(step['label'])\n",
    "        for word in doc:\n",
    "            word = str(word)\n",
    "            words_all.append(word)\n",
    "            if word not in [\",\", \".\"]:\n",
    "                words[sample[\"ori\"]][word] += 1\n",
    "sentences_count = {}\n",
    "for key, val in sentences.items():\n",
    "    sentences_count[key] = list2count(val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0f8a9e3-a044-47e7-8b5c-bcf8441c6e6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "for key, val in sentences_count.items():\n",
    "    plt.plot(val.keys(), val.values(), label=key)\n",
    "plt.title(\"Sentences Length Count\", fontstyle=\"italic\")\n",
    "plt.xlabel(\"length\")\n",
    "plt.ylabel(\"Count\")\n",
    "plt.legend(loc='best', numpoints=1, fancybox=False)\n",
    "plt.savefig(\"statistics-sentences.pdf\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3664c11-46aa-48fe-ab63-1485cb5521cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from wordcloud import WordCloud, STOPWORDS, ImageColorGenerator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e51a7c76-8233-4280-9366-f958379f1633",
   "metadata": {},
   "outputs": [],
   "source": [
    "# join the list and lowercase all the words\n",
    "text = ' '.join(words_all).lower()\n",
    "\n",
    "#create the wordcloud object\n",
    "wordcloud = WordCloud(width=1000,height=500,stopwords = STOPWORDS, collocations=True, background_color=\"white\").generate(text)\n",
    "\n",
    "#plot the wordcloud object\n",
    "plt.imshow(wordcloud, interpolation='bilInear')\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65a37675-fa0c-43f4-aeca-3c29c00b24a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "with jsonlines.open(\"/data/vtt/meta/vtt.jsonl\") as reader:\n",
    "    data = list(reader)\n",
    "sentences = defaultdict(int)\n",
    "for sample in data:\n",
    "    for step in sample[\"annotation\"]:\n",
    "        n_word = len(step[\"label\"].split())\n",
    "        sentences[n_word] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a331757-298f-4718-ac45-0a99b26a96e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "font_size = 16\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"savefig.dpi\"] = 300\n",
    "plt.rcParams['axes.labelsize'] = font_size + 2\n",
    "plt.rcParams['axes.labelweight'] = 'normal'\n",
    "plt.rcParams['legend.fontsize'] = font_size\n",
    "plt.rcParams['xtick.labelsize'] = font_size\n",
    "plt.rcParams['ytick.labelsize'] = font_size\n",
    "plt.rcParams['axes.linewidth'] = 1.5\n",
    "\n",
    "axis = plt.bar(sentences.keys(), sentences.values())\n",
    "# plt.xticks(list(sentences.keys()))\n",
    "plt.ylabel(\"#Sentences\")\n",
    "plt.xlabel(\"words\")\n",
    "\n",
    "plt.savefig(\"sentences_dist.pdf\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a503e7a5-8c8e-4493-a4b7-a63cc8934319",
   "metadata": {},
   "outputs": [],
   "source": [
    "## merge two plots into one"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25d7fcf1-e843-4360-806d-eaed29e51a9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "font_size = 16\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"savefig.dpi\"] = 300\n",
    "plt.rcParams['axes.labelsize'] = font_size + 2\n",
    "plt.rcParams['axes.labelweight'] = 'normal'\n",
    "plt.rcParams['legend.fontsize'] = font_size\n",
    "plt.rcParams['xtick.labelsize'] = font_size\n",
    "plt.rcParams['ytick.labelsize'] = font_size\n",
    "plt.rcParams['axes.linewidth'] = 1.5\n",
    "\n",
    "width, height = plt.figaspect(0.75)\n",
    "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(width, height))\n",
    "\n",
    "\n",
    "\n",
    "ax1.bar(steps_count.keys(), steps_count.values())\n",
    "ax1.set_xticks(list(steps_count.keys()))\n",
    "ax1.set_xlabel(\"transformations\")\n",
    "ax1.set_ylabel(\"#Samples\")\n",
    "\n",
    "ax2.bar(sentences.keys(), sentences.values())\n",
    "# plt.xticks(list(sentences.keys()))\n",
    "ax2.set_ylabel(\"#Sentences\")\n",
    "ax2.set_xlabel(\"words\")\n",
    "\n",
    "# plt.subplots_adjust(hspace=0.25, wspace=0.40)\n",
    "plt.subplots_adjust(hspace=0.6) \n",
    "\n",
    "plt.savefig(\"steps_sentences_dist.pdf\", dpi=300)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28de295e-c16d-43d9-a8bf-7c2f1269bb5c",
   "metadata": {},
   "source": [
    "## statistics of categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b002c78a-a7d0-4f01-a43e-0114caf3c7d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "with jsonlines.open(\"/data/vtt/meta/vtt.jsonl\") as reader:\n",
    "    data = list(reader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef4efbcf-3cdf-4bce-9124-59f38a2aed01",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "class ListData:\n",
    "    def __init__(self, data_list):\n",
    "        self._data_list = data_list\n",
    "        self._id_map = {sample[\"id\"]: sample for sample in data_list}\n",
    "\n",
    "    def __getitem__(self, _id):\n",
    "        return self._id_map[_id]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self._data_list)\n",
    "\n",
    "    def __iter__(self):\n",
    "        return self._data_list.__iter__()\n",
    "\n",
    "\n",
    "class Taxonomy:\n",
    "    def __init__(self, json_path=\"/data/coin/data/taxonomy.json\"):\n",
    "        with open(json_path) as f:\n",
    "            self._data = json.load(f)\n",
    "        self.domains = ListData(self._data[\"domain\"])\n",
    "        self.targets = ListData(self._data[\"target\"])\n",
    "        self.actions = ListData(self._data[\"action\"])\n",
    "\n",
    "    def get_domain_targets(self, domain_id):\n",
    "        domain = self.domains[domain_id]\n",
    "        targets = [self.targets[_id] for _id in domain[\"target_list\"]]\n",
    "        return targets\n",
    "\n",
    "    def get_target_actions(self, target_id):\n",
    "        target = self.targets[target_id]\n",
    "        actions = [self.actions[_id] for _id in target[\"action_list\"]]\n",
    "        return actions\n",
    "\n",
    "    def get_action_target(self, action_id):\n",
    "        return self.targets[self.actions[action_id][\"target_id\"]]\n",
    "\n",
    "    def get_target_domain(self, target_id):\n",
    "        return self.domains[self.targets[target_id][\"domain_id\"]]\n",
    "\n",
    "    def get_action_domain(self, action_id):\n",
    "        target = self.get_action_target(action_id)\n",
    "        return self.domains[target[\"domain_id\"]]\n",
    "    def split_words(self, s):\n",
    "        # split words by Capital letter\n",
    "        words = re.findall(r\"CPR|RJ45|SIM|SSD|CD|TV|PC|[A-Z][^A-Z]*\", s)\n",
    "        words = \" \".join(words)\n",
    "        return words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f38c3203-05ad-4a2c-ad37-ce2ae0c43453",
   "metadata": {},
   "outputs": [],
   "source": [
    "taxonomy = Taxonomy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a443e55-619d-4c64-af67-b547c8475fa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"{len(taxonomy.domains)} domains\")\n",
    "print(f\"{len(taxonomy.targets)} targets\")\n",
    "print(f\"{len(taxonomy.actions)} actions\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23a77523-c7bc-4b69-942b-a429d45efb5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, domain in enumerate(taxonomy.domains):\n",
    "    targets = taxonomy.get_domain_targets(domain['id'])\n",
    "    # print(f\"{i:02d}. {domain['label']}\")\n",
    "    print(f\"{i:02d}. {domain['label']}\")\n",
    "    # for i, target in enumerate(targets):\n",
    "    #     actions = taxonomy.get_target_actions(target['id'])\n",
    "    #     print(f\"\\t{i:03d}. {target['label']}\")\n",
    "        # for i, action in enumerate(actions):\n",
    "        #     print(f\"\\t\\t{i:03d}. {action['label']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b4e7866-0744-4747-8d9c-1ed7f8125859",
   "metadata": {
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "for i, domain in enumerate(taxonomy.domains):\n",
    "    targets = taxonomy.get_domain_targets(domain['id'])\n",
    "    print(f\"{i:02d}. {domain['label']}\")\n",
    "    for i, target in enumerate(targets):\n",
    "        actions = taxonomy.get_target_actions(target['id'])\n",
    "        print(f\"\\t{i:03d}. {taxonomy.split_words(target['label'])}\")\n",
    "        # for i, action in enumerate(actions):\n",
    "        #     print(f\"\\t\\t{i:03d}. {action['label']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6905f60-8e27-4c41-9024-309edbd85221",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "import src.utils.datatool as dtool\n",
    "tasks = dtool.read_jsonlines(\"/data/CrossTask/crosstask_release/tasks.jsonl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03419b22-8d26-4ca0-8e34-4af1008c76a8",
   "metadata": {
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "for i, task in enumerate(tasks):\n",
    "    print(f\"\\\"{task['task']}\\\": \\\"\\\",\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29047ed1-6412-47f1-bd5d-d4a3d461b11b",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, domain in enumerate(taxonomy.domains):\n",
    "    targets = taxonomy.get_domain_targets(domain['id'])\n",
    "    print(f\"{i:02d}. {domain['label']}\")\n",
    "    # for i, target in enumerate(targets):\n",
    "    #     actions = taxonomy.get_target_actions(target['id'])\n",
    "    #     print(f\"\\t{i:03d}. {target['label']}\")\n",
    "        # for i, action in enumerate(actions):\n",
    "        #     print(f\"\\t\\t{i:03d}. {action['label']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94b522d5-d27f-4b3e-be09-bd6c04e745e1",
   "metadata": {},
   "source": [
    "### Plot category distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f37a1c4-f298-4d70-8c79-6f0da5fcd34e",
   "metadata": {},
   "outputs": [],
   "source": [
    "with jsonlines.open(\"/data/vtt/meta/vtt.jsonl\") as reader:\n",
    "    data = list(reader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f28ea7c-4878-4384-b8dc-d73b3005ad36",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "topics = defaultdict(int)\n",
    "categories = defaultdict(int)\n",
    "topics_split = defaultdict(lambda: defaultdict(int))\n",
    "for sample in data:\n",
    "    topics[sample['topic']] += 1\n",
    "    categories[sample['category']] += 1\n",
    "    topics_split[sample['split']][sample['topic']] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5339667a-9af7-4ac4-86f3-0ea17dc936a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d12a7df-11d0-4688-a20f-a9bc46a32ac0",
   "metadata": {},
   "outputs": [],
   "source": [
    "list_sorted = sorted(categories.items(), key=lambda x: x[1], reverse=True)\n",
    "\n",
    "font_size = 16\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"savefig.dpi\"] = 300\n",
    "plt.rcParams['axes.labelsize'] = font_size + 2\n",
    "plt.rcParams['axes.labelweight'] = 'normal'\n",
    "plt.rcParams['legend.fontsize'] = font_size\n",
    "plt.rcParams['xtick.labelsize'] = font_size\n",
    "plt.rcParams['ytick.labelsize'] = font_size\n",
    "plt.rcParams['axes.linewidth'] = 1\n",
    "\n",
    "plt.xticks(rotation='vertical')\n",
    "axis = plt.bar([x[0] for x in list_sorted], [x[1] for x in list_sorted])\n",
    "plt.ylabel(\"#Samples\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49e1ad41-fbe5-4ab1-aa03-dc685b55b8cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(plt.colormaps())\n",
    "colors = plt.get_cmap(\"Set3\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0fc17f0-6583-41a2-bb19-e65282935e9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "list_sorted = sorted(categories.items(), key=lambda x: x[1], reverse=True)\n",
    "n_sample = sum([x[1] for x in list_sorted])\n",
    "\n",
    "width, height = plt.figaspect(1)\n",
    "plt.figure(figsize=(width,height))\n",
    "\n",
    "font_size = 12\n",
    "\n",
    "colormap = random.choice(['Accent', 'Dark2', 'Paired', 'Pastel1', 'Pastel2', 'Set1', 'Set2', 'Set3', 'tab10', 'tab20', 'tab20b', 'tab20c'])\n",
    "colormap = \"tab20c\"\n",
    "print(colormap)\n",
    "\n",
    "plt.style.use('seaborn-muted')\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"savefig.dpi\"] = 300\n",
    "plt.rcParams['font.size'] = font_size\n",
    "plt.rcParams['axes.labelsize'] = font_size + 2\n",
    "plt.rcParams['axes.labelweight'] = 'normal'\n",
    "plt.rcParams['legend.fontsize'] = font_size\n",
    "plt.rcParams['xtick.labelsize'] = font_size\n",
    "plt.rcParams['ytick.labelsize'] = font_size\n",
    "plt.rcParams['axes.linewidth'] = 1\n",
    "\n",
    "_, _, autotexts = plt.pie(\n",
    "    [x[1] / n_sample for x in list_sorted],\n",
    "    # labels = [f\"{x[0]} \\n{x[1]}\" for x in list_sorted],\n",
    "    labels = [f\"{x[0]}\" for x in list_sorted],\n",
    "    startangle=15,\n",
    "    explode=[0.02] * len(list_sorted),\n",
    "    autopct=lambda x: f\"{int(x*n_sample/100)},\\n{x:.2f}%\",\n",
    "    pctdistance=0.75,\n",
    "    colors=plt.get_cmap(colormap).colors\n",
    ")\n",
    "for i, autotext in enumerate(autotexts):\n",
    "    autotext.set_fontsize(font_size - 3 - i*0.2)\n",
    "plt.savefig(\"categories_dist.pdf\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88f2245d-a8dd-47d1-96d2-6edbbe674646",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(topics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9928615a-a942-4484-b554-875bb1eb0ad5",
   "metadata": {},
   "outputs": [],
   "source": [
    "sorted_keys = sorted(topics.items(), key=lambda x: x[1], reverse=True)\n",
    "list_sorted = {}\n",
    "list_sorted['train'] = [(x[0], topics_split['train'][x[0]]) for x in sorted_keys]\n",
    "list_sorted['val'] = [(x[0], topics_split['val'][x[0]]) for x in sorted_keys]\n",
    "list_sorted['test'] = [(x[0], topics_split['test'][x[0]]) for x in sorted_keys]\n",
    "\n",
    "n_rows = 2\n",
    "\n",
    "width, height = plt.figaspect(0.2)\n",
    "fig, ax = plt.subplots(n_rows, 1, figsize=(width, height*n_rows))\n",
    "\n",
    "split_pos = len(list_sorted['train']) // n_rows\n",
    "\n",
    "font_size = 8\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"savefig.dpi\"] = 300\n",
    "plt.rcParams['axes.labelsize'] = font_size + 2\n",
    "plt.rcParams['axes.labelweight'] = 'normal'\n",
    "plt.rcParams['legend.fontsize'] = font_size\n",
    "plt.rcParams['xtick.labelsize'] = font_size\n",
    "plt.rcParams['ytick.labelsize'] = font_size\n",
    "plt.rcParams['axes.linewidth'] = 1\n",
    "\n",
    "colormap = random.choice(['Accent', 'Dark2', 'Paired', 'Pastel1', 'Pastel2', 'Set1', 'Set2', 'Set3', 'tab10', 'tab20', 'tab20b', 'tab20c'])\n",
    "colormap = \"tab20b\"\n",
    "colors = plt.get_cmap(colormap).colors\n",
    "print(colormap)\n",
    "\n",
    "for i, axi in enumerate(ax):\n",
    "\n",
    "    start = i * split_pos\n",
    "    end = len(list_sorted['train']) if i == (len(ax) - 1) else (i + 1) * split_pos\n",
    "    last_top = [0] * (end - start)\n",
    "    for j, split in enumerate([\"train\", \"val\", \"test\"]):\n",
    "        items = list_sorted[split]\n",
    "        key = [x[0] for x in list_sorted[split][start:end]]\n",
    "        val = [x[1] for x in list_sorted[split][start:end]]\n",
    "        axi.bar(\n",
    "            key,\n",
    "            val,\n",
    "            bottom=last_top,\n",
    "            label=split,\n",
    "            color=colors[j+1]\n",
    "        )\n",
    "        for i in range(len(last_top)):\n",
    "            last_top[i] += val[i]\n",
    "    axi.margins(x=0.005)\n",
    "    axi.tick_params(axis='x', rotation=90)\n",
    "    axi.set_ylabel(\"#Samples\")\n",
    "\n",
    "    axi.legend()\n",
    "plt.subplots_adjust(hspace=1.2)\n",
    "plt.savefig(\"topics_dist.pdf\", dpi=300)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf561d7e-d3ce-4bbf-9d0d-87b106fb6be9",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Statistics Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6d4aa05-f632-4068-bdd3-584e997af932",
   "metadata": {},
   "outputs": [],
   "source": [
    "with jsonlines.open(\"/data/vtt/meta/vtt.jsonl\") as reader:\n",
    "    data = list(reader)\n",
    "statistics = defaultdict(lambda: defaultdict(int))\n",
    "for sample in data:\n",
    "    for key in [\"total\", sample[\"split\"]]:\n",
    "        statistics[key][\"Samples\"] += 1\n",
    "        statistics[key][\"Segments\"] += len(sample[\"annotation\"])\n",
    "        statistics[key][\"Duration\"] += sample[\"duration\"]\n",
    "        # statistics[key][\"seg_duration\"] += sum([\n",
    "        #     x['segment'][1] - x['segment'][0] \n",
    "        #     for x in sample['annotation']\n",
    "        # ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b7dc192-e164-48b8-a26e-8941859e5a75",
   "metadata": {},
   "outputs": [],
   "source": [
    "statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "514b34a3-eda2-4bc3-b230-c70704954990",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df = pd.DataFrame.from_dict(statistics, orient=\"index\")\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51931de3-4189-4f9b-ad17-68e04830eacf",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"increase {(13547 - 11827)/11827*100:.2f}% samples\")\n",
    "print(f\"increase {(595 - 476 + (22 - 38)/60)/(476+38/60)*100:.2f}% duration\")\n",
    "print(f\"increase {(55482 - 46354)/46354*100:.2f}% segments\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2ac98e2-4b77-467d-8a17-43ae535ac2a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from src.utils.timetool import time2str\n",
    "from functools import partial\n",
    "time_fmt = partial(time2str, units=[\"h\", \"min\"])\n",
    "df[\"Duration\"] = df[\"Duration\"].apply(time_fmt)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aae7ea05-99fd-48c1-8a5a-dd3f347d5291",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\n",
    "    df.style.to_latex(\n",
    "        caption=\"Statistics of the VTT dataset\",\n",
    "        hrules=True,\n",
    "        position=\"ht\",\n",
    "        position_float=\"centering\",\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ed4de6f-d5a2-41b8-aeed-73cbc58d1d25",
   "metadata": {},
   "source": [
    "## All topics and categories (for Dataset categories, topic indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d1d7ac8-208e-4501-9de5-85ca45199582",
   "metadata": {},
   "outputs": [],
   "source": [
    "topics = sorted(set([x['topic'] for x in data]))\n",
    "categories = sorted(set([x['category'] for x in data]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d2b25fd-59de-465a-ac34-95db65cf4774",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(categories)\n",
    "print(len(categories))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c2d3f65-39df-45d0-8b31-77cbb990916d",
   "metadata": {},
   "outputs": [],
   "source": [
    "categories.index(\"Dish\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98169c1d-23e3-47c0-949f-5bc7d209bd3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(topics)\n",
    "print(len(topics))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d8cef01-b925-4c46-83eb-eec1edf9ec87",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
