{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f66dbff0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import numpy.random as rn\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import scipy.stats as st\n",
    "import csv\n",
    "import sys\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import warnings\n",
    "import itertools\n",
    "\n",
    "datadir = 'data/'\n",
    "figdir = 'figures/'\n",
    "\n",
    "#Algorithm 2 (with discrete responses)\n",
    "def collider_synth(X,max_feats=30,K=30,k=5,num_Y=15,prob_collide = 0.5, sd_X =1, mu_X = 1, sd_Y=0.5,mu_Y=5, intercept = -6):    \n",
    "    nonzeros = np.where(X.sum(0) > 400)[0]\n",
    "    X = X[:, nonzeros]\n",
    "    np.random.shuffle(X)\n",
    "    top_feats = np.argsort(X.sum(0))[-max_feats:]  \n",
    "    sig_feats = top_feats[np.random.choice(range(max_feats), K, replace=False)]\n",
    "\n",
    "    # k = 10  # number of significant features for each Y_j\n",
    "    x_feats = []\n",
    "    y_feats = []\n",
    "    y_coefs = []\n",
    "    Y = np.zeros((X.shape[0],num_Y))\n",
    "    for j in range(num_Y):\n",
    "        x_feats.append(sig_feats[np.random.choice(range(K), k, replace=False)])\n",
    "        y_depend = np.random.binomial(1,prob_collide,j)\n",
    "        y_feats.append(np.nonzero(y_depend))\n",
    "        y_coefs.append(np.random.randn(len(y_feats[j][0]))*sd_Y + mu_Y)\n",
    "\n",
    "\n",
    "    x_coefs = (np.random.randn(k*num_Y)*sd_X + mu_X).reshape(num_Y,k)\n",
    "\n",
    "\n",
    "    for j in range(num_Y):\n",
    "        logodds = intercept + np.matmul(X[:,x_feats[j]] , x_coefs[j].T)\n",
    "        if(y_coefs[j].shape[0] > 0 ):\n",
    "            logodds = logodds + np.matmul(Y[:,y_feats[j][0]] , y_coefs[j].T)\n",
    "        like = 1 / (1 + np.exp(-logodds))\n",
    "        Y[:,j] = np.random.binomial(1,like)\n",
    "    return X, Y, x_feats, y_feats\n",
    "\n",
    "#Algorithm 2 (with continuous responses)\n",
    "def collider_cont(X,max_feats=30,K=30,k=5,num_Y=15,prob_collide = 0.5, sd_X =1, mu_X = 1, sd_Y=0.5,mu_Y=1, intercept = 0,sd=1):    \n",
    "    nonzeros = np.where(X.sum(0) > 400)[0]\n",
    "    X = X[:, nonzeros]\n",
    "    np.random.shuffle(X)\n",
    "    top_feats = np.argsort(X.sum(0))[-max_feats:]  \n",
    "    sig_feats = top_feats[np.random.choice(range(max_feats), K, replace=False)]\n",
    "\n",
    "    # k = 10  # number of significant features for each Y_j\n",
    "    x_feats = []\n",
    "    y_feats = []\n",
    "    y_coefs = []\n",
    "    Y = np.zeros((X.shape[0],num_Y))\n",
    "    Y_response = np.zeros((X.shape[0],num_Y))\n",
    "    for j in range(num_Y):\n",
    "        x_feats.append(sig_feats[np.random.choice(range(K), k, replace=False)])\n",
    "        y_depend = np.random.binomial(1,prob_collide,j)\n",
    "        y_feats.append(np.nonzero(y_depend))\n",
    "        y_coefs.append(np.random.randn(len(y_feats[j][0]))*sd_Y + mu_Y)\n",
    "\n",
    "\n",
    "    x_coefs = (np.random.randn(k*num_Y)*sd_X + mu_X).reshape(num_Y,k)\n",
    "\n",
    "\n",
    "    for j in range(num_Y):\n",
    "        response = intercept + np.matmul(X[:,x_feats[j]] , x_coefs[j].T)\n",
    "        if(y_coefs[j].shape[0] > 0 ):\n",
    "            response = response + np.matmul(Y[:,y_feats[j][0]] , y_coefs[j].T)\n",
    "        Y_response[:,j] = response\n",
    "        Y[:,j] = np.random.normal(response,sd)\n",
    "    return X, Y,Y_response, x_feats, y_feats\n",
    "\n",
    "#Algorithm 3 (real-world confoduning)\n",
    "def semi_synth(X, Y, max_feats, K=30, k=10,truncate=True):\n",
    "    if(truncate == True):\n",
    "        nonzeros = np.where(X.sum(0) > 400)[0]\n",
    "        X = X[:, nonzeros]\n",
    "    np.random.shuffle(X)\n",
    "    np.random.shuffle(Y)\n",
    "    top_feats = np.argsort(X.sum(0))[-max_feats:]  \n",
    "    sig_feats = top_feats[np.random.choice(range(max_feats), K, replace=False)]\n",
    "\n",
    "    # k = 10  # number of significant features for each Y_j\n",
    "    y_feats = []\n",
    "    for j in range(Y.shape[1]):\n",
    "        y_feats.append(sig_feats[np.random.choice(range(K), k, replace=False)])\n",
    "    print(y_feats)\n",
    "    y_coefs = abs(np.random.randn(X.shape[1]*Y.shape[1]).reshape(Y.shape[1], X.shape[1]))+2\n",
    "    coef_mask = np.zeros(y_coefs.shape)\n",
    "    for j in range(len(y_feats)):\n",
    "        coef_mask[j, y_feats[j]] = 1\n",
    "\n",
    "    y_coefs = y_coefs * coef_mask  # zero out null features\n",
    "    like = 1 / (1 + np.exp(-np.matmul(X, y_coefs.T)))\n",
    "    Y_tilde = np.zeros_like(Y)\n",
    "\n",
    "    for i in range(X.shape[0]):\n",
    "        like_i = like[i:i+1]\n",
    "        prob_y = like_i * Y + (1-like_i) * (1 - Y)\n",
    "        prob_y = prob_y.sum(1)\n",
    "        prob_y = prob_y / prob_y.sum()\n",
    "        Y_tilde[i, :] = Y[np.random.choice(range(Y.shape[0]), p=prob_y)]\n",
    "    return X, Y_tilde, y_feats"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d34ae19",
   "metadata": {},
   "source": [
    "# Merge genes and metastases together\n",
    "Two versions of the dataset are maintained. \n",
    "- The \"actual\" dataset merges together the genes and metastases together without including primary site information. As these sites are useful from a predictive modelling perspective, these views are mostly only useful as a base for constructing semi-synthetic data. \n",
    "- The \"actualsmall\" dataset includes primary site information and is mostly useful when extracting causal $p$-values on the actual dataset that have biologically interesting implications. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "647955a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#full dataset\n",
    "genes = pd.read_csv(datadir + 'missing_gene_alteration_matrix.csv',index_col=0)\n",
    "metastases = pd.read_csv(datadir + 'metastasis_bysamples.csv',index_col=0)\n",
    "ds = genes.merge(metastases,on=\"uid\",how=\"left\")\n",
    "ds.to_pickle(datadir + 'merged_dataset.pkl')\n",
    "nomissing = ds.dropna(axis='columns')\n",
    "ds = nomissing.to_numpy()\n",
    "X = nomissing.iloc[:,:331].to_numpy()\n",
    "Y = nomissing.iloc[:,332:].to_numpy()\n",
    "np.save(datadir + \"actual_x\",X)\n",
    "np.save(datadir + \"actual_y\",Y)\n",
    "\n",
    "#reduced dataset\n",
    "primary_site = pd.read_csv(datadir + 'primary_sites.csv')\n",
    "counts = primary_site.reset_index().groupby([\"PRIMARY_SITE\"]).count()['index'].reset_index()\n",
    "counts = counts[counts['index'] > 200]\n",
    "primary_site = primary_site.merge(counts, how='left',on=\"PRIMARY_SITE\")\n",
    "primary_site['include'] = ~primary_site['index'].isnull()\n",
    "primary_site.index = genes.index\n",
    "primary_site = primary_site[['PRIMARY_SITE','include']]\n",
    "ds = genes.merge(primary_site,on=\"uid\",how=\"left\")\n",
    "ds =ds.merge(metastases,on=\"uid\",how=\"left\")\n",
    "ds = ds.dropna(axis='columns')\n",
    "select = np.where(ds.iloc[:,333])[0]\n",
    "X = ds.iloc[:,:332]\n",
    "X = X.iloc[:,np.where(X.sum(0) > 400)[0]]\n",
    "X = X.iloc[select,:]\n",
    "prim = pd.get_dummies(ds.iloc[select,332])\n",
    "Y = ds.iloc[select,334:]\n",
    "np.save(datadir + \"actualsmall_x\",np.concatenate((X.to_numpy(),prim.iloc[:,1:].to_numpy()),axis=1))\n",
    "np.save(datadir + \"actualsmall_y\",Y)\n",
    "ds.iloc[select,332].to_pickle(datadir + \"actualsmall_primary.pkl\")  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "217ac435",
   "metadata": {},
   "source": [
    "# Create semi-synthetic datasets\n",
    "The naming convention is as follows: \n",
    "* semisynth-numX-numY refers to discrete datasets using real-world confondunding where num_X refers to the numbers of nodes used within the X dataset and num_Y refers to the number of nodes Y dataset. \n",
    "* collidesynth-numX-numY refers to discrete datasets using synthetic confounding where num_X refers to the numbers of nodes used within the X dataset and num_Y refers to the number of nodes Y dataset. Confoundedness parameter set at 0.5. \n",
    "* synth_collide_large_p refers to discrete datasets using synthetic confounding where $p$ refers to the confoundedness parameter. num_X set to 27 and num_y set to 15. \n",
    "* synth_collide_cont_p refers to continuous datasets using synthetic confounding where $p$ refers to the confoundedness parameter. num_X set to 27 and num_y set to 15. \n",
    "\n",
    ". "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db567c99",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "X = ds.iloc[:,:332]\n",
    "X = X.iloc[:,np.where(X.sum(0) > 750)[0]].to_numpy()\n",
    "Y = ds.iloc[:,334:].to_numpy()\n",
    "\n",
    "synth_x, synth_y, synth_true = semi_synth(X, Y, 20, K=20, k=5)\n",
    "np.save(datadir + \"semisynth-20-20_x\",synth_x)\n",
    "np.save(datadir + \"semisynth-20-20_y\",synth_y)\n",
    "np.save(datadir + \"semisynth-20-20_ytrue\",synth_true)\n",
    "\n",
    "synth_x, synth_y, synth_true = semi_synth(X[:,:15], Y[:,:15], 15, K=15, k=4)\n",
    "np.save(datadir + \"semisynth-15-15_x\",synth_x)\n",
    "np.save(datadir + \"semisynth-15-15_y\",synth_y)\n",
    "np.save(datadir + \"semisynth-15-15_ytrue\",synth_true)\n",
    "\n",
    "synth_x, synth_y, synth_true = semi_synth(X[:,:10], Y[:,:10], 10, K=10, k=3)\n",
    "np.save(datadir + \"semisynth-10-10_x\",synth_x)\n",
    "np.save(datadir + \"semisynth-10-10_y\",synth_y)\n",
    "np.save(datadir + \"semisynth-10-10_ytrue\",synth_true)\n",
    "\n",
    "synth_x, synth_y, synth_true = semi_synth(X[:,:5], Y[:,:5], 5, K=5, k=2)\n",
    "np.save(datadir + \"semisynth-5-5_x\",synth_x)\n",
    "np.save(datadir + \"semisynth-5-5_y\",synth_y)\n",
    "np.save(datadir + \"semisynth-5-5_ytrue\",synth_true)\n",
    "\n",
    "collide_synth_x, collide_synth_y, collide_synth_xtrue, collide_synth_ytrue = collider_synth(X[:,:5],prob_collide = 0.3,max_feats=5,K=5,k=2,num_Y=5)\n",
    "np.save(datadir + \"collidesynth-5-5_x\",collide_synth_x)\n",
    "np.save(datadir + \"collidesynth-5-5_y\",collide_synth_y)\n",
    "np.save(datadir + \"collidesynth-5-5_xtrue\",collide_synth_xtrue)\n",
    "np.save(datadir + \"collidesynth-5-5_ytrue\",collide_synth_ytrue)\n",
    "\n",
    "collide_synth_x, collide_synth_y, collide_synth_xtrue, collide_synth_ytrue = collider_synth(X[:,:10],prob_collide = 0.3,max_feats=10,K=10,k=3,num_Y=10)\n",
    "np.save(datadir + \"collidesynth-10-10_x\",collide_synth_x)\n",
    "np.save(datadir + \"collidesynth-10-10_y\",collide_synth_y)\n",
    "np.save(datadir + \"collidesynth-10-10_xtrue\",collide_synth_xtrue)\n",
    "np.save(datadir + \"collidesynth-10-10_ytrue\",collide_synth_ytrue)\n",
    "\n",
    "\n",
    "collide_synth_x, collide_synth_y, collide_synth_xtrue, collide_synth_ytrue = collider_synth(X[:,:15],prob_collide = 0.3,max_feats=15,K=15,k=4,num_Y=15)\n",
    "np.save(datadir + \"collidesynth-15-15_x\",collide_synth_x)\n",
    "np.save(datadir + \"collidesynth-15-15_y\",collide_synth_y)\n",
    "np.save(datadir + \"collidesynth-15-15_xtrue\",collide_synth_xtrue)\n",
    "np.save(datadir + \"collidesynth-15-15_ytrue\",collide_synth_ytrue)\n",
    "\n",
    "collide_synth_x, collide_synth_y, collide_synth_xtrue, collide_synth_ytrue = collider_synth(X[:,:20],prob_collide = 0.3,max_feats=20,K=20,k=5,num_Y=20)\n",
    "np.save(datadir + \"collidesynth-20-20_x\",collide_synth_x)\n",
    "np.save(datadir + \"collidesynth-20-20_y\",collide_synth_y)\n",
    "np.save(datadir + \"collidesynth-20-20_xtrue\",collide_synth_xtrue)\n",
    "np.save(datadir + \"collidesynth-20-20_ytrue\",collide_synth_ytrue)\n",
    "\n",
    "\n",
    "for i in range(0,10,1):\n",
    "    collide_synth_x, collide_synth_y, collide_synth_xtrue, collide_synth_ytrue = collider_synth(X,num_Y=15,prob_collide = i/10)\n",
    "    np.save(datadir + \"synth_collide_large_\" +str(i/10) + \"_x\",collide_synth_x)\n",
    "    np.save(datadir + \"synth_collide_large_\" +str(i/10)  + \"_y\",collide_synth_y)\n",
    "    np.save(datadir + \"synth_collide_large_\" +str(i/10) + \"_xtrue\",collide_synth_xtrue)\n",
    "    np.save(datadir + \"synth_collide_large_\" +str(i/10) + \"_ytrue\",collide_synth_ytrue)\n",
    "    np.savetxt(datadir + \"synth_collide_large_\" +str(i/10) + \"_x.csv\",collide_synth_x, delimiter=\",\")\n",
    "    np.savetxt(datadir + \"synth_collide_large_\" +str(i/10) + \"_y.csv\",collide_synth_y, delimiter=\",\")\n",
    "    np.savetxt(datadir + \"synth_collide_large_\" +str(i/10) + \"_xtrue.csv\",collide_synth_xtrue, delimiter=\",\")\n",
    "    \n",
    "\n",
    "\n",
    "for i in range(0,10,1):\n",
    "    collide_synth_x, collide_synth_y, collide_synth_response, collide_synth_xtrue, collide_synth_ytrue = collider_cont(X,num_Y=15,prob_collide = i/10)\n",
    "    np.save(datadir + \"synth_collide_cont_\" +str(i/10) + \"_x\",collide_synth_x)\n",
    "    np.save(datadir + \"synth_collide_cont_\" +str(i/10)  + \"_y\",collide_synth_y)\n",
    "    np.save(datadir + \"synth_collide_cont_\" +str(i/10) + \"_xtrue\",collide_synth_xtrue)\n",
    "    np.save(datadir + \"synth_collide_cont_\" +str(i/10) + \"_ytrue\",collide_synth_ytrue)\n",
    "    np.savetxt(datadir + \"synth_collide_cont_\" +str(i/10) + \"_x.csv\",collide_synth_x, delimiter=\",\")\n",
    "    np.savetxt(datadir + \"synth_collide_cont_\" +str(i/10) + \"_y.csv\",collide_synth_y, delimiter=\",\")\n",
    "    np.savetxt(datadir + \"synth_collide_cont_\" +str(i/10) + \"_xtrue.csv\",collide_synth_xtrue, delimiter=\",\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
