{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "251a30ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import sklearn.metrics\n",
    "import scipy.stats\n",
    "import os\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9f1e3fbd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "101\n"
     ]
    }
   ],
   "source": [
    "df_all = []\n",
    "# get all filnames beginning with results/geodesic_optimization_results_\n",
    "filenames = os.listdir('./results/')\n",
    "print(len(filenames))\n",
    "filenames = [f for f in filenames if re.match(\n",
    "    r'geodesic_optimization_results_\\d+', f\n",
    "    )]\n",
    "filenames = sorted(filenames, key=lambda x: int(re.search(r'\\d+', x).group()))\n",
    "for filename in filenames:\n",
    "    df = pd.read_csv(f'./results/{filename}', index_col=0)\n",
    "    df_all.append(df)\n",
    "df_all = pd.concat(df_all)\n",
    "df_all.reset_index(inplace=True, drop=True)\n",
    "\n",
    "df_all = df_all.drop(\n",
    "    df_all[df_all['sites_source_index']==6].index\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d1b64055",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import t\n",
    "\n",
    "\n",
    "def corrected_std(differences, n_train, n_test):\n",
    "    \"\"\"Corrects standard deviation using Nadeau and Bengio's approach.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    differences : ndarray of shape (n_samples,)\n",
    "        Vector containing the differences in the score metrics of two models.\n",
    "    n_train : int\n",
    "        Number of samples in the training set.\n",
    "    n_test : int\n",
    "        Number of samples in the testing set.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    corrected_std : float\n",
    "        Variance-corrected standard deviation of the set of differences.\n",
    "    \"\"\"\n",
    "    # kr = k times r, r times repeated k-fold crossvalidation,\n",
    "    # kr equals the number of times the model was evaluated\n",
    "    kr = len(differences)\n",
    "    corrected_var = np.var(differences, ddof=1) * (1 / kr + n_test / n_train)\n",
    "    corrected_std = np.sqrt(corrected_var)\n",
    "    return corrected_std\n",
    "\n",
    "\n",
    "def compute_corrected_ttest(differences, df, n_train, n_test):\n",
    "    \"\"\"Computes right-tailed paired t-test with corrected variance.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    differences : array-like of shape (n_samples,)\n",
    "        Vector containing the differences in the score metrics of two models.\n",
    "    df : int\n",
    "        Degrees of freedom.\n",
    "    n_train : int\n",
    "        Number of samples in the training set.\n",
    "    n_test : int\n",
    "        Number of samples in the testing set.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    t_stat : float\n",
    "        Variance-corrected t-statistic.\n",
    "    p_val : float\n",
    "        Variance-corrected p-value.\n",
    "    \"\"\"\n",
    "    mean = np.mean(differences)\n",
    "    std = corrected_std(differences, n_train, n_test)\n",
    "    t_stat = mean / std\n",
    "    return t_stat\n",
    "\n",
    "\n",
    "def get_p_val(t_stat, df=100):\n",
    "    p_val = t.sf(np.abs(t_stat), df)  # right-tailed t-test\n",
    "    return p_val\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "45f8ebbd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['dummy', 'baseline_no_recenter', 'baseline_recenter',\n",
       "       'baseline_fit_intercept', 'geodesic_optim'], dtype=object)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_all.method.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "64de5b78",
   "metadata": {},
   "outputs": [],
   "source": [
    "source_sizes = {1: 739, 2: 389, 3: 294, 4: 592, 5: 1093}\n",
    "contrast_pairs = [\n",
    "    ('dummy', 'baseline_no_recenter'),\n",
    "    ('dummy', 'baseline_recenter'),\n",
    "    ('baseline_no_recenter', 'baseline_fit_intercept'),\n",
    "    ('baseline_fit_intercept', 'geodesic_optim')\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d2170c77",
   "metadata": {},
   "outputs": [],
   "source": [
    "p_vals = list()\n",
    "for ref, pro in contrast_pairs:\n",
    "    for source in df_all.sites_source_index.unique():\n",
    "    # for source in [1]:\n",
    "        source = int(source)\n",
    "    #     print(source, target)\n",
    "        df_ref = df_all.query(\n",
    "            f\"method == '{ref}' & sites_source_index == {source}\"\n",
    "        ).reset_index(drop=True)\n",
    "        df_pro = df_all.query(\n",
    "            f\"method == '{pro}' & sites_source_index == {source}\"\n",
    "        ).reset_index(drop=True)\n",
    "\n",
    "        assert np.all(df_ref.index == df_pro.index)\n",
    "        assert np.all(df_ref.y_true == df_pro.y_true)\n",
    "        differences = list()\n",
    "        for gg, index in df_ref.groupby('split_index').groups.items():\n",
    "            df_sub1 = df_ref.iloc[index]\n",
    "            df_sub2 = df_pro.iloc[index]\n",
    "            sp_ref = scipy.stats.spearmanr(df_sub1['y_true'], df_sub1['y_pred']).statistic\n",
    "            sp_pro = scipy.stats.spearmanr(df_sub2['y_true'], df_sub2['y_pred']).statistic\n",
    "            mae_ref = sklearn.metrics.mean_absolute_error(df_sub1['y_true'], df_sub1['y_pred'])\n",
    "            mae_pro = sklearn.metrics.mean_absolute_error(df_sub2['y_true'], df_sub2['y_pred'])\n",
    "            r2_ref = sklearn.metrics.r2_score(df_sub1['y_true'], df_sub1['y_pred'])\n",
    "            r2_pro = sklearn.metrics.r2_score(df_sub2['y_true'], df_sub2['y_pred'])\n",
    "\n",
    "            differences.append(\n",
    "                dict(spearman=sp_pro - sp_ref,\n",
    "                     mae=mae_pro - mae_ref,\n",
    "                     r2=r2_pro - r2_ref)\n",
    "            )\n",
    "        differences = pd.DataFrame(differences)\n",
    "        n_train = source_sizes[source]\n",
    "        n_test = int(len(df_sub1['y_pred']) // 2)\n",
    "        df_p = differences.apply(lambda x: get_p_val(compute_corrected_ttest(differences=x, df=100, n_train=n_train, n_test=n_test)))\n",
    "        df_p['contr'] = f'{pro}-{ref}'\n",
    "        df_p['source'] = source\n",
    "        p_vals.append(df_p)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6c6f0176",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pval = pd.concat(p_vals, axis=1).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "bb0a9e39",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>spearman</th>\n",
       "      <th>mae</th>\n",
       "      <th>r2</th>\n",
       "      <th>contr</th>\n",
       "      <th>source</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000165</td>\n",
       "      <td>baseline_no_recenter-dummy</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000202</td>\n",
       "      <td>baseline_no_recenter-dummy</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.352656</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000012</td>\n",
       "      <td>baseline_no_recenter-dummy</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.342457</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.329652</td>\n",
       "      <td>baseline_no_recenter-dummy</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.026381</td>\n",
       "      <td>0.000001</td>\n",
       "      <td>0.000271</td>\n",
       "      <td>baseline_no_recenter-dummy</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.307687</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>baseline_recenter-dummy</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0.000019</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>baseline_recenter-dummy</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>baseline_recenter-dummy</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>baseline_recenter-dummy</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>baseline_recenter-dummy</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>baseline_fit_intercept-baseline_no_recenter</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>0.006434</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>baseline_fit_intercept-baseline_no_recenter</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>0.00058</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>baseline_fit_intercept-baseline_no_recenter</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>baseline_fit_intercept-baseline_no_recenter</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>baseline_fit_intercept-baseline_no_recenter</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.056037</td>\n",
       "      <td>0.015017</td>\n",
       "      <td>geodesic_optim-baseline_fit_intercept</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>0.000001</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>geodesic_optim-baseline_fit_intercept</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>0.000534</td>\n",
       "      <td>0.04902</td>\n",
       "      <td>0.14088</td>\n",
       "      <td>geodesic_optim-baseline_fit_intercept</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>0.414214</td>\n",
       "      <td>0.001135</td>\n",
       "      <td>0.025483</td>\n",
       "      <td>geodesic_optim-baseline_fit_intercept</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.002563</td>\n",
       "      <td>0.000052</td>\n",
       "      <td>geodesic_optim-baseline_fit_intercept</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    spearman       mae        r2                                        contr  \\\n",
       "0        0.0       0.0  0.000165                   baseline_no_recenter-dummy   \n",
       "1        0.0       0.0  0.000202                   baseline_no_recenter-dummy   \n",
       "2   0.352656       0.0  0.000012                   baseline_no_recenter-dummy   \n",
       "3   0.342457       0.0  0.329652                   baseline_no_recenter-dummy   \n",
       "4   0.026381  0.000001  0.000271                   baseline_no_recenter-dummy   \n",
       "5   0.307687       0.0       0.0                      baseline_recenter-dummy   \n",
       "6   0.000019       0.0       0.0                      baseline_recenter-dummy   \n",
       "7        0.0       0.0       0.0                      baseline_recenter-dummy   \n",
       "8        0.0       0.0       0.0                      baseline_recenter-dummy   \n",
       "9        0.0       0.0       0.0                      baseline_recenter-dummy   \n",
       "10       0.0       0.0       0.0  baseline_fit_intercept-baseline_no_recenter   \n",
       "11  0.006434       0.0       0.0  baseline_fit_intercept-baseline_no_recenter   \n",
       "12   0.00058       0.0       0.0  baseline_fit_intercept-baseline_no_recenter   \n",
       "13       0.0       0.0       0.0  baseline_fit_intercept-baseline_no_recenter   \n",
       "14       0.0       0.0       0.0  baseline_fit_intercept-baseline_no_recenter   \n",
       "15       0.0  0.056037  0.015017        geodesic_optim-baseline_fit_intercept   \n",
       "16  0.000001       0.0  0.000003        geodesic_optim-baseline_fit_intercept   \n",
       "17  0.000534   0.04902   0.14088        geodesic_optim-baseline_fit_intercept   \n",
       "18  0.414214  0.001135  0.025483        geodesic_optim-baseline_fit_intercept   \n",
       "19       0.0  0.002563  0.000052        geodesic_optim-baseline_fit_intercept   \n",
       "\n",
       "   source  \n",
       "0       1  \n",
       "1       2  \n",
       "2       3  \n",
       "3       4  \n",
       "4       5  \n",
       "5       1  \n",
       "6       2  \n",
       "7       3  \n",
       "8       4  \n",
       "9       5  \n",
       "10      1  \n",
       "11      2  \n",
       "12      3  \n",
       "13      4  \n",
       "14      5  \n",
       "15      1  \n",
       "16      2  \n",
       "17      3  \n",
       "18      4  \n",
       "19      5  "
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_pval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "9e1a81c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pval.to_csv('./results/p_values.csv')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "synapse",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
