{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import pandas as pd\n",
    "import warnings\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import matplotlib as mpl\n",
    "from uncertainties import ufloat\n",
    "from scipy import stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rc('font', family='serif')\n",
    "plt.rc('xtick', labelsize='x-small')\n",
    "plt.rc('ytick', labelsize='x-small')\n",
    "plt.rc('text', usetex=False)\n",
    "sns.set(style=\"ticks\", font_scale=1.5, color_codes=True)\n",
    "sns.set_style({'font.family':'serif', 'font.serif':'Times New Roman'})\n",
    "mpl.rcParams['figure.dpi'] = 300"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Results Loading Code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mean_std(x):\n",
    "    return ufloat(np.mean(x),np.std(x))\n",
    "def my_mean(x):\n",
    "    return(str(round(np.mean(x),2)))\n",
    "def sort_by(list1, list2):\n",
    "    return [x for _,x in sorted(zip(list2,list1))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_splits(fname):\n",
    "    splits = !grep -hnr 'TEST EVALS\\|VAL EVALS\\|TEST EASY\\|VAL EASY' {fname}\n",
    "    length = !wc -l {fname}\n",
    "    length = int(length[0].split(' ')[0])\n",
    "    starts = []\n",
    "    names  = []\n",
    "    for s in splits:\n",
    "        start,sname = s.split(':')\n",
    "        starts.append(int(start))\n",
    "        names.append(sname)\n",
    "    starts.append(length)\n",
    "    sdict = {}\n",
    "    for i in range(0,len(starts)-1):\n",
    "        sdict[names[i]] = slice(starts[i],starts[i+1])\n",
    "    return sdict\n",
    "\n",
    "def get_tokens(line):\n",
    "    return line.rstrip().split(\" \")[1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_next_matching_block(lines,start):\n",
    "    index = start\n",
    "    found = False\n",
    "    for l in lines[start:]:\n",
    "        if l.startswith('INPUT:'):\n",
    "            found = True\n",
    "            break\n",
    "        index+=1\n",
    "    if found:\n",
    "        inp   = lines[index]\n",
    "        ref   = get_tokens(lines[index+1])\n",
    "        pred  = get_tokens(lines[index+2])\n",
    "        return (inp,ref,pred), index+3\n",
    "    else:\n",
    "        return None, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_scores(lines, train_tags=None, reduced=True):\n",
    "    start=0 \n",
    "    finished = False\n",
    "    preds, tps, fps, fns, f1s, novel, sorter = [],[],[],[],[],[], []\n",
    "    while not finished:\n",
    "        data, start = find_next_matching_block(lines,start)\n",
    "        if data is not None:\n",
    "            inp, ref, pred_here = data\n",
    "            d = inp+\"\\t\".join(ref)\n",
    "            correct = pred_here == ref\n",
    "            tp = (len([p for p in pred_here if p in ref]))\n",
    "            fp = (len([p for p in pred_here if p not in ref]))\n",
    "            fn = (len([p for p in ref if p not in pred_here]))\n",
    "            prec = tp / (tp + fp)\n",
    "            rec = tp / (tp + fn)\n",
    "            if prec == 0 or rec == 0:\n",
    "                f1 = 0\n",
    "            else:\n",
    "                f1 = 2 * prec * rec / (prec + rec)\n",
    "            for (k,v) in ((sorter,d),(f1s,f1),(tps,tp),(fps,fp),(fns,fn),(preds,correct)):\n",
    "                k.append(v)\n",
    "            if train_tags is not None:\n",
    "                novel.append((\";\".join(ref[1:]) not in train_tags))\n",
    "        else:\n",
    "            finished = True\n",
    "            \n",
    "    if len(novel) == 0:\n",
    "        novel = [True] * len(f1s)\n",
    "        \n",
    "    tps, fps, fns, preds, f1s, sorter = map(lambda x:  np.array(x)[novel], (tps, fps, fns, preds, f1s, sorter))\n",
    "    tp, fp, fn = np.sum(tps), np.sum(fps), np.sum(fns)\n",
    "    prec = tp / (tp + fp)\n",
    "    rec = tp / (tp + fn)\n",
    "    if prec == 0 or rec == 0:\n",
    "        f1 = 0\n",
    "    else:\n",
    "        f1 = 2 * prec * rec / (prec + rec)\n",
    "    f1_std =  np.std(f1s)\n",
    "    acc = np.mean(preds)\n",
    "    acc_std = np.std(preds)\n",
    "    if not reduced:\n",
    "        tps, fps, fns, preds, f1s = map(lambda x: sort_by(x,sorter), (tps, fps, fns, preds, f1s))\n",
    "        return np.mean(preds), preds, f1, f1s\n",
    "    return acc, acc_std, f1, f1_std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "def get_tags(datafolder, hints, seed):\n",
    "    train_tags= set([line.rstrip().split('\\t')[2] for line in open(datafolder+f\"train.hints-{hints}.{seed}.txt\")])\n",
    "#   for split in (\"test_hard\", \"val_hard\"):\n",
    "#       split_tags = set([line.rstrip().split('\\t')[2] for line in open(datafolder+f\"{split}.hints-{hints}.{seed}.txt\")])\n",
    "    return train_tags"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_tags(\"./data/SIGDataSet.large/spanish/\", hints=4, seed=0);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./checkpoints_seperate_large/SIGDataSet/spanish/logs/2proto.vae.true.hints.16.seed.0.cond.log\n",
      "split: TEST EVALS\n",
      " (0.515625, 0.49975579974123363, 0.8501291989664083, 0.17237705255322905)\n",
      "split: VAL EVALS\n",
      " (0.2441860465116279, 0.42960356283514334, 0.7231329690346084, 0.2502584644495325)\n",
      "split: TEST EASY\n",
      " (nan, nan, nan, nan)\n",
      "split: VAL EASY\n",
      " (nan, nan, nan, nan)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/gridsan/eakyurek/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:32: RuntimeWarning: invalid value encountered in long_scalars\n",
      "/home/gridsan/eakyurek/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:33: RuntimeWarning: invalid value encountered in long_scalars\n",
      "/home/gridsan/eakyurek/anaconda3/lib/python3.7/site-packages/numpy/core/_methods.py:234: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
      "  keepdims=keepdims)\n",
      "/home/gridsan/eakyurek/anaconda3/lib/python3.7/site-packages/numpy/core/_methods.py:195: RuntimeWarning: invalid value encountered in true_divide\n",
      "  arrmean, rcount, out=arrmean, casting='unsafe', subok=False)\n",
      "/home/gridsan/eakyurek/anaconda3/lib/python3.7/site-packages/numpy/core/_methods.py:226: RuntimeWarning: invalid value encountered in double_scalars\n",
      "  ret = ret.dtype.type(ret / rcount)\n",
      "/home/gridsan/eakyurek/anaconda3/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3373: RuntimeWarning: Mean of empty slice.\n",
      "  out=out, **kwargs)\n",
      "/home/gridsan/eakyurek/anaconda3/lib/python3.7/site-packages/numpy/core/_methods.py:170: RuntimeWarning: invalid value encountered in double_scalars\n",
      "  ret = ret.dtype.type(ret / rcount)\n"
     ]
    }
   ],
   "source": [
    "testfile = \"./checkpoints_seperate_large/SIGDataSet/spanish/logs/2proto.vae.true.hints.16.seed.0.cond.log\"\n",
    "train_tags = get_tags('data/SIGDataSet.large/spanish/', hints=4,seed=0)\n",
    "testlines  = open(testfile,'r').readlines()\n",
    "print(testfile)\n",
    "for (s,r) in get_splits(testfile).items():\n",
    "         print(f\"split: {s}\\n\",calculate_scores(testlines[r],train_tags=train_tags))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SIGMorphon\n",
    "### Scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_lang_scores(df=None,\n",
    "                    langs=(\"spanish\",\"turkish\",\"swahili\"),\n",
    "                    hintss=(4,8,16),\n",
    "                    seeds=(0,1,2,3,4),\n",
    "                    vaes =(\"true\",\"false\"),\n",
    "                    models=(\"baseline\",\"0proto\",\"1proto\",\"2proto\"),\n",
    "                    exppath=\"./checkpoints\",\n",
    "                    datapath=\"./data/SIGDataSet.large/\",\n",
    "                    novel=False,\n",
    "                    reduced=True,\n",
    "                   ):\n",
    "    train_tags=None\n",
    "    for lang in langs:\n",
    "        for hints in hintss:\n",
    "            for seed in seeds:\n",
    "                if novel:\n",
    "                    train_tags = get_tags(datapath + lang + '/', hints=hints,seed=seed)\n",
    "                for vae in vaes:\n",
    "                    for model in  models:\n",
    "                        langpath=os.path.join(exppath,\"SIGDataSet\",lang)\n",
    "                        if model == \"baseline\" or model == \"geca\":\n",
    "                            identifier =\"{}.hints.{}.seed.{}\".format(model,hints,seed)\n",
    "                        else:\n",
    "                            identifier =\"{}.vae.{}.hints.{}.seed.{}\".format(model,vae,hints,seed)\n",
    "                        condfile=os.path.join(langpath,\"logs\",identifier+\".cond.log\") \n",
    "                        if os.path.exists(condfile):\n",
    "                            lines  = open(condfile,'r').readlines()\n",
    "                            if len(lines) < 2142:\n",
    "                                print(\"format broken in \"+condfile)\n",
    "                                continue\n",
    "#                             print(\"processing: \"+condfile)\n",
    "                            for (s,r) in get_splits(condfile).items():#splitinfo.items():\n",
    "                                acc, accstd, f1, f1std = calculate_scores(lines[r], train_tags=train_tags, reduced=reduced)  \n",
    "                                df.loc[len(df.index)] = (lang,hints,seed,vae,model,s,acc,accstd,f1,f1std)\n",
    "                        else:\n",
    "                            print(f\"file doesnot exist: {condfile}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "format broken in ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.4.seed.3.cond.log\n",
      "format broken in ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.4.seed.3.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.4.seed.4.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.4.seed.4.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.8.seed.0.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.8.seed.0.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.8.seed.1.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.8.seed.1.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.8.seed.2.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.8.seed.2.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.8.seed.3.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.8.seed.3.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.8.seed.4.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.8.seed.4.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.16.seed.0.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.16.seed.0.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.16.seed.1.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.16.seed.1.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.16.seed.2.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.16.seed.2.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.16.seed.3.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.16.seed.3.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.16.seed.4.cond.log\n",
      "file doesnot exist: ./checkpoints_sig_copy/SIGDataSet/turkish/logs/baseline.hints.16.seed.4.cond.log\n"
     ]
    }
   ],
   "source": [
    "dfcopy = pd.DataFrame(columns=('Language', 'Hints', 'Seed', 'Vae','Model','Split','Acc','Acc_std','F1','F1_std',))\n",
    "get_lang_scores(df=dfcopy,exppath=\"./checkpoints_sig_copy\",datapath=\"./data/SIGDataSet.large.copy/\",  models=(\"baseline\",), novel=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfrare = pd.DataFrame(columns=('Language', 'Hints', 'Seed', 'Vae','Model','Split','Acc','Acc_std','F1','F1_std',))\n",
    "get_lang_scores(df=dfrare,exppath=\"./checkpoints_seperate_large\",datapath=\"./data/SIGDataSet.large/\",  models=(\"0proto\",\"1proto\",\"2proto\"), novel=False)\n",
    "get_lang_scores(df=dfrare,exppath=\"./checkpoints_large_test\",datapath=\"./data/SIGDataSet.large/\",  models=(\"baseline\",), novel=False)\n",
    "get_lang_scores(df=dfrare,exppath=\"./checkpoints_large_test_geca\",datapath=\"./data/SIGDataSet.large/\",models=(\"geca\",), novel=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfrare_novel = pd.DataFrame(columns=('Language', 'Hints', 'Seed', 'Vae','Model','Split','Acc','Acc_std','F1','F1_std',))\n",
    "get_lang_scores(df=dfrare_novel,exppath=\"./checkpoints_seperate_large\",datapath=\"./data/SIGDataSet.large/\",  models=(\"0proto\",\"1proto\",\"2proto\"), novel=True)\n",
    "get_lang_scores(df=dfrare_novel,exppath=\"./checkpoints_large_test\",datapath=\"./data/SIGDataSet.large/\",  models=(\"baseline\",), novel=True)\n",
    "get_lang_scores(df=dfrare_novel,exppath=\"./checkpoints_large_test_geca\",datapath=\"./data/SIGDataSet.large/\",models=(\"geca\",), novel=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfnorare = pd.DataFrame(columns=('Language', 'Hints', 'Seed', 'Vae','Model','Split','Acc','Acc_std','F1','F1_std',))\n",
    "get_lang_scores(df=dfnorare,exppath=\"./checkpoints_morph_norare\",datapath=\"./data/SIGDataSet.large/\",  models=(\"0proto\",\"1proto\",\"2proto\"), novel=False, vaes=(\"false\",))\n",
    "get_lang_scores(df=dfnorare,exppath=\"./checkpoints_large_test\",datapath=\"./data/SIGDataSet.large/\",  models=(\"baseline\",), novel=False, vaes=(\"false\",))\n",
    "get_lang_scores(df=dfnorare,exppath=\"./checkpoints_large_test_geca_norare\",datapath=\"./data/SIGDataSet.large/\",models=(\"geca\",), novel=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfnorare_novel = pd.DataFrame(columns=('Language', 'Hints', 'Seed', 'Vae','Model','Split','Acc','Acc_std','F1','F1_std',))\n",
    "get_lang_scores(df=dfnorare_novel,exppath=\"./checkpoints_morph_norare\",datapath=\"./data/SIGDataSet.large/\",  models=(\"0proto\",\"1proto\",\"2proto\"), novel=True, vaes=(\"false\",))\n",
    "get_lang_scores(df=dfnorare_novel,exppath=\"./checkpoints_large_test\",datapath=\"./data/SIGDataSet.large/\",  models=(\"baseline\",), novel=True, vaes=(\"false\",))\n",
    "get_lang_scores(df=dfnorare_novel,exppath=\"./checkpoints_large_test_geca_norare\",datapath=\"./data/SIGDataSet.large/\",models=(\"geca\",), novel=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_morph_results_table(df, hints=8, vae=\"false\", score=\"F1\", markdown=False):\n",
    "    splits_s = [\"TEST EVALS\", \"VAL EVALS\", \"TEST EASY\"]\n",
    "    splits_s_alt = [\"Fut\", \"Past\", \"Pres\"]\n",
    "    cols_s = [\"Language\", \"Seed\", \"Model\", \"Split\"] + [score]\n",
    "    df = df.replace(splits_s, splits_s_alt).\\\n",
    "                    loc[(df['Split'].isin(splits_s)) & (df['Vae'] == vae) & (df['Hints'] == hints), cols_s].\\\n",
    "                    reset_index().\\\n",
    "                    drop(columns=['index']).\\\n",
    "                    groupby(by=[\"Model\",\"Split\",\"Language\"]).\\\n",
    "                    agg({score:mean_std}).\\\n",
    "                    reset_index(\"Split\").pivot(columns=\"Split\")\n",
    "\n",
    "    df.columns = df.columns.swaplevel(1,0)\n",
    "    df['Fut-Past '+ score] =  (df['Fut'] + df['Past'])/2\n",
    "    df = df.drop(columns=['Fut','Past'])\n",
    "    df.columns = [' '.join(col).strip() for col in df.columns.values]\n",
    "    df = df.reset_index(\"Language\").pivot(columns=\"Language\")\n",
    "    df.columns = df.columns.swaplevel(1,0)\n",
    "    df.columns = [' '.join(col).strip().title() for col in df.columns.values]\n",
    "    return df\n",
    "    return df.iloc[:,[3,0,4,1,5,2]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_copy_table = get_morph_results_table(dfcopy, hints=8, vae=\"false\", score=\"F1\", markdown=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| Model    | Spanish Pres F1   | Swahili Pres F1   | Spanish Fut-Past F1   | Swahili Fut-Past F1   |\n",
      "|:---------|:------------------|:------------------|:----------------------|:----------------------|\n",
      "| baseline | 0.882+/-0.017     | 0.900+/-0.019     | 0.652+/-0.006         | 0.770+/-0.010         |\n"
     ]
    }
   ],
   "source": [
    "print(df_copy_table.to_markdown())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_rare_table = get_morph_results_table(dfrare, hints=16, vae=\"false\", score=\"Acc\", markdown=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_norare_table = get_morph_results_table(dfnorare, hints=16, vae=\"false\", score=\"Acc\", markdown=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_rare_table = df_rare_table.reset_index(\"Model\")\n",
    "df_rare_table[\"Model\"] = df_rare_table[\"Model\"] .astype(str) + ' +rare'\n",
    "df_rare_table.set_index(\"Model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(df_rare_table.to_latex(index=False,float_format=\"{:0.2f}\", caption=\"Morphology Results\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_norare_table = df_norare_table.reset_index(\"Model\")\n",
    "df_norare_table[\"Model\"] = df_norare_table[\"Model\"] .astype(str)\n",
    "df_norare_table.set_index(\"Model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_morph = df_norare_table[3:5].append([df_rare_table.iloc[4],\n",
    "                             df_norare_table.iloc[0],\n",
    "                             df_rare_table.iloc[0],\n",
    "                             df_norare_table.iloc[1],\n",
    "                             df_rare_table.iloc[1],\n",
    "                             df_norare_table.iloc[2],\n",
    "                             df_rare_table.iloc[2]],ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_morph "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(df_morph.to_latex(index=False,float_format=\"{:0.2f}\", caption=\"Morphology Results\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_rare_novel_table = get_morph_results_table(dfrare_novel, hints=8, vae=\"false\", score=\"F1\", markdown=False)\n",
    "df_rare_novel_table = df_rare_novel_table.reset_index(\"Model\")\n",
    "df_rare_novel_table[\"Model\"] = df_rare_novel_table[\"Model\"] .astype(str) + ' +rare'\n",
    "df_rare_novel_table.set_index(\"Model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_norare_novel_table = get_morph_results_table(dfnorare_novel, hints=8, vae=\"false\", score=\"F1\", markdown=False)\n",
    "df_norare_novel_table = df_norare_novel_table.reset_index(\"Model\")\n",
    "df_norare_novel_table[\"Model\"] = df_norare_novel_table[\"Model\"] .astype(str) \n",
    "df_norare_novel_table.set_index(\"Model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_morph_novel = df_norare_novel_table[3:5].append([df_rare_novel_table.iloc[4],\n",
    "                             df_norare_novel_table.iloc[0],\n",
    "                             df_rare_novel_table.iloc[0],\n",
    "                             df_norare_novel_table.iloc[1],\n",
    "                             df_rare_novel_table.iloc[1],\n",
    "                             df_norare_novel_table.iloc[2],\n",
    "                             df_rare_novel_table.iloc[2]],ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(df_morph_novel.to_latex(index=False,float_format=\"{:0.2f}\", caption=\"Morphology Results\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #options\n",
    "splits_s = [\"TEST EVALS\", \"TEST EASY\", \"VAL EVALS\"]\n",
    "splits_s_alt = [\"Future Tense\", \"Present Tense\", \"Past Tense\"]\n",
    "palette = {\"baseline\":\"grey\",\n",
    "          \"0proto\":\"lightsalmon\",\n",
    "          \"1proto\":\"salmon\",\n",
    "          \"2proto\":\"coral\",\n",
    "          \"geca\":\"cornflowerblue\"}\n",
    "def get_morph_results_graph(df, hints, vae, score):\n",
    "    cols = [\"Language\", \"Seed\", \"Model\", \"Split\", score]\n",
    "    #filter based on options\n",
    "    df = df.replace(splits_s, splits_s_alt).loc[(df['Split'].isin(splits_s)) & (df['Vae'] == vae)  & (df[\"Hints\"]==hints), cols]\n",
    "#   df.head()\n",
    "#   print(len(df))\n",
    "    #aggregate to get the mean\n",
    "    agg = df.groupby(by=[\"Model\",\"Split\",\"Seed\"]).agg(\"mean\"). \\\n",
    "          reset_index()\n",
    "    agg[\"Language\"] = \"Average\"+score\n",
    "    agg = agg[df.columns]\n",
    "#   print(len(agg))\n",
    "    #new df with mean\n",
    "    df= df.append(agg, ignore_index=True)\n",
    "    return(df)\n",
    "\n",
    "def show_values_on_bars(axs):\n",
    "    def _show_on_single_plot(ax):        \n",
    "        for p in ax.patches:\n",
    "            _x = p.get_x() + p.get_width() / 2\n",
    "            _y = p.get_y() +  0.02\n",
    "            value = '{:.2f}'.format(p.get_height())\n",
    "            ax.text(_x, _y, value, ha=\"center\",rotation=\"vertical\",fontsize=12) \n",
    "\n",
    "    if isinstance(axs, np.ndarray):\n",
    "        for idx, ax in np.ndenumerate(axs):\n",
    "            _show_on_single_plot(ax)\n",
    "    else:\n",
    "        _show_on_single_plot(axs)\n",
    "        \n",
    "def get_morph_graph(df, hints, vae, score, ylim=0.0):\n",
    "    df = get_morph_results_graph(df, hints, vae, score)\n",
    "    df.rename(columns={'Split':'Set'}, inplace=True)\n",
    "    g = sns.catplot(x=\"Set\",\n",
    "               y=score,\n",
    "               col=\"Language\",\n",
    "               hue=\"Model\",\n",
    "               col_order=[\"Average\"+score, \"spanish\",\"turkish\",\"swahili\"],\n",
    "               kind=\"bar\",\n",
    "               data=df,\n",
    "               hue_order=[\"baseline\",\"geca\",\"0proto\",\"1proto\",\"2proto\"],\n",
    "               ci='sd',\n",
    "               legend_out=True,\n",
    "               palette=palette#sns.color_palette(\"RdBu\", n_colors=5)\n",
    "               )\n",
    "\n",
    "    #fix labels and save\n",
    "    axes = g.axes.flatten()\n",
    "    show_values_on_bars(axes)\n",
    "    axes[0].set_title(\"Average\"+score)\n",
    "    axes[0].set_ylim(ylim,)\n",
    "    axes[1].set_title(\"Spanish\")\n",
    "    axes[2].set_title(\"Turkish\")\n",
    "    axes[3].set_title(\"Swahili\")\n",
    "    g.set_xticklabels(rotation=15)\n",
    "    g._legend.set_title(f\"hints: {hints}\\nvae: {vae}\")\n",
    "    return(g)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for hints in (4,8,16):\n",
    "    for vae in (\"true\",\"false\"):\n",
    "        for score in (\"Acc\",\"F1\"):\n",
    "            plt.figure()\n",
    "            #print(\"morph_results_{}_hints_{}_vae_{}_{}.pdf\".format(score,hints,vae,score))\n",
    "            g = get_morph_graph(dfrare, hints=hints, vae=vae, score=score)\n",
    "            #g.savefig(\"morph_results_{}_hints_{}_vae_{}_{}.pdf\".format(score,hints,vae,score), dpi=300, verbose=True)\n",
    "            plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SCAN \n",
    "### Scores & Tables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_scan_scores(df=None,\n",
    "                    tasks=(\"jump\",\"around_right\"),\n",
    "                    seeds=(0,1,2,3,4),\n",
    "                    vaes =(\"true\",\"false\"),\n",
    "                    models=(\"0proto\",\"1proto\",\"2proto\"),\n",
    "                    exppath=\"./checkpoints\",\n",
    "                    reduced=True,\n",
    "                   ):\n",
    "    for task in tasks:\n",
    "            for seed in seeds:\n",
    "                for vae in vaes:\n",
    "                    for model in  models:\n",
    "                        taskpath=os.path.join(exppath,\"SCANDataSet\")\n",
    "                        identifier =\"{}.vae.{}.{}.seed.{}\".format(model,vae,task,seed)\n",
    "                        condfile=os.path.join(taskpath,\"logs\",identifier+\".cond.log\") \n",
    "                        if os.path.exists(condfile):\n",
    "                            lines  = open(condfile,'r').readlines()\n",
    "#                           print(\"processing: \"+condfile)\n",
    "                            for (s,r) in get_splits(condfile).items():\n",
    "                                acc, accstd, f1, f1std = calculate_scores(lines[r], reduced=reduced)  \n",
    "                                df.loc[len(df.index)] = (task,s,seed,vae,model,acc,accstd,f1,f1std)\n",
    "                        else:\n",
    "                            print(f\"file doesnot exist: {condfile}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_scan_results_table(df, markdown=False):\n",
    "    splits_s = [\"TEST EVALS\"]\n",
    "    score_s = \"Acc\"\n",
    "    cols_s = [\"Task\", \"Seed\", \"Model\", \"Split\", score_s]\n",
    "    vae_s = \"false\"\n",
    "    \n",
    "    df = df.loc[(df['Split'].isin(splits_s)) & (df['Vae'] == vae_s), cols_s].\\\n",
    "                    reset_index().\\\n",
    "                    drop(columns=['index', 'Split'])                  \n",
    "#    df.head()\n",
    "#    print(\"Len: \", len(df))\n",
    "    # add geca and baseline scores\n",
    "    geca_baseline_s = pd.read_csv(\"stats/scan-geca-baseline.csv\", header=None)\n",
    "    # append\n",
    "    for index,row in geca_baseline_s.iterrows():\n",
    "        task, seed, model, val = row[0:4]\n",
    "        df.loc[len(df)] = [task, seed, model, float(val)]\n",
    "\n",
    "    df= df.groupby(by=[\"Model\",\"Task\"]).\\\n",
    "                    agg({\"Acc\":mean_std}).\\\n",
    "                    reset_index(\"Task\").\\\n",
    "                    pivot(columns=\"Task\").\\\n",
    "                    rename(columns={\"around_right\":\"AROUND RIGHT\", \"jump\":\"JUMP\"})\n",
    "    \n",
    "    df.columns = [' '.join(col).strip().title() for col in  df.columns.values]\n",
    "    df = df.reset_index('Model')\n",
    "    \n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#create SCAN table\n",
    "dfscannorare = pd.DataFrame(columns=(\"Task\",'Split', 'Seed', 'Vae','Model','Acc','Acc_std','F1','F1_std',))\n",
    "get_scan_scores(df=dfscannorare, exppath=\"./checkpoints.bak\",vaes=(\"false\",))\n",
    "dfscan_other_norare = pd.DataFrame(columns=(\"Task\",'Split', 'Seed', 'Vae','Model','Acc','Acc_std','F1','F1_std',))\n",
    "get_scan_scores(df=dfscan_other_norare,\n",
    "                 models=(\"2proto\",),\n",
    "                 tasks=(\"jump\",),\n",
    "                 exppath=\"./checkpoints_dgx_scan/\",\n",
    "                 vaes=(\"false\",),\n",
    "                )\n",
    "#results with seed 5-9\n",
    "dfscan_other_norare['Seed'] = dfscan_other_norare['Seed'] + 5\n",
    "dfscannorare=dfscannorare.append(dfscan_other_norare, ignore_index=True)\n",
    "dfscannorare"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#create SCAN table\n",
    "dfscan = pd.DataFrame(columns=(\"Task\",'Split', 'Seed', 'Vae','Model','Acc','Acc_std','F1','F1_std',))\n",
    "get_scan_scores(df=dfscan, exppath=\"./checkpoints_scan_rare2\",vaes=(\"false\",))\n",
    "dfscan_other = pd.DataFrame(columns=(\"Task\",'Split', 'Seed', 'Vae','Model','Acc','Acc_std','F1','F1_std',))\n",
    "get_scan_scores(df=dfscan_other,\n",
    "                 models=(\"2proto\",),\n",
    "                 tasks=(\"jump\",),\n",
    "                 exppath=\"./checkpoints_scan_rare3/\",\n",
    "                 vaes=(\"false\",),\n",
    "                )\n",
    "#results with seed 5-9\n",
    "dfscan_other['Seed'] = dfscan_other['Seed'] + 5\n",
    "dfscan=dfscan.append(dfscan_other, ignore_index=True)\n",
    "dfscan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "table_scannorare = get_scan_results_table(dfscannorare)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "table_scannorare"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(table_scannorare.to_latex(index=False, float_format=\"{:0.2f}\", caption=\"SCAN Experiments with no rare filtering\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "table_scanrare = get_scan_results_table(dfscan)\n",
    "table_scanrare.loc[len(table_scanrare)] = (\"0proto\",\"NaN\",\"NaN\")\n",
    "table_scanrare = table_scanrare.iloc[[0,1,4,2,3],:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "table_scanrare"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(table_scanrare.to_latex(index=False, float_format=\"{:0.2f}\", caption=\"SCAN Experiments with rare filtering\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Significance Analyses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def merge_test_val(df):\n",
    "    df = df.set_index([\"Language\",\"Hints\",\"Seed\",\"Vae\",\"Model\"])\n",
    "    t1 = df[df['Split'] == 'TEST EVALS']\n",
    "    t2 = df[df['Split'] == 'VAL EVALS']\n",
    "    t = t1.join(t2, lsuffix='_test', rsuffix='_val')\n",
    "    t['Acc_std'] = t['Acc_std_test'] + t['Acc_std_val'] \n",
    "    t['F1_std'] = t['F1_std_test'] + t['F1_std_val']\n",
    "    #t=t.drop(columns=['Acc_std_test', 'F1_std_test', 'F1_std_val', 'Acc_std_val', 'Acc_test', 'Acc_val','F1_val', 'F1_test', 'Split_val', 'Split'])\n",
    "    t = t[['Acc_std', 'F1_std']].reset_index()\n",
    "    return t\n",
    "\n",
    "def get_others(df):\n",
    "    df = df.set_index([\"Language\",\"Hints\",\"Seed\",\"Vae\",\"Model\"])\n",
    "    df = df[df['Split'] == 'TEST EASY']\n",
    "    df = df[['Acc_std', 'F1_std']].reset_index()\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfunreduced_rare = pd.DataFrame(columns=('Language', 'Hints', 'Seed', 'Vae','Model','Split','Acc','Acc_std','F1','F1_std',))\n",
    "get_lang_scores(df=dfunreduced_rare,exppath=\"./checkpoints_seperate_large\",datapath=\"./data/SIGDataSet.large/\",  models=(\"0proto\",\"1proto\",\"2proto\"), vaes=(\"false\",), novel=False, reduced=False)\n",
    "get_lang_scores(df=dfunreduced_rare,exppath=\"./checkpoints_large_test_geca\",datapath=\"./data/SIGDataSet.large/\",models=(\"geca\",), novel=False, reduced=False)\n",
    "dfunreduced_rare['Model'] = dfunreduced_rare['Model'].astype(str) + ' +rare'\n",
    "dfunreduced_rare = dfunreduced_rare\n",
    "dfunreduced_rare"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfunreduced_rare.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfunreduced_norare = pd.DataFrame(columns=('Language', 'Hints', 'Seed', 'Vae','Model','Split','Acc','Acc_std','F1','F1_std',))\n",
    "get_lang_scores(df=dfunreduced_norare,exppath=\"./checkpoints_morph_norare\",datapath=\"./data/SIGDataSet.large/\",  models=(\"0proto\",\"1proto\",\"2proto\"), vaes=(\"false\",), novel=False, reduced=False)\n",
    "get_lang_scores(df=dfunreduced_norare,exppath=\"./checkpoints_large_test\",datapath=\"./data/SIGDataSet.large/\",  models=(\"baseline\",), novel=False,reduced=False)\n",
    "get_lang_scores(df=dfunreduced_norare,exppath=\"./checkpoints_large_test_geca_norare\",datapath=\"./data/SIGDataSet.large/\",models=(\"geca\",), novel=False, reduced=False)\n",
    "dfunreduced_norare.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(dfunreduced_norare)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfunreduced = dfunreduced_rare.append(dfunreduced_norare, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfunreduced"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_for_eval(df, hints=8, vae=\"false\"):\n",
    "    return df[(df[\"Hints\"]==hints) & (df['Vae']==vae)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "df_merged = filter_for_eval(merge_test_val(dfunreduced))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_merged"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_others = filter_for_eval(get_others(dfunreduced))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_others "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models=(\"baseline\",\"geca\",\"0proto\",\"1proto\",\"2proto\",\"geca +rare\",\"0proto +rare\",\"1proto +rare\",\"2proto +rare\")\n",
    "langs = (\"spanish\", \"turkish\", \"swahili\")\n",
    "hints=8\n",
    "vae=\"false\"\n",
    "cols = [\"Language\", \"Seed\", \"Model\", \"Acc_std\", \"F1_std\"]\n",
    "def get_significance_data(df, models,langs,hints,vae,cols):\n",
    "    data = {}\n",
    "    for l in langs:\n",
    "        data[l] = {}\n",
    "        for m in models:\n",
    "            data[l][m] = {}\n",
    "            for t in [\"Acc_std\", \"F1_std\"]:\n",
    "                data[l][m][t] = []\n",
    "                for s in range(5):\n",
    "                    cond1 = (df[\"Model\"]==m) & (df[\"Language\"]==l) & (df[\"Seed\"]==s)\n",
    "                    if t == \"Acc_std\":\n",
    "                        data[l][m][t].extend([int(el) for el in df.loc[cond1, t].tolist()[0]])\n",
    "                    else:\n",
    "                        data[l][m][t].extend(df.loc[cond1, t].tolist()[0])\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pval_pstfut = get_significance_data(df_merged, models,langs,hints,vae,cols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(pval_pstfut[\"turkish\"]['baseline']['Acc_std'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(pval_pstfut[\"swahili\"][\"baseline\"][\"Acc_std\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_d_avg(pvals, models, langs):\n",
    "    d_avg = {}\n",
    "    for l in langs:\n",
    "        d_avg[l] = {}\n",
    "        for m in models:\n",
    "            d_avg[l][m] = {}\n",
    "            for t in [\"Acc_std\", \"F1_std\"]:\n",
    "                d_avg[l][m][t] = np.mean(pvals[l][m][t])\n",
    "#         d_avg[m][t][\"std\"] = np.std(d[m][t])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d_avg_pstfut = get_d_avg(pval_pstfut,models,langs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame(d_avg[\"turkish\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame(d_avg[\"spanish\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame(d_avg[\"swahili\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def t_test(d, model1, model2, t):\n",
    "    return(stats.ttest_rel(d[model1][t],d[model2][t]).pvalue)\n",
    "\n",
    "def get_pvals(d, t):\n",
    "    sign = {}\n",
    "    for m1 in models:\n",
    "        sign[m1] = {}\n",
    "        for m2 in models:\n",
    "            sign[m1][m2] = t_test(d,m1,m2,t)\n",
    "    df = pd.DataFrame(sign)\n",
    "    df = remove_upper_diagonal(df)\n",
    "    return df.replace(np.nan, '', regex=True)\n",
    "\n",
    "def remove_upper_diagonal(df):\n",
    "    return df.where(np.tril(np.ones(df.shape)).astype(np.bool))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"acc\")       \n",
    "display(get_pvals(pval_pstfut[\"turkish\"],\"Acc_std\"))\n",
    "print(\"f1\")\n",
    "display(get_pvals(pval_pstfut[\"turkish\"],\"F1_std\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"acc\")       \n",
    "display(get_pvals(pval_pstfut[\"spanish\"],\"Acc_std\"))\n",
    "print(\"f1\")\n",
    "display(get_pvals(pval_pstfut[\"spanish\"],\"F1_std\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"acc\")       \n",
    "display(get_pvals(pval_pstfut[\"swahili\"],\"Acc_std\"))\n",
    "print(\"f1\")\n",
    "display(get_pvals(pval_pstfut[\"swahili\"],\"F1_std\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(get_pvals(pval_pstfut[\"turkish\"],\"F1_std\").to_latex(caption=\"Turkish F1 Significance\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(get_pvals(pval_pstfut[\"spanish\"],\"F1_std\").to_latex(caption=\"Spanish F1 Significance\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(get_pvals(pval_pstfut[\"swahili\"],\"F1_std\").to_latex(caption=\"Swahili F1 Significance\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pval_prs =  get_significance_data(df_others, models,langs,hints,vae,cols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(get_pvals(pval_prs[\"turkish\"],\"F1_std\").to_latex(caption=\"Turkish F1 Significance\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(get_pvals(pval_prs[\"spanish\"],\"F1_std\").to_latex(caption=\"Spanish F1 Significance\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(get_pvals(pval_prs[\"swahili\"],\"F1_std\").to_latex(caption=\"Swahili F1 Significance\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ablations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfscan_ablations = pd.DataFrame(columns=(\"Task\",'Split', 'Seed', 'Vae','Model','Acc','Acc_std','F1','F1_std',))\n",
    "get_scan_scores(df=dfscan_ablations,\n",
    "                models=(\"ID.1proto\",\"nocopy.1proto\", \"nocopy.2proto\"),\n",
    "                exppath=\"./checkpoints_ablations/\",\n",
    "               )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfscan_ablations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def mean_std(x):\n",
    "#     return(str(round(np.mean(x),2))+\" (${+-}\"+str(round(np.std(x),2))+\"$)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def get_scan_scores_unreduced(df=None,\n",
    "#                     tasks=(\"jump\",\"around_right\"),\n",
    "#                     seeds=(0,1,2,3,4),\n",
    "#                     vaes =(\"true\",\"false\"),\n",
    "#                     models=(\"0proto\",\"1proto\",\"2proto\"),\n",
    "#                     exppath=\"./checkpoints\",\n",
    "#                    ):\n",
    "#     for task in tasks:\n",
    "#             for seed in seeds:\n",
    "#                 for vae in vaes:\n",
    "#                     for model in  models:\n",
    "#                         taskpath=os.path.join(exppath,\"SCANDataSet\")\n",
    "#                         identifier =\"{}.vae.{}.{}.seed.{}\".format(model,vae,task,seed)\n",
    "#                         condfile=os.path.join(taskpath,\"logs\",identifier+\".cond.log\") \n",
    "#                         if os.path.exists(condfile):\n",
    "#                             lines  = open(condfile,'r').readlines()\n",
    "#                             print(\"processing: \"+condfile)\n",
    "#                             for (s,r) in get_splits(condfile).items():\n",
    "#                                 acc, accstd, f1, f1std = calculate_scores_unreduced(lines[r])  \n",
    "#                                 df.loc[len(df.index)] = (task,s,seed,vae,model,acc,accstd,f1,f1std)\n",
    "#                         else:\n",
    "#                             print(f\"file doesnot exist: {condfile}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# def get_morph_results_table_alt_std_deprecated(df, hints=4, vae=\"false\", lang=\"average\", markdown=False):\n",
    "#     splits_s = [\"TEST EVALS\", \"VAL EVALS\", \"TEST EASY\"]\n",
    "#     splits_s_alt = [\"Future Tense\", \"Past Tense\", \"Present Tense\"]\n",
    "#     score_s = [\"Acc\", \"F1\"]\n",
    "#     cols_s = [\"Language\", \"Seed\", \"Model\", \"Split\"] + score_s\n",
    "\n",
    "#     df = df.replace(splits_s, splits_s_alt).\\\n",
    "#                     loc[(df['Split'].isin(splits_s)) & (df['Vae'] == vae) & (df['Hints'] == hints), cols_s].\\\n",
    "#                     reset_index().\\\n",
    "#                     drop(columns=['index'])\n",
    "\n",
    "#     func = my_mean if markdown else mean_std\n",
    "#     if lang == \"average\":\n",
    "#         df = df.groupby(by=[\"Model\",\"Split\", \"Seed\"]).\\\n",
    "#                     agg({\"Acc\":\"mean\", \"F1\":\"mean\"}).\\\n",
    "#                     reset_index().\\\n",
    "#                     groupby(by=[\"Model\",\"Split\"]).\\\n",
    "#                     agg({'Acc':func, 'F1':func}).\\\n",
    "#                     reset_index(\"Split\").pivot(columns=\"Split\")\n",
    "#     else:\n",
    "#         df= df[df[\"Language\"]==lang].groupby(by=[\"Model\",\"Split\", \"Seed\"]).\\\n",
    "#             agg({\"Acc\":\"mean\", \"F1\":\"mean\"}).\\\n",
    "#             reset_index().\\\n",
    "#             groupby(by=[\"Model\",\"Split\"]).\\\n",
    "#             agg({'Acc':func, 'F1':func}).\\\n",
    "#             reset_index(\"Split\").pivot(columns=\"Split\")\n",
    "        \n",
    "\n",
    "#     df.columns = df.columns.swaplevel(1,0)\n",
    "#     #return(df_mean_std.iloc[[3,4,0,1,2,],[0,3,1,4,2,5]])\n",
    "#     return(df.iloc[[0,1,2],[0,3,1,4,2,5]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #TODO: implement novel_tags\n",
    "# def calculate_scores_unreduced_deprecated(lines):\n",
    "#     start=0 \n",
    "#     finished = False\n",
    "#     preds, tps, fps, fns, f1s, novel, sorter = [],[],[],[],[],[],[]\n",
    "#     while not finished:\n",
    "#         data, start = find_next_matching_block(lines,start)\n",
    "#         if data is not None:\n",
    "#             inp, ref, pred_here = data\n",
    "#             tp = (len([p for p in pred_here if p in ref]))\n",
    "#             fp = (len([p for p in pred_here if p not in ref]))\n",
    "#             fn = (len([p for p in ref if p not in pred_here]))\n",
    "#             prec = tp / (tp + fp)\n",
    "#             rec = tp / (tp + fn)\n",
    "#             if prec == 0 or rec == 0:\n",
    "#                 f1 = 0\n",
    "#             else:\n",
    "#                 f1 = 2 * prec * rec / (prec + rec)\n",
    "#             sorter.append(inp+\"\\t\".join(ref))\n",
    "#             f1s.append(f1)\n",
    "#             tps.append(tp)\n",
    "#             fps.append(fp)\n",
    "#             fns.append(fn)\n",
    "#             preds.append(pred_here == ref)\n",
    "#         else:\n",
    "#             finished = True\n",
    "#     f1s = sort_by(f1s,sorter)\n",
    "#     tps = sort_by(tps,sorter)\n",
    "#     fps = sort_by(fps,sorter)\n",
    "#     fns = sort_by(fns,sorter)\n",
    "#     preds = sort_by(preds,sorter)\n",
    "#     tp, fp, fn = np.sum(tps), np.sum(fps), np.sum(fns)\n",
    "#     prec = tp / (tp + fp)\n",
    "#     rec = tp / (tp + fn)\n",
    "#     if prec == 0 or rec == 0:\n",
    "#         f1 = 0\n",
    "#     else:\n",
    "#         f1 = 2 * prec * rec / (prec + rec)\n",
    "#     return np.mean(preds), preds, f1, f1s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def get_lang_scores_unreduced(df=None,\n",
    "#                     langs=(\"spanish\",\"turkish\",\"swahili\"),\n",
    "#                     hintss=(4,8,16),\n",
    "#                     seeds=(0,1,2,3,4),\n",
    "#                     vaes =(\"true\",\"false\"),\n",
    "#                     models=(\"baseline\",\"0proto\",\"1proto\",\"2proto\"),\n",
    "#                     exppath=\"./checkpoints\",\n",
    "#                     datapath=\"data/SIGDataSet.large\",\n",
    "#                     novel=False,\n",
    "#                    ):\n",
    "#     train_tags=None\n",
    "#     for lang in langs:\n",
    "#         for hints in hintss:\n",
    "#             for seed in seeds:\n",
    "#                 if novel:\n",
    "#                     train_tags = get_tags(datapath + lang + '/', hints=hints,seed=seed)\n",
    "#                 for vae in vaes:\n",
    "#                     for model in  models:\n",
    "#                         langpath=os.path.join(exppath,\"SIGDataSet\",lang)\n",
    "#                         if model == \"baseline\" or model == \"geca\":\n",
    "#                             identifier =\"{}.hints.{}.seed.{}\".format(model,hints,seed)\n",
    "#                         else:\n",
    "#                             identifier =\"{}.vae.{}.hints.{}.seed.{}\".format(model,vae,hints,seed)\n",
    "#                         condfile=os.path.join(langpath,\"logs\",identifier+\".cond.log\") \n",
    "#                         if os.path.exists(condfile):\n",
    "#                             lines  = open(condfile,'r').readlines()\n",
    "#                             if len(lines) < 2142:\n",
    "#                                 print(\"format broken in \"+condfile)\n",
    "#                                 continue\n",
    "# #                             print(\"processing: \"+condfile)\n",
    "#                             for (s,r) in get_splits(condfile).items():\n",
    "#                                 acc, accstd, f1, f1std = calculate_scores(lines[r], train_tags=train_tags, reduced=False)  \n",
    "#                                 df.loc[len(df.index)] = (lang,hints,seed,vae,model,s,acc,accstd,f1,f1std)\n",
    "#                         else:\n",
    "#                             print(f\"file doesnot exist: {condfile}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
