{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6060e8e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66f1d879",
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = [\"MUTAG\", \"COX2\", \"ENZYMES\", \"PROTEINS\", \"Mutagenicity\", \"AIDS\", \"IMDB-BINARY\", \"IMDB-MULTI\", \"NCI1\", \"NCI109\", \n",
    "            \"BZR\", \"DHFR\", \"ogbg-molhiv\", 'BZR_MD', 'COX2_MD', 'DHFR_MD', 'ER_MD', \"PROTEINS_full\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c782ae16",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_per_dataset = {}\n",
    "for i, d in enumerate(datasets):\n",
    "    try:\n",
    "        data = pd.read_csv(f\"../src/graph_classification/results/{d}/{d}_GNN_means.csv\"\n",
    "                        , index_col=0, header=0)\n",
    "        std = pd.read_csv(f\"../src/graph_classification/results/{d}/{d}_GNN_stds.csv\"\n",
    "                        , index_col=0, header=0)\n",
    "        \n",
    "        with open(f\"../src/graph_classification/results/{d}/{d}_GNN_accuracies.json\", \"r\") as f:\n",
    "            acc_per_dataset[d] = json.load(f)\n",
    "\n",
    "        data = data[[\"loss\", \"accuracy\"]]\n",
    "        std = std[[\"loss\", \"accuracy\"]]\n",
    "        #data = data[[]]\n",
    "        data.columns = [\"test_loss_\"+d, \"test_accuracy_\"+d]\n",
    "        std.columns = [\"test_loss_\"+d, \"test_accuracy_\"+d]\n",
    "        if i == 0:\n",
    "            results = data\n",
    "            stds = std\n",
    "        else:\n",
    "            results = pd.merge(results, data, left_index=True, right_index=True, how =\"outer\")\n",
    "            stds = pd.merge(stds, std, left_index=True, right_index=True, how =\"outer\")\n",
    "    except:\n",
    "        print(f\"Error with {d}\")\n",
    "        continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "069a5a0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_loss = results.filter(like=\"loss\")\n",
    "stds_loss = stds.filter(like=\"loss\")\n",
    "results_accuracy = results.filter(like=\"accuracy\")\n",
    "stds_accuracy = stds.filter(like=\"accuracy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f63ef195",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5164adb",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d9e5de6",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_existing_results_slow = results_accuracy\n",
    "all_existing_results_faster = all_existing_results_slow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9e21b57",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_acc = results_accuracy#_split.applymap(lambda x: x[0] if isinstance(x, list) else x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abeacdf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "st_acc = stds_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abab87a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "to_report = results_acc.round(3)\n",
    "st_acc = st_acc.round(3)\n",
    "name_to = {\n",
    "                                     'MAG_EDGE_diffusion_distance': 'MagEdge',\n",
    "                                     'SPREAD_EDGE_diffusion_distance': 'SpreadEdge',\n",
    "                                     \"Flat\": \"No Pooling\",\n",
    "                                     }\n",
    "to_report_mean = to_report.rename(index=name_to)\n",
    "std_to_report = st_acc.rename(index=name_to)\n",
    "\n",
    "to_report_mean = (to_report_mean*100).round(1)\n",
    "std_to_report = (std_to_report*100).round(1)\n",
    "\n",
    "to_report = to_report_mean.astype(str) + \" ± \" + std_to_report.astype(str)\n",
    "to_report.loc[([(\"diffusion\" not in d) for d in to_report.index]),:] #= to_report.rename(index=name_to)+\n",
    "\n",
    "to_subset = [\"No Pooling\", \"MagEdge\", \n",
    "                     \"SpreadEdge\", \n",
    "                     \"NDP\", \"Graclus\", \"NMF\", \"TopK\", \"SAGPool\", \"DiffPool\", \"MinCut\"]\n",
    "\n",
    "sub = to_report.loc[to_subset,:]\n",
    "sub_mean = to_report_mean.loc[to_subset,:]\n",
    "sub_std = std_to_report.loc[to_subset,:]\n",
    "sub.to_csv(\"./plots/graph_classification_accuracy.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6e862a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "sub.iloc[:,:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f5971fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "sub.iloc[:,11:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f528a01",
   "metadata": {},
   "outputs": [],
   "source": [
    "datt = [\"ENZYMES\",\"PROTEINS\",\"Mutagenicity\", \"DHFR\",\n",
    "        \"IMDB-BINARY\",\"IMDB-MULTI\",\"NCI1\",\"NCI109\"]\n",
    "cols_to_include = [\"test_accuracy_\"+d for d in datt]\n",
    "to_report_mean = to_report_mean.iloc[:,[c in cols_to_include for c in to_report_mean.columns]]\n",
    "baselines = [#'No Pooling', \n",
    "    'DiffPool',  'Graclus',\n",
    "       'MinCut', 'NDP', 'NMF', 'SAGPool', \"TopK\",\n",
    "       'MagEdge', 'SpreadEdge']\n",
    "all_means = to_report_mean[to_report_mean.index.isin(baselines)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35d0c2a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_means.rank(axis=0, ascending=False).mean(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf9ee439",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import wilcoxon\n",
    "import numpy as np\n",
    "\n",
    "def holm_bonferroni(p_values, alpha):\n",
    "    \"\"\"\n",
    "    Apply Holm–Bonferroni correction to a list of (model, p_value) tuples.\n",
    "    Returns a set of model names that are statistically significantly different.\n",
    "    \"\"\"\n",
    "    sorted_p = sorted(enumerate(p_values), key=lambda x: x[1])  # (index, p-value)\n",
    "    m = len(p_values)\n",
    "    significant = set()\n",
    "\n",
    "    for i, (idx, p) in enumerate(sorted_p):\n",
    "        threshold = alpha / (m - i)\n",
    "        if p <= threshold:\n",
    "            significant.add(idx)\n",
    "        else:\n",
    "            break\n",
    "\n",
    "    return significant\n",
    "\n",
    "def compare_models_with_wilcoxon(accuracy_dict, alpha=0.05, exclude_as_best=None):\n",
    "    if exclude_as_best is None:\n",
    "        exclude_as_best = []\n",
    "\n",
    "    results = {}\n",
    "\n",
    "    for dataset, model_accuracies in accuracy_dict.items():\n",
    "        results[dataset] = {}\n",
    "\n",
    "        # Exclude certain models from being selected as best\n",
    "        eligible_models = {model: accs for model, accs in model_accuracies.items() if model not in exclude_as_best}\n",
    "        if not eligible_models:\n",
    "            raise ValueError(f\"No eligible models for best model in dataset '{dataset}'\")\n",
    "\n",
    "        mean_accuracies = {model: np.mean(accs) for model, accs in eligible_models.items()}\n",
    "        best_model = max(mean_accuracies, key=mean_accuracies.get)\n",
    "        best_scores = model_accuracies[best_model]\n",
    "\n",
    "        p_values = []\n",
    "        model_list = []\n",
    "        stats_by_model = {}\n",
    "\n",
    "        for model, scores in eligible_models.items():\n",
    "            if model == best_model:\n",
    "                stats_by_model[model] = {\n",
    "                    'p_value': None,\n",
    "                    'significant': False,\n",
    "                    'comparable_to_best': True\n",
    "                }\n",
    "                continue\n",
    "\n",
    "            try:\n",
    "                stat, p_value = wilcoxon(best_scores, scores)\n",
    "            except ValueError:\n",
    "                p_value = 1.0\n",
    "\n",
    "            p_values.append(p_value)\n",
    "            model_list.append(model)\n",
    "            stats_by_model[model] = {'p_value': p_value}  # fill in later\n",
    "\n",
    "        # Apply Holm-Bonferroni\n",
    "        sig_indices = holm_bonferroni(p_values, alpha)\n",
    "\n",
    "        for i, model in enumerate(model_list):\n",
    "            is_significant = i in sig_indices\n",
    "            stats_by_model[model]['significant'] = is_significant\n",
    "            stats_by_model[model]['comparable_to_best'] = not is_significant\n",
    "\n",
    "        results[dataset] = stats_by_model\n",
    "\n",
    "    return results\n",
    "\n",
    "\n",
    "accs_allowed = [\"MAG_EDGE_diffusion_distance\",  \"SPREAD_EDGE_diffusion_distance\", \"Flat\",\"NDP\", \"Graclus\", \"NMF\", \"TopK\", \"SAGPool\", \"DiffPool\", \"MinCut\"]\n",
    "accs_per_dataset = {}\n",
    "for dataset, model_accuracies in acc_per_dataset.items():\n",
    "    accs_per_dataset[dataset] = {model: accs for model, accs in model_accuracies.items() if model in accs_allowed}\n",
    "\n",
    "results = compare_models_with_wilcoxon(accs_per_dataset, exclude_as_best=\"Flat\")\n",
    "#pd.DataFrame(results)\n",
    "for dataset, model_results in results.items():\n",
    "    print(f\"\\nDataset: {dataset}\")\n",
    "    print(\"Comparable to best:\")\n",
    "    for model, stats in model_results.items():\n",
    "        if model in accs_allowed:\n",
    "            model = pd.Series(model).replace(name_to).values[0]\n",
    "            #if stats['comparable_to_best']:\n",
    "            try:\n",
    "                print(f\"  - {model} ({stats['p_value']:.4f}, {stats['comparable_to_best']})\")\n",
    "            except:\n",
    "                print(f\"  - {model} ({stats['p_value']}, {stats['comparable_to_best']})\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18007a8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from magni.src.graph_classification.jsons_to_csvs import get_method_names, json_to_df\n",
    "from sklearn.metrics import roc_auc_score, average_precision_score, balanced_accuracy_score\n",
    "dataset = \"ogbg-molhiv\"\n",
    "model = \"GNN\"\n",
    "methods = get_method_names()#[:10]\n",
    "path = \"../src/graph_classification/results/\"\n",
    "dataset_path = f\"{path}/{dataset}\"\n",
    "\n",
    "methods_found = []\n",
    "all_rocs = []\n",
    "if not os.path.exists(dataset_path):\n",
    "    print(f\"Path {dataset_path} does not exist.\")\n",
    "else:\n",
    "    dfs_means = []\n",
    "    dfs_stds = []\n",
    "    all_acc = {}\n",
    "    acc_json = f\"{path}/{dataset}/{dataset}_{model}_accuracies.json\"\n",
    "    for method in methods:\n",
    "        print(f\"Method: {method}\")\n",
    "        rocs = []\n",
    "        balanced_accs = []\n",
    "        avg_precs= []\n",
    "        json_path = f\"{path}/{dataset}/{model}_{method}_{dataset}_stratified.json\"\n",
    "        try:\n",
    "            df = json_to_df(json_path)\n",
    "            #print(df)\n",
    "        except json.decoder.JSONDecodeError:\n",
    "            print(f\"Error decoding JSON for {json_path}.\")\n",
    "            continue\n",
    "        if (df is None):\n",
    "            continue\n",
    "        elif (df.empty):\n",
    "            continue\n",
    "        elif (df.shape[0] == 0):\n",
    "            continue\n",
    "        else:\n",
    "            n_runs = json_to_df(json_path, \"runs\")\n",
    "            for run in range(n_runs):\n",
    "                targets = pd.read_csv(f\"{path}/{dataset}/{model}_{method}_{dataset}_stratified_targets_{run}.csv\")\n",
    "                predictions = pd.read_csv(f\"{path}/{dataset}/{model}_{method}_{dataset}_stratified_predictions_{run}.csv\")\n",
    "\n",
    "                rocs.append(roc_auc_score(targets, predictions))\n",
    "                balanced_accs.append(balanced_accuracy_score(np.argmax(targets, axis=1), np.argmax(predictions, axis=1)))\n",
    "                avg_precs.append(average_precision_score(targets, predictions))\n",
    "            methods_found.append(method)\n",
    "            all_rocs.append(rocs)\n",
    "\n",
    "            print(f\"ROC-AUC: {round(np.mean(rocs), 3)} ± {round(np.std(rocs), 3)}\")\n",
    "            print(f\"Balanced Accuracy: {np.mean(balanced_accs)} ± {np.std(balanced_accs)}\")\n",
    "            print(f\"Average Precision: {np.mean(avg_precs)} ± {np.std(avg_precs)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85fbe3b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "dit = {}\n",
    "for i, method in enumerate(methods_found):\n",
    "    if method not in dit:\n",
    "        dit[method] = all_rocs[i]\n",
    "pd.DataFrame(dit).mean(axis=0).round(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88f6c3e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "dit"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
