{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6392b2da",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import path\n",
    "\n",
    "#WIP module for deep system identification: (deepSI)\n",
    "import deepSI\n",
    "from deepSI import System_data_list, System_data\n",
    "from deepSI.utils import fit_with_early_stopping\n",
    "from deepSI.fit_systems import SS_encoder, SS_encoder_deriv_general #models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb879485",
   "metadata": {},
   "source": [
    "## CT CED"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89b651cb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "&&&&&&&&& ./models-encoder-CT-CED/dis-1000 &&&&&&&&&&&&&&&\n",
      "Starting training with early stopping with settings: stop_frac=66.67% step0=3,000 max_step=None\n",
      "\n",
      "######## Early stopping iteration check: 1 ########\n",
      "\t epochs done: 300.0 steps done: 4,800 last best val loss at step: 4,464\n",
      "\t Current val: 0.449092 Lowest val: 0.437440\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=134.667% < 66.667% = stop_frac\n",
      "\n",
      "######## Early stopping iteration check: 2 ########\n",
      "\t epochs done: 600.0 steps done: 9,600 last best val loss at step: 5,648\n",
      "\t Current val: 0.525420 Lowest val: 0.422011\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=79.667% < 66.667% = stop_frac\n",
      "\n",
      "######## Early stopping iteration check: 3 ########\n",
      "\t epochs done: 900.0 steps done: 14,400 last best val loss at step: 5,648\n",
      "\t Current val: 0.561240 Lowest val: 0.422011\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=53.111% < 66.667% = stop_frac\n",
      "\n",
      "&&&&&&&&& ./models-encoder-CT-CED/con-1000-0.00010 &&&&&&&&&&&&&&&\n",
      "Starting training with early stopping with settings: stop_frac=66.67% step0=3,000 max_step=None\n",
      "\n",
      "######## Early stopping iteration check: 1 ########\n",
      "\t epochs done: 300.0 steps done: 4,800 last best val loss at step: 4,560\n",
      "\t Current val: 0.805035 Lowest val: 0.736379\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=136.667% < 66.667% = stop_frac\n",
      "\n",
      "######## Early stopping iteration check: 2 ########\n",
      "\t epochs done: 600.0 steps done: 9,600 last best val loss at step: 4,560\n",
      "\t Current val: 0.961422 Lowest val: 0.736379\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=68.333% < 66.667% = stop_frac\n",
      "\n",
      "######## Early stopping iteration check: 3 ########\n",
      "\t epochs done: 900.0 steps done: 14,400 last best val loss at step: 4,560\n",
      "\t Current val: 0.797336 Lowest val: 0.736379\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=45.556% < 66.667% = stop_frac\n",
      "\n",
      "&&&&&&&&& ./models-encoder-CT-CED/con-1000-0.00018 &&&&&&&&&&&&&&&\n",
      "Starting training with early stopping with settings: stop_frac=66.67% step0=3,000 max_step=None\n",
      "\n",
      "######## Early stopping iteration check: 1 ########\n",
      "\t epochs done: 300.0 steps done: 4,800 last best val loss at step: 3,424\n",
      "\t Current val: 0.839236 Lowest val: 0.742873\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=113.000% < 66.667% = stop_frac\n",
      "\n",
      "######## Early stopping iteration check: 2 ########\n",
      "\t epochs done: 600.0 steps done: 9,600 last best val loss at step: 3,424\n",
      "\t Current val: 0.806553 Lowest val: 0.742873\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=56.500% < 66.667% = stop_frac\n",
      "\n",
      "&&&&&&&&& ./models-encoder-CT-CED/con-1000-0.00032 &&&&&&&&&&&&&&&\n",
      "Starting training with early stopping with settings: stop_frac=66.67% step0=3,000 max_step=None\n",
      "\n",
      "######## Early stopping iteration check: 1 ########\n",
      "\t epochs done: 300.0 steps done: 4,800 last best val loss at step: 2,480\n",
      "\t Current val: 0.746548 Lowest val: 0.730959\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=93.333% < 66.667% = stop_frac\n",
      "\n",
      "######## Early stopping iteration check: 2 ########\n",
      "\t epochs done: 600.0 steps done: 9,600 last best val loss at step: 2,480\n",
      "\t Current val: 0.765370 Lowest val: 0.730959\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=46.667% < 66.667% = stop_frac\n",
      "\n",
      "&&&&&&&&& ./models-encoder-CT-CED/con-1000-0.00056 &&&&&&&&&&&&&&&\n",
      "Starting training with early stopping with settings: stop_frac=66.67% step0=3,000 max_step=None\n",
      "\n",
      "Stopping early due to a KeyboardInterrupt\n",
      "######## Early stopping iteration check: 1 ########\n",
      "\t epochs done: 14.0 steps done: 224 last best val loss at step: 224\n",
      "\t Current val: 1.092329 Lowest val: 1.092329\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=992.857% < 66.667% = stop_frac\n",
      "\n",
      "Stopping early due to a KeyboardInterrupt\n",
      "######## Early stopping iteration check: 2 ########\n",
      "\t epochs done: 16.0 steps done: 256 last best val loss at step: 224\n",
      "\t Current val: 1.233922 Lowest val: 1.092329\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=868.750% < 66.667% = stop_frac\n",
      "\n",
      "Stopping early due to a KeyboardInterrupt\n",
      "######## Early stopping iteration check: 3 ########\n",
      "\t epochs done: 16.0 steps done: 256 last best val loss at step: 256\n",
      "\t Current val: 1.014077 Lowest val: 1.014077\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=881.250% < 66.667% = stop_frac\n",
      "\n",
      "Stopping early due to a KeyboardInterrupt\n",
      "######## Early stopping iteration check: 4 ########\n",
      "\t epochs done: 16.0 steps done: 256 last best val loss at step: 256\n",
      "\t Current val: 1.046906 Lowest val: 1.014077\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=881.250% < 66.667% = stop_frac\n",
      "\n",
      "Stopping early due to a KeyboardInterrupt\n",
      "######## Early stopping iteration check: 5 ########\n",
      "\t epochs done: 16.0 steps done: 256 last best val loss at step: 256\n",
      "\t Current val: 1.032059 Lowest val: 1.014077\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=881.250% < 66.667% = stop_frac\n",
      "\n",
      "Stopping early due to a KeyboardInterrupt\n",
      "######## Early stopping iteration check: 6 ########\n",
      "\t epochs done: 16.0 steps done: 256 last best val loss at step: 256\n",
      "\t Current val: 1.039566 Lowest val: 1.014077\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=881.250% < 66.667% = stop_frac\n",
      "\n",
      "Stopping early due to a KeyboardInterrupt\n",
      "######## Early stopping iteration check: 7 ########\n",
      "\t epochs done: 16.0 steps done: 256 last best val loss at step: 256\n",
      "\t Current val: 1.039566 Lowest val: 1.014077\n",
      "\t stopping condition: (stop_frac*step0+best_id)/last_id=881.250% < 66.667% = stop_frac\n",
      "\n"
     ]
    }
   ],
   "source": [
    "dataset = 'CED'\n",
    "norm_list = np.geomspace(0.0001, 10, num=21)\n",
    "\n",
    "if dataset=='CED':\n",
    "    nx = 3\n",
    "    na = nb = 5\n",
    "    batch_size = 32\n",
    "    epochs_per_early_stopping_check = 300 #minimum epochs\n",
    "    nf = 30 #T in paper (nf = number of future steps)\n",
    "\n",
    "    data_full = deepSI.datasets.CED()\n",
    "    train = System_data_list([data_i[:300] for data_i in data_full])\n",
    "    test = System_data_list([data_i[300:] for data_i in data_full])\n",
    "    val = System_data_list([t[:100] for t in test])\n",
    "\n",
    "elif dataset=='CCT':\n",
    "    nx = 2\n",
    "    na = nb = 4\n",
    "    batch_size = 64 #there is a bit more data \n",
    "    epochs_per_early_stopping_check = 300 #minimum epochs\n",
    "    nf = 60 #T in paper (nf = number of future steps)\n",
    "\n",
    "    train, test = deepSI.datasets.Cascaded_Tanks()\n",
    "    val, test = test[:len(test)//2], test\n",
    "\n",
    "fit_kwargs = dict(train_sys_data=train, val_sys_data=val, batch_size=batch_size, \\\n",
    "                  epochs=epochs_per_early_stopping_check, verbose=0, loss_kwargs=dict(nf=nf))\n",
    "\n",
    "for i in range(1000,1001):\n",
    "    #discrate time\n",
    "    name = f'./models-encoder-CT-{dataset}/dis-{i}'\n",
    "    print(f'&&&&&&&&& {name} &&&&&&&&&&&&&&&')\n",
    "    if not os.path.exists(name):\n",
    "        sys = SS_encoder(nx=nx, na=na, nb=nb)\n",
    "        fit_with_early_stopping(sys, fit_kwargs=fit_kwargs)\n",
    "        sys.save_system(name)\n",
    "        sys.save_system(name+'-best')\n",
    "        sys.checkpoint_load_system('_last')\n",
    "        sys.save_system(name+'-last')\n",
    "    \n",
    "    for f_norm in norm_list:\n",
    "        name = f'./models-encoder-CT-{dataset}/con-{i}-{f_norm:.5f}'\n",
    "        print(f'&&&&&&&&& {name} &&&&&&&&&&&&&&&')\n",
    "        if not os.path.exists(name):\n",
    "            sys = SS_encoder_deriv_general(nx=nx, na=na, nb=nb, f_norm=f_norm)\n",
    "            fit_with_early_stopping(sys, fit_kwargs=fit_kwargs)\n",
    "            sys.save_system(name)\n",
    "            sys.save_system(name+'-best')\n",
    "            sys.checkpoint_load_system('_last')\n",
    "            sys.save_system(name+'-last')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03daf800",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
