{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ground Truth problems\n",
    "\n",
    "These are problems for which the data generating process is a known model, \n",
    "\n",
    "$$ y = \\phi^*(\\mathbf{x}, \\theta^*) $$\n",
    "\n",
    "We assess how well symbolic regression algorithms find the form of the model, $\\phi^*$, with some leniency on $\\theta^*$ (we allow the model to be off by a constant or a scalar). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from tabulate import tabulate\n",
    "import pandas as pd\n",
    "import json\n",
    "import numpy as np\n",
    "from glob import glob\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "matplotlib.rc('pdf', fonttype=42)\n",
    "import os\n",
    "sns.set(font_scale=1.2)\n",
    "rdir = '../results/'\n",
    "figdir = \"../experiment/figs/\"\n",
    "\n",
    "def save(h=None,name='tmp'):\n",
    "    name = name.strip().replace(' ','-').replace('%','pct')\n",
    "    if h == None:\n",
    "        h = plt.gcf()\n",
    "    h.tight_layout()\n",
    "    print('saving',name+'.pdf')\n",
    "    if not os.path.exists(figdir):\n",
    "        os.makedirs(figdir)\n",
    "    plt.savefig(figdir+'/'+name+'.pdf', dpi=400, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# read data from feather"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_results = pd.read_feather(rdir+'ground-truth_results.feather')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_results.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "htssr_mask = (df_results[\"algorithm\"] == \"htssr\")\n",
    "if htssr_mask.any():\n",
    "    df_results = df_results[htssr_mask]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# compute symbolic solutions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_results.loc[:,'symbolic_solution'] = df_results[['symbolic_error_is_zero',\n",
    "                                                    'symbolic_error_is_constant',\n",
    "                                                    'symbolic_fraction_is_constant']\n",
    "                                                   ].apply(any,raw=True, axis=1)\n",
    "# clean up any corner cases (constant models, failures)\n",
    "df_results.loc[:,'symbolic_solution'] = df_results['symbolic_solution'] & ~df_results['simplified_symbolic_model'].isna() \n",
    "df_results.loc[:,'symbolic_solution'] = df_results['symbolic_solution'] & ~(df_results['simplified_symbolic_model'] == '0')\n",
    "df_results.loc[:,'symbolic_solution'] = df_results['symbolic_solution'] & ~(df_results['simplified_symbolic_model'] == 'nan')\n",
    "\n",
    "# save results for detailed tabulating\n",
    "df_results.to_feather(rdir+'ground-truth_solns.feather')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert \"target_noise\" in df_results.columns\n",
    "if \"model_size\" not in df_results.columns:\n",
    "    df_results[\"model_size\"] = df_results[\"symbolic_model\"].apply(lambda x: len(x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## summarize results by dataset, including ranking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_results2 = df_results.merge(df_results.groupby(['dataset','target_noise'])['algorithm'].nunique().reset_index(),\n",
    "                              on=['dataset','target_noise'],suffixes=('','_count'))\n",
    "# count repeat trials\n",
    "df_results2 = df_results2.merge(\n",
    "           df_results2.groupby(['algorithm','dataset','target_noise'])['random_state'].nunique().reset_index(),\n",
    "           on=['algorithm','dataset','target_noise'],suffixes=('','_repeats'))\n",
    "\n",
    "# accuracy-based exact solutions \n",
    "df_results2['accuracy_solution'] = df_results2['r2_test'].apply(lambda x: x > 0.999).astype(float)\n",
    "\n",
    "# get mean solution rates for algs on datasets at specific noise levels, averaged over trials \n",
    "for soln in ['accuracy_solution','symbolic_solution']:\n",
    "    df_results2 = df_results2.merge(\n",
    "        df_results2.groupby(['algorithm','dataset','target_noise'])[soln].mean().reset_index(),\n",
    "                                  on=['algorithm','dataset', 'target_noise'],suffixes=('','_rate'))\n",
    "                                       \n",
    "# # rankings\n",
    "for col in [c for c in df_results2.columns if c.endswith('test') or c.endswith('size')]:\n",
    "    ascending = 'r2' not in col\n",
    "    # df_results2[col+'_rank_per_trial']=df_results2.groupby(\n",
    "    #                     ['dataset','target_noise','random_state'])[col].apply(\n",
    "    #                                                                           lambda x: \n",
    "    #                                                                           round(x,3).rank(\n",
    "    #                                                                           ascending=ascending).astype(int))\n",
    "    df_results2[col + '_rank_per_trial'] = (\n",
    "        df_results2\n",
    "        .groupby(['dataset', 'target_noise', 'random_state'])[col]\n",
    "        .transform(lambda x: x.round(3).rank(ascending=ascending).astype(int))\n",
    "    )\n",
    "    # df_results2[col + '_rank_per_trial'] = (\n",
    "    #     df_results2[col]                      # pega a coluna\n",
    "    #     .round(3)                          # arredonda\n",
    "    #     .groupby([df_results2['dataset'],\n",
    "    #     df_results2['target_noise'],\n",
    "    #     df_results2['random_state']])\n",
    "    #     .rank(method='first', ascending=ascending)  # rankeia dentro do grupo\n",
    "    #     .astype(int)\n",
    "    # )\n",
    "\n",
    "\n",
    "# df_sum = df_results2.groupby(['algorithm','dataset','target_noise','data_group'],as_index=False).median()\n",
    "df_sum = df_results2.groupby(['algorithm','dataset','target_noise','data_group'],as_index=False).median(numeric_only=True)\n",
    "# rankings and normalized scores per dataset\n",
    "for col in [c for c in df_sum.columns if any([c.endswith(n) for n in ['test','size','rate']])]:\n",
    "    ascending = 'r2' not in col and 'solution' not in col\n",
    "    # df_sum[col+'_rank']=df_sum.groupby(['dataset','target_noise'])[col].apply(\n",
    "    #     lambda x:  round(x,3).rank(ascending=ascending).astype(int) )\n",
    "    df_sum[col + '_rank'] = (\n",
    "        df_sum\n",
    "        .groupby(['dataset', 'target_noise'])[col]\n",
    "        .transform(lambda x: x.round(3).rank(ascending=ascending).astype(int))\n",
    "    )\n",
    "    # df_sum[col+'_norm'] = df_sum.groupby(['dataset','target_noise'])[col].apply(lambda x: (x-x.min())/(x.max()-x.min()))\n",
    "    df_sum[col+'_norm'] = df_sum.groupby(['dataset','target_noise'])[col].transform(lambda x: (x-x.min())/(x.max()-x.min()))\n",
    "# df_sum['success_rate'] = df_results2.groupby(['algorithm','dataset'])['solution'].mean().reset_index()\n",
    "for soln in ['accuracy_solution','symbolic_solution']:\n",
    "    df_sum[soln +'_rate_(%)'] = df_sum[soln+'_rate'].apply(lambda x: x*100)\n",
    "df_sum['rmse_test'] = df_sum['mse_test'].apply(np.sqrt)\n",
    "df_sum['log_mse_test'] = df_sum['mse_test'].apply(lambda x: np.log(1+x))\n",
    "df_results = df_results2\n",
    "# df_sum"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# save summary data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not htssr_mask.any():\n",
    "    df_sum.to_csv(rdir+'/GT_symbolic_dataset_results_sum.csv.gz',compression='gzip', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bench_sum_df =  pd.read_csv(rdir+'/GT_symbolic_dataset_results_sum.csv.gz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set(bench_sum_df.columns) - set(df_sum.columns), set(df_sum.columns) - set(bench_sum_df.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del bench_sum_df['feature_noise']\n",
    "del bench_sum_df['process_time']\n",
    "del df_sum['simplicity']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merger_df = bench_sum_df.groupby([\"dataset\", \"target_noise\", \"data_group\"]).sum().reset_index()[[\"dataset\", \"target_noise\", \"data_group\"]]\n",
    "ext_sum_df = pd.merge(merger_df, df_sum, how=\"left\", on=[\"dataset\", \"target_noise\", \"data_group\"])\n",
    "ext_sum_df[\"algorithm\"] = \"htssr\"\n",
    "ext_sum_df[\"symbolic_solution_rate_(%)\"] = ext_sum_df[\"symbolic_solution_rate_(%)\"].fillna(0.0)\n",
    "df_sum = pd.concat([bench_sum_df, ext_sum_df])\n",
    "df_sum = df_sum[df_sum[\"target_noise\"].isin([0.0, 0.01])]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# solution rates by alg/dataset/noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sum.groupby(['target_noise','algorithm','data_group'])['symbolic_solution_rate_(%)'].mean().round(2).unstack() #.transpose()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# plot comparisons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_style('whitegrid')\n",
    "def compare(df_compare=None, x='r2_test',y='algorithm', row=None, col=None, scale=None, xlim=[], est=np.mean,\n",
    "            orient='h', hue=None, **kwargs):\n",
    "    df_compare = df_compare.copy()\n",
    "    if row==None and col == None:\n",
    "        aspect=1\n",
    "    else:\n",
    "        aspect=0.55\n",
    "#     plt.figure(figsize=(8,7))\n",
    "    tmp = df_compare.groupby(['target_noise',y])[x].apply(est).unstack().mean()\n",
    "    order = tmp.sort_values(ascending=False).index\n",
    "    \n",
    "    for c in [x,y,row,col]:\n",
    "        if c:\n",
    "            df_compare = df_compare.rename(columns={c:c.replace('_',' ').title()})\n",
    "        \n",
    "    x = x.replace('_',' ').title()\n",
    "    y = y.replace('_',' ').title()\n",
    "    if row:\n",
    "        row = row.replace('_',' ').title()\n",
    "    if col:\n",
    "        col = col.replace('_',' ').title()\n",
    "    \n",
    "    if scale=='log' and len(xlim)>0 and xlim[0] == 0:\n",
    "        df_compare.loc[:,x] += 1\n",
    "        xlim[0] = 1\n",
    "        xnew = '1 + '+x\n",
    "        df_compare=df_compare.rename(columns={x:xnew})\n",
    "        x = xnew\n",
    "    if orient=='v':\n",
    "        tmp = x\n",
    "        x = y\n",
    "        y = tmp\n",
    "    if col and not row:\n",
    "        col_wrap = min(4, df_compare[col].nunique()) \n",
    "    else:\n",
    "        col_wrap=None\n",
    "        \n",
    "    cat_args = dict(\n",
    "                data=df_compare, \n",
    "                kind='point',\n",
    "                y=y,\n",
    "                x=x,\n",
    "                order=order,\n",
    "                row=row,\n",
    "                col=col,\n",
    "                col_wrap=col_wrap,\n",
    "                palette='flare_r',\n",
    "                #  palette='Paired',\n",
    "                margin_titles=True,\n",
    "                aspect=aspect,\n",
    "                hue=hue,\n",
    "                legend_out=False,\n",
    "    )\n",
    "    cat_args.update(kwargs)\n",
    "    g = sns.catplot( **cat_args )\n",
    "    if hue:\n",
    "        g._legend.remove() #(title=hue.replace('_',' ').title())\n",
    "        g.axes.flat[-1].legend(title=hue.replace('_',' ').title(),\n",
    "                               fontsize=10\n",
    "                              )\n",
    "    for ax in g.axes.flat: \n",
    "        ax.yaxis.grid(True)\n",
    "        ax.set_ylabel('')\n",
    "        ax.set_xlabel(ax.get_xlabel().replace('Symbolic ',''))\n",
    "        if col:\n",
    "            ttl = ax.get_title()\n",
    "            ax.set_title(ttl.replace(col,'').replace('=',''))\n",
    "\n",
    "    \n",
    "    if len(xlim)>0:\n",
    "        plt.xlim(xlim[0],xlim[1])\n",
    "    if scale:\n",
    "        plt.gca().set_xscale(scale)\n",
    "\n",
    "    sns.despine(left=True, bottom=True)\n",
    "    savename = '-'.join(['cat-'+cat_args['kind']+'plot',x+ '-by-'+ y])\n",
    "    if row: savename += '_'+row\n",
    "    if col: savename += '_'+col\n",
    "    \n",
    "    save(g, savename )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for metric in ['symbolic_solution_rate_(%)']:\n",
    "# for metric in ['symbolic_solution_rate_(%)','r2_test','accuracy_solution']:\n",
    "# for metric in ['r2_test']:\n",
    "    for kind in ['point']: #,'strip']:\n",
    "        args =dict(df_compare=df_sum, x=metric, est=np.mean, orient='h',\n",
    "                   kind=kind) \n",
    "        if kind=='point':\n",
    "            args['join'] = False \n",
    "            # args['markers']=['o','s','x','+']\n",
    "            args['markers']=['x','+','s', 'o']\n",
    "        if metric == 'r2_test':\n",
    "            args['xlim'] = [-1, 1]\n",
    "        compare(**args,\n",
    "                hue='target_noise', \n",
    "                col=None,\n",
    "                ) \n",
    "        compare(**args, \n",
    "                hue='target_noise', \n",
    "                col='data_group',\n",
    "                ) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make the PairGrid\n",
    "df_plot = df_sum.copy()\n",
    "tmp = df_plot.groupby(['target_noise','algorithm'])['symbolic_solution_rate'].mean().unstack().mean()\n",
    "order = tmp.sort_values(ascending=False).index\n",
    "df_plot['size_diff'] = df_plot['model_size']-df_plot['simplified_complexity']+1\n",
    "x_vars=[\n",
    "#         'accuracy_solution_rate_(%)',\n",
    "#         'mse_test',\n",
    "#         'r2_test_rank',\n",
    "#         'r2_test_norm',\n",
    "        'symbolic_solution_rate_(%)',\n",
    "        'r2_test',\n",
    "        'simplified_complexity',\n",
    "#         'size_diff',\n",
    "#         'model_size',\n",
    "#         'training time (s)',\n",
    "#         'solution'\n",
    "]\n",
    "g = sns.PairGrid(df_plot, \n",
    "                 x_vars=x_vars,\n",
    "                 y_vars=['algorithm'],\n",
    "                 height=6.5, \n",
    "                 aspect=0.7,\n",
    "                 hue='target_noise',\n",
    "#                  hue_order=[0.01,0.001,0]\n",
    "#                  hue='dataset'\n",
    "                )\n",
    "g.map(sns.pointplot, \n",
    "#       size=10,\n",
    "      orient=\"h\",\n",
    "      # jitter=False,\n",
    "      order=order,\n",
    "      palette=\"flare_r\",\n",
    "      errwidth=2,\n",
    "      linewidth=0.01,\n",
    "      markeredgecolor='w',\n",
    "      join=False,\n",
    "      estimator=np.mean,\n",
    "      n_boot=1000,\n",
    "      markers=['x','o','s','+'],\n",
    "      # markeralpha=0.5\n",
    "      alpha=0.5\n",
    "     )\n",
    "plt.legend(title='Target Noise')\n",
    "titles = [x.replace('_',' ').title().replace('(S)','(s)') for x in x_vars]\n",
    "\n",
    "for ax, title in zip(g.axes.flat, titles):\n",
    "\n",
    "    # remove xlabel\n",
    "    ax.set_xlabel('')\n",
    "    ax.set_ylabel('')\n",
    "    # Set a different title for each axes\n",
    "    ax.set(title=title)\n",
    "    \n",
    "    if any([n in title.lower() for n in ['size','complexity','time']]):\n",
    "        ax.set_xscale('log')\n",
    "    if 'R2' in title and 'Rank' not in title:\n",
    "        ax.set(title=title.replace('R2','$R^2$'))\n",
    "        ax.set_xlim([0,1])\n",
    "\n",
    "    # Make the grid horizontal instead of vertical\n",
    "    ax.yaxis.grid(True)\n",
    "save(g, 'pairgrid_'+'_'.join(x_vars))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "palinspa",
   "language": "python",
   "name": "palinspa"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
