{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h4 style=\"margin:3px;padding:3px;\">Walkthrough</h4>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This walkthrough is used to generate the plots and tables in the TabMini paper. For convenience, we have already exported our benchmark results to Microsoft Excel and added a tab in the long format. We have then saved the [Excel file](results/test_scores.xlsx) as well as the tabs in the wide and long format as [test_scores_wide.csv](results/test_scores_wide.csv) and [test_scores_long.csv](results/test_scores_long.csv), respectively. In order to run the cells, you need to have [CriticalDifferenceDiagrams.jl](https://mirkobunse.github.io/CriticalDifferenceDiagrams.jl/stable/), [CSV.jl](https://csv.juliadata.org/stable/), [DataFrames.jl](https://dataframes.juliadata.org/stable/), [PGFPlots.jl](https://kristofferc.github.io/PGFPlotsX.jl/stable/), [Plots.jl](https://docs.juliaplots.org/stable/), [PyCall.jl](https://github.com/JuliaPy/PyCall.jl), [StatsBase.jl](https://juliastats.org/StatsBase.jl/stable/), and [StatsPlots.jl](https://github.com/JuliaPlots/StatsPlots.jl) installed. Additionally, you need the Python libraries [numpy](https://numpy.org/), [pandas](https://pandas.pydata.org/), and [PyMFE](https://pymfe.readthedocs.io/en/latest/)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h5>Imports</h5>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "include(\"helpers/generate_correlations.jl\")\n",
    "include(\"helpers/generate_metafeatures.jl\")\n",
    "\n",
    "using CriticalDifferenceDiagrams\n",
    "using CSV\n",
    "using DataFrames\n",
    "using PGFPlots\n",
    "using Plots\n",
    "using StatsBase\n",
    "using StatsPlots\n",
    "\n",
    "\n",
    "results_wide = CSV.read(\"results/test_scores_wide.csv\", DataFrame)\n",
    "results_long = CSV.read(\"results/test_scores_long.csv\", DataFrame);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h5>Meta-Feature Generation</h5>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# write CSV\n",
    "py\"generate_metafeatures\"(\"results/test_scores_wide.csv\")\n",
    "# read CSV\n",
    "metafeatures = CSV.read(\"metafeatures.csv\", DataFrame);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h5>Figure 1</h5>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Figure 1 is a composite figure and made up of other figures, generated below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h5>Table 1</h5>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Table 1 is constructed with values from the relevant studies."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h5>Figure 2a</h5>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = [\"AutoPrognosis\" \"AutoGluon\" \"TabPFN\" \"HyperFast\" \"Logistic regression\"]\n",
    "sample_size_ranges = [1:12, 13:22, 23:31, 32:39, 40:44]\n",
    "xticks_labels = [\"32 to 100\", \"101 to 200\", \"201 to 300\", \"301 to 400\", \"401 to 500\"]\n",
    "\n",
    "Q3s = zeros(length(sample_size_ranges), length(methods))\n",
    "Q2s = zeros(length(sample_size_ranges), length(methods))\n",
    "Q1s = zeros(length(sample_size_ranges), length(methods))\n",
    "for (idx_a, approach) in enumerate(methods)\n",
    "    for (idx_r, sample_size_range) in enumerate(sample_size_ranges)\n",
    "        Q3s[idx_r, idx_a] = quantile(results_wide[sample_size_range, approach], 0.75)\n",
    "        Q2s[idx_r, idx_a] = quantile(results_wide[sample_size_range, approach], 0.5)\n",
    "        Q1s[idx_r, idx_a] = quantile(results_wide[sample_size_range, approach], 0.25)\n",
    "    end\n",
    "end\n",
    "\n",
    "Plots.plot(Q2s,\n",
    "    ribbon=(Q2s .- Q1s, Q3s .- Q2s),\n",
    "    fillalpha=0.3,\n",
    "    ylabel=\"Mean test AUC\",\n",
    "    xlabel=\"Sample size range\",\n",
    "    xticks=(1:5, xticks_labels),\n",
    "    label=methods,\n",
    "    linewidth=5,\n",
    "    legend=:bottomright,\n",
    "    margin=10Plots.mm,\n",
    "    palette=:tab10\n",
    ")\n",
    "\n",
    "# Plots.scalefontsizes(1.2)\n",
    "# savefig(\"plots/auc.svg\")\n",
    "# savefig(\"plots/auc.pdf\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h5>Figure 2b</h5>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cdd_plot = CriticalDifferenceDiagrams.plot(\n",
    "    results_long,\n",
    "    :approach,\n",
    "    :dataset,\n",
    "    :auc,\n",
    "    maximize_outcome=true\n",
    ")\n",
    "\n",
    "# PGFPlots.save(\"plots/cdd.svg\", cdd_plot)\n",
    "# PGFPlots.save(\"plots/cdd.pdf\", cdd_plot);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h5>Dataset Reduction</h5>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As mentioned in the experimental results, we have also performed pairwise mean test AUC comparisons for all datasets that TabPFN was not meta-trained on. This dataset reduction prevented us from finding statistically significant performance differences between any of\n",
    "the evaluated methods, though (p > 0.05)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# datasets that TabPFN was not meta-trained on\n",
    "datasets_reduced = [\n",
    "    # M = 32 - 100 (12 datasets)\n",
    "    [\"analcatdata_aids\", \"analcatdata_asbestos\", \"analcatdata_bankruptcy\", \"analcatdata_creditscore\",\n",
    "    \"analcatdata_cyyoung8092\", \"analcatdata_cyyoung9302\", \"analcatdata_fraud\", \"analcatdata_japansolvent\",\n",
    "    \"labor\", \"lupus\", \"parity5\", \"postoperative_patient_data\"],\n",
    "    # M = 101 - 200 (6 datasets)\n",
    "    [\"analcatdata_boxing1\", \"analcatdata_boxing2\", \"appendicitis\", \"glass2\", \"molecular_biology_promoters\",\n",
    "    \"mux6\"],\n",
    "    # M = 201 - 300 (1 dataset)\n",
    "    [\"hungarian\"],\n",
    "    # M = 301 - 400 (3 datasets)\n",
    "    [\"bupa\", \"colic\", \"horse_colic\"],\n",
    "    # M = 401 - 500 (2 datasets)\n",
    "    [\"clean1\", \"house_votes_84\"]\n",
    "]\n",
    "\n",
    "# dataframe using only the reduced datasets\n",
    "results_long_reduced = DataFrame([String[], String[], Int[]], names(results_long))\n",
    "for datasets in datasets_reduced\n",
    "    for dataset in datasets\n",
    "        append!(results_long_reduced, results_long[results_long.dataset .== dataset, :])\n",
    "    end\n",
    "end\n",
    "\n",
    "cdd_plot_reduced = CriticalDifferenceDiagrams.plot(\n",
    "    results_long_reduced,\n",
    "    :approach,\n",
    "    :dataset,\n",
    "    :auc,\n",
    "    maximize_outcome=true\n",
    ")\n",
    "\n",
    "# PGFPlots.save(\"plots/cdd_reduced.svg\", cdd_plot_reduced)\n",
    "# PGFPlots.save(\"plots/cdd_reduced.pdf\", cdd_plot_reduced);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h5>Figure 3</h5>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = [\"AutoPrognosis\", \"AutoGluon\", \"TabPFN\", \"HyperFast\"]\n",
    "py\"generate_correlations\"(methods, \"results/test_scores_wide.csv\")\n",
    "\n",
    "# from correlations.txt\n",
    "clustering = [0, 0, 0, 1]\n",
    "complexity = [3, 3, 0, 0]\n",
    "concept = [0, 4, 0, 0]\n",
    "general = [0, 0, 0, 0]\n",
    "infotheory = [2, 0, 0, 0]\n",
    "itemset = [2, 0, 0, 0]\n",
    "landmarking = [0, 1, 0, 0]\n",
    "modelbased = [0, 0, 0, 0]\n",
    "statistical = [3, 2, 10, 9]\n",
    "\n",
    "StatsPlots.groupedbar(\n",
    "        [clustering complexity concept general infotheory itemset landmarking modelbased statistical],\n",
    "        bar_position=:stack,\n",
    "        xticks=(1:4, [\"AutoPrognosis\" \"AutoGluon\" \"TabPFN\" \"HyperFast\"]),\n",
    "        label=[\"Clustering\" \"Complexity\" \"Concept\" \"General\" \"Info theory\" \"Itemset\" \"Landmarking\" #=\n",
    "        =# \"Model-based\" \"Statistical\"],\n",
    "        linecolor=:white,\n",
    "        palette=:tab20\n",
    ")\n",
    "\n",
    "# Plots.scalefontsizes(1.2)\n",
    "# savefig(\"plots/bar.svg\")\n",
    "# savefig(\"plots/bar.pdf\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h5>Table 2</h5>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for characteristic in [\"nr_inst\", \"nr_attr\", \"freq_class.min\", \"EPV\", \"nr_bin\"]\n",
    "    print(\"$(characteristic):\\n\n",
    "          Mean: $(mean(metafeatures[!, characteristic]))\\n\n",
    "          Std: $(std(metafeatures[!, characteristic], corrected=false))\\n\n",
    "          Min: $(minimum(metafeatures[!, characteristic]))\\n\n",
    "          25%: $(quantile(metafeatures[!, characteristic], 0.25))\\n\n",
    "          50%: $(quantile(metafeatures[!, characteristic], 0.5))\\n\n",
    "          75%: $(quantile(metafeatures[!, characteristic], 0.75))\\n\n",
    "          Max: $(maximum(metafeatures[!, characteristic]))\\n\")\n",
    "end;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h5>Figure 4</h5>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Plots.scatter(metafeatures[!, \"nr_inst\"], metafeatures[!, \"nr_attr\"],\n",
    "    xlabel=\"Sample size\",\n",
    "    ylabel=\"Feature set size\",\n",
    "    ylim=(0, 70),\n",
    "    markersize=10,\n",
    "    markerstrokewidth=0,\n",
    "    palette=:tab10,\n",
    "    legend=false\n",
    ")\n",
    "# Plots.scalefontsizes(1.5)\n",
    "# savefig(\"plots/scatter.svg\")\n",
    "# savefig(\"plots/scatter.pdf\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h5>Table 3</h5>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Table 3 is constructed with raw values from our benchmark results."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h5>Figure 5</h5>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Figure 5 is made with the meta-feature correlations generated for Figure 3."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.8.3",
   "language": "julia",
   "name": "julia-1.8"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
