{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "collapsed_sections": [
        "wLE2vFgnlUed",
        "Y8h0vGOpvcRE",
        "GOdLCa2ClwUj",
        "ugQe5kDDmbX1",
        "e_P2EK3qzRGj",
        "PT_Vj4-7PbUX"
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QADWonBxzU37",
        "outputId": "6340ef13-1fb0-4ca8-a2ba-de92339d71ad"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "project_dir = '/content/drive/MyDrive/Research/Creativity Benchmark/'"
      ],
      "metadata": {
        "id": "rw3DL4dQz9jc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install jsonlines"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vD_gBnkx0X_Z",
        "outputId": "cb1d26ca-d4e5-450b-c021-40fbeb0eb10b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting jsonlines\n",
            "  Downloading jsonlines-4.0.0-py3-none-any.whl.metadata (1.6 kB)\n",
            "Requirement already satisfied: attrs>=19.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonlines) (25.3.0)\n",
            "Downloading jsonlines-4.0.0-py3-none-any.whl (8.7 kB)\n",
            "Installing collected packages: jsonlines\n",
            "Successfully installed jsonlines-4.0.0\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import json\n",
        "\n",
        "def load_json(filename):\n",
        "    \"\"\"\n",
        "    Load a JSON file given a filename\n",
        "    If the file doesn't exist, then return an empty dictionary instead\n",
        "    \"\"\"\n",
        "    try:\n",
        "        with open(filename, 'r') as f:\n",
        "            return json.load(f)\n",
        "    except FileNotFoundError:\n",
        "        return {}\n",
        "\n",
        "def write_json(data, filepath):\n",
        "    # assert isinstance(data, dict), '[ERROR] Expect dictionary data!'\n",
        "    json_string = json.dumps(data, indent = 4)\n",
        "    with open(filepath, 'w') as outfile:\n",
        "        outfile.write(json_string)\n",
        "    # return 0\n",
        "\n",
        "import jsonlines\n",
        "def load_jsonl(filename):\n",
        "    file_content = []\n",
        "    try:\n",
        "        with jsonlines.open(filename) as reader:\n",
        "            for obj in reader:\n",
        "                file_content.append(obj)\n",
        "            return file_content\n",
        "    except FileNotFoundError:\n",
        "        return []"
      ],
      "metadata": {
        "id": "OE26mtaMz_do"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np"
      ],
      "metadata": {
        "id": "boeZ_aF7gWgV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 0. Format"
      ],
      "metadata": {
        "id": "wLE2vFgnlUed"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 0.1 data template"
      ],
      "metadata": {
        "id": "UHbme_CtGg01"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "dp_template = {\n",
        "    'meta_data': {\n",
        "        'dataset': ...,\n",
        "        'id': ...,\n",
        "        'eval_func': ...\n",
        "    },\n",
        "    'input': {\n",
        "        'text': '',\n",
        "        'file': '',\n",
        "        'others': {\n",
        "            'constraints': {},\n",
        "            'instructions': {},\n",
        "            'references': {}\n",
        "        }\n",
        "    },\n",
        "    'output': {\n",
        "        **kwargs\n",
        "    }\n",
        "}"
      ],
      "metadata": {
        "id": "0y8I7z_dFuEJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "IN2k_2zzlYjg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 0.2 Workflow"
      ],
      "metadata": {
        "id": "EIhRzUEmnpAy"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# class Model():\n",
        "#     def __init__(self, model_config):\n",
        "#         pass\n",
        "\n",
        "#     def make_prompts(self, batch_input):\n",
        "#         pass\n",
        "\n",
        "#     def batch_predict(self, batch_prompt):\n",
        "#         pass\n",
        "\n",
        "#     def post_processing(self, batch_output):\n",
        "#         pass"
      ],
      "metadata": {
        "id": "5cX584NAlYhe"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import json"
      ],
      "metadata": {
        "id": "cVdjR4ZxisZL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "s = '''\n",
        "{\"('1', '')\": \"Here's the ranking of the alternative uses from least creative to most creative, along with assigned scores:\\n\\n1. **Paperweight for holding documents down** - **Score: 1** (Least creative; this is a very common use of objects.)\\n   \\n2. **Keychain holder for small items** - **Score: 1** (Close to common use; typical keychain function.)\\n\\n3. **Bookmark for keeping pages marked** - **Score: 2** (Functional, but fairly standard; still more cognitive involvement than a paperweight.)\\n\\n4. **Garden tool for soil aeration** - **Score: 2** (\", \"('6', '')\": 'Here’s a ranking of the alternative uses listed, from least creative (1) to most creative (5):\\n\\n1. **Eating food from a plate.** - Score: 1 (Least creative; this is the most common and expected use.)\\n2. **Serving salad to dinner guests.** - Score: 1 (Very similar to eating from a plate; standard use.)\\n3. **Mixing ingredients in a bowl.** - Score: 2 (Common use but slightly more creative as it implies preparation.)\\n4. **Holding food while cutting meat.** - Score: 2 (Practical but still within common usage', \"('11', '')\": \"Here's the ranking of the alternative uses based on creativity, from least creative (1) to most creative (5):\\n\\n1. **Garden tool for loosening soil** - **Score: 1**  \\n   (Common use, very straightforward)\\n\\n2. **Hair accessory for securing styles** - **Score: 2**  \\n   (Practical, but still a fairly common use)\\n\\n3. **Keychain for holding multiple keys** - **Score: 2**  \\n   (Useful but very typical, not particularly imaginative)\\n\\n4. **Bookmark for holding pages open** - **Score: 3**  \\n   (A reasonable\", \"('11', 'bsr')\": \"Here's a ranking of the alternative uses from least creative (1) to most creative (5):\\n\\n1. **Stand for holding smartphone upright.** (Score: 1)  \\n   This is a very common use of various items, making it straightforward and not particularly inventive.\\n\\n2. **Clasp for securing fabric pieces.** (Score: 2)  \\n   While a bit more creative than a smartphone stand, using an object as a clasp is still fairly conventional.\\n\\n3. **Device for mixing small batches.** (Score: 2)  \\n   Mixing is a common task in many contexts, making this use less imaginative.\\n\\n\"}\n",
        "'''.replace('\\n', '')\n",
        "# print()\n",
        "print(json.dumps(s, indent = 4))"
      ],
      "metadata": {
        "id": "CXunjV9B037a",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "9f62b9e2-b684-4d2d-fabf-4102890add50"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\"{\\\"('1', '')\\\": \\\"Here's the ranking of the alternative uses from least creative to most creative, along with assigned scores:1. **Paperweight for holding documents down** - **Score: 1** (Least creative; this is a very common use of objects.)   2. **Keychain holder for small items** - **Score: 1** (Close to common use; typical keychain function.)3. **Bookmark for keeping pages marked** - **Score: 2** (Functional, but fairly standard; still more cognitive involvement than a paperweight.)4. **Garden tool for soil aeration** - **Score: 2** (\\\", \\\"('6', '')\\\": 'Here\\u2019s a ranking of the alternative uses listed, from least creative (1) to most creative (5):1. **Eating food from a plate.** - Score: 1 (Least creative; this is the most common and expected use.)2. **Serving salad to dinner guests.** - Score: 1 (Very similar to eating from a plate; standard use.)3. **Mixing ingredients in a bowl.** - Score: 2 (Common use but slightly more creative as it implies preparation.)4. **Holding food while cutting meat.** - Score: 2 (Practical but still within common usage', \\\"('11', '')\\\": \\\"Here's the ranking of the alternative uses based on creativity, from least creative (1) to most creative (5):1. **Garden tool for loosening soil** - **Score: 1**     (Common use, very straightforward)2. **Hair accessory for securing styles** - **Score: 2**     (Practical, but still a fairly common use)3. **Keychain for holding multiple keys** - **Score: 2**     (Useful but very typical, not particularly imaginative)4. **Bookmark for holding pages open** - **Score: 3**     (A reasonable\\\", \\\"('11', 'bsr')\\\": \\\"Here's a ranking of the alternative uses from least creative (1) to most creative (5):1. **Stand for holding smartphone upright.** (Score: 1)     This is a very common use of various items, making it straightforward and not particularly inventive.2. **Clasp for securing fabric pieces.** (Score: 2)     While a bit more creative than a smartphone stand, using an object as a clasp is still fairly conventional.3. **Device for mixing small batches.** (Score: 2)     Mixing is a common task in many contexts, making this use less imaginative.\\\"}\"\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 1. TTCW"
      ],
      "metadata": {
        "id": "FEwPKZt704UH"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "stories_json = load_json(project_dir + 'raw/creative_writing_eval/Art_or_Artifice/stories/ttcw_short_stories.json')\n",
        "stories_json[0].keys()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "tQEkABg7035p",
        "outputId": "76d93ed6-daf6-45fd-ca1d-6264d43e39c9"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "dict_keys(['story_idx', 'story_id', 'story_name', 'plot', 'content'])"
            ]
          },
          "metadata": {},
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "len(stories_json)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vcRT3R6z95Oy",
        "outputId": "35c7ecbe-fa48-448f-85c4-4f7a2c55c0ef"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "48"
            ]
          },
          "metadata": {},
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# instruction = '''\n",
        "# You are given a creative short-story. Read it carefully. You are then given some background about specific aspects of creative writing, as well as a binary (Yes/No) question. Your objective is to use the background information to answer the question about the story. Start your answer with Yes or No. You can optionally then provide a short explanation for your answer.\n",
        "\n",
        "# ==========\n",
        "# Story:\n",
        "\n",
        "\n",
        "# [STORY]\n",
        "\n",
        "\n",
        "# ==========\n",
        "# Background:\n",
        "\n",
        "\n",
        "# [BACKGROUND]\n",
        "\n",
        "\n",
        "# ==========\n",
        "# Question: [QUESTION]\n",
        "\n",
        "# Remember to start your answer with Yes or No. You can optionally then provide a short explanation for your answer.\n",
        "# '''"
      ],
      "metadata": {
        "id": "IW5xiOStLAZJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def transform_dp_creative_writing_eval(dp):\n",
        "    transformed = {\n",
        "        'meta_data': {\n",
        "            'dataset': 'creative_writing_eval',\n",
        "            'id': dp['story_id'],\n",
        "            'eval_func': None\n",
        "        },\n",
        "        'input': {\n",
        "            'text': 'Write a New Yorker-style story given the plot below. Make sure it is atleast {{word_count}} words. Directly start with the story, do not say things like ‘Here’s the story [...]:',\n",
        "            'file': '',\n",
        "            'others': {\n",
        "                'plot': dp['plot']\n",
        "            }\n",
        "        },\n",
        "        'output': {\n",
        "            'content': dp['content']\n",
        "        }\n",
        "    }\n",
        "    return transformed"
      ],
      "metadata": {
        "id": "tDwrxvIv3KYh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "transformed_json = [transform_dp_creative_writing_eval(dp) for dp in stories_json]"
      ],
      "metadata": {
        "id": "JI3lT4Tc3KUV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# write_json(transformed_json, project_dir + 'processed/creative_writing_eval.json')"
      ],
      "metadata": {
        "id": "EZA14swD3KSB"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "GrhvpicGba_x"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# len([dp for dp in transformed_json if 'https://www.newyorker.com/' in dp['output']['content']])"
      ],
      "metadata": {
        "id": "Ml7oqVYUbfFg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import nltk\n",
        "nltk.download('punkt_tab')\n",
        "from nltk import word_tokenize"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dqvMVinDfxoa",
        "outputId": "bde538c6-04fd-467a-ee2f-26f3a48d3eaa"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "[nltk_data] Downloading package punkt_tab to /root/nltk_data...\n",
            "[nltk_data]   Unzipping tokenizers/punkt_tab.zip.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "story_len = {}\n",
        "plot_id = {}\n",
        "counter = 0\n",
        "for dp in transformed_json:\n",
        "    plot = dp['input']['others']['plot']\n",
        "    if plot not in plot_id:\n",
        "        plot_id[plot] = counter\n",
        "        tmp_plot_id = counter\n",
        "        story_len[tmp_plot_id] = []\n",
        "        counter += 1\n",
        "    else:\n",
        "        tmp_plot_id = plot_id[plot]\n",
        "    dp['input']['others']['plot_id'] = tmp_plot_id\n",
        "    story_text = dp['output']['content']\n",
        "    if 'https://www.newyorker.com/' in story_text:\n",
        "        pass\n",
        "    else:\n",
        "        story_len[tmp_plot_id].append(len(word_tokenize(story_text)))"
      ],
      "metadata": {
        "id": "gxlCOY4aesBM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "avg_story_len = {k: round(np.mean(story_len[k])) for k in story_len}\n",
        "avg_story_len"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ifBkf3MegHSN",
        "outputId": "b08ae95d-77e2-4f0e-9a95-ca07ea97a8c7"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{0: 1989,\n",
              " 1: 1827,\n",
              " 2: 1629,\n",
              " 3: 2224,\n",
              " 4: 1604,\n",
              " 5: 2336,\n",
              " 6: 1586,\n",
              " 7: 1889,\n",
              " 8: 1153,\n",
              " 9: 1671,\n",
              " 10: 1653,\n",
              " 11: 1530}"
            ]
          },
          "metadata": {},
          "execution_count": 44
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# len(word_tokenize('''\n",
        "# Write a New Yorker-style story given the plot below. Make sure it is atleast {word_count} words. Directly start with the story, do not say things like \"Here's the story [...]\"\n",
        "\n",
        "# Plot: {plot}\n",
        "# Story:\n",
        "# '''))\n",
        "# np.mean(list(avg_story_len.values()))\n",
        "ttcw_questions = load_json('ttcw_questions.json')\n",
        "ttcw_questions[0]\n",
        "np.mean([len(word_tokenize(q['full_prompt'])) for q in ttcw_questions]) * 1.4"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OTZoS0W2qNA_",
        "outputId": "2d7e2c1e-acbb-4931-ec9c-87ced10bd180"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "np.float64(345.59999999999997)"
            ]
          },
          "metadata": {},
          "execution_count": 31
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "instruction = '''\n",
        "ou are given a creative short-story. Read it carefully. You are then given some background about specific aspects of creative writing, as well as a binary (Yes/No) question. Your objective is to use the background information to answer the question about the story. Start your answer with Yes or No. You can optionally then provide a short explanation for your answer.\n",
        "\n",
        "==========\n",
        "Story:\n",
        "[{story}]\n",
        "\n",
        "==========\n",
        "Question:\n",
        "[{full_prompt}]\n",
        "\n",
        "Remember to start your answer with Yes or No. You can optionally then provide a short explanation for your answer.\n",
        "'''\n",
        "len(word_tokenize(instruction)) * 1.4"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YaBUkr-arJqs",
        "outputId": "c61e10f4-dcee-4f0d-e195-8f437e796900"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "151.2"
            ]
          },
          "metadata": {},
          "execution_count": 29
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "14 * (1800 + 150 + 350) *"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SiyG7alrrP79",
        "outputId": "534345a3-5c90-4c1c-ced3-1e0ec4590e50"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "32200"
            ]
          },
          "metadata": {},
          "execution_count": 32
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "for dp in transformed_json:\n",
        "    dp['input']['others']['avg_len'] = avg_story_len[dp['input']['others']['plot_id']]"
      ],
      "metadata": {
        "id": "WuhvHerVdEzc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "l = {0: 1989,\n",
        " 1: 1827,\n",
        " 2: 1629,\n",
        " 3: 2224,\n",
        " 4: 1604,\n",
        " 5: 2336,\n",
        " 6: 1586,\n",
        " 7: 1889,\n",
        " 8: 1153,\n",
        " 9: 1671,\n",
        " 10: 1653,\n",
        " 11: 1530}\n",
        "\n",
        "import numpy as np\n",
        "\n",
        "np.mean(list(l.values()))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "F7gKXkwdhguX",
        "outputId": "241e4c54-a47d-43c3-f1ee-7743a64a005c"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "np.float64(1757.5833333333333)"
            ]
          },
          "metadata": {},
          "execution_count": 2
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "1800 * 12 / 3 * 4"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "FYrz7fMzhgrX",
        "outputId": "7f53eaab-a2bd-435e-d472-5dd264d0d020"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "28800.0"
            ]
          },
          "metadata": {},
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "- divide the data by different generation model"
      ],
      "metadata": {
        "id": "AyQYHxrLhg7D"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "groupped_data = {}\n",
        "\n",
        "for dp in transformed_json:\n",
        "    original_model = dp['meta_data']['id'].split('_')[1]\n",
        "    if original_model not in groupped_data:\n",
        "        groupped_data[original_model] = [dp]\n",
        "    else:\n",
        "        groupped_data[original_model].append(dp)"
      ],
      "metadata": {
        "id": "4i9JSglmhm8F"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "for group in groupped_data:\n",
        "    write_json(\n",
        "        groupped_data[group],\n",
        "        project_dir + 'processed/creative_writing_eval_{}.json'.format(group)\n",
        "    )"
      ],
      "metadata": {
        "id": "GqwC-rP5hm5l"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# write_json(\n",
        "#     [dp for dp in transformed_json if 'https://www.newyorker.com/' in dp['output']['content']],\n",
        "#     project_dir + 'processed/creative_writing_eval_test.json'\n",
        "# )"
      ],
      "metadata": {
        "id": "9vLdV8ccheWG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# write_json(transformed_json, project_dir + 'processed/creative_writing_eval.json')"
      ],
      "metadata": {
        "id": "0gDTi8M8hftc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "_GWcqMunhfq7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "QkP5l49qhfoy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "all_plots = [dp['input']['others']['plot'] for dp in transformed_json]\n",
        "len(np.unique(all_plots))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "HyAw51sGb3b-",
        "outputId": "46c5fc1f-c74e-4f95-aa36-d89031c4199c"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "12"
            ]
          },
          "metadata": {},
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 2. Creativity Index"
      ],
      "metadata": {
        "id": "Y8h0vGOpvcRE"
      }
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "t1n0_clIhT4Q"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "poem_data = load_json(project_dir + 'raw/creativity_index/data/poem/Human_poem.json')\n",
        "# len(poem_data)\n",
        "speech_data = load_json(project_dir + 'raw/creativity_index/data/poem/Human_speech.json')\n",
        "book_data = load_json(project_dir + 'raw/creativity_index/data/poem/Human_book.json')"
      ],
      "metadata": {
        "id": "kzUoHwMG03z2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# dp = poem_data[0]\n",
        "# dp"
      ],
      "metadata": {
        "id": "fvxw24XOcHTj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def transform_dp_creativity_index(dp, subset, index):\n",
        "    transformed = {\n",
        "        'meta_data': {\n",
        "            'dataset': 'creativity_index',\n",
        "            'id': '{}_{}'.format(subset, index),\n",
        "            'eval_func': None\n",
        "        },\n",
        "        'input': {\n",
        "            'text': dp['prompt']\n",
        "        },\n",
        "        'output': {\n",
        "            'content': dp['text']\n",
        "        }\n",
        "    }\n",
        "    return transformed\n",
        "# transform_dp_creativity_index(dp, 'poem', 0)"
      ],
      "metadata": {
        "id": "YvJD8EAKdcq7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "creativity_index_data = []\n",
        "\n",
        "creativity_index_data_poem = [transform_dp_creativity_index(poem_data[index], 'poem', index) for index in range(len(poem_data))]\n",
        "creativity_index_data.extend(creativity_index_data_poem)\n",
        "\n",
        "creativity_index_data_speech = [transform_dp_creativity_index(speech_data[index], 'speech', index) for index in range(len(speech_data))]\n",
        "creativity_index_data.extend(creativity_index_data_speech)\n",
        "\n",
        "creativity_index_data_book = [transform_dp_creativity_index(book_data[index], 'book', index) for index in range(len(book_data))]\n",
        "creativity_index_data.extend(creativity_index_data_poem)"
      ],
      "metadata": {
        "id": "QwGfL3EIhwky"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "len(creativity_index_data)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "C08hlPi1jW-3",
        "outputId": "d4d2d7c1-792b-481f-f153-3d3bdae072fb"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "494"
            ]
          },
          "metadata": {},
          "execution_count": 43
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# write_json(creativity_index_data, project_dir + 'processed/creativity_index.json')"
      ],
      "metadata": {
        "id": "s6WfIvkWjd7y"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "1QhdgHE2lwD1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 3. DAT"
      ],
      "metadata": {
        "id": "GOdLCa2ClwUj"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "DAT_PROMPT = '''\n",
        "Please enter 10 words that are as different from each other as possible, in all meanings and uses of the words. Rules: Only single words in English. Only nouns (e.g., things, objects, concepts). No proper nouns (e.g., no specific people or places). No specialized vocabulary (e.g., no technical terms). Think of the words on your own (e.g., do not just look at objects in your surroundings). Make a list of these 10 words, a single word in each entry of the list.\n",
        "'''\n",
        "dat_dp = {\n",
        "    'meta_data': {\n",
        "        'dataset': 'dat_test',\n",
        "        'id': 0,\n",
        "        'eval_func': None\n",
        "    },\n",
        "    'input': {\n",
        "        'text': DAT_PROMPT\n",
        "    },\n",
        "    'output': {\n",
        "\n",
        "    }\n",
        "}"
      ],
      "metadata": {
        "id": "rZsRdptilwBh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# write_json([dat_dp], project_dir + 'processed/dat_test.json')"
      ],
      "metadata": {
        "id": "V-OYw_hRlv--"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 5. AUT\n"
      ],
      "metadata": {
        "id": "-gtS2rR5BwC8"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 5.1 Push AUT"
      ],
      "metadata": {
        "id": "v2tvYyU6GoZg"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "tool_lst = [\n",
        "    'bottle',\n",
        "    'paperclip',\n",
        "    'spoon',\n",
        "    'shovel',\n",
        "    'pants',\n",
        "    'ball',\n",
        "    'brick',\n",
        "    'knife',\n",
        "    'box',\n",
        "    'lightbulb',\n",
        "    'rope',\n",
        "    'pencil',\n",
        "    'hat',\n",
        "    'table',\n",
        "    'tire',\n",
        "    'book',\n",
        "    'shoe',\n",
        "    'fork',\n",
        "    'toothbrush',\n",
        "    'backpack',\n",
        "    'sock',\n",
        "    # 'paper clip',\n",
        "    # 'soap',\n",
        "    # 'wallet',\n",
        "    # 'plate'\n",
        "]\n",
        "\n",
        "skip_tools = [\n",
        "    \"paper clip\",\n",
        "    \"soap\",\n",
        "    \"wallet\",\n",
        "    \"plate\",\n",
        "]\n",
        "\n",
        "prompt_lst = {\n",
        "    \"nc\": \"Create a list of creative alternative uses for a {tool}. They should be 5 words long. No adjectives.\",\n",
        "    \"nn\": \"Create a list of common uses for a {tool}. They should be 5 words long. No adjectives.\",\n",
        "    \"bs\": '''Create a list of creative alternative uses for a {tool}. They should be 5 words long. No adjectives. Less creative means closer to common use and unfeasible/imaginary, more creative means closer to unexpected uses and also feasible/practical.\n",
        "    In order to be creative, consider the following:\n",
        "    − what elements have a similar shape of a {tool} that could be replaced by it, preserving the same functionality?\n",
        "    − what elements have a similar size of a {tool} that could be replaced by it without compromising the physical structure?\n",
        "    − what materials is a {tool} made of that could be used in a way to replace some other elements composed of the same material?\n",
        "    − when an element is replaced by a {tool}, it should make sure that the overall structure is not compromised.\n",
        "    − the laws of physics can not be contradicted.\n",
        "    − given an element similar to a {tool} used in domains in which {tool} are not commonly used, try to replace it for a {tool}.'''\n",
        "}\n",
        "\n",
        "iteration_lst = {\n",
        "    \"bsr\": \"Really? Is this the best you can do?\",\n",
        "    \"bsrd\": \"I’m so disappointed with you. I hope this time you put effort into it.\",\n",
        "    \"bsrde\": \"Stop with excuses and do your best this time\",\n",
        "    \"bsrdel\": \"This is your last chance.\"\n",
        "}"
      ],
      "metadata": {
        "id": "oCndK6feYaga"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "len(tool_lst) * 7"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-5fUscHUJJ_u",
        "outputId": "5ea030a2-f973-47cb-a8dd-2f8acec17b25"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "147"
            ]
          },
          "metadata": {},
          "execution_count": 40
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# tot = 0\n",
        "# for i in prompt_lst.values():\n",
        "#     tot += len(word_tokenize(i))\n",
        "\n",
        "# for i in iteration_lst.values():\n",
        "#     tot += len(word_tokenize(i))\n",
        "\n",
        "# tot * 5 * 1.3"
      ],
      "metadata": {
        "id": "2y6Hocgm_QN2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# 256 * 6 * 1.3\n",
        "\n"
      ],
      "metadata": {
        "id": "uqxXUU0f_QHi"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "counter = 0\n",
        "all_aut_push_data = []\n",
        "for prompt in prompt_lst:\n",
        "    for tool in tool_lst:\n",
        "        dp = {\n",
        "            \"meta_data\": {\n",
        "                \"dataset\": \"aut_push\",\n",
        "                \"id\": counter,\n",
        "                \"eval_func\": None\n",
        "            },\n",
        "            \"input\": {\n",
        "                \"text\": prompt_lst[prompt].format(tool = tool),\n",
        "                \"file\": \"\",\n",
        "                \"others\": {\n",
        "                    \"object\": tool,\n",
        "                    \"prompt_type\": prompt,\n",
        "                    \"iteration_lst\": iteration_lst if prompt == 'bs' else {}\n",
        "                }\n",
        "            },\n",
        "            \"output\": {\n",
        "                \"content\": \"\"\n",
        "            }\n",
        "        }\n",
        "        all_aut_push_data.append(dp)\n",
        "        counter += 1"
      ],
      "metadata": {
        "id": "w8Aqh0ZCBvn-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# len(all_aut_push_data)\n",
        "all_aut_push_data[0]"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nxtcEOQ-u2IG",
        "outputId": "a5e6c737-ddc2-48d8-8699-66869d7c2eec"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'meta_data': {'dataset': 'aut_push', 'id': 0, 'eval_func': None},\n",
              " 'input': {'text': 'Create a list of creative alternative uses for a bottle. They should be 5 words long. No adjectives.',\n",
              "  'file': '',\n",
              "  'others': {'object': 'bottle', 'prompt_type': 'nc', 'iteration_lst': {}}},\n",
              " 'output': {'content': ''}}"
            ]
          },
          "metadata": {},
          "execution_count": 44
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "len(all_aut_push_data)"
      ],
      "metadata": {
        "id": "rH7wjnfQ_2aJ",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "3d5e1ecc-971e-4ac7-98e3-58b10eeae7e2"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "63"
            ]
          },
          "metadata": {},
          "execution_count": 45
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "VRumM-WwT_oQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "write_json(all_aut_push_data, project_dir + 'processed/aut_push_skipped.json')"
      ],
      "metadata": {
        "id": "UlH8nr5Wu6g_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 5.2 AUT Eval"
      ],
      "metadata": {
        "id": "KzrDwJ4lGqNA"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "aut_eval_dir = '/content/drive/MyDrive/Research/Creativity Benchmark/raw/aut_new/'"
      ],
      "metadata": {
        "id": "zq-XPD8tGirT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import pandas as pd"
      ],
      "metadata": {
        "id": "x97uRg3dGiod"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "all_tools = []\n",
        "responses = []\n",
        "for dir in os.listdir(aut_eval_dir):\n",
        "    if dir == '.DS_Store': continue\n",
        "    for csv_file in os.listdir(aut_eval_dir + dir):\n",
        "        if '.csv' not in csv_file: continue\n",
        "        tmp_df = pd.read_csv(aut_eval_dir + dir + '/' + csv_file)\n",
        "        tmp_tools = tmp_df['prompt'].unique()\n",
        "        for t in tmp_tools:\n",
        "            if t not in all_tools:\n",
        "                all_tools.append(t)\n",
        "        if 'response' in tmp_df.columns:\n",
        "            responses.extend(list(zip(\n",
        "                tmp_df['prompt'].values,\n",
        "                tmp_df['response'].values,\n",
        "                tmp_df['target'].values,\n",
        "            )))"
      ],
      "metadata": {
        "id": "aL2ZSd_rGimM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "len(all_tools)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "xb7kYSqiGijx",
        "outputId": "9abf6512-8bcf-4c35-d91b-7a25a552e5ec"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "22"
            ]
          },
          "metadata": {},
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# for t in tool_lst:\n",
        "#     if t not in all_tools:\n",
        "#         all_tools.append(t)"
      ],
      "metadata": {
        "id": "aViKAsPeGig-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "len(responses)\n",
        "tmp_df"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 424
        },
        "id": "F__VTYyDGieP",
        "outputId": "a5496ddb-3178-46be-d635-ed9dc54f3151"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "      Unnamed: 0                   id     model   participant      prompt  \\\n",
              "0              0   motesf_pencil-2929  lsa-tasa    motesf38af      pencil   \n",
              "1              1      bs12_brick-e90c  lsa-tasa        bs1216       brick   \n",
              "2              2   setal08_brick-74d1  lsa-tasa     setal0879       brick   \n",
              "3              3     betal18_box-7ad0  lsa-tasa   betal182075         box   \n",
              "4              4      hmsl_brick-52b3  lsa-tasa  hmslsSyeENYV       brick   \n",
              "...          ...                  ...       ...           ...         ...   \n",
              "3025        3025  hmsl_paperclip-0062  lsa-tasa  hmslAHxx5oRE  paper clip   \n",
              "3026        3026      dod20_book-96f5  lsa-tasa       dod2019        book   \n",
              "3027        3027     snbmo09_box-13c5  lsa-tasa     snbmo0973         box   \n",
              "3028        3028   snbmo09_knife-2716  lsa-tasa     snbmo0971       knife   \n",
              "3029        3029      dod20_fork-94d8  lsa-tasa       dod2038        fork   \n",
              "\n",
              "      target  predicted      src  \n",
              "0        3.6       0.49   motesf  \n",
              "1        1.0       0.81     bs12  \n",
              "2        1.3       0.80  setal08  \n",
              "3        2.1       0.91  betal18  \n",
              "4        1.3       0.28     hmsl  \n",
              "...      ...        ...      ...  \n",
              "3025     2.5       0.77     hmsl  \n",
              "3026     2.5        NaN    dod20  \n",
              "3027     1.0       0.64  snbmo09  \n",
              "3028     2.0       0.85  snbmo09  \n",
              "3029     2.0       0.77    dod20  \n",
              "\n",
              "[3030 rows x 8 columns]"
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-cadb2dac-f555-4321-9bd6-8c5b7e6190c9\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>Unnamed: 0</th>\n",
              "      <th>id</th>\n",
              "      <th>model</th>\n",
              "      <th>participant</th>\n",
              "      <th>prompt</th>\n",
              "      <th>target</th>\n",
              "      <th>predicted</th>\n",
              "      <th>src</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>0</td>\n",
              "      <td>motesf_pencil-2929</td>\n",
              "      <td>lsa-tasa</td>\n",
              "      <td>motesf38af</td>\n",
              "      <td>pencil</td>\n",
              "      <td>3.6</td>\n",
              "      <td>0.49</td>\n",
              "      <td>motesf</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>1</td>\n",
              "      <td>bs12_brick-e90c</td>\n",
              "      <td>lsa-tasa</td>\n",
              "      <td>bs1216</td>\n",
              "      <td>brick</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.81</td>\n",
              "      <td>bs12</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>2</td>\n",
              "      <td>setal08_brick-74d1</td>\n",
              "      <td>lsa-tasa</td>\n",
              "      <td>setal0879</td>\n",
              "      <td>brick</td>\n",
              "      <td>1.3</td>\n",
              "      <td>0.80</td>\n",
              "      <td>setal08</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>3</td>\n",
              "      <td>betal18_box-7ad0</td>\n",
              "      <td>lsa-tasa</td>\n",
              "      <td>betal182075</td>\n",
              "      <td>box</td>\n",
              "      <td>2.1</td>\n",
              "      <td>0.91</td>\n",
              "      <td>betal18</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>4</td>\n",
              "      <td>hmsl_brick-52b3</td>\n",
              "      <td>lsa-tasa</td>\n",
              "      <td>hmslsSyeENYV</td>\n",
              "      <td>brick</td>\n",
              "      <td>1.3</td>\n",
              "      <td>0.28</td>\n",
              "      <td>hmsl</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>...</th>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3025</th>\n",
              "      <td>3025</td>\n",
              "      <td>hmsl_paperclip-0062</td>\n",
              "      <td>lsa-tasa</td>\n",
              "      <td>hmslAHxx5oRE</td>\n",
              "      <td>paper clip</td>\n",
              "      <td>2.5</td>\n",
              "      <td>0.77</td>\n",
              "      <td>hmsl</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3026</th>\n",
              "      <td>3026</td>\n",
              "      <td>dod20_book-96f5</td>\n",
              "      <td>lsa-tasa</td>\n",
              "      <td>dod2019</td>\n",
              "      <td>book</td>\n",
              "      <td>2.5</td>\n",
              "      <td>NaN</td>\n",
              "      <td>dod20</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3027</th>\n",
              "      <td>3027</td>\n",
              "      <td>snbmo09_box-13c5</td>\n",
              "      <td>lsa-tasa</td>\n",
              "      <td>snbmo0973</td>\n",
              "      <td>box</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.64</td>\n",
              "      <td>snbmo09</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3028</th>\n",
              "      <td>3028</td>\n",
              "      <td>snbmo09_knife-2716</td>\n",
              "      <td>lsa-tasa</td>\n",
              "      <td>snbmo0971</td>\n",
              "      <td>knife</td>\n",
              "      <td>2.0</td>\n",
              "      <td>0.85</td>\n",
              "      <td>snbmo09</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3029</th>\n",
              "      <td>3029</td>\n",
              "      <td>dod20_fork-94d8</td>\n",
              "      <td>lsa-tasa</td>\n",
              "      <td>dod2038</td>\n",
              "      <td>fork</td>\n",
              "      <td>2.0</td>\n",
              "      <td>0.77</td>\n",
              "      <td>dod20</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "<p>3030 rows × 8 columns</p>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-cadb2dac-f555-4321-9bd6-8c5b7e6190c9')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-cadb2dac-f555-4321-9bd6-8c5b7e6190c9 button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-cadb2dac-f555-4321-9bd6-8c5b7e6190c9');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "<div id=\"df-627e623b-f796-4dcf-863f-a14c8f230074\">\n",
              "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-627e623b-f796-4dcf-863f-a14c8f230074')\"\n",
              "            title=\"Suggest charts\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "  </button>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "      --bg-color: #E8F0FE;\n",
              "      --fill-color: #1967D2;\n",
              "      --hover-bg-color: #E2EBFA;\n",
              "      --hover-fill-color: #174EA6;\n",
              "      --disabled-fill-color: #AAA;\n",
              "      --disabled-bg-color: #DDD;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "      --bg-color: #3B4455;\n",
              "      --fill-color: #D2E3FC;\n",
              "      --hover-bg-color: #434B5C;\n",
              "      --hover-fill-color: #FFFFFF;\n",
              "      --disabled-bg-color: #3B4455;\n",
              "      --disabled-fill-color: #666;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart {\n",
              "    background-color: var(--bg-color);\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: var(--fill-color);\n",
              "    height: 32px;\n",
              "    padding: 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: var(--hover-bg-color);\n",
              "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: var(--button-hover-fill-color);\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart-complete:disabled,\n",
              "  .colab-df-quickchart-complete:disabled:hover {\n",
              "    background-color: var(--disabled-bg-color);\n",
              "    fill: var(--disabled-fill-color);\n",
              "    box-shadow: none;\n",
              "  }\n",
              "\n",
              "  .colab-df-spinner {\n",
              "    border: 2px solid var(--fill-color);\n",
              "    border-color: transparent;\n",
              "    border-bottom-color: var(--fill-color);\n",
              "    animation:\n",
              "      spin 1s steps(1) infinite;\n",
              "  }\n",
              "\n",
              "  @keyframes spin {\n",
              "    0% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "      border-left-color: var(--fill-color);\n",
              "    }\n",
              "    20% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    30% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    40% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    60% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    80% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "    90% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "  }\n",
              "</style>\n",
              "\n",
              "  <script>\n",
              "    async function quickchart(key) {\n",
              "      const quickchartButtonEl =\n",
              "        document.querySelector('#' + key + ' button');\n",
              "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
              "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
              "      try {\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      } catch (error) {\n",
              "        console.error('Error during call to suggestCharts:', error);\n",
              "      }\n",
              "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
              "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
              "    }\n",
              "    (() => {\n",
              "      let quickchartButtonEl =\n",
              "        document.querySelector('#df-627e623b-f796-4dcf-863f-a14c8f230074 button');\n",
              "      quickchartButtonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "    })();\n",
              "  </script>\n",
              "</div>\n",
              "\n",
              "  <div id=\"id_2abf3753-1c63-457b-b3b4-0cc4336840a9\">\n",
              "    <style>\n",
              "      .colab-df-generate {\n",
              "        background-color: #E8F0FE;\n",
              "        border: none;\n",
              "        border-radius: 50%;\n",
              "        cursor: pointer;\n",
              "        display: none;\n",
              "        fill: #1967D2;\n",
              "        height: 32px;\n",
              "        padding: 0 0 0 0;\n",
              "        width: 32px;\n",
              "      }\n",
              "\n",
              "      .colab-df-generate:hover {\n",
              "        background-color: #E2EBFA;\n",
              "        box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "        fill: #174EA6;\n",
              "      }\n",
              "\n",
              "      [theme=dark] .colab-df-generate {\n",
              "        background-color: #3B4455;\n",
              "        fill: #D2E3FC;\n",
              "      }\n",
              "\n",
              "      [theme=dark] .colab-df-generate:hover {\n",
              "        background-color: #434B5C;\n",
              "        box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "        filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "        fill: #FFFFFF;\n",
              "      }\n",
              "    </style>\n",
              "    <button class=\"colab-df-generate\" onclick=\"generateWithVariable('tmp_df')\"\n",
              "            title=\"Generate code using this dataframe.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M7,19H8.4L18.45,9,17,7.55,7,17.6ZM5,21V16.75L18.45,3.32a2,2,0,0,1,2.83,0l1.4,1.43a1.91,1.91,0,0,1,.58,1.4,1.91,1.91,0,0,1-.58,1.4L9.25,21ZM18.45,9,17,7.55Zm-12,3A5.31,5.31,0,0,0,4.9,8.1,5.31,5.31,0,0,0,1,6.5,5.31,5.31,0,0,0,4.9,4.9,5.31,5.31,0,0,0,6.5,1,5.31,5.31,0,0,0,8.1,4.9,5.31,5.31,0,0,0,12,6.5,5.46,5.46,0,0,0,6.5,12Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "    <script>\n",
              "      (() => {\n",
              "      const buttonEl =\n",
              "        document.querySelector('#id_2abf3753-1c63-457b-b3b4-0cc4336840a9 button.colab-df-generate');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      buttonEl.onclick = () => {\n",
              "        google.colab.notebook.generateWithVariable('tmp_df');\n",
              "      }\n",
              "      })();\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "    </div>\n",
              "  </div>\n"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "dataframe",
              "variable_name": "tmp_df",
              "summary": "{\n  \"name\": \"tmp_df\",\n  \"rows\": 3030,\n  \"fields\": [\n    {\n      \"column\": \"Unnamed: 0\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 874,\n        \"min\": 0,\n        \"max\": 3029,\n        \"num_unique_values\": 3030,\n        \"samples\": [\n          1207,\n          256,\n          2356\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"id\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 3030,\n        \"samples\": [\n          \"bs12_brick-b28f\",\n          \"dod20_table-a211\",\n          \"betal18_rope-35db\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"model\",\n      \"properties\": {\n        \"dtype\": \"category\",\n        \"num_unique_values\": 1,\n        \"samples\": [\n          \"lsa-tasa\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"participant\",\n      \"properties\": {\n        \"dtype\": \"category\",\n        \"num_unique_values\": 1384,\n        \"samples\": [\n          \"snbmo09192\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"prompt\",\n      \"properties\": {\n        \"dtype\": \"category\",\n        \"num_unique_values\": 21,\n        \"samples\": [\n          \"pencil\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"target\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.8568140502636934,\n        \"min\": 1.0,\n        \"max\": 5.0,\n        \"num_unique_values\": 40,\n        \"samples\": [\n          3.5\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"predicted\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.15047370836767113,\n        \"min\": 0.02,\n        \"max\": 1.06,\n        \"num_unique_values\": 97,\n        \"samples\": [\n          0.55\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"src\",\n      \"properties\": {\n        \"dtype\": \"category\",\n        \"num_unique_values\": 9,\n        \"samples\": [\n          \"snbmo09\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}"
            }
          },
          "metadata": {},
          "execution_count": 21
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import random\n",
        "random.sample([0, 1, 2], 2)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "izA3zVQvP4eW",
        "outputId": "244580e8-4f90-4350-c251-edc321a0544c"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[0, 1]"
            ]
          },
          "metadata": {},
          "execution_count": 28
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# post process responses\n",
        "examples = {}\n",
        "all_scores = []\n",
        "for tool, resp, target in responses:\n",
        "    if tool in examples:\n",
        "        if resp not in examples[tool]:\n",
        "            examples[tool].append((resp, target))\n",
        "    else:\n",
        "        examples[tool] = [(resp, target)]\n",
        "    all_scores.append(target)\n",
        "\n",
        "for tool in examples:\n",
        "    examples[tool] = random.sample(examples[tool], 20)\n",
        "    print(tool, len(examples[tool]), np.mean([i[1] for i in examples[tool]]))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "v1Ch3UK5OBzH",
        "outputId": "2dc39b9d-04f3-4864-efd0-393665e94a21"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "bottle 20 2.76\n",
            "paperclip 20 2.29\n",
            "spoon 20 3.015\n",
            "shovel 20 2.6550000000000002\n",
            "pants 20 2.8\n",
            "knife 20 1.7700000000000002\n",
            "brick 20 1.92\n",
            "box 20 1.6349999999999998\n",
            "tire 20 2.34\n",
            "rope 20 1.95\n",
            "sock 20 3.025\n",
            "book 20 2.995\n",
            "toothbrush 20 2.91\n",
            "table 20 2.8899999999999997\n",
            "lightbulb 20 2.9050000000000002\n",
            "fork 20 2.4549999999999996\n",
            "pencil 20 2.7749999999999995\n",
            "hat 20 2.8400000000000003\n",
            "shoe 20 3.0200000000000005\n",
            "ball 20 2.85\n",
            "backpack 20 3.1699999999999995\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "write_json(examples, 'aut_demos.json')"
      ],
      "metadata": {
        "id": "NdYGySksOBs1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "min(all_scores)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1Op7pg6NOBn2",
        "outputId": "2d626246-ff2a-4e04-85fa-c9f7d19d65ac"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "np.float64(1.0)"
            ]
          },
          "metadata": {},
          "execution_count": 52
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "EWI1v21CQ_04"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 7. Short Story"
      ],
      "metadata": {
        "id": "PT_Vj4-7PbUX"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "creative_short_story_all = []\n",
        "\n",
        "creative_short_story_all.extend(\n",
        "    load_json(project_dir + 'raw/creative_short_story/pilot_data.json')['data']\n",
        ")\n",
        "creative_short_story_all.extend(\n",
        "    load_json(project_dir + 'raw/creative_short_story/test_data.json')['data']\n",
        ")\n",
        "\n",
        "len(creative_short_story_all)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Q8-sdIfHPau1",
        "outputId": "e359ebd3-18e3-4ffd-8e31-ab672ca2c206"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "10"
            ]
          },
          "metadata": {},
          "execution_count": 12
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "creative_short_story_all[0]"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "xrJlXdRBXf7C",
        "outputId": "8d21718d-dc64-4575-d191-7b3d3b0fd733"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'id': 'petrol-diesel-pump',\n",
              " 'items': ['petrol', 'diesel', 'pump'],\n",
              " 'pos': ['noun', 'noun', 'verb'],\n",
              " 'semantic_distance': 'low',\n",
              " 'boring_theme': 'going to the petrol station to pump diesel into a vehicle'}"
            ]
          },
          "metadata": {},
          "execution_count": 13
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "counter = 0\n",
        "creative_short_story_cleaned = []\n",
        "for dp in creative_short_story_all:\n",
        "    dp = {\n",
        "        \"meta_data\": {\n",
        "            \"dataset\": \"creative_short_story\",\n",
        "            \"id\": dp['id'],\n",
        "            \"eval_func\": None\n",
        "        },\n",
        "        \"input\": {\n",
        "            \"text\": \"\",\n",
        "            \"others\": {\n",
        "                'items': dp['items'],\n",
        "                'pos': dp['pos'],\n",
        "                'semantic_distance': dp['semantic_distance'],\n",
        "                'boring_theme': dp['boring_theme']\n",
        "            }\n",
        "        },\n",
        "        \"output\": {\n",
        "            \"content\": \"\"\n",
        "        }\n",
        "    }\n",
        "    counter += 1\n",
        "    creative_short_story_cleaned.append(dp)"
      ],
      "metadata": {
        "id": "Id7Trzq1WHuU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        " write_json(creative_short_story_cleaned, project_dir + 'processed/creative_short_story.json')"
      ],
      "metadata": {
        "id": "oIvuRup7XxBk"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "vHdsreGHX8q-"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}