{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8390390",
   "metadata": {},
   "outputs": [],
   "source": [
    "from odeformer.DataGenKinetics import *\n",
    "from odeformer.model.embedders import LinearPointEmbedder\n",
    "from odeformer.model.transformer import TransformerModel\n",
    "from odeformer.envs.encoders import FloatSequences\n",
    "from odeformer.trainer import Trainer, autocast_wrapper\n",
    "from odeformer.envs.environment import FunctionEnvironment\n",
    "from odeformer.envs.encoders import Equation\n",
    "from evaluate import Evaluator, setup_odeformer\n",
    "from odeformer.model.model_wrapper import ModelWrapper\n",
    "from odeformer.utils import to_cuda\n",
    "from odeformer.model.model_eval import SymbolicTransformerRegressor\n",
    "from odeformer.metrics import r2_score\n",
    "from odeformer.metrics import compute_metrics\n",
    "from odeformer.envs.generators import NodeList\n",
    "\n",
    "\n",
    "from addict import Dict\n",
    "import json\n",
    "from logging import getLogger\n",
    "import torch\n",
    "import tqdm\n",
    "import os\n",
    "import math\n",
    "\n",
    "import pickle\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "488e11dc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "125d304b",
   "metadata": {},
   "outputs": [],
   "source": [
    "logger = getLogger()\n",
    "with open(\"params.pickle\",\"rb\") as fin:\n",
    "    params = pickle.load(fin)\n",
    "params.max_epoch = 500\n",
    "params.n_steps_per_epoch = 400\n",
    "params.batch_size = 250\n",
    "params.batch_size_eval = 32\n",
    "params.max_dimension = 6\n",
    "params.dump_path = './experiments/debug/v1'\n",
    "params.rescale = False\n",
    "env = FunctionEnvironment(params)\n",
    "modules = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f026411",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_bs = 50\n",
    "ckpt_dir = \"<YOUR CHECKPOINT DIR>\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72fc0686",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataframes = []\n",
    "for i in range(1, 21):\n",
    "    filename = f\"<YOUR DATA DIR>\"\n",
    "    with open(filename, 'rb') as f:\n",
    "        df = pickle.load(f)\n",
    "        dataframes.append(df)\n",
    "dataframes = pd.concat(dataframes, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23bf2a32",
   "metadata": {},
   "outputs": [],
   "source": [
    "def batch_generate_test(batch_size,dataframes):\n",
    "    batch = {}\n",
    "    batch[\"times\"] = []\n",
    "    batch[\"trajectory\"] = []\n",
    "    batch[\"trajectory_clean\"] = []\n",
    "    batch[\"tree_encoded\"] = []\n",
    "    batch[\"infos\"] = {}\n",
    "    batch[\"infos\"][\"dimension\"] = []\n",
    "    batch[\"infos\"][\"class\"] = []\n",
    "\n",
    "    sample = dataframes.sample(n=batch_size)\n",
    "    \n",
    "    batch[\"tree_encoded\"] = sample['tree_encoded'].tolist()\n",
    "    for i in range(batch_size):\n",
    "        time = sample['time'].iloc[0][0]\n",
    "        batch[\"times\"].append(time)\n",
    "        class_number = sample['class'].iloc[i]\n",
    "        substrate = sample[\"substrate\"].iloc[i]\n",
    "        product = sample['product'].iloc[i]\n",
    "        catalyst = sample['catalyst'].iloc[i]\n",
    "        other = sample['other'].iloc[i].transpose(1,0)\n",
    "        trajectory = np.concatenate((np.column_stack((substrate, product, catalyst)), other), axis=1)\n",
    "        dimension = trajectory.shape[1]\n",
    "        batch[\"trajectory\"].append(trajectory)\n",
    "        batch[\"infos\"][\"dimension\"].append(dimension)\n",
    "        batch[\"infos\"][\"class\"].append(class_number)\n",
    "    return batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8889e7ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def combine(a,b,c):\n",
    "    if a == \"+\":\n",
    "        return int(b[1:])*10**int(c[1:])\n",
    "    else:\n",
    "        return -int(b[1:])*10**int(c[1:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6be4c33",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49235c2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_all_con(ckpt_dir, eval_bs,dataframes):\n",
    "\n",
    "    length = len(dataframes)\n",
    "\n",
    "    dstr = torch.load(ckpt_dir)\n",
    "    model_args = {'beam_size':1,'beam_temperature':0.1}\n",
    "    dstr.set_model_args(model_args)\n",
    "    \n",
    "    num_chunks = length // eval_bs\n",
    "    exact_match=[]\n",
    "    pred_gt = []\n",
    "    r2 = []\n",
    "    r2_M = []\n",
    "\n",
    "    for idx in tqdm.tqdm(range(num_chunks)):\n",
    "        samples_eval = batch_generate_test(eval_bs,dataframes)\n",
    "        for i in range(eval_bs):\n",
    "            nodes = {}\n",
    "            trees = {}\n",
    "            prompts = [\"[dx{}]\".format(j) for j in range(samples_eval[\"trajectory\"][i].shape[1])]\n",
    "            condition = []\n",
    "            for j in range(samples_eval[\"trajectory\"][i].shape[1]):\n",
    "                prompt = prompts.pop(np.random.choice(range(len(prompts))))\n",
    "                idx = int(prompt[-2])\n",
    "                a, b, _ = dstr.fit(samples_eval[\"times\"][i], samples_eval[\"trajectory\"][i], condition = condition, prompt = prompt)\n",
    "                try:\n",
    "                    nodes[idx] = a[0][0]\n",
    "                    trees[idx] = [item for item in b[0] if item not in ['<EOS>','<PAD>']]\n",
    "                    condition += [prompt]+trees[idx]+[\"|\"]\n",
    "                except:\n",
    "                    continue\n",
    "            \n",
    "            try: \n",
    "                all_nodes = []\n",
    "                for ii in range(len(nodes)):\n",
    "                    for n in nodes[ii].nodes:\n",
    "                        all_nodes.append(n)\n",
    "                dstr.predictions[0][0] = NodeList(all_nodes)\n",
    "                pred_tree_encoded = []\n",
    "                for ii in range(len(trees)):\n",
    "                    pred_tree_encoded += trees[ii] + [\"|\"]\n",
    "                pred_tree_encoded = pred_tree_encoded[:-1]\n",
    "\n",
    "            except:\n",
    "                print(\"//////\",len(pred_gt))\n",
    "                pred_gt.append([\"invalid\",samples_eval[\"trajectory\"][i]])\n",
    "                print(\"\\\\\\\\\\\\\",len(pred_gt))\n",
    "                exact_match.append(0)\n",
    "                r2.append(0)\n",
    "                r2_M.append([0 for _ in range(samples_eval[\"trajectory\"][i].shape[1])])\n",
    "                continue\n",
    "            \n",
    "            print(\"============ Generate formula %i ... ============\" % i)\n",
    "            dstr.print()\n",
    "            pred_traj = dstr.predict(samples_eval[\"times\"][i], samples_eval[\"trajectory\"][i][0])\n",
    "            pred_trees = dstr.predictions\n",
    "            \n",
    "            try :\n",
    "                j = 0\n",
    "                pred_tree_encoded_for_compare = []\n",
    "                while j <len(pred_tree_encoded[0]) and '<EOS>' != pred_tree_encoded[0][j] :\n",
    "                    if '+' == pred_tree_encoded[0][j] or '-' == pred_tree_encoded[0][j]:\n",
    "                        try:\n",
    "                            kkk = combine(pred_tree_encoded[0][j],pred_tree_encoded[0][j+1],pred_tree_encoded[0][j+2])\n",
    "                        except:\n",
    "                            kkk = 0\n",
    "                        pred_tree_encoded_for_compare.append(kkk)\n",
    "                        j = j + 3\n",
    "                    else:\n",
    "                        pred_tree_encoded_for_compare.append(pred_tree_encoded[0][j])\n",
    "                        j = j + 1 \n",
    "\n",
    "                k = 0\n",
    "                tree_encoded_for_compare = []\n",
    "                while k <len(samples_eval['tree_encoded'][i]) and '<EOS>' != samples_eval['tree_encoded'][i][k]:\n",
    "                    if '+' == samples_eval['tree_encoded'][i][k] or '-' == samples_eval['tree_encoded'][i][k]:\n",
    "                        try:\n",
    "                            kkk = combine(samples_eval['tree_encoded'][i][k],samples_eval['tree_encoded'][i][k+1],samples_eval['tree_encoded'][i][k+2])\n",
    "                        except:\n",
    "                            kkk = 0\n",
    "                        tree_encoded_for_compare.append(kkk)\n",
    "                        k = k + 3\n",
    "                    else:\n",
    "                        tree_encoded_for_compare.append(samples_eval['tree_encoded'][i][k])\n",
    "                        k = k + 1 \n",
    "\n",
    "                a = [item if type(item)==str else \"C\" for item in pred_tree_encoded_for_compare]\n",
    "                b = [item if type(item)==str else \"C\" for item in tree_encoded_for_compare]\n",
    "                \n",
    "                if a == b:\n",
    "                    max_alpha = 0\n",
    "                    for ii in range(len(pred_tree_encoded_for_compare)):\n",
    "                        if type(pred_tree_encoded_for_compare[ii]) != str:\n",
    "                            num_1 = pred_tree_encoded_for_compare[ii]\n",
    "                            num_2 = tree_encoded_for_compare[ii]\n",
    "                            alpha = abs(num_1-num_2)/min(num_1,num_2)\n",
    "                            # print(num_1,num_2,alpha)\n",
    "                            if alpha > max_alpha:\n",
    "                                max_alpha = alpha\n",
    "                            # print(\"??\",max_alpha)\n",
    "                    if max_alpha > 10:\n",
    "                        exact_match.append(1)\n",
    "                    elif max_alpha > 5:\n",
    "                        exact_match.append(2)\n",
    "                    elif max_alpha <= 5 and max_alpha > 1:\n",
    "                        exact_match.append(3)\n",
    "                    elif max_alpha <= 1 and max_alpha >0.5:\n",
    "                        exact_match.append(4)\n",
    "                    elif max_alpha <= 0.5 and max_alpha >0.1:\n",
    "                        exact_match.append(5)\n",
    "                    elif max_alpha <= 0.1 and max_alpha >0.01:\n",
    "                        exact_match.append(6)\n",
    "                    else:\n",
    "                        exact_match.append(7)\n",
    "                else:\n",
    "                    exact_match.append(0)\n",
    "                metric = compute_metrics(pred_traj,samples_eval[\"trajectory\"][i], predicted_tree = pred_trees, metrics = \"r2,is_valid_tree\")   \n",
    "                r2score = metric[\"r2\"][0]\n",
    "                \n",
    "                r_M_li = []\n",
    "                for iii in range(pred_traj.shape[1]):\n",
    "                    rr = compute_metrics(pred_traj[:,iii:(iii+1)],samples_eval[\"trajectory\"][i][:,iii:(iii+1)], metrics = \"r2\")[\"r2\"][0]\n",
    "                    rr = max(rr, 0)\n",
    "                    r_M_li.append(rr)\n",
    "                r2_M.append(r_M_li)\n",
    "                \n",
    "                print(r2score)\n",
    "                print(metric[\"is_valid_tree\"][0])\n",
    "                if r2score>0:\n",
    "                    r2.append(r2score)\n",
    "                else:\n",
    "                    r2.append(0)\n",
    "                pred_gt.append([pred_traj,samples_eval[\"trajectory\"][i]])\n",
    "                \n",
    "            except:\n",
    "                pred_gt.append([pred_traj,samples_eval[\"trajectory\"][i]])\n",
    "                exact_match.append(0)\n",
    "                r2.append(0)\n",
    "                r2_M.append([0 for _ in range(samples_eval[\"trajectory\"][i].shape[1])])\n",
    "    \n",
    "    average_r2 = sum(r2)/(len(r2)+0.01)\n",
    "    \n",
    "    acc90 = [int(item) for item in np.array(r2) > 0.9]\n",
    "    acc90 = sum(acc90)/(len(acc90)+0.01)\n",
    "    \n",
    "    acc95 = [int(item) for item in np.array(r2) > 0.95]\n",
    "    acc95 = sum(acc95)/(len(acc95)+0.01)\n",
    "    \n",
    "    acc98 = [int(item) for item in np.array(r2) > 0.98]\n",
    "    acc98 = sum(acc98)/(len(acc98)+0.01)\n",
    "    \n",
    "    acc99 = [int(item) for item in np.array(r2) > 0.99]\n",
    "    acc99 = sum(acc99)/(len(acc99)+0.01)\n",
    "    \n",
    "    r2_mean = [np.mean(item) for item in r2_M]\n",
    "    r2_min = [np.min(item) for item in r2_M]\n",
    "    \n",
    "    acc90_M = [int(item) for item in np.array(r2_mean) > 0.9]\n",
    "    acc90_M = sum(acc90_M)/(len(acc90_M)+0.01)\n",
    "    \n",
    "    acc95_M = [int(item) for item in np.array(r2_mean) > 0.95]\n",
    "    acc95_M = sum(acc95_M)/(len(acc95_M)+0.01)\n",
    "    \n",
    "    acc98_M = [int(item) for item in np.array(r2_mean) > 0.98]\n",
    "    acc98_M = sum(acc98_M)/(len(acc98_M)+0.01)\n",
    "    \n",
    "    acc99_M = [int(item) for item in np.array(r2_mean) > 0.99]\n",
    "    acc99_M = sum(acc99_M)/(len(acc99_M)+0.01)\n",
    "    \n",
    "    acc90_min = [int(item) for item in np.array(r2_min) > 0.9]\n",
    "    acc90_min = sum(acc90_min)/(len(acc90_min)+0.01)\n",
    "    \n",
    "    acc95_min = [int(item) for item in np.array(r2_min) > 0.95]\n",
    "    acc95_min = sum(acc95_min)/(len(acc95_min)+0.01)\n",
    "    \n",
    "    acc98_min = [int(item) for item in np.array(r2_min) > 0.98]\n",
    "    acc98_min = sum(acc98_min)/(len(acc98_min)+0.01)\n",
    "    \n",
    "    acc99_min = [int(item) for item in np.array(r2_min) > 0.99]\n",
    "    acc99_min = sum(acc99_min)/(len(acc99_min)+0.01)\n",
    "\n",
    "    \n",
    "    return {\"gt\":pred_gt,\"r2\":r2, \"r2_macro\":r2_M,\"accuracy\":{90:acc90,95:acc95,98:acc98,99:acc99},\n",
    "            \"accuracy_M\":{90:acc90_M,95:acc95_M,98:acc98_M,99:acc99_M},\n",
    "            \"accuracy_min\":{90:acc90_min,95:acc95_min,98:acc98_min,99:acc99_min},\"em\":exact_match}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba4d0f49",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f993feae",
   "metadata": {},
   "outputs": [],
   "source": [
    "result = evaluate_all_con(ckpt_dir, eval_bs,dataframes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7b07d77",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e838e495",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2590184",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07005ef0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "740f38d8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
