{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72b7f0f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from odeformer.DataGenKinetics import *\n",
    "from odeformer.model.embedders import LinearPointEmbedder\n",
    "from odeformer.model.transformer import TransformerModel\n",
    "from odeformer.envs.encoders import FloatSequences\n",
    "from odeformer.trainer import Trainer, autocast_wrapper\n",
    "from odeformer.envs.environment import FunctionEnvironment\n",
    "from odeformer.envs.encoders import Equation\n",
    "from evaluate import Evaluator, setup_odeformer\n",
    "from odeformer.model.model_wrapper import ModelWrapper\n",
    "from odeformer.utils import to_cuda\n",
    "from odeformer.model.model_eval import SymbolicTransformerRegressor\n",
    "from odeformer.metrics import r2_score\n",
    "from odeformer.metrics import compute_metrics\n",
    "from odeformer.envs.generators import NodeList\n",
    "\n",
    "\n",
    "from addict import Dict\n",
    "import json\n",
    "from logging import getLogger\n",
    "import torch\n",
    "import tqdm\n",
    "import os\n",
    "import math\n",
    "\n",
    "import networkx as nx\n",
    "import pickle\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from typing_extensions import Literal\n",
    "from typing import List, Union\n",
    "from collections import defaultdict\n",
    "from sklearn.metrics import r2_score, mean_squared_error\n",
    "\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a127308",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "706f4df0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7489c7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "logger = getLogger()\n",
    "with open(\"params.pickle\",\"rb\") as fin:\n",
    "    params = pickle.load(fin)\n",
    "params.max_epoch = 500\n",
    "params.n_steps_per_epoch = 400\n",
    "params.batch_size = 250\n",
    "params.batch_size_eval = 32\n",
    "params.max_dimension = 6\n",
    "params.dump_path = './experiments/debug/v11'\n",
    "params.rescale = False\n",
    "env = FunctionEnvironment(params)\n",
    "modules = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecbfc364",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_dir = \"<YOUR CHECKPOINT DIR>\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e6594d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "dstr = torch.load(ckpt_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "237affe2",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_args = {'beam_size':1,'beam_temperature':0.1}\n",
    "dstr.set_model_args(model_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dda4a31e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data():\n",
    "    dataframes = []\n",
    "    for i in range(1, 5):\n",
    "        filename = f\"<YOUR DATA DIR>\"\n",
    "        with open(filename, 'rb') as f:\n",
    "            df = pickle.load(f)\n",
    "            dataframes.append(df)\n",
    "    data = pd.concat(dataframes, ignore_index=True)\n",
    "    \n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a550ec62",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataframe = read_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7cf10c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_generate(i,dataframes):\n",
    "    batch = {}\n",
    "    batch[\"times\"] = []\n",
    "    batch[\"trajectory\"] = []\n",
    "    batch[\"trajectory_clean\"] = []\n",
    "    batch[\"tree_encoded\"] = []\n",
    "    batch[\"infos\"] = {}\n",
    "    batch[\"infos\"][\"dimension\"] = []\n",
    "    batch[\"infos\"][\"class\"] = []\n",
    "    sample = dataframes.iloc[i:i+1]\n",
    "    batch[\"tree_encoded\"] = sample['tree_encoded'].tolist()\n",
    "    for i in range(1):\n",
    "        time = sample['time'].iloc[0][0]\n",
    "        batch[\"times\"].append(time)\n",
    "        class_number = sample['class'].iloc[i]\n",
    "        substrate = sample[\"substrate\"].iloc[i]\n",
    "        product = sample['product'].iloc[i]\n",
    "        catalyst = sample['catalyst'].iloc[i]\n",
    "        other = sample['other'].iloc[i].transpose(1,0)\n",
    "        trajectory = np.concatenate((np.column_stack((substrate, product, catalyst)), other), axis=1)\n",
    "        dimension = trajectory.shape[1]\n",
    "        batch[\"trajectory\"].append(trajectory)\n",
    "        batch[\"infos\"][\"dimension\"].append(dimension)\n",
    "        batch[\"infos\"][\"class\"].append(class_number)\n",
    "    return batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcad91d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_children(G,node):\n",
    "    return list(G.predecessors(node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "887c4c49",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_parent(G,node):\n",
    "    if list(G.successors(node)) != []:\n",
    "        return list(G.successors(node))[0]\n",
    "    else:\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85ae66b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_numbers(s):\n",
    "    return int(re.findall(r'\\d+', s)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76fa153f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71b05cd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MCTS_tree(object):\n",
    "    def __init__(self, dim, word2id):\n",
    "        self.dim = dim\n",
    "        self.tree = nx.DiGraph()\n",
    "        node_attr = {\"condition\":[],\"prompt\":[],\"ode\":[],\"eq\":None,\"prob\":1,\"Q\":0,\"N\":0,\"state\":\"root\", \"stop\":False}\n",
    "        self.tree.add_node(0,**node_attr)\n",
    "        self.root = 0\n",
    "        self.current_node = 0\n",
    "        self.max_node = 0\n",
    "        self.word2id = word2id\n",
    "        self.stop_idx = False\n",
    "    \n",
    "    def expand(self, data, dstr):\n",
    "        condition = self.tree.nodes[self.current_node][\"condition\"] + self.tree.nodes[self.current_node][\"prompt\"] + self.tree.nodes[self.current_node][\"ode\"]\n",
    "        if self.tree.nodes[self.current_node][\"ode\"] != []:\n",
    "            condition += [\"|\"]\n",
    "        prompts = [\"[dx{}]\".format(j) for j in range(self.dim) if \"[dx{}]\".format(j) not in condition]\n",
    "        # print(condition)\n",
    "        # print(prompts)\n",
    "        if len(prompts) == 1:\n",
    "            self.stop_idx = True\n",
    "            self.tree.nodes[self.current_node][\"stop\"] = True\n",
    "        else:\n",
    "            for j in range(len(prompts)):\n",
    "                try:\n",
    "                    a,b,lgts = dstr.fit(data[\"times\"][0], data[\"trajectory\"][0], condition = condition, prompt = prompts[j])\n",
    "                    node = a[0][0]\n",
    "                    tree = [item for item in b[0] if item not in ['<EOS>','<PAD>']]\n",
    "                    assert len(tree) == len(lgts[0][:-1])\n",
    "                    probs = [torch.nn.Softmax()(item[0]) for item in lgts[0][:-1]] \n",
    "                    scr = np.prod([probs[k][self.word2id[item]].item() for k,item in enumerate(tree)])\n",
    "                    scr *= torch.nn.Softmax()(lgts[0][-1][0])[73].item()\n",
    "                    node_attr = {\"condition\":condition,\"prompt\":[prompts[j]],\"ode\":tree,\"eq\":node,\"prob\":scr,\"Q\":0,\"N\":0,\"state\":\"leaf\",\"stop\":False}\n",
    "                    self.tree.add_node(self.max_node+j+1,**node_attr)\n",
    "                    self.tree.add_edge(self.max_node+j+1,self.current_node)\n",
    "                except:\n",
    "                    continue\n",
    "            if self.max_node == max(self.tree.nodes):\n",
    "                self.stop_idx = True\n",
    "                self.tree.nodes[self.current_node][\"state\"] = \"error\"\n",
    "            else:\n",
    "                self.tree.nodes[self.current_node][\"state\"] = \"nonleaf\"\n",
    "        self.max_node = max(self.tree.nodes)\n",
    "        \n",
    "    def select(self, c = 1.0):\n",
    "        Vs = get_children(self.tree,self.current_node)\n",
    "        if Vs != []:\n",
    "            Vs_0 = [n for n in Vs if self.tree.nodes[n][\"N\"] == 0]\n",
    "            if Vs_0 != []:\n",
    "                selection = Vs_0[np.argmax([self.tree.nodes[n][\"prob\"] for n in Vs_0])]\n",
    "            else:\n",
    "                Q = np.array([self.tree.nodes[n][\"Q\"] for n in Vs])\n",
    "                N = np.array([self.tree.nodes[n][\"N\"] for n in Vs])\n",
    "                N_tot = self.tree.nodes[self.current_node][\"N\"]\n",
    "                P = np.array([self.tree.nodes[n][\"prob\"] for n in Vs])\n",
    "                self.ucb = Q/(N+0.01) + c * P * np.sqrt(N_tot)/(1+N)\n",
    "                selection = Vs[np.argmax(self.ucb)]\n",
    "            self.current_node = selection\n",
    "        else:\n",
    "            self.stop_idx = True\n",
    "            self.tree.nodes[self.current_node][\"stop\"] = True\n",
    "        \n",
    "    def simulate(self, data, dstr):\n",
    "        condition = self.tree.nodes[self.current_node][\"condition\"] + self.tree.nodes[self.current_node][\"prompt\"] + self.tree.nodes[self.current_node][\"ode\"]\n",
    "        if self.tree.nodes[self.current_node][\"ode\"] != []:\n",
    "            condition += [\"|\"]\n",
    "        prompts = [\"[dx{}]\".format(j) for j in range(self.dim) if \"[dx{}]\".format(j) not in condition]\n",
    "        \n",
    "        if prompts == []: \n",
    "            self.tree.nodes[self.current_node][\"stop\"] = True\n",
    "            self.stop_idx = True\n",
    "            nodes = {}\n",
    "            trees = {}\n",
    "            \n",
    "            nn = self.current_node\n",
    "            while True:\n",
    "                if nn == self.root:\n",
    "                    break\n",
    "                idx = extract_numbers(self.tree.nodes[nn][\"prompt\"][0])\n",
    "                nodes[idx] = self.tree.nodes[nn][\"eq\"]\n",
    "                trees[idx] = self.tree.nodes[nn][\"ode\"]\n",
    "                nn = get_parent(self.tree,nn)\n",
    "\n",
    "            try:\n",
    "                assert len(nodes) == self.dim\n",
    "                all_nodes = []\n",
    "                for ii in range(len(nodes)):\n",
    "                    for n in nodes[ii].nodes:\n",
    "                        all_nodes.append(n)\n",
    "                dstr.predictions[0][0] = NodeList(all_nodes)\n",
    "                pred_traj = dstr.predict(data[\"times\"][0], data[\"trajectory\"][0][0])\n",
    "                # dstr.print()\n",
    "                pred_tree_encoded = []\n",
    "                for ii in range(len(trees)):\n",
    "                    pred_tree_encoded += trees[ii] + [\"|\"]\n",
    "                pred_tree_encoded = pred_tree_encoded[:-1]\n",
    "                metric = compute_metrics(pred_traj,data[\"trajectory\"][0], metrics = \"r2\")\n",
    "                score1 = max(0,metric[\"r2\"][0])\n",
    "                r_M = []\n",
    "                for iii in range(pred_traj.shape[1]):\n",
    "                    rr = compute_metrics(pred_traj[:,iii:(iii+1)],data[\"trajectory\"][0][:,iii:(iii+1)], metrics = \"r2\")[\"r2\"][0]\n",
    "                    rr = max(rr, 0)\n",
    "                    r_M.append(rr)\n",
    "                score2 = min(r_M)\n",
    "                score = (score1+score2)/2\n",
    "            except:\n",
    "                score = 0\n",
    "            \n",
    "        else:\n",
    "            np.random.shuffle(prompts)\n",
    "            nodes = {}\n",
    "            trees = {}\n",
    "\n",
    "            for prompt in prompts:\n",
    "                try:\n",
    "                    a,b,lgts = dstr.fit(data[\"times\"][0], data[\"trajectory\"][0], condition = condition, prompt = prompt)\n",
    "                    idx = extract_numbers(prompt)\n",
    "                    nodes[idx] = a[0][0]\n",
    "                    trees[idx] = [item for item in b[0] if item not in ['<EOS>','<PAD>']]\n",
    "                    condition += [prompt]+trees[idx]+[\"|\"]\n",
    "                except:\n",
    "                    continue\n",
    "\n",
    "            nn = self.current_node\n",
    "            while True:\n",
    "                if nn == self.root:\n",
    "                    break\n",
    "                idx = extract_numbers(self.tree.nodes[nn][\"prompt\"][0])\n",
    "                nodes[idx] = self.tree.nodes[nn][\"eq\"]\n",
    "                trees[idx] = self.tree.nodes[nn][\"ode\"]\n",
    "                nn = get_parent(self.tree,nn)\n",
    "\n",
    "            try:\n",
    "                assert len(nodes) == self.dim\n",
    "                all_nodes = []\n",
    "                for ii in range(len(nodes)):\n",
    "                    for n in nodes[ii].nodes:\n",
    "                        all_nodes.append(n)\n",
    "                dstr.predictions[0][0] = NodeList(all_nodes)\n",
    "                pred_traj = dstr.predict(data[\"times\"][0], data[\"trajectory\"][0][0])\n",
    "                # dstr.print()\n",
    "                pred_tree_encoded = []\n",
    "                for ii in range(len(trees)):\n",
    "                    pred_tree_encoded += trees[ii] + [\"|\"]\n",
    "                pred_tree_encoded = pred_tree_encoded[:-1]\n",
    "                metric = compute_metrics(pred_traj,data[\"trajectory\"][0], metrics = \"r2\")\n",
    "                score1 = max(0,metric[\"r2\"][0])\n",
    "                r_M = []\n",
    "                for iii in range(pred_traj.shape[1]):\n",
    "                    rr = compute_metrics(pred_traj[:,iii:(iii+1)],data[\"trajectory\"][0][:,iii:(iii+1)], metrics = \"r2\")[\"r2\"][0]\n",
    "                    rr = max(rr, 0)\n",
    "                    r_M.append(rr)\n",
    "                score2 = min(r_M)\n",
    "                score = (score1+score2)/2\n",
    "            except:\n",
    "                self.tree.nodes[self.current_node][\"state\"] = \"error\"\n",
    "                score = 0\n",
    "            # print(self.current_node,score)\n",
    "        return score\n",
    "    \n",
    "    def backpropogate(self, score):\n",
    "        self.tree.nodes[self.current_node][\"N\"]+=1\n",
    "        self.tree.nodes[self.current_node][\"Q\"]+=score\n",
    "        nn = self.current_node\n",
    "        while True:\n",
    "            nn = get_parent(self.tree, nn)\n",
    "            self.tree.nodes[nn][\"N\"]+=1\n",
    "            self.tree.nodes[nn][\"Q\"]+=score\n",
    "            if nn == self.root:\n",
    "                break\n",
    "        self.reset()\n",
    "        \n",
    "    def reset(self):\n",
    "        self.current_node = self.root\n",
    "        self.stop_idx = False\n",
    "        \n",
    "    def get_graph(self):\n",
    "        return self.tree\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ea2d3b0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "990ceb5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mcts_op(tree, data, dstr, c = 1.0):\n",
    "    if tree.current_node == tree.root:\n",
    "        tree.select(c = c)\n",
    "        mcts_op(tree, data, dstr, c = c)\n",
    "    else:\n",
    "        if get_children(tree.tree,tree.current_node) == []:\n",
    "            if tree.tree.nodes[tree.current_node][\"N\"] == 0:\n",
    "                score = tree.simulate(data, dstr)\n",
    "                tree.backpropogate(score)\n",
    "            else:\n",
    "                tree.expand(data,dstr)\n",
    "                tree.select(c = c)\n",
    "                score = tree.simulate(data, dstr)\n",
    "                tree.backpropogate(score)\n",
    "        else:\n",
    "            tree.select(c = c)\n",
    "            mcts_op(tree, data, dstr, c = c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7995135",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97768847",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mcts_ex(tree, data, dstr, n, c = 1.0):\n",
    "    tree.expand(data,dstr)\n",
    "    for i in tqdm.tqdm(range(n)):\n",
    "        mcts_op(tree, data, dstr, c = c)\n",
    "    return tree, tree.get_graph()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d07f9c1c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46209e4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "c = 1.0\n",
    "n = 100\n",
    "trees, Gs = [],[]\n",
    "for i in tqdm.tqdm(range(len(dataframe))):\n",
    "    data = data_generate(i,dataframe)\n",
    "    dim = data[\"trajectory\"][0].shape[1]\n",
    "    tree = MCTS_tree(dim,env.equation_word2id)\n",
    "    final_tree, G = mcts_ex(tree, data, dstr, n, c)\n",
    "    trees.append(final_tree)\n",
    "    Gs.append(G)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cd011f7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a2e1953",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
