{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8390390",
   "metadata": {},
   "outputs": [],
   "source": [
    "from odeformer.DataGenKinetics import *\n",
    "from odeformer.model.embedders import LinearPointEmbedder\n",
    "from odeformer.model.transformer import TransformerModel\n",
    "from odeformer.envs.encoders import FloatSequences\n",
    "from odeformer.trainer import Trainer, autocast_wrapper\n",
    "from odeformer.envs.environment import FunctionEnvironment\n",
    "from odeformer.envs.encoders import Equation\n",
    "from evaluate import Evaluator, setup_odeformer\n",
    "from odeformer.model.model_wrapper import ModelWrapper\n",
    "from odeformer.utils import to_cuda\n",
    "from odeformer.model.model_eval import SymbolicTransformerRegressor\n",
    "from odeformer.metrics import r2_score\n",
    "from odeformer.metrics import compute_metrics\n",
    "from odeformer.envs.generators import NodeList\n",
    "\n",
    "from addict import Dict\n",
    "import json\n",
    "from logging import getLogger\n",
    "import torch\n",
    "import tqdm\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c22a958",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "488e11dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "logger = getLogger()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "125d304b",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"params.pickle\",\"rb\") as fin:\n",
    "    params = pickle.load(fin)\n",
    "params.max_epoch = 200\n",
    "params.n_steps_per_epoch = 1000\n",
    "params.batch_size = 100\n",
    "params.batch_size_eval = 32\n",
    "params.max_dimension = 6\n",
    "params.dump_path = './experiments/debug/v1'\n",
    "params.rescale = False\n",
    "env = FunctionEnvironment(params)\n",
    "modules = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b1f0a80",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.load(\"odeformer.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11cd1fd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "modules[\"embedder\"] = model.embedder\n",
    "s1,embedding_dim = model.embedder.embeddings.weight.shape\n",
    "num_words = len(env.float_words)\n",
    "pd = torch.zeros(num_words-s1,embedding_dim)\n",
    "wt = torch.cat([model.embedder.embeddings.weight,pd],0)\n",
    "embedding = nn.Embedding(num_words, embedding_dim, padding_idx=model.embedder.embeddings.padding_idx)\n",
    "embedding.weight.data.copy_(wt)\n",
    "embedding.weight.requires_grad = True\n",
    "modules[\"embedder\"].embeddings = embedding\n",
    "modules[\"embedder\"].env = env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "899f6f7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "modules[\"encoder\"] = model.encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffe46542",
   "metadata": {},
   "outputs": [],
   "source": [
    "modules[\"decoder\"] = model.decoder\n",
    "s2,embedding_dim_2 = model.decoder.embeddings.weight.shape\n",
    "num_words2 = len(env.equation_words)\n",
    "pd2 = torch.zeros(num_words2-s2,embedding_dim_2)\n",
    "wt2 = torch.cat([model.decoder.embeddings.weight,pd2],0)\n",
    "embedding2 = nn.Embedding(num_words2, embedding_dim_2, padding_idx=model.decoder.embeddings.padding_idx)\n",
    "embedding2.weight.data.copy_(wt2)\n",
    "embedding2.weight.requires_grad = True\n",
    "modules[\"decoder\"].embeddings = embedding2\n",
    "proj = nn.Linear(modules[\"decoder\"].proj.in_features,num_words2)\n",
    "proj.weight = embedding2.weight\n",
    "modules[\"decoder\"].proj= proj\n",
    "modules[\"decoder\"].id2word = env.equation_id2word\n",
    "modules[\"decoder\"].word2id = {s: i for i, s in env.equation_id2word.items()}\n",
    "modules[\"decoder\"].n_words = len(env.equation_id2word)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7549cf0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "del model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c21ffa65",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k, v in modules.items():\n",
    "    logger.debug(f\"{v}: {v}\")\n",
    "for k, v in modules.items():\n",
    "    logger.info(\n",
    "        f\"Number of parameters ({k}): {sum([p.numel() for p in v.parameters() if p.requires_grad])}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b65b4f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not params.cpu:\n",
    "    for v in modules.values():\n",
    "        v.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ae08db7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09290bc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "env.get_length_after_batching = modules[\"embedder\"].get_length_after_batching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5e4fb9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(modules, env, params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acfe4b73",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_checkpoint(trainer, name):\n",
    "    \"\"\"\n",
    "    Save the model / checkpoints.\n",
    "    \"\"\"\n",
    "    if not trainer.params.is_master:\n",
    "        return\n",
    "    path = os.path.join(trainer.params.dump_path, \"%s.pt\" % name)\n",
    "    logger.info(\"Saving %s to %s ...\" % (name, path))\n",
    "\n",
    "    if not os.path.exists(trainer.params.dump_path):\n",
    "        os.makedirs(trainer.params.dump_path)\n",
    "    \n",
    "    \n",
    "    \n",
    "    embedder = (\n",
    "        trainer.modules[\"embedder\"].module\n",
    "        if trainer.params.multi_gpu\n",
    "        else trainer.modules[\"embedder\"]\n",
    "    )\n",
    "    encoder = (\n",
    "        trainer.modules[\"encoder\"].module\n",
    "        if trainer.params.multi_gpu\n",
    "        else trainer.modules[\"encoder\"]\n",
    "    )\n",
    "    decoder = (\n",
    "        trainer.modules[\"decoder\"].module\n",
    "        if trainer.params.multi_gpu\n",
    "        else trainer.modules[\"decoder\"]\n",
    "    )\n",
    "    embedder.eval()\n",
    "    encoder.eval()\n",
    "    decoder.eval()\n",
    "\n",
    "    model_kwargs = {\n",
    "        'beam_length_penalty': trainer.params.beam_length_penalty,\n",
    "        'beam_size': trainer.params.beam_size,\n",
    "        'max_generated_output_len': trainer.params.max_generated_output_len,\n",
    "        'beam_early_stopping': trainer.params.beam_early_stopping,\n",
    "        'beam_temperature': trainer.params.beam_temperature,\n",
    "        'beam_type': trainer.params.beam_type,\n",
    "    }\n",
    "\n",
    "    mw = ModelWrapper(\n",
    "        env=trainer.env,\n",
    "        embedder=embedder,\n",
    "        encoder=encoder,\n",
    "        decoder=decoder,\n",
    "        **model_kwargs\n",
    "    )\n",
    "    model = SymbolicTransformerRegressor(\n",
    "        path,\n",
    "        model=mw,\n",
    "        from_pretrained=trainer.params.from_pretrained,\n",
    "        max_input_points=trainer.params.max_points,\n",
    "        rescale=trainer.params.rescale,\n",
    "        params=trainer.params,\n",
    "        model_kwargs=model_kwargs,\n",
    "    )\n",
    "    \n",
    "    torch.save(model, path)\n",
    "\n",
    "    dstr = torch.load(os.path.join(trainer.params.dump_path, \"%s.pt\" % (\"kin_ckpt_\"+str(trainer.epoch))))\n",
    "    return dstr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55d60d7e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14136099",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for xxx in range(params.max_epoch):\n",
    "\n",
    "    print(\"============ Starting epoch %i ... ============\" % trainer.epoch)\n",
    "\n",
    "    trainer.inner_epoch = 0\n",
    "\n",
    "    pbar = tqdm.tqdm(total=trainer.n_steps_per_epoch)\n",
    "    for step in range(trainer.n_steps_per_epoch):\n",
    "        # training steps\n",
    "        for task_id in np.random.permutation(len(params.tasks)):\n",
    "            task = params.tasks[task_id]\n",
    "            if params.export_data:\n",
    "                trainer.export_data(task)\n",
    "            else:\n",
    "                loss = trainer.enc_dec_step(task)\n",
    "            trainer.iter()\n",
    "        pbar.set_description(\"Epoch {}: \".format(trainer.epoch)+f'Loss: {loss:.4f}')\n",
    "        pbar.update(1)\n",
    "    pbar.close()\n",
    "\n",
    "    logger.info(\"============ End of epoch %i ============\" % trainer.epoch)\n",
    "    trainer.epoch += 1\n",
    "    \n",
    "    dstr = save_checkpoint(trainer,\"kin_ckpt_\"+str(trainer.epoch))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d97ad36",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86ebf725",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26145c92",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
