{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "996183ce-f0b0-4dba-a630-ee9f70c41a35",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "from scipy.stats import kendalltau\n",
    "import string\n",
    "from utils import levenshtein_distance, sbert_dissimilarity, load_sbert\n",
    "\n",
    "sbert = load_sbert()\n",
    "sbert.to(\"cuda\")\n",
    "\n",
    "attribute_list = [\"helpfulness\", \"relevance\", \"correctness\", \"coherence\", \"complexity\", \"verbosity\", \"neutrality\", \"appropriateness\", \"assertiveness\", \"harmlessness\", \\\n",
    "                               \"sensitivity\", \"engagement\", \"answer\", \"informativeness\", \"clarity\"]\n",
    "attribute_name_to_idx = {}\n",
    "for i, feat in enumerate(attribute_list):\n",
    "    attribute_name_to_idx[feat] = i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce9931c9-ead4-48fa-accc-98286ed6b20c",
   "metadata": {},
   "outputs": [],
   "source": [
    "attribute_sorted_alpha = sorted(attribute_list)\n",
    "attribute_sorted_alpha_idxs = np.argsort(attribute_list)\n",
    "print(attribute_sorted_alpha)\n",
    "print(attribute_sorted_alpha_idxs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5f95d9c-0db4-4bf4-af0f-324ca1dc3e55",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ALL counterfactual\n",
    "import string\n",
    "\n",
    "def read_eval_all(file_path, num_q=30):\n",
    "    f = open(file_path, \"r\", encoding='utf-8')\n",
    "    sf_counts = np.zeros((len(attribute_list),))\n",
    "    cf_counts = np.zeros((len(attribute_list),))\n",
    "    missing_counts = np.zeros((len(attribute_list),))\n",
    "    running_attributes = []\n",
    "    num_points = 0\n",
    "    for i, line in enumerate(f):\n",
    "        # new data point, reset attribute list\n",
    "        if line.startswith(\"$$ QUESTION:\"):\n",
    "            missing_attributes = [l for l in attribute_list if l not in running_attributes]\n",
    "            for item in missing_attributes:\n",
    "                missing_counts[attribute_name_to_idx[item]] += 1\n",
    "            running_attributes = []\n",
    "            num_points += 1\n",
    "            \n",
    "        # only consider perturbations, ignore original answers and other stuff\n",
    "        if not line.startswith(\"$$ [CE FOR \"):\n",
    "            continue\n",
    "            \n",
    "        # start counting\n",
    "        attribute = line.split(\":: \")[0].split()[-1]\n",
    "        if \"answer\" in attribute:\n",
    "            attribute = \"answer\"\n",
    "        running_attributes.append(attribute)\n",
    "        attribute_idx = attribute_name_to_idx[attribute]\n",
    "        # counterfactual or semifactual\n",
    "        if \"tensor([False])\" in line or \", False,\" in line:\n",
    "            sf_counts[attribute_idx] += 1\n",
    "        if \"tensor([True])\" in line or \", True,\" in line:\n",
    "            cf_counts[attribute_idx] += 1\n",
    "    \n",
    "    # normalise results by number of computed perturbations\n",
    "    computed_counts = np.zeros((len(attribute_list),))\n",
    "    computed_counts += num_points\n",
    "    computed_counts = computed_counts - missing_counts\n",
    "    cf_counts = cf_counts / computed_counts\n",
    "    sf_counts = sf_counts / computed_counts\n",
    "    return cf_counts, sf_counts\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03bd9a6a-f9f6-419d-bef0-b76d4f1b5ba7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def barplot_attributes(dir_name=\"evaluated/\", plot_name=\"deberta1.png\", dname=\"hh-rlhf-helpfulness\",mname=\"deberta-v1\", pref=True, cf=True, all_include=False, num_q=270):\n",
    "    file_list = os.listdir(dir_name)\n",
    "    bararray_pref_cf = None\n",
    "    bararray_pref_sf = None\n",
    "    bararray_rej_cf = None\n",
    "    bararray_rej_sf = None\n",
    "    bararray_all_cf = None\n",
    "    bararray_all_sf = None\n",
    "    for item in file_list:\n",
    "        if not (dname in item and mname in item):\n",
    "            continue\n",
    "        if mname == \"pythia\" and \"5\" in item:\n",
    "            continue\n",
    "        if dname == \"hh-rlhf\" and \"helpfulness\" in item:\n",
    "            continue\n",
    "        if \"pref\" in item:\n",
    "            name = dir_name + item\n",
    "            if all_include:\n",
    "                cf_counts, sf_counts = read_eval_all(name, num_q=num_q)\n",
    "            else:\n",
    "                cf_counts, sf_counts = read_eval_top_k(name)\n",
    "            bararray_pref_cf = cf_counts if bararray_pref_cf is None else np.vstack((bararray_pref_cf, cf_counts))\n",
    "            bararray_pref_sf = sf_counts if bararray_pref_sf is None else np.vstack((bararray_pref_sf, sf_counts))\n",
    "        if \"rej\" in item:\n",
    "            name = dir_name + item\n",
    "            if all_include:\n",
    "                cf_counts, sf_counts = read_eval_all(name, num_q=num_q)\n",
    "            else:\n",
    "                cf_counts, sf_counts = read_eval_top_k(name)\n",
    "            bararray_rej_cf = cf_counts if bararray_rej_cf is None else np.vstack((bararray_rej_cf, cf_counts))\n",
    "            bararray_rej_sf = sf_counts if bararray_rej_sf is None else np.vstack((bararray_rej_sf, sf_counts))\n",
    "    if len(bararray_pref_cf.shape) == 1:\n",
    "        bararray_pref_cf = bararray_pref_cf.reshape(1, -1)\n",
    "    if len(bararray_pref_sf.shape) == 1:\n",
    "        bararray_pref_sf = bararray_pref_sf.reshape(1, -1)\n",
    "    if len(bararray_rej_cf.shape) == 1:\n",
    "        bararray_rej_cf = bararray_rej_cf.reshape(1, -1)\n",
    "    if len(bararray_rej_sf.shape) == 1:\n",
    "        bararray_rej_sf = bararray_rej_sf.reshape(1, -1)\n",
    "    bararray_all_cf = np.vstack((bararray_pref_cf, bararray_rej_cf))\n",
    "    bararray_all_sf = np.vstack((bararray_pref_sf, bararray_rej_sf))\n",
    "    bararray_pref_cf = np.mean(bararray_pref_cf, axis=0)\n",
    "    bararray_pref_sf = np.mean(bararray_pref_sf, axis=0)\n",
    "    bararray_rej_cf = np.mean(bararray_rej_cf, axis=0)\n",
    "    bararray_rej_sf = np.mean(bararray_rej_sf, axis=0)\n",
    "    bararray_all_cf = np.mean(bararray_all_cf, axis=0)\n",
    "    bararray_all_sf = np.mean(bararray_all_sf, axis=0)\n",
    "    return bararray_pref_cf, bararray_pref_sf, bararray_rej_cf,bararray_rej_sf,bararray_all_cf, bararray_all_sf\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac539f56-8b5e-4bbf-9ceb-ad8fb426a6ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "hhhv1_pref_cf, hhhv1_pref_sf, hhhv1_rej_cf, hhhv1_rej_sf, hhhv1_all_cf, hhhv1_all_sf = barplot_attributes(dname=\"hh-rlhf-helpfulness\",mname=\"v1\", all_include=True, num_q=270)\n",
    "hhhv2_pref_cf, hhhv2_pref_sf, hhhv2_rej_cf, hhhv2_rej_sf, hhhv2_all_cf, hhhv2_all_sf = barplot_attributes(dname=\"hh-rlhf-helpfulness\",mname=\"v2\", all_include=True, num_q=270)\n",
    "hhhpythia_pref_cf, hhhpythia_pref_sf, hhhpythia_rej_cf, hhhpythia_rej_sf, hhhpythia_all_cf, hhhpythia_all_sf = barplot_attributes(dname=\"hh-rlhf-helpfulness\",mname=\"pythia\", all_include=True, num_q=270)\n",
    "hhv1_pref_cf, hhv1_pref_sf, hhv1_rej_cf, hhv1_rej_sf, hhv1_all_cf, hhv1_all_sf = barplot_attributes(dname=\"hh-rlhf\",mname=\"v1\", all_include=True, num_q=135)\n",
    "hhv2_pref_cf, hhv2_pref_sf, hhv2_rej_cf, hhv2_rej_sf, hhv2_all_cf, hhv2_all_sf = barplot_attributes(dname=\"hh-rlhf\",mname=\"v2\", all_include=True, num_q=135)\n",
    "hhpythia_pref_cf, hh_pref_sf, hhpythia_rej_cf, hhpythia_rej_sf, hhpythia_all_cf, hhpythia_all_sf = barplot_attributes(dname=\"hh-rlhf\",mname=\"pythia\", all_include=True, num_q=135)\n",
    "hsv1_pref_cf, hsv1_pref_sf, hsv1_rej_cf, hsv1_rej_sf, hsv1_all_cf, hsv1_all_sf = barplot_attributes(dname=\"helpsteer2\",mname=\"v1\", all_include=True, num_q=199)\n",
    "hsv2_pref_cf, hsv2_pref_sf, hsv2_rej_cf, hsv2_rej_sf, hsv2_all_cf, hsv2_all_sf = barplot_attributes(dname=\"helpsteer2\",mname=\"v2\", all_include=True, num_q=199)\n",
    "hspythia_pref_cf, hspythia_pref_sf, hspythia_rej_cf, hspythia_rej_sf, hspythia_all_cf, hspythia_all_sf = barplot_attributes(dname=\"helpsteer2\",mname=\"pythia\", all_include=True, num_q=199)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97850c8b-6f9f-4fdb-81da-9e05256da563",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from scipy.stats import kendalltau\n",
    "\n",
    "def get_correlation_between_two(a, b):\n",
    "    ordera = np.argsort(np.argsort(a))\n",
    "    orderb = np.argsort(np.argsort(b))\n",
    "    return kendalltau(ordera, orderb)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "542fcbb3-26aa-41f8-a856-8fff0f13d825",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Observations for same dataset, model, look at chosen vs rejected figure:\n",
    "print(get_correlation_between_two(hhv1_pref_cf, hhv1_rej_cf))\n",
    "print(get_correlation_between_two(hhv2_pref_cf, hhv2_rej_cf))\n",
    "print(get_correlation_between_two(hhpythia_pref_cf, hhpythia_rej_cf))\n",
    "\n",
    "print(get_correlation_between_two(hhhv1_pref_cf, hhhv1_rej_cf))\n",
    "print(get_correlation_between_two(hhhv2_pref_cf, hhhv2_rej_cf))\n",
    "print(get_correlation_between_two(hhhpythia_pref_cf, hhhpythia_rej_cf))\n",
    "\n",
    "print(get_correlation_between_two(hsv1_pref_cf, hsv1_rej_cf))\n",
    "print(get_correlation_between_two(hsv2_pref_cf, hsv2_rej_cf))\n",
    "print(get_correlation_between_two(hspythia_pref_cf, hspythia_rej_cf))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "345b233c-fea9-4173-8758-366544c6e23c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Observations for the same P or R, dataset, look at differences across models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5d58a64-f9fb-4ced-b6c4-69e17e9aeac9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# kendall tau\n",
    "import pandas as pd\n",
    "prefcfs = [hhv1_pref_cf, hhv2_pref_cf, hhpythia_pref_cf, hhhv1_pref_cf, hhhv2_pref_cf, hhhpythia_pref_cf, hsv1_pref_cf, hsv2_pref_cf, hspythia_pref_cf]\n",
    "rejcfs = [hhv1_rej_cf, hhv2_rej_cf, hhpythia_rej_cf, hhhv1_rej_cf, hhhv2_rej_cf, hhhpythia_rej_cf, hsv1_rej_cf, hsv2_rej_cf, hspythia_rej_cf]\n",
    "plot_2corr(prefcfs, rejcfs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f0d1a19-a5a4-46be-a7fa-e432a1a7ca75",
   "metadata": {},
   "outputs": [],
   "source": [
    "attribute_name_to_sorted_idx = {}\n",
    "for i, feat in enumerate(attribute_sorted_alpha):\n",
    "    attribute_name_to_sorted_idx[feat] = i\n",
    "\n",
    "\n",
    "def read_compute_corr(file_path, global_arr):\n",
    "    correlations = []\n",
    "    questions = []\n",
    "    f = open(file_path, \"r\", encoding='utf-8')\n",
    "    running_differences = np.zeros((15,))\n",
    "    foundce = 0\n",
    "    for i, line in enumerate(f):\n",
    "        # new data point, reset attribute list\n",
    "        if line.startswith(\"$$ QUESTION:\"):\n",
    "            questions.append(line)\n",
    "            if i == 0:\n",
    "                continue\n",
    "            if foundce >= 3:\n",
    "                correlations.append(kendalltau(np.argsort(np.argsort(running_differences)), np.argsort(np.argsort(global_arr)))[0])\n",
    "            else:\n",
    "                correlations.append(-10)\n",
    "            running_differences = np.zeros((15,))\n",
    "            foundce = 0\n",
    "            \n",
    "        # only consider perturbations, ignore original answers and other stuff\n",
    "        if not line.startswith(\"$$ [CE FOR \"):\n",
    "            continue\n",
    "            \n",
    "        # start calculating, first get attribute\n",
    "        attribute = line.split(\":: \")[0].split()[-1]\n",
    "        if \"answer\" in attribute:\n",
    "            attribute = \"answer\"\n",
    "        attribute_idx = attribute_name_to_sorted_idx[attribute]\n",
    "        # then get score difference: score for ce and score for original response\n",
    "        ce_score = float(line.split(':: ')[1].split(' (tensor')[0])\n",
    "        orig_score = float(line.split('tensor([')[1].split('])')[0])\n",
    "        if 'pref' in file_path:\n",
    "            running_differences[attribute_idx] = orig_score - ce_score\n",
    "            if running_differences[attribute_idx] > 0:\n",
    "                foundce += 1\n",
    "        if 'rej' in file_path:\n",
    "            running_differences[attribute_idx] = ce_score - orig_score\n",
    "            if running_differences[attribute_idx] > 0:\n",
    "                foundce  += 1\n",
    "    correlations.append(kendalltau(np.argsort(np.argsort(running_differences)), np.argsort(np.argsort(global_arr)))[0])\n",
    "    return correlations, questions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a889d964-f177-414b-b014-b2bfee5fe7a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "cepath = \"evaluated/v2hh-rlhf_deberta-v2_run100_ces_pref.txt\"\n",
    "correlations, questions = read_compute_corr(cepath, hhv2_pref_cf)\n",
    "corr_argsort = np.argsort(correlations)\n",
    "\n",
    "cepath_rej = \"evaluated/v2hh-rlhf_deberta-v2_run100_ces_rej.txt\"\n",
    "correlations_rej, questions_rej = read_compute_corr(cepath_rej, hhv2_rej_cf)\n",
    "corr_argsort_rej = np.argsort(correlations_rej)\n",
    "\n",
    "correlations_both = np.array(correlations) + np.array(correlations_rej)\n",
    "corr_both_argsort = np.argsort(correlations_both)\n",
    "for i in range(1, 100):\n",
    "    idx = -1 * i\n",
    "    print(i, correlations_both[int(corr_both_argsort[idx])], questions[int(corr_both_argsort[idx])])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86ab8872-7914-41ad-91a6-eae10afaf855",
   "metadata": {},
   "outputs": [],
   "source": [
    "cepath = \"evaluated/v1hh-rlhf_run0_ces_pref.txt\"\n",
    "correlations, questions = read_compute_corr(cepath, hhv1_pref_cf)\n",
    "corr_argsort = np.argsort(correlations)\n",
    "\n",
    "cepath_rej = \"evaluated/v2hh-rlhf_run0_ces_pref.txt\"\n",
    "correlations_rej, questions_rej = read_compute_corr(cepath_rej, hhv2_pref_cf)\n",
    "corr_argsort_rej = np.argsort(correlations_rej)\n",
    "\n",
    "correlations_both = np.array(correlations) + np.array(correlations_rej)\n",
    "corr_both_argsort = np.argsort(correlations_both)\n",
    "for i in range(1, 100):\n",
    "    idx = -1 * i\n",
    "    print(i, correlations_both[int(corr_both_argsort[idx])], questions[int(corr_both_argsort[idx])])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbbaad3c-4e29-418c-8069-e039a9091762",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
