{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import json\n",
    "import numpy as np\n",
    "import os.path\n",
    "import pandas as pd\n",
    "\n",
    "from itertools import product"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_evals(root):\n",
    "    with open(root + '/' + 'experiment_config.json') as f:\n",
    "        hyperparameters = json.load(f)\n",
    "\n",
    "    with open(root + '/' + glob.glob('*/', root_dir=root)[0]+'evals.json') as f:\n",
    "        evaluations = json.load(f)\n",
    "        eval_keys = list(evaluations[list(evaluations.keys())[0]].keys())\n",
    "\n",
    "    eval_dict = {key : np.zeros((hyperparameters['cv_num_folds'], hyperparameters['cv_iterations'], hyperparameters['cv_save_top_k'] + hyperparameters['cv_save_last'])) for key in eval_keys}\n",
    "    checkpoint_dict = {key : np.empty((hyperparameters['cv_num_folds'], hyperparameters['cv_iterations'], hyperparameters['cv_save_top_k'] + hyperparameters['cv_save_last']), dtype=object) for key in eval_keys}\n",
    "    for path in glob.glob('*/', root_dir=root):\n",
    "\n",
    "        with open(root + '/' + path+'evals.json') as f:\n",
    "            evals = json.load(f)\n",
    "\n",
    "        with open(root + '/' + path+'params.json') as f:\n",
    "            params = json.load(f)\n",
    "        \n",
    "        \n",
    "        for i, checkpoint in enumerate(evals.keys()):\n",
    "            for eval_key in eval_keys:\n",
    "                if i >= eval_dict[eval_key].shape[2]:\n",
    "                    eval_dict[eval_key] = np.concatenate((eval_dict[eval_key], np.zeros((hyperparameters['cv_num_folds'], hyperparameters['cv_iterations'], 1))), axis=2)\n",
    "                    checkpoint_dict[eval_key] = np.concatenate((checkpoint_dict[eval_key], np.empty((hyperparameters['cv_num_folds'], hyperparameters['cv_iterations'], 1), dtype=object)), axis=2)\n",
    "                eval_dict[eval_key][params['data_split_idx'], params['iteration'], i] = evals[checkpoint][eval_key]\n",
    "                checkpoint_dict[eval_key][params['data_split_idx'], params['iteration'], i] = checkpoint\n",
    "                \n",
    "    return eval_dict, checkpoint_dict\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_all_evals(root, verbose=3):\n",
    "    sorted_by_dataset = {}\n",
    "    for path in glob.glob('*', root_dir=root):\n",
    "\n",
    "        try:\n",
    "            with open(root + '/' + path + '/' +'experiment_config.json') as f:\n",
    "                config = json.load(f)\n",
    "                get_evals(root + '/' + path)\n",
    "            if config['dataset_name'] not in sorted_by_dataset:\n",
    "                sorted_by_dataset[config['dataset_name']] = []\n",
    "            sorted_by_dataset[config['dataset_name']].append(path)\n",
    "        except Exception as e:\n",
    "            pass\n",
    "        \n",
    "    df = {}\n",
    "    for dataset in sorted_by_dataset.keys():\n",
    "\n",
    "        for path in sorted(sorted_by_dataset[dataset]):\n",
    "            name = path.split('_')[0]\n",
    "            if name in df:\n",
    "                name = path\n",
    "\n",
    "            df[name] = ''\n",
    "            evals, ckpts = get_evals(root + '/' + path)\n",
    "            for i in range(evals['test_acc'].shape[2]) if verbose >= 1 else [verbose]:\n",
    "                test_acc = evals['test_acc'][:,:,i]\n",
    "\n",
    "                if np.max(np.max(evals['test_acc'], axis=1).flatten()[1:]) <= 0.000001:\n",
    "                    test_acc = test_acc[:1,:]\n",
    "\n",
    "                if verbose > 2:\n",
    "                    print('max', np.max(test_acc, axis=1).flatten()) \n",
    "                    print('mean', np.mean(test_acc, axis=1).flatten()) \n",
    "                if verbose > 1:\n",
    "                    print('mean max', np.mean(np.max(test_acc, axis=1)))\n",
    "                    print('mean median', np.mean(np.median(test_acc, axis=1)))\n",
    "                \n",
    "                df[name] += (str(round(np.mean(test_acc) * 100, 1)) + ' \\pm ' + str(round(np.mean(np.std(test_acc, axis=1) * 100),1)) + ' ')\n",
    "                \n",
    "     \n",
    "                if verbose > 3:\n",
    "                    print('mean top5', np.mean([evals['test_acc'][i,np.argsort(evals['val_acc'][i,:,0])][-5:] for i in range(evals['test_acc'].shape[0])]))\n",
    "                    print('mean loss', np.mean(evals['test_loss'], axis=1).flatten())\n",
    "\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_latex_all_experiment_evals(root, datasets=None, models=None, methods=None, verbose=0):\n",
    "    \n",
    "    if datasets is None:\n",
    "        datasets = glob.glob('*', root_dir=root)\n",
    "        \n",
    "    if models is None:\n",
    "        models = glob.glob('*', root_dir=root + '/' + datasets[0])\n",
    "    \n",
    "    if methods is None:\n",
    "        methods = list(get_all_evals(root + '/' + datasets[0] + '/' + models[0], verbose=verbose).keys())\n",
    "    \n",
    "    dataset_index = {dataset: i for i, dataset in enumerate(datasets)}\n",
    "    model_index = {(model,method) : i for i, (model,method) in enumerate(product(models,[''] + methods))}\n",
    "    \n",
    "    df = np.empty(dtype=object, shape=(len(models) * (len(methods) + 1), len(datasets)))\n",
    "    df.fill('')\n",
    "    \n",
    "    for path_idx, dataset in enumerate(datasets):\n",
    "\n",
    "        for model_idx, model in enumerate(models):\n",
    "\n",
    "            eval_dict = get_all_evals(root + '/' + dataset + '/' + model, verbose=verbose)\n",
    "            \n",
    "            max_method = methods[np.argmax([float(eval_dict[method][:4]) for method in methods if method in eval_dict.keys()])] if len(eval_dict.keys()) > 0 else 'no_winner'         \n",
    "            for method in methods:\n",
    "                df[model_index[(model, method)], dataset_index[dataset]] = ('$'+ ('\\\\bm{' if method==max_method else '') + eval_dict[method] + ('}' if method==max_method else '') +'$') if method in eval_dict.keys() else ''\n",
    "\n",
    "\n",
    "    df = np.concatenate([np.array([[method] if method != '' else [model] for model, method in model_index.keys()]), df], axis=1)\n",
    "    pd_df = pd.DataFrame(df, columns=['']+datasets)\n",
    "    print(pd_df.to_latex(index=False)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_latex_all_experiment_evals('ray_results_cv', \n",
    "                               datasets=['Cora_ML', 'Citeseer', 'ogbn-arxiv'], \n",
    "                               models=['GCN', 'GAT', 'GIN'], \n",
    "                               methods=['none', 'batch', 'graph', 'pair', 'graph2'],\n",
    "                               verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_latex_all_experiment_evals('ray_results_cv', \n",
    "                               datasets=['Cora_ML', 'Citeseer'], \n",
    "                               models=['GCN', 'GAT', 'GIN'], \n",
    "                               methods=['none', 'batch', 'graph', 'pair', 'graph2'],\n",
    "                               verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_latex_all_experiment_evals('ray_results_cv_shallow', datasets=['Cora_ML', 'Citeseer', 'ogb-arxiv'], models=['GCN', 'GAT'], verbose=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_latex_all_experiment_evals('ray_results_cv', datasets=['MUTAG', 'PROTEINS', 'PTC_MR'], models=['GIN', 'GCN', 'GAT'], verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_all_evals('ray_results_cv/ogbn-arxiv/GCN', verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_evals('ray_results_cv/Cora_ML/GCN/graph_20240328-071504')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rs_venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "30bc9a2a51b38fcc7fe646cb1ad2d15d35612af30b028c1d1ce1a5fdcf8fcffe"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
