{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "import pickle\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "def read_results(tasks: List[str], data_sources: List[str], mask_ratios: List[str], context_type,):\n",
    "    \n",
    "    all_ans_paths = []\n",
    "    for task in tasks:\n",
    "        if task != 'continue_conversation':\n",
    "            for data_source in data_sources:\n",
    "                for mask_ratio in mask_ratios:\n",
    "                    if data_source == 'news' or data_source == 'conversation':\n",
    "                        ans_path = f'/vol/research/lyc/llm_memorize/answer_{task}_{data_source}_{mask_ratio}.pkl'\n",
    "                    elif data_source == 'arxiv':\n",
    "                        ans_path = f'/vol/research/lyc/llm_memorize/{\"arxiv_buggy/\" if context_type == \"Random-phrase\" else \"\"}answer_{task}_{data_source}_{mask_ratio}.pkl'\n",
    "                    all_ans_paths.append(ans_path)\n",
    "        else:\n",
    "            for mask_ratio in mask_ratios:\n",
    "                ans_path = f'/vol/research/lyc/llm_memorize/answer_{task}_conversation_{mask_ratio}.pkl'\n",
    "                all_ans_paths.append(ans_path)\n",
    "    \n",
    "    results = []\n",
    "    for ans_path in all_ans_paths:\n",
    "        print(ans_path)\n",
    "\n",
    "        with open(ans_path, 'rb') as f:\n",
    "            ans = pickle.load(f)\n",
    "        \n",
    "        r = ans.metrics[context_type]\n",
    "        if 'precision' in r:\n",
    "            r['bertscore_precision'] = np.mean(r['precision'])\n",
    "            r['bertscore_recall'] = np.mean(r['recall'])\n",
    "            r['bertscore_f1'] = np.mean(r['f1'])\n",
    "        results.append(r)\n",
    "    \n",
    "    df = pd.DataFrame(results)\n",
    "    df = df[['bleu', 'meteor', 'rouge1', 'rouge2', 'rougeL', 'bertscore_precision', 'bertscore_recall', 'bertscore_f1']]\n",
    "    avg_results = df.mean()\n",
    "\n",
    "    return avg_results\n",
    "\n",
    "def main(context_type):\n",
    "    tasks = ['summarisation', 'qa', 'continue_conversation']\n",
    "    data_sources = ['news', 'arxiv']\n",
    "    mask_ratios = [0.2, 0.35, 0.5, 0.65, 0.8]\n",
    "    \n",
    "    all_results = {}\n",
    "    for ratio in mask_ratios:\n",
    "        avg_results = read_results(tasks=tasks, data_sources=data_sources, mask_ratios=[ratio], context_type=context_type)\n",
    "        print(avg_results)\n",
    "        all_results[ratio] = avg_results\n",
    "    \n",
    "    df = pd.DataFrame.from_dict(all_results, orient='index')\n",
    "    print(df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "main('self-info-phrase')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "main('Random-phrase')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "main('no2-phrase')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(context_type):\n",
    "    tasks = ['summarisation', 'qa', 'conversation']\n",
    "    data_sources = ['news', 'arxiv',]\n",
    "    mask_ratios = [0.2, 0.35, 0.5, 0.65, 0.8]\n",
    "    \n",
    "    all_results = []\n",
    "    for ratio in mask_ratios:\n",
    "        ratio_results = []\n",
    "        method = f\"SC-{ratio}\"\n",
    "        for task in tasks:\n",
    "            if task == 'conversation':\n",
    "                avg_results = read_results(tasks=['continue_conversation'], data_sources=['conversation'], mask_ratios=[ratio], context_type=context_type)\n",
    "            else:\n",
    "                avg_results = read_results(tasks=[task], data_sources=data_sources, mask_ratios=[ratio], context_type=context_type)\n",
    "            avg_results['Method'] = method\n",
    "            avg_results['Task'] = task\n",
    "            avg_r = avg_results.to_frame().T.set_index(['Method', 'Task'])\n",
    "            print(ratio, '----', task)\n",
    "            print(avg_r)\n",
    "            ratio_results.append(avg_r)\n",
    "        print(ratio, '*****')\n",
    "        df = pd.concat(ratio_results)\n",
    "        print(df)\n",
    "        ratio_avg = df.mean()\n",
    "        ratio_avg.name\n",
    "        df.loc[(method, 'avg'), :] = ratio_avg\n",
    "        print(df, \"^^^^^^^^^^\")\n",
    "        all_results.append(df)\n",
    "    \n",
    "    df = pd.concat(all_results, axis=0)\n",
    "    print(df)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = main('self-info-phrase')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv('/user/HS502/yl02706/LLMs_Memorize/self-info-phrase.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(context_type):\n",
    "    tasks = ['summarisation', 'qa', 'conversation']\n",
    "    data_sources = ['news', 'arxiv']\n",
    "    mask_ratios = [0.2, 0.35, 0.5, 0.65, 0.8]\n",
    "    \n",
    "    all_results = []\n",
    "    method = f\"Original\"\n",
    "    for task in tasks:\n",
    "        if task == 'conversation':\n",
    "            avg_results = read_results(tasks=['continue_conversation'], data_sources=['conversation'], mask_ratios=mask_ratios, context_type=context_type)\n",
    "        else:\n",
    "            avg_results = read_results(tasks=[task], data_sources=data_sources, mask_ratios=mask_ratios, context_type=context_type)\n",
    "        avg_results['Method'] = method\n",
    "        avg_results['Task'] = task\n",
    "        avg_r = avg_results.to_frame().T.set_index(['Method', 'Task'])\n",
    "        print(avg_r)\n",
    "        all_results.append(avg_r)\n",
    "    df = pd.concat(all_results)\n",
    "    print(df)\n",
    "    ratio_avg = df.mean()\n",
    "    ratio_avg.name\n",
    "    df.loc[(method, 'avg'), :] = ratio_avg\n",
    "    print(df, \"^^^^^^^^^^\")\n",
    "    # all_results.append(df)\n",
    "    \n",
    "    # df = pd.concat(all_results, axis=0)\n",
    "    print(df)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df2 = main('no2-phrase')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.drop(index=('Original', 'avg'), inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.concat([df2, df])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df2.loc[('Original', 'avg')] - df.loc[('SC-0.2', 'avg')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_sc = []\n",
    "for index, line in df.iterrows():\n",
    "    task = index[1]\n",
    "    origin_r = df2.loc[('Original', task)]\n",
    "    gap = origin_r - line\n",
    "    new_line = {}\n",
    "    for key, value in line.items():\n",
    "        value = f\"{value:.3f}\"[1:]\n",
    "        if 'bert' in key:\n",
    "            gap_ = f\"{gap[key]:.3f}\"[1:]\n",
    "        else:\n",
    "            gap_ = f\"{gap[key]:.2f}\"[1:]\n",
    "        new_line[key] = f\"{value} ({gap_})\"\n",
    "    new_sc.append(pd.Series(new_line, name=index))\n",
    "df = pd.concat(new_sc, axis=1).T\n",
    "print(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_sc = []\n",
    "for index, line in df2.iterrows():\n",
    "    task = index[1]\n",
    "    new_line = {}\n",
    "    for key, value in line.items():\n",
    "        value = f\"{value:.3f}\"[1:]\n",
    "        new_line[key] = f\"{value}\"\n",
    "    new_sc.append(pd.Series(new_line, name=index))\n",
    "df2 = pd.concat(new_sc, axis=1).T\n",
    "print(df2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.concat([df2, df])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(pd.concat([df2, df]).to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(context_type):\n",
    "    tasks = ['summarisation', 'qa', 'reconstruction', 'continue_conversation']\n",
    "    data_sources = ['news', 'arxiv']\n",
    "    mask_ratios = [0.2, 0.35, 0.5, 0.65, 0.8]\n",
    "    \n",
    "    all_results = []\n",
    "    method = f\"Original\"\n",
    "    for mask in mask_ratios:\n",
    "        avg_results = read_results(tasks=tasks, data_sources=data_sources, mask_ratios=[mask], context_type=context_type)\n",
    "        avg_results['Method'] = context_type\n",
    "        avg_results['Ratio'] = mask\n",
    "        avg_r = avg_results.to_frame().T.set_index(['Method', 'Ratio'])\n",
    "        print(avg_r)\n",
    "        all_results.append(avg_r)\n",
    "    df = pd.concat(all_results)\n",
    "    # print(df)\n",
    "    # ratio_avg = df.mean()\n",
    "    # ratio_avg.name\n",
    "    # df.loc[(method, 'avg'), :] = ratio_avg\n",
    "    # print(df, \"^^^^^^^^^^\")\n",
    "    # all_results.append(df)\n",
    "    \n",
    "    # df = pd.concat(all_results, axis=0)\n",
    "    print(df)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "main('Random-phrase')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = main('Random-phrase')\n",
    "df2 = main('self-info-phrase')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "df2 = df2.reset_index()\n",
    "df = df.reset_index()\n",
    "\n",
    "fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 4), dpi=120)\n",
    "df2.plot(y='bleu', x='Ratio', ax=axes[0], marker='^', label = 'Selective Context')\n",
    "df.plot(y='bleu', x='Ratio', ax=axes[0], marker='+', label = 'Random')\n",
    "axes[0].set_title('BLEU')\n",
    "axes[0].set_xlabel('')\n",
    "axes[0].set_xticks([0.2, 0.35, 0.5, 0.65, 0.8])\n",
    "# axes[0].set_xlabel('Filtered Ratio')\n",
    "\n",
    "df2.plot(y='rouge1', x='Ratio', ax=axes[1], marker='^', label = 'Selective Context')\n",
    "df.plot(y='rouge1', x='Ratio', ax=axes[1], marker='+', label = 'Random')\n",
    "axes[1].set_title('ROUGE1')\n",
    "axes[1].set_xlabel('')\n",
    "axes[1].set_xticks([0.2, 0.35, 0.5, 0.65, 0.8])\n",
    "# axes[1].set_ylim(0.2, 0.7)\n",
    "# axes[1].set_xlabel('Filtered Ratio')\n",
    "\n",
    "df2.plot(y='bertscore_f1', x='Ratio', ax=axes[2], marker='^', label = 'Selective Context')\n",
    "df.plot(y='bertscore_f1', x='Ratio', ax=axes[2], marker='+', label = 'Random')\n",
    "axes[2].set_title('BERTScore')\n",
    "axes[2].set_xlabel('')\n",
    "axes[2].set_xticks([0.2, 0.35, 0.5, 0.65, 0.8])\n",
    "# axes[-1].set_xlabel('Filtered Ratio')\n",
    "# axes[2].set_ylim(0.5, 1)\n",
    "fig.text(0.52, -0.03, 'Context reduction ratio', ha='center', fontsize=14)\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def task_wise(context_type):\n",
    "    tasks = ['summarisation', 'qa', 'reconstruction', 'continue_conversation']\n",
    "    data_sources = ['news', 'arxiv']\n",
    "    mask_ratios = [0.2, 0.35, 0.5, 0.65, 0.8]\n",
    "    \n",
    "    all_results = []\n",
    "    method = f\"Original\"\n",
    "    for mask in mask_ratios:\n",
    "        for task in tasks:\n",
    "            avg_results = read_results(tasks=[task], data_sources=data_sources, mask_ratios=[mask], context_type=context_type)\n",
    "            avg_results['Task'] = task\n",
    "            avg_results['Ratio'] = mask\n",
    "            avg_r = avg_results.to_frame().T.set_index(['Task', 'Ratio'])\n",
    "            print(avg_r)\n",
    "            all_results.append(avg_r)\n",
    "\n",
    "    df = pd.concat(all_results)\n",
    "    print(df)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_wise(context_type):\n",
    "    tasks = ['summarisation', 'qa', 'reconstruction',]\n",
    "    data_sources = ['news', 'arxiv', 'conversation']\n",
    "    mask_ratios = [0.2, 0.35, 0.5, 0.65, 0.8]\n",
    "    \n",
    "    all_results = []\n",
    "    method = f\"Original\"\n",
    "    for mask in mask_ratios:\n",
    "        for data_source in data_sources:\n",
    "            if data_source == 'conversation':\n",
    "                avg_results = read_results(tasks=['continue_conversation'], data_sources=[data_source], mask_ratios=[mask], context_type=context_type)\n",
    "            else:\n",
    "                avg_results = read_results(tasks=tasks, data_sources=[data_source], mask_ratios=[mask], context_type=context_type)\n",
    "            avg_results['Data'] = data_source\n",
    "            avg_results['Ratio'] = mask\n",
    "            avg_r = avg_results.to_frame().T.set_index(['Data', 'Ratio'])\n",
    "            print(avg_r)\n",
    "            all_results.append(avg_r)\n",
    "\n",
    "    df = pd.concat(all_results)\n",
    "    print(df)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_wise('self-info-phrase')\n",
    "# task_wise('self-info-phrase')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "df = task_wise('self-info-phrase')\n",
    "df = df.reset_index()\n",
    "\n",
    "fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 4), dpi=120)\n",
    "\n",
    "markers = {\n",
    "    'Conversation': '+',\n",
    "    'Reconstruction': 's',\n",
    "    'Summarisation': 'v',\n",
    "    'QA': '^',\n",
    "}\n",
    "\n",
    "colors = {\n",
    "    'Conversation': 'salmon',\n",
    "    'Reconstruction': 'y',\n",
    "    'Summarisation': 'grey',\n",
    "    'QA': 'violet',\n",
    "}\n",
    "\n",
    "for task_name, group in df.groupby('Task'):\n",
    "    if task_name == 'continue_conversation':\n",
    "        task_name = 'Conversation'\n",
    "    if task_name == 'reconstruction':\n",
    "        task_name = 'Reconstruction'\n",
    "    if task_name == 'summarisation':\n",
    "        task_name = 'Summarisation'\n",
    "    if task_name == 'qa':\n",
    "        task_name = 'QA'\n",
    "    group.plot(y='bleu', x='Ratio', ax=axes[0], marker=markers[task_name], label = task_name, color=colors[task_name])\n",
    "    group.plot(y='rouge1', x='Ratio', ax=axes[1], marker=markers[task_name], label = task_name, color=colors[task_name])\n",
    "    group.plot(y='bertscore_f1', x='Ratio', ax=axes[2], marker=markers[task_name], label = task_name, color=colors[task_name])\n",
    "\n",
    "axes[0].set_title('BLEU')\n",
    "axes[0].set_xlabel('')\n",
    "axes[0].set_xticks([0.2, 0.35, 0.5, 0.65, 0.8])\n",
    "\n",
    "axes[1].set_title('ROUGE1')\n",
    "axes[1].set_xlabel('')\n",
    "axes[1].set_xticks([0.2, 0.35, 0.5, 0.65, 0.8])\n",
    "\n",
    "axes[2].set_title('BERTScore')\n",
    "axes[2].set_xlabel('')\n",
    "axes[2].set_xticks([0.2, 0.35, 0.5, 0.65, 0.8])\n",
    "fig.text(0.52, -0.03, 'Context reduction ratio', ha='center', fontsize=14)\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "df = data_wise('self-info-phrase')\n",
    "df = df.reset_index()\n",
    "\n",
    "fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 4), dpi=120)\n",
    "\n",
    "markers = {\n",
    "    'arxiv': '+',\n",
    "    'BBC': 's',\n",
    "    'ShareGPT': 'v',\n",
    "}\n",
    "\n",
    "colors = {\n",
    "    'arxiv': 'salmon',\n",
    "    'BBC': 'y',\n",
    "    'ShareGPT': 'violet',\n",
    "}\n",
    "\n",
    "for data_source, group in df.groupby('Data'):\n",
    "    if data_source == 'news':\n",
    "        data_source = 'BBC'\n",
    "    if data_source == 'conversation':\n",
    "        data_source = 'ShareGPT'\n",
    "    group.plot(y='bleu', x='Ratio', ax=axes[0], marker=markers[data_source], label = data_source, color=colors[data_source])\n",
    "    group.plot(y='rouge1', x='Ratio', ax=axes[1], marker=markers[data_source], label = data_source, color=colors[data_source])\n",
    "    group.plot(y='bertscore_f1', x='Ratio', ax=axes[2], marker=markers[data_source], label = data_source, color=colors[data_source])\n",
    "\n",
    "axes[0].set_title('BLEU')\n",
    "axes[0].set_xlabel('')\n",
    "axes[0].set_xticks([0.2, 0.35, 0.5, 0.65, 0.8])\n",
    "\n",
    "axes[1].set_title('ROUGE1')\n",
    "axes[1].set_xlabel('')\n",
    "axes[1].set_xticks([0.2, 0.35, 0.5, 0.65, 0.8])\n",
    "\n",
    "axes[2].set_title('BERTScore')\n",
    "axes[2].set_xlabel('')\n",
    "axes[2].set_xticks([0.2, 0.35, 0.5, 0.65, 0.8])\n",
    "fig.text(0.52, -0.03, 'Context reduction ratio', ha='center', fontsize=14)\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/vol/research/lyc/llm_memorize/news/NewsContextManager_sent.pkl', 'rb') as f:\n",
    "    data = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(data.articles[0].units[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "articles = []\n",
    "for article in data.articles:\n",
    "    prompt = ''\n",
    "    for role, utterance in article.context[:-1]:\n",
    "        prompt += f\"{role}: {utterance}\\n\"\n",
    "    prompt += 'gpt: '\n",
    "    article.prompt = prompt\n",
    "    articles.append(article)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "articles[0].units[1].text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "articles[0].units[1].self_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.articles = articles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(data.articles[2].prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/vol/research/lyc/llm_memorize/conversation/ConversationContextManager_sent.pkl', 'wb') as f:\n",
    "    pickle.dump(data, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import GPT2Tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = GPT2Tokenizer.from_pretrained('gpt2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "n_sent = []\n",
    "n_phrase = []\n",
    "n_token = []\n",
    "for article in data.articles:\n",
    "    n_sent.append(len(article.units[0].text))\n",
    "    n_phrase.append(len(article.units[1].text))\n",
    "    n_token.append(len(article.units[2].text))\n",
    "\n",
    "print(np.mean(n_sent), np.mean(n_phrase), np.mean(n_token))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.articles[2].units[1].text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_lexical_units(article, mask_level = 'phrase'):\n",
    "    if mask_level == 'sent':\n",
    "        lexical_units = article.units[0]\n",
    "        assert lexical_units.unit_type == 'sent'\n",
    "    elif mask_level == 'phrase':\n",
    "        lexical_units = article.units[1]\n",
    "        assert lexical_units.unit_type == 'phrase'\n",
    "    elif mask_level == 'token':\n",
    "        lexical_units = article.units[2]\n",
    "        assert lexical_units.unit_type == 'token'\n",
    "\n",
    "    tokens = lexical_units.text\n",
    "    self_info = lexical_units.self_info\n",
    "    new_self_info = [i for i in self_info if i != 100]\n",
    "    # self_info = [x**1.2 for x in self_info]\n",
    "\n",
    "    max_score = max(new_self_info)\n",
    "    min_score = min(self_info)\n",
    "\n",
    "    mid = np.percentile(self_info, 50)\n",
    "\n",
    "    lines = []\n",
    "    highlighted = []\n",
    "    buffer = []\n",
    "    for token, score in zip(tokens, self_info):\n",
    "        if score == 100:\n",
    "            lines.append(token)\n",
    "            highlighted.append(token)\n",
    "            continue\n",
    "        normalized_score = ((score - min_score) / (max_score - min_score)) * 100\n",
    "        line = f\"\\\\colorize{{{normalized_score}}}{{{token}}}\"\n",
    "        if score > mid:\n",
    "            highlighted.append(line)\n",
    "            lines.append(line)\n",
    "        else:\n",
    "            token = f\"\\\\sdelete{{{token}}}\"\n",
    "            line = f\"\\\\colorize{{{normalized_score}}}{{{token}}}\"\n",
    "            lines.append(line)\n",
    "\n",
    "    return '\\n'.join(lines) + '\\n\\n\\n' + '\\n'.join(highlighted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(read_lexical_units(data.articles[2], mask_level = 'phrase'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.16 ('lyc')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "76684c5ca2233d26fd87f46a97dbb973f4bcb37221e033069f14c90ad955cc3d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
