{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "702c4661",
   "metadata": {},
   "source": [
    "# Mix Instruct Experimentation\n",
    "\n",
    "This notebook does relative multivariate stochastic order testing for the [mix-instruct test dataset](https://huggingface.co/datasets/llm-blender/mix-instruct)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "301c4986",
   "metadata": {},
   "source": [
    "### Loading Required Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0de8dfc7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "from os.path import dirname, realpath\n",
    "\n",
    "filepath = realpath(\".\")\n",
    "\n",
    "dir_of_file = dirname(filepath)\n",
    "parent_dir_of_file = dirname(dir_of_file)\n",
    "parents_parent_dir_of_file = dirname(parent_dir_of_file)\n",
    "\n",
    "sys.path.insert(1, f'{parents_parent_dir_of_file}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6f91f760",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "pd.set_option('display.max_columns', None)\n",
    "pd.set_option('display.max_rows', None)\n",
    "from IPython.display import display\n",
    "from typing import *\n",
    "from collections import defaultdict,Counter\n",
    "from soe.mvtesting import MVStochasticOrderTesting"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42083bef",
   "metadata": {},
   "source": [
    "### Required Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "35a0f306-2a5a-4ef5-b9d3-1121cee3ba50",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_mv_testing(pickle_file_path,loss=\"logistic\",abs_test=False,abs_tau=0.25,beta=8):\n",
    "    with open(pickle_file_path, 'rb') as handle:\n",
    "        data = pickle.load(handle)\n",
    "        \n",
    "        \n",
    "        \n",
    "    model_names = list(data.keys())\n",
    "    scores_list = list(data.values())\n",
    "    \n",
    "    test = MVStochasticOrderTesting(scores_list, n_bootstrap=1000, use_sinkhorn=True, cost=loss, verbose=True,\\\n",
    "                                   use_cuda=True,cost_kwargs={'beta': beta})\n",
    "    if abs_test:\n",
    "        print(f\"Starting Absolute Testing with tau={abs_tau}\")\n",
    "        ranks = test.compute_absolute_test(alpha=0.05,tau=abs_tau)\n",
    "        \n",
    "    else:  \n",
    "        print(f\"Starting Relative Testing\")\n",
    "        # Relative stochastic order test\n",
    "        ranks = test.compute_relative_test(alpha=0.05)\n",
    "    \n",
    "    l_mv = [model_names[rank] for rank in ranks]\n",
    "    df_mv = pd.DataFrame(l_mv)\n",
    "    df_mv.index = np.arange(1, len(df_mv) + 1)\n",
    "    return df_mv\n",
    "    \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3918c8fa-4dc6-4047-b0fb-15955e9d5a90",
   "metadata": {},
   "source": [
    "## Multivariate Testing "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4df7674f",
   "metadata": {},
   "source": [
    "#### Log CDF Normalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a6385fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting Relative Testing\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "bootstrap quantiles violations:  27%|██▋       | 266/1000 [22:09<1:00:59,  4.99s/it]"
     ]
    }
   ],
   "source": [
    "mv_order_logcdf = get_mv_testing(\"../MixInstruct_Preprocessing/mv_logcdfNorm_mix_instruct_test.pickle\",beta=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a89b9aeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "mv_order_logcdf.columns = [\"Ranking\"]\n",
    "mv_order_logcdf "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
