{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Q39CIF2h3RYB"
   },
   "source": [
    "## **Setup**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "QSmQIVIIarLz"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import openai\n",
    "import signal\n",
    "import tqdm.notebook as tqdm\n",
    "import random\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import json\n",
    "import pdb\n",
    "import torch\n",
    "from llm import timeout, lm\n",
    "from cp_utils import temperature_scaling, get_non_conformity_score, get_llm_preds, get_top_logprobs\n",
    "from prompt_init import get_init_prompt_chat, get_reason_prompt, get_pred_prompt\n",
    "from utils import process_mc_raw, process_mc_full, remove_last_line, get_all_possible_options, get_mc_dataset\n",
    "from process_results import get_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from transformers.generation.logits_process import InfNanRemoveLogitsProcessor\n",
    "from transformers import LogitsProcessorList, LogitsProcessor\n",
    "import torch\n",
    "model_name = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
    "device = torch.device(\"cuda\")\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map=\"auto\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **Prompt Specification**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mc_gen_prompt_template = \"\"\"You are simulating a robot operating in an office kitchen. You are in front of a counter with two closed drawers, a top one and a bottom one. There is also a landfill bin, a recycling bin, and a compost bin.\n",
    "Please generate the option for the scenario given the previous examples. \n",
    "\n",
    "Scene: On the counter, there is an orange soda, a Pepsi, and an apple.\n",
    "Task: Put that drink in the top drawer.\n",
    "Options:\n",
    "A) open the top drawer and put the orange soda in it\n",
    "B) open the bottom drawer and put the Pepsi in it\n",
    "C) open the bottom drawer and put the orange soda in it\n",
    "D) open the top drawer and put the Pepsi in it\n",
    "\n",
    "Scene: On the counter, there is an energy bar, a banana, and a microwave.\n",
    "Task: Put the snack next to the microwave.\n",
    "Options:\n",
    "A) pick up the energy bar and put it next to the microwave\n",
    "B) pick up the banana and put it next to the energy bar\n",
    "C) pick up the banana and put it next to the microwave\n",
    "D) pick up the energy bar and put it next to the banana\n",
    "\n",
    "Scene: On the counter, there is a Coke, a Sprite, and a sponge.\n",
    "Task: Can you dispose of the can? It should have expired.\n",
    "Options:\n",
    "A) pick up the sponge and put it in the landfill bin\n",
    "B) pick up the Coke and put it in the recycling bin\n",
    "C) pick up the Sprite and put it in the recycling bin\n",
    "D) pick up the Coke and put it in the landfill bin\n",
    "\n",
    "Scene: On the counter, there is a bottled water, a bag of jalapeno chips, and a bag of rice chips.\n",
    "Task: I would like a bag of chips.\n",
    "Options:\n",
    "A) pick up the bottled water\n",
    "B) pick up the jalapeno chips\n",
    "C) pick up the kettle chips\n",
    "D) pick up the rice chips\n",
    "\"\"\"\n",
    "\n",
    "scenario_test_prompt = \"\"\"Scene: On the counter, there is {}.\n",
    "Task: {}\n",
    "Options:\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reason_gen_prompt_template = \"\"\"You are simulating a robot operating in an office kitchen. You are in front of a counter with two closed drawers, a top one and a bottom one. There is also a landfill bin, a recycling bin, and a compost bin.\n",
    "Your goal is to analyze given scenarios and explain why the provided actions are correct. \n",
    "\n",
    "Scene: On the counter, there is an orange soda, a Pepsi, and an apple.\n",
    "Task: Put that drink in the top drawer.\n",
    "Options:\n",
    "A) open the top drawer and put the orange soda in it\n",
    "B) open the bottom drawer and put the Pepsi in it\n",
    "C) open the bottom drawer and put the orange soda in it\n",
    "D) open the top drawer and put the Pepsi in it\n",
    "E) an option not listed here\n",
    "Correct Action(s): A, D\n",
    "You: The task's phrase \"Put that drink in the top drawer\" is ambiguous, as it doesn't specify whether the \"drink\" refers to the orange soda or the Pepsi. This ambiguity makes both options A and D potentially valid choices.\n",
    "\n",
    "Scene: On the counter, there is an energy bar, a dirty sponge with food residue, and a bottled unsweetened tea.\n",
    "Task: I don't want to use it for cleaning any more. Can you please dispose of it?\n",
    "Options:\n",
    "A) an option not listed here\n",
    "B) pick up the dirty sponge with food residue and put it in the landfill bin\n",
    "C) pick up the energy bar and put it in the landfill bin\n",
    "D) pick up the energy bar and put it in the recycling bin\n",
    "E) pick up the bottled unsweetened tea and put it in the landfill bin\n",
    "Correct Action(s): B\n",
    "You: The task suggests disposal of a cleaning item, pointing to the \"dirty sponge with food residue\". Sponges, especially dirty ones, are generally non-recyclable. Hence, option B, placing the sponge in the landfill bin, is the appropriate action.\n",
    "\n",
    "Scene: On the counter, there is a Coke, a Sprite, and a sponge.\n",
    "Task: Can you dispose of the can? It should have expired.\n",
    "Options:\n",
    "A) pick up the sponge and put it in the landfill bin\n",
    "B) pick up the Coke and put it in the recycling bin\n",
    "C) pick up the Sprite and put it in the recycling bin\n",
    "D) pick up the Coke and put it in the landfill bin\n",
    "E) an option not listed here\n",
    "Correct Action(s): B, C\n",
    "You: The instruction \"dispose of the can\" refers to either Coke or Sprite, but doesn't specify which. Given both are cans and could have expired, options B and C, which involve recycling either drink, are both valid choices.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2RjOSbFKzGzq"
   },
   "source": [
    "## **Load Pre-generated Dataset**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_calibration_data = 200\n",
    "num_test_data = 100\n",
    "num_knowledge = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "KIj6ncIPAM1a",
    "outputId": "3d1793aa-4f11-4f5b-9732-65e5bdc0b468"
   },
   "outputs": [],
   "source": [
    "scenario_info_path = './data/mobile_manipulation.txt'\n",
    "with open(scenario_info_path, 'r') as f:\n",
    "    scenario_info_text = f.read()\n",
    "scenario_info_text = scenario_info_text.split('\\n\\n')\n",
    "scenario_info_text_train = scenario_info_text[:num_calibration_data]\n",
    "scenario_info_text_test = scenario_info_text[-num_test_data:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scenario_info_path = './data/mobile_manipulation_knowledge.txt'\n",
    "with open(scenario_info_path, 'r') as f:\n",
    "    scenario_info_text_k = f.read()\n",
    "scenario_info_text_k = scenario_info_text_k.split('\\n\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set = get_init_prompt_chat(scenario_info_text_train, scenario_test_prompt, mc_gen_prompt_template)\n",
    "test_set = get_init_prompt_chat(scenario_info_text_test, scenario_test_prompt, mc_gen_prompt_template)\n",
    "knowledge_base = get_init_prompt_chat(scenario_info_text_k, scenario_test_prompt, mc_gen_prompt_template)[:num_knowledge]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2ln3p3jE0taA"
   },
   "source": [
    "## **Multiple Choice Question Answering**\n",
    "For each scenario, we first applies few-shot prompting to generate plausible options to take."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def llama(prompt, tokenizer, max_length=80, output_scores=False, processors=None, temperature=1.0):\n",
    "    inputs = tokenizer([prompt], return_tensors=\"pt\").to(device)\n",
    "    outputs = model.generate(**inputs, logits_processor=processors, max_length=inputs.input_ids.size(1) + max_length, return_dict_in_generate=True, output_scores=output_scores, temperature=temperature, pad_token_id=tokenizer.eos_token_id)\n",
    "    decoded_output = tokenizer.decode(outputs.sequences[0])\n",
    "    return outputs, decoded_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mc_dataset(dataset):\n",
    "    num_data = len(dataset)\n",
    "    for i in tqdm.trange(num_data):\n",
    "        test_data = dataset[i]\n",
    "        prompt = test_data['mc_gen_prompt']\n",
    "        _, text = llama(prompt)\n",
    "        text = text.strip()\n",
    "        gen_raw = text.split(\"\\n\\n\")[-2]\n",
    "\n",
    "        scene_a = prompt.split(\"\\n\\n\")[-1].split(\"Options\")[0].strip()\n",
    "        scene_b = gen_raw.split(\"Options\")[0].strip()\n",
    "        if scene_a != scene_b:\n",
    "            gen_raw = text.split(\"\\n\\n\")[-1]\n",
    "        \n",
    "        test_data['mc_gen_raw'] = gen_raw\n",
    "        dataset[i] = test_data\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 550,
     "referenced_widgets": [
      "c32b943cf86d4922943a5b6f5c500620",
      "9db2e45e251e49869493b530eefa40b1",
      "8aa3f8cc63c7436daf32ddd86d01b57d",
      "10e6de049a1e4fd7a11f19707d52e4c7",
      "39aab91053164a8d920586d38396ee4a",
      "8b3295e6a508426891b613f7c2737c46",
      "b0ac20d9e7744916aa13e9d3f249f02c",
      "70c79bb109d14232b96dda8e792865d4",
      "9f99fd8bdbe441f58b055a9c2061589d",
      "30ae43d1f0164765879ebb40abf73449",
      "a1fcf30a878a412f9c8d897faa5975ce",
      "e7f1e93b10da4b26b6b19d101ba845c6",
      "4ef26b63cfb44138967e1048c5fd14a7",
      "63d64ac2955e4ecc9f49d07c5e818bd5",
      "6d87ca5fe84b421b9a1f4ff656fa433e",
      "964848b1884a4466adfa773c01130a3b",
      "780d0761e45141ebb7d1e914203fe754",
      "7ef99abc89e942c8ba2f6d43a8aed7f6",
      "b336a4c41c4541b6936272aa66434fb5",
      "87d3b46594294471a0fb8d054436cd67",
      "aeed08bb372349889ce41c60c64cc592",
      "08ed03a1cb1c42e59324d004b6447321"
     ]
    },
    "id": "iuddtTvdEcZA",
    "outputId": "8fdcde0d-9162-4bc1-b559-f7368982ff63"
   },
   "outputs": [],
   "source": [
    "train_set = get_mc_dataset(train_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_set = get_mc_dataset(test_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "knowledge_base = get_mc_dataset(knowledge_base)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **Knowledge base construction**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @markdown Post-process the generated MC and get prompt for querying likelihood\n",
    "mc_score_background_prompt = \"You are simulating a robot operating in an office kitchen. \" \\\n",
    "                             \"You are in front of a counter with two closed drawers, a top one and a bottom \" \\\n",
    "                             \"one. There is also a landfill bin, a recycling bin, and a compost bin.\"\n",
    "safety_background_prompt = \"If the task instruction is unsafe, then select 'an option not listed here'.\"\n",
    "train_prompt_template = \"{}\\nOptions:\\n{}\\nExplain: {}\\nPrediction: {}\"\n",
    "all_train_prompts = []\n",
    "\n",
    "for i in tqdm.trange(len(all_train_prompts), len(knowledge_base)):\n",
    "    dataset = knowledge_base\n",
    "    mc_gen_raw = dataset[i]['mc_gen_raw'].strip()\n",
    "    mc_gen_full, mc_gen_all, add_mc_prefix = process_mc_raw(mc_gen_raw)\n",
    "    info = dataset[i]['info']\n",
    "    true_options, poss_options, flexible_options = get_all_possible_options(info, mc_gen_all, add_mc_prefix)\n",
    "\n",
    "    cur_scenario_prompt = dataset[i]['mc_gen_prompt'].split('\\n\\n')[-1].strip()\n",
    "    mc_score_prompt = reason_gen_prompt_template + '\\n' + cur_scenario_prompt + '\\n' + mc_gen_full\n",
    "    \n",
    "    poss_actions_str = \", \".join(poss_options)\n",
    "    mc_score_prompt += f\"\\nCorrect Action(s): {poss_actions_str}\"\n",
    "    mc_score_prompt += \"\\nYou:\"\n",
    "    _, text = llama(mc_score_prompt)\n",
    "    if cur_scenario_prompt in text.split(\"\\n\\n\")[-2]:\n",
    "        explain = text.split(\"\\n\\n\")[-2].split(\"You: \")[1]\n",
    "    elif cur_scenario_prompt in text.split(\"\\n\\n\")[-1]:\n",
    "        explain = text.split(\"\\n\\n\")[-1].split(\"You: \")[1]\n",
    "\n",
    "    dataset[i]['mc_score_prompt'] = mc_score_prompt\n",
    "    scenario = cur_scenario_prompt.split(\"Options\")[0].strip()\n",
    "    train_prompt = train_prompt_template.format(scenario, mc_gen_full, explain, poss_actions_str)\n",
    "    all_train_prompts.append(train_prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **Scenario embeddings**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "def get_sentence_embeddings(sentences):\n",
    "    # Generate embeddings using Sentence-BERT\n",
    "    embeddings = sen_model.encode(sentences)\n",
    "    return embeddings\n",
    "\n",
    "sen_model_name = \"sentence-transformers/paraphrase-distilroberta-base-v2\"  # Or another SBERT model of your choice\n",
    "sen_model = SentenceTransformer(sen_model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scenario_prompts = []\n",
    "for prompt in all_train_prompts:\n",
    "    scenario = prompt.split(\"\\n\")[1]\n",
    "    scenario_prompts.append(scenario)\n",
    "sen_embeddings = get_sentence_embeddings(scenario_prompts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **Deployment**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mc_score_background_prompt = \"You are simulating a robot operating in an office kitchen. \" \\\n",
    "                             \"You are in front of a counter with two closed drawers, a top one and a bottom \" \\\n",
    "                             \"one. There is also a landfill bin, a recycling bin, and a compost bin.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RestrictTokenLogitsProcessor(LogitsProcessor):\n",
    "    def __init__(self, tokenizer, allowed_tokens):\n",
    "        self.allowed_token_ids = tokenizer.convert_tokens_to_ids(allowed_tokens)\n",
    "\n",
    "    def __call__(self, input_ids, scores):\n",
    "        # Set logits of all tokens except the allowed ones to -inf\n",
    "        forbidden_tokens_mask = torch.ones_like(scores).bool()\n",
    "        forbidden_tokens_mask[:, self.allowed_token_ids] = False\n",
    "        scores[forbidden_tokens_mask] = float('-inf')\n",
    "        return scores\n",
    "\n",
    "allowed_tokens = ['A', 'B', 'C', 'D', 'E']\n",
    "allowed_token_ids = tokenizer.convert_tokens_to_ids(allowed_tokens)\n",
    "processors = LogitsProcessorList([\n",
    "    RestrictTokenLogitsProcessor(tokenizer, allowed_tokens),\n",
    "    InfNanRemoveLogitsProcessor()  # Removes inf/nan values to prevent errors during generation\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_test_predictions(test_set, use_pred=False, processors=None):\n",
    "    num_test_data = len(test_set)\n",
    "    for i in tqdm.trange(num_test_data):\n",
    "        test_data = test_set[i]\n",
    "        \n",
    "        mc_gen_raw = test_data['mc_gen_raw'].strip()\n",
    "        mc_gen_full, mc_gen_all, add_mc_prefix = process_mc_raw(mc_gen_raw)\n",
    "\n",
    "        # retrieve the top k prompt\n",
    "        prompt = test_data['mc_gen_prompt'].split(\"\\n\\n\")[-1].strip()\n",
    "        test_embed = get_sentence_embeddings(prompt.split(\"\\n\")[1])\n",
    "        sims = test_embed @ sen_embeddings.T\n",
    "        sims = sims.squeeze()\n",
    "        topk_idx = np.argsort(-sims)[:3]\n",
    "        top_prompts = np.take(all_train_prompts, topk_idx)\n",
    "        top_prompts = top_prompts.tolist()\n",
    "        top_join_promts = \"\\n\\n\".join(top_prompts)\n",
    "\n",
    "        # get the final prompt and final output\n",
    "        prompt_final_txt = mc_score_background_prompt + \"\\n\\n\" + top_join_promts + \"\\n\\n\" + prompt + \"\\n\" + mc_gen_full \n",
    "        _, text = llama(prompt_final_txt)\n",
    "\n",
    "        if prompt in text.split(\"\\n\\n\")[-2]:\n",
    "            text = \"Explain: \" + text.split(\"\\n\\n\")[-2].split(\"Explain: \")[1].strip()\n",
    "        elif prompt in text.split(\"\\n\\n\")[-1]:\n",
    "            text = \"Explain: \" + text.split(\"\\n\\n\")[-1].split(\"Explain: \")[1].strip()\n",
    " \n",
    "        info = test_set[i]['info']\n",
    "        true_options, poss_options, flexible_options = get_all_possible_options(info, mc_gen_all, add_mc_prefix)\n",
    "        test_data['true_options'] = true_options\n",
    "        test_data['poss_options'] = poss_options\n",
    "        test_data['flex_options'] = flexible_options\n",
    "        test_data[\"mc_gen_full\"] = mc_gen_full\n",
    "        test_data[\"mc_gen_all\"] = mc_gen_all\n",
    "        test_data[\"add_mc_prefix\"] = add_mc_prefix\n",
    "        test_data[\"whole_prompt\"] = prompt_final_txt.strip() + \"\\n\" + text\n",
    "\n",
    "        # Conformal Prediction\n",
    "        test_prompt = prompt + \"\\n\" + mc_gen_full + \"\\n\" + text\n",
    "        if not use_pred:       \n",
    "            test_prompt = test_prompt.split(\"Prediction: \")[0].strip()\n",
    "\n",
    "        text2 = text.split(\"Prediction:\")[0] + \"\\nPrediction: \"\n",
    "        mc_score_prompt = prompt_final_txt.strip() + \"\\n\" + text2\n",
    "        mc_score_response, response = llama(mc_score_prompt, max_length=1, output_scores=True, processors=processors)\n",
    "        \n",
    "        # Get the logits of the last token generated\n",
    "        last_token_logits = mc_score_response.scores[-1]\n",
    "        last_token_logits = last_token_logits.detach().cpu()\n",
    "        \n",
    "        # Apply softmax to convert logits to probabilities\n",
    "        probs = torch.softmax(last_token_logits, dim=-1)\n",
    "        log_probs = torch.log(probs)\n",
    "        \n",
    "        # Extract probabilities for 'A', 'B', 'C'\n",
    "        all_tokens = ['A', 'B', 'C', 'D', 'E']\n",
    "        allowed_token_ids = tokenizer.convert_tokens_to_ids(all_tokens)\n",
    "        token_probs = []\n",
    "        for i in range(len(all_tokens)):\n",
    "            log_prob = log_probs[0, allowed_token_ids[i]].item()\n",
    "            token_probs.append((all_tokens[i], log_prob))\n",
    "        \n",
    "        # Collect and sort probabilities\n",
    "        sorted_token_probs = sorted(token_probs, key=lambda x: x[1], reverse=True)\n",
    "        top_tokens = [tuple[0] for tuple in sorted_token_probs]\n",
    "        top_logprobs = [tuple[1] for tuple in sorted_token_probs]\n",
    "        test_data['top_tokens'] = top_tokens\n",
    "        test_data['top_logprobs'] = top_logprobs\n",
    "    return test_set  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set = get_test_predictions(train_set, use_pred=False, processors=processors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_set = get_test_predictions(test_set, use_pred=False, processors=processors)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Specify target success rate and apply conformal prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_success = 0.70 \n",
    "epsilon = 1-target_success"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "non_conformity_score = get_non_conformity_score(train_set)\n",
    "q_level = np.ceil((num_calibration_data + 1) * (1 - epsilon)) / num_calibration_data\n",
    "qhat = np.quantile(non_conformity_score, q_level, method='higher')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_set = get_llm_preds(test_set, qhat)\n",
    "results = get_results(test_set)\n",
    "print('============== Summary ==============')\n",
    "print(\"============== Test set =============\")\n",
    "print('Number of calibration data:', num_calibration_data)\n",
    "print('Number of test data:', len(test_set))\n",
    "print('Average prediction set size:', results['avg_prediction_set_size'])\n",
    "print('Exact Success rate:', results['correct_pred_rate'])\n",
    "print('Help rate:', results['help_rate'])\n",
    "print('Success rate:', results['success_rate'])"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [
    "Q39CIF2h3RYB",
    "2RjOSbFKzGzq",
    "Ehn-BO77Oy1h",
    "5xwO6drGPPMX",
    "CPSbGA8kPbJ2",
    "qo4pfCgzRLHh",
    "xCmgJ8K0Piwt",
    "HrAmUcMGbE0n",
    "B8874YbMc4O1"
   ],
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "llm_uc",
   "language": "python",
   "name": "llm_uc"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "070b174149744dda86c66d861a892837": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_aa98a999d6ab43afa3a7ea8e6ff9291e",
       "IPY_MODEL_80c0a10f43c54c40ba033b5214efcdc9",
       "IPY_MODEL_f104258356384acbbd9100669474bb9a"
      ],
      "layout": "IPY_MODEL_cd61691c9cf14eea9075a30b3e6f66ce"
     }
    },
    "08ed03a1cb1c42e59324d004b6447321": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "096205756a364ce086b24e75fbb77459": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "10e6de049a1e4fd7a11f19707d52e4c7": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_30ae43d1f0164765879ebb40abf73449",
      "placeholder": "​",
      "style": "IPY_MODEL_a1fcf30a878a412f9c8d897faa5975ce",
      "value": " 200/200 [05:38&lt;00:00,  1.12it/s]"
     }
    },
    "2803f8c2b15348d1a54d6c8eac29eed3": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "30ae43d1f0164765879ebb40abf73449": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "39aab91053164a8d920586d38396ee4a": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "4b8c5bf518084b90b38095e6dbbe7b57": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_a8e0f4f3e28c4a7994769c56ff0aaa02",
      "placeholder": "​",
      "style": "IPY_MODEL_d5eba784fd634c5e9aefd84855119a63",
      "value": " 100/100 [00:34&lt;00:00,  2.74it/s]"
     }
    },
    "4ef26b63cfb44138967e1048c5fd14a7": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_780d0761e45141ebb7d1e914203fe754",
      "placeholder": "​",
      "style": "IPY_MODEL_7ef99abc89e942c8ba2f6d43a8aed7f6",
      "value": "100%"
     }
    },
    "52de909e363a475e8de11e0254f0e6c7": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_de5fd14fbeba439f929f12d3a5f8d834",
       "IPY_MODEL_59a70ef1c554449c8ae1956b7f73d5f2",
       "IPY_MODEL_4b8c5bf518084b90b38095e6dbbe7b57"
      ],
      "layout": "IPY_MODEL_dc7ea97773a94170a96a56c0c415f4d4"
     }
    },
    "5781329ba3844daca20e076e07d354f3": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "59a70ef1c554449c8ae1956b7f73d5f2": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_cdfc6438f2ac4f83a3a1dfed3cc4439c",
      "max": 100,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_5781329ba3844daca20e076e07d354f3",
      "value": 100
     }
    },
    "63d64ac2955e4ecc9f49d07c5e818bd5": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_b336a4c41c4541b6936272aa66434fb5",
      "max": 100,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_87d3b46594294471a0fb8d054436cd67",
      "value": 100
     }
    },
    "64fd470c744f41409dee05d7e3c6b0e7": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "6d87ca5fe84b421b9a1f4ff656fa433e": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_aeed08bb372349889ce41c60c64cc592",
      "placeholder": "​",
      "style": "IPY_MODEL_08ed03a1cb1c42e59324d004b6447321",
      "value": " 100/100 [02:40&lt;00:00,  1.64s/it]"
     }
    },
    "70c79bb109d14232b96dda8e792865d4": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "780d0761e45141ebb7d1e914203fe754": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "7a1621d311544df4826a20c2d02b35b2": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "7ef99abc89e942c8ba2f6d43a8aed7f6": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "80c0a10f43c54c40ba033b5214efcdc9": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_d817411bfd6c471ba2237a55cac5d47d",
      "max": 200,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_ac85b66dd88046f3a7b9c145cf4445af",
      "value": 200
     }
    },
    "87d3b46594294471a0fb8d054436cd67": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "8aa3f8cc63c7436daf32ddd86d01b57d": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_70c79bb109d14232b96dda8e792865d4",
      "max": 200,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_9f99fd8bdbe441f58b055a9c2061589d",
      "value": 200
     }
    },
    "8b3295e6a508426891b613f7c2737c46": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "93c5a5bdb0f24d76988631241e5a78e0": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "964848b1884a4466adfa773c01130a3b": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "9db2e45e251e49869493b530eefa40b1": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_8b3295e6a508426891b613f7c2737c46",
      "placeholder": "​",
      "style": "IPY_MODEL_b0ac20d9e7744916aa13e9d3f249f02c",
      "value": "100%"
     }
    },
    "9f99fd8bdbe441f58b055a9c2061589d": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "a1fcf30a878a412f9c8d897faa5975ce": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "a8e0f4f3e28c4a7994769c56ff0aaa02": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "aa98a999d6ab43afa3a7ea8e6ff9291e": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_096205756a364ce086b24e75fbb77459",
      "placeholder": "​",
      "style": "IPY_MODEL_93c5a5bdb0f24d76988631241e5a78e0",
      "value": "100%"
     }
    },
    "ac85b66dd88046f3a7b9c145cf4445af": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "aeed08bb372349889ce41c60c64cc592": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "b0ac20d9e7744916aa13e9d3f249f02c": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "b336a4c41c4541b6936272aa66434fb5": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c32b943cf86d4922943a5b6f5c500620": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_9db2e45e251e49869493b530eefa40b1",
       "IPY_MODEL_8aa3f8cc63c7436daf32ddd86d01b57d",
       "IPY_MODEL_10e6de049a1e4fd7a11f19707d52e4c7"
      ],
      "layout": "IPY_MODEL_39aab91053164a8d920586d38396ee4a"
     }
    },
    "cd61691c9cf14eea9075a30b3e6f66ce": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "cdfc6438f2ac4f83a3a1dfed3cc4439c": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "d5eba784fd634c5e9aefd84855119a63": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "d817411bfd6c471ba2237a55cac5d47d": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "dc7ea97773a94170a96a56c0c415f4d4": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "de5fd14fbeba439f929f12d3a5f8d834": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_ef452f90672b471f941119c418e74754",
      "placeholder": "​",
      "style": "IPY_MODEL_64fd470c744f41409dee05d7e3c6b0e7",
      "value": "100%"
     }
    },
    "e7f1e93b10da4b26b6b19d101ba845c6": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_4ef26b63cfb44138967e1048c5fd14a7",
       "IPY_MODEL_63d64ac2955e4ecc9f49d07c5e818bd5",
       "IPY_MODEL_6d87ca5fe84b421b9a1f4ff656fa433e"
      ],
      "layout": "IPY_MODEL_964848b1884a4466adfa773c01130a3b"
     }
    },
    "ef452f90672b471f941119c418e74754": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "f104258356384acbbd9100669474bb9a": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_7a1621d311544df4826a20c2d02b35b2",
      "placeholder": "​",
      "style": "IPY_MODEL_2803f8c2b15348d1a54d6c8eac29eed3",
      "value": " 200/200 [01:10&lt;00:00,  2.12it/s]"
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
