{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Arjun\\Code\\ai\\self_detection\\venv\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from data import load_data, save_to_json, load_from_json\n",
    "from matplotlib.gridspec import GridSpec\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import plotly.graph_objects as go\n",
    "from plotly.subplots import make_subplots\n",
    "from tabulate import tabulate\n",
    "import pandas as pd\n",
    "from itertools import zip_longest\n",
    "from pandas import MultiIndex\n",
    "import altair as alt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import log\n",
    "from scipy.stats import kendalltau\n",
    "\n",
    "COLORS = ['red', 'blue', 'green', 'orange', 'purple', 'cyan', 'magenta', 'yellow', 'black', 'grey', 'pink']\n",
    "\n",
    "def avg(l):\n",
    "    return sum(l) / len(l)\n",
    "\n",
    "def brier_score(pred):\n",
    "    return avg([(x - 1) ** 2 for x in pred])\n",
    "\n",
    "def log_loss(pred):\n",
    "    return -1 * avg([log(x) for x in pred])\n",
    "\n",
    "def kendall_tau_for_results(model_results):\n",
    "    detection_scores = [i['detection_score'] for i in model_results] \n",
    "    self_preferences = [i['self_preference'] for i in model_results]\n",
    "    \n",
    "    return kendalltau(detection_scores, self_preferences).correlation\n",
    "\n",
    "def kendall_tau(x, y):\n",
    "    return kendalltau(x, y).correlation\n",
    "\n",
    "MODEL_TO_STRING = {\n",
    "    'claude': 'Claude 2.1',\n",
    "    'llama': 'LLaMA-2-7b-chat',\n",
    "    'human': 'Human',\n",
    "    'gpt4': 'GPT-4 11/06',\n",
    "    'gpt35': 'Llama-2-7b-chat',\n",
    "\n",
    "    'xsum_500_ft_gpt35': '[XSUM] FT GPT-3.5 Turbo 11/06 (500 examples)',\n",
    "    'xsum_10_ft_gpt35': '[XSUM] FT GPT-3.5 Turbo 11/06 (10 examples)',\n",
    "    'xsum_2_ft_gpt35': '[XSUM] FT GPT-3.5 Turbo 11/06 (2 examples)',\n",
    "    'xsum_always_1_ft_gpt35': '[XSUM] FT GPT-3.5 Turbo 11/06 (answers always 1)',\n",
    "    'xsum_random_ft_gpt35': '[XSUM] FT GPT-3.5 Turbo 11/06 (random answers)',\n",
    "    'xsum_readability_ft_gpt35': '[XSUM] FT GPT-3.5 Turbo 11/06 (readability)',\n",
    "    'xsum_length_ft_gpt35': '[XSUM] FT GPT-3.5 Turbo 11/06 (length)',\n",
    "    'xsum_vowelcount_ft_gpt35': '[XSUM] FT GPT-3.5 Turbo 11/06 (vowel count)',\n",
    "\n",
    "    'cnn_500_ft_gpt35': '[CNN] FT GPT-3.5 Turbo 11/06 (500 examples)',\n",
    "    'cnn_10_ft_gpt35': '[CNN] FT GPT-3.5 Turbo 11/06 (10 examples)',\n",
    "    'cnn_2_ft_gpt35': '[CNN] FT GPT-3.5 Turbo 11/06 (2 examples)',\n",
    "    'cnn_always_1_ft_gpt35': '[CNN] FT GPT-3.5 Turbo 11/06 (answers always 1)',\n",
    "    'cnn_random_ft_gpt35': '[CNN] FT GPT-3.5 Turbo 11/06 (random answers)',\n",
    "    'cnn_readability_ft_gpt35': '[CNN] FT GPT-3.5 Turbo 11/06 (readability)',\n",
    "    'cnn_length_ft_gpt35': '[CNN] FT GPT-3.5 Turbo 11/06 (length)',\n",
    "    'cnn_vowelcount_ft_gpt35': '[CNN] FT GPT-3.5 Turbo 11/06 (vowel count)',\n",
    "\n",
    "    'xsum_500_ft_llama': '[XSUM] FT Llama-2-7b-chat (500 examples)',\n",
    "    'xsum_10_ft_llama': '[XSUM] FT Llama-2-7b-chat (10 examples)',\n",
    "    'xsum_2_ft_llama': '[XSUM] FT Llama-2-7b-chat (2 examples)',\n",
    "    'xsum_always_1_ft_llama': '[XSUM] FT Llama-2-7b-chat (answers always 1)',\n",
    "    'xsum_random_ft_llama': '[XSUM] FT Llama-2-7b-chat (random answers)',\n",
    "    'xsum_readability_ft_llama': '[XSUM] FT Llama-2-7b-chat (readability)',\n",
    "    'xsum_length_ft_llama': '[XSUM] FT Llama-2-7b-chat (length)',\n",
    "    'xsum_vowelcount_ft_llama': '[XSUM] FT Llama-2-7b-chat (vowel count)',\n",
    "\n",
    "    'cnn_500_ft_llama': '[CNN] FT Llama-2-7b-chat (500 examples)',\n",
    "    'cnn_10_ft_llama': '[CNN] FT Llama-2-7b-chat (10 examples)',\n",
    "    'cnn_2_ft_llama': '[CNN] FT Llama-2-7b-chat (2 examples)',\n",
    "    'cnn_always_1_ft_llama': '[CNN] FT Llama-2-7b-chat (answers always 1)',\n",
    "    'cnn_random_ft_llama': '[CNN] FT Llama-2-7b-chat (random answers)',\n",
    "    'cnn_readability_ft_llama': '[CNN] FT Llama-2-7b-chat (readability)',\n",
    "    'cnn_length_ft_llama': '[CNN] FT Llama-2-7b-chat (length)',\n",
    "    'cnn_vowelcount_ft_llama': '[CNN] FT Llama-2-7b-chat (vowel count)',\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "main_models = ['gpt4', 'gpt35', 'llama']\n",
    "xsum_models_gpt35 = ['xsum_2_ft_gpt35', 'xsum_10_ft_gpt35', 'xsum_500_ft_gpt35', 'xsum_always_1_ft_gpt35', 'xsum_random_ft_gpt35', 'xsum_readability_ft_gpt35', 'xsum_length_ft_gpt35', 'xsum_vowelcount_ft_gpt35']\n",
    "cnn_models_gpt35 = ['cnn_2_ft_gpt35', 'cnn_10_ft_gpt35', 'cnn_500_ft_gpt35', 'cnn_always_1_ft_gpt35', 'cnn_random_ft_gpt35', 'cnn_readability_ft_gpt35', 'cnn_length_ft_gpt35', 'cnn_vowelcount_ft_gpt35']\n",
    "\n",
    "xsum_models_llama = ['xsum_2_ft_llama', 'xsum_10_ft_llama', 'xsum_500_ft_llama', 'xsum_always_1_ft_llama', 'xsum_random_ft_llama', 'xsum_readability_ft_llama', 'xsum_length_ft_llama', 'xsum_vowelcount_ft_llama']\n",
    "cnn_models_llama = ['cnn_2_ft_llama', 'cnn_10_ft_llama', 'cnn_500_ft_llama', 'cnn_always_1_ft_llama', 'cnn_random_ft_llama', 'cnn_readability_ft_llama', 'cnn_length_ft_llama', 'cnn_vowelcount_ft_llama']\n",
    "\n",
    "models = main_models + xsum_models_gpt35 + cnn_models_gpt35 + xsum_models_llama + cnn_models_llama\n",
    "\n",
    "xsum_responses, xsum_articles, xsum_keys = load_data('xsum')\n",
    "cnn_responses, cnn_articles, cnn_keys = load_data('cnn')\n",
    "\n",
    "xsum_results = {}\n",
    "cnn_results = {}\n",
    "for model in models:\n",
    "    xsum_results[model] = load_from_json(f'results/xsum/{model}_results.json')\n",
    "    cnn_results[model] = load_from_json(f'results/cnn/{model}_results.json')\n",
    "    \n",
    "    if model in main_models:\n",
    "        continue\n",
    "    elif '_2_ft_' in model:\n",
    "        xsum_results[model] = [i for i in xsum_results[model] if i['key'] in xsum_keys[2:]]\n",
    "        cnn_results[model] = [i for i in cnn_results[model] if i['key'] in cnn_keys[2:]]\n",
    "    elif '_10_ft_' in model:\n",
    "        xsum_results[model] = [i for i in xsum_results[model] if i['key'] in xsum_keys[10:]]\n",
    "        cnn_results[model] = [i for i in cnn_results[model] if i['key'] in cnn_keys[10:]]\n",
    "    else:\n",
    "        xsum_results[model] = [i for i in xsum_results[model] if i['key'] in xsum_keys[500:]]\n",
    "        cnn_results[model] = [i for i in cnn_results[model] if i['key'] in cnn_keys[500:]]\n",
    "\n",
    "# For the label result data, \"self_preference\" is the model's preference for the first summary\n",
    "cnn_correct_label_results = {}\n",
    "cnn_wrong_label_results = {}\n",
    "cnn_random_label_results = {}\n",
    "xsum_correct_label_results = {}\n",
    "xsum_wrong_label_results = {}\n",
    "xsum_random_label_results = {}\n",
    "for model in main_models:\n",
    "    dataset = 'cnn'\n",
    "    cnn_correct_label_results[model] = load_from_json(f'label_results/correct_label_results/{dataset}/{model}_results.json')\n",
    "    cnn_wrong_label_results[model] = load_from_json(f'label_results/wrong_label_results/{dataset}/{model}_results.json')\n",
    "    # cnn_random_label_results[model] = load_from_json(f'label_results/random_label_results/{dataset}/{model}_results.json')\n",
    "    dataset = 'xsum'\n",
    "    xsum_correct_label_results[model] = load_from_json(f'label_results/correct_label_results/{dataset}/{model}_results.json')\n",
    "    xsum_wrong_label_results[model] = load_from_json(f'label_results/wrong_label_results/{dataset}/{model}_results.json')\n",
    "    # xsum_random_label_results[model] = load_from_json(f'label_results/random_label_results/{dataset}/{model}_results.json')\n",
    "\n",
    "for results in [cnn_correct_label_results, cnn_wrong_label_results, xsum_correct_label_results, xsum_wrong_label_results]:\n",
    "    results['llama'] = [result for result in results['llama'] if any(r['key'] == result['key'] for r in results['gpt4'])]\n",
    "\n",
    "# Individual setting results\n",
    "xsum_detection_results = {}\n",
    "xsum_score_results = {}\n",
    "cnn_detection_results = {}\n",
    "cnn_score_results = {}\n",
    "for model in models:\n",
    "    dataset = 'xsum'\n",
    "    xsum_detection_results[model] = load_from_json(f'individual_setting_results/recognition_results/{dataset}/{model}_results.json')\n",
    "    xsum_score_results[model] = load_from_json(f'individual_setting_results/score_results/{dataset}/{model}_results.json')\n",
    "    dataset = 'cnn'\n",
    "    cnn_detection_results[model] = load_from_json(f'individual_setting_results/recognition_results/{dataset}/{model}_results.json')\n",
    "    cnn_score_results[model] = load_from_json(f'individual_setting_results/score_results/{dataset}/{model}_results.json')\n",
    "\n",
    "    if model in main_models:\n",
    "        continue\n",
    "    elif '_2_ft_' in model:\n",
    "        xsum_detection_results[model] = [i for i in xsum_detection_results[model] if i['key'] in xsum_keys[2:]]\n",
    "        xsum_score_results[model] = [i for i in xsum_score_results[model] if i['key'] in xsum_keys[2:]]\n",
    "        cnn_detection_results[model] = [i for i in cnn_detection_results[model] if i['key'] in cnn_keys[2:]]\n",
    "        cnn_score_results[model] = [i for i in cnn_score_results[model] if i['key'] in cnn_keys[2:]]\n",
    "    elif '_10_ft_' in model:\n",
    "        xsum_detection_results[model] = [i for i in xsum_detection_results[model] if i['key'] in xsum_keys[10:]]\n",
    "        xsum_score_results[model] = [i for i in xsum_score_results[model] if i['key'] in xsum_keys[10:]]\n",
    "        cnn_detection_results[model] = [i for i in cnn_detection_results[model] if i['key'] in cnn_keys[10:]]\n",
    "        cnn_score_results[model] = [i for i in cnn_score_results[model] if i['key'] in cnn_keys[10:]]\n",
    "    else:\n",
    "        xsum_detection_results[model] = [i for i in xsum_detection_results[model] if i['key'] in xsum_keys[500:]]\n",
    "        xsum_score_results[model] = [i for i in xsum_score_results[model] if i['key'] in xsum_keys[500:]]\n",
    "        cnn_detection_results[model] = [i for i in cnn_detection_results[model] if i['key'] in cnn_keys[500:]]\n",
    "        cnn_score_results[model] = [i for i in cnn_score_results[model] if i['key'] in cnn_keys[500:]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| Model                     |       XSUM |        CNN |   # Bad Outputs (e.g. NaN) - XSUM |   # Bad Outputs (e.g. NaN) - CNN |\n",
      "|---------------------------|------------|------------|-----------------------------------|----------------------------------|\n",
      "| xsum_2_ft_gpt35           |   0.239259 |   0.409474 |                                 0 |                                0 |\n",
      "| xsum_10_ft_gpt35          |   0.351976 |   0.440032 |                                 0 |                                0 |\n",
      "| xsum_500_ft_gpt35         |   0.181138 |   0.830678 |                                 0 |                                1 |\n",
      "| xsum_always_1_ft_gpt35    |   0.5      |   0.5      |                                 0 |                                0 |\n",
      "| xsum_random_ft_gpt35      |   0.5      |   0.5      |                                 0 |                                0 |\n",
      "| xsum_readability_ft_gpt35 |   0.569504 |   0.418801 |                                 0 |                                0 |\n",
      "| xsum_length_ft_gpt35      |   0.295818 |   0.630802 |                                 0 |                                0 |\n",
      "| xsum_vowelcount_ft_gpt35  |   0.228472 |   0.840382 |                                 0 |                                0 |\n",
      "| cnn_2_ft_gpt35            |   0.198612 |   0.32733  |                                 0 |                                1 |\n",
      "| cnn_10_ft_gpt35           |   0.110884 |   0.412831 |                                 0 |                                0 |\n",
      "| cnn_500_ft_gpt35          |   0.252215 |   0.44     |                                 0 |                                0 |\n",
      "| cnn_always_1_ft_gpt35     |   0.5      |   0.5      |                                 0 |                                0 |\n",
      "| cnn_random_ft_gpt35       |   0.500977 |   0.502331 |                                 0 |                                0 |\n",
      "| cnn_readability_ft_gpt35  |   0.632709 |   0.537255 |                                 0 |                                0 |\n",
      "| cnn_length_ft_gpt35       |   0.305934 |   0.557903 |                                 0 |                                0 |\n",
      "| cnn_vowelcount_ft_gpt35   |   0.351465 |   0.653654 |                                 0 |                                0 |\n",
      "| xsum_2_ft_llama           |   0.497366 |   0.477095 |                                 0 |                                0 |\n",
      "| xsum_10_ft_llama          |   0.492147 |   0.443973 |                                 0 |                                0 |\n",
      "| xsum_500_ft_llama         |   0.372354 |   0.442329 |                                 0 |                                0 |\n",
      "| xsum_always_1_ft_llama    | nan        | nan        |                                50 |                               50 |\n",
      "| xsum_random_ft_llama      |   0.501927 |   0.490814 |                                 0 |                                0 |\n",
      "| xsum_readability_ft_llama |   0.498449 |   0.489863 |                                 0 |                                0 |\n",
      "| xsum_length_ft_llama      |   0.5367   |   0.503469 |                                 0 |                                0 |\n",
      "| xsum_vowelcount_ft_llama  |   0.532137 |   0.528725 |                                 0 |                                0 |\n",
      "| cnn_2_ft_llama            |   0.498157 |   0.456382 |                                 0 |                                0 |\n",
      "| cnn_10_ft_llama           |   0.501493 |   0.46072  |                                 0 |                                0 |\n",
      "| cnn_500_ft_llama          | nan        | nan        |                                 1 |                                5 |\n",
      "| cnn_always_1_ft_llama     | nan        | nan        |                                 4 |                                8 |\n",
      "| cnn_random_ft_llama       |   0.499854 |   0.500691 |                                 0 |                                0 |\n",
      "| cnn_readability_ft_llama  |   0.491107 |   0.500197 |                                 0 |                                0 |\n",
      "| cnn_length_ft_llama       | nan        | nan        |                                50 |                               50 |\n",
      "| cnn_vowelcount_ft_llama   |   0.494051 |   0.489912 |                                 0 |                                0 |\n",
      "\n",
      "\n",
      "Without degraded generations\n",
      "\n",
      "\n",
      "| Model                     |     XSUM |      CNN |   # Bad Outputs (e.g. NaN) - XSUM |   # Bad Outputs (e.g. NaN) - CNN |\n",
      "|---------------------------|----------|----------|-----------------------------------|----------------------------------|\n",
      "| xsum_2_ft_gpt35           | 0.239259 | 0.409474 |                                 0 |                                0 |\n",
      "| xsum_10_ft_gpt35          | 0.351976 | 0.440032 |                                 0 |                                0 |\n",
      "| xsum_500_ft_gpt35         | 0.181138 | 0.830678 |                                 0 |                                1 |\n",
      "| xsum_always_1_ft_gpt35    | 0.5      | 0.5      |                                 0 |                                0 |\n",
      "| xsum_random_ft_gpt35      | 0.5      | 0.5      |                                 0 |                                0 |\n",
      "| xsum_readability_ft_gpt35 | 0.569504 | 0.418801 |                                 0 |                                0 |\n",
      "| xsum_length_ft_gpt35      | 0.295818 | 0.630802 |                                 0 |                                0 |\n",
      "| xsum_vowelcount_ft_gpt35  | 0.228472 | 0.840382 |                                 0 |                                0 |\n",
      "| cnn_2_ft_gpt35            | 0.198612 | 0.32733  |                                 0 |                                1 |\n",
      "| cnn_10_ft_gpt35           | 0.110884 | 0.412831 |                                 0 |                                0 |\n",
      "| cnn_500_ft_gpt35          | 0.252215 | 0.44     |                                 0 |                                0 |\n",
      "| cnn_always_1_ft_gpt35     | 0.5      | 0.5      |                                 0 |                                0 |\n",
      "| cnn_random_ft_gpt35       | 0.500977 | 0.502331 |                                 0 |                                0 |\n",
      "| cnn_readability_ft_gpt35  | 0.632709 | 0.537255 |                                 0 |                                0 |\n",
      "| cnn_length_ft_gpt35       | 0.305934 | 0.557903 |                                 0 |                                0 |\n",
      "| cnn_vowelcount_ft_gpt35   | 0.351465 | 0.653654 |                                 0 |                                0 |\n",
      "| xsum_2_ft_llama           | 0.497366 | 0.477095 |                                 0 |                                0 |\n",
      "| xsum_10_ft_llama          | 0.492147 | 0.443973 |                                 0 |                                0 |\n",
      "| xsum_500_ft_llama         | 0.372354 | 0.442329 |                                 0 |                                0 |\n",
      "| xsum_random_ft_llama      | 0.501927 | 0.490814 |                                 0 |                                0 |\n",
      "| xsum_readability_ft_llama | 0.498449 | 0.489863 |                                 0 |                                0 |\n",
      "| xsum_length_ft_llama      | 0.5367   | 0.503469 |                                 0 |                                0 |\n",
      "| xsum_vowelcount_ft_llama  | 0.532137 | 0.528725 |                                 0 |                                0 |\n",
      "| cnn_2_ft_llama            | 0.498157 | 0.456382 |                                 0 |                                0 |\n",
      "| cnn_10_ft_llama           | 0.501493 | 0.46072  |                                 0 |                                0 |\n"
     ]
    }
   ],
   "source": [
    "# Cross-Model Evals\n",
    "\n",
    "print(tabulate(\n",
    "    [\n",
    "        [\n",
    "            model,\n",
    "            avg([r[\"new_preference\"] for r in load_from_json(f\"comparisons/xsum/{model}_comparisons.json\") if \"new_preference\" in r]),\n",
    "            avg([r[\"new_preference\"] for r in load_from_json(f\"comparisons/cnn/{model}_comparisons.json\") if \"new_preference\" in r]),\n",
    "            len([r for r in load_from_json(f\"comparisons/xsum/{model}_comparisons.json\") if \"new_preference\" not in r or r['new_preference'] != r['new_preference']]),\n",
    "            len([r for r in load_from_json(f\"comparisons/cnn/{model}_comparisons.json\") if \"new_preference\" not in r or r['new_preference'] != r['new_preference']]),\n",
    "        ]\n",
    "        for model in xsum_models_gpt35 + cnn_models_gpt35 + xsum_models_llama + cnn_models_llama\n",
    "    ],\n",
    "    headers=[\"Model\", \"XSUM\", \"CNN\", \"# Bad Outputs (e.g. NaN) - XSUM\", \"# Bad Outputs (e.g. NaN) - CNN\"],\n",
    "    tablefmt=\"github\",\n",
    "))\n",
    "\n",
    "ones_twos_models = ['cnn_always_1_ft_llama', 'cnn_readability_ft_llama', 'cnn_vowelcount_ft_llama', 'cnn_random_ft_llama']\n",
    "generation_error_models = ['cnn_500_ft_llama', 'cnn_length_ft_llama', 'xsum_always_1_ft_llama']\n",
    "\n",
    "print('\\n\\nWithout degraded generations\\n\\n')\n",
    "print(tabulate(\n",
    "    [\n",
    "        [\n",
    "            model,\n",
    "            avg([r[\"new_preference\"] for r in load_from_json(f\"comparisons/xsum/{model}_comparisons.json\") if \"new_preference\" in r]),\n",
    "            avg([r[\"new_preference\"] for r in load_from_json(f\"comparisons/cnn/{model}_comparisons.json\") if \"new_preference\" in r]),\n",
    "            len([r for r in load_from_json(f\"comparisons/xsum/{model}_comparisons.json\") if \"new_preference\" not in r or r['new_preference'] != r['new_preference']]),\n",
    "            len([r for r in load_from_json(f\"comparisons/cnn/{model}_comparisons.json\") if \"new_preference\" not in r or r['new_preference'] != r['new_preference']]),\n",
    "        ]\n",
    "        for model in xsum_models_gpt35 + cnn_models_gpt35 + xsum_models_llama + cnn_models_llama if model not in ones_twos_models + generation_error_models\n",
    "    ],\n",
    "    headers=[\"Model\", \"XSUM\", \"CNN\", \"# Bad Outputs (e.g. NaN) - XSUM\", \"# Bad Outputs (e.g. NaN) - CNN\"],\n",
    "    tablefmt=\"github\",\n",
    "))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Correct Labels\n",
      "Model      Self-Preference (XSUM)    Self-Preference (CNN)\n",
      "-------  ------------------------  -----------------------\n",
      "gpt4                     0.73245                  0.936168\n",
      "gpt35                    0.605016                 0.46125\n",
      "llama                    0.819782                 0.973418\n",
      "\n",
      "\n",
      "\n",
      "Incorrect Labels\n",
      "Model      Self-Preference (XSUM)    Self-Preference (CNN)\n",
      "-------  ------------------------  -----------------------\n",
      "gpt4                     0.315728                 0.89418\n",
      "gpt35                    0.455228                 0.400281\n",
      "llama                    0.834755                 0.973516\n"
     ]
    }
   ],
   "source": [
    "# Label Results\n",
    "\n",
    "table = [[model,\n",
    "            # avg([i['detection_score'] for i in xsum_correct_label_results[model]]),\n",
    "            avg([i['self_preference'] for i in xsum_correct_label_results[model]]),\n",
    "            # avg([i['detection_score'] for i in cnn_correct_label_results[model]]),\n",
    "            avg([i['self_preference'] for i in cnn_correct_label_results[model]])\n",
    "] for model in main_models]\n",
    "\n",
    "print('Correct Labels')\n",
    "print(tabulate(table, headers=['Model', 'Self-Preference (XSUM)', 'Self-Preference (CNN)']))\n",
    "\n",
    "table = [[model,\n",
    "            # avg([i['detection_score'] for i in xsum_wrong_label_results[model]]),\n",
    "            avg([i['self_preference'] for i in xsum_wrong_label_results[model]]),\n",
    "            # avg([i['detection_score'] for i in cnn_wrong_label_results[model]]),\n",
    "            avg([i['self_preference'] for i in cnn_wrong_label_results[model]])\n",
    "] for model in main_models]\n",
    "\n",
    "print('\\n\\n')\n",
    "print('Incorrect Labels')\n",
    "print(tabulate(table, headers=['Model', 'Self-Preference (XSUM)', 'Self-Preference (CNN)']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Individual Setting\n",
    "\n",
    "table = [[model, \n",
    "                        avg([result['ratio'] for result in xsum_detection_results[model] if result['target_model'] == 'gpt4' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in xsum_score_results[model] if result['target_model'] == 'gpt4' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in xsum_detection_results[model] if result['target_model'] == 'gpt35' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in xsum_score_results[model] if result['target_model'] == 'gpt35' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in xsum_detection_results[model] if result['target_model'] == 'llama' and result['ratio'] == result['ratio']]),                        \n",
    "                        avg([result['ratio'] for result in xsum_score_results[model] if result['target_model'] == 'llama' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in xsum_detection_results[model] if result['target_model'] == 'human' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in xsum_score_results[model] if result['target_model'] == 'human' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in xsum_detection_results[model] if result['target_model'] == 'claude' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in xsum_score_results[model] if result['target_model'] == 'claude' and result['ratio'] == result['ratio']]),\n",
    "\n",
    "                        avg([result['ratio'] for result in cnn_detection_results[model] if result['target_model'] == 'gpt4' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in cnn_score_results[model] if result['target_model'] == 'gpt4' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in cnn_detection_results[model] if result['target_model'] == 'gpt35' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in cnn_score_results[model] if result['target_model'] == 'gpt35' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in cnn_detection_results[model] if result['target_model'] == 'llama' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in cnn_score_results[model] if result['target_model'] == 'llama' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in cnn_detection_results[model] if result['target_model'] == 'human' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in cnn_score_results[model] if result['target_model'] == 'human' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in cnn_detection_results[model] if result['target_model'] == 'claude' and result['ratio'] == result['ratio']]),\n",
    "                        avg([result['ratio'] for result in cnn_score_results[model] if result['target_model'] == 'claude' and result['ratio'] == result['ratio']]),\n",
    "] for model in models if model != 'cnn_length_ft_llama']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model                        Self-Rec    Self-Pref    Self-Rec    Self-Pref    Self-Rec    Self-Pref    Self-Rec    Self-Pref    Self-Rec    Self-Pref\n",
      "-------------------------  ----------  -----------  ----------  -----------  ----------  -----------  ----------  -----------  ----------  -----------\n",
      "gpt4                            0.5          0.5         0.602        0.516       0.619        0.52        0.715        0.536       0.634        0.518\n",
      "gpt35                           0.493        0.492       0.5          0.5         0.502        0.502       0.518        0.516       0.498        0.499\n",
      "llama                           0.501        0.5         0.495        0.501       0.5          0.5         0.495        0.502       0.503        0.501\n",
      "xsum_2_ft_gpt35                 0.491        0.492       0.5          0.5         0.501        0.503       0.53         0.52        0.503        0.502\n",
      "xsum_10_ft_gpt35                0.492        0.494       0.5          0.5         0.503        0.502       0.54         0.518       0.507        0.502\n",
      "xsum_500_ft_gpt35               0.495        0.536       0.5          0.5         0.506        0.537       0.671        0.602       0.607        0.578\n",
      "xsum_always_1_ft_gpt35          0.49         0.499       0.5          0.5         0.493        0.501       0.495        0.501       0.495        0.5\n",
      "xsum_random_ft_gpt35            0.488        0.499       0.5          0.5         0.492        0.501       0.492        0.501       0.494        0.5\n",
      "xsum_readability_ft_gpt35       0.507        0.496       0.5          0.5         0.53         0.53        0.568        0.577       0.531        0.524\n",
      "xsum_length_ft_gpt35            0.502        0.489       0.5          0.5         0.507        0.5         0.541        0.52        0.511        0.503\n",
      "xsum_vowelcount_ft_gpt35        0.5          0.49        0.5          0.5         0.5          0.501       0.508        0.518       0.501        0.503\n",
      "cnn_2_ft_gpt35                  0.484        0.494       0.5          0.5         0.49         0.503       0.516        0.521       0.494        0.503\n",
      "cnn_10_ft_gpt35                 0.49         0.495       0.5          0.5         0.495        0.505       0.525        0.525       0.498        0.504\n",
      "cnn_500_ft_gpt35                0.721        0.494       0.5          0.5         0.723        0.512       0.888        0.625       0.806        0.538\n",
      "cnn_always_1_ft_gpt35           0.497        0.499       0.5          0.5         0.5          0.5         0.501        0.505       0.502        0.5\n",
      "cnn_random_ft_gpt35             0.498        0.494       0.5          0.5         0.501        0.499       0.501        0.505       0.5          0.499\n",
      "cnn_readability_ft_gpt35        0.489        0.467       0.5          0.5         0.507        0.5         0.543        0.579       0.508        0.499\n",
      "cnn_length_ft_gpt35             0.505        0.481       0.5          0.5         0.519        0.489       0.544        0.514       0.517        0.494\n",
      "cnn_vowelcount_ft_gpt35         0.497        0.496       0.5          0.5         0.499        0.497       0.544        0.514       0.508        0.5\n",
      "xsum_2_ft_llama                 0.504        0.5         0.494        0.501       0.5          0.5         0.492        0.502       0.505        0.501\n",
      "xsum_10_ft_llama                0.505        0.5         0.497        0.501       0.5          0.5         0.501        0.501       0.51         0.501\n",
      "xsum_500_ft_llama               0.503        0.496       0.484        0.501       0.5          0.5         0.463        0.508       0.491        0.498\n",
      "xsum_always_1_ft_llama          0.5          0.5         0.5          0.487       0.5          0.5         0.5          0.516       0.5          0.479\n",
      "xsum_random_ft_llama            0.501        0.5         0.498        0.5         0.5          0.5         0.498        0.503       0.502        0.5\n",
      "xsum_readability_ft_llama       0.498        0.5         0.499        0.5         0.5          0.5         0.496        0.502       0.502        0.5\n",
      "xsum_length_ft_llama            0.5          0.5         0.474        0.5         0.5          0.5         0.467        0.501       0.488        0.5\n",
      "xsum_vowelcount_ft_llama        0.509        0.499       0.48         0.5         0.5          0.5         0.481        0.501       0.497        0.5\n",
      "cnn_2_ft_llama                  0.5          0.5         0.497        0.5         0.5          0.5         0.499        0.502       0.501        0.501\n",
      "cnn_10_ft_llama                 0.502        0.5         0.498        0.5         0.5          0.5         0.5          0.502       0.506        0.5\n",
      "cnn_500_ft_llama                0.508        0.498       0.501        0.499       0.5          0.5         0.499        0.498       0.502        0.499\n",
      "cnn_always_1_ft_llama           0.5          0.5         0.5          0.5         0.5          0.5         0.5          0.5         0.5          0.5\n",
      "cnn_random_ft_llama             0.501        0.5         0.5          0.5         0.5          0.5         0.5          0.5         0.501        0.5\n",
      "cnn_readability_ft_llama        0.511        0.501       0.508        0.499       0.5          0.5         0.518        0.498       0.504        0.499\n",
      "cnn_vowelcount_ft_llama         0.5          0.501       0.503        0.501       0.5          0.5         0.502        0.501       0.505        0.502\n"
     ]
    }
   ],
   "source": [
    "# table = [[row[0]] + [round(i, 2) for i in row[1:]] for row in table]\n",
    "rounded_table = [[row[0]] + ['{:.3f}'.format(i) for i in row[1:]] for row in table]\n",
    "\n",
    "print(tabulate([row[0:1] + row[11:] for row in rounded_table], headers = ['Model', 'Self-Rec', 'Self-Pref', 'Self-Rec', 'Self-Pref', 'Self-Rec', 'Self-Pref', 'Self-Rec', 'Self-Pref', 'Self-Rec', 'Self-Pref', ]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model                        Ambiguous    Correct    Incorrect    Ambiguous    Correct    Incorrect\n",
      "-------------------------  -----------  ---------  -----------  -----------  ---------  -----------\n",
      "gpt4                             0.383      0.595        0.022        0.088      0.877        0.034\n",
      "gpt35                            0.62       0.149        0.23         0.517      0.151        0.332\n",
      "llama                            1          0            0            1          0            0.001\n",
      "xsum_2_ft_gpt35                  0.815      0.046        0.139        0.442      0.15         0.409\n",
      "xsum_10_ft_gpt35                 0.805      0.086        0.109        0.479      0.181        0.34\n",
      "xsum_500_ft_gpt35                0.194      0.651        0.155        0.193      0.654        0.153\n",
      "xsum_always_1_ft_gpt35           1          0            0            1          0            0\n",
      "xsum_random_ft_gpt35             1          0            0            1          0            0\n",
      "xsum_readability_ft_gpt35        0.286      0.383        0.332        0.28       0.412        0.308\n",
      "xsum_length_ft_gpt35             0.79       0.082        0.128        0.597      0.128        0.275\n",
      "xsum_vowelcount_ft_gpt35         0.601      0.117        0.282        0.17       0.239        0.591\n",
      "cnn_2_ft_gpt35                   0.665      0.167        0.169        0.454      0.188        0.358\n",
      "cnn_10_ft_gpt35                  0.55       0.311        0.139        0.34       0.317        0.343\n",
      "cnn_500_ft_gpt35                 0.054      0.932        0.013        0.031      0.955        0.014\n",
      "cnn_always_1_ft_gpt35            1          0            0            1          0            0\n",
      "cnn_random_ft_gpt35              1          0            0            1          0            0\n",
      "cnn_readability_ft_gpt35         0.171      0.629        0.2          0.147      0.61         0.243\n",
      "cnn_length_ft_gpt35              0.152      0.093        0.754        0.125      0.124        0.75\n",
      "cnn_vowelcount_ft_gpt35          0.143      0.104        0.752        0.07       0.137        0.793\n",
      "xsum_2_ft_llama                  0.952      0.033        0.015        0.997      0.001        0.002\n",
      "xsum_10_ft_llama                 0.881      0.083        0.037        0.976      0.018        0.006\n",
      "xsum_500_ft_llama                0.922      0.061        0.017        0.892      0.086        0.021\n",
      "xsum_always_1_ft_llama           1          0            0            1          0            0\n",
      "xsum_random_ft_llama             0.957      0.025        0.018        0.998      0.002        0.001\n",
      "xsum_readability_ft_llama        0.978      0.011        0.011        1          0.001        0\n",
      "xsum_length_ft_llama             0.523      0.355        0.122        0.957      0.035        0.009\n",
      "xsum_vowelcount_ft_llama         0.914      0.065        0.021        0.981      0.016        0.003\n",
      "cnn_2_ft_llama                   0.833      0.113        0.055        0.868      0.092        0.041\n",
      "cnn_10_ft_llama                  0.89       0.077        0.033        0.988      0.009        0.003\n",
      "cnn_500_ft_llama                 0.926      0.035        0.039        0.923      0.04         0.037\n",
      "cnn_always_1_ft_llama            0.976      0.013        0.011        0.973      0.018        0.009\n",
      "cnn_random_ft_llama              0.982      0.009        0.01         0.984      0.007        0.009\n",
      "cnn_readability_ft_llama         0.765      0.103        0.131        0.779      0.102        0.119\n",
      "cnn_length_ft_llama              0.536      0.351        0.113        0.696      0.232        0.073\n",
      "cnn_vowelcount_ft_llama          0.942      0.037        0.021        0.938      0.037        0.025\n"
     ]
    }
   ],
   "source": [
    "def print_ambig_table(results):\n",
    "    task = 'detection' # 'comparison\n",
    "    task2 = 'comparison'\n",
    "    table = [[model, \n",
    "                            avg([result[f'forward_{task}'] == result[f'backward_{task}'] for result in results[model]]), \n",
    "                            avg([result[f'forward_{task}'] == '1' and result[f'backward_{task}'] == '2' for result in results[model]]),\n",
    "                            avg([result[f'forward_{task}'] == '2' and result[f'backward_{task}'] == '1' for result in results[model]]),\n",
    "                            \n",
    "                            avg([result[f'forward_{task2}'] == result[f'backward_{task2}'] for result in results[model]]), \n",
    "                            avg([result[f'forward_{task2}'] == '1' and result[f'backward_{task2}'] == '2' for result in results[model]]),\n",
    "                            avg([result[f'forward_{task2}'] == '2' and result[f'backward_{task2}'] == '1' for result in results[model]]),\n",
    "    ] for model in models]\n",
    "\n",
    "    # table = [row + [row[2] / (row[2] + row[3]) if any(i != 0 for i in [row[2], row[3]]) else 0] for row in table]\n",
    "    # table = [row + [row[5] / (row[5] + row[6]) if any(i != 0 for i in [row[2], row[3]]) else 0] for row in table]\n",
    "\n",
    "    table = [row[0:1] + [round(i, 3) for i in row[1:]] for row in table]\n",
    "    print(tabulate(table, headers = ['Model', 'Ambiguous', 'Correct', 'Incorrect', 'Ambiguous', 'Correct', 'Incorrect']))\n",
    "    # print(tabulate(sorted())\n",
    "\n",
    "print_ambig_table(cnn_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model                        Self-Rec (XSUM)    Zapped    Self-Pref (XSUM)    Zapped    Self-Rec (CNN)    Zapped    Self-Pref \"(CNN)    Zapped\n",
      "-------------------------  -----------------  --------  ------------------  --------  ----------------  --------  ------------------  --------\n",
      "cnn_length_ft_gpt35                 0.574447  0.588               0.571714  0.571             0.169116  0.1355              0.187859  0.1565\n",
      "cnn_vowelcount_ft_gpt35             0.607817  0.615               0.585983  0.5945            0.175772  0.1645              0.171075  0.1675\n",
      "xsum_vowelcount_ft_gpt35            0.60012   0.5945              0.597612  0.5945            0.416476  0.3065              0.326374  0.3135\n",
      "xsum_2_ft_gpt35                     0.630637  0.649242            0.618081  0.626263          0.452511  0.351954            0.376425  0.338677\n",
      "gpt35                               0.534677  0.57725             0.581617  0.604             0.480843  0.43125             0.431062  0.358\n",
      "xsum_10_ft_gpt35                    0.674499  0.689899            0.656532  0.658838          0.488748  0.459091            0.421113  0.382828\n",
      "cnn_2_ft_gpt35                      0.619661  0.65506             0.586618  0.594689          0.496871  0.456663            0.42312   0.38477\n",
      "xsum_length_ft_gpt35                0.572245  0.5835              0.566803  0.5745            0.474401  0.2995              0.427378  0.4175\n",
      "cnn_readability_ft_llama            0.501337  0.4345              0.463783  0.2895            0.495163  0.48525             0.48851   0.4445\n",
      "cnn_10_ft_gpt35                     0.648517  0.690152            0.627085  0.644697          0.586736  0.615909            0.487369  0.464141\n",
      "cnn_500_ft_llama                    0.55599   0.682               0.434032  0.283             0.59203   0.746               0.499934  0.50225\n",
      "xsum_always_1_ft_llama              0.499844  0.322632            0.499799  0.455263          0.5       0.52                0.5       0.504211\n",
      "xsum_readability_ft_gpt35           0.404802  0.373               0.398861  0.3715            0.505353  0.526               0.530689  0.562\n",
      "cnn_random_ft_gpt35                 0.499629  0.4155              0.499699  0.448             0.500416  0.572               0.500584  0.5925\n",
      "llama                               0.513524  0.5275              0.510752  0.5485            0.505046  0.6375              0.50492   0.64175\n",
      "cnn_length_ft_llama                 0.489451  0.432               0.486789  0.3995            0.548378  0.6765              0.54111   0.666\n",
      "cnn_readability_ft_gpt35            0.450028  0.403               0.415862  0.3835            0.616615  0.727               0.628716  0.701\n",
      "xsum_500_ft_gpt35                   0.895799  0.913               0.89785   0.914             0.73849   0.75375             0.749659  0.7645\n",
      "xsum_always_1_ft_gpt35              0.499999  0.598               0.5       0.8565            0.5       0.5635              0.5       0.8945\n",
      "xsum_random_ft_gpt35                0.499999  0.598               0.5       0.86              0.5       0.559               0.5       0.898\n",
      "cnn_random_ft_llama                 0.673236  0.9335              0.675831  0.9555            0.637675  0.867               0.653747  0.902\n",
      "cnn_vowelcount_ft_llama             0.580018  0.83                0.580531  0.811             0.571426  0.90475             0.580851  0.908\n",
      "cnn_2_ft_llama                      0.357445  0.218938            0.502014  0.506012          0.566808  0.649549            0.702605  0.926603\n",
      "cnn_always_1_ft_llama               0.5       0.5                 0.5       0.5               0.949     0.949               0.933     0.93325\n",
      "gpt4                                0.671656  0.70575             0.705006  0.71875           0.74652   0.88675             0.912338  0.939\n",
      "cnn_500_ft_gpt35                    0.763612  0.781               0.786999  0.792             0.959415  0.977               0.970093  0.973\n",
      "xsum_500_ft_llama                   0.454197  0.323               0.485196  0.4695            0.792802  0.9855              0.788255  0.9735\n",
      "xsum_vowelcount_ft_llama            0.480679  0.37625             0.576182  0.7355            0.780833  0.971               0.903391  0.9965\n",
      "xsum_10_ft_llama                    0.525555  0.569697            0.665453  0.851515          0.680572  0.920202            0.80993   0.997475\n",
      "cnn_always_1_ft_gpt35               0.5       1                   0.5       0.9985            0.5       1                   0.5       0.9975\n",
      "cnn_10_ft_llama                     0.519166  0.520455            0.655903  0.794192          0.66481   0.876768            0.825356  0.998485\n",
      "xsum_length_ft_llama                0.342242  0.0185              0.482781  0.39475           0.535485  0.6105              0.803763  0.9985\n",
      "xsum_2_ft_llama                     0.591835  0.699775            0.742704  0.927605          0.798989  0.975701            0.904906  0.999749\n",
      "xsum_random_ft_llama                0.543247  0.7465              0.647942  0.9415            0.618494  0.9685              0.753334  1\n",
      "xsum_readability_ft_llama           0.558076  0.7605              0.708908  0.9495            0.675009  0.9765              0.793965  1\n"
     ]
    }
   ],
   "source": [
    "def print_zapped_table():\n",
    "    table = [[model, \n",
    "                        avg([result['detection_score'] for result in xsum_results[model]]),\n",
    "                        avg([0.5 if result['detection_score'] == 0.5 else 1 if result['detection_score'] > 0.5 else 0 for result in xsum_results[model]]),\n",
    "\n",
    "                        avg([result['self_preference'] for result in xsum_results[model]]),\n",
    "                        avg([0.5 if result['self_preference'] == 0.5 else 1 if result['self_preference'] > 0.5 else 0 for result in xsum_results[model]]),\n",
    "\n",
    "                        avg([result['detection_score'] for result in cnn_results[model]]),\n",
    "                        avg([0.5 if result['detection_score'] == 0.5 else 1 if result['detection_score'] > 0.5 else 0 for result in cnn_results[model]]),\n",
    "\n",
    "                        avg([result['self_preference'] for result in cnn_results[model]]),\n",
    "                        avg([0.5 if result['self_preference'] == 0.5 else 1 if result['self_preference'] > 0.5 else 0 for result in cnn_results[model]]),\n",
    "\n",
    "    ] for model in models]\n",
    "\n",
    "    print(tabulate(sorted(table, key = lambda x: x[-1]), headers = ['Model', 'Self-Rec (XSUM)', 'Zapped', 'Self-Pref (XSUM)', 'Zapped', 'Self-Rec (CNN)', 'Zapped', 'Self-Pref (CNN)', 'Zapped']))\n",
    "    # print(tabulate(sorted())\n",
    "\n",
    "print_zapped_table()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model                        Self-Rec    Self-Pref    Self-Rec    Self-Pref\n",
      "-------------------------  ----------  -----------  ----------  -----------\n",
      "gpt4                            0.672        0.705       0.747        0.912\n",
      "gpt35                           0.535        0.582       0.481        0.431\n",
      "llama                           0.514        0.511       0.505        0.505\n",
      "xsum_2_ft_gpt35                 0.631        0.618       0.453        0.376\n",
      "xsum_10_ft_gpt35                0.674        0.657       0.489        0.421\n",
      "xsum_500_ft_gpt35               0.896        0.898       0.738        0.75\n",
      "xsum_always_1_ft_gpt35          0.5          0.5         0.5          0.5\n",
      "xsum_random_ft_gpt35            0.5          0.5         0.5          0.5\n",
      "xsum_readability_ft_gpt35       0.405        0.399       0.505        0.531\n",
      "xsum_length_ft_gpt35            0.572        0.567       0.474        0.427\n",
      "xsum_vowelcount_ft_gpt35        0.6          0.598       0.416        0.326\n",
      "cnn_2_ft_gpt35                  0.62         0.587       0.497        0.423\n",
      "cnn_10_ft_gpt35                 0.649        0.627       0.587        0.487\n",
      "cnn_500_ft_gpt35                0.764        0.787       0.959        0.97\n",
      "cnn_always_1_ft_gpt35           0.5          0.5         0.5          0.5\n",
      "cnn_random_ft_gpt35             0.5          0.5         0.5          0.501\n",
      "cnn_readability_ft_gpt35        0.45         0.416       0.617        0.629\n",
      "cnn_length_ft_gpt35             0.574        0.572       0.169        0.188\n",
      "cnn_vowelcount_ft_gpt35         0.608        0.586       0.176        0.171\n",
      "xsum_2_ft_llama                 0.592        0.743       0.799        0.905\n",
      "xsum_10_ft_llama                0.526        0.665       0.681        0.81\n",
      "xsum_500_ft_llama               0.454        0.485       0.793        0.788\n",
      "xsum_always_1_ft_llama          0.5          0.5         0.5          0.5\n",
      "xsum_random_ft_llama            0.543        0.648       0.618        0.753\n",
      "xsum_readability_ft_llama       0.558        0.709       0.675        0.794\n",
      "xsum_length_ft_llama            0.342        0.483       0.535        0.804\n",
      "xsum_vowelcount_ft_llama        0.481        0.576       0.781        0.903\n",
      "cnn_2_ft_llama                  0.357        0.502       0.567        0.703\n",
      "cnn_10_ft_llama                 0.519        0.656       0.665        0.825\n",
      "cnn_500_ft_llama                0.556        0.434       0.592        0.5\n",
      "cnn_always_1_ft_llama           0.5          0.5         0.949        0.933\n",
      "cnn_random_ft_llama             0.673        0.676       0.638        0.654\n",
      "cnn_readability_ft_llama        0.501        0.464       0.495        0.489\n",
      "cnn_length_ft_llama             0.489        0.487       0.548        0.541\n",
      "cnn_vowelcount_ft_llama         0.58         0.581       0.571        0.581\n"
     ]
    }
   ],
   "source": [
    "# Main pairwise results\n",
    "table = [[model, \n",
    "                    avg([result['detection_score'] for result in xsum_results[model]]),\n",
    "                    avg([result['self_preference'] for result in xsum_results[model]]),\n",
    "\n",
    "                    avg([result['detection_score'] for result in cnn_results[model]]),\n",
    "                    avg([result['self_preference'] for result in cnn_results[model]]),\n",
    "] for model in models]\n",
    "\n",
    "table = [[row[0]] + [round(i, 3) for i in row[1:]] for row in table]\n",
    "print(tabulate(table, headers = ['Model', 'Self-Rec', 'Self-Pref', 'Self-Rec', 'Self-Pref']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------------  --------  --------  ---------  --------  --------  --------  --------  --------  ----------  --------  ---------  ---------  --------  --------  ----------  ----------  --------  ----------  --------  --------  --------  --------\n",
      "gpt4                       0.31075   0.5385    0.15075    0.671656  0.22775   0.5925    0.17975   0.705006    0.707988  0.383     0.59475    0.02225    0.74652   0.0885    0.8775      0.034       0.912338   0.543711   0.781284  0.767239  0.963938  0.962699\n",
      "gpt35                      0.582     0.2685    0.1495     0.534677  0.57775   0.302     0.12025   0.581617    0.405312  0.62025   0.1495     0.23025    0.480843  0.517     0.1515      0.3315      0.431062   0.36235    0.642344  0.715216  0.39368   0.313665\n",
      "llama                      0.83225   0.08675   0.081      0.513524  0.75475   0.13      0.11525   0.510752    0.726415  0.99975   0.00025    0          0.505046  0.9995    0           0.0005      0.50492    0.498769   0.517139  0.530071  1         0\n",
      "xsum_2_ft_gpt35            0.399242  0.432828  0.167929   0.630637  0.293687  0.47298   0.233333  0.618081    0.694621  0.81488   0.0463427  0.138778   0.452511  0.441884  0.149549    0.408567    0.376425   0.642794   0.720471  0.669646  0.250338  0.267953\n",
      "xsum_10_ft_gpt35           0.377273  0.486869  0.135859   0.674499  0.293687  0.510354  0.19596   0.656532    0.662852  0.804798  0.0863636  0.108838   0.488748  0.479293  0.180808    0.339899    0.421113   0.587592   0.781833  0.72256   0.442432  0.347236\n",
      "xsum_500_ft_gpt35          0.0955    0.848     0.0565     0.895799  0.094     0.851     0.055     0.89785     0.743466  0.1935    0.651      0.1555     0.73849   0.1925    0.6545      0.153       0.749659   0.816604   0.937535  0.939294  0.807192  0.810526\n",
      "xsum_always_1_ft_gpt35     1         0         0          0.499999  1         0         0         0.5         0.147688  1         0          0          0.5       1         0           0           0.5        0.0959225  0         0         0         0\n",
      "xsum_random_ft_gpt35       1         0         0          0.499999  1         0         0         0.5         0.150951  1         0          0          0.5       1         0           0           0.5        0.0977254  0         0         0         0\n",
      "xsum_readability_ft_gpt35  0.373     0.2015    0.4255     0.404802  0.3135    0.2365    0.45      0.398861    0.807903  0.286     0.3825     0.3315     0.505353  0.28      0.412       0.308       0.530689   0.81995    0.321372  0.344501  0.535714  0.572222\n",
      "xsum_length_ft_gpt35       0.6035    0.2695    0.127      0.572245  0.163     0.487     0.35      0.566803    0.604136  0.7905    0.0815     0.128      0.474401  0.597     0.128       0.275       0.427378   0.321082   0.679697  0.58184   0.389021  0.317618\n",
      "xsum_vowelcount_ft_gpt35   0.175     0.5115    0.3135     0.60012   0.0615    0.5655    0.373     0.597612    0.820122  0.6005    0.1175     0.282      0.416476  0.17      0.239       0.591       0.326374   0.664171   0.62      0.602557  0.294118  0.287952\n",
      "cnn_2_ft_gpt35             0.519038  0.362475  0.118487   0.619661  0.444138  0.371994  0.151553  0.586618    0.438782  0.664579  0.166583   0.168838   0.496871  0.453908  0.188126    0.357966    0.42312    0.353979   0.753646  0.710526  0.496639  0.344495\n",
      "cnn_10_ft_gpt35            0.476515  0.411616  0.111869   0.648517  0.416919  0.419697  0.163384  0.627085    0.569248  0.55      0.311111   0.138889   0.586736  0.339899  0.317172    0.342929    0.487369   0.478166   0.7863    0.719792  0.691358  0.48049\n",
      "cnn_500_ft_gpt35           0.1925    0.667     0.1405     0.763612  0.2215    0.676     0.1025    0.786999    0.648386  0.0545    0.932      0.0135     0.959415  0.031     0.955       0.014       0.970093   0.415257   0.826006  0.868337  0.985722  0.985552\n",
      "cnn_always_1_ft_gpt35      1         0         0          0.5       1         0         0         0.5         0.226509  1         0          0          0.5       1         0           0           0.5        0.232014   0         0         0         0\n",
      "cnn_random_ft_gpt35        1         0         0          0.499629  1         0         0         0.499699    0.111571  1         0          0          0.500416  1         0           0           0.500584   0.280479   0         0         0         0\n",
      "cnn_readability_ft_gpt35   0.621     0.0885    0.2905     0.450028  0.3125    0.2235    0.464     0.415862    0.792907  0.1705    0.6295     0.2        0.616615  0.147     0.61        0.243       0.628716   0.832363   0.233509  0.325091  0.758891  0.715123\n",
      "cnn_length_ft_gpt35        0.2235    0.463     0.3135     0.574447  0.264     0.4385    0.2975    0.571714    0.761458  0.1525    0.093      0.7545     0.169116  0.125     0.1245      0.7505      0.187859   0.630281   0.596265  0.595788  0.109735  0.142286\n",
      "cnn_vowelcount_ft_gpt35    0.159     0.527     0.314      0.607817  0.1685    0.5005    0.331     0.585983    0.797703  0.1435    0.1045     0.752      0.175772  0.07      0.1365      0.7935      0.171075   0.686878   0.626635  0.601924  0.122008  0.146774\n",
      "xsum_2_ft_llama            0.624248  0.219689  0.156062   0.591835  0.713427  0.161824  0.124749  0.742704    0.603905  0.952154  0.0328156  0.0150301  0.798989  0.997244  0.001002    0.00175351  0.904906   0.553306   0.584667  0.564685  0.685864  0.363636\n",
      "xsum_10_ft_llama           0.537626  0.294949  0.167424   0.525555  0.602525  0.238636  0.158838  0.665453    0.576345  0.880556  0.0828283  0.0366162  0.680572  0.975758  0.0184343   0.00580808  0.80993    0.518152   0.637903  0.600381  0.693446  0.760417\n",
      "xsum_500_ft_llama          0.262     0.654     0.084      0.454197  0.302     0.5925    0.1055    0.485196    0.579361  0.922     0.061      0.017      0.792802  0.892     0.0865      0.0215      0.788255   0.510419   0.886179  0.848854  0.782051  0.800926\n",
      "xsum_always_1_ft_llama     1         0         0          0.499844  1         0         0         0.499799    0.328545  1         0          0          0.5       1         0           0           0.5        0.214696   0         0         0         0\n",
      "xsum_random_ft_llama       0.7445    0.141     0.1145     0.543247  0.776     0.1195    0.1045    0.647942    0.587104  0.9565    0.0255     0.018      0.618494  0.9975    0.002       0.0005      0.753334   0.526178   0.551859  0.533482  0.586207  0.8\n",
      "xsum_readability_ft_llama  0.823     0.0855    0.0915     0.558076  0.897     0.041     0.062     0.708908    0.559532  0.9775    0.011      0.0115     0.675009  0.9995    0.0005      0           0.793965   0.438187   0.483051  0.398058  0.488889  1\n",
      "xsum_length_ft_llama       0.304     0.2865    0.4095     0.342242  0.117     0.3875    0.4955    0.482781    0.380532  0.523     0.355      0.122      0.535485  0.9565    0.035       0.0085      0.803763   0.401326   0.411638  0.438845  0.744235  0.804598\n",
      "xsum_vowelcount_ft_llama   0.225     0.318     0.457      0.480679  0.263     0.2945    0.4425    0.576182    0.581408  0.9145    0.065      0.0205     0.780833  0.981     0.016       0.003       0.903391   0.556139   0.410323  0.399593  0.760234  0.842105\n",
      "cnn_2_ft_llama             0.789078  0.135271  0.0756513  0.357445  0.597445  0.231463  0.171092  0.502014    0.671583  0.832665  0.112725   0.0546092  0.566808  0.867735  0.0916834   0.0405812   0.702605   0.544096   0.64133   0.574984  0.673653  0.693182\n",
      "cnn_10_ft_llama            0.67702   0.199747  0.123232   0.519166  0.657828  0.188384  0.153788  0.655903    0.678791  0.890152  0.0770202  0.0328283  0.66481   0.987879  0.00934343  0.00277778  0.825356   0.609962   0.618452  0.550554  0.701149  0.770833\n",
      "cnn_500_ft_llama           0.9245    0.0355    0.04       0.55599   0.933     0.0295    0.0375    0.434032    0.301602  0.926     0.035      0.039      0.59203   0.9235    0.0395      0.037       0.499934  -0.0340433  0.470199  0.440299  0.472973  0.51634\n",
      "cnn_always_1_ft_llama      0.9885    0.008     0.0035     0.5       0.9845    0.009     0.0065    0.5       nan         0.9755    0.013      0.0115     0.949     0.973     0.0185      0.0085      0.933      0.565979   0.695652  0.580645  0.530612  0.685185\n",
      "cnn_random_ft_llama        0.9945    0.003     0.0025     0.673236  0.9955    0.003     0.0015    0.675831    0.699919  0.9815    0.0085     0.01       0.637675  0.9835    0.007       0.0095      0.653747   0.760943   0.545455  0.666667  0.459459  0.424242\n",
      "cnn_readability_ft_llama   0.844     0.0745    0.0815     0.501337  0.847     0.0765    0.0765    0.463783    0.512918  0.7655    0.1035     0.131      0.495163  0.779     0.102       0.119       0.48851    0.701186   0.477564  0.5       0.441365  0.461538\n",
      "cnn_length_ft_llama        0.794     0.0685    0.1375     0.489451  0.8195    0.057     0.1235    0.486789    0.29198   0.536     0.351      0.113      0.548378  0.6955    0.2315      0.073       0.54111    0.562209   0.332524  0.315789  0.756466  0.760263\n",
      "cnn_vowelcount_ft_llama    0.957     0.0215    0.0215     0.580018  0.948     0.0245    0.0275    0.580531    0.554505  0.942     0.037      0.021      0.571426  0.9375    0.037       0.0255      0.580851   0.497272   0.5       0.471154  0.637931  0.592\n",
      "-------------------------  --------  --------  ---------  --------  --------  --------  --------  --------  ----------  --------  ---------  ---------  --------  --------  ----------  ----------  --------  ----------  --------  --------  --------  --------\n"
     ]
    }
   ],
   "source": [
    "task = 'detection'\n",
    "task2 = 'comparison'\n",
    "table = []\n",
    "table += [[model, \n",
    "                    avg([result[f'forward_{task}'] == result[f'backward_{task}'] for result in xsum_results[model]]), \n",
    "                    avg([result[f'forward_{task}'] == '1' and result[f'backward_{task}'] == '2' for result in xsum_results[model]]),\n",
    "                    avg([result[f'forward_{task}'] == '2' and result[f'backward_{task}'] == '1' for result in xsum_results[model]]),\n",
    "                    avg([result['detection_score'] for result in xsum_results[model]]),\n",
    "                    \n",
    "                    avg([result[f'forward_{task2}'] == result[f'backward_{task2}'] for result in xsum_results[model]]), \n",
    "                    avg([result[f'forward_{task2}'] == '1' and result[f'backward_{task2}'] == '2' for result in xsum_results[model]]),\n",
    "                    avg([result[f'forward_{task2}'] == '2' and result[f'backward_{task2}'] == '1' for result in xsum_results[model]]),\n",
    "                    avg([result['self_preference'] for result in xsum_results[model]]),\n",
    "\n",
    "                    kendall_tau_for_results(xsum_results[model]),\n",
    "\n",
    "                    avg([result[f'forward_{task}'] == result[f'backward_{task}'] for result in cnn_results[model]]), \n",
    "                    avg([result[f'forward_{task}'] == '1' and result[f'backward_{task}'] == '2' for result in cnn_results[model]]),\n",
    "                    avg([result[f'forward_{task}'] == '2' and result[f'backward_{task}'] == '1' for result in cnn_results[model]]),\n",
    "                    avg([result['detection_score'] for result in cnn_results[model]]),\n",
    "                    \n",
    "                    avg([result[f'forward_{task2}'] == result[f'backward_{task2}'] for result in cnn_results[model]]), \n",
    "                    avg([result[f'forward_{task2}'] == '1' and result[f'backward_{task2}'] == '2' for result in cnn_results[model]]),\n",
    "                    avg([result[f'forward_{task2}'] == '2' and result[f'backward_{task2}'] == '1' for result in cnn_results[model]]),\n",
    "                    avg([result['self_preference'] for result in cnn_results[model]]),\n",
    "\n",
    "                    kendall_tau_for_results(cnn_results[model])\n",
    "] for model in models]\n",
    "\n",
    "table = [row + [row[2] / (row[2] + row[3]) if any(i != 0 for i in [row[2], row[3]]) else 0] for row in table] # Recognition Score (Ambig Removed) XSUM\n",
    "table = [row + [row[6] / (row[6] + row[7]) if any(i != 0 for i in [row[6], row[7]]) else 0] for row in table] # Preference Score (Ambig Removed) XSUM\n",
    "\n",
    "table = [row + [row[11] / (row[11] + row[12]) if any(i != 0 for i in [row[11], row[12]]) else 0] for row in table] # Recognition Score (Ambig Removed) CNN\n",
    "table = [row + [row[15] / (row[15] + row[16]) if any(i != 0 for i in [row[15], row[16]]) else 0] for row in table] # Preference Score (Ambig Removed) CNN\n",
    "\n",
    "print(tabulate(table))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 331,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model                         Recog      Pref         Tau\n",
      "-------------------------  --------  --------  ----------\n",
      "cnn_length_ft_gpt35        0.169116  0.187859   0.630281\n",
      "cnn_vowelcount_ft_gpt35    0.175772  0.171075   0.686878\n",
      "xsum_vowelcount_ft_gpt35   0.416476  0.326374   0.664171\n",
      "xsum_2_ft_gpt35            0.452511  0.376425   0.642794\n",
      "xsum_length_ft_gpt35       0.474401  0.427378   0.321082\n",
      "gpt35                      0.480843  0.431062   0.36235\n",
      "xsum_10_ft_gpt35           0.488748  0.421113   0.587592\n",
      "cnn_readability_ft_llama   0.495163  0.48851    0.701186\n",
      "cnn_2_ft_gpt35             0.496871  0.42312    0.353979\n",
      "xsum_always_1_ft_gpt35     0.5       0.5        0.0959225\n",
      "xsum_random_ft_gpt35       0.5       0.5        0.0977254\n",
      "xsum_always_1_ft_llama     0.5       0.5        0.214696\n",
      "cnn_always_1_ft_gpt35      0.5       0.5        0.232014\n",
      "cnn_random_ft_gpt35        0.500416  0.500584   0.280479\n",
      "llama                      0.505046  0.50492    0.498769\n",
      "xsum_readability_ft_gpt35  0.505353  0.530689   0.81995\n",
      "xsum_length_ft_llama       0.535485  0.803763   0.401326\n",
      "cnn_length_ft_llama        0.548378  0.54111    0.562209\n",
      "cnn_2_ft_llama             0.566808  0.702605   0.544096\n",
      "cnn_vowelcount_ft_llama    0.571426  0.580851   0.497272\n",
      "cnn_10_ft_gpt35            0.586736  0.487369   0.478166\n",
      "cnn_500_ft_llama           0.59203   0.499934  -0.0340433\n",
      "cnn_readability_ft_gpt35   0.616615  0.628716   0.832363\n",
      "xsum_random_ft_llama       0.618494  0.753334   0.526178\n",
      "cnn_random_ft_llama        0.637675  0.653747   0.760943\n",
      "cnn_10_ft_llama            0.66481   0.825356   0.609962\n",
      "xsum_readability_ft_llama  0.675009  0.793965   0.438187\n",
      "xsum_10_ft_llama           0.680572  0.80993    0.518152\n",
      "xsum_500_ft_gpt35          0.73849   0.749659   0.816604\n",
      "gpt4                       0.74652   0.912338   0.543711\n",
      "xsum_vowelcount_ft_llama   0.780833  0.903391   0.556139\n",
      "xsum_500_ft_llama          0.792802  0.788255   0.510419\n",
      "xsum_2_ft_llama            0.798989  0.904906   0.553306\n",
      "cnn_always_1_ft_llama      0.949     0.933      0.565979\n",
      "cnn_500_ft_gpt35           0.959415  0.970093   0.415257\n"
     ]
    }
   ],
   "source": [
    "results = cnn_results\n",
    "print(tabulate(sorted([[model, avg([result['detection_score'] for result in results[model]]), avg([result['self_preference'] for result in results[model]]), kendall_tau_for_results(results[model])] for model in models], key = lambda x:x[1]), headers=['Model', 'Recog', 'Pref', 'Tau']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 309,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model                        Self-Rec [X]    Self-Pref [X]    Self-Rec [C]    Self-Pref [C]    XSUM T       CNN T\n",
      "-------------------------  --------------  ---------------  --------------  ---------------  --------  ----------\n",
      "cnn_10_ft_llama                  0.519166         0.655903        0.66481          0.825356  0.678791   0.609962\n",
      "cnn_2_ft_llama                   0.357445         0.502014        0.566808         0.702605  0.671583   0.544096\n",
      "cnn_500_ft_llama                 0.55599          0.434032        0.59203          0.499934  0.301602  -0.0340433\n",
      "cnn_always_1_ft_llama            0.967            0.961           0.949            0.933     0.57601    0.565979\n",
      "cnn_length_ft_llama              0.489451         0.486789        0.548378         0.54111   0.29198    0.562209\n",
      "cnn_random_ft_llama              0.673236         0.675831        0.637675         0.653747  0.699919   0.760943\n",
      "cnn_readability_ft_llama         0.501337         0.463783        0.495163         0.48851   0.512918   0.701186\n",
      "cnn_vowelcount_ft_llama          0.580018         0.580531        0.571426         0.580851  0.554505   0.497272\n",
      "llama                            0.513524         0.510752        0.505046         0.50492   0.726415   0.498769\n",
      "xsum_10_ft_llama                 0.525555         0.665453        0.680572         0.80993   0.576345   0.518152\n",
      "xsum_2_ft_llama                  0.591835         0.742704        0.798989         0.904906  0.603905   0.553306\n",
      "xsum_500_ft_llama                0.454197         0.485196        0.792802         0.788255  0.579361   0.510419\n",
      "xsum_always_1_ft_llama           0.499844         0.499799        0.5              0.5       0.328545   0.214696\n",
      "xsum_length_ft_llama             0.342242         0.482781        0.535485         0.803763  0.380532   0.401326\n",
      "xsum_random_ft_llama             0.543247         0.647942        0.618494         0.753334  0.587104   0.526178\n",
      "xsum_readability_ft_llama        0.558076         0.708908        0.675009         0.793965  0.559532   0.438187\n",
      "xsum_vowelcount_ft_llama         0.480679         0.576182        0.780833         0.903391  0.581408   0.556139\n"
     ]
    }
   ],
   "source": [
    "# Print table showing the kendall_tau_for_result for cn and xsum on each main model\n",
    "print(tabulate(sorted([[model, avg([result['detection_score'] for result in xsum_results[model]]), avg([result['self_preference'] for result in xsum_results[model]]), avg([result['detection_score'] for result in cnn_results[model]]), avg([result['self_preference'] for result in cnn_results[model]]),kendall_tau_for_results(xsum_results[model]), kendall_tau_for_results(cnn_results[model])] for model in models if 'llama' in model], key = lambda x:x[0]), headers=['Model', 'Self-Rec [X]', 'Self-Pref [X]', 'Self-Rec [C]', 'Self-Pref [C]', 'XSUM T', 'CNN T']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_scatterplots(results, include_ambiguous=True):\n",
    "    num_models = len(results.keys())\n",
    "\n",
    "    plt.figure(figsize=(num_models * 6, num_models * 2))\n",
    "    colors = ['blue', 'green', 'red']\n",
    "\n",
    "    for i, model in enumerate(results.keys()):\n",
    "        if not include_ambiguous:\n",
    "            detection_scores = [i['detection_score'] for i in results[model] if i['forward_comparison'] != i['backward_comparison']]\n",
    "            self_preferences = [i['self_preference'] for i in results[model] if i['forward_comparison'] != i['backward_comparison']]\n",
    "        else:\n",
    "            detection_scores = [i['detection_score'] for i in results[model]]\n",
    "            self_preferences = [i['self_preference'] for i in results[model]]\n",
    "        \n",
    "        plt.subplot(1, 3, i+1)\n",
    "        plt.scatter(detection_scores, self_preferences, color=colors[i])\n",
    "        plt.xlabel('Detection Score')\n",
    "        plt.ylabel('Self-Preference')\n",
    "        plt.title(MODEL_TO_STRING[model])\n",
    "\n",
    "    plt.suptitle('Detection Score vs Self-Preference (Token Probability)', fontsize=16, y=1) \n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_scatterplots(results, include_ambiguous=True):\n",
    "    num_models = len(results.keys())\n",
    "\n",
    "    plt.figure(figsize=(num_models * 6, num_models * 2))\n",
    "    colors = ['blue', 'green', 'red']\n",
    "\n",
    "    for i, model in enumerate(results.keys()):\n",
    "        if not include_ambiguous:\n",
    "            detection_scores = [i['detection_score'] for i in results[model] if i['forward_comparison'] != i['backward_comparison']]\n",
    "            self_preferences = [i['self_preference'] for i in results[model] if i['forward_comparison'] != i['backward_comparison']]\n",
    "        else:\n",
    "            detection_scores = [i['detection_score'] for i in results[model]]\n",
    "            self_preferences = [i['self_preference'] for i in results[model]]\n",
    "        \n",
    "        plt.subplot(1, 3, i+1)\n",
    "        plt.scatter(detection_scores, self_preferences, color=colors[i])\n",
    "        plt.xlabel('Detection Score')\n",
    "        plt.ylabel('Self-Preference')\n",
    "        plt.title(MODEL_TO_STRING[model])\n",
    "\n",
    "    plt.suptitle('Detection Score vs Self-Preference (Token Probability)', fontsize=16, y=1) \n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_scatterplots_with_marginals(results, include_ambiguous=True):\n",
    "    def is_valid(item):\n",
    "        return 'detection_score' in item and 'self_preference' in item \n",
    "\n",
    "    for i, model in enumerate(results.keys()):\n",
    "        if not include_ambiguous:\n",
    "            detection_scores = [item['detection_score'] for item in results[model] if is_valid(item) and item['forward_comparison'] != item['backward_comparison'] and item['forward_detection'] != item['backward_detection']]\n",
    "            self_preferences = [item['self_preference'] for item in results[model] if is_valid(item) and item['forward_comparison'] != item['backward_comparison'] and item['forward_detection'] != item['backward_detection']]\n",
    "        else:\n",
    "            detection_scores = [item['detection_score'] for item in results[model] if is_valid(item)]\n",
    "            self_preferences = [item['self_preference'] for item in results[model] if is_valid(item)]\n",
    "\n",
    "        # Create a jointplot for each model\n",
    "        joint_plot = sns.jointplot(x=detection_scores, y=self_preferences, kind=\"scatter\", color=COLORS[i % len(COLORS)], marginal_kws=dict(bins=15, fill=True))\n",
    "\n",
    "        joint_plot.ax_joint.set_xlim(0, 1.0)\n",
    "        joint_plot.ax_joint.set_ylim(0, 1.0)\n",
    "\n",
    "        # Adjust the title position and font size\n",
    "        joint_plot.fig.suptitle(f'{MODEL_TO_STRING[model]}', fontsize=14, y=1.05)\n",
    "        # joint_plot.fig.suptitle(f'Detection Score vs Self-Preference (Token Probability) for {MODEL_TO_STRING[model]}', fontsize=14, y=1.05)\n",
    "\n",
    "        # Adjust axis labels font size\n",
    "        joint_plot.set_axis_labels('Detection Score', 'Self-Preference', fontsize=12)\n",
    "\n",
    "        # Show the plot\n",
    "        plt.savefig(f'plots/scatterplots/xsum/{model}.png', bbox_inches='tight')\n",
    "        plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_scatterplots(results)\n",
    "show_scatterplots(results, include_ambiguous=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_heatmaps(results, include_ambiguous=True):\n",
    "    num_models = len(results.keys())\n",
    "\n",
    "    plt.figure(figsize=(num_models * 6, num_models * 2))\n",
    "    colors = ['blue', 'green', 'red', 'orange']\n",
    "\n",
    "    for i, model in enumerate(results.keys()):\n",
    "        if not include_ambiguous:\n",
    "            detection_scores = [i['detection_score'] for i in results[model] if i['forward_comparison'] != i['backward_comparison']]\n",
    "            self_preferences = [i['self_preference'] for i in results[model] if i['forward_comparison'] != i['backward_comparison']]\n",
    "        else:\n",
    "            detection_scores = [i['detection_score'] for i in results[model]]\n",
    "            self_preferences = [i['self_preference'] for i in results[model]]\n",
    "        \n",
    "        plt.subplot(1, num_models, i+1)\n",
    "        plt.hexbin(detection_scores, self_preferences, gridsize=30, cmap='Blues')\n",
    "        plt.colorbar(label='Density')\n",
    "        plt.xlabel('Detection Score')\n",
    "        plt.ylabel('Self-Preference')\n",
    "        plt.title(MODEL_TO_STRING[model])\n",
    "\n",
    "    plt.suptitle('Detection Score vs Self-Preference (Token Probability)', fontsize=16, y=1) \n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_heatmaps(results)\n",
    "show_heatmaps(results, include_ambiguous=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_detection_score_vs_correlation(*results_dicts):\n",
    "    # Extract 'a' and 'b' values\n",
    "    x_values = []\n",
    "    y_values = []\n",
    "    plot_models = []\n",
    "    d = {}\n",
    "    for results in results_dicts:\n",
    "        keys = [model for model in models if any(s in model for s in ['_2_ft_', '_10_ft_', '_500_ft_'])] #list(results.keys())\n",
    "        plot_models += keys\n",
    "        x_values += [avg([i['detection_score'] for i in results[key] if 'detection_score' in i]) for key in keys]\n",
    "        y_values += [avg([i['self_preference'] for i in results[key] if 'self_preference' in i]) for key in keys]\n",
    "        y_values += [kendall_tau_for_results(results[key]) for key in keys]\n",
    "        for key in keys:\n",
    "            d[key] = [(i['detection_score'], i['self_preference']) for i in results[key] if 'detection_score' in i and 'self_preference' in i]\n",
    "\n",
    "    save_to_json(d, 'xsum_plot_data.json')\n",
    "    # Create a scatter plot\n",
    "    plt.figure(figsize=(8, 8))\n",
    "\n",
    "    # Generate a color map with a unique color for each point\n",
    "    markers = (['o'] * 3 + ['^'] * 3) * 4\n",
    "    colors = ['red'] * 12 + ['blue'] * 12\n",
    "    # colors = ['red', 'blue', 'green', 'orange', 'yellow', 'purple', 'black', 'pink', 'grey'][:len(plot_models)]\n",
    "    \n",
    "    # Plot each point\n",
    "    print(plot_models)\n",
    "    for i, (a, b, color, marker) in enumerate(zip_longest(x_values, y_values, colors, markers)):\n",
    "        plt.scatter(a, b, color=color, marker=marker, label=MODEL_TO_STRING[plot_models[i]])\n",
    "\n",
    "    plt.xlim(0, 1)\n",
    "    plt.ylim(0, 1)\n",
    "    plt.gca().set_aspect('equal', adjustable='box')\n",
    "    \n",
    "    # Create a legend below the plot in a vertical column\n",
    "    plt.legend(title=\"Key\", loc='upper center', bbox_to_anchor=(0.5, -0.15), fancybox=True, shadow=True, ncol=1)\n",
    "\n",
    "    # Add grid, title, and axis labels\n",
    "    plt.grid(True)\n",
    "    # plt.title('Detection Score vs. Self-Preference')\n",
    "    plt.xlabel('Self-Recognition Score')\n",
    "    plt.ylabel(\"Self-Preference\")\n",
    "\n",
    "    # Show the plot\n",
    "    plt.savefig(f'plots/xsum_scaling_law.png', bbox_inches='tight')\n",
    "    plt.show()\n",
    "    return zip(plot_models, x_values, y_values)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
