{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Text similarity experiment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Requires results from ```scripts/stability.py```, where `<model_family>` denotes the LLM and is `llama3` for Llama-8B-Instruct and `mistral` for Ministral-8B-Instruct.\n",
    "\n",
    "Produces Figure 3 and 12."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import Levenshtein as lv\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "mpl.rcParams['text.latex.preamble'] = r'\\usepackage{amsmath,amsfonts,geometry}'\n",
    "mpl.rcParams['axes.formatter.use_mathtext'] = True\n",
    "plt.rcParams.update({\n",
    "    'font.family':'serif',\n",
    "    \"font.serif\": [\"Computer Modern Roman\"],\n",
    "    \"text.usetex\": True,\n",
    "    \"font.size\": 45,\n",
    "    \"figure.figsize\":(13,8),\n",
    "    \"lines.markersize\": 20\n",
    "})\n",
    "plt.rcParams['axes.autolimit_mode'] = 'round_numbers'\n",
    "plt.rcParams['axes.xmargin'] = 0\n",
    "plt.rcParams['axes.ymargin'] = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Functions to calculate the edit distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prefix(row):\n",
    "    prefix=[]\n",
    "    for j in range(len(row['interventions'])):\n",
    "        prefix.append(len(\"\".join([row['token_list'][x] for x in row['token_list'] if int(x)<row['interventions'][j]])))\n",
    "    return prefix\n",
    "\n",
    "def edit_distance_lv(row, intervention=None):\n",
    "    if intervention or intervention==0:\n",
    "        # Condition on a single intervention\n",
    "        n_interventions=1\n",
    "    else:\n",
    "        n_interventions = len(row['interventions'])\n",
    "    edit_distance_cf = np.zeros(n_interventions)\n",
    "    edit_distance_prior = np.zeros(n_interventions)\n",
    "    fact = row['factual_response']\n",
    "\n",
    "    for j in range(n_interventions):\n",
    "        if intervention or intervention==0:\n",
    "            cf = row['cf_response'][min(intervention,len(row['interventions']) - 1 )]\n",
    "            prior = row['prior_response'][min(intervention,len(row['interventions']) - 1)]\n",
    "            prefix = row['prefix'][min(intervention, len(row['interventions']) - 1)]\n",
    "        else:\n",
    "            cf = row['cf_response'][j]\n",
    "            prior = row['prior_response'][j]\n",
    "            prefix = row['prefix'][j]\n",
    "        try:\n",
    "            edit_distance_cf[j] = lv.distance(fact,cf) / max(len(fact[prefix:]), len(cf[prefix:]))\n",
    "            edit_distance_prior[j] = lv.distance(fact,prior) / max(len(fact[prefix:]), len(prior[prefix:]))\n",
    "        except(ZeroDivisionError):\n",
    "            edit_distance_cf[j]=0\n",
    "            edit_distance_prior[j]=0\n",
    "    return edit_distance_cf, edit_distance_prior\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Read input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "root_dir=os.getcwd()[:-9]\n",
    "model_family=['llama3','mistral']\n",
    "input_file_llama3=f\"{root_dir}outputs/stability/{model_family[0]}/stability_strings\"\n",
    "input_file_mistral=f\"{root_dir}outputs/stability/{model_family[1]}/stability_strings\"\n",
    "\n",
    "df_llama3=pd.read_parquet(f\"{input_file_llama3}.parquet\", engine='fastparquet')\n",
    "df_mistral=pd.read_parquet(f\"{input_file_mistral}.parquet\", engine='fastparquet')\n",
    "df_orig_llama3=df_llama3.copy()\n",
    "df_orig_mistral=df_mistral.copy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prepare results to plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def edit_distance_vs_param(df,param='temperature'):\n",
    "    if param=='temperature':\n",
    "        df=df[df['sampler_type'].isin(['vocabulary','categorical'])]\n",
    "        var_to_plot='temperature'\n",
    "    elif param=='p':\n",
    "        df=df[df['temperature'].isin([0.6])]\n",
    "        df=df[df['sampler_type'].isin(['top-p token','vocabulary'])]\n",
    "        df['sampler_param']=df['sampler_param'].replace([0],[1])\n",
    "        var_to_plot='sampler_param'\n",
    "    elif param=='k':\n",
    "        df=df[df['temperature'].isin([0.6])]\n",
    "        df=df[df['sampler_type'].isin(['top-k token','vocabulary'])]\n",
    "        df['sampler_param']=df['sampler_param'].replace([0],[128256])\n",
    "        var_to_plot='sampler_param'\n",
    "\n",
    "    ints = df['interventions'].apply(lambda x: len(x))\n",
    "    df=df.copy()\n",
    "    df['prefix']=df.apply(prefix,axis=1)\n",
    "    df=df.reset_index()\n",
    "\n",
    "    df[['Counterfactual Token Generation','Interventional Token Generation']] = df.apply(edit_distance_lv, axis=1, result_type='expand')\n",
    "\n",
    "    results = df[['Counterfactual Token Generation','Interventional Token Generation',var_to_plot]].explode(['Counterfactual Token Generation','Interventional Token Generation'])\n",
    "        \n",
    "\n",
    "    res_to_plot = results.melt(id_vars=[var_to_plot], \n",
    "                                value_vars=['Counterfactual Token Generation','Interventional Token Generation'],\n",
    "                                var_name='method',\n",
    "                                value_name='edit_distance')\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "vars_to_plot=['temperature','p','k']\n",
    "\n",
    "results_llama3={var_to_plot:edit_distance_vs_param(df_orig_llama3.copy(),var_to_plot) for var_to_plot in vars_to_plot}\n",
    "results_cf_llama3={var_to_plot:0 for var_to_plot in vars_to_plot}\n",
    "results_interv_llama3={var_to_plot:0 for var_to_plot in vars_to_plot}\n",
    "\n",
    "results_cf_llama3['temperature']=results_llama3['temperature'][['temperature','Counterfactual Token Generation']]\n",
    "results_interv_llama3['temperature']=results_llama3['temperature'][['temperature','Interventional Token Generation']]\n",
    "results_cf_llama3['p']=results_llama3['p'][['sampler_param','Counterfactual Token Generation']]\n",
    "results_interv_llama3['p']=results_llama3['p'][['sampler_param','Interventional Token Generation']]\n",
    "results_cf_llama3['k']=results_llama3['k'][['sampler_param','Counterfactual Token Generation']]\n",
    "results_interv_llama3['k']=results_llama3['k'][['sampler_param','Interventional Token Generation']]\n",
    "\n",
    "results_mistral={var_to_plot:edit_distance_vs_param(df_orig_mistral.copy(),var_to_plot) for var_to_plot in vars_to_plot}\n",
    "results_cf_mistral={var_to_plot:0 for var_to_plot in vars_to_plot}\n",
    "results_interv_mistral={var_to_plot:0 for var_to_plot in vars_to_plot}\n",
    "\n",
    "results_cf_mistral['temperature']=results_mistral['temperature'][['temperature','Counterfactual Token Generation']]\n",
    "results_interv_mistral['temperature']=results_mistral['temperature'][['temperature','Interventional Token Generation']]\n",
    "results_cf_mistral['p']=results_mistral['p'][['sampler_param','Counterfactual Token Generation']]\n",
    "results_interv_mistral['p']=results_mistral['p'][['sampler_param','Interventional Token Generation']]\n",
    "results_cf_mistral['k']=results_mistral['k'][['sampler_param','Counterfactual Token Generation']]\n",
    "results_interv_mistral['k']=results_mistral['k'][['sampler_param','Interventional Token Generation']]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Figure 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Requires results from ```scripts/stability.py```, where `<model_family>` is `llama3` and `<categorical>` is `True`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_edit_distance_vs_param(results_cf1,\n",
    "                                results_interv1,\n",
    "                                results_cf2,\n",
    "                                results_interv2,\n",
    "                                output_dir,\n",
    "                                param='temperature'):\n",
    "    if param=='temperature':\n",
    "        var_to_plot='temperature'\n",
    "        xlabel=r'$\\tau$'\n",
    "        xticks=[0.0,0.2,0.4,0.6,0.8,1.0]\n",
    "        xticklabels=['0.0','0.2','0.4','0.6','0.8','1.0']\n",
    "        xlims=[-0.05,1.05]\n",
    "        file_name='tau_vs_ed'\n",
    "    elif param=='p':\n",
    "        var_to_plot='sampler_param'\n",
    "        xlabel=r'$p$'\n",
    "        xticks=[0.75,0.8,0.85,0.9,0.95,1]\n",
    "        xticklabels=['0.75','0.80','0.85','0.90','0.95','1.00']\n",
    "        xlims=[0.74,1.01]\n",
    "        file_name='p_vs_ed'\n",
    "    elif param=='k':\n",
    "        var_to_plot='sampler_param'\n",
    "        xlabel=r'$k$'\n",
    "        xticks=[1,10,100,1000,10000,100000]\n",
    "        xticklabels=[r'$1$',r'$10$',r'$10^2$',r'$10^3$',r'$10^4$',r'$10^5$']\n",
    "        xlims=[0.9,160000]\n",
    "        file_name='k_vs_ed'\n",
    "\n",
    "    res_to_plot_cf1 = results_cf1.melt(id_vars=[var_to_plot], value_vars=['Counterfactual Token Generation'],\n",
    "                                       var_name='method',value_name='edit_distance')\n",
    "    res_to_plot_interv1 = results_interv1.melt(id_vars=[var_to_plot], value_vars=['Interventional Token Generation'],\n",
    "                                       var_name='method', value_name='edit_distance')\n",
    "    res_to_plot_cf2 = results_cf2.melt(id_vars=[var_to_plot], value_vars=['Counterfactual Token Generation'],\n",
    "                                        var_name='method', value_name='edit_distance')\n",
    "    res_to_plot_interv2 = results_interv2.melt(id_vars=[var_to_plot], value_vars=['Interventional Token Generation'],\n",
    "                                    var_name='method', value_name='edit_distance')\n",
    "    \n",
    "    fig,ax=plt.subplots()\n",
    "\n",
    "    sns.lineplot(data=res_to_plot_cf1, ax=ax, x=var_to_plot, y='edit_distance', estimator='mean',\n",
    "                hue='method', palette=[\"#08b5c4\"], style=True, markers='o', dashes=False, markeredgecolor=None)\n",
    "    sns.lineplot(data=res_to_plot_cf2, ax=ax, x=var_to_plot, y='edit_distance', estimator='mean',\n",
    "                hue='method', palette=[\"#ff7f0e\"], style=True, markers='o', dashes=False, markeredgecolor=None)\n",
    "    sns.lineplot(data=res_to_plot_interv1, ax=ax, x=var_to_plot, y='edit_distance', estimator='mean',\n",
    "                hue='method', palette=[\"#8ce5ed\"], style=True, markers='^', dashes=[(4,4)], alpha=0.8, markeredgecolor=None)\n",
    "    sns.lineplot(data=res_to_plot_interv2, ax=ax, x=var_to_plot, y='edit_distance', estimator='mean',\n",
    "                hue='method', palette=[\"#f0a665\"], style=True, markers='^', dashes=[(4,4)], alpha=0.8, markeredgecolor=None)\n",
    "\n",
    "    ax.set_xlabel(xlabel)\n",
    "    ax.set_ylabel('Edit Distance')\n",
    "    ax.get_legend().remove()\n",
    "    ax.spines[['right', 'top']].set_visible(False)\n",
    "    if param=='k':\n",
    "        ax.set_xscale('log')\n",
    "    ax.set_xticks(xticks)\n",
    "    ax.set_xticklabels(xticklabels)\n",
    "    ax.set_xlim(xlims[0],xlims[1])\n",
    "    ax.set_yticks([0.45,0.50,0.55,0.60,0.65,0.70,0.75])\n",
    "    ax.set_yticklabels(['0.45','0.50','0.55','0.60','0.65','0.70','0.75'])\n",
    "    ax.set_ylim(0.44,0.76)\n",
    "    plt.savefig(f'{output_dir}{file_name}.pdf',bbox_inches='tight')\n",
    "    print(f'Saved figure at {output_dir}{file_name}.pdf')\n",
    "\n",
    "def plot_legend(output_dir):\n",
    "    fig,ax1=plt.subplots()\n",
    "    ax1.plot([0,1],[0,1],color='grey',marker='^',linestyle='dashed',dashes=(4,4),label='Interventional Token Generation',alpha=0.8)\n",
    "    ax1.plot([0,1],[0,1],color='grey',marker='o',label='Counterfactual Token Generation')\n",
    "    plt.close()\n",
    "    fig,ax=plt.subplots(figsize=(36,1))\n",
    "    ax.legend(*ax1.get_legend_handles_labels(),ncol=2,frameon=False, loc='center')\n",
    "    ax.spines[['right', 'top','left','bottom']].set_visible(False)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    plt.savefig(f'{output_dir}legend.pdf',bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for var_to_plot in vars_to_plot:\n",
    "        plot_edit_distance_vs_param(results_cf1=results_cf_llama3[var_to_plot],\n",
    "                                        results_cf2=results_cf_mistral[var_to_plot],\n",
    "                                        results_interv1=results_interv_llama3[var_to_plot],\n",
    "                                        results_interv2=results_interv_mistral[var_to_plot],\n",
    "                                        output_dir=f'{root_dir}outputs/figures/edit_distance/',\n",
    "                                        param=var_to_plot)\n",
    "plot_legend(output_dir=f'{root_dir}outputs/figures/edit_distance/')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Figure 12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_edit_distance_vs_param_categorical(results_cf1,\n",
    "                                results_interv1,\n",
    "                                results_cf2,\n",
    "                                results_interv2,\n",
    "                                output_dir,\n",
    "                                param='temperature'):\n",
    "    if param=='temperature':\n",
    "        var_to_plot='temperature'\n",
    "        xlabel=r'$\\tau$'\n",
    "        xticks=[0.0,0.2,0.4,0.6,0.8,1.0]\n",
    "        xticklabels=['0.0','0.2','0.4','0.6','0.8','1.0']\n",
    "        xlims=[-0.05,1.05]\n",
    "        file_name='tau_vs_ed'\n",
    "    elif param=='p':\n",
    "        var_to_plot='sampler_param'\n",
    "        xlabel=r'$p$'\n",
    "        xticks=[0.75,0.8,0.85,0.9,0.95,1]\n",
    "        xticklabels=['0.75','0.80','0.85','0.90','0.95','1.00']\n",
    "        xlims=[0.74,1.01]\n",
    "        file_name='p_vs_ed'\n",
    "    elif param=='k':\n",
    "        var_to_plot='sampler_param'\n",
    "        xlabel=r'$k$'\n",
    "        xticks=[1,10,100,1000,10000,100000]\n",
    "        xticklabels=[r'$1$',r'$10$',r'$10^2$',r'$10^3$',r'$10^4$',r'$10^5$']\n",
    "        xlims=[0.9,160000]\n",
    "        file_name='k_vs_ed'\n",
    "    file_name=file_name+'_categorical'\n",
    "\n",
    "    res_to_plot_cf1 = results_cf1.melt(id_vars=[var_to_plot], value_vars=['Counterfactual Token Generation'],\n",
    "                                       var_name='method',value_name='edit_distance')\n",
    "    res_to_plot_interv1 = results_interv1.melt(id_vars=[var_to_plot], value_vars=['Interventional Token Generation'],\n",
    "                                       var_name='method', value_name='edit_distance')\n",
    "    res_to_plot_cf2 = results_cf2.melt(id_vars=[var_to_plot], value_vars=['Counterfactual Token Generation'],\n",
    "                                        var_name='method', value_name='edit_distance')\n",
    "    res_to_plot_interv2 = results_interv2.melt(id_vars=[var_to_plot], value_vars=['Interventional Token Generation'],\n",
    "                                       var_name='method', value_name='edit_distance')\n",
    "    \n",
    "    y1,y2,y3,y4,yerr1,yerr2,yerr3,yerr4=[],[],[],[],[],[],[],[]\n",
    "    x=res_to_plot_cf1[var_to_plot].unique()\n",
    "    offset=0.039\n",
    "    offsets=[-offset,0,offset,2*offset]\n",
    "    for t in res_to_plot_cf1[var_to_plot].unique():\n",
    "        cf_t=res_to_plot_cf1[res_to_plot_cf1[var_to_plot]==t]\n",
    "        y1.append(cf_t['edit_distance'].mean())\n",
    "        yerr1.append(cf_t['edit_distance'].std()*1.96/np.sqrt(len(cf_t)))\n",
    "        cf_t=res_to_plot_cf2[res_to_plot_cf2[var_to_plot]==t]\n",
    "        y2.append(cf_t['edit_distance'].mean())\n",
    "        yerr2.append(cf_t['edit_distance'].std()*1.96/np.sqrt(len(cf_t)))\n",
    "        cf_t=res_to_plot_interv1[res_to_plot_interv1[var_to_plot]==t]\n",
    "        y3.append(cf_t['edit_distance'].mean())\n",
    "        yerr3.append(cf_t['edit_distance'].std()*1.96/np.sqrt(len(cf_t)))\n",
    "        cf_t=res_to_plot_interv2[res_to_plot_interv2[var_to_plot]==t]\n",
    "        y4.append(cf_t['edit_distance'].mean())\n",
    "        yerr4.append(cf_t['edit_distance'].std()*1.96/np.sqrt(len(cf_t)))\n",
    "    # exit(0)\n",
    "\n",
    "    fig,ax=plt.subplots()\n",
    "\n",
    "    cols=sns.husl_palette(6)\n",
    "    plt.errorbar([t+offsets[0] for t in x],y1,yerr1,fmt='o',color=cols[3],\n",
    "                 elinewidth=4,capsize=8,capthick=4)\n",
    "    plt.errorbar([t+offsets[1] for t in x],y2,yerr2,fmt='s',color=cols[1],\n",
    "                 elinewidth=4,capsize=8,capthick=4)\n",
    "    plt.errorbar([t+offsets[2] for t in x],y3,yerr3,fmt='^',color=cols[5],\n",
    "                 elinewidth=4,capsize=8,capthick=4)\n",
    "    plt.errorbar([t+offsets[3] for t in x],y4,yerr4,fmt='X',color=cols[2],\n",
    "                 elinewidth=4,capsize=8,capthick=4)\n",
    "    \n",
    "    ax.set_xlabel(xlabel)\n",
    "    ax.set_ylabel('Edit Distance')\n",
    "    cols=sns.color_palette(\"husl\", 6)\n",
    "    fig,ax1=plt.subplots()\n",
    "    ax1.plot([0,1],[0,1],color=cols[2],marker='X',linestyle='',label='Interventional (ITS)')\n",
    "    ax1.plot([0,1],[0,1],color=cols[5],marker='^',linestyle='',label='Interventional (GM)')\n",
    "    ax1.plot([0,1],[0,1],color=cols[1],marker='s',linestyle='',label='Counterfactual (ITS)')\n",
    "    ax1.plot([0,1],[0,1],color=cols[3],marker='o',linestyle='',label='Counterfactual (GM)')\n",
    "    plt.close()\n",
    "    ax.legend(*ax1.get_legend_handles_labels(),ncol=1,loc='center right',handletextpad=0.1,bbox_to_anchor=(1.8,0.5))\n",
    "    ax.spines[['right', 'top']].set_visible(False)\n",
    "    ax.set_xticks(xticks)\n",
    "    ax.set_xticklabels(xticklabels)\n",
    "    ax.set_xlim(xlims[0]-0.02,xlims[1]+0.06)\n",
    "    ax.set_yticks([0.45,0.50,0.55,0.60,0.65,0.70,0.75])\n",
    "    ax.set_yticklabels(['0.45','0.50','0.55','0.60','0.65','0.70','0.75'])\n",
    "    ax.set_ylim(0.44,0.76)\n",
    "    plt.savefig(f'{output_dir}{file_name}.pdf',bbox_inches='tight')\n",
    "    print(f'Saved figure at {output_dir}{file_name}.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dir=f\"{root_dir}outputs/stability/categorical\"\n",
    "\n",
    "input_file_categorical=f\"{input_dir}/stability_strings\"\n",
    "df_categorical=pd.read_parquet(f\"{input_file_categorical}.parquet\", engine='fastparquet')\n",
    "\n",
    "df_orig_categorical=df_categorical.copy()\n",
    "\n",
    "vars_to_plot=['temperature']\n",
    "results_categorical={var_to_plot:edit_distance_vs_param(df_orig_categorical.copy(),var_to_plot) for var_to_plot in vars_to_plot}\n",
    "results_cf_categorical={var_to_plot:0 for var_to_plot in vars_to_plot}\n",
    "results_interv_categorical={var_to_plot:0 for var_to_plot in vars_to_plot}\n",
    "\n",
    "results_cf_categorical['temperature']=results_categorical['temperature'][['temperature','Counterfactual Token Generation']]\n",
    "results_interv_categorical['temperature']=results_categorical['temperature'][['temperature','Interventional Token Generation']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_edit_distance_vs_param_categorical(results_cf1=results_cf_llama3['temperature'],\n",
    "                                        results_cf2=results_cf_categorical['temperature'],\n",
    "                                        results_interv1=results_interv_llama3['temperature'],\n",
    "                                        results_interv2=results_interv_categorical['temperature'],\n",
    "                                        output_dir=f'{root_dir}outputs/figures/edit_distance/',\n",
    "                                        param='temperature')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
