{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac8b1654",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import sklearn.metrics\n",
    "import warnings\n",
    "import csv\n",
    "from os.path import exists\n",
    "from tqdm import *\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import jpype.imports\n",
    "import time \n",
    "\n",
    "\n",
    "\n",
    "resdir = '/ocean/projects/mth230012p/user/hrt-factor/results/'\n",
    "datadir = '/ocean/projects/mth230012p/user/hrt-factor/data/'\n",
    "modeldir = '/ocean/projects/mth230012p/user/hrt-factor/models/'\n",
    "figdir = 'figures/'\n",
    "\n",
    "genes = pd.read_csv(datadir + 'missing_gene_alteration_matrix.csv',index_col=0)\n",
    "metastases = pd.read_csv(datadir + 'metastasis_bysamples.csv',index_col=0)\n",
    "ds = genes.merge(metastases,on=\"uid\",how=\"left\")\n",
    "ds.to_pickle(datadir + 'merged_dataset.pkl')\n",
    "nomissing = ds.dropna(axis='columns')\n",
    "nomissing = nomissing.iloc[:,:331]\n",
    "nomissing.columns\n",
    "nonzeros = nomissing.sum(0) > 400\n",
    "nonzeros = nonzeros.reset_index()\n",
    "gene_names = nonzeros[nonzeros[0]]['index'].reset_index(drop=True)\n",
    "tumor_names = ['adrenal_gland', 'biliary_tract', 'bladder_or_urinary_tract', 'bone',\n",
    "       'bowel', 'breast', 'cns_brain', 'dist_lymph', 'genital_female',\n",
    "       'genital_male', 'head_and_neck', 'kidney', 'liver', 'lung', 'lymph',\n",
    "       'mediastinum', 'other', 'ovary', 'peripheral_nervous_system',\n",
    "       'peritoneum', 'pleura', 'regional_lymph', 'skin']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50f4cc55",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def compile_results(search,duration,label):\n",
    "    h=search.get_pcalg()[X.shape[1]:].copy(deep=True)\n",
    "    h.index =  [\"Y\" + str(j) for j in range(0,Y.shape[1])]\n",
    "    h=pd.melt(h[h.columns[:X.shape[1]]],ignore_index=False).reset_index()\n",
    "    h.columns = [\"Y_target\",\"X_target\",\"edge\"]\n",
    "    h[\"Y_target\"] = pd.to_numeric(h[\"Y_target\"].str.slice(1))\n",
    "    h[\"X_target\"] =  pd.to_numeric(h[\"X_target\"].str.slice(1))\n",
    "    h['is_edge'] = h.apply(lambda row : row[\"X_target\"] in truth[row[\"Y_target\"]],axis=1)\n",
    "    h['method'] = label\n",
    "    h['dataset'] = ds\n",
    "    h['duration'] = duration\n",
    "    return(h)\n",
    "   \n",
    "\n",
    "def get_truth(row): \n",
    "    if(\"collide\" in row['Dataset']):\n",
    "        ds_true = datadir + row['Dataset'] + \"_xtrue.npy\"\n",
    "    elif('actualsmall' in row['Dataset'] ):\n",
    "        return None\n",
    "    else:\n",
    "        ds_true = datadir + row['Dataset'] + \"_ytrue.npy\"\n",
    "    y_true = np.load(ds_true)\n",
    "    return row[\"X_target\"] in y_true[row[\"Y_target\"]]\n",
    "\n",
    "\n",
    "def get_metrics(results,search_type,alpha_list = np.arange(0,1.00,0.005)):\n",
    "    pvals_comp = results[(results['Type'] == search_type)]\n",
    "    tpr = [np.sum((pvals_comp['p'] <= alpha)&(pvals_comp['is_edge'] ==1)) / np.sum(pvals_comp['is_edge'] ==1) for alpha in alpha_list] \n",
    "    fpr = [ np.sum((pvals_comp['p'] <= alpha)&(pvals_comp['is_edge'] ==0)) / np.sum(pvals_comp['is_edge'] ==0) for alpha in alpha_list]\n",
    "    precision = [np.sum((pvals_comp['p'] <= alpha)&(pvals_comp['is_edge'] ==1)) / np.sum(pvals_comp['p'] <= alpha) for alpha in alpha_list]\n",
    "    fdr_bh = [np.sum((pvals_comp['p'] < alpha *( pvals_comp['p'].rank() / pvals_comp.shape[0])) & (~pvals_comp['is_edge'])) / np.sum((pvals_comp['p'] < alpha *( pvals_comp['p'].rank() / pvals_comp.shape[0]))) for alpha in alpha_list]\n",
    "    power_bh = [np.sum((pvals_comp['p'] < alpha *( pvals_comp['p'].rank() / pvals_comp.shape[0])) & (pvals_comp['is_edge']))/ np.sum(pvals_comp['is_edge']) for alpha in alpha_list]\n",
    "    return alpha_list, tpr, fpr, fdr_bh, power_bh\n",
    "\n",
    "def rollup_ds(label,model_type,alpha = 0.2,early_stop = 0.5):\n",
    "    if(label in ['synth_med','synth']):\n",
    "        truth = np.load(datadir+ label + \"_ytrue.npy\")\n",
    "    elif('semisynth' in label):\n",
    "        truth = np.load(datadir+ label + \"_ytrue.npy\")\n",
    "    elif(\"collide\" in label):\n",
    "        truth = np.load(datadir+ label + \"_xtrue.npy\")\n",
    "\n",
    "    x_dim = np.load(datadir+ label + \"_x.npy\").shape[1]\n",
    "    y_dim = np.load(datadir+ label + \"_y.npy\").shape[1]\n",
    "\n",
    "    i=0\n",
    "    outfile = \"bash/redo.txt\"\n",
    "    for X_target in trange(x_dim):\n",
    "        for Y_target in range(y_dim):\n",
    "            filename = resdir + label.replace('_', '') + \"_\" + model_type + \"_\" + str(X_target) + \"_\" + str(Y_target) + \"_combined.pkl\"\n",
    "            if(os.path.exists(filename) != True):\n",
    "                filename = resdir + label.replace('_', '') + \"_\" + model_type + \"_univariate_\" + str(X_target) + \"_\" + str(Y_target) + \"_combined.pkl\"\n",
    "            if(os.path.exists(filename)):\n",
    "                ds = pd.read_pickle(filename)\n",
    "                ds.loc[ds['p'] > early_stop,'p'] = 1.0\n",
    "                if(\"actualsmall\" not in label):\n",
    "                    ds['is_edge'] = X_target in truth[Y_target]\n",
    "                ds[\"X_target\"] = X_target\n",
    "                ds[\"Y_target\"] = Y_target\n",
    "                if(i==0):\n",
    "                    ds_comb =ds\n",
    "                else:\n",
    "                    ds_comb =pd.concat([ds_comb, ds], ignore_index=True, axis=0)\n",
    "                i=i+1\n",
    "            else:\n",
    "                with open(outfile, 'a') as f:\n",
    "                    writer = csv.writer(f, delimiter =' ')\n",
    "                    writer.writerow([label, X_target, Y_target, model_type])\n",
    "\n",
    "    ds_copy = ds_comb.loc[ds_comb[\"Type\"].isnull()].copy()\n",
    "    ds_copy.loc[ds_copy[\"Type\"].isnull(),\"Type\"] = \"GSO\"\n",
    "    pd.concat([ds_comb, ds], ignore_index=True, axis=0)\n",
    "    ds_comb.loc[ds_comb[\"Type\"].isnull(),\"Type\"] = \"Hybrid\"\n",
    "    if(\"actualsmall\" not in label):\n",
    "        ds_rollup = ds_comb[[\"X_target\",\"Y_target\",\"Type\",\"is_edge\",\"p\"]].groupby([\"X_target\",\"Y_target\",\"Type\",\"is_edge\"]).max().reset_index()\n",
    "    else:\n",
    "        ds_rollup = ds_comb[[\"X_target\",\"Y_target\",\"Type\",\"p\"]].groupby([\"X_target\",\"Y_target\",\"Type\"]).max().reset_index()\n",
    "    ds_rollup['gene'] = ds_rollup['X_target'].apply( lambda x : gene_names[x])\n",
    "    ds_rollup['tumor'] = ds_rollup['Y_target'].apply( lambda x : tumor_names[x])\n",
    "    ds_rollup.reset_index(drop=True,inplace=True)\n",
    "    cutoff_FDR = alpha*ds_rollup['p'].rank(method=\"first\")/ds_rollup.shape[0]\n",
    "    ds_rollup['reject'] = ds_rollup['p'] < cutoff_FDR\n",
    "    ds_dur = ds_comb[[\"X_target\",\"Y_target\",\"Type\",\"duration\"]].groupby([\"X_target\",\"Y_target\",\"Type\"]).sum().reset_index()\n",
    "    ds_rollup = ds_rollup.merge(ds_dur,on=[\"X_target\",\"Y_target\",'Type'])\n",
    "    return ds_comb, ds_rollup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eeaac3f5",
   "metadata": {},
   "source": [
    "# Compile Experimental Results "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92702eb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_list = ['collidesynth-10-10', 'collidesynth-15-15', 'collidesynth-20-20',\n",
    "           'collidesynth-5-5', 'semisynth-10-10', 'semisynth-15-15','semisynth-20-20',\n",
    "           'semisynth-5-5','synth_med','synth','synth_collide_large_0.2',\n",
    "           'synth_collide_large_0.4','synth_collide_large_0.6','synth_collide_large_0.8' ]\n",
    "\n",
    "i=0\n",
    "for ds in ds_list:\n",
    "    for model_type in ['logit','mlp']:\n",
    "        comb, rollup = rollup_ds(ds,model_type)\n",
    "        thresh, tpr, fpr, fdr_bh, power_bh = get_metrics(rollup,\"Hybrid\")\n",
    "        result_comp = pd.DataFrame(thresh,columns=['thresh'])\n",
    "        result_comp['dataset'] = ds\n",
    "        result_comp['model_type'] = model_type\n",
    "        result_comp['tpr'] = tpr\n",
    "        result_comp['fpr'] = fpr\n",
    "        result_comp['fdr_bh'] = fdr_bh\n",
    "        result_comp['power_bh'] = power_bh\n",
    "        if \"-\" in ds: \n",
    "            num_S1 = ds.split('-')[1]\n",
    "            num_S2 = ds.split('-')[2]\n",
    "            result_comp[\"num_X\"] = num_S1\n",
    "            result_comp[\"num_Y\"] = num_S2\n",
    "        if \"collide\" in ds: \n",
    "            result_comp[\"ds_type\"] = \"collide\"\n",
    "            if \"_\" in ds:\n",
    "                result_comp[\"corr\"] = float(ds.split(\"_\")[-1])\n",
    "\n",
    "        else:\n",
    "            result_comp[\"ds_type\"] = \"synth\"\n",
    "        if(i==0):\n",
    "            result_agg = result_comp\n",
    "        else:\n",
    "            result_agg = pd.concat([result_agg, result_comp], ignore_index=True, axis=0)\n",
    "        i=i+1\n",
    "result_agg.to_pickle(resdir + \"scsl_combined.pkl\")\n",
    "\n",
    "\n",
    "ds_list = ['synth_collide_cont_0.1',\n",
    "           'synth_collide_cont_0.3','synth_collide_cont_0.5','synth_collide_cont_0.7','synth_collide_cont_0.9']\n",
    "\n",
    "i=0\n",
    "for ds in ds_list:\n",
    "    for model_type in ['mlp']:\n",
    "        comb, rollup = rollup_ds(ds,model_type)\n",
    "        thresh, tpr, fpr, fdr_bh, power_bh = get_metrics(rollup,\"Hybrid\")\n",
    "        result_comp = pd.DataFrame(thresh,columns=['thresh'])\n",
    "        result_comp['dataset'] = ds\n",
    "        result_comp['model_type'] = model_type\n",
    "        result_comp['tpr'] = tpr\n",
    "        result_comp['fpr'] = fpr\n",
    "        result_comp['fdr_bh'] = fdr_bh\n",
    "        result_comp['power_bh'] = power_bh\n",
    "        if \"-\" in ds: \n",
    "            num_S1 = ds.split('-')[1]\n",
    "            num_S2 = ds.split('-')[2]\n",
    "            result_comp[\"num_X\"] = num_S1\n",
    "            result_comp[\"num_Y\"] = num_S2\n",
    "        if \"collide\" in ds: \n",
    "            result_comp[\"ds_type\"] = \"collide\"\n",
    "            if \"_\" in ds:\n",
    "                result_comp[\"corr\"] = float(ds.split(\"_\")[-1])\n",
    "\n",
    "        else:\n",
    "            result_comp[\"ds_type\"] = \"synth\"\n",
    "        if(i==0):\n",
    "            result_agg = result_comp\n",
    "        else:\n",
    "            result_agg = pd.concat([result_agg, result_comp], ignore_index=True, axis=0)\n",
    "        i=i+1\n",
    "result_agg.to_pickle(resdir + \"scsl_cont.pkl\")\n",
    "\n",
    "\n",
    "# assign directory\n",
    "i=0\n",
    "for filename in os.listdir(resdir):\n",
    "    f = os.path.join(resdir, filename)\n",
    "    if (('pcalgo' in f) & ('benchmark' not in f) &('pcalgo_combined' not in f)):\n",
    "        print(f)\n",
    "        alpha = float(filename.split(\"_\")[-2])\n",
    "        temp = pd.read_pickle(f)\n",
    "        temp['alpha'] = alpha\n",
    "        if 'collide' in f:\n",
    "            temp['ds_type'] = 'collide'\n",
    "            if \"large\" in f:\n",
    "                temp[\"corr\"] = float(temp.loc[0,\"Dataset\"].split(\"_\")[-1])\n",
    "        if \"semisynth\" in f:\n",
    "            temp['ds_type'] = 'synth'\n",
    "        if \"-\" in temp.loc[0,\"Dataset\"]:\n",
    "            temp['num_X'] = temp.loc[0,\"Dataset\"].split(\"-\")[1]\n",
    "            temp['num_Y'] = temp.loc[0,\"Dataset\"].split(\"-\")[2]\n",
    "        if i == 0:\n",
    "            results = temp\n",
    "        else:\n",
    "            results = pd.concat([results, temp], ignore_index=True, axis=0)\n",
    "        i=i+1\n",
    "\n",
    "results.sort_values(['Dataset'])\n",
    "results['truth'] = results.apply(get_truth, axis = 1)\n",
    "\n",
    "results['numer_tpr'] = results['in_graph']*results['truth']\n",
    "results['numer_fpr'] = results['in_graph']*(~results['truth'])\n",
    "results['denom_fpr'] = ~results['truth']\n",
    "\n",
    "\n",
    "pcalgo_agg = results.groupby(['Dataset','Model Type','alpha','ds_type','num_X','num_Y','corr'],dropna=False).sum()\n",
    "pcalgo_agg.drop(['X_target','Y_target'],axis=1,inplace=True)\n",
    "pcalgo_agg['fpr'] = (pcalgo_agg['numer_fpr']/pcalgo_agg['denom_fpr'])\n",
    "pcalgo_agg['tpr'] = (pcalgo_agg['numer_tpr']/pcalgo_agg['truth'])\n",
    "pcalgo_agg = pcalgo_agg.reset_index()   \n",
    "pcalgo_agg['fpr_stderr'] =2*np.sqrt((pcalgo_agg['fpr']*(1-pcalgo_agg['fpr']))/pcalgo_agg['denom_fpr'])\n",
    "pcalgo_agg['tpr_stderr'] =2*np.sqrt((pcalgo_agg['tpr']*(1-pcalgo_agg['tpr']))/pcalgo_agg['truth'])\n",
    "pcalgo_agg['num_X'] = pcalgo_agg['num_X'].astype(float)\n",
    "pcalgo_agg['num_Y'] = pcalgo_agg['num_Y'].astype(float)\n",
    "pcalgo_agg.to_pickle(resdir + \"pcalgo_combined.pkl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60972f91",
   "metadata": {},
   "source": [
    "# Recreate  Experimental Figures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8c4589b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Figure 4\n",
    "logit_comb, logit_rollup = rollup_ds('synth_med','logit')\n",
    "ds_comb = logit_comb\n",
    "ds_rollup = ds_comb[~ds_comb['p_rank_cum'].isnull()]\n",
    "ds_iterview = ds_rollup.groupby([\"Type\",\"iter\",\"Model Type\"]).mean()[['p_rank_cum']].reset_index()\n",
    "plt.plot(ds_iterview[ds_iterview['Type'] == 'Exhaustive']['iter'],np.log(ds_iterview[ds_iterview['Type'] == 'Exhaustive']['p_rank_cum']),label=\"Naive\")\n",
    "plt.plot(ds_iterview[ds_iterview['Type'] == 'GSO']['iter'],np.log(ds_iterview[ds_iterview['Type'] == 'GSO']['p_rank_cum']),label=\"GSO\",linestyle=\"dashed\")\n",
    "plt.plot(ds_iterview[ds_iterview['Type'] == 'Hybrid']['iter'],np.log(ds_iterview[ds_iterview['Type'] == 'Hybrid']['p_rank_cum']),label=\"Hybrid\",linestyle=\"dotted\")\n",
    "plt.xlabel(\"Iteration\")\n",
    "plt.ylabel(\"Log-rank of test statistic\")\n",
    "plt.legend()\n",
    "plt.xlim([-50, 1000])\n",
    "plt.savefig(\"iter.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39f1b0ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "scsl = pd.read_pickle(resdir + \"scsl_combined.pkl\")\n",
    "pcalgo = pd.read_pickle(resdir + \"pcalgo_combined.pkl\")\n",
    "model_type = \"logit\"\n",
    "ds = 'synth_med'\n",
    "\n",
    "scsl_subset = scsl[(scsl['dataset'] == ds)&(scsl['model_type'] == model_type)]\n",
    "pc_subset = pcalgo[(pcalgo['Dataset'] == ds)&(pcalgo['Model Type'] == model_type)]\n",
    "plt.plot(scsl_subset['thresh'],scsl_subset['tpr'],label=\"SCSL (TPR)\",color=\"blue\",linestyle = \"-\",marker= 'v',markevery=5)\n",
    "plt.plot(pc_subset['alpha'],pc_subset['tpr'],color=\"brown\",marker='x',label=\"PC-p (TPR)\")\n",
    "plt.plot(scsl_subset['thresh'],scsl_subset['fpr'],label=\"SCSL (FPR)\",color=\"blue\",linestyle = \"dashed\",marker= 'v',markevery=5)\n",
    "plt.plot(pc_subset['alpha'],pc_subset['fpr'],color=\"brown\",marker='x',linestyle = \"dashed\",label=\"PC-p (FPR)\")\n",
    "plt.plot(pc_subset['alpha'],pc_subset['alpha'],color='black',linestyle='dashed')\n",
    "plt.xlim([0,0.3])\n",
    "plt.legend()\n",
    "plt.xlabel(\"P-value threshold\")\n",
    "plt.show()\n",
    "\n",
    "plt.plot(np.log(scsl_subset['thresh']),scsl_subset['fpr']/scsl_subset['thresh'],label=\"SCSL\",color=\"blue\",linestyle = \"dashed\",marker= 'v',markevery=5)\n",
    "plt.plot(np.log(pc_subset['alpha']),pc_subset['fpr']/pc_subset['alpha'],color=\"brown\",marker='x',label=\"PC-p\")\n",
    "plt.axhline(y= 1,color=\"black\",linestyle=\"dashed\")\n",
    "plt.xlabel(\"Log of p-value Threshold\")\n",
    "plt.ylabel(\"Actual FPR / Target\")\n",
    "plt.xlim([-5,-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cae60042",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_list = [\n",
    "['semisynth-5-5',\"Algorithm 2 (5, 5)\"],   \n",
    "['semisynth-10-10',\"Algorithm 2 (5, 5)\"],   \n",
    "['semisynth-15-15',\"Algorithm 2 (15, 15)\"],   \n",
    "['synth_med',\"Algorithm 2 (47, 12)\"],\n",
    "['synth',\"Algorithm 2 (47, 23)\"],\n",
    "['synth_collide_large_0.2',\"Algorithm 3 (0.2)\"],\n",
    "['synth_collide_large_0.4',\"Algorithm 3 (0.4)\"],\n",
    "['synth_collide_large_0.6',\"Algorithm 3 (0.6)\"],\n",
    "['synth_collide_large_0.8',\"Algorithm 3 (0.8)\"]]\n",
    "\n",
    "for datas in ds_list:\n",
    "    ds = datas[0]\n",
    "    label = datas[1]\n",
    "    res_subset = scsl[(scsl[\"dataset\"]==ds)* (scsl[\"model_type\"] == 'logit')]\n",
    "    pc_subset = pcalgo[(pcalgo[\"Dataset\"]==ds)*(pcalgo[\"Model Type\"] == 'logit')]\n",
    "    plt.plot(res_subset['thresh'],res_subset['tpr'],color=\"blue\",label=\"Proposal (TPR)\")\n",
    "    plt.plot(res_subset['thresh'],res_subset['fpr'],color=\"blue\",label=\"Proposal (FPR)\",linestyle =\"dashed\")\n",
    "    plt.plot(pc_subset['alpha'],pc_subset['tpr'],color=\"red\",label=\"PC w/GCM (TPR)\")\n",
    "    plt.plot(pc_subset['alpha'],pc_subset['fpr'],color=\"red\",linestyle = \"dashed\",label=\"PC w/GCM (FPR)\")\n",
    "    plt.plot([0.005,0.5],[0.005,0.5],color='black',linestyle='dashed')\n",
    "    plt.xlim([0.005,0.3])\n",
    "    plt.legend()\n",
    "    plt.title(label)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "102b9f0b",
   "metadata": {},
   "source": [
    "# Recreate Benchmarking Figures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9c1fd92",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "method_list = ['fges',\n",
    "'boss',\n",
    "'sp',\n",
    "'grasp',\n",
    "'pc',\n",
    "'fci',\n",
    "'gfci',\n",
    "'bfci',\n",
    "'grasp-fci',\n",
    "'ccd']\n",
    "\n",
    "\n",
    "ds_list = ['collidesynth-5-5',\n",
    "'collidesynth-10-10',\n",
    "'collidesynth-15-15',\n",
    "'collidesynth-20-20',\n",
    "'semisynth-5-5',\n",
    "'semisynth-10-10',\n",
    "'semisynth-15-15',\n",
    "'semisynth-20-20',\n",
    "'synth_med',\n",
    "'synth',\n",
    "'synth_collide_large_0.0',\n",
    "'synth_collide_large_0.1',\n",
    "'synth_collide_large_0.2',\n",
    "'synth_collide_large_0.3',\n",
    "'synth_collide_large_0.4',\n",
    "'synth_collide_large_0.5',\n",
    "'synth_collide_large_0.6',\n",
    "'synth_collide_large_0.7',\n",
    "'synth_collide_large_0.8',\n",
    "'synth_collide_large_0.9',\n",
    "'synth_collide_cont_0.0',\n",
    "'synth_collide_cont_0.1',\n",
    "'synth_collide_cont_0.2',\n",
    "'synth_collide_cont_0.3',\n",
    "'synth_collide_cont_0.4',\n",
    "'synth_collide_cont_0.5',\n",
    "'synth_collide_cont_0.6',\n",
    "'synth_collide_cont_0.7',\n",
    "'synth_collide_cont_0.8',\n",
    "'synth_collide_cont_0.9']\n",
    "\n",
    "first = False\n",
    "for ds in ds_list:\n",
    "    for method in method_list:\n",
    "        filepath = resdir + ds +'_'+ method +'_benchmarks.pkl'\n",
    "        if(os.path.exists(filepath)):\n",
    "            temp = pd.read_pickle(filepath)\n",
    "            if first == False:\n",
    "                results = temp\n",
    "                first = True\n",
    "            else:\n",
    "                results = pd.concat([results,temp],axis=0)\n",
    "results['tp'] = (results['edge'] != 0)&results['is_edge']\n",
    "results['fn'] = (results['edge'] == 0)&results['is_edge']\n",
    "results['tn'] = (results['edge'] == 0)&~results['is_edge']\n",
    "results['fp'] = (results['edge'] != 0)&~results['is_edge']\n",
    "results_compiled = results.groupby(by=['dataset','method']).sum()\n",
    "results_compiled['fpr'] = results_compiled['fp']/(results_compiled['tn']+ results_compiled['fp'])\n",
    "results_compiled['tpr'] = results_compiled['tp']/(results_compiled['tp']+ results_compiled['fn'])\n",
    "benchmarks_compiled = results_compiled.reset_index() \n",
    "\n",
    "figdir = ''\n",
    "\n",
    "res_large = res_large.sort_values(by = ['num_nodes'])\n",
    "label = 'gfci'\n",
    "metric = 'fpr'\n",
    "plt.plot(res_large[res_large['method'] == label]['num_nodes'],res_large[res_large['method'] == label][metric],label=\"GFCI\",marker='v')\n",
    "label = 'fges'\n",
    "plt.plot(res_large[res_large['method'] == label]['num_nodes'],res_large[res_large['method'] == label][metric],label=\"FGES\",marker='s')\n",
    "label = 'grasp-fci'\n",
    "plt.plot(res_large[res_large['method'] == label]['num_nodes'],res_large[res_large['method'] == label][metric],label=\"GRaSP-FCI\",marker='.')\n",
    "label = 'grasp'\n",
    "plt.plot(res_large[res_large['method'] == label]['num_nodes'],res_large[res_large['method'] == label][metric],label=\"GRaSP\",marker='o')\n",
    "label = 'boss'\n",
    "plt.plot(res_large[res_large['method'] == label]['num_nodes'],res_large[res_large['method'] == label][metric],label=\"BOSS\",marker='^')\n",
    "label = 'pc'\n",
    "plt.plot(res_large[res_large['method'] == label]['num_nodes'],res_large[res_large['method'] == label][metric],label=\"PC\",marker='>')\n",
    "#plt.legend()\n",
    "plt.xlabel(\"Number of Nodes\")\n",
    "plt.ylabel(\"False Positive Rate\")\n",
    "plt.ylim(top=0.1)\n",
    "plt.savefig(figdir  + metric + \"_benchmarks.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hrt-funs",
   "language": "python",
   "name": "hrt-funs"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
