{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46d7d4b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from odeformer.DataGenKinetics import *\n",
    "from odeformer.model.embedders import LinearPointEmbedder\n",
    "from odeformer.model.transformer import TransformerModel\n",
    "from odeformer.envs.encoders import FloatSequences\n",
    "from odeformer.trainer import Trainer, autocast_wrapper\n",
    "from odeformer.envs.environment import FunctionEnvironment\n",
    "from odeformer.envs.encoders import Equation\n",
    "from evaluate import Evaluator, setup_odeformer\n",
    "from odeformer.model.model_wrapper import ModelWrapper\n",
    "from odeformer.utils import to_cuda\n",
    "from odeformer.model.model_eval import SymbolicTransformerRegressor\n",
    "from odeformer.metrics import r2_score\n",
    "from odeformer.metrics import compute_metrics\n",
    "from odeformer.envs.generators import NodeList\n",
    "\n",
    "\n",
    "from addict import Dict\n",
    "import json\n",
    "from logging import getLogger\n",
    "import torch\n",
    "import tqdm\n",
    "import os\n",
    "import math\n",
    "\n",
    "import networkx as nx\n",
    "import pickle\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from typing_extensions import Literal\n",
    "from typing import List, Union\n",
    "from collections import defaultdict\n",
    "from sklearn.metrics import r2_score, mean_squared_error\n",
    "\n",
    "import re\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a15a2edb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85ac33e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_children(G,node):\n",
    "    return list(G.predecessors(node))\n",
    "\n",
    "def get_parent(G,node):\n",
    "    if list(G.successors(node)) != []:\n",
    "        return list(G.successors(node))[0]\n",
    "    else:\n",
    "        return None\n",
    "    \n",
    "def extract_numbers(s):\n",
    "    return int(re.findall(r'\\d+', s)[0])\n",
    "\n",
    "def combine(a,b,c):\n",
    "    if a == \"+\":\n",
    "        return int(b[1:])*10**int(c[1:])\n",
    "    else:\n",
    "        return -int(b[1:])*10**int(c[1:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa2158b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MCTS_tree(object):\n",
    "    def __init__(self, dim, word2id):\n",
    "        self.dim = dim\n",
    "        self.tree = nx.DiGraph()\n",
    "        node_attr = {\"condition\":[],\"prompt\":[],\"ode\":[],\"eq\":None,\"prob\":1,\"Q\":0,\"N\":0,\"state\":\"root\", \"stop\":False}\n",
    "        self.tree.add_node(0,**node_attr)\n",
    "        self.root = 0\n",
    "        self.current_node = 0\n",
    "        self.max_node = 0\n",
    "        self.word2id = word2id\n",
    "        self.stop_idx = False\n",
    "    \n",
    "    def expand(self, data, dstr):\n",
    "        condition = self.tree.nodes[self.current_node][\"condition\"] + self.tree.nodes[self.current_node][\"prompt\"] + self.tree.nodes[self.current_node][\"ode\"]\n",
    "        if self.tree.nodes[self.current_node][\"ode\"] != []:\n",
    "            condition += [\"|\"]\n",
    "        prompts = [\"[dx{}]\".format(j) for j in range(self.dim) if \"[dx{}]\".format(j) not in condition]\n",
    "        # print(condition)\n",
    "        # print(prompts)\n",
    "        if len(prompts) == 1:\n",
    "            self.stop_idx = True\n",
    "            self.tree.nodes[self.current_node][\"stop\"] = True\n",
    "        else:\n",
    "            for j in range(len(prompts)):\n",
    "                try:\n",
    "                    a,b,lgts = dstr.fit(data[\"times\"][0], data[\"trajectory\"][0], condition = condition, prompt = prompts[j])\n",
    "                    node = a[0][0]\n",
    "                    tree = [item for item in b[0] if item not in ['<EOS>','<PAD>']]\n",
    "                    assert len(tree) == len(lgts[0][:-1])\n",
    "                    probs = [torch.nn.Softmax()(item[0]) for item in lgts[0][:-1]] \n",
    "                    scr = np.prod([probs[k][self.word2id[item]].item() for k,item in enumerate(tree)])\n",
    "                    scr *= torch.nn.Softmax()(lgts[0][-1][0])[73].item()\n",
    "                    node_attr = {\"condition\":condition,\"prompt\":[prompts[j]],\"ode\":tree,\"eq\":node,\"prob\":scr,\"Q\":0,\"N\":0,\"state\":\"leaf\",\"stop\":False}\n",
    "                    self.tree.add_node(self.max_node+j+1,**node_attr)\n",
    "                    self.tree.add_edge(self.max_node+j+1,self.current_node)\n",
    "                except:\n",
    "                    continue\n",
    "            if self.max_node == max(self.tree.nodes):\n",
    "                self.stop_idx = True\n",
    "                self.tree.nodes[self.current_node][\"state\"] = \"error\"\n",
    "            else:\n",
    "                self.tree.nodes[self.current_node][\"state\"] = \"nonleaf\"\n",
    "        self.max_node = max(self.tree.nodes)\n",
    "        \n",
    "    def select(self, c = 1.0):\n",
    "        Vs = get_children(self.tree,self.current_node)\n",
    "        if Vs != []:\n",
    "            Vs_0 = [n for n in Vs if self.tree.nodes[n][\"N\"] == 0]\n",
    "            if Vs_0 != []:\n",
    "                selection = Vs_0[np.argmax([self.tree.nodes[n][\"prob\"] for n in Vs_0])]\n",
    "            else:\n",
    "                Q = np.array([self.tree.nodes[n][\"Q\"] for n in Vs])\n",
    "                N = np.array([self.tree.nodes[n][\"N\"] for n in Vs])\n",
    "                N_tot = self.tree.nodes[self.current_node][\"N\"]\n",
    "                P = np.array([self.tree.nodes[n][\"prob\"] for n in Vs])\n",
    "                self.ucb = Q/(N+0.01) + c * P * np.sqrt(N_tot)/(1+N)\n",
    "                selection = Vs[np.argmax(self.ucb)]\n",
    "            self.current_node = selection\n",
    "        else:\n",
    "            self.stop_idx = True\n",
    "            self.tree.nodes[self.current_node][\"stop\"] = True\n",
    "        \n",
    "    def simulate(self, data, dstr):\n",
    "        condition = self.tree.nodes[self.current_node][\"condition\"] + self.tree.nodes[self.current_node][\"prompt\"] + self.tree.nodes[self.current_node][\"ode\"]\n",
    "        if self.tree.nodes[self.current_node][\"ode\"] != []:\n",
    "            condition += [\"|\"]\n",
    "        prompts = [\"[dx{}]\".format(j) for j in range(self.dim) if \"[dx{}]\".format(j) not in condition]\n",
    "        \n",
    "        if prompts == []: \n",
    "            self.tree.nodes[self.current_node][\"stop\"] = True\n",
    "            self.stop_idx = True\n",
    "            nodes = {}\n",
    "            trees = {}\n",
    "            \n",
    "            nn = self.current_node\n",
    "            while True:\n",
    "                if nn == self.root:\n",
    "                    break\n",
    "                idx = extract_numbers(self.tree.nodes[nn][\"prompt\"][0])\n",
    "                nodes[idx] = self.tree.nodes[nn][\"eq\"]\n",
    "                trees[idx] = self.tree.nodes[nn][\"ode\"]\n",
    "                nn = get_parent(self.tree,nn)\n",
    "\n",
    "            try:\n",
    "                assert len(nodes) == self.dim\n",
    "                all_nodes = []\n",
    "                for ii in range(len(nodes)):\n",
    "                    for n in nodes[ii].nodes:\n",
    "                        all_nodes.append(n)\n",
    "                dstr.predictions[0][0] = NodeList(all_nodes)\n",
    "                pred_traj = dstr.predict(data[\"times\"][0], data[\"trajectory\"][0][0])\n",
    "                # dstr.print()\n",
    "                pred_tree_encoded = []\n",
    "                for ii in range(len(trees)):\n",
    "                    pred_tree_encoded += trees[ii] + [\"|\"]\n",
    "                pred_tree_encoded = pred_tree_encoded[:-1]\n",
    "                metric = compute_metrics(pred_traj,data[\"trajectory\"][0], metrics = \"r2\")\n",
    "                score1 = max(0,metric[\"r2\"][0])\n",
    "                r_M = []\n",
    "                for iii in range(pred_traj.shape[1]):\n",
    "                    rr = compute_metrics(pred_traj[:,iii:(iii+1)],data[\"trajectory\"][0][:,iii:(iii+1)], metrics = \"r2\")[\"r2\"][0]\n",
    "                    rr = max(rr, 0)\n",
    "                    r_M.append(rr)\n",
    "                score2 = min(r_M)\n",
    "                score = (score1+score2)/2\n",
    "            except:\n",
    "                score = 0\n",
    "            \n",
    "        else:\n",
    "            np.random.shuffle(prompts)\n",
    "            nodes = {}\n",
    "            trees = {}\n",
    "\n",
    "            for prompt in prompts:\n",
    "                try:\n",
    "                    a,b,lgts = dstr.fit(data[\"times\"][0], data[\"trajectory\"][0], condition = condition, prompt = prompt)\n",
    "                    idx = extract_numbers(prompt)\n",
    "                    nodes[idx] = a[0][0]\n",
    "                    trees[idx] = [item for item in b[0] if item not in ['<EOS>','<PAD>']]\n",
    "                    condition += [prompt]+trees[idx]+[\"|\"]\n",
    "                except:\n",
    "                    continue\n",
    "\n",
    "            nn = self.current_node\n",
    "            while True:\n",
    "                if nn == self.root:\n",
    "                    break\n",
    "                idx = extract_numbers(self.tree.nodes[nn][\"prompt\"][0])\n",
    "                nodes[idx] = self.tree.nodes[nn][\"eq\"]\n",
    "                trees[idx] = self.tree.nodes[nn][\"ode\"]\n",
    "                nn = get_parent(self.tree,nn)\n",
    "\n",
    "            try:\n",
    "                assert len(nodes) == self.dim\n",
    "                all_nodes = []\n",
    "                for ii in range(len(nodes)):\n",
    "                    for n in nodes[ii].nodes:\n",
    "                        all_nodes.append(n)\n",
    "                dstr.predictions[0][0] = NodeList(all_nodes)\n",
    "                pred_traj = dstr.predict(data[\"times\"][0], data[\"trajectory\"][0][0])\n",
    "                # dstr.print()\n",
    "                pred_tree_encoded = []\n",
    "                for ii in range(len(trees)):\n",
    "                    pred_tree_encoded += trees[ii] + [\"|\"]\n",
    "                pred_tree_encoded = pred_tree_encoded[:-1]\n",
    "                metric = compute_metrics(pred_traj,data[\"trajectory\"][0], metrics = \"r2\")\n",
    "                score1 = max(0,metric[\"r2\"][0])\n",
    "                r_M = []\n",
    "                for iii in range(pred_traj.shape[1]):\n",
    "                    rr = compute_metrics(pred_traj[:,iii:(iii+1)],data[\"trajectory\"][0][:,iii:(iii+1)], metrics = \"r2\")[\"r2\"][0]\n",
    "                    rr = max(rr, 0)\n",
    "                    r_M.append(rr)\n",
    "                score2 = min(r_M)\n",
    "                score = (score1+score2)/2\n",
    "            except:\n",
    "                self.tree.nodes[self.current_node][\"state\"] = \"error\"\n",
    "                score = 0\n",
    "            # print(self.current_node,score)\n",
    "        return score\n",
    "    \n",
    "    def backpropogate(self, score):\n",
    "        self.tree.nodes[self.current_node][\"N\"]+=1\n",
    "        self.tree.nodes[self.current_node][\"Q\"]+=score\n",
    "        nn = self.current_node\n",
    "        while True:\n",
    "            nn = get_parent(self.tree, nn)\n",
    "            self.tree.nodes[nn][\"N\"]+=1\n",
    "            self.tree.nodes[nn][\"Q\"]+=score\n",
    "            if nn == self.root:\n",
    "                break\n",
    "        self.reset()\n",
    "        \n",
    "    def reset(self):\n",
    "        self.current_node = self.root\n",
    "        self.stop_idx = False\n",
    "        \n",
    "    def get_graph(self):\n",
    "        return self.tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce211d5e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c171a82",
   "metadata": {},
   "outputs": [],
   "source": [
    "logger = getLogger()\n",
    "\n",
    "with open(\"params.pickle\",\"rb\") as fin:\n",
    "    params = pickle.load(fin)\n",
    "params.max_epoch = 500\n",
    "params.n_steps_per_epoch = 400\n",
    "params.batch_size = 250\n",
    "params.batch_size_eval = 32\n",
    "params.max_dimension = 6\n",
    "params.dump_path = './experiments/debug/v1'\n",
    "params.rescale = False\n",
    "env = FunctionEnvironment(params)\n",
    "modules = {}\n",
    "\n",
    "ckpt_dir = \"<YOUR CHECKPOINT DIR>\"\n",
    "dstr = torch.load(ckpt_dir)\n",
    "model_args = {'beam_size':1,'beam_temperature':None}\n",
    "dstr.set_model_args(model_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bc28d8e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed71d8f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_tree(tree_dir):\n",
    "    with open(tree_dir,\"rb\") as fin:\n",
    "        trees = pickle.load(fin)\n",
    "        trees = [item[1] for item in trees]\n",
    "    return trees"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b19c558",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c137b937",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data(idx1,idx2, data_dir, chunk = 100):\n",
    "    dataframes = []\n",
    "    for i in range(idx1,idx1+1):\n",
    "        filename = f\"<YOUR DATA DIR>\"\n",
    "        with open(filename, 'rb') as f:\n",
    "            df = pickle.load(f)\n",
    "            dataframes.append(df)\n",
    "    data = pd.concat(dataframes, ignore_index=True)\n",
    "    data = data.iloc[(j-1)*chunk:j*chunk]\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb64caf3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d35417f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_generate(i,dataframes):\n",
    "    batch = {}\n",
    "    batch[\"times\"] = []\n",
    "    batch[\"trajectory\"] = []\n",
    "    batch[\"trajectory_clean\"] = []\n",
    "    batch[\"tree_encoded\"] = []\n",
    "    batch[\"infos\"] = {}\n",
    "    batch[\"infos\"][\"dimension\"] = []\n",
    "    batch[\"infos\"][\"class\"] = []\n",
    "\n",
    "    sample = dataframes.iloc[i:i+1]\n",
    "    \n",
    "    batch[\"tree_encoded\"] = sample['tree_encoded'].tolist()\n",
    "    for i in range(1):\n",
    "        time = sample['time'].iloc[0][0]\n",
    "        batch[\"times\"].append(time)\n",
    "        class_number = sample['class'].iloc[i]\n",
    "        substrate = sample[\"substrate\"].iloc[i]\n",
    "        product = sample['product'].iloc[i]\n",
    "        catalyst = sample['catalyst'].iloc[i]\n",
    "        other = sample['other'].iloc[i].transpose(1,0)\n",
    "        trajectory = np.concatenate((np.column_stack((substrate, product, catalyst)), other), axis=1)\n",
    "        dimension = trajectory.shape[1]\n",
    "        batch[\"trajectory\"].append(trajectory)\n",
    "        batch[\"infos\"][\"dimension\"].append(dimension)\n",
    "        batch[\"infos\"][\"class\"].append(class_number)\n",
    "    return batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e22cf76",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a14e21e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_conditon(tree):\n",
    "    leaves = [n for n in tree.nodes if tree.nodes[n][\"state\"] == \"leaf\"]\n",
    "    selection = leaves[np.argmax([tree.nodes[nl][\"N\"] for nl in leaves])]\n",
    "    conditions = tree.nodes[selection][\"condition\"]+tree.nodes[selection][\"prompt\"]+tree.nodes[selection][\"ode\"]+[\"|\"]\n",
    "    prompts = [\"[dx{}]\".format(idx) for idx in range(len(get_children(tree,0))) if \"[dx{}]\".format(idx) not in conditions]\n",
    "    np.random.shuffle(prompts)\n",
    "    order = [extract_numbers(item) for item in conditions if \"[dx\" in item] + [extract_numbers(prompt) for prompt in prompts]\n",
    "    return conditions,prompts,order,selection\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fecd34c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cdbc10b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def metric_cal(tree,data,dstr):\n",
    "    condition, prompts, order, selection = generate_conditon(tree)\n",
    "    nodes = {}\n",
    "    for j in range(len(prompts)):\n",
    "        prompt = prompts[j]\n",
    "        idx = extract_numbers(prompt)\n",
    "        a,b,_ = dstr.fit(data[\"times\"][0], data[\"trajectory\"][0], condition = condition, prompt = prompt)\n",
    "        nodes[extract_numbers(prompt)] = a[0][0]\n",
    "        condition += [prompt]+[item for item in b[0] if item not in ['<EOS>','<PAD>']]+[\"|\"]\n",
    "    pred_tree_encoded = condition[:-1]\n",
    "    nd = selection\n",
    "    while True:\n",
    "        if nd == 0:\n",
    "            break\n",
    "        idx = extract_numbers(tree.nodes[nd][\"prompt\"][0])\n",
    "        nodes[idx] = tree.nodes[nd][\"eq\"]\n",
    "        nd = get_parent(tree,nd)\n",
    "    try: \n",
    "        all_nodes = []\n",
    "        for ii in range(len(nodes)):\n",
    "            for n in nodes[ii].nodes:\n",
    "                all_nodes.append(n)\n",
    "        dstr.predictions[0][0] = NodeList(all_nodes)\n",
    "    except:\n",
    "        print(\"error\")\n",
    "    return dstr, pred_tree_encoded, order"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1e43aa3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac88e290",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_all(tree,data,dstr,pred_tree_encoded):\n",
    "    try:\n",
    "        pred_traj = dstr.predict(data[\"times\"][0], data[\"trajectory\"][0][0])\n",
    "        pred_trees = dstr.predictions\n",
    "        new_pte = {}\n",
    "        for item in pred_tree_encoded:\n",
    "            if \"[dx\" in item:\n",
    "                idx = extract_numbers(item)\n",
    "                new_pte[idx] = []\n",
    "            else:\n",
    "                new_pte[idx].append(item)\n",
    "        inter = []\n",
    "        for i in range(len(new_pte)):\n",
    "            if new_pte[i][-1] != \"|\":\n",
    "                inter += new_pte[i]+[\"|\"]\n",
    "            else:\n",
    "                inter += new_pte[i]\n",
    "        pred_tree_encoded = inter\n",
    "\n",
    "        j = 0\n",
    "        pred_tree_encoded_for_compare = []\n",
    "        while j <len(pred_tree_encoded) and '<EOS>' != pred_tree_encoded[j] :\n",
    "            if '+' == pred_tree_encoded[j] or '-' == pred_tree_encoded[j]:\n",
    "                try:\n",
    "                    kkk = combine(pred_tree_encoded[j],pred_tree_encoded[j+1],pred_tree_encoded[j+2])\n",
    "                except:\n",
    "                    kkk = 0\n",
    "                pred_tree_encoded_for_compare.append(kkk)\n",
    "                j = j + 3\n",
    "            else:\n",
    "                pred_tree_encoded_for_compare.append(pred_tree_encoded[j])\n",
    "                j = j + 1\n",
    "\n",
    "        k = 0\n",
    "        tree_encoded_for_compare = []\n",
    "        # print(samples_eval['tree_encoded'][i])\n",
    "        while k <len(data['tree_encoded'][0]) and '<EOS>' != data['tree_encoded'][0][k]:\n",
    "            if '+' == data['tree_encoded'][0][k] or '-' == data['tree_encoded'][0][k]:\n",
    "                try:\n",
    "                    kkk = combine(data['tree_encoded'][0][k],data['tree_encoded'][0][k+1],data['tree_encoded'][0][k+2])\n",
    "                except:\n",
    "                    kkk = 0\n",
    "                tree_encoded_for_compare.append(kkk)\n",
    "                k = k + 3\n",
    "            else:\n",
    "                tree_encoded_for_compare.append(data['tree_encoded'][0][k])\n",
    "                k = k + 1\n",
    "\n",
    "        a = [item if type(item)==str else \"C\" for item in pred_tree_encoded_for_compare]\n",
    "        b = [item if type(item)==str else \"C\" for item in tree_encoded_for_compare]\n",
    "        a = [item for item in a if \"[dx\" not in item][:-1]\n",
    "\n",
    "        if a == b:\n",
    "            max_alpha = 0\n",
    "            for ii in range(len(pred_tree_encoded_for_compare)):\n",
    "                if type(pred_tree_encoded_for_compare[ii]) != str:\n",
    "                    num_1 = pred_tree_encoded_for_compare[ii]\n",
    "                    num_2 = tree_encoded_for_compare[ii]\n",
    "                    alpha = abs(num_1-num_2)/min(num_1,num_2)\n",
    "                    # print(num_1,num_2,alpha)\n",
    "                    if alpha > max_alpha:\n",
    "                        max_alpha = alpha\n",
    "                    # print(\"??\",max_alpha)\n",
    "            if max_alpha > 10:\n",
    "                em = 1\n",
    "            elif max_alpha > 5:\n",
    "                em = 2\n",
    "            elif max_alpha <= 5 and max_alpha > 1:\n",
    "                em = 3\n",
    "            elif max_alpha <= 1 and max_alpha >0.5:\n",
    "                em = 4\n",
    "            elif max_alpha <= 0.5 and max_alpha >0.1:\n",
    "                em = 5\n",
    "            elif max_alpha <= 0.1 and max_alpha >0.01:\n",
    "                em = 6\n",
    "            else:\n",
    "                em = 7\n",
    "        else:\n",
    "            em = 0\n",
    "\n",
    "        metric = compute_metrics(pred_traj,data[\"trajectory\"][0], predicted_tree = pred_trees, metrics = \"r2,is_valid_tree\")   \n",
    "        r2score = metric[\"r2\"][0]\n",
    "\n",
    "        r_M = []\n",
    "        for iii in range(pred_traj.shape[1]):\n",
    "            rr = compute_metrics(pred_traj[:,iii:(iii+1)],data[\"trajectory\"][0][:,iii:(iii+1)], metrics = \"r2\")[\"r2\"][0]\n",
    "            rr = max(rr, 0)\n",
    "            r_M.append(rr)\n",
    "        return em, r2score, r_M\n",
    "    except:\n",
    "        print(\"error 2\")\n",
    "        return -1, 0, [0 for _ in range(data[\"trajectory\"][0].shape[1])]\n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2935ceb2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0547563e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "errors = []\n",
    "results = {}\n",
    "\n",
    "for i in range(1,21):\n",
    "    for j in range(1,6):\n",
    "        tree_dir = \"<YOU TREE DIR>\".format(i,j)\n",
    "        trees = read_tree(tree_dir)\n",
    "        dataframe = read_data(i,j)\n",
    "        results[(i,j)] = {}\n",
    "        results[(i,j)][\"em\"] = []\n",
    "        results[(i,j)][\"r2_m\"] = []\n",
    "        results[(i,j)][\"r2_M\"] = []\n",
    "        results[(i,j)][\"order\"] = []\n",
    "        try:\n",
    "            assert len(trees) == len(dataframe)\n",
    "        except:\n",
    "            errors.append((i,j))\n",
    "        for k in tqdm.tqdm(range(len(trees))):\n",
    "            data = data_generate(k,dataframe)\n",
    "            tree = trees[k]\n",
    "            try:\n",
    "                dstr, pred_tree_encoded, order = metric_cal(tree,data,dstr)\n",
    "                em,r2_m,r2_M = evaluate_all(tree,data,dstr,pred_tree_encoded)\n",
    "            except:\n",
    "                em = -1\n",
    "                r2_m = 0\n",
    "                r2_M = [0]\n",
    "                order = [-1]\n",
    "            results[(i,j)][\"em\"].append(em)\n",
    "            results[(i,j)][\"r2_m\"].append(r2_m)\n",
    "            results[(i,j)][\"r2_M\"].append(r2_M)\n",
    "            results[(i,j)][\"order\"].append(order)\n",
    "\n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "100837b4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5b38b63",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_result(i):\n",
    "    acc_m = 0\n",
    "    acc_M = 0\n",
    "    r2_m = 0\n",
    "    r2_M = 0\n",
    "    cls = 0 \n",
    "    n = 0\n",
    "    for j in range(1,6):\n",
    "        for k in tqdm.tqdm(range(len(results[(i,j)][\"em\"]))):\n",
    "            if results[(i,j)][\"r2_m\"][k] >= 0.9:\n",
    "                acc_m += 1\n",
    "            if np.mean(results[(i,j)][\"r2_M\"][k]) >= 0.9:\n",
    "                acc_M += 1\n",
    "\n",
    "            r2_m += max(0,results[(i,j)][\"r2_m\"][k])\n",
    "            r2_M += max(0,np.mean(results[(i,j)][\"r2_M\"][k]))\n",
    "\n",
    "            if results[(i,j)][\"em\"][k] >= 1:\n",
    "                cls += 1\n",
    "            n += 1\n",
    "    return {\"acc_m\":acc_m/n,\"acc_M\":acc_M/n,\"r2_m\":r2_m/n,\"r2_M\":r2_M/n,\"acc_form\":cls/n}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "941f164a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1132b0ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_cls = {}\n",
    "for i in range(1,21):\n",
    "    results_cls[i] = generate_result(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2db03d90",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3ede978",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "817925c1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "DLogic",
   "language": "python",
   "name": "dlogic"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
