{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generating Cardiogenesis Raw Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyboolnet.file_exchange import bnet2primes, primes2bnet\n",
    "from pyboolnet.prime_implicants import find_constants, create_variables\n",
    "from pyboolnet.repository import get_primes\n",
    "\n",
    "from pyboolnet.repository import get_primes\n",
    "from pyboolnet.state_transition_graphs import create_stg_image\n",
    "from pyboolnet.state_transition_graphs import energy, random_walk, add_style_path, add_style_anonymous, stg2image, \\\n",
    "    primes2stg\n",
    "from pyboolnet.state_transition_graphs import sccgraph2image\n",
    "from pyboolnet.state_transition_graphs import stg2sccgraph, stg2condensationgraph, best_first_reachability\n",
    "\n",
    "from pyboolnet.attractors import compute_attractors_tarjan, compute_attractors, find_attractor_state_by_randomwalk_and_ctl\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Boolean network definition\n",
    "\n",
    "## Note: Foxc1/2 changed to Fox1_2. Nkx2.5 changed to Nkx2_5.\n",
    "\n",
    "grn = \"\"\"\n",
    "    Bmp2,   (!canWnt & exogen_Bmp2_II)\n",
    "    canWnt,    exogen_canWnt_II\n",
    "    Dkk1,   (Mesp1 | (canWnt & !exogen_Bmp2_II))\n",
    "    Fgf8,   (!Mesp1 & (Foxc1_2 | Tbx1))\n",
    "    Foxc1_2,    (canWnt & exogen_canWnt_II)\n",
    "    GATAs,    (Nkx2_5 | Mesp1 | Tbx5)\n",
    "    Isl1,    (Tbx1 | Mesp1 | Fgf8 | (canWnt & exogen_canWnt_II))\n",
    "    Mesp1,    (canWnt & !exogen_Bmp2_II)\n",
    "    Nkx2_5,    ((Isl1 & GATAs) | Tbx1 | (Mesp1 & Dkk1) | (Bmp2 & GATAs) | Tbx5)\n",
    "    Tbx1,    Foxc1_2\n",
    "    Tbx5,    (!(Tbx1 | canWnt) & (Nkx2_5 | Tbx5 | Mesp1) & !(Dkk1 & !(Mesp1 | Tbx5)))\n",
    "    exogen_Bmp2_I,    exogen_Bmp2_I\n",
    "    exogen_Bmp2_II,    exogen_Bmp2_I\n",
    "    exogen_canWnt_I,    exogen_canWnt_I\n",
    "    exogen_canWnt_II,    exogen_canWnt_I\n",
    "    \"\"\"\n",
    "\n",
    "bnet = bnet2primes(grn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# stable attractor states\n",
    "\n",
    "attrs = [a[\"state\"]['str'] for a in compute_attractors(bnet, \"asynchronous\")[\"attractors\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generating trajectories using random walk policy\n",
    "\n",
    "def generate_trajectories_fixed_interval(primes, initial_state, length, interval=10):\n",
    "    # interval -> \"k\" from the paper (number of asynchronous udpates per timestep)\n",
    "\n",
    "    states = []\n",
    "    states_readable = []\n",
    "    actions = []\n",
    "    actions_int= []\n",
    "\n",
    "    x = {}\n",
    "    for name, value in zip(sorted(primes), initial_state):\n",
    "        if value.isdigit():\n",
    "            x[name] = int(value)\n",
    "    initial_state_full = x\n",
    "\n",
    "    states.append(initial_state_full)\n",
    "    states_readable.append(initial_state)\n",
    "    prev_state = initial_state\n",
    "\n",
    "    for i in range(length):\n",
    "        action = np.random.randint(len(initial_state))\n",
    "        actions_int.append(action)\n",
    "        action_full = sorted(primes)[action]\n",
    "        actions.append(action_full)\n",
    "        \n",
    "        new_state = prev_state[:action] + str(1 - int(prev_state[action])) + prev_state[action + 1:]\n",
    "        path = random_walk(bnet, \"asynchronous\", initial_state=new_state, length=interval)\n",
    "        path_readable = np.array([\"\".join([str(x) for x in p.values()]) for p in path])\n",
    "\n",
    "        states.append(path[-1])\n",
    "        states_readable.append(path_readable[-1])\n",
    "        prev_state = path_readable[-1]\n",
    "    \n",
    "    actions_int.append(None)\n",
    "    actions.append(None)\n",
    "\n",
    "    results = {\"states\": states, \"states_readable\": states_readable, \"actions\": actions, \"actions_int\": actions_int}\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_list_fixed_10 = []\n",
    "for i in attrs:\n",
    "    results_list_fixed_10.append(generate_trajectories_fixed_interval(bnet, i, 10000, 10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "states_list = [state for i in range(6) for state in results_list_fixed_10[i][\"states_readable\"]]\n",
    "states_pd = pd.Series(states_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1]),\n",
       " array([0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1]),\n",
       " array([1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0]),\n",
       " array([0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0]),\n",
       " array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n",
       " array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]),\n",
       " array([0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0]),\n",
       " array([0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]),\n",
       " array([1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1]),\n",
       " array([0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1]),\n",
       " array([0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1]),\n",
       " array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n",
       " array([0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1]),\n",
       " array([0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0]),\n",
       " array([0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0]),\n",
       " array([0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1]),\n",
       " array([0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0]),\n",
       " array([1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1]),\n",
       " array([0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1]),\n",
       " array([0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1])]"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# top 20 most represented states in random walk\n",
    "\n",
    "most_represented_readable = list(states_pd.value_counts()[:20].index)\n",
    "most_represented_s = [np.array([int(x) for x in state]) for state in most_represented_readable]\n",
    "most_represented_s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "# split random walk data into trajectories of set length\n",
    "\n",
    "traj_len = 30\n",
    "traj_list = []\n",
    "for i in range(len(attrs)):\n",
    "    trajs_df = pd.DataFrame(results_list_fixed_10[i])\n",
    "    for j in range(len(trajs_df) // traj_len):\n",
    "        traj_df = trajs_df[j * traj_len : j * traj_len + traj_len + 1]\n",
    "        traj_dict = {}\n",
    "        traj_dict[\"observations_r\"] = np.array(traj_df[\"states_readable\"][:-1])\n",
    "        traj_dict[\"next_observations_r\"] = np.array(traj_df[\"states_readable\"][:-1])\n",
    "        traj_dict[\"observations\"] = np.array([np.array([int(x) for x in state_readable]) for state_readable in traj_df[\"states_readable\"][:-1]])\n",
    "        traj_dict[\"next_observations\"] = np.array([np.array([int(x) for x in state_readable]) for state_readable in traj_df[\"states_readable\"][1:]])\n",
    "        traj_dict[\"actions\"] = np.array([int(x) for x in traj_df[\"actions_int\"][:-1]])\n",
    "        e = np.eye(15)\n",
    "        traj_dict[\"actions_onehot\"] = np.array([e[int(action_int)] for action_int in traj_df[\"actions_int\"][:-1]])\n",
    "        traj_dict['timesteps'] = traj_dict[\"observations\"].shape[0]\n",
    "        traj_dict['terminals'] = np.array([False] * (traj_dict['timesteps'] - 1) + [True])\n",
    "        traj_dict['rewards'] = np.zeros(traj_dict[\"actions\"].shape)\n",
    "        traj_dict['costs'] = np.zeros(traj_dict[\"actions\"].shape)\n",
    "        traj_list.append(traj_dict)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "# Write dataset to radt dataset dir\n",
    "with open('dataset/cardiogenesis_60000.pkl', 'wb') as handle:\n",
    "    pickle.dump(traj_list, handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hindsight Relabeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "# Load in raw dataset\n",
    "with open('dataset/cardiogenesis_60000.pkl', 'rb') as f:\n",
    "    # Load the data from the file\n",
    "    data = pickle.load(f)\n",
    "data_copy = deepcopy(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "# avoid centroid sampling from 20 most represented states\n",
    "\n",
    "for traj in data:\n",
    "    random_avoid_i = np.random.choice(np.arange(8))\n",
    "    random_avoid_r = most_represented_readable[random_avoid_i]\n",
    "    random_avoid_s = most_represented_s[random_avoid_i]\n",
    "    if random_avoid_r in traj[\"observations_r\"]:\n",
    "        traj[\"success\"] = 0\n",
    "    else:\n",
    "        traj[\"success\"] = 1\n",
    "    # padding with arbitrarily small buffer to create a \"box\" representation\n",
    "    traj[\"avoid_states\"] = np.array([np.concatenate([random_avoid_s - 0.001, random_avoid_s + 0.001])])\n",
    "\n",
    "for traj2 in data_copy:\n",
    "    traj2[\"success\"] = 1\n",
    "    traj2[\"avoid_states\"] = np.array([])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = data + data_copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "# Export data\n",
    "with open('dataset/cardiogenesis_60000_avoid.pkl', 'wb') as handle:\n",
    "    # Load the data from the file\n",
    "    pickle.dump(data, handle)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
