{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pickle\n",
    "import plotly.graph_objs as go"
   ]
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# load perplexity and neg-log-likelihood evaluation logs\n",
    "import os\n",
    "import pandas as pd\n",
    "\n",
    "def load_eval_result(file_name, name=None, ):\n",
    "    with open(file_name, 'r') as f_:\n",
    "        lines = f_.readlines()\n",
    "\n",
    "    res = {}\n",
    "    for i in range(len(lines) - 1):\n",
    "        line1 = lines[i]\n",
    "        if line1.startswith(' validation results'):\n",
    "            benchmark_name = line1.split(' ')[-1][:-1]\n",
    "            line2 = lines[i+1]\n",
    "            assert line2.startswith('avg loss:')\n",
    "            l2s = line2.split(' ')\n",
    "            loss, ppl, ppl_a = [float(l2s[i]) for i in [2, 5, 9]]\n",
    "            res[benchmark_name] = {'loss': loss, 'ppl': ppl, 'ppl_a': ppl_a}\n",
    "    return res | {'name': name or file_name}\n",
    "\n",
    "data_micro_commute_s01 = []\n",
    "for step in list(range(2000, 25000, 2000)):\n",
    "    f = f'../eval_results_megatron/model_global_step{step}.txt'\n",
    "    if not os.path.exists(f):\n",
    "        continue\n",
    "    res = load_eval_result(f)\n",
    "    try:\n",
    "        data_micro_commute_s01.append([\n",
    "            'v01_step01',\n",
    "            f'load{step}', \n",
    "            step * 0.5,\n",
    "            step * 0.5,\n",
    "            *[res[f'eval_data_{lang}.jsonl']['loss'] for lang in ['en', 'ru']], \n",
    "            *[res[f'eval_data_{lang}.jsonl']['ppl'] for lang in ['en', 'ru']],\n",
    "        ])\n",
    "    except:\n",
    "        print(f'failed to load {f}')\n",
    "        pass\n",
    "df_micro_commute_s01 = pd.DataFrame(data_micro_commute_s01, columns=['exp_series',\n",
    "                                             'ver', \n",
    "                                             *[f'steps_{lang}' for lang in ['en', 'ru']],\n",
    "                                             *[f'loss_{lang}' for lang in ['en', 'ru']],\n",
    "                                             *[f'ppl_{lang}' for lang in ['en', 'ru']],])\n",
    "\n",
    "data_reorder_v01_s03 = []\n",
    "for step in range(2000, 25000, 6000):\n",
    "    steps1 = 4000\n",
    "    for steps2 in [2000, 4000]:\n",
    "        for direction1 in range(1, 8):\n",
    "            for direction2 in range(1, 8):\n",
    "                if (direction1 + direction2) % 2 or not 4 <= direction1 + direction2 <= 12:\n",
    "                    continue\n",
    "                f = f'../eval_results_megatron/model_v01_step03_load{step}_direction{direction1}_4000_direction{direction2}_global_step{step + steps1 + steps2}.txt'\n",
    "                if not os.path.exists(f):\n",
    "                    continue\n",
    "                res = load_eval_result(f)\n",
    "                try:\n",
    "                    data_reorder_v01_s03.append([\n",
    "                        'v01_step03',\n",
    "                        f'load{step}_dir{direction1}_dir{direction2}', \n",
    "                        step * 0.5 + steps1 * (0.5 - (direction1 - 4) / 3 * 0.45) + steps2 * (0.5 - (direction2 - 4) / 3 * 0.45),\n",
    "                        step * 0.5 + steps1 * (0.5 + (direction1 - 4) / 3 * 0.45) + steps2 * (0.5 + (direction2 - 4) / 3 * 0.45),\n",
    "                        *[res[f'eval_data_{lang}.jsonl']['loss'] for lang in ['en', 'ru']], \n",
    "                        *[res[f'eval_data_{lang}.jsonl']['ppl'] for lang in ['en', 'ru']],\n",
    "                        direction1, direction2, steps1, steps2\n",
    "                    ])\n",
    "                except:\n",
    "                    print(f'failed to load {f}')\n",
    "                    pass\n",
    "df_reorder_v01_s03 = pd.DataFrame(data_reorder_v01_s03, columns=['exp_series',\n",
    "                                             'ver', \n",
    "                                             *[f'steps_{lang}' for lang in ['en', 'ru']],\n",
    "                                             *[f'loss_{lang}' for lang in ['en', 'ru']],\n",
    "                                             *[f'ppl_{lang}' for lang in ['en', 'ru']],\n",
    "                                             *['direction1', 'direction2', 'steps1', 'steps2'],])"
   ],
   "id": "9e98cce729517a9b"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "p12 = []\n",
    "for step in list(range(2000, 25000, 2000)) + [25000]:\n",
    "    hg_data = pickle.load(open(f'/path/to/results_step{step}.pickle', 'rb'))\n",
    "    res = [\n",
    "        np.array([\n",
    "            ((hg_data['hvps']['en']['ru'][layer] - hg_data['hvps']['ru']['en'][layer]) * hg_data['grads'][lang][layer]).sum()\n",
    "            for layer in range(0, 76)\n",
    "        ]).sum()\n",
    "        for lang in ['en', 'ru']\n",
    "    ]\n",
    "    p12.append({'step': step,} | {lang: value for (lang, value) in zip(['en', 'ru'], res)})\n",
    "    print(step, res)"
   ],
   "id": "51a6d0812a66f28d"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "df = df_micro_commute_s01\n",
    "df_micro_commute_s01_dict = {}\n",
    "for row in df.iterrows():\n",
    "    df_micro_commute_s01_dict[int(row[1]['ver'][4:])] = dict(row[1])\n",
    "\n",
    "df = df_reorder_v01_s03\n",
    "df_reorder_v01_s03_dict = {}\n",
    "for row in df.iterrows():\n",
    "    step = int(row[1]['ver'][4:row[1]['ver'].find('_')])\n",
    "    if row[1]['steps2'] == 4000:\n",
    "        if step not in df_reorder_v01_s03_dict:\n",
    "            df_reorder_v01_s03_dict[step] = []\n",
    "        df_reorder_v01_s03_dict[step].append(dict(row[1]))"
   ],
   "id": "9197558b603592e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# llm table 1\n",
    "for r in p12:\n",
    "    step = r['step']\n",
    "    loss_en = df_micro_commute_s01_dict[step][\"loss_en\"]\n",
    "    loss_ru = df_micro_commute_s01_dict[step][\"loss_ru\"]\n",
    "    if step in df_reorder_v01_s03_dict:\n",
    "        rows = df_reorder_v01_s03_dict[step]\n",
    "        row_4_4 = [row for row in rows if row['ver'].endswith('_dir4_dir4')][0]\n",
    "        row_3_5 = [row for row in rows if row['ver'].endswith('_dir3_dir5')][0]\n",
    "        row_5_3 = [row for row in rows if row['ver'].endswith('_dir5_dir3')][0]\n",
    "        loss_en_base = row_4_4['loss_en']\n",
    "        loss_ru_base = row_4_4['loss_ru']\n",
    "        el12_en = row_5_3['loss_en'] - row_4_4['loss_en']\n",
    "        el12_ru = row_5_3['loss_ru'] - row_4_4['loss_ru']\n",
    "        el21_en = row_3_5['loss_en'] - row_4_4['loss_en']\n",
    "        el21_ru = row_3_5['loss_ru'] - row_4_4['loss_ru']\n",
    "        p_adjusted = {lang: r[lang] * 4000**2 * 1.5e-4**2 * .15**2 / 8 for lang in ['en', 'ru']}\n",
    "        print(f'{step:6d} & '\n",
    "              f'{loss_en_base:6.2f} & {loss_ru_base:6.2f} & '\n",
    "              f'{el12_en:8.4f} & {el12_ru:8.4f} & {el21_en:8.4f} & {el21_ru:8.4f} & '\n",
    "              f'{p_adjusted[\"en\"]:8.4f} & {p_adjusted[\"ru\"]:8.4f} \\\\\\\\')"
   ],
   "id": "2e65f33a31e952b9"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# llm table 2\n",
    "for r in p12:\n",
    "    p_adjusted = {lang: r[lang] * 4000**2 * 1.5e-4**2 * .15**2 / 8 for lang in ['en', 'ru']}\n",
    "    print(f'{r[\"step\"]:5d} & {p_adjusted[\"en\"]:8.4f} & {p_adjusted[\"ru\"]:8.4f} \\\\\\\\')"
   ],
   "id": "246e96be2b5b4682"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "fig = go.Figure()\n",
    "\n",
    "x_fn = lambda df: df.loss_en\n",
    "y_fn = lambda df: df.loss_ru\n",
    "\n",
    "# add data point\n",
    "df0 = df_reorder_v01_s03\n",
    "df0 = df0[df0['ver'] == 'load20000_dir4_dir4']\n",
    "df0 = df0.drop(columns=['direction1', 'direction2', 'steps1', 'steps2'])\n",
    "# step1 equal sampling\n",
    "df = df_micro_commute_s01\n",
    "df = df[(df['loss_en'] + df['loss_ru'] < 8)]\n",
    "df = pd.concat([df, df0])\n",
    "fig.add_scatter(\n",
    "    x=df.loss_en, \n",
    "    y=df.loss_ru, \n",
    "    name=f's01',\n",
    "    mode='markers+lines', \n",
    "    marker=dict(\n",
    "        colorscale=\"Cividis\",\n",
    "        size=8,\n",
    "    ),\n",
    "    hovertemplate='<br>%{text}<br>',\n",
    "    text=[''.join([f\"{key}: {df.iloc[i][key]}<br>\" for key in df.keys()]) for i in range(df.shape[0])],\n",
    "    showlegend=False,\n",
    ")\n",
    "# step3 load 4000 in different directions\n",
    "for load_step in [2000, 8000, 14000, 20000]:\n",
    "    # for dir1 in range(1, 8):\n",
    "    df = df_reorder_v01_s03\n",
    "    df = df[(df['loss_en'] + df['loss_ru'] < 8)\n",
    "            *(df.steps2 == 4000)\n",
    "            *(df.direction1 + df.direction2 == 8)\n",
    "            # *(df.direction1 >= 3)\n",
    "            # *(df.direction1 != 4)\n",
    "            # *(df.direction1 <= 5)\n",
    "    ]\n",
    "    df = df[df.steps_en + df.steps_ru == load_step + 4000 + 4000]\n",
    "    # df = df[df.direction1 == dir1]\n",
    "    fig.add_scatter(\n",
    "        x=x_fn(df), \n",
    "        y=y_fn(df), \n",
    "        name=f's03_load{load_step}_dir{dir1}_4000_4000',\n",
    "        mode='markers+lines', \n",
    "        marker=dict(\n",
    "            color='black', #1 / (1 + np.exp((df.direction1-df.direction2) * 1.5)),\n",
    "            colorscale=\"Turbo\",\n",
    "            size=8,\n",
    "            # symbol=[{2000: 'circle', 8000: 'cross', 14000: 'square', 20000: 'star',}[int(x[x.find('d')+1: x.find('_')])] for x in df.ver],\n",
    "        ),\n",
    "        hovertemplate='<br>%{text}<br>',\n",
    "        text=[''.join([f\"{key}: {df.iloc[i][key]}<br>\" for key in df.keys()]) for i in range(df.shape[0])],\n",
    "        # customdata=df['id'],\n",
    "        # name='model:0 data:1',\n",
    "        showlegend=False,\n",
    "    )\n",
    "fig.add_annotation(\n",
    "    xref=\"x domain\",\n",
    "    yref=\"y domain\",\n",
    "    x=1.031,\n",
    "    y=-.037,\n",
    "    text=\"NLL<br>English\",\n",
    "    font=dict(size=16),\n",
    "    showarrow=False,\n",
    ")\n",
    "fig.add_annotation(\n",
    "    xref=\"x domain\",\n",
    "    yref=\"y domain\",\n",
    "    x=-.03,\n",
    "    y=.99,\n",
    "    text=\"NLL<br>Russian\",\n",
    "    font=dict(size=16),\n",
    "    showarrow=False\n",
    ")\n",
    "fig.update_layout(\n",
    "    autosize=False,\n",
    "    margin=dict(l=20, r=20, t=20, b=20),\n",
    "    width=800,\n",
    "    height=600,\n",
    "    scene=dict(\n",
    "        xaxis_title='en',\n",
    "        yaxis_title='zh',\n",
    "        zaxis_title='ru',\n",
    "    ),\n",
    ")\n",
    "fig.show()"
   ],
   "id": "f7d6502deacfa868"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "7f07d2abe35e289d"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "de6f0ff3bc1643"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "3e02acaf1e91e404"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "af38400d2b96b84e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "72dbd0fc461ca814"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "2c8914671816ee14"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "3483110cfa334256"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "50d93edf8837afa0"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "3c71eebe97c314cf"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "45b4336fb92b5981"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
