{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e4ec90b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "from natsort import natsorted\n",
    "import glob\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from copydf import copyDF\n",
    "import matplotlib.pyplot as plt\n",
    "import oapackage as oa\n",
    "from scipy.stats import hmean\n",
    "from natsort import natsorted\n",
    "\n",
    "import matplotlib.patches as mpatches\n",
    "import matplotlib.lines as mlines\n",
    "import seaborn as sns\n",
    "from matplotlib.colors import ListedColormap\n",
    "import matplotlib\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b318acdd",
   "metadata": {},
   "source": [
    "# Model Selection\n",
    "## integrate three selection strategies\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a719630",
   "metadata": {},
   "outputs": [],
   "source": [
    "# functions for DTO\n",
    "def l2norm(matrix_1, matrix_2):\n",
    "    \"\"\"calculate Euclidean distance\n",
    "    Args:\n",
    "        matrix_1 (n*d np array): n is the number of instances, d is num of metric\n",
    "        matrix_2 (n*d np array): same as matrix_1\n",
    "    Returns:\n",
    "        float: the row-wise Euclidean distance \n",
    "    \"\"\"\n",
    "    return np.power(np.sum(np.power(matrix_1-matrix_2, 2), axis=1), 0.5)\n",
    "    \n",
    "\n",
    "def DTO(performacne_metric):\n",
    "    \"\"\"calculate DTO for each condidate model\n",
    "    Args:\n",
    "        performacne_metric (array): 2D array, row: models, column: AUC of different subgroups\n",
    "    \"\"\"\n",
    "\n",
    "    normalized_metric = np.zeros(performacne_metric.shape)\n",
    "    utopias = []\n",
    "    for i in range(performacne_metric.shape[1]):\n",
    "        utopias.append(np.max(performacne_metric[:, i]))\n",
    "        normalized_metric[:, i] = performacne_metric[:, i] / np.max(performacne_metric[:, i])\n",
    "    \n",
    "    # Calculate Euclidean distance\n",
    "    distances = l2norm(normalized_metric, np.ones_like(normalized_metric))\n",
    "    \n",
    "    # return the index of model with the smallest distance\n",
    "    return np.argmin(distances)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "250b3958",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Functions for Pareto\n",
    "\n",
    "def find_mingroup(df):\n",
    "    aucs = df[list(df.filter(regex=(\"auc-group\")))]\n",
    "    aucs = aucs.to_numpy()\n",
    "    maxs = []\n",
    "    for i in range(aucs.shape[1]):\n",
    "        temp = aucs[:, i].tolist()\n",
    "        maxs.append(max(temp))\n",
    "    min_idx = np.argmin(maxs)\n",
    "    return min_idx\n",
    "    \n",
    "\n",
    "def get_groups(df, metric = 'auc'):\n",
    "    aucs = df[natsorted(list(df.filter(regex=(metric+'-group'))))].values.T\n",
    "    groups = []\n",
    "    for i in range((aucs.shape[0])):\n",
    "        groups.append(aucs[i, :].astype('float'))\n",
    "    return groups\n",
    "\n",
    "def cal_pareto(groups, min_group_no):\n",
    "    pareto=oa.ParetoDoubleLong()\n",
    "    for i in range(groups[0].shape[0]):\n",
    "        w= oa.doubleVector(tuple([group[i] for group in groups]))\n",
    "        pareto.addvalue(w, i)\n",
    "    datapoints = np.stack(groups)\n",
    "    lst=pareto.allindices()\n",
    "    optimal_datapoints=datapoints[:,lst]\n",
    "\n",
    "    loc = np.argmax(optimal_datapoints.T[:, min_group_no])\n",
    "    return lst[loc]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ef4157e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# config\n",
    "\n",
    "exp_dict = {'HAM10000': [{'attribute': 'Sex', 'sens_num': 2}], 'HAM100004': [{'attribute': 'Age', 'sens_num': 4}], \n",
    "           'CXP': [{'attribute': 'Age', 'sens_num': 5}, {'attribute': 'Sex', 'sens_num': 2}, {'attribute': 'Race', 'sens_num': 2}],\n",
    "           'MIMIC_CXR': [{'attribute': 'Age', 'sens_num': 5}, {'attribute': 'Sex', 'sens_num': 2}, {'attribute': 'Race', 'sens_num': 2}],\n",
    "           'PAPILA': [{'attribute': 'Age', 'sens_num': 2}, {'attribute': 'Sex', 'sens_num': 2}],\n",
    "           'ADNI': [{'attribute': 'Age', 'sens_num': 2}, {'attribute': 'Sex', 'sens_num': 2}],\n",
    "           'OCT': [{'attribute': 'Age', 'sens_num': 2}],\n",
    "           'COVID_CT_MD': [{'attribute': 'Age', 'sens_num': 2}, {'attribute': 'Sex', 'sens_num': 2}],\n",
    "           'Fitz17k': [{'attribute': 'skin_type', 'sens_num': 6}],}\n",
    "\n",
    "model_selections = ['overall_auc', 'pareto', 'DTO']\n",
    "#model_selections = ['pareto', 'DTO']\n",
    "methods = ['baseline', 'resampling', 'DomainInd', 'LAFTR', 'CFair','LNL', 'EnD', 'ODR', 'GroupDRO', 'SWAD', \n",
    "            'SAM']\n",
    "path = 'somewhere/model_records/'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b5fd2ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = pd.DataFrame()\n",
    "\n",
    "for model_selection in model_selections:\n",
    "    for key, sens_exps in exp_dict.items():\n",
    "        dataset = key\n",
    "        for sens_exp in sens_exps:\n",
    "            attribute = sens_exp['attribute']\n",
    "            sens_num = sens_exp['sens_num']\n",
    "            if dataset in ['OCT', 'ADNI', 'COVID_CT_MD']:\n",
    "                backbone = 'cusResNet18_3d'\n",
    "            else:\n",
    "                backbone = 'cusResNet18'\n",
    "            \n",
    "            means, stds, hashs = [], [], []\n",
    "            #index = []\n",
    "            mean_metrics = pd.DataFrame()\n",
    "            std_metrics = pd.DataFrame()\n",
    "\n",
    "            for method in methods:\n",
    "                \n",
    "                val_stats = natsorted(glob.glob(path + '/{datas}/{attr}/{bkb}/{meth}/*_val_pred_stat.csv'\n",
    "                                                 .format(datas = dataset, attr = attribute, bkb = backbone, meth = method)))\n",
    "                \n",
    "                test_stats = natsorted(glob.glob(path + '/{datas}/{attr}/{bkb}/{meth}/*_test_pred_stat.csv'\n",
    "                                                 .format(datas = dataset, attr = attribute, bkb = backbone, meth = method)))\n",
    "                \n",
    "                val_metrics = pd.DataFrame()\n",
    "                \n",
    "                for val_stat in val_stats:\n",
    "                    hash_id = val_stat.split('/')[-1].split('_')[1]\n",
    "                    val_stat_df = pd.read_csv(val_stat)\n",
    "                    val_stat_df['hash'] = hash_id\n",
    "                    val_metrics = pd.concat([val_metrics, pd.DataFrame(val_stat_df.iloc[0]).T], ignore_index=True)\n",
    "                \n",
    "                # decide min/max AUC group before selection\n",
    "                if method == 'baseline':\n",
    "                    min_group_no = find_mingroup(val_metrics)\n",
    "\n",
    "                # model selection --> find hash id\n",
    "                if model_selection == 'overall_auc':\n",
    "                    hash_id = val_metrics.iloc[pd.to_numeric(val_metrics['Val Overall AUC']).idxmax()]['hash']\n",
    "                elif model_selection == 'DTO':\n",
    "                    aucs = val_metrics[list(val_metrics.filter(regex=(\"auc-group\")))].values\n",
    "                    model_idx = DTO(aucs)\n",
    "                    hash_id = val_metrics.iloc[model_idx]['hash']\n",
    "                elif model_selection == 'pareto':\n",
    "                    groups = get_groups(val_metrics, 'auc')\n",
    "                    lst_loc = cal_pareto(groups, min_group_no)\n",
    "                    hash_id = val_metrics.iloc[lst_loc]['hash']\n",
    "                else:\n",
    "                    raise ValueError(\"selection strategy not supported\")\n",
    "                    \n",
    "                hashs.append(hash_id)\n",
    "                test_result = pd.read_csv(path + '/{datas}/{attr}/{bkb}/{meth}/{meth}_{hashid}_test_pred_stat.csv'\n",
    "                                                 .format(datas = dataset, attr = attribute, bkb = backbone,meth = method, hashid = hash_id))\n",
    "                #print(path + '/{datas}/{attr}/{bkb}/{meth}/{meth}_{hashid}_test_pred_stat.csv'\n",
    "                #                                 .format(datas = dataset, attr = attribute, bkb = backbone,meth = method, hashid = hash_id))\n",
    "                #print(method, test_result)\n",
    "                mean_metrics = pd.concat([mean_metrics, pd.DataFrame(test_result.iloc[0]).T])\n",
    "                std_metrics = pd.concat([std_metrics, pd.DataFrame(test_result.iloc[1]).T])\n",
    "        \n",
    "            try:\n",
    "                mean_metrics = mean_metrics.drop(columns = ['Unnamed: 0'])\n",
    "                std_metrics = std_metrics.drop(columns = ['Unnamed: 0'])\n",
    "            except:\n",
    "                pass\n",
    "\n",
    "            mean_metrics = mean_metrics.astype('float').round(4)\n",
    "            std_metrics = std_metrics.astype('float').round(3)\n",
    "                \n",
    "            out = pd.concat([mean_metrics.reset_index(drop=True).stack(), std_metrics.reset_index(drop=True).stack()], axis=1) \\\n",
    "                    .apply(lambda x: u\"\\u00B1\".join(x.astype('str')), axis=1) \\\n",
    "                    .unstack()\n",
    "            out['methods'] = methods\n",
    "\n",
    "            out = out[ ['methods'] + [ col for col in out.columns if col != 'methods' ] ]\n",
    "            out['hash'] = hashs\n",
    "\n",
    "            out.to_csv('./results/selections/{}/{}-{}.csv'.format(model_selection, dataset, attribute), index = False)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "768c3b9d",
   "metadata": {},
   "source": [
    "### Calculate Max/Min/Gap for Different selections\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ece40d3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_max_min_gap(df):\n",
    "    mean_std = df.drop(columns=['hash']).to_numpy()\n",
    "\n",
    "    rt = np.char.split(mean_std[:, 1:].astype('str'), u\"\\u00B1\")\n",
    "    means = np.zeros(rt.shape)\n",
    "    \n",
    "    for i in range(rt.shape[0]):\n",
    "        for j in range(rt.shape[1]):\n",
    "            means[i][j] = float(rt[i][j][0])\n",
    "    means = pd.DataFrame(means, columns = df.columns[1:-1])\n",
    "\n",
    "    aucs = means[list(means.filter(regex=(\"auc-group\")))]\n",
    "    aucs = aucs.to_numpy()\n",
    "    max_auc, min_auc = [], []\n",
    "    for i in range(aucs.shape[0]):\n",
    "        temp = aucs[i].tolist()\n",
    "        max_auc.append(max(temp))\n",
    "        min_auc.append(min(temp))\n",
    "    gap = np.asarray(max_auc)-np.asarray(min_auc)\n",
    "    return np.asarray(max_auc), np.asarray(min_auc), gap\n",
    "\n",
    "\n",
    "test_results_path = './results/selections/'\n",
    "selections = ['DTO', 'overall_auc', 'pareto']\n",
    "for selection in selections:\n",
    "    exp_paths = os.listdir(os.path.join(test_results_path, selection))\n",
    "    for exp_path in exp_paths:\n",
    "        #dataset, sensitive = exp_path.split('.csv')[0].split('-')\n",
    "        #if dataset != 'CXP' or sensitive != 'Age':\n",
    "        #    continue\n",
    "        result_path = os.path.join(test_results_path, selection, exp_path)\n",
    "        results = pd.read_csv(result_path)\n",
    "        # cal\n",
    "        max_auc, min_auc, gap = get_max_min_gap(results)\n",
    "        results['Max AUC'] = max_auc\n",
    "        results['Min AUC'] = min_auc\n",
    "        results['AUC Gap'] = gap\n",
    "    \n",
    "        results.to_csv(result_path, index = False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2153be1",
   "metadata": {},
   "source": [
    "## rank them"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b206b8c4",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "methods = ['baseline', 'resampling', 'DomainInd', 'LAFTR', 'CFair', 'LNL', 'EnD', 'ODR', 'GroupDRO', 'SWAD', 'SAM']\n",
    "for selection in selections:\n",
    "    result_path = './results/selections/{}/'.format(selection)\n",
    "    results = glob.glob(result_path + '*.csv')\n",
    "\n",
    "    for result in results:\n",
    "        names = result.split('/')[-1].split('.csv')[0].split('-')\n",
    "        dataset, sensitive  = names[0], names[1]\n",
    "        \n",
    "        mean_std_df = pd.read_csv(result).drop(columns=['hash'])\n",
    "        \n",
    "        mean_std = mean_std_df.to_numpy()\n",
    "        \n",
    "        rt = np.char.split(mean_std[:, 1:].astype('str'), u\"\\u00B1\")\n",
    "        means = np.zeros(rt.shape)\n",
    "        \n",
    "        for i in range(rt.shape[0]):\n",
    "            for j in range(rt.shape[1]):\n",
    "                means[i][j] = float(rt[i][j][0])\n",
    "        means = pd.DataFrame(means, columns = mean_std_df.columns[1:])\n",
    "        # todo some ascending, some descending\n",
    "        means = means.rank(ascending = False, numeric_only = True)\n",
    "        try:\n",
    "            means['methods'] = methods\n",
    "        except:\n",
    "            means['methods'] = mean_std_df['methods']\n",
    "        means = means[ ['methods'] + [ col for col in means.columns if col != 'methods' ] ]\n",
    "    \n",
    "        means.to_csv('./results/selections/{}/{}-{}.csv'.format(selection+'_rank', dataset, sensitive), index = False)\n",
    "means"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6b1fa7e",
   "metadata": {},
   "source": [
    "## Statistic Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfd3cfb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy import stats\n",
    "import scikit_posthocs as sp\n",
    "import glob\n",
    "import Orange\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf934c3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "selections = ['overall_auc', 'pareto', 'DTO']\n",
    "Metrics = ['Max AUC', 'Min AUC', 'AUC Gap', 'Test Overall AUC']\n",
    "methods_toplot = ['ERM', 'Resampling', 'DomainInd', 'LAFTR', 'CFair', 'LNL', 'EnD', 'ODR', 'GroupDRO', 'SWAD', 'SAM']\n",
    "\n",
    "for metric in Metrics:\n",
    "    for selection in selections:\n",
    "        result_path = './results/selections/{}/'.format(selection + '_rank')\n",
    "        results = glob.glob(result_path + '*.csv')\n",
    "    \n",
    "        min_ranks, max_ranks, gap_ranks, overall_ranks = [], [], [], []\n",
    "        ranks = []\n",
    "        for result in results:\n",
    "            names = result.split('/')[-1].split('.csv')[0].split('-')\n",
    "            dataset, sensitive  = names[0], names[1]\n",
    "            \n",
    "            mean_std_df = pd.read_csv(result)\n",
    "            \n",
    "            values = mean_std_df[metric].values\n",
    "            if metric == 'AUC Gap':\n",
    "                values = 12 - values\n",
    "            ranks.append(values)\n",
    "    \n",
    "            avgrank = np.mean(ranks, 0)\n",
    "        cd = Orange.evaluation.compute_CD(avgrank, len(ranks), alpha='0.05', test='nemenyi') #tested on 13 datasets \n",
    "        print('cd=', cd)\n",
    "        Orange.evaluation.graph_ranks(avgrank, methods_toplot, cd=cd, width=5, textspace=1.5,\n",
    "                                filename = 'results/selections/cd_diagrams/{}_{}_in_distribution.pdf'.format(selection, metric))\n",
    "            \n",
    "            \n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "791cca96",
   "metadata": {},
   "source": [
    "### rank ERM across datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21789629",
   "metadata": {},
   "outputs": [],
   "source": [
    "selections = ['overall_auc', 'DTO', 'pareto']\n",
    "#Metrics = ['Max AUC', 'Min AUC', 'AUC Gap', 'Test Overall AUC']\n",
    "Metrics = ['Min AUC']\n",
    "\n",
    "methods_toplot = ['ERM']\n",
    "\n",
    "total_df = pd.DataFrame()\n",
    "\n",
    "for selection in selections:\n",
    "    \n",
    "    value_path = './results/selections/{}/'.format(selection)\n",
    "    value_results = natsorted(glob.glob(value_path + '*.csv'))\n",
    "\n",
    "    mean_values = []\n",
    "    column_names = []\n",
    "    for value_result in value_results:\n",
    "        dataset, sensitive = value_result.split('/')[-1].split('.csv')[0].split('-')\n",
    "        column_names.append(dataset  + '-' + sensitive)\n",
    "\n",
    "        # process values\n",
    "        value_df = pd.read_csv(value_result).drop(columns=['hash'])\n",
    "        mean_std = value_df.to_numpy()\n",
    "        \n",
    "        rt = np.char.split(mean_std[:, 1:].astype('str'), u\"\\u00B1\")\n",
    "        means = np.zeros(rt.shape)\n",
    "        \n",
    "        for i in range(rt.shape[0]):\n",
    "            for j in range(rt.shape[1]):\n",
    "                means[i][j] = float(rt[i][j][0])\n",
    "        means = pd.DataFrame(means, columns = value_df.columns[1:])\n",
    "\n",
    "        values = means[Metrics].iloc[0]\n",
    "        mean_values.append(values)\n",
    "\n",
    "    mean_values = np.stack(mean_values).squeeze()\n",
    "        \n",
    "\n",
    "    #to_present = ['%.1f' % rk + '/' + '%.2f' % (_mean*100) for (rk, _mean) in zip(avg_rank, avg_mean)]\n",
    "    total_df = pd.concat([total_df, pd.DataFrame([mean_values], columns=column_names)])\n",
    "\n",
    "total_df['Selection Strategy'] = ['Overall Performance', 'DTO', 'Pareto Optimal']\n",
    "total_df = total_df[ ['Selection Strategy'] + [ col for col in total_df.columns if col != 'Selection Strategy' ] ]\n",
    "total_df\n",
    "\n",
    "total_rank_df = total_df.rank(ascending = False, numeric_only = True)\n",
    "rank_values = total_rank_df.to_numpy()\n",
    "avgrank = np.mean(rank_values, axis=1)\n",
    "\n",
    "cd = Orange.evaluation.compute_CD(avgrank, 16, alpha='0.05', test='nemenyi') #tested on 13 datasets \n",
    "print('cd=', cd)\n",
    "Orange.evaluation.graph_ranks(avgrank, total_df['Selection Strategy'].values, cd=cd, width=5, textspace=1.5,)\n",
    "                        #filename = 'results/selections/cd_diagrams/{}_{}_in_distribution.pdf'.format(selection, metric))\n",
    "           "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6.8 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "vscode": {
   "interpreter": {
    "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
