{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%load_ext lab_black"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import math\n",
    "import random\n",
    "from itertools import product\n",
    "from tqdm import tqdm\n",
    "from typing import Dict, List, Set, Union, Tuple\n",
    "import yaml\n",
    "from ast import literal_eval\n",
    "\n",
    "import json\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "# from transformers import GPTNeoXForCausalLM, AutoTokenizer\n",
    "import torch\n",
    "from typing import List\n",
    "from matplotlib import font_manager as fm, pyplot as plt\n",
    "import numpy as np\n",
    "import wandb\n",
    "\n",
    "# import statsmodels.api as sm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/home/kevin/code/rycolab/context-vs-prior-finetuning'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# IMPORTANT: Run as if from project root so that imports work.\n",
    "pardir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))\n",
    "os.chdir(pardir)\n",
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/kevin/mambaforge/envs/sftcontext/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Could not find the bitsandbytes CUDA binary at PosixPath('/home/kevin/mambaforge/envs/sftcontext/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda120.so')\n",
      "The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.\n"
     ]
    }
   ],
   "source": [
    "from model_utils.utils import (\n",
    "    construct_paths_and_dataset_kwargs,\n",
    "    construct_test_results_dir,\n",
    "    EvalConfig,\n",
    ")\n",
    "from preprocessing.dataset import load_dataset_from_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "##################\n",
    "### Parameters ###\n",
    "##################\n",
    "\n",
    "# Data parameters\n",
    "\n",
    "# wandb stuff\n",
    "PROJECT_NAME = \"sftcontext\"\n",
    "GROUP_NAME = None\n",
    "TAGS = [\"basefakepedia\", \"analysis\", \"summarize\", \"across-models\"]\n",
    "LOG_DATASETS = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Construct dataframes for analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "analysis_dir = \"analysis/summarize/generalization_datasets\"\n",
    "os.makedirs(analysis_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[EvalConfig(dataset_name='BaseFakepedia', subsplit='nodup_relpid', k_demonstrations=0, context_weight_format='instruction', do_steering=False), EvalConfig(dataset_name='BaseFakepedia', subsplit='nodup_relpid', k_demonstrations=0, context_weight_format='float', do_steering=False)]\n"
     ]
    }
   ],
   "source": [
    "dataset_names = [\"BaseFakepedia\"]\n",
    "zero_shot_evals = [\n",
    "    {\n",
    "        \"dataset_name\": \"BaseFakepedia\",\n",
    "        \"subsplit\": \"nodup_relpid\",\n",
    "        \"k_demonstrations\": 0,\n",
    "        \"context_weight_format\": \"instruction\",\n",
    "        \"do_steering\": False,\n",
    "    },\n",
    "    {\n",
    "        \"dataset_name\": \"BaseFakepedia\",\n",
    "        \"subsplit\": \"nodup_relpid\",\n",
    "        \"k_demonstrations\": 0,\n",
    "        \"context_weight_format\": \"float\",\n",
    "        \"do_steering\": False,\n",
    "    },\n",
    "]\n",
    "\n",
    "evals = zero_shot_evals\n",
    "evals = [EvalConfig(**eval) for eval in evals]\n",
    "print(evals)\n",
    "subsplit_names = [\n",
    "    \"nodup_relpid\",\n",
    "]\n",
    "seeds = [1, 2, 3]\n",
    "train_sizes = [2048]\n",
    "no_train_statuses = [False]\n",
    "# no_train_statuses = [False]\n",
    "peft_modules = [\n",
    "    json.dumps(\n",
    "        [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n",
    "        # [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
    "        separators=(\",\", \":\"),\n",
    "    ),\n",
    "]\n",
    "context_weight_formats = [\"instruction\", \"float\"]\n",
    "# model_id_and_bs_and_ga_and_quantize_and_peft_tuples = [\n",
    "#     (\"Meta-Llama-3.1-8B-Instruct\", 8, 2, None, True),\n",
    "#     (\"Meta-Llama-3.1-8B\", 8, 2, None, True),\n",
    "#     (\"Meta-Llama-3.1-8B\", 4, 2, None, True),\n",
    "# ]\n",
    "\n",
    "model_id_and_bs_and_ga_and_quantize_and_peft_tuples = [\n",
    "    (\"Meta-Llama-3.1-8B-Instruct\", 8, 2, None, True),\n",
    "    # (\"Meta-Llama-3.1-8B\", 8, 2, None, True),\n",
    "    # (\"Meta-Llama-3.1-8B\", 4, 2, None, True),\n",
    "    (\"Mistral-7B-Instruct-v0.3\", 8, 2, None, True),\n",
    "    # (\"Mistral-7B-v0.3\", 4, 2, None, True),\n",
    "    (\"gemma-2-9b-it\", 2, 8, None, True),\n",
    "    # (\"gemma-2-9b\", 4, 2, None, True),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/1/models/Meta-Llama-3.1-8B-Instruct-peftq_proj_k_proj_v_proj_o_proj-bs8-ga2-cwf_instruction/results/BaseFakepedia-sp_nodup_relpid-k0_OOD-cwf_instruction\n",
      "data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/1/models/Meta-Llama-3.1-8B-Instruct-peftq_proj_k_proj_v_proj_o_proj-bs8-ga2-cwf_instruction/results/BaseFakepedia-sp_nodup_relpid-k0_OOD-cwf_float\n",
      "data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/1/models/Meta-Llama-3.1-8B-Instruct-peftq_proj_k_proj_v_proj_o_proj-bs8-ga2-cwf_float/results/BaseFakepedia-sp_nodup_relpid-k0_OOD-cwf_instruction\n",
      "data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/1/models/Meta-Llama-3.1-8B-Instruct-peftq_proj_k_proj_v_proj_o_proj-bs8-ga2-cwf_float/results/BaseFakepedia-sp_nodup_relpid-k0_OOD-cwf_float\n",
      "data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/2/models/Meta-Llama-3.1-8B-Instruct-peftq_proj_k_proj_v_proj_o_proj-bs8-ga2-cwf_instruction/results/BaseFakepedia-sp_nodup_relpid-k0_OOD-cwf_instruction\n",
      "data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/2/models/Meta-Llama-3.1-8B-Instruct-peftq_proj_k_proj_v_proj_o_proj-bs8-ga2-cwf_instruction/results/BaseFakepedia-sp_nodup_relpid-k0_OOD-cwf_float\n",
      "data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/2/models/Meta-Llama-3.1-8B-Instruct-peftq_proj_k_proj_v_proj_o_proj-bs8-ga2-cwf_float/results/BaseFakepedia-sp_nodup_relpid-k0_OOD-cwf_instruction\n",
      "data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/2/models/Meta-Llama-3.1-8B-Instruct-peftq_proj_k_proj_v_proj_o_proj-bs8-ga2-cwf_float/results/BaseFakepedia-sp_nodup_relpid-k0_OOD-cwf_float\n",
      "data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/3/models/Meta-Llama-3.1-8B-Instruct-peftq_proj_k_proj_v_proj_o_proj-bs8-ga2-cwf_instruction/results/BaseFakepedia-sp_nodup_relpid-k0_OOD-cwf_instruction\n",
      "data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/3/models/Meta-Llama-3.1-8B-Instruct-peftq_proj_k_proj_v_proj_o_proj-bs8-ga2-cwf_instruction/results/BaseFakepedia-sp_nodup_relpid-k0_OOD-cwf_float\n",
      "data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/3/models/Meta-Llama-3.1-8B-Instruct-peftq_proj_k_proj_v_proj_o_proj-bs8-ga2-cwf_float/results/BaseFakepedia-sp_nodup_relpid-k0_OOD-cwf_instruction\n",
      "data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/3/models/Meta-Llama-3.1-8B-Instruct-peftq_proj_k_proj_v_proj_o_proj-bs8-ga2-cwf_float/results/BaseFakepedia-sp_nodup_relpid-k0_OOD-cwf_float\n"
     ]
    }
   ],
   "source": [
    "overwrite = False\n",
    "df_dict = []\n",
    "p_scores_df_dict = []\n",
    "metrics_dict = []\n",
    "\n",
    "for ds in dataset_names:\n",
    "    for sp in subsplit_names:\n",
    "        for seed in seeds:\n",
    "            for ts in train_sizes:\n",
    "                for nts in no_train_statuses:\n",
    "                    for pm in peft_modules:\n",
    "                        for cwf in context_weight_formats:\n",
    "                            for (\n",
    "                                model_id,\n",
    "                                bs,\n",
    "                                ga,\n",
    "                                quantize,\n",
    "                                peft,\n",
    "                            ) in model_id_and_bs_and_ga_and_quantize_and_peft_tuples:\n",
    "                                dict_vals = dict(\n",
    "                                    DATASET_NAME=ds,\n",
    "                                    SUBSPLIT=sp,\n",
    "                                    SEED=seed,\n",
    "                                    TRAIN_SIZE=ts,\n",
    "                                    MODEL_ID=model_id,\n",
    "                                    PEFT=peft and not nts,\n",
    "                                    LORA_MODULES=json.loads(pm),\n",
    "                                    LOAD_IN_4BIT=(quantize == \"4bit\"),\n",
    "                                    LOAD_IN_8BIT=(quantize == \"8bit\"),\n",
    "                                    BATCH_SZ=bs,\n",
    "                                    GRAD_ACCUM=ga,\n",
    "                                    NO_TRAIN=nts,\n",
    "                                    CONTEXT_WEIGHT_AT_END=False,\n",
    "                                    CONTEXT_WEIGHT_FORMAT=cwf,\n",
    "                                    ANSWER_FORMAT_PROMPT_POSITION=\"end\",\n",
    "                                    ADD_ANSWER_FORMAT_PROMPT=False,\n",
    "                                )\n",
    "                                (\n",
    "                                    data_dir,\n",
    "                                    input_dir,\n",
    "                                    model_dir,\n",
    "                                    results_dir,\n",
    "                                    val_results_path,\n",
    "                                    data_id,\n",
    "                                    full_model_id,\n",
    "                                    DATASET_KWARGS_IDENTIFIABLE,\n",
    "                                    MODEL_KWARGS_IDENTIFIABLE,\n",
    "                                ) = construct_paths_and_dataset_kwargs(**dict_vals)\n",
    "                                for (\n",
    "                                    eval_name,\n",
    "                                    eval_subsplit,\n",
    "                                    eval_k_demonstrations,\n",
    "                                    eval_ctx_weight_format,\n",
    "                                    eval_do_steering,\n",
    "                                ) in evals:\n",
    "                                    test_results_dir = construct_test_results_dir(\n",
    "                                        base_results_dir=results_dir,\n",
    "                                        subsplit=eval_subsplit,\n",
    "                                        context_weight_format=eval_ctx_weight_format,\n",
    "                                        eval_name=eval_name,\n",
    "                                        k_demonstrations=eval_k_demonstrations,\n",
    "                                        in_domain_demonstrations=False,\n",
    "                                        answer_format_prompt_position=None,\n",
    "                                        add_answer_format_prompt=False,\n",
    "                                        do_steering=False,\n",
    "                                        steering_prior_value=None,\n",
    "                                        steering_context_value=None,\n",
    "                                        steering_layer=None,\n",
    "                                    )\n",
    "                                    test_results_path = os.path.join(\n",
    "                                        test_results_dir, \"test.csv\"\n",
    "                                    )\n",
    "                                    # test_pscore_results_path = os.path.join(test_results_dir, \"test_pscore.csv\")\n",
    "                                    test_metrics_path = os.path.join(\n",
    "                                        test_results_dir, \"metrics.json\"\n",
    "                                    )\n",
    "                                    test_metrics_query_only_path = os.path.join(\n",
    "                                        test_results_dir, \"metrics_query_only.json\"\n",
    "                                    )\n",
    "                                    if os.path.isfile(test_results_path):\n",
    "                                        if \"Meta\" in test_results_path:\n",
    "                                            print(test_results_dir)\n",
    "                                        # Load predictions/results\n",
    "                                        res = pd.read_csv(\n",
    "                                            test_results_path,\n",
    "                                        )\n",
    "                                        for k, v in dict_vals.items():\n",
    "                                            if isinstance(v, list):\n",
    "                                                v = [v] * len(res)\n",
    "                                            res[k] = v\n",
    "                                        scores: List[dict] = res.to_dict(\"records\")\n",
    "                                        df_dict += [\n",
    "                                            {\n",
    "                                                **dict_vals,\n",
    "                                                **{\n",
    "                                                    \"EVAL_NAME\": eval_name,\n",
    "                                                    \"EVAL_K_DEMONSTRATIONS\": eval_k_demonstrations,\n",
    "                                                    \"EVAL_CTX_WEIGHT_FORMAT\": eval_ctx_weight_format,\n",
    "                                                    \"TEACH_METHOD\": (\n",
    "                                                        \"few_shot\"\n",
    "                                                        if nts\n",
    "                                                        else \"finetune\"\n",
    "                                                    ),\n",
    "                                                },\n",
    "                                                **d,\n",
    "                                            }\n",
    "                                            for d in scores\n",
    "                                        ]\n",
    "                                    # if os.path.isfile(test_pscore_results_path):\n",
    "                                    #     # Load predictions/results\n",
    "                                    #     res = pd.read_csv(\n",
    "                                    #         test_pscore_results_path,\n",
    "                                    #     )\n",
    "                                    #     for k, v in dict_vals.items():\n",
    "                                    #         if isinstance(v, list):\n",
    "                                    #             v = [v] * len(res)\n",
    "                                    #         res[k] = v\n",
    "                                    #     p_scores: List[dict] = res.to_dict(\"records\")\n",
    "                                    #     p_scores_df_dict += [{**dict_vals, **{\"EVAL_NAME\": eval_name, \"EVAL_K_DEMONSTRATIONS\": eval_k_demonstrations, \"EVAL_CTX_WEIGHT_FORMAT\": eval_ctx_weight_format, \"TEACH_METHOD\": \"few_shot\" if nts else \"finetune\"}, **d} for d in p_scores]\n",
    "                                    if os.path.isfile(test_metrics_path):\n",
    "                                        # Load metrics\n",
    "                                        metrics = load_dataset_from_path(\n",
    "                                            test_metrics_path\n",
    "                                        )\n",
    "                                        if os.path.isfile(test_metrics_query_only_path):\n",
    "                                            metrics_query_only = load_dataset_from_path(\n",
    "                                                test_metrics_query_only_path\n",
    "                                            )\n",
    "                                            metrics_query_only = {\n",
    "                                                f\"QO_{k}\": v\n",
    "                                                for k, v in metrics_query_only.items()\n",
    "                                            }\n",
    "                                        else:\n",
    "                                            metrics_query_only = {}\n",
    "                                        metrics_dict += [\n",
    "                                            {\n",
    "                                                **dict_vals,\n",
    "                                                **{\n",
    "                                                    \"EVAL_NAME\": eval_name,\n",
    "                                                    \"EVAL_K_DEMONSTRATIONS\": eval_k_demonstrations,\n",
    "                                                    \"EVAL_CTX_WEIGHT_FORMAT\": eval_ctx_weight_format,\n",
    "                                                    \"TEACH_METHOD\": (\n",
    "                                                        \"few_shot\"\n",
    "                                                        if nts\n",
    "                                                        else \"finetune\"\n",
    "                                                    ),\n",
    "                                                },\n",
    "                                                **metrics,\n",
    "                                                **metrics_query_only,\n",
    "                                            }\n",
    "                                        ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 35 entries, 0 to 34\n",
      "Data columns (total 30 columns):\n",
      " #   Column                         Non-Null Count  Dtype  \n",
      "---  ------                         --------------  -----  \n",
      " 0   DATASET_NAME                   35 non-null     object \n",
      " 1   SUBSPLIT                       35 non-null     object \n",
      " 2   SEED                           35 non-null     int64  \n",
      " 3   TRAIN_SIZE                     35 non-null     int64  \n",
      " 4   MODEL_ID                       35 non-null     object \n",
      " 5   PEFT                           35 non-null     bool   \n",
      " 6   LORA_MODULES                   35 non-null     object \n",
      " 7   LOAD_IN_4BIT                   35 non-null     bool   \n",
      " 8   LOAD_IN_8BIT                   35 non-null     bool   \n",
      " 9   BATCH_SZ                       35 non-null     int64  \n",
      " 10  GRAD_ACCUM                     35 non-null     int64  \n",
      " 11  NO_TRAIN                       35 non-null     bool   \n",
      " 12  CONTEXT_WEIGHT_AT_END          35 non-null     bool   \n",
      " 13  CONTEXT_WEIGHT_FORMAT          35 non-null     object \n",
      " 14  ANSWER_FORMAT_PROMPT_POSITION  35 non-null     object \n",
      " 15  ADD_ANSWER_FORMAT_PROMPT       35 non-null     bool   \n",
      " 16  EVAL_NAME                      35 non-null     object \n",
      " 17  EVAL_K_DEMONSTRATIONS          35 non-null     int64  \n",
      " 18  EVAL_CTX_WEIGHT_FORMAT         35 non-null     object \n",
      " 19  TEACH_METHOD                   35 non-null     object \n",
      " 20  acc                            35 non-null     float64\n",
      " 21  context_acc                    35 non-null     float64\n",
      " 22  context_mr                     35 non-null     float64\n",
      " 23  context_pct_other              35 non-null     float64\n",
      " 24  overall_mr                     35 non-null     float64\n",
      " 25  overall_pct_other              35 non-null     float64\n",
      " 26  pair_acc                       35 non-null     float64\n",
      " 27  prior_acc                      35 non-null     float64\n",
      " 28  prior_mr                       35 non-null     float64\n",
      " 29  prior_pct_other                35 non-null     float64\n",
      "dtypes: bool(6), float64(10), int64(5), object(9)\n",
      "memory usage: 6.9+ KB\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Index(['DATASET_NAME', 'SUBSPLIT', 'SEED', 'TRAIN_SIZE', 'MODEL_ID', 'PEFT',\n",
       "       'LORA_MODULES', 'LOAD_IN_4BIT', 'LOAD_IN_8BIT', 'BATCH_SZ',\n",
       "       'GRAD_ACCUM', 'NO_TRAIN', 'CONTEXT_WEIGHT_AT_END',\n",
       "       'CONTEXT_WEIGHT_FORMAT', 'ANSWER_FORMAT_PROMPT_POSITION',\n",
       "       'ADD_ANSWER_FORMAT_PROMPT', 'EVAL_NAME', 'EVAL_K_DEMONSTRATIONS',\n",
       "       'EVAL_CTX_WEIGHT_FORMAT', 'TEACH_METHOD', 'acc', 'context_acc',\n",
       "       'context_mr', 'context_pct_other', 'overall_mr', 'overall_pct_other',\n",
       "       'pair_acc', 'prior_acc', 'prior_mr', 'prior_pct_other'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics_df = pd.DataFrame(metrics_dict)\n",
    "# [[\"EVAL_NAME\", \"TEACH_METHOD\", \"accuracy\", \"QO_accuracy\", \"pair_accuracy\"]]\n",
    "metrics_df.info()\n",
    "metrics_df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SEED</th>\n",
       "      <th>MODEL_ID</th>\n",
       "      <th>EVAL_CTX_WEIGHT_FORMAT</th>\n",
       "      <th>CONTEXT_WEIGHT_FORMAT</th>\n",
       "      <th>pair_acc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>Meta-Llama-3.1-8B-Instruct</td>\n",
       "      <td>instruction</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.930</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>Meta-Llama-3.1-8B-Instruct</td>\n",
       "      <td>float</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.476</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>Mistral-7B-Instruct-v0.3</td>\n",
       "      <td>instruction</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.914</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>Mistral-7B-Instruct-v0.3</td>\n",
       "      <td>float</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.056</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>gemma-2-9b-it</td>\n",
       "      <td>instruction</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.842</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1</td>\n",
       "      <td>gemma-2-9b-it</td>\n",
       "      <td>float</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.726</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1</td>\n",
       "      <td>Meta-Llama-3.1-8B-Instruct</td>\n",
       "      <td>instruction</td>\n",
       "      <td>float</td>\n",
       "      <td>0.880</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>1</td>\n",
       "      <td>Meta-Llama-3.1-8B-Instruct</td>\n",
       "      <td>float</td>\n",
       "      <td>float</td>\n",
       "      <td>0.924</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1</td>\n",
       "      <td>Mistral-7B-Instruct-v0.3</td>\n",
       "      <td>instruction</td>\n",
       "      <td>float</td>\n",
       "      <td>0.952</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1</td>\n",
       "      <td>gemma-2-9b-it</td>\n",
       "      <td>instruction</td>\n",
       "      <td>float</td>\n",
       "      <td>0.850</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>1</td>\n",
       "      <td>gemma-2-9b-it</td>\n",
       "      <td>float</td>\n",
       "      <td>float</td>\n",
       "      <td>0.852</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>2</td>\n",
       "      <td>Meta-Llama-3.1-8B-Instruct</td>\n",
       "      <td>instruction</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.916</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>2</td>\n",
       "      <td>Meta-Llama-3.1-8B-Instruct</td>\n",
       "      <td>float</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.810</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>2</td>\n",
       "      <td>Mistral-7B-Instruct-v0.3</td>\n",
       "      <td>instruction</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.948</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>2</td>\n",
       "      <td>Mistral-7B-Instruct-v0.3</td>\n",
       "      <td>float</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.820</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>2</td>\n",
       "      <td>gemma-2-9b-it</td>\n",
       "      <td>instruction</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.846</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>2</td>\n",
       "      <td>gemma-2-9b-it</td>\n",
       "      <td>float</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.710</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>2</td>\n",
       "      <td>Meta-Llama-3.1-8B-Instruct</td>\n",
       "      <td>instruction</td>\n",
       "      <td>float</td>\n",
       "      <td>0.892</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>2</td>\n",
       "      <td>Meta-Llama-3.1-8B-Instruct</td>\n",
       "      <td>float</td>\n",
       "      <td>float</td>\n",
       "      <td>0.950</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>2</td>\n",
       "      <td>Mistral-7B-Instruct-v0.3</td>\n",
       "      <td>instruction</td>\n",
       "      <td>float</td>\n",
       "      <td>0.780</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>2</td>\n",
       "      <td>Mistral-7B-Instruct-v0.3</td>\n",
       "      <td>float</td>\n",
       "      <td>float</td>\n",
       "      <td>0.940</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>2</td>\n",
       "      <td>gemma-2-9b-it</td>\n",
       "      <td>instruction</td>\n",
       "      <td>float</td>\n",
       "      <td>0.866</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>2</td>\n",
       "      <td>gemma-2-9b-it</td>\n",
       "      <td>float</td>\n",
       "      <td>float</td>\n",
       "      <td>0.860</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>3</td>\n",
       "      <td>Meta-Llama-3.1-8B-Instruct</td>\n",
       "      <td>instruction</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.948</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>3</td>\n",
       "      <td>Meta-Llama-3.1-8B-Instruct</td>\n",
       "      <td>float</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.706</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>3</td>\n",
       "      <td>Mistral-7B-Instruct-v0.3</td>\n",
       "      <td>instruction</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.898</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>3</td>\n",
       "      <td>Mistral-7B-Instruct-v0.3</td>\n",
       "      <td>float</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>3</td>\n",
       "      <td>gemma-2-9b-it</td>\n",
       "      <td>instruction</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.844</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>3</td>\n",
       "      <td>gemma-2-9b-it</td>\n",
       "      <td>float</td>\n",
       "      <td>instruction</td>\n",
       "      <td>0.690</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>3</td>\n",
       "      <td>Meta-Llama-3.1-8B-Instruct</td>\n",
       "      <td>instruction</td>\n",
       "      <td>float</td>\n",
       "      <td>0.930</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>3</td>\n",
       "      <td>Meta-Llama-3.1-8B-Instruct</td>\n",
       "      <td>float</td>\n",
       "      <td>float</td>\n",
       "      <td>0.940</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>3</td>\n",
       "      <td>Mistral-7B-Instruct-v0.3</td>\n",
       "      <td>instruction</td>\n",
       "      <td>float</td>\n",
       "      <td>0.784</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>3</td>\n",
       "      <td>Mistral-7B-Instruct-v0.3</td>\n",
       "      <td>float</td>\n",
       "      <td>float</td>\n",
       "      <td>0.944</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>3</td>\n",
       "      <td>gemma-2-9b-it</td>\n",
       "      <td>instruction</td>\n",
       "      <td>float</td>\n",
       "      <td>0.850</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>3</td>\n",
       "      <td>gemma-2-9b-it</td>\n",
       "      <td>float</td>\n",
       "      <td>float</td>\n",
       "      <td>0.848</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    SEED                    MODEL_ID EVAL_CTX_WEIGHT_FORMAT  \\\n",
       "0      1  Meta-Llama-3.1-8B-Instruct            instruction   \n",
       "1      1  Meta-Llama-3.1-8B-Instruct                  float   \n",
       "2      1    Mistral-7B-Instruct-v0.3            instruction   \n",
       "3      1    Mistral-7B-Instruct-v0.3                  float   \n",
       "4      1               gemma-2-9b-it            instruction   \n",
       "5      1               gemma-2-9b-it                  float   \n",
       "6      1  Meta-Llama-3.1-8B-Instruct            instruction   \n",
       "7      1  Meta-Llama-3.1-8B-Instruct                  float   \n",
       "8      1    Mistral-7B-Instruct-v0.3            instruction   \n",
       "9      1               gemma-2-9b-it            instruction   \n",
       "10     1               gemma-2-9b-it                  float   \n",
       "11     2  Meta-Llama-3.1-8B-Instruct            instruction   \n",
       "12     2  Meta-Llama-3.1-8B-Instruct                  float   \n",
       "13     2    Mistral-7B-Instruct-v0.3            instruction   \n",
       "14     2    Mistral-7B-Instruct-v0.3                  float   \n",
       "15     2               gemma-2-9b-it            instruction   \n",
       "16     2               gemma-2-9b-it                  float   \n",
       "17     2  Meta-Llama-3.1-8B-Instruct            instruction   \n",
       "18     2  Meta-Llama-3.1-8B-Instruct                  float   \n",
       "19     2    Mistral-7B-Instruct-v0.3            instruction   \n",
       "20     2    Mistral-7B-Instruct-v0.3                  float   \n",
       "21     2               gemma-2-9b-it            instruction   \n",
       "22     2               gemma-2-9b-it                  float   \n",
       "23     3  Meta-Llama-3.1-8B-Instruct            instruction   \n",
       "24     3  Meta-Llama-3.1-8B-Instruct                  float   \n",
       "25     3    Mistral-7B-Instruct-v0.3            instruction   \n",
       "26     3    Mistral-7B-Instruct-v0.3                  float   \n",
       "27     3               gemma-2-9b-it            instruction   \n",
       "28     3               gemma-2-9b-it                  float   \n",
       "29     3  Meta-Llama-3.1-8B-Instruct            instruction   \n",
       "30     3  Meta-Llama-3.1-8B-Instruct                  float   \n",
       "31     3    Mistral-7B-Instruct-v0.3            instruction   \n",
       "32     3    Mistral-7B-Instruct-v0.3                  float   \n",
       "33     3               gemma-2-9b-it            instruction   \n",
       "34     3               gemma-2-9b-it                  float   \n",
       "\n",
       "   CONTEXT_WEIGHT_FORMAT  pair_acc  \n",
       "0            instruction     0.930  \n",
       "1            instruction     0.476  \n",
       "2            instruction     0.914  \n",
       "3            instruction     0.056  \n",
       "4            instruction     0.842  \n",
       "5            instruction     0.726  \n",
       "6                  float     0.880  \n",
       "7                  float     0.924  \n",
       "8                  float     0.952  \n",
       "9                  float     0.850  \n",
       "10                 float     0.852  \n",
       "11           instruction     0.916  \n",
       "12           instruction     0.810  \n",
       "13           instruction     0.948  \n",
       "14           instruction     0.820  \n",
       "15           instruction     0.846  \n",
       "16           instruction     0.710  \n",
       "17                 float     0.892  \n",
       "18                 float     0.950  \n",
       "19                 float     0.780  \n",
       "20                 float     0.940  \n",
       "21                 float     0.866  \n",
       "22                 float     0.860  \n",
       "23           instruction     0.948  \n",
       "24           instruction     0.706  \n",
       "25           instruction     0.898  \n",
       "26           instruction     0.100  \n",
       "27           instruction     0.844  \n",
       "28           instruction     0.690  \n",
       "29                 float     0.930  \n",
       "30                 float     0.940  \n",
       "31                 float     0.784  \n",
       "32                 float     0.944  \n",
       "33                 float     0.850  \n",
       "34                 float     0.848  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics_df_short = metrics_df[\n",
    "    [\n",
    "        # \"DATASET_NAME\",\n",
    "        \"SEED\",\n",
    "        \"MODEL_ID\",\n",
    "        # \"NO_TRAIN\",\n",
    "        # \"EVAL_NAME\",\n",
    "        \"EVAL_CTX_WEIGHT_FORMAT\",\n",
    "        \"CONTEXT_WEIGHT_FORMAT\",\n",
    "        # \"EVAL_K_DEMONSTRATIONS\",\n",
    "        # \"TEACH_METHOD\",\n",
    "        # \"acc\",\n",
    "        \"pair_acc\",\n",
    "    ]\n",
    "]\n",
    "# metrics_df_short[\"TEACH_METHOD\"] = metrics_df_short.apply(\n",
    "#     lambda x: (\n",
    "#         \"zero shot\"\n",
    "#         if x[\"TEACH_METHOD\"] == \"few_shot\" and x[\"EVAL_K_DEMONSTRATIONS\"] == 0\n",
    "#         else x[\"TEACH_METHOD\"]\n",
    "#     ),\n",
    "#     axis=1,\n",
    "# )\n",
    "# metrics_df_short = metrics_df_short.rename(\n",
    "#     columns={\n",
    "#         \"MODEL_ID\": \"Model\",\n",
    "#         \"EVAL_NAME\": \"Test Dataset\",\n",
    "#         \"TEACH_METHOD\": \"Train/Eval Setting\",\n",
    "#     }\n",
    "# )\n",
    "metrics_df_short\n",
    "# metrics_df_short[]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SEED</th>\n",
       "      <th>MODEL_ID</th>\n",
       "      <th>EVAL_CTX_WEIGHT_FORMAT</th>\n",
       "      <th>CONTEXT_WEIGHT_FORMAT</th>\n",
       "      <th>pair_acc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "Empty DataFrame\n",
       "Columns: [SEED, MODEL_ID, EVAL_CTX_WEIGHT_FORMAT, CONTEXT_WEIGHT_FORMAT, pair_acc]\n",
       "Index: []"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics_instruct_only = metrics_df_short[\n",
    "    metrics_df_short[\"MODEL_ID\"] == \"Meta-Llama-3.1-8B\"\n",
    "]\n",
    "metrics_instruct_only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.plotly.v1+json": {
       "config": {
        "plotlyServerURL": "https://plot.ly"
       },
       "data": [
        {
         "colorscale": [
          [
           0,
           "rgb(255,255,217)"
          ],
          [
           0.125,
           "rgb(237,248,177)"
          ],
          [
           0.25,
           "rgb(199,233,180)"
          ],
          [
           0.375,
           "rgb(127,205,187)"
          ],
          [
           0.5,
           "rgb(65,182,196)"
          ],
          [
           0.625,
           "rgb(29,145,192)"
          ],
          [
           0.75,
           "rgb(34,94,168)"
          ],
          [
           0.875,
           "rgb(37,52,148)"
          ],
          [
           1,
           "rgb(8,29,88)"
          ]
         ],
         "hoverongaps": false,
         "text": [
          [
           "0.901 ± 0.026",
           "0.938 ± 0.013"
          ],
          [
           "0.931 ± 0.016",
           "0.664 ± 0.171"
          ]
         ],
         "textfont": {
          "size": 84
         },
         "texttemplate": "%{text}",
         "type": "heatmap",
         "x": [
          "🫵",
          "1️⃣"
         ],
         "y": [
          "1️⃣",
          "🫵"
         ],
         "z": [
          [
           0.9006666666666666,
           0.9380000000000001
          ],
          [
           0.9313333333333333,
           0.664
          ]
         ]
        }
       ],
       "layout": {
        "font": {
         "family": "Computer Modern",
         "size": 60
        },
        "height": 600,
        "template": {
         "data": {
          "bar": [
           {
            "error_x": {
             "color": "#2a3f5f"
            },
            "error_y": {
             "color": "#2a3f5f"
            },
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "bar"
           }
          ],
          "barpolar": [
           {
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "barpolar"
           }
          ],
          "carpet": [
           {
            "aaxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "baxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "type": "carpet"
           }
          ],
          "choropleth": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "choropleth"
           }
          ],
          "contour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "contour"
           }
          ],
          "contourcarpet": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "contourcarpet"
           }
          ],
          "heatmap": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "heatmap"
           }
          ],
          "heatmapgl": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "heatmapgl"
           }
          ],
          "histogram": [
           {
            "marker": {
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "histogram"
           }
          ],
          "histogram2d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2d"
           }
          ],
          "histogram2dcontour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2dcontour"
           }
          ],
          "mesh3d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "mesh3d"
           }
          ],
          "parcoords": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "parcoords"
           }
          ],
          "pie": [
           {
            "automargin": true,
            "type": "pie"
           }
          ],
          "scatter": [
           {
            "fillpattern": {
             "fillmode": "overlay",
             "size": 10,
             "solidity": 0.2
            },
            "type": "scatter"
           }
          ],
          "scatter3d": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatter3d"
           }
          ],
          "scattercarpet": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattercarpet"
           }
          ],
          "scattergeo": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergeo"
           }
          ],
          "scattergl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergl"
           }
          ],
          "scattermapbox": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermapbox"
           }
          ],
          "scatterpolar": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolar"
           }
          ],
          "scatterpolargl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolargl"
           }
          ],
          "scatterternary": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterternary"
           }
          ],
          "surface": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "surface"
           }
          ],
          "table": [
           {
            "cells": {
             "fill": {
              "color": "#EBF0F8"
             },
             "line": {
              "color": "white"
             }
            },
            "header": {
             "fill": {
              "color": "#C8D4E3"
             },
             "line": {
              "color": "white"
             }
            },
            "type": "table"
           }
          ]
         },
         "layout": {
          "annotationdefaults": {
           "arrowcolor": "#2a3f5f",
           "arrowhead": 0,
           "arrowwidth": 1
          },
          "autotypenumbers": "strict",
          "coloraxis": {
           "colorbar": {
            "outlinewidth": 0,
            "ticks": ""
           }
          },
          "colorscale": {
           "diverging": [
            [
             0,
             "#8e0152"
            ],
            [
             0.1,
             "#c51b7d"
            ],
            [
             0.2,
             "#de77ae"
            ],
            [
             0.3,
             "#f1b6da"
            ],
            [
             0.4,
             "#fde0ef"
            ],
            [
             0.5,
             "#f7f7f7"
            ],
            [
             0.6,
             "#e6f5d0"
            ],
            [
             0.7,
             "#b8e186"
            ],
            [
             0.8,
             "#7fbc41"
            ],
            [
             0.9,
             "#4d9221"
            ],
            [
             1,
             "#276419"
            ]
           ],
           "sequential": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ],
           "sequentialminus": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ]
          },
          "colorway": [
           "#636efa",
           "#EF553B",
           "#00cc96",
           "#ab63fa",
           "#FFA15A",
           "#19d3f3",
           "#FF6692",
           "#B6E880",
           "#FF97FF",
           "#FECB52"
          ],
          "font": {
           "color": "#2a3f5f"
          },
          "geo": {
           "bgcolor": "white",
           "lakecolor": "white",
           "landcolor": "#E5ECF6",
           "showlakes": true,
           "showland": true,
           "subunitcolor": "white"
          },
          "hoverlabel": {
           "align": "left"
          },
          "hovermode": "closest",
          "mapbox": {
           "style": "light"
          },
          "paper_bgcolor": "white",
          "plot_bgcolor": "#E5ECF6",
          "polar": {
           "angularaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "radialaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "scene": {
           "xaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "yaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "zaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           }
          },
          "shapedefaults": {
           "line": {
            "color": "#2a3f5f"
           }
          },
          "ternary": {
           "aaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "baxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "caxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "title": {
           "x": 0.05
          },
          "xaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          },
          "yaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          }
         }
        },
        "width": 700,
        "xaxis": {
         "title": {
          "text": "Eval Intent Format"
         }
        },
        "yaxis": {
         "title": {
          "text": "Train Intent Format"
         }
        }
       }
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Figure saved as 'confusion_matrix_pair_accuracy.png'\n"
     ]
    },
    {
     "data": {
      "application/vnd.plotly.v1+json": {
       "config": {
        "plotlyServerURL": "https://plot.ly"
       },
       "data": [
        {
         "colorscale": [
          [
           0,
           "rgb(255,255,217)"
          ],
          [
           0.125,
           "rgb(237,248,177)"
          ],
          [
           0.25,
           "rgb(199,233,180)"
          ],
          [
           0.375,
           "rgb(127,205,187)"
          ],
          [
           0.5,
           "rgb(65,182,196)"
          ],
          [
           0.625,
           "rgb(29,145,192)"
          ],
          [
           0.75,
           "rgb(34,94,168)"
          ],
          [
           0.875,
           "rgb(37,52,148)"
          ],
          [
           1,
           "rgb(8,29,88)"
          ]
         ],
         "hoverongaps": false,
         "text": [
          [
           "0.839 ± 0.098",
           "0.942 ± 0.003"
          ],
          [
           "0.920 ± 0.026",
           "0.325 ± 0.429"
          ]
         ],
         "textfont": {
          "size": 84
         },
         "texttemplate": "%{text}",
         "type": "heatmap",
         "x": [
          "🫵",
          "1️⃣"
         ],
         "y": [
          "1️⃣",
          "🫵"
         ],
         "z": [
          [
           0.8386666666666667,
           0.942
          ],
          [
           0.9199999999999999,
           0.3253333333333333
          ]
         ]
        }
       ],
       "layout": {
        "font": {
         "family": "Computer Modern",
         "size": 60
        },
        "height": 600,
        "template": {
         "data": {
          "bar": [
           {
            "error_x": {
             "color": "#2a3f5f"
            },
            "error_y": {
             "color": "#2a3f5f"
            },
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "bar"
           }
          ],
          "barpolar": [
           {
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "barpolar"
           }
          ],
          "carpet": [
           {
            "aaxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "baxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "type": "carpet"
           }
          ],
          "choropleth": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "choropleth"
           }
          ],
          "contour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "contour"
           }
          ],
          "contourcarpet": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "contourcarpet"
           }
          ],
          "heatmap": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "heatmap"
           }
          ],
          "heatmapgl": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "heatmapgl"
           }
          ],
          "histogram": [
           {
            "marker": {
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "histogram"
           }
          ],
          "histogram2d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2d"
           }
          ],
          "histogram2dcontour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2dcontour"
           }
          ],
          "mesh3d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "mesh3d"
           }
          ],
          "parcoords": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "parcoords"
           }
          ],
          "pie": [
           {
            "automargin": true,
            "type": "pie"
           }
          ],
          "scatter": [
           {
            "fillpattern": {
             "fillmode": "overlay",
             "size": 10,
             "solidity": 0.2
            },
            "type": "scatter"
           }
          ],
          "scatter3d": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatter3d"
           }
          ],
          "scattercarpet": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattercarpet"
           }
          ],
          "scattergeo": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergeo"
           }
          ],
          "scattergl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergl"
           }
          ],
          "scattermapbox": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermapbox"
           }
          ],
          "scatterpolar": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolar"
           }
          ],
          "scatterpolargl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolargl"
           }
          ],
          "scatterternary": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterternary"
           }
          ],
          "surface": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "surface"
           }
          ],
          "table": [
           {
            "cells": {
             "fill": {
              "color": "#EBF0F8"
             },
             "line": {
              "color": "white"
             }
            },
            "header": {
             "fill": {
              "color": "#C8D4E3"
             },
             "line": {
              "color": "white"
             }
            },
            "type": "table"
           }
          ]
         },
         "layout": {
          "annotationdefaults": {
           "arrowcolor": "#2a3f5f",
           "arrowhead": 0,
           "arrowwidth": 1
          },
          "autotypenumbers": "strict",
          "coloraxis": {
           "colorbar": {
            "outlinewidth": 0,
            "ticks": ""
           }
          },
          "colorscale": {
           "diverging": [
            [
             0,
             "#8e0152"
            ],
            [
             0.1,
             "#c51b7d"
            ],
            [
             0.2,
             "#de77ae"
            ],
            [
             0.3,
             "#f1b6da"
            ],
            [
             0.4,
             "#fde0ef"
            ],
            [
             0.5,
             "#f7f7f7"
            ],
            [
             0.6,
             "#e6f5d0"
            ],
            [
             0.7,
             "#b8e186"
            ],
            [
             0.8,
             "#7fbc41"
            ],
            [
             0.9,
             "#4d9221"
            ],
            [
             1,
             "#276419"
            ]
           ],
           "sequential": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ],
           "sequentialminus": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ]
          },
          "colorway": [
           "#636efa",
           "#EF553B",
           "#00cc96",
           "#ab63fa",
           "#FFA15A",
           "#19d3f3",
           "#FF6692",
           "#B6E880",
           "#FF97FF",
           "#FECB52"
          ],
          "font": {
           "color": "#2a3f5f"
          },
          "geo": {
           "bgcolor": "white",
           "lakecolor": "white",
           "landcolor": "#E5ECF6",
           "showlakes": true,
           "showland": true,
           "subunitcolor": "white"
          },
          "hoverlabel": {
           "align": "left"
          },
          "hovermode": "closest",
          "mapbox": {
           "style": "light"
          },
          "paper_bgcolor": "white",
          "plot_bgcolor": "#E5ECF6",
          "polar": {
           "angularaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "radialaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "scene": {
           "xaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "yaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "zaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           }
          },
          "shapedefaults": {
           "line": {
            "color": "#2a3f5f"
           }
          },
          "ternary": {
           "aaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "baxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "caxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "title": {
           "x": 0.05
          },
          "xaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          },
          "yaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          }
         }
        },
        "width": 700,
        "xaxis": {
         "title": {
          "text": "Eval Intent Format"
         }
        },
        "yaxis": {
         "title": {
          "text": "Train Intent Format"
         }
        }
       }
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Figure saved as 'confusion_matrix_pair_accuracy.png'\n"
     ]
    },
    {
     "data": {
      "application/vnd.plotly.v1+json": {
       "config": {
        "plotlyServerURL": "https://plot.ly"
       },
       "data": [
        {
         "colorscale": [
          [
           0,
           "rgb(255,255,217)"
          ],
          [
           0.125,
           "rgb(237,248,177)"
          ],
          [
           0.25,
           "rgb(199,233,180)"
          ],
          [
           0.375,
           "rgb(127,205,187)"
          ],
          [
           0.5,
           "rgb(65,182,196)"
          ],
          [
           0.625,
           "rgb(29,145,192)"
          ],
          [
           0.75,
           "rgb(34,94,168)"
          ],
          [
           0.875,
           "rgb(37,52,148)"
          ],
          [
           1,
           "rgb(8,29,88)"
          ]
         ],
         "hoverongaps": false,
         "text": [
          [
           "0.855 ± 0.009",
           "0.853 ± 0.006"
          ],
          [
           "0.844 ± 0.002",
           "0.709 ± 0.018"
          ]
         ],
         "textfont": {
          "size": 84
         },
         "texttemplate": "%{text}",
         "type": "heatmap",
         "x": [
          "🫵",
          "1️⃣"
         ],
         "y": [
          "1️⃣",
          "🫵"
         ],
         "z": [
          [
           0.8553333333333333,
           0.8533333333333334
          ],
          [
           0.844,
           0.7086666666666667
          ]
         ]
        }
       ],
       "layout": {
        "font": {
         "family": "Computer Modern",
         "size": 60
        },
        "height": 600,
        "template": {
         "data": {
          "bar": [
           {
            "error_x": {
             "color": "#2a3f5f"
            },
            "error_y": {
             "color": "#2a3f5f"
            },
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "bar"
           }
          ],
          "barpolar": [
           {
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "barpolar"
           }
          ],
          "carpet": [
           {
            "aaxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "baxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "type": "carpet"
           }
          ],
          "choropleth": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "choropleth"
           }
          ],
          "contour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "contour"
           }
          ],
          "contourcarpet": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "contourcarpet"
           }
          ],
          "heatmap": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "heatmap"
           }
          ],
          "heatmapgl": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "heatmapgl"
           }
          ],
          "histogram": [
           {
            "marker": {
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "histogram"
           }
          ],
          "histogram2d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2d"
           }
          ],
          "histogram2dcontour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2dcontour"
           }
          ],
          "mesh3d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "mesh3d"
           }
          ],
          "parcoords": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "parcoords"
           }
          ],
          "pie": [
           {
            "automargin": true,
            "type": "pie"
           }
          ],
          "scatter": [
           {
            "fillpattern": {
             "fillmode": "overlay",
             "size": 10,
             "solidity": 0.2
            },
            "type": "scatter"
           }
          ],
          "scatter3d": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatter3d"
           }
          ],
          "scattercarpet": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattercarpet"
           }
          ],
          "scattergeo": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergeo"
           }
          ],
          "scattergl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergl"
           }
          ],
          "scattermapbox": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermapbox"
           }
          ],
          "scatterpolar": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolar"
           }
          ],
          "scatterpolargl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolargl"
           }
          ],
          "scatterternary": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterternary"
           }
          ],
          "surface": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "surface"
           }
          ],
          "table": [
           {
            "cells": {
             "fill": {
              "color": "#EBF0F8"
             },
             "line": {
              "color": "white"
             }
            },
            "header": {
             "fill": {
              "color": "#C8D4E3"
             },
             "line": {
              "color": "white"
             }
            },
            "type": "table"
           }
          ]
         },
         "layout": {
          "annotationdefaults": {
           "arrowcolor": "#2a3f5f",
           "arrowhead": 0,
           "arrowwidth": 1
          },
          "autotypenumbers": "strict",
          "coloraxis": {
           "colorbar": {
            "outlinewidth": 0,
            "ticks": ""
           }
          },
          "colorscale": {
           "diverging": [
            [
             0,
             "#8e0152"
            ],
            [
             0.1,
             "#c51b7d"
            ],
            [
             0.2,
             "#de77ae"
            ],
            [
             0.3,
             "#f1b6da"
            ],
            [
             0.4,
             "#fde0ef"
            ],
            [
             0.5,
             "#f7f7f7"
            ],
            [
             0.6,
             "#e6f5d0"
            ],
            [
             0.7,
             "#b8e186"
            ],
            [
             0.8,
             "#7fbc41"
            ],
            [
             0.9,
             "#4d9221"
            ],
            [
             1,
             "#276419"
            ]
           ],
           "sequential": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ],
           "sequentialminus": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ]
          },
          "colorway": [
           "#636efa",
           "#EF553B",
           "#00cc96",
           "#ab63fa",
           "#FFA15A",
           "#19d3f3",
           "#FF6692",
           "#B6E880",
           "#FF97FF",
           "#FECB52"
          ],
          "font": {
           "color": "#2a3f5f"
          },
          "geo": {
           "bgcolor": "white",
           "lakecolor": "white",
           "landcolor": "#E5ECF6",
           "showlakes": true,
           "showland": true,
           "subunitcolor": "white"
          },
          "hoverlabel": {
           "align": "left"
          },
          "hovermode": "closest",
          "mapbox": {
           "style": "light"
          },
          "paper_bgcolor": "white",
          "plot_bgcolor": "#E5ECF6",
          "polar": {
           "angularaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "radialaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "scene": {
           "xaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "yaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "zaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           }
          },
          "shapedefaults": {
           "line": {
            "color": "#2a3f5f"
           }
          },
          "ternary": {
           "aaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "baxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "caxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "title": {
           "x": 0.05
          },
          "xaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          },
          "yaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          }
         }
        },
        "width": 700,
        "xaxis": {
         "title": {
          "text": "Eval Intent Format"
         }
        },
        "yaxis": {
         "title": {
          "text": "Train Intent Format"
         }
        }
       }
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Figure saved as 'confusion_matrix_pair_accuracy.png'\n"
     ]
    }
   ],
   "source": [
    "import plotly.graph_objects as go\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "for model, metrics_instruct_only in metrics_df_short.groupby(\"MODEL_ID\"):\n",
    "    # Calculate mean and standard deviation of pair_acc across seeds\n",
    "    grouped = metrics_instruct_only.groupby(\n",
    "        [\"CONTEXT_WEIGHT_FORMAT\", \"EVAL_CTX_WEIGHT_FORMAT\"]\n",
    "    )[\"pair_acc\"]\n",
    "    mean_metrics = grouped.mean().unstack()\n",
    "    std_metrics = grouped.std().unstack()\n",
    "\n",
    "    # Reorder the index and columns\n",
    "    mean_metrics = mean_metrics.reindex(\n",
    "        columns=[\"instruction\", \"float\"], index=[\"float\", \"instruction\"]\n",
    "    )\n",
    "    std_metrics = std_metrics.reindex(\n",
    "        columns=[\"instruction\", \"float\"], index=[\"float\", \"instruction\"]\n",
    "    )\n",
    "\n",
    "    # Create text for the heatmap cells\n",
    "    text = [\n",
    "        [f\"{mean:.3f} ± {std:.3f}\" for mean, std in zip(row_means, row_stds)]\n",
    "        for row_means, row_stds in zip(mean_metrics.values, std_metrics.values)\n",
    "    ]\n",
    "\n",
    "    # Create the heatmap using plotly\n",
    "    fig = go.Figure(\n",
    "        data=go.Heatmap(\n",
    "            z=mean_metrics.values,\n",
    "            x=[\"🫵\", \"1️⃣\"],\n",
    "            y=[\"1️⃣\", \"🫵\"],\n",
    "            hoverongaps=False,\n",
    "            text=text,\n",
    "            texttemplate=\"%{text}\",\n",
    "            textfont={\"size\": 84},\n",
    "            colorscale=\"YlGnBu\",\n",
    "        )\n",
    "    )\n",
    "\n",
    "    fig.update_layout(\n",
    "        # title=\"Confusion Matrix of Mean Pair Accuracy (± Std Dev) for Meta-Llama-3.1-8B-Instruct\",\n",
    "        xaxis_title=\"Eval Intent Format\",\n",
    "        yaxis_title=\"Train Intent Format\",\n",
    "        width=700,\n",
    "        height=600,\n",
    "        font=dict(family=\"Computer Modern\", size=60),\n",
    "    )\n",
    "\n",
    "    fig.show()\n",
    "\n",
    "    # Save the figure to a high resolution PNG\n",
    "    fig.write_image(\n",
    "        f\"{model}_confusion_matrix_pair_accuracy.png\", scale=4, width=1400, height=1200\n",
    "    )\n",
    "\n",
    "    # Optionally, you can also save it as an interactive HTML file\n",
    "    # fig.write_html(\"confusion_matrix_pair_accuracy.html\")\n",
    "\n",
    "    print(\"Figure saved as 'confusion_matrix_pair_accuracy.png'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sftcontext",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
