{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d852e3f-ebd1-44f0-a5a7-cac5e6df0c44",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import jsonlines\n",
    "from collections import defaultdict\n",
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import spacy\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "import src.utils.datatool as dtool  # noqa: E402"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c3a8b69-52e9-464b-8499-dfd2d96d54b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.font_manager\n",
    "import matplotlib.image as mpimg\n",
    "print(f\"available fonts: {sorted([f.name for f in matplotlib.font_manager.fontManager.ttflist])}\")\n",
    "\n",
    "plt.style.use('seaborn-muted')\n",
    "\n",
    "plt.rcParams[\"figure.dpi\"] = 150\n",
    "plt.rcParams[\"savefig.dpi\"] = 300\n",
    "plt.rcParams[\"savefig.format\"] = \"pdf\"\n",
    "plt.rcParams[\"savefig.bbox\"] = \"tight\"\n",
    "plt.rcParams[\"savefig.pad_inches\"] = 0.1\n",
    "\n",
    "plt.rcParams['figure.titlesize'] = 18\n",
    "plt.rcParams['axes.titlesize'] = 18\n",
    "plt.rcParams['font.family'] = 'Helvetica'\n",
    "plt.rcParams['font.size'] = 18\n",
    "\n",
    "plt.rcParams[\"lines.linewidth\"] = 2\n",
    "plt.rcParams['axes.labelsize'] = 16\n",
    "plt.rcParams['axes.labelweight'] = 'bold'\n",
    "plt.rcParams['xtick.labelsize'] = 16\n",
    "plt.rcParams['ytick.labelsize'] = 16\n",
    "plt.rcParams['legend.fontsize'] = 16\n",
    "plt.rcParams['axes.linewidth'] = 2\n",
    "plt.rcParams['axes.titlepad'] = 6\n",
    "\n",
    "plt.rcParams['mathtext.fontset'] = 'dejavuserif'\n",
    "plt.rcParams['mathtext.it'] = 'serif:italic'\n",
    "plt.rcParams['lines.marker'] = \"\"\n",
    "plt.rcParams['legend.frameon'] = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33650e81-32db-47c6-a09a-6b821454534c",
   "metadata": {},
   "outputs": [],
   "source": [
    "with jsonlines.open(\"/data/vtt/meta/vtt.jsonl\") as reader:\n",
    "    data = list(reader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f342c04-024c-4fb5-b54a-1add63fa80ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "data[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eaa398bc-87d7-4e17-992e-1d8189d8f9a4",
   "metadata": {},
   "source": [
    "## Language Compositional Generalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "511a387e-dae2-4983-b5dc-9827db3d61cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def list2count(_list):\n",
    "    count = defaultdict(int)\n",
    "    for x in _list:\n",
    "        count[x] += 1\n",
    "    count = {key: val for key, val in sorted(count.items())}\n",
    "    return count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b082fec4-6f99-45fb-86d7-b5598c4e3fb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# python -m spacy download en_core_web_sm\n",
    "nlp = spacy.load(\"en_core_web_sm\")\n",
    "lemmatizer = nlp.get_pipe(\"lemmatizer\")\n",
    "\n",
    "sentences = defaultdict(list)\n",
    "words = defaultdict(lambda: defaultdict(int))\n",
    "words_all = []\n",
    "for sample in tqdm(data):\n",
    "    for step in sample[\"annotation\"]:\n",
    "        sentences[sample[\"ori\"]].append(len(step['label'].split()))\n",
    "        doc = nlp(step['label'])\n",
    "        for word in doc:\n",
    "            word = str(word)\n",
    "            words_all.append(word)\n",
    "            if word not in [\",\", \".\"]:\n",
    "                words[\"all\"][word] += 1\n",
    "                words[sample[\"ori\"]][word] += 1\n",
    "sentences_count = {}\n",
    "for key, val in sentences.items():\n",
    "    sentences_count[key] = list2count(val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebca528b-b5d7-4f88-8898-a20b8a86f74f",
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_words_all = set(words_all)\n",
    "len(unique_words_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0e39d37-036d-4191-a7d1-e54273f60e10",
   "metadata": {},
   "outputs": [],
   "source": [
    "stat_all = defaultdict(lambda: defaultdict(int))\n",
    "stat_unique = defaultdict(lambda:defaultdict(set))\n",
    "for sample in data:\n",
    "    \n",
    "    stat_all['all']['Samples'] += 1\n",
    "    stat_all[sample['ori']]['Samples'] += 1\n",
    "    stat_all[sample['split']]['Samples'] += 1\n",
    "    \n",
    "    stat_all['all']['Transformations'] += len(sample['annotation'])\n",
    "    stat_all[sample['ori']]['Transformations'] += len(sample['annotation'])\n",
    "    stat_all[sample['split']]['Transformations'] += len(sample['annotation'])\n",
    "    \n",
    "    stat_all['all']['States'] += (len(sample['annotation']) + 1)\n",
    "    stat_all[sample['ori']]['States'] += (len(sample['annotation']) + 1)\n",
    "    stat_all[sample['split']]['States'] += (len(sample['annotation']) + 1)\n",
    "    \n",
    "    stat_unique['all']['Categories'].add(sample['category'])\n",
    "    stat_unique[sample['ori']]['Categories'].add(sample['category'])\n",
    "    stat_unique[sample['split']]['Categories'].add(sample['category'])\n",
    "    \n",
    "    stat_unique['all']['Topics'].add(sample['topic'])\n",
    "    stat_unique[sample['ori']]['Topics'].add(sample['topic'])\n",
    "    stat_unique[sample['split']]['Topics'].add(sample['topic'])\n",
    "    \n",
    "    for t in sample['annotation']:\n",
    "        stat_unique['all']['transformations'].add(t['label'])\n",
    "        stat_unique[sample['ori']]['transformations'].add(t['label'])\n",
    "        stat_unique[sample['split']]['transformations'].add(t['label'])\n",
    "        \n",
    "for dataset, info in stat_unique.items():\n",
    "    for key, s in info.items():\n",
    "        if key == \"transformations\":\n",
    "            key = \"Unique Transformations\"\n",
    "        stat_all[dataset][key] = len(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e8d2cc2-9897-49e8-a118-8acee10b660c",
   "metadata": {},
   "outputs": [],
   "source": [
    "### words in unique transforamtions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67e957b3-91a1-4c84-91bd-fe9eff282713",
   "metadata": {},
   "outputs": [],
   "source": [
    "words_all = set()\n",
    "words_cnt = defaultdict(int)\n",
    "for t in stat_unique['all']['transformations']:\n",
    "    doc = nlp(t)\n",
    "    for word in doc:\n",
    "        word = str(word)\n",
    "        words_all.add(word)\n",
    "        if word not in [\",\", \".\"]:\n",
    "            words_cnt[word] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d56daf90-68e8-4446-aee7-ffb89094a710",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(stat_unique['all']['transformations'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "269eca0c-d510-4a4b-9867-2c805de07b3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(words_all))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45cc8dd3-b5ff-48af-a297-f8f576beb6e8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b80593c6-dfec-4363-8e23-a4ec97906edd",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ff97a78-e868-4f25-b634-2d97b123502b",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_words_cnt = pd.Series(words_cnt).sort_values()[::-1]\n",
    "top_words = t_words_cnt[2:52]\n",
    "top_words_str = list(top_words.index)\n",
    "top_words_cnt = list(top_words.values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44e6f375-f856-4290-bdbb-76080a926914",
   "metadata": {},
   "outputs": [],
   "source": [
    "width, height = plt.figaspect(0.15)\n",
    "font_size = 16\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"savefig.dpi\"] = 300\n",
    "plt.rcParams['axes.labelsize'] = font_size + 2\n",
    "plt.rcParams['axes.labelweight'] = 'normal'\n",
    "plt.rcParams['legend.fontsize'] = font_size\n",
    "plt.rcParams['xtick.labelsize'] = font_size\n",
    "plt.rcParams['ytick.labelsize'] = font_size\n",
    "plt.rcParams['axes.linewidth'] = 1\n",
    "\n",
    "plt.figure(figsize=(width, height))\n",
    "plt.xticks(rotation='vertical')\n",
    "colormap = \"tab20b\"\n",
    "colors = plt.get_cmap(colormap).colors\n",
    "axis = plt.bar(top_words_str, top_words_cnt, color=colors[2])\n",
    "plt.ylabel(\"Count\")\n",
    "plt.margins(x=0.005)\n",
    "plt.savefig(\"top_words.pdf\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff8460aa-c65b-400d-ad50-bf512af4f1aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_train_transformations = list(sorted(stat_unique['train']['transformations']))\n",
    "print(len(unique_train_transformations))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddbbd5f3-c19b-4eab-8249-fcb40402d577",
   "metadata": {},
   "outputs": [],
   "source": [
    "### TTNet transformations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c877c05d-a5b3-4501-9581-0e2e9af835e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = dtool.read_jsonlines(\"/log/exp/vtt/VTTDataModule.TTNetDiff.TellingLossV1.2022-09-19_21-44-42/detail.jsonl\")\n",
    "unique_ttnet_transformations = set()\n",
    "for sample in results:\n",
    "    for t in sample[\"preds\"]:\n",
    "        unique_ttnet_transformations.add(t)\n",
    "print(f\"unique ttnet transformations: {len(unique_ttnet_transformations)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "827ca3f0-289c-4844-8853-9db7c83874d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# no useful unique transformations from TTNet\n",
    "ttnet_transformations_only = unique_ttnet_transformations - set(unique_train_transformations)\n",
    "ttnet_transformations_only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bb7a1f4-46b7-4348-afb2-a7683e35b6ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "### CrossTask related videos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11fb5205-5bfc-4bb4-bdb1-786d4903c8b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = dtool.read_jsonlines(\"../docs/lists/tasks.jsonl\")\n",
    "print(len(tasks))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6090f162-9d9b-4f57-8f49-eec10722ede2",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"add\" not in words_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a54dd46-824f-48b1-9c4c-183ff10b0f36",
   "metadata": {},
   "outputs": [],
   "source": [
    "candidates_tasks = []\n",
    "def is_task_valid(task):\n",
    "    for t in task[\"steps\"]:\n",
    "        doc = nlp(t)\n",
    "        for word in doc:\n",
    "            word = str(word)\n",
    "            if word not in words_all:\n",
    "                print(word)\n",
    "                return False\n",
    "    return True\n",
    "for task in tasks:\n",
    "    if task['type'] == 'related' and is_task_valid(task):\n",
    "        candidates_tasks.append(task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdeed743-2f17-4806-b374-ee879c0324b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "candidates_tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fa9b6e1-1990-4140-a08d-164190d1aafe",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXPERIMENTS = {\n",
    "    \"cst\": \"/log/exp/vtt/VTTDataModule.CST.GenerationLoss.2022-09-17_00-06-25\",\n",
    "    \"glacnet\": \"/log/exp/vtt/VTTDataModule.GLACNet.GenerationLoss.2022-09-18_17-56-36\",\n",
    "    \"densecap\": \"/log/exp/vtt/VTTDataModule.DenseCap.GenerationLoss.2022-09-25_15-15-34\",\n",
    "    \"ttnet_base\": \"/log/exp/vtt/VTTDataModule.TTNetMTM.GenerationLoss.2022-09-16_10-59-03\",\n",
    "    \"ttnet\": \"/log/exp/vtt/VTTDataModule.TTNetDiff.TellingLossV1.2022-09-19_21-44-42\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e02b780e-008a-4acb-a51e-65906055af17",
   "metadata": {},
   "outputs": [],
   "source": [
    "for exp, exp_path in EXPERIMENTS.items():\n",
    "    result_path = Path(exp_path) / \"detail.jsonl\"\n",
    "    results = dtool.read_jsonlines(result_path)\n",
    "    print(f\"{exp}:\")\n",
    "    print([t for t in results[-1][\"preds\"]])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "870a612c-5c28-4ccf-91da-242913e510b3",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Combination Generalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b884254d-1dc8-46fe-af6a-fdcafb16d50a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# topic, all, train, val, test\n",
    "process_set = defaultdict(lambda: defaultdict(set))\n",
    "process = {}\n",
    "for sample in data:\n",
    "    p = \"-\".join([x[\"label\"] for x in sample[\"annotation\"]])\n",
    "    process_set[sample[\"category\"]][\"All\"].add(p)\n",
    "    process_set[sample[\"category\"]][sample[\"split\"].capitalize()].add(p)\n",
    "for topic, sets in process_set.items():\n",
    "    process[topic] = {}\n",
    "    s_train = sets['Train']\n",
    "    for split, s in sets.items():\n",
    "        process[topic][split] = len(s)\n",
    "        if split == \"Val\" or split == \"Test\":\n",
    "            name = f\"{split} Unique\"\n",
    "            process[topic][name] = len(s - s_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c3e0800-b50a-4b4c-91d8-f02d1164e250",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(process).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fa87f5b-5b84-42a3-aa57-c61d63f866fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.sort_index()\n",
    "df.loc[\"Total\"] = df.sum()\n",
    "df = df[[\"Train\", \"Val\", \"Val Unique\", \"Test\", \"Test Unique\", \"All\"]]\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5092ca7-79e9-4650-b336-29b3dcd822cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(df.style.to_latex(hrules=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52278c42-9971-4e61-81a1-cb0340c57c97",
   "metadata": {},
   "outputs": [],
   "source": [
    "META_FILE = Path(\"/data/vtt/meta/vtt.jsonl\")\n",
    "EXPERIMENTS = {\n",
    "    \"cst\": \"/log/exp/vtt/VTTDataModule.CST.GenerationLoss.2022-09-17_00-06-25\",\n",
    "    \"glacnet\": \"/log/exp/vtt/VTTDataModule.GLACNet.GenerationLoss.2022-09-18_17-56-36\",\n",
    "    \"densecap\": \"/log/exp/vtt/VTTDataModule.DenseCap.GenerationLoss.2022-09-25_15-15-34\",\n",
    "    \"ttnet_base\": \"/log/exp/vtt/VTTDataModule.TTNetMTM.GenerationLoss.2022-09-16_10-59-03\",\n",
    "    \"ttnet\": \"/log/exp/vtt/VTTDataModule.TTNetDiff.TellingLossV1.2022-09-19_21-44-42\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "681fcff5-de11-4172-8cc4-a4925eef148c",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_samples = dtool.JSONLList(META_FILE, lambda x: x[\"split\"] == \"test\").samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65e2aa8e-68f7-4ad5-803c-0b7848c39965",
   "metadata": {},
   "outputs": [],
   "source": [
    "process_set_train = set()\n",
    "for topic, sets in process_set.items():\n",
    "    process_set_train = process_set_train | sets[\"Train\"]\n",
    "print(len(process_set_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4436193c-2418-4a9e-8cfb-a06f0027eb5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_share = []\n",
    "test_only = []\n",
    "for i, sample in enumerate(test_samples):\n",
    "    p = \"-\".join([x[\"label\"] for x in sample[\"annotation\"]])\n",
    "    if p in process_set_train:\n",
    "        test_share.append(i)\n",
    "    else:\n",
    "        test_only.append(i)\n",
    "print(f\"share: {len(test_share)}\")\n",
    "print(f\"only: {len(test_only)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d9e6320-f061-4ec8-850c-120290e41ce4",
   "metadata": {},
   "outputs": [],
   "source": [
    "METRICS = [\"BLEU_4\", \"ROUGE\", \"METEOR\", \"CIDEr\", \"SPICE\", \"BERTScore\"]\n",
    "def compute_metrics(results, metrics=METRICS):\n",
    "    scores = defaultdict(list)\n",
    "    for result in results:\n",
    "        for metric in metrics:\n",
    "            if type(result[metric]) is list:\n",
    "                scores[metric].extend(result[metric])\n",
    "            else:\n",
    "                scores[metric].append(result[metric])\n",
    "    return {key: np.mean(value) for key, value in scores.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b34377e-1c4d-4cbd-b6c7-b15b332a047a",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Automatic Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b00c2743-84e9-410d-9e74-a8e88265979c",
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = {}\n",
    "for exp, exp_path in EXPERIMENTS.items():\n",
    "    result_path = Path(exp_path) / \"detail.jsonl\"\n",
    "    results = dtool.read_jsonlines(result_path)\n",
    "    scores[exp] = compute_metrics(results)\n",
    "df_scores = pd.DataFrame(scores).T\n",
    "df_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a356745-70e2-4817-b48b-49cacabb4847",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(df_scores[[\"CIDEr\"]].style.format(precision=2).to_latex(hrules=True, ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cd3df16-d2fd-4474-92b1-80959663104e",
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = {}\n",
    "for exp, exp_path in EXPERIMENTS.items():\n",
    "    result_path = Path(exp_path) / \"detail.jsonl\"\n",
    "    results = dtool.read_jsonlines(result_path)\n",
    "    # share\n",
    "    results = [results[i] for i in test_share]\n",
    "    scores[exp] = compute_metrics(results)\n",
    "df_scores = pd.DataFrame(scores).T\n",
    "df_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd01999f-e447-4db6-85c8-0fee5691867f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(df_scores[[\"CIDEr\"]].style.format(precision=2).to_latex(hrules=True, ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49bd0f09-82a3-42ff-9ed5-60848c27378a",
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = {}\n",
    "for exp, exp_path in EXPERIMENTS.items():\n",
    "    result_path = Path(exp_path) / \"detail.jsonl\"\n",
    "    results = dtool.read_jsonlines(result_path)\n",
    "    # only\n",
    "    results = [results[i] for i in test_only]\n",
    "    scores[exp] = compute_metrics(results)\n",
    "df_scores = pd.DataFrame(scores).T\n",
    "df_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2a48170-e567-43c0-8c40-96ecf8bfec16",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(df_scores[[\"CIDEr\"]].style.format(precision=2).to_latex(hrules=True, ))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd52436d-1d3e-4473-9f02-f775fe030edf",
   "metadata": {},
   "source": [
    "### Human Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71945a4d-ae14-4b7f-8ae7-6b1673d658e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "HUMAN_RESULTS_DIR = Path(\"../docs/lists/human_results\")\n",
    "EXPS = [\"cst\", \"glacnet\", \"densecap\", \"ttnet_base\", \"ttnet\"]\n",
    "HUMAN_METRICS = [\"fluency\", \"relevance\", \"logical_soundness\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd77c847-f365-48c1-b2fb-5bf091d520e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = {}\n",
    "for exp in EXPS:\n",
    "    path = HUMAN_RESULTS_DIR / f\"{exp}.jsonl\"\n",
    "    results = dtool.read_jsonlines(path)\n",
    "    scores[exp] = compute_metrics(results, HUMAN_METRICS)\n",
    "df_scores = pd.DataFrame(scores).T\n",
    "df_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc82e637-12a5-43ba-a53f-752c9512d0e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(df_scores.style.format(precision=2).to_latex(hrules=True, ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc9a67cc-d9c4-4089-af82-5f910f5bb963",
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = {}\n",
    "for exp in EXPS:\n",
    "    path = HUMAN_RESULTS_DIR / f\"{exp}.jsonl\"\n",
    "    results = dtool.read_jsonlines(path)\n",
    "    results = [x for x in results if x[\"index\"] in test_share]\n",
    "    scores[exp] = compute_metrics(results, HUMAN_METRICS)\n",
    "df_scores = pd.DataFrame(scores).T\n",
    "df_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02b06fcf-e088-4272-b3d7-f1d57e00d1f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(df_scores.style.format(precision=2).to_latex(hrules=True, ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "458115f9-297a-446e-9956-56a9a7f0709b",
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = {}\n",
    "for exp in EXPS:\n",
    "    path = HUMAN_RESULTS_DIR / f\"{exp}.jsonl\"\n",
    "    results = dtool.read_jsonlines(path)\n",
    "    results = [x for x in results if x[\"index\"] in test_only]\n",
    "    scores[exp] = compute_metrics(results, HUMAN_METRICS)\n",
    "df_scores = pd.DataFrame(scores).T\n",
    "df_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba125c3e-629a-45b3-b720-2f09a2c15a7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(df_scores.style.format(precision=2).to_latex(hrules=True, ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc1cbdec-a906-4ea4-aa6b-5876f6a02cf0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
