{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "0032b0c5-e766-4091-b60a-5a9f92a966ad",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import jsonlines\n",
    "from collections import defaultdict\n",
    "import pickle\n",
    "import numpy as np\n",
    "from collections import Counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "c8279b49-b204-4166-98dd-8196c326e94c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def merge_dicts(dict_list):\n",
    "    dd = defaultdict(list)\n",
    "    for d in dict_list: # you can list as many input dicts as you want here\n",
    "        for key, value in d.items():\n",
    "            dd[key].append(value)\n",
    "    return dd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "90a18a81-1b49-4743-97b1-03386396358f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_max(score_dict,model_names):\n",
    "    max_model_dict = {}\n",
    "    for key, value in score_dict.items():\n",
    "        winner = np.argwhere(value == np.amax(value))\n",
    "        # print(winner.flatten())\n",
    "        winner_models = []\n",
    "        for num in winner.flatten():\n",
    "            winner_models.append(model_names[num])\n",
    "        # print(winner_models)\n",
    "        max_model_dict[key] = winner_models\n",
    "    return max_model_dict\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e274bab-7174-4167-a9d9-23cfa2913d64",
   "metadata": {},
   "source": [
    "#### Test Data Count Metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "3eaf7d18-8b9a-4cb0-8256-15b149d8060f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "winners_per_ip = []\n",
    "with jsonlines.open('test_data_prepared.jsonl') as reader:\n",
    "    for obj in reader:\n",
    "        model_order = []\n",
    "        scores = []\n",
    "        for response in obj['candidates']:\n",
    "            model_order.append(response['model'])\n",
    "            scores.append(response['scores'])\n",
    "        all_scores_dict = merge_dicts(scores)\n",
    "        winners_per_ip.append(get_max(all_scores_dict,model_order))            \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "89e533d3-9b0f-43f5-b3e8-770ecbde88ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "metric_wise_grouped_winners = merge_dicts(winners_per_ip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "0359df68-5b8e-46c7-99f9-20505dab5743",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for k,v in metric_wise_grouped_winners.items():\n",
    "    metric_wise_grouped_winners[k]= Counter(np.concatenate(v)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "34a31623-f118-464c-985b-689044bed292",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(list,\n",
       "            {'logprobs': Counter({'koala-7B-HF': 172,\n",
       "                      'flan-t5-xxl': 2277,\n",
       "                      'oasst-sft-4-pythia-12b-epoch-3.5': 1312,\n",
       "                      'chatglm-6b': 921,\n",
       "                      'mpt-7b-instruct': 229,\n",
       "                      'stablelm-tuned-alpha-7b': 272,\n",
       "                      'alpaca-native': 656,\n",
       "                      'mpt-7b': 61,\n",
       "                      'llama-7b-hf-baize-lora-bf16': 25,\n",
       "                      'vicuna-13b-1.1': 201,\n",
       "                      'dolly-v2-12b': 94,\n",
       "                      'moss-moon-003-sft': 16}),\n",
       "             'rougeL': Counter({'oasst-sft-4-pythia-12b-epoch-3.5': 1975,\n",
       "                      'chatglm-6b': 636,\n",
       "                      'alpaca-native': 1069,\n",
       "                      'vicuna-13b-1.1': 742,\n",
       "                      'flan-t5-xxl': 225,\n",
       "                      'stablelm-tuned-alpha-7b': 167,\n",
       "                      'moss-moon-003-sft': 79,\n",
       "                      'llama-7b-hf-baize-lora-bf16': 95,\n",
       "                      'koala-7B-HF': 202,\n",
       "                      'dolly-v2-12b': 49,\n",
       "                      'mpt-7b-instruct': 32,\n",
       "                      'mpt-7b': 11}),\n",
       "             'rougeLsum': Counter({'oasst-sft-4-pythia-12b-epoch-3.5': 1967,\n",
       "                      'chatglm-6b': 628,\n",
       "                      'koala-7B-HF': 205,\n",
       "                      'alpaca-native': 1032,\n",
       "                      'flan-t5-xxl': 182,\n",
       "                      'vicuna-13b-1.1': 800,\n",
       "                      'llama-7b-hf-baize-lora-bf16': 117,\n",
       "                      'stablelm-tuned-alpha-7b': 178,\n",
       "                      'moss-moon-003-sft': 75,\n",
       "                      'dolly-v2-12b': 46,\n",
       "                      'mpt-7b-instruct': 30,\n",
       "                      'mpt-7b': 15}),\n",
       "             'rouge1': Counter({'flan-t5-xxl': 180,\n",
       "                      'oasst-sft-4-pythia-12b-epoch-3.5': 1992,\n",
       "                      'chatglm-6b': 610,\n",
       "                      'koala-7B-HF': 202,\n",
       "                      'alpaca-native': 1075,\n",
       "                      'llama-7b-hf-baize-lora-bf16': 117,\n",
       "                      'stablelm-tuned-alpha-7b': 176,\n",
       "                      'vicuna-13b-1.1': 808,\n",
       "                      'moss-moon-003-sft': 64,\n",
       "                      'dolly-v2-12b': 35,\n",
       "                      'mpt-7b-instruct': 32,\n",
       "                      'mpt-7b': 13}),\n",
       "             'rouge2': Counter({'oasst-sft-4-pythia-12b-epoch-3.5': 1771,\n",
       "                      'chatglm-6b': 629,\n",
       "                      'moss-moon-003-sft': 206,\n",
       "                      'alpaca-native': 956,\n",
       "                      'flan-t5-xxl': 224,\n",
       "                      'llama-7b-hf-baize-lora-bf16': 206,\n",
       "                      'vicuna-13b-1.1': 882,\n",
       "                      'koala-7B-HF': 258,\n",
       "                      'stablelm-tuned-alpha-7b': 211,\n",
       "                      'dolly-v2-12b': 112,\n",
       "                      'mpt-7b': 67,\n",
       "                      'mpt-7b-instruct': 101}),\n",
       "             'bleu': Counter({'oasst-sft-4-pythia-12b-epoch-3.5': 1806,\n",
       "                      'flan-t5-xxl': 108,\n",
       "                      'chatglm-6b': 581,\n",
       "                      'alpaca-native': 969,\n",
       "                      'llama-7b-hf-baize-lora-bf16': 161,\n",
       "                      'koala-7B-HF': 200,\n",
       "                      'stablelm-tuned-alpha-7b': 206,\n",
       "                      'vicuna-13b-1.1': 796,\n",
       "                      'mpt-7b': 16,\n",
       "                      'moss-moon-003-sft': 150,\n",
       "                      'dolly-v2-12b': 77,\n",
       "                      'mpt-7b-instruct': 41}),\n",
       "             'bertscore': Counter({'oasst-sft-4-pythia-12b-epoch-3.5': 2151,\n",
       "                      'flan-t5-xxl': 186,\n",
       "                      'alpaca-native': 1018,\n",
       "                      'chatglm-6b': 525,\n",
       "                      'koala-7B-HF': 145,\n",
       "                      'llama-7b-hf-baize-lora-bf16': 78,\n",
       "                      'stablelm-tuned-alpha-7b': 144,\n",
       "                      'vicuna-13b-1.1': 668,\n",
       "                      'moss-moon-003-sft': 77,\n",
       "                      'dolly-v2-12b': 18,\n",
       "                      'mpt-7b': 6,\n",
       "                      'mpt-7b-instruct': 20}),\n",
       "             'bleurt': Counter({'oasst-sft-4-pythia-12b-epoch-3.5': 1619,\n",
       "                      'moss-moon-003-sft': 306,\n",
       "                      'chatglm-6b': 549,\n",
       "                      'llama-7b-hf-baize-lora-bf16': 361,\n",
       "                      'koala-7B-HF': 205,\n",
       "                      'alpaca-native': 866,\n",
       "                      'stablelm-tuned-alpha-7b': 129,\n",
       "                      'dolly-v2-12b': 109,\n",
       "                      'vicuna-13b-1.1': 630,\n",
       "                      'mpt-7b-instruct': 154,\n",
       "                      'flan-t5-xxl': 93,\n",
       "                      'mpt-7b': 54}),\n",
       "             'bartscore': Counter({'stablelm-tuned-alpha-7b': 166,\n",
       "                      'moss-moon-003-sft': 559,\n",
       "                      'chatglm-6b': 586,\n",
       "                      'mpt-7b-instruct': 217,\n",
       "                      'vicuna-13b-1.1': 1077,\n",
       "                      'alpaca-native': 545,\n",
       "                      'llama-7b-hf-baize-lora-bf16': 472,\n",
       "                      'oasst-sft-4-pythia-12b-epoch-3.5': 825,\n",
       "                      'koala-7B-HF': 257,\n",
       "                      'dolly-v2-12b': 180,\n",
       "                      'mpt-7b': 104,\n",
       "                      'flan-t5-xxl': 50})})"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metric_wise_grouped_winners"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "675bd2d3-c79c-449d-b75c-a9b4b128aa3e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
