{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import networkx as nx\n",
    "import json\n",
    "import os\n",
    "from os import path\n",
    "from addict import Dict\n",
    "import sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "sys.path.append(os.path.expanduser('~/mlp/lgw/'))\n",
    "from lgw.graph_generator import GraphGen, BigGraphGen\n",
    "from lgw.args import get_args\n",
    "from codes.utils.inspect_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_loc = os.path.expanduser('~/checkpoint/lgw/data')\n",
    "data_name = 'comp_r10_n100_ov'\n",
    "loc = os.path.join(data_loc, data_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "51"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_paths = {}\n",
    "modes = [\"train\",\"valid\",\"test\"]\n",
    "for mode in modes:\n",
    "    data_path = os.path.join(loc, mode)\n",
    "    all_paths[mode] = [\n",
    "        folder\n",
    "        for folder in os.listdir(data_path)\n",
    "        if os.path.isdir(os.path.join(data_path, folder)) and os.path.exists(os.path.join(data_path, folder, 'config.json'))\n",
    "    ] \n",
    "len(all_paths['train'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "worlds = {}\n",
    "for mode, rule_worlds in all_paths.items():\n",
    "    worlds[mode] = {}\n",
    "    for rule_world in rule_worlds:\n",
    "        worlds[mode][rule_world] = load_world(get_paths(mode, rule_world, loc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_patterns = []\n",
    "for mode in worlds:\n",
    "    for rule_world in worlds[mode]:\n",
    "        all_patterns.extend([get_rel_pattern(g) for g in worlds[mode][rule_world].train])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Stats methods\n",
    "def get_graph_stats(world):\n",
    "    num_nodes = []\n",
    "    num_edges = []\n",
    "    in_deg = []\n",
    "    out_deg = []\n",
    "    modes = ['train','valid','test']\n",
    "    for mode in modes:\n",
    "        graphs = world[mode]\n",
    "        nx_graphs,_ = load_networkx_graphs(graphs)\n",
    "        for nxg in nx_graphs:\n",
    "            info = nx.info(nxg)\n",
    "            info_d = {f.split(\":\")[0]:f.split(\":\")[1].strip() for f in info.split('\\n')}\n",
    "            num_nodes.append(int(info_d['Number of nodes']))\n",
    "            num_edges.append(int(info_d['Number of edges']))\n",
    "            in_deg.append(float(info_d['Average in degree']))\n",
    "            out_deg.append(float(info_d['Average out degree']))\n",
    "    return np.mean(num_nodes), np.mean(num_edges), np.mean(in_deg), np.mean(out_deg)\n",
    "\n",
    "def get_simple_graph_stats(world):\n",
    "    num_nodes = []\n",
    "    num_edges = []\n",
    "    modes = ['train','valid','test']\n",
    "    for mode in modes:\n",
    "        graphs = world[mode]\n",
    "        for gr in graphs:\n",
    "            num_edges.append(len(gr['edges']))\n",
    "            all_nodes = [e[:2] for e in gr['edges']]\n",
    "            all_nodes = [r for n in all_nodes for r in n]\n",
    "            num_nodes.append(len(set(all_nodes)))\n",
    "    return \"{}-{},{}-{}\".format(round(np.mean(num_nodes),3), round(np.std(num_nodes),2),\n",
    "                                round(np.mean(num_edges),3), round(np.std(num_edges),2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset Statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_class = []\n",
    "world_ids = []\n",
    "num_des = []\n",
    "world_type = []\n",
    "world_modes = ['train','valid','test']\n",
    "avg_resolution_length = []\n",
    "num_nodes = []\n",
    "num_edges = []\n",
    "for world_mode in world_modes:\n",
    "    for world_id, world in worlds[world_mode].items():\n",
    "        world_ids.append(world_id)\n",
    "        num_class.append(len(get_class(world)))\n",
    "        num_des.append(len(get_descriptors(world)))\n",
    "        world_type.append(world_mode)\n",
    "        avg_resolution_length.append(get_avg_resolution_length(world))\n",
    "        stats = get_simple_graph_stats(world)\n",
    "        num_nodes.append(stats.split(',')[0])\n",
    "        num_edges.append(stats.split(',')[1])\n",
    "graphlog_stats = pd.DataFrame({'world_id':world_ids,\n",
    "                               'world_id_num': [int(w.split('_')[-1]) for w in world_ids],\n",
    "                               'num_class': num_class,\n",
    "                               'ND': num_des,\n",
    "                               'Average Resolution Length': avg_resolution_length,\n",
    "                               'Split': world_type,\n",
    "                               'Average Nodes': num_nodes,\n",
    "                               'Average Edges': num_edges\n",
    "                              })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "## post processing\n",
    "graphlog_stats['nodes-std'] = graphlog_stats['Average Nodes'].apply(lambda x: float(x.split('-')[-1]))\n",
    "graphlog_stats['Average Nodes'] = graphlog_stats['Average Nodes'].apply(lambda x: float(x.split('-')[0]))\n",
    "graphlog_stats['edges-std'] = graphlog_stats['Average Edges'].apply(lambda x: float(x.split('-')[-1]))\n",
    "graphlog_stats['Average Edges'] = graphlog_stats['Average Edges'].apply(lambda x: float(x.split('-')[0]))\n",
    "graphlog_stats['edge_to_noise_ratio'] = graphlog_stats['Average Resolution Length'] / graphlog_stats['Average Edges']\n",
    "graphlog_stats['Average Resolution Length'] = graphlog_stats['Average Resolution Length'].apply(lambda x: round(x,2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Add Supervised Learning Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Add supervised learning results\n",
    "supervised_all = pd.read_csv('raw_data/supervised_result_complete.csv')\n",
    "clean_model_names = {\n",
    "    \"GatedNodeGatEncoder\": \"GAT\",\n",
    "    \"GatedGatEncoder\": \"E-GAT\",\n",
    "    \"RepresentationGCNEncoder\": \"GCN\",\n",
    "    \"CompositionRGCNEncoder\": \"RGCN\",\n",
    "    \"Param\": \"Param\",\n",
    "}\n",
    "supervised_all['rep_fn'] = supervised_all.model_representation_fn_path.apply(lambda x: clean_model_names[x.split('.')[-1]])\n",
    "supervised_all['comp_fn'] = supervised_all.model_composition_fn_path.apply(lambda x: clean_model_names[x.split('.')[-1]])\n",
    "supervised_test = pd.read_csv('raw_data/supervised_result_complete_test.csv')\n",
    "supervised_test['rep_fn'] = supervised_test.model_representation_fn_path.apply(lambda x: clean_model_names[x.split('.')[-1]])\n",
    "supervised_test['comp_fn'] = supervised_test.model_composition_fn_path.apply(lambda x: clean_model_names[x.split('.')[-1]])\n",
    "supervised_valid = pd.read_csv('raw_data/supervised_result_complete_valid.csv')\n",
    "supervised_valid['rep_fn'] = supervised_valid.model_representation_fn_path.apply(lambda x: clean_model_names[x.split('.')[-1]])\n",
    "supervised_valid['comp_fn'] = supervised_valid.model_composition_fn_path.apply(lambda x: clean_model_names[x.split('.')[-1]])\n",
    "\n",
    "\n",
    "\n",
    "def get_baseline(rep_fn, comp_fn, rule):\n",
    "    if rule in ['rule_54','rule_55','rule_56']:\n",
    "        rez = supervised_test[(supervised_test.rep_fn == rep_fn) & (supervised_test.comp_fn == comp_fn) & (supervised_test.general_train_rule == rule)]\n",
    "        rez['train_test_accuracy'] = rez.test_test_accuracy\n",
    "    elif rule in ['rule_51','rule_52','rule_53']:\n",
    "        rez = supervised_valid[(supervised_valid.rep_fn == rep_fn) & (supervised_valid.comp_fn == comp_fn) & (supervised_valid.general_train_rule == rule)]\n",
    "        rez['train_test_accuracy'] = rez.valid_test_accuracy\n",
    "    else:\n",
    "        rez = supervised_all[(supervised_all.rep_fn == rep_fn) & (supervised_all.comp_fn == comp_fn) & (supervised_all.general_train_rule == rule)]\n",
    "    return rez['train_test_accuracy'].values[0]\n",
    "\n",
    "model_order = [\"GAT-E-GAT\",\"GCN-E-GAT\",\"Param-E-GAT\",\"GAT-RGCN\",\"GCN-RGCN\",\"Param-RGCN\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "shortened_model_names = {\n",
    "    'GAT-E-GAT': 'M1',\n",
    "    'GCN-E-GAT': 'M2',\n",
    "    'Param-E-GAT': 'M3',\n",
    "    'GAT-RGCN': 'M4',\n",
    "    'GCN-RGCN': 'M5',\n",
    "    'Param-RGCN': 'M6'\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model in model_order:\n",
    "    graphlog_stats[model] = 0.0\n",
    "    for i,row in graphlog_stats.iterrows():\n",
    "        rep_fn = model.split('-')[0]\n",
    "        comp_fn = '-'.join(model.split('-')[1:])\n",
    "        #if row['Split'] != 'valid':\n",
    "        graphlog_stats.at[i, model] = round(get_baseline(rep_fn, comp_fn, row['world_id']),3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Set Difficulty\n",
    "graphlog_stats['Difficulty'] = graphlog_stats[\"GAT-E-GAT\"].apply(lambda x: \"Easy\" if x >= 0.7 else \"Medium\" if x >= 0.54 else \"Hard\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "display_columns = ['world_id', 'num_class','ND','Split','Average Resolution Length','Average Nodes','Average Edges','Difficulty'] + model_order"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "graphlog_stats.sort_values(by='world_id_num')[(graphlog_stats[\"Split\"] == 'train') & (graphlog_stats[\"Difficulty\"] == \"Hard\")]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrlrrrlrrrrrr}\n",
      "\\toprule\n",
      "world\\_id &  num\\_class &    ND &  Split &  Average Resolution Length &  Average Nodes &  Average Edges & Difficulty &  GAT-E-GAT &  GCN-E-GAT &  Param-E-GAT &  GAT-RGCN &  GCN-RGCN &  Param-RGCN \\\\\n",
      "\\midrule\n",
      "  rule\\_0 &         17 &   286 &  train &                       4.49 &         15.487 &         19.295 &       Hard &      0.481 &      0.500 &        0.494 &     0.486 &     0.462 &       0.462 \\\\\n",
      "  rule\\_1 &         15 &   239 &  train &                       4.10 &         11.565 &         13.615 &       Hard &      0.432 &      0.411 &        0.428 &     0.406 &     0.400 &       0.408 \\\\\n",
      "  rule\\_2 &         17 &   157 &  train &                       3.21 &          9.809 &         11.165 &       Hard &      0.412 &      0.357 &        0.373 &     0.347 &     0.347 &       0.319 \\\\\n",
      "  rule\\_3 &         16 &   189 &  train &                       3.63 &         11.137 &         13.273 &       Hard &      0.429 &      0.404 &        0.473 &     0.373 &     0.401 &       0.451 \\\\\n",
      "  rule\\_4 &         16 &   189 &  train &                       3.94 &         12.622 &         15.501 &     Medium &      0.624 &      0.606 &        0.619 &     0.475 &     0.481 &       0.595 \\\\\n",
      "  rule\\_5 &         14 &   275 &  train &                       4.41 &         14.545 &         18.872 &       Hard &      0.526 &      0.539 &        0.548 &     0.429 &     0.461 &       0.455 \\\\\n",
      "  rule\\_6 &         16 &   249 &  train &                       5.06 &         16.257 &         20.164 &       Hard &      0.528 &      0.514 &        0.536 &     0.498 &     0.495 &       0.476 \\\\\n",
      "  rule\\_7 &         17 &   288 &  train &                       4.47 &         13.161 &         16.333 &     Medium &      0.613 &      0.558 &        0.598 &     0.487 &     0.486 &       0.537 \\\\\n",
      "  rule\\_8 &         15 &   404 &  train &                       5.43 &         15.997 &         19.134 &     Medium &      0.627 &      0.643 &        0.629 &     0.523 &     0.563 &       0.569 \\\\\n",
      "  rule\\_9 &         19 &  1011 &  train &                       7.22 &         24.151 &         32.668 &       Easy &      0.758 &      0.744 &        0.739 &     0.683 &     0.651 &       0.623 \\\\\n",
      " rule\\_10 &         18 &   524 &  train &                       5.87 &         18.011 &         22.202 &     Medium &      0.656 &      0.654 &        0.663 &     0.596 &     0.563 &       0.605 \\\\\n",
      " rule\\_11 &         17 &   194 &  train &                       4.29 &         11.459 &         13.037 &     Medium &      0.552 &      0.525 &        0.533 &     0.445 &     0.456 &       0.419 \\\\\n",
      " rule\\_12 &         15 &   306 &  train &                       4.14 &         11.238 &         12.919 &       Easy &      0.771 &      0.726 &        0.603 &     0.511 &     0.561 &       0.523 \\\\\n",
      " rule\\_13 &         16 &   149 &  train &                       3.58 &         11.238 &         13.549 &       Hard &      0.453 &      0.402 &        0.419 &     0.347 &     0.298 &       0.344 \\\\\n",
      " rule\\_14 &         16 &   224 &  train &                       4.14 &         11.371 &         13.403 &       Hard &      0.448 &      0.457 &        0.401 &     0.314 &     0.318 &       0.332 \\\\\n",
      " rule\\_15 &         14 &   224 &  train &                       3.82 &         12.661 &         15.105 &       Hard &      0.494 &      0.423 &        0.501 &     0.402 &     0.397 &       0.435 \\\\\n",
      " rule\\_16 &         16 &   205 &  train &                       3.59 &         11.345 &         13.293 &       Hard &      0.318 &      0.332 &        0.292 &     0.328 &     0.306 &       0.291 \\\\\n",
      " rule\\_17 &         17 &   147 &  train &                       3.16 &          8.163 &          8.894 &       Hard &      0.347 &      0.308 &        0.274 &     0.164 &     0.161 &       0.181 \\\\\n",
      " rule\\_18 &         18 &   923 &  train &                       6.63 &         25.035 &         33.080 &       Easy &      0.700 &      0.680 &        0.713 &     0.650 &     0.641 &       0.618 \\\\\n",
      " rule\\_19 &         16 &   416 &  train &                       6.10 &         17.180 &         20.818 &       Easy &      0.790 &      0.774 &        0.777 &     0.731 &     0.729 &       0.702 \\\\\n",
      " rule\\_20 &         20 &  2024 &  train &                       8.63 &         34.059 &         45.985 &       Easy &      0.830 &      0.799 &        0.854 &     0.756 &     0.741 &       0.750 \\\\\n",
      " rule\\_21 &         13 &   272 &  train &                       4.58 &         10.559 &         11.754 &     Medium &      0.621 &      0.610 &        0.632 &     0.531 &     0.516 &       0.580 \\\\\n",
      " rule\\_22 &         17 &   422 &  train &                       5.21 &         16.540 &         20.681 &     Medium &      0.586 &      0.593 &        0.628 &     0.530 &     0.506 &       0.573 \\\\\n",
      " rule\\_23 &         15 &   383 &  train &                       4.97 &         17.067 &         21.111 &       Hard &      0.508 &      0.522 &        0.493 &     0.455 &     0.473 &       0.476 \\\\\n",
      " rule\\_24 &         18 &   879 &  train &                       6.33 &         21.402 &         26.152 &       Easy &      0.706 &      0.704 &        0.743 &     0.656 &     0.641 &       0.638 \\\\\n",
      " rule\\_25 &         15 &   278 &  train &                       3.84 &         11.093 &         12.775 &       Hard &      0.424 &      0.419 &        0.382 &     0.358 &     0.345 &       0.412 \\\\\n",
      " rule\\_26 &         15 &   352 &  train &                       4.71 &         14.157 &         17.115 &     Medium &      0.565 &      0.534 &        0.532 &     0.466 &     0.461 &       0.499 \\\\\n",
      " rule\\_27 &         16 &   393 &  train &                       4.98 &         14.296 &         16.499 &       Easy &      0.713 &      0.714 &        0.722 &     0.632 &     0.604 &       0.647 \\\\\n",
      " rule\\_28 &         16 &   391 &  train &                       4.82 &         17.551 &         21.897 &     Medium &      0.575 &      0.564 &        0.571 &     0.503 &     0.499 &       0.552 \\\\\n",
      " rule\\_29 &         16 &   144 &  train &                       3.87 &         10.193 &         11.774 &       Hard &      0.468 &      0.445 &        0.475 &     0.325 &     0.336 &       0.389 \\\\\n",
      " rule\\_30 &         17 &   177 &  train &                       3.51 &         10.270 &         11.764 &       Hard &      0.381 &      0.426 &        0.382 &     0.357 &     0.316 &       0.336 \\\\\n",
      " rule\\_31 &         19 &   916 &  train &                       5.90 &         20.147 &         26.562 &       Easy &      0.788 &      0.789 &        0.770 &     0.669 &     0.674 &       0.641 \\\\\n",
      " rule\\_32 &         16 &   287 &  train &                       4.66 &         16.270 &         20.929 &     Medium &      0.674 &      0.671 &        0.700 &     0.621 &     0.594 &       0.615 \\\\\n",
      " rule\\_33 &         18 &   312 &  train &                       4.50 &         14.738 &         18.266 &     Medium &      0.695 &      0.660 &        0.709 &     0.710 &     0.679 &       0.668 \\\\\n",
      " rule\\_34 &         18 &   504 &  train &                       5.00 &         15.345 &         18.614 &       Easy &      0.908 &      0.888 &        0.906 &     0.768 &     0.762 &       0.811 \\\\\n",
      " rule\\_35 &         19 &   979 &  train &                       6.23 &         21.867 &         28.266 &       Easy &      0.831 &      0.750 &        0.782 &     0.680 &     0.700 &       0.662 \\\\\n",
      " rule\\_36 &         19 &   252 &  train &                       4.66 &         13.900 &         16.613 &       Easy &      0.742 &      0.698 &        0.698 &     0.659 &     0.627 &       0.651 \\\\\n",
      " rule\\_37 &         17 &   260 &  train &                       4.00 &         11.956 &         14.010 &       Easy &      0.843 &      0.826 &        0.826 &     0.673 &     0.698 &       0.716 \\\\\n",
      " rule\\_38 &         17 &   568 &  train &                       5.21 &         15.305 &         20.075 &       Easy &      0.748 &      0.762 &        0.733 &     0.644 &     0.630 &       0.719 \\\\\n",
      " rule\\_39 &         15 &   182 &  train &                       3.98 &         12.552 &         14.800 &       Easy &      0.737 &      0.642 &        0.635 &     0.592 &     0.603 &       0.587 \\\\\n",
      " rule\\_40 &         17 &   181 &  train &                       3.69 &         11.556 &         14.437 &     Medium &      0.552 &      0.584 &        0.575 &     0.525 &     0.472 &       0.479 \\\\\n",
      " rule\\_41 &         15 &   113 &  train &                       3.58 &         10.162 &         11.553 &     Medium &      0.619 &      0.601 &        0.626 &     0.490 &     0.468 &       0.470 \\\\\n",
      " rule\\_42 &         14 &    95 &  train &                       2.96 &          8.939 &          9.751 &       Hard &      0.511 &      0.472 &        0.483 &     0.386 &     0.393 &       0.395 \\\\\n",
      " rule\\_43 &         16 &   162 &  train &                       3.36 &         11.077 &         13.337 &     Medium &      0.622 &      0.567 &        0.579 &     0.473 &     0.482 &       0.437 \\\\\n",
      " rule\\_44 &         18 &   705 &  train &                       4.75 &         15.310 &         18.172 &       Hard &      0.538 &      0.561 &        0.603 &     0.498 &     0.519 &       0.450 \\\\\n",
      " rule\\_45 &         15 &   151 &  train &                       3.39 &          9.127 &         10.001 &     Medium &      0.569 &      0.580 &        0.592 &     0.535 &     0.524 &       0.524 \\\\\n",
      " rule\\_46 &         19 &  2704 &  train &                       7.94 &         31.458 &         43.489 &       Easy &      0.850 &      0.820 &        0.828 &     0.773 &     0.762 &       0.749 \\\\\n",
      " rule\\_47 &         18 &   647 &  train &                       6.66 &         22.139 &         27.789 &       Easy &      0.723 &      0.667 &        0.708 &     0.620 &     0.649 &       0.611 \\\\\n",
      " rule\\_48 &         16 &   978 &  train &                       6.15 &         17.802 &         21.674 &       Easy &      0.812 &      0.798 &        0.812 &     0.772 &     0.763 &       0.753 \\\\\n",
      " rule\\_49 &         14 &   169 &  train &                       3.41 &          9.983 &         11.177 &       Easy &      0.714 &      0.734 &        0.700 &     0.511 &     0.491 &       0.615 \\\\\n",
      " rule\\_50 &         16 &   286 &  train &                       3.99 &         12.274 &         16.117 &     Medium &      0.651 &      0.653 &        0.656 &     0.555 &     0.583 &       0.570 \\\\\n",
      " rule\\_51 &         16 &   332 &  valid &                       4.44 &         16.384 &         21.817 &       Easy &      0.746 &      0.742 &        0.738 &     0.667 &     0.657 &       0.689 \\\\\n",
      " rule\\_52 &         17 &   351 &  valid &                       4.81 &         16.231 &         20.613 &     Medium &      0.697 &      0.716 &        0.754 &     0.653 &     0.655 &       0.670 \\\\\n",
      " rule\\_53 &         15 &   165 &  valid &                       3.65 &         10.838 &         12.378 &       Hard &      0.458 &      0.464 &        0.525 &     0.334 &     0.364 &       0.373 \\\\\n",
      " rule\\_54 &         13 &   303 &   test &                       5.25 &         13.503 &         15.567 &     Medium &      0.638 &      0.623 &        0.603 &     0.587 &     0.586 &       0.555 \\\\\n",
      " rule\\_55 &         16 &   293 &   test &                       4.83 &         16.444 &         20.944 &     Medium &      0.625 &      0.582 &        0.578 &     0.561 &     0.528 &       0.571 \\\\\n",
      " rule\\_56 &         15 &   241 &   test &                       4.40 &         14.010 &         16.702 &     Medium &      0.653 &      0.681 &        0.692 &     0.522 &     0.513 &       0.550 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(graphlog_stats.sort_values(by='world_id_num')[display_columns].to_latex(index=False)) #to_csv('graphlog_stats.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "M1    26\n",
       "M3    20\n",
       "M2    10\n",
       "M4     1\n",
       "Name: best, dtype: int64"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## best ranking\n",
    "graphlog_stats['best'] = \"\"\n",
    "for i,row in graphlog_stats.iterrows():\n",
    "    vals = [(shortened_model_names[m], row[m]) for m in shortened_model_names]\n",
    "    vals = sorted(vals, key=lambda x: x[1], reverse=True)\n",
    "    graphlog_stats.at[i,'best'] = vals[0][0]\n",
    "## Aggregating\n",
    "graphlog_stats.best.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "graphlog_stats.mean()\n",
    "graphlog_stats.to_csv('clean_data/graphlog_stats.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (clutrr-2.0)",
   "language": "python",
   "name": "clutrr-2.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}