{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Update text labels and generate variants\n",
    "1. assign text uid;\n",
    "2. update sentiment and relation labels;\n",
    "3. generate multiple text variants."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Assign unique ids for texts\n",
    "This is critical for (1) non-overlapping train/val/test split; and (2) training batch sampling that supports contrastive loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index(['raw text', 'dataset', 'task', 'control', 'raw label', 'input text'], dtype='object')\n",
      "Index(['raw text', 'dataset', 'task', 'control', 'raw label', 'input text',\n",
      "       'text uid'],\n",
      "      dtype='object')\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "df = pd.read_pickle('./data/tmp/zuco_label_input_text.df')\n",
    "print(df.columns)\n",
    "\n",
    "uids, unique_texts = pd.factorize(df['input text'])\n",
    "df['text uid'] = uids.tolist()\n",
    "print(df.columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Update semantic/relation labels\n",
    "1. to separate columns;\n",
    "2. with more tractable natural language terms."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "raw label\n",
      "1                                                                 140\n",
      "0                                                                 137\n",
      "-1                                                                123\n",
      "EDUCATION                                                         120\n",
      "JOB_TITLE                                                         114\n",
      "                                                                 ... \n",
      "Where was he born? Philadelphia                                     1\n",
      "Which army did he join?  U.S. Army                                  1\n",
      "Which university did he get his degree from? Tulane University      1\n",
      "Which war interrupted his sutdies? World War II                     1\n",
      "Who ran an oil drilling company? his father                         1\n",
      "Name: count, Length: 82, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "print(df.value_counts('raw label'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sentiment label for task1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "raw label\n",
      "1     140\n",
      "0     137\n",
      "-1    123\n",
      "Name: count, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "df_task1 = df[df['task']=='task1']\n",
    "print(df_task1.value_counts('raw label'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_7470/322930557.py:9: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df_task1['sentiment label'] = df_task1['raw label'].apply(create_readable_sentiment_label)\n"
     ]
    }
   ],
   "source": [
    "def create_readable_sentiment_label(src):\n",
    "    \n",
    "    labels = ['-1', '0', '1']\n",
    "    new_labels = ['negative', 'neutral', 'positive']\n",
    "    assert src in labels\n",
    "    tgt = new_labels[labels.index(src)]\n",
    "    return tgt\n",
    "\n",
    "df_task1['sentiment label'] = df_task1['raw label'].apply(create_readable_sentiment_label)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Relation label for task3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "raw label\n",
      "EDUCATION                120\n",
      "JOB_TITLE                114\n",
      "NATIONALITY              109\n",
      "EMPLOYER                 104\n",
      "WIFE                     102\n",
      "POLITICAL_AFFILIATION     92\n",
      "FOUNDER                   84\n",
      "VISITED                   48\n",
      "AWARD                     45\n",
      "Name: count, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "df_task3 = df[df['task']=='task3']\n",
    "print(df_task3.value_counts('raw label'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_7470/1638297978.py:14: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df_task3['relation label'] = df_task3['raw label'].apply(create_readable_relation_label)\n"
     ]
    }
   ],
   "source": [
    "def create_readable_relation_label(src):\n",
    "    # replace the original labels by LM-friendly terms\n",
    "    labels = ['AWARD', 'EDUCATION', 'EMPLOYER', \n",
    "            'FOUNDER', 'JOB_TITLE', 'NATIONALITY', \n",
    "            'POLITICAL_AFFILIATION', 'VISITED', 'WIFE']\n",
    "    \n",
    "    new_labels = ['awarding', 'education', 'employment',\n",
    "                    'foundation', 'job title', 'nationality', \n",
    "                    'political affiliation','visit', 'marriage'] \n",
    "    assert src in labels\n",
    "    tgt = new_labels[labels.index(src)]\n",
    "    return tgt\n",
    "\n",
    "df_task3['relation label'] = df_task3['raw label'].apply(create_readable_relation_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1888, 9)\n",
      "Index(['raw text', 'dataset', 'task', 'control', 'raw label', 'input text',\n",
      "       'text uid', 'sentiment label', 'relation label'],\n",
      "      dtype='object')\n"
     ]
    }
   ],
   "source": [
    "df_task1_task3 = pd.concat([df_task1, df_task3], ignore_index=False)\n",
    "df_class_label = pd.merge(df, df_task1_task3, how='left')\n",
    "print(df_class_label.shape)\n",
    "print(df_class_label.columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate text variants\n",
    "- Run [_gen_variants_llm_regular.py](_gen_variants_llm_regular.py) to generate 6 finely paraphrased variants using LLM with detailed instruction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1888, 2) Index(['lexical simplification (v0)', 'lexical simplification (v1)'], dtype='object')\n",
      "(1888, 2) Index(['semantic clarity (v0)', 'semantic clarity (v1)'], dtype='object')\n",
      "(1888, 2) Index(['syntax simplification (v0)', 'syntax simplification (v1)'], dtype='object')\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "df1 = pd.read_pickle('./data/tmp/zuco_label_lexical_simplification.df')\n",
    "df2 = pd.read_pickle('./data/tmp/zuco_label_semantic_clarity.df')\n",
    "df3 = pd.read_pickle('./data/tmp/zuco_label_syntax_simplification.df')\n",
    "print(df1.shape, df1.columns)\n",
    "print(df2.shape, df2.columns)\n",
    "print(df3.shape, df3.columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Run [_gen_variants_llm_general.py](_gen_variants_llm_general.py) to generate 8+8 simplified/rewritten variants using LLM with general instruction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1888, 8) Index(['simplified text (v0)', 'simplified text (v1)', 'simplified text (v2)',\n",
      "       'simplified text (v3)', 'simplified text (v4)', 'simplified text (v5)',\n",
      "       'simplified text (v6)', 'simplified text (v7)'],\n",
      "      dtype='object')\n",
      "(1888, 8) Index(['rewritten text (v0)', 'rewritten text (v1)', 'rewritten text (v2)',\n",
      "       'rewritten text (v3)', 'rewritten text (v4)', 'rewritten text (v5)',\n",
      "       'rewritten text (v6)', 'rewritten text (v7)'],\n",
      "      dtype='object')\n"
     ]
    }
   ],
   "source": [
    "df_simplified = pd.read_pickle('./data/tmp/zuco_simplified_text.df')\n",
    "df_rewritten = pd.read_pickle('./data/tmp/zuco_rewritten_text.df')\n",
    "print(df_simplified.shape, df_simplified.columns)\n",
    "print(df_rewritten.shape, df_rewritten.columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Run [_gen_variants_t5_naive.py](_gen_variants_t5_naive.py) to generate 1+1 simplified/rewritten variants using the integerated LM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1888, 8) Index(['raw text', 'dataset', 'task', 'control', 'raw label', 'input text',\n",
      "       'naive rewritten', 'naive simplified'],\n",
      "      dtype='object')\n"
     ]
    }
   ],
   "source": [
    "df_naive = pd.read_pickle('./data/tmp/zuco_label_naive.df')\n",
    "print(df_naive.shape, df_naive.columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Merge labels and text variants\n",
    "- Our final 8 variants consist of: (1) the 6 LLM-finely-paraphrased variants; and (2) the 2 naive variants, where we replace the excpetion samples that marked by an <ERROR> in (1) with those general-simplified variants. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1888, 15) Index(['raw text', 'dataset', 'task', 'control', 'raw label', 'input text',\n",
      "       'text uid', 'sentiment label', 'relation label',\n",
      "       'lexical simplification (v0)', 'lexical simplification (v1)',\n",
      "       'semantic clarity (v0)', 'semantic clarity (v1)',\n",
      "       'syntax simplification (v0)', 'syntax simplification (v1)'],\n",
      "      dtype='object')\n"
     ]
    }
   ],
   "source": [
    "df_6var_error = pd.concat([df_class_label, df1, df2, df3], axis=1)\n",
    "print(df_6var_error.shape, df_6var_error.columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- handle `<ERROR>` cases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1888, 15) Index(['raw text', 'dataset', 'task', 'control', 'raw label', 'input text',\n",
      "       'text uid', 'sentiment label', 'relation label',\n",
      "       'lexical simplification (v0)', 'lexical simplification (v1)',\n",
      "       'semantic clarity (v0)', 'semantic clarity (v1)',\n",
      "       'syntax simplification (v0)', 'syntax simplification (v1)'],\n",
      "      dtype='object')\n",
      "(0, 15)\n"
     ]
    }
   ],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "# Find rows that contain \"<ERROR>\" in any column\n",
    "error_rows = df_6var_error[df_6var_error.apply(lambda row: row.astype(str).str.contains('<ERROR>').any(), axis=1)]\n",
    "\n",
    "error_check_columns = ['lexical simplification (v0)', 'lexical simplification (v1)', \n",
    "                       'semantic clarity (v0)', 'semantic clarity (v1)', \n",
    "                       'syntax simplification (v0)', 'syntax simplification (v1)']\n",
    "\n",
    "# List of columns in df2 that will replace the columns in df1\n",
    "replace_columns_df_simplified = ['simplified text (v0)', 'simplified text (v1)', \n",
    "                                 'simplified text (v2)', 'simplified text (v3)', \n",
    "                                 'simplified text (v4)', 'simplified text (v5)']\n",
    "\n",
    "df_all = deepcopy(df_6var_error)\n",
    "# Iterate through each row in df1 and check for \"<ERROR>\" in the specified columns\n",
    "for i, row in df_6var_error.iterrows():\n",
    "    if row[error_check_columns].astype(str).str.contains('<ERROR>').any():\n",
    "        # Replace the columns in df1 with the corresponding columns from df2\n",
    "        df_all.loc[i, error_check_columns] = df_simplified.loc[i, replace_columns_df_simplified].values\n",
    "\n",
    "print(df_all.shape, df_all.columns)\n",
    "error_rows_6var = df_all[df_all.apply(lambda row: row.astype(str).str.contains('<ERROR>').any(), axis=1)]\n",
    "print(error_rows_6var.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- add naive variants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index(['raw text', 'dataset', 'task', 'control', 'raw label', 'input text',\n",
      "       'text uid', 'sentiment label', 'relation label',\n",
      "       'lexical simplification (v0)', 'lexical simplification (v1)',\n",
      "       'semantic clarity (v0)', 'semantic clarity (v1)',\n",
      "       'syntax simplification (v0)', 'syntax simplification (v1)',\n",
      "       'naive rewritten', 'naive simplified'],\n",
      "      dtype='object')\n"
     ]
    }
   ],
   "source": [
    "df_all[['naive rewritten', 'naive simplified']] = df_naive[['naive rewritten', 'naive simplified']]\n",
    "print(df_all.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.to_pickle(df_all, './data/tmp/zuco_label_8variants.df')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "glim",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
