{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load ckpt models and generate csv files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob, os, time, sys\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow.keras.applications as tka\n",
    "import tensorflow_datasets as tfds\n",
    "import tensorflow_addons as tfa\n",
    "\n",
    "from models.optimizers import OptimizerManager\n",
    "from models.losses import calc_loss, LossManager, CostWeightGenerator\n",
    "from models.backbones import SmallModel\n",
    "from utils.misc import set_gpu_devices, fix_random_seed,\\\n",
    "    config_checker, concat_configs, save_config_as_yaml, load_yaml, concat_configs\n",
    "from utils.utiltensorboard import TensorboardLogger\n",
    "from utils.metrics import logits_to_confmx, confmx_to_macrec,\\\n",
    "    confmx_to_exjac\n",
    "from utils.utiloptuna import run_optuna, run_with_good_params,\\\n",
    "    suggest_parameters, load_config_try, load_config_stat \n",
    "from utils.utilckpt import checkpoint_logger\n",
    "from dataprocess.dataloader import load_mnist\n",
    "from algorithm.alg import calc_perp_para_translation\n",
    "\n",
    "gpu = 0\n",
    "set_gpu_devices(gpu)\n",
    "tf.keras.backend.set_floatx(\"float64\")\n",
    "\n",
    "# User-defined: Directories\n",
    "ROOT_DIR = \"/data/t-miyagawa\"\n",
    "TRUNK_DIR = \"/eom/mnist/ckptlogs\"\n",
    "CKPT_DIR = ROOT_DIR + TRUNK_DIR\n",
    "DATA_DIR = ROOT_DIR + \"/tensorflow_datasets\"\n",
    "CSV_DIR = ROOT_DIR + \"/eom/mnist/csv\"\n",
    "NUM_CLASSES = 10\n",
    "UNITS = [64, 64]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def return_step_ckptpath(ind_lr, ind_wd, ind_fmg):\n",
    "    # Set parameters\n",
    "    if ind_fmg == 0:\n",
    "        subprojs = sorted(glob.glob(CKPT_DIR+\"/Run20220216_try/*\"))\n",
    "        learning_rate = [\"0.1\", \"0.01\", \"0.001\", \"0.0001\", \"1e-05\", \"1e-06\"][ind_lr] ############# =========================\n",
    "        weight_decay = [\"0.01\", \"0.001\"][ind_wd]                                     ############# =========================\n",
    "        flag_modgrad = [\"0\", \"0.1\", \"0.01\", \"0.001\", \"0.0001\"][ind_fmg]              ############# =========================\n",
    "        trial_str = CKPT_DIR+\"/Run20220216_try/_ModGrad{}LR{}WD{}_*\".format(\"False\", learning_rate, weight_decay)\n",
    "        trial_dir = glob.glob(trial_str)\n",
    "\n",
    "    else:\n",
    "        subprojs = sorted(glob.glob(CKPT_DIR+\"/Run20220221_try/*\"))\n",
    "        learning_rate = [\"0.1\", \"0.01\", \"0.001\", \"0.0001\", \"1e-05\", \"1e-06\"][ind_lr] ############# =========================\n",
    "        weight_decay = [\"0.01\", \"0.001\"][ind_wd]                                     ############# =========================\n",
    "        flag_modgrad = [\"0\", \"0.1\", \"0.01\", \"0.001\", \"0.0001\"][ind_fmg]              ############# =========================\n",
    "        trial_str = CKPT_DIR+\"/Run20220221_try/_ModGrad{}{}LR{}WD{}_*\".format(\"True\", flag_modgrad, learning_rate, weight_decay)\n",
    "        trial_dir = glob.glob(trial_str)\n",
    "\n",
    "        \n",
    "    # Assert\n",
    "    assert trial_dir != [], \"No directory found:\" + trial_str\n",
    "    assert len(trial_dir) == 1, \"Multiple directories found: {}\".format(trial_dir)\n",
    "\n",
    "    trial_dir = trial_dir[0]\n",
    "\n",
    "    # Get ckpt paths and their extract #steps\n",
    "    ckpt_paths = glob.glob(trial_dir + \"/*.data*\")\n",
    "    keys = [v[v.rfind(\"_step\")+5: v.rfind(\".data-\")] for v in ckpt_paths]\n",
    "    keys = [int(v[:v.rfind(\"-\")]) for v in keys]\n",
    "    ckpt_paths = [v[:v.rfind(\".\")] for v in ckpt_paths]\n",
    "    print(\"Num of ckpts: \" + str(len(keys)))\n",
    "\n",
    "    # Sort keys and ckpt_paths in #steps\n",
    "    tmpind = np.argsort(keys)\n",
    "    keys = np.array(keys)[tmpind]\n",
    "    ckpt_paths = np.array(ckpt_paths)[tmpind]\n",
    "\n",
    "    # Make up a dict\n",
    "    dc_ckpt_paths = {k:v for k, v in zip(keys, ckpt_paths)}\n",
    "    return dc_ckpt_paths, np.float64(learning_rate), np.float64(weight_decay)\n",
    "\n",
    "# def calc_angupd(weights0, weights1):\n",
    "#     cossim = [-tf.keras.losses.cosine_similarity(\n",
    "#         tf.reshape(v, [-1]), tf.reshape(w, [-1])) for v, w in zip(\n",
    "#         weights0, weights1)]\n",
    "#     return cossim\n",
    "\n",
    "# def calc_angupd_gd(learning_rate, model):\n",
    "\n",
    "# def calc_angupd_gf():\n",
    "\n",
    "def return_optimizer(lr):\n",
    "    kwargs_scheduler = {\"learning_rate\": lr}\n",
    "    kwargs_optimizer = {\"momentum\": 0.0, \"nesterov\": False,}\n",
    "    opt_manager = OptimizerManager(name_scheduler=\"Constant\",\n",
    "        name_optimizer=\"SGD\", kwargs_scheduler=kwargs_scheduler,\n",
    "        kwargs_optimizer=kwargs_optimizer)\n",
    "    optimizer = opt_manager()\n",
    "    return optimizer\n",
    "\n",
    "def return_gradients(model, x, y, weight_decay):\n",
    "    with tf.GradientTape() as tape0:\n",
    "        logits, bottleneck_feat = model(x, training=False)\n",
    "            # (batch, num_classes) and (batch, final_size)\n",
    "        xent = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "            labels=y, logits=logits))\n",
    "        total_loss = xent\n",
    "        for variables in model.trainable_variables:\n",
    "            total_loss += weight_decay * tf.nn.l2_loss(variables)\n",
    "\n",
    "    gradients = tape0.gradient(total_loss, model.trainable_variables)\n",
    "        # [(n-dim weight shape,)] * num of trainable layers\n",
    "    return gradients\n",
    "\n",
    "def calc_Hg(model, x, y, wd1, tmp_weights1, eps):    \n",
    "    # Reset model\n",
    "    model.layers[0].set_weights([tmp_weights1[0].numpy()])\n",
    "    model.layers[2].set_weights([tmp_weights1[1].numpy()])\n",
    "    model.layers[4].set_weights([tmp_weights1[2].numpy()])\n",
    "    \n",
    "    # Plus epsilon grad (forward)\n",
    "    gradients = return_gradients(model, x, y, wd1)\n",
    "    tmp_optimizer.apply_gradients(\n",
    "        zip([-v for v in gradients], model.trainable_variables))\n",
    "    grad_fw = return_gradients(model, x, y, wd1)\n",
    "\n",
    "    # Reset model\n",
    "    model.layers[0].set_weights([tmp_weights1[0].numpy()])\n",
    "    model.layers[2].set_weights([tmp_weights1[1].numpy()])\n",
    "    model.layers[4].set_weights([tmp_weights1[2].numpy()])\n",
    "\n",
    "    # Minus epsilon grad (backward)\n",
    "    gradients = return_gradients(model, x, y, wd1)\n",
    "    tmp_optimizer.apply_gradients(\n",
    "        zip(gradients, model.trainable_variables))\n",
    "    grad_bw = return_gradients(model, x, y, wd1)\n",
    "\n",
    "    # Calc Hg\n",
    "    Hg = [0.5 * (v - w) / eps for v, w in zip(grad_fw, grad_bw)]\n",
    "    #normHg = tf.sqrt(tf.reduce_sum([tf.norm(v)**2 for v in Hg])).numpy()\n",
    "\n",
    "    return Hg\n",
    "\n",
    "def calc_grad(model, x, y, weights, weight_decay):    \n",
    "    # Reset model\n",
    "    model.layers[0].set_weights([weights[0].numpy()])\n",
    "    model.layers[2].set_weights([weights[1].numpy()])\n",
    "    model.layers[4].set_weights([weights[2].numpy()])\n",
    "\n",
    "    # Plus epsilon grad (forward)\n",
    "    gradients = return_gradients(model, x, y, weight_decay)\n",
    "\n",
    "    return gradients\n",
    "\n",
    "# def calc_cau(model, x, y, weights, weight_decay):    #######################################################\n",
    "#     dense1 = tf.identity(weights[1])\n",
    "    \n",
    "#     # Reset model\n",
    "#     model.layers[0].set_weights([weights[0].numpy()])\n",
    "#     model.layers[2].set_weights([weights[1].numpy()])\n",
    "#     model.layers[4].set_weights([weights[2].numpy()])\n",
    "    \n",
    "#     # Plus epsilon grad\n",
    "#     gradients = return_gradients(model, x, y, weight_decay)\n",
    "#     optimizer =  return_optimizer(...learning_rate)\n",
    "#     tmp_optimizer.apply_gradients(\n",
    "#         zip([-v for v in gradients], model.trainable_variables))\n",
    "#     grad_fw = return_gradients(model, x, y, wd1)\n",
    "\n",
    "#     # Reset model\n",
    "#     model.layers[0].set_weights([tmp_weights1[0].numpy()])\n",
    "#     model.layers[2].set_weights([tmp_weights1[1].numpy()])\n",
    "#     model.layers[4].set_weights([tmp_weights1[2].numpy()])\n",
    "\n",
    "#     # Calc Hg\n",
    "#     Hg = [0.5 * (v - w) / eps for v, w in zip(grad_fw, grad_bw)]\n",
    "#     #normHg = tf.sqrt(tf.reduce_sum([tf.norm(v)**2 for v in Hg])).numpy()\n",
    "\n",
    "#     return Hg\n",
    "\n",
    "def interp_Hg(var0, var1, num_add):\n",
    "    \"\"\" var0, var1: arbit tf.Variable\"\"\"\n",
    "    delta = (var1 - var0) / (num_add + 1)\n",
    "    vars_interp = [var0 + i * delta for i in range(num_add+2)]\n",
    "    vars_interp[-1] = var1 # Avoid numerical error\n",
    "    \n",
    "    return vars_interp # list. len=num_add + 2\n",
    "\n",
    "calc_cossim = tf.keras.losses.CosineSimilarity() # For angular update"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fetch ckpt paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# User-defined: Which trial?\n",
    "ind_lr0 = 0               # [\"0.1\", \"0.01\", \"0.001\", \"0.0001\", \"1e-05\", \"1e-06\"] 0, 1, 2, 3\n",
    "ind_wd0 = 0               # [\"0.01\", \"0.001\"] 0, 1\n",
    "ind_fmg0 = 0              # [\"0\", \"0.1\", \"0.01\", \"0.001\", \"0.0001\"] 0 (fix)\n",
    "\n",
    "ind_lr1 = 4               # [\"0.1\", \"0.01\", \"0.001\", \"0.0001\", \"1e-05\", \"1e-06\"] 4 (fix)\n",
    "ind_wd1 = ind_wd0         # [\"0.01\", \"0.001\"] ind_wd0 (fix)\n",
    "ind_fmg1 = 0              # [\"0\", \"0.1\", \"0.01\", \"0.001\", \"0.0001\"] GF\n",
    "#ind_fmg1 = ind_lr0+1      # [\"0\", \"0.1\", \"0.01\", \"0.001\", \"0.0001\"] EoM\n",
    "\n",
    "assert ind_lr0 <= ind_lr1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$t = k \\eta = k \\frac{\\eta}{\\eta^\\prime}\\eta^\\prime (= k^\\prime \\eta^\\prime)$. $\\therefore k^\\prime = k \\frac{\\eta}{\\eta^\\prime}$, where $\\eta =$ `lr0`, $\\eta^\\prime =$ `lr1`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num of ckpts: 100000\n",
      "Num of ckpts: 51993\n",
      "Num plot points: 52\n",
      "LR0: 0.1 LR1: 1e-05 ind_fmg1: 0\n"
     ]
    }
   ],
   "source": [
    "# Fetch steps and ckpt paths\n",
    "dc_ckpt_paths0, lr0, wd0 = return_step_ckptpath(ind_lr0, ind_wd0, ind_fmg0)\n",
    "dc_ckpt_paths1, lr1, wd1 = return_step_ckptpath(ind_lr1, ind_wd1, ind_fmg1)\n",
    "\n",
    "# Calc correspondence between trial 0 and trial 1\n",
    "# See above equations. We need k's and k^prime's.\n",
    "frac = lr0/lr1\n",
    "k = list(dc_ckpt_paths0.keys())\n",
    "kp = [np.round(v * frac) for v in k]\n",
    "\n",
    "for v in kp:\n",
    "    assert v % 1 == 0, v    \n",
    "\n",
    "kp = [int(v) for v in kp]\n",
    "kp_intsec = sorted(list(set(kp).intersection(set(dc_ckpt_paths1.keys()))))\n",
    "\n",
    "assert kp_intsec != [], \"More iterations needed for 0 or 1:\\nk={},\\nkp={}\".format(k, kp)\n",
    "\n",
    "steps_intsec0 = k[:len(kp_intsec)] # k's\n",
    "steps_intsec1 = kp_intsec          # k^prime's\n",
    "print(\"Num plot points:\", len(steps_intsec0))\n",
    "print(\"LR0:\", lr0, \"LR1:\", lr1, \"ind_fmg1:\", ind_fmg1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CSV file name: /data/t-miyagawa/eom/mnist/csv/GDvsGF__LR0WD0.csv\n"
     ]
    }
   ],
   "source": [
    "# Define CSV file name\n",
    "GFvsEoM = \"GF_\" if ind_fmg1 == 0 else \"EoM\" \n",
    "csv_filename = \"GDvs{}_LR{}WD{}.csv\".format(GFvsEoM, ind_lr0, ind_wd0)\n",
    "csvpath = CSV_DIR + \"/\" + csv_filename\n",
    "print(\"CSV file name:\", csvpath)\n",
    "#os.remove(csvpath); sys.exit(0)\n",
    "assert not os.path.exists(csvpath)\n",
    "\n",
    "p0 = dc_ckpt_paths0[0]\n",
    "p0 = p0[p0.find(\"ckptlogs/\")+9: p0.rfind(\"/\")]\n",
    "p1 = dc_ckpt_paths1[0]\n",
    "p1 = p1[p1.find(\"ckptlogs/\")+9: p1.rfind(\"/\")]\n",
    "with open(csvpath, \"a\") as f:\n",
    "    f.write(\"{},{}\\n\".format(p0, p1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "global_step = tf.Variable(0, name=\"global_step\", dtype=tf.int64)\n",
    "model = SmallModel(num_classes=NUM_CLASSES, units=UNITS)\n",
    "\n",
    "kwargs_scheduler = {\"learning_rate\": lr0}\n",
    "kwargs_optimizer = {\"momentum\": 0.0, \"nesterov\": False,}\n",
    "opt_manager = OptimizerManager(name_scheduler=\"Constant\",\n",
    "    name_optimizer=\"SGD\", kwargs_scheduler=kwargs_scheduler,\n",
    "    kwargs_optimizer=kwargs_optimizer)\n",
    "optimizer = opt_manager()\n",
    "\n",
    "ckpt = tf.train.Checkpoint(step=global_step, optimizer=optimizer, net=model)\n",
    "model.build(input_shape=tf.constant(np.ones([1,784])).shape)\n",
    "\n",
    "# For Hessian calc\n",
    "PREPROC = lambda x: x/127.5 - 1\n",
    "dstr = load_mnist(preproc=PREPROC, batch_size=60000, data_dir=DATA_DIR)\n",
    "for cnt, (x, y) in enumerate(dstr):\n",
    "    x = tf.reshape(x, (-1,784))\n",
    "    break\n",
    "eps = np.float64(1e-7) # default: 1e-7\n",
    "tmp_optimizer =  return_optimizer(eps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Calc Empirical and Theoretical Discretization Error"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "${\\bf e}_k = {\\bf e}_{100} + \\frac{\\eta^2}{2} \\sum_{s=100}^{k-1}  ( H (\\mathbf{\\theta}(s\\eta)) + \\lambda I) \\mathbf{g} (\\mathbf{\\theta}(s\\eta)) + O(\\eta^3)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "1/52: GD step=0  \n",
      "norm_thperp0        0.018411463161117456\n",
      "norm_thperp1        0.018411463161117456\n",
      "norm_thperp1-0      0.0\n",
      "norm_thperp1-0/0    0.0\n",
      "normsq_dense10      64.23896020943124\n",
      "normsq_dense11      64.23896020943124\n",
      "normsq_dense1_1-0   0.0\n",
      "normsq_dense1_1-0/0 0.0\n",
      "sum_dense20         -0.4657772691633193\n",
      "sum_dense21         -0.4657772691633193\n",
      "sum_dense2_1-0      0.0\n",
      "sum_dense2_1-0/0    -0.0\n",
      "errex_by_norm_all   0.0\n",
      "errex               0.0\n",
      "errth             0.0\n",
      "\n",
      "2/52: GD step=10  \n",
      "norm_thperp0        0.018228174837115828\n",
      "norm_thperp1        0.018228266037235225\n",
      "norm_thperp1-0      9.120011939742212e-08\n",
      "norm_thperp1-0/0    5.003250199889587e-06\n",
      "normsq_dense10      62.539219303616676\n",
      "normsq_dense11      63.81136217630256\n",
      "normsq_dense1_1-0   1.2721428726858832\n",
      "normsq_dense1_1-0/0 0.02034152147806416\n",
      "sum_dense20         -0.46114040058443884\n",
      "sum_dense21         -0.4611427077852399\n",
      "sum_dense2_1-0      -2.3072008010771583e-06\n",
      "sum_dense2_1-0/0    5.003250199186765e-06\n",
      "errex_by_norm_all   0.09841260909251502\n",
      "errex               1.3756924899463217\n",
      "errth             15.354398847275563\n",
      "\n",
      "3/52: GD step=20  \n",
      "norm_thperp0        0.018046711170362908\n",
      "norm_thperp1        0.01804689175524004\n",
      "norm_thperp1-0      1.8058487713176752e-07\n",
      "norm_thperp1-0/0    1.0006525589456535e-05\n",
      "normsq_dense10      61.75938016966301\n",
      "normsq_dense11      63.30196935800918\n",
      "normsq_dense1_1-0   1.5425891883461702\n",
      "normsq_dense1_1-0/0 0.02497740722313644\n",
      "sum_dense20         -0.4565496925883976\n",
      "sum_dense21         -0.45655426106457764\n",
      "sum_dense2_1-0      -4.568476180022429e-06\n",
      "sum_dense2_1-0/0    1.0006525585685016e-05\n",
      "errex_by_norm_all   0.13480986892207453\n",
      "errex               1.8766481196188805\n",
      "errth             15.352462232471193\n",
      "\n",
      "4/52: GD step=30  \n",
      "norm_thperp0        0.017867053996176763\n",
      "norm_thperp1        0.01786732217754565\n",
      "norm_thperp1-0      2.681813688863721e-07\n",
      "norm_thperp1-0/0    1.5009825847269407e-05\n",
      "normsq_dense10      60.75709531469793\n",
      "normsq_dense11      62.57511897873177\n",
      "normsq_dense1_1-0   1.818023664033845\n",
      "normsq_dense1_1-0/0 0.029922820612427164\n",
      "sum_dense20         -0.452004685641056\n",
      "sum_dense21         -0.4520114701526685\n",
      "sum_dense2_1-0      -6.784511612512034e-06\n",
      "sum_dense2_1-0/0    1.5009825844813745e-05\n",
      "errex_by_norm_all   0.22553424096169877\n",
      "errex               3.135780061068378\n",
      "errth             15.351879415278963\n",
      "\n",
      "5/52: GD step=40  \n",
      "norm_thperp0        0.01768918533070741\n",
      "norm_thperp1        0.017689539347048824\n",
      "norm_thperp1-0      3.54016341415353e-07\n",
      "norm_thperp1-0/0    2.0013151244495185e-05\n",
      "normsq_dense10      59.88009024626988\n",
      "normsq_dense11      61.749328751879155\n",
      "normsq_dense1_1-0   1.8692385056092746\n",
      "normsq_dense1_1-0/0 0.031216360862544215\n",
      "sum_dense20         -0.44750492478299453\n",
      "sum_dense21         -0.4475138807667358\n",
      "sum_dense2_1-0      -8.95598374128781e-06\n",
      "sum_dense2_1-0/0    2.0013151242147275e-05\n",
      "errex_by_norm_all   0.23644585189321782\n",
      "errex               3.2663786124870144\n",
      "errth             15.351613747870452\n",
      "\n",
      "6/52: GD step=50  \n",
      "norm_thperp0        0.017513087369136007\n",
      "norm_thperp1        0.017513525485314305\n",
      "norm_thperp1-0      4.3811617829805916e-07\n",
      "norm_thperp1-0/0    2.501650160611705e-05\n",
      "normsq_dense10      58.91084470566244\n",
      "normsq_dense11      60.87479006920836\n",
      "normsq_dense1_1-0   1.9639453635459176\n",
      "normsq_dense1_1-0/0 0.03333758620095199\n",
      "sum_dense20         -0.4430499595839663\n",
      "sum_dense21         -0.4430610431439903\n",
      "sum_dense2_1-0      -1.1083560023994465e-05\n",
      "sum_dense2_1-0/0    2.5016501602668404e-05\n",
      "errex_by_norm_all   0.2526601031819037\n",
      "errex               3.470535792105991\n",
      "errth             15.351464689956673\n",
      "\n",
      "7/52: GD step=60  \n",
      "norm_thperp0        0.017338742483892857\n",
      "norm_thperp1        0.01733926299080807\n",
      "norm_thperp1-0      5.205069152115771e-07\n",
      "norm_thperp1-0/0    3.0019876914090602e-05\n",
      "normsq_dense10      58.17138872372992\n",
      "normsq_dense11      59.976795557751736\n",
      "normsq_dense1_1-0   1.8054068340218166\n",
      "normsq_dense1_1-0/0 0.03103599335742409\n",
      "sum_dense20         -0.438639344097814\n",
      "sum_dense21         -0.4386525119969331\n",
      "sum_dense2_1-0      -1.3167899119093107e-05\n",
      "sum_dense2_1-0/0    3.0019876913177087e-05\n",
      "errex_by_norm_all   0.26438813426550545\n",
      "errex               3.6156986740639905\n",
      "errth             15.351372223210875\n",
      "\n",
      "8/52: GD step=70  \n",
      "norm_thperp0        0.01716613322289291\n",
      "norm_thperp1        0.01716673443713332\n",
      "norm_thperp1-0      6.012142404095144e-07\n",
      "norm_thperp1-0/0    3.502327708885131e-05\n",
      "normsq_dense10      57.369546754523746\n",
      "normsq_dense11      59.06961098723847\n",
      "normsq_dense1_1-0   1.7000642327147233\n",
      "normsq_dense1_1-0/0 0.029633565696257982\n",
      "sum_dense20         -0.4342726368178278\n",
      "sum_dense21         -0.434287846468719\n",
      "sum_dense2_1-0      -1.5209650891190307e-05\n",
      "sum_dense2_1-0/0    3.50232770884217e-05\n",
      "errex_by_norm_all   0.27443614914615827\n",
      "errex               3.730436938608917\n",
      "errth             15.351311235332428\n",
      "\n",
      "9/52: GD step=80  \n",
      "norm_thperp0        0.01699524230778853\n",
      "norm_thperp1        0.01699592257129044\n",
      "norm_thperp1-0      6.802635019072245e-07\n",
      "norm_thperp1-0/0    4.0026702155077556e-05\n",
      "normsq_dense10      56.535644095753305\n",
      "normsq_dense11      58.16187020780144\n",
      "normsq_dense1_1-0   1.626226112048137\n",
      "normsq_dense1_1-0/0 0.02876461634175123\n",
      "sum_dense20         -0.42994940063254594\n",
      "sum_dense21         -0.4299666100891475\n",
      "sum_dense2_1-0      -1.7209456601552375e-05\n",
      "sum_dense2_1-0/0    4.002670215665761e-05\n",
      "errex_by_norm_all   0.2839668406190708\n",
      "errex               3.835746152796792\n",
      "errth             15.351269308497027\n",
      "\n",
      "10/52: GD step=90  \n",
      "norm_thperp0        0.01682605263224029\n",
      "norm_thperp1        0.016826810311961652\n",
      "norm_thperp1-0      7.576797213630637e-07\n",
      "norm_thperp1-0/0    4.50301528185689e-05\n",
      "normsq_dense10      55.69763524139821\n",
      "normsq_dense11      57.259047797970716\n",
      "normsq_dense1_1-0   1.561412556572506\n",
      "normsq_dense1_1-0/0 0.02803373159031283\n",
      "sum_dense20         -0.42566920278200726\n",
      "sum_dense21         -0.42568837073125776\n",
      "sum_dense2_1-0      -1.916794925049814e-05\n",
      "sum_dense2_1-0/0    4.5030152816374613e-05\n",
      "errex_by_norm_all   0.2921441596747894\n",
      "errex               3.919923809899848\n",
      "errth             15.35123967303559\n",
      "\n",
      "11/52: GD step=100  \n",
      "norm_thperp0        0.01665854726020455\n",
      "norm_thperp1        0.016659380747763897\n",
      "norm_thperp1-0      8.33487559345758e-07\n",
      "norm_thperp1-0/0    5.003362816257506e-05\n",
      "normsq_dense10      54.89692533436661\n",
      "normsq_dense11      56.36472569442425\n",
      "normsq_dense1_1-0   1.4678003600576446\n",
      "normsq_dense1_1-0/0 0.02673738740589851\n",
      "sum_dense20         -0.42143161481443236\n",
      "sum_dense21         -0.42145270056714335\n",
      "sum_dense2_1-0      -2.10857527109809e-05\n",
      "sum_dense2_1-0/0    5.003362816115616e-05\n",
      "errex_by_norm_all   0.29885838445472956\n",
      "errex               3.982083494174276\n",
      "errth             15.351218359657345\n",
      "errth_100stp,     3.982083494174276\n",
      "\n",
      "12/52: GD step=110  \n",
      "norm_thperp0        0.01649270942423753\n",
      "norm_thperp1        0.016493617135603454\n",
      "norm_thperp1-0      9.077113659258784e-07\n",
      "norm_thperp1-0/0    5.503712838061128e-05\n",
      "normsq_dense10      54.09767506090835\n",
      "normsq_dense11      55.48129503867315\n",
      "normsq_dense1_1-0   1.383619977764802\n",
      "normsq_dense1_1-0/0 0.025576329781399846\n",
      "sum_dense20         -0.41723621254331844\n",
      "sum_dense21         -0.4172591760263127\n",
      "sum_dense2_1-0      -2.2963482994242668e-05\n",
      "sum_dense2_1-0/0    5.5037128379307546e-05\n",
      "errex_by_norm_all   0.305074156212617\n",
      "errex               4.035736908436108\n",
      "errth             15.35120289421342\n",
      "errth_100stp,     3.982184188388566\n",
      "\n",
      "13/52: GD step=120  \n",
      "norm_thperp0        0.01632852252381785\n",
      "norm_thperp1        0.016329502898987533\n",
      "norm_thperp1-0      9.80375169684633e-07\n",
      "norm_thperp1-0/0    6.004065390819003e-05\n",
      "normsq_dense10      53.29630390504623\n",
      "normsq_dense11      54.61036850260115\n",
      "normsq_dense1_1-0   1.3140645975549177\n",
      "normsq_dense1_1-0/0 0.024655829790675196\n",
      "sum_dense20         -0.4130825760050021\n",
      "sum_dense21         -0.41310737775298545\n",
      "sum_dense2_1-0      -2.4801747983360656e-05\n",
      "sum_dense2_1-0/0    6.0040653912888174e-05\n",
      "errex_by_norm_all   0.31112658959805645\n",
      "errex               4.085987062580046\n",
      "errth             15.351191652084454\n",
      "errth_100stp,     3.982265218911923\n",
      "\n",
      "14/52: GD step=130  \n",
      "norm_thperp0        0.016165970123684065\n",
      "norm_thperp1        0.01616702162635285\n",
      "norm_thperp1-0      1.05150266878698e-06\n",
      "norm_thperp1-0/0    6.504420463121288e-05\n",
      "normsq_dense10      52.4642381631629\n",
      "normsq_dense11      53.75303389648066\n",
      "normsq_dense1_1-0   1.2887957333177624\n",
      "normsq_dense1_1-0/0 0.02456522344438948\n",
      "sum_dense20         -0.4089702894166045\n",
      "sum_dense21         -0.40899689056379707\n",
      "sum_dense2_1-0      -2.660114719255091e-05\n",
      "sum_dense2_1-0/0    6.504420463035935e-05\n",
      "errex_by_norm_all   0.3166701355993751\n",
      "errex               4.127898273017192\n",
      "errth             15.351183517001987\n",
      "errth_100stp,     3.982331403991444\n",
      "\n",
      "15/52: GD step=140  \n",
      "norm_thperp0        0.016005035952189828\n",
      "norm_thperp1        0.016006157069433997\n",
      "norm_thperp1-0      1.1211172441689954e-06\n",
      "norm_thperp1-0/0    7.004778043098384e-05\n",
      "normsq_dense10      51.666773953859675\n",
      "normsq_dense11      52.91001612932138\n",
      "normsq_dense1_1-0   1.2432421754617025\n",
      "normsq_dense1_1-0/0 0.02406270181629616\n",
      "sum_dense20         -0.40489894113441327\n",
      "sum_dense21         -0.4049273034065384\n",
      "sum_dense2_1-0      -2.8362272125104226e-05\n",
      "sum_dense2_1-0/0    7.004778043044789e-05\n",
      "errex_by_norm_all   0.32274641864450265\n",
      "errex               4.176384588094043\n",
      "errth             15.351177695580027\n",
      "errth_100stp,     3.982386177550067\n",
      "\n",
      "16/52: GD step=150  \n",
      "norm_thperp0        0.015845703899674945\n",
      "norm_thperp1        0.015846893141632113\n",
      "norm_thperp1-0      1.189241957168241e-06\n",
      "norm_thperp1-0/0    7.50513807842034e-05\n",
      "normsq_dense10      50.90863502348292\n",
      "normsq_dense11      52.08178396368467\n",
      "normsq_dense1_1-0   1.1731489402017488\n",
      "normsq_dense1_1-0/0 0.02304420339811907\n",
      "sum_dense20         -0.40086812361267965\n",
      "sum_dense21         -0.40089820931887044\n",
      "sum_dense2_1-0      -3.0085706190785544e-05\n",
      "sum_dense2_1-0/0    7.505138078739948e-05\n",
      "errex_by_norm_all   0.32830073562561635\n",
      "errex               4.2175041490762775\n",
      "errth             15.35117361174957\n",
      "errth_100stp,     3.9824320470950316\n",
      "\n",
      "17/52: GD step=160  \n",
      "norm_thperp0        0.015687958016852813\n",
      "norm_thperp1        0.01568921391643684\n",
      "norm_thperp1-0      1.2558995840274145e-06\n",
      "norm_thperp1-0/0    8.00550066922835e-05\n",
      "normsq_dense10      50.139494558761726\n",
      "normsq_dense11      51.26862229250266\n",
      "normsq_dense1_1-0   1.129127733740937\n",
      "normsq_dense1_1-0/0 0.022519727086950166\n",
      "sum_dense20         -0.3968774333628242\n",
      "sum_dense21         -0.396909205388408\n",
      "sum_dense2_1-0      -3.1772025583798325e-05\n",
      "sum_dense2_1-0/0    8.005500669208478e-05\n",
      "errex_by_norm_all   0.3336083838797173\n",
      "errex               4.253817120672236\n",
      "errth             15.35117083723346\n",
      "errth_100stp,     3.982470879219359\n",
      "\n",
      "18/52: GD step=170  \n",
      "norm_thperp0        0.015531782513213879\n",
      "norm_thperp1        0.015533103625778295\n",
      "norm_thperp1-0      1.3211125644164295e-06\n",
      "norm_thperp1-0/0    8.505865719485028e-05\n",
      "normsq_dense10      49.377707804746926\n",
      "normsq_dense11      50.470682190050056\n",
      "normsq_dense1_1-0   1.0929743853031297\n",
      "normsq_dense1_1-0/0 0.022134976164245045\n",
      "sum_dense20         -0.3929264709130398\n",
      "sum_dense21         -0.3929598927110338\n",
      "sum_dense2_1-0      -3.342179799403766e-05\n",
      "sum_dense2_1-0/0    8.505865719959188e-05\n",
      "errex_by_norm_all   0.3389800071783691\n",
      "errex               4.290491793298792\n",
      "errth             15.351169050538372\n",
      "errth_100stp,     3.9825040853449645\n",
      "\n",
      "19/52: GD step=180  \n",
      "norm_thperp0        0.015377161755445005\n",
      "norm_thperp1        0.015378546658509703\n",
      "norm_thperp1-0      1.3849030646977817e-06\n",
      "norm_thperp1-0/0    9.006233313552755e-05\n",
      "normsq_dense10      48.6149119080321\n",
      "normsq_dense11      49.688016233758276\n",
      "normsq_dense1_1-0   1.0731043257261774\n",
      "normsq_dense1_1-0/0 0.022073563102536043\n",
      "sum_dense20         -0.38901484076831405\n",
      "sum_dense21         -0.389049876352499\n",
      "sum_dense2_1-0      -3.503558418493924e-05\n",
      "sum_dense2_1-0/0    9.006233313809592e-05\n",
      "errex_by_norm_all   0.34421158907770816\n",
      "errex               4.324325534499617\n",
      "errth             15.351168009165125\n",
      "errth_100stp,     3.982532746630261\n",
      "\n",
      "20/52: GD step=190  \n",
      "norm_thperp0        0.01522408026586465\n",
      "norm_thperp1        0.015225527558793299\n",
      "norm_thperp1-0      1.4472929286482583e-06\n",
      "norm_thperp1-0/0    9.506603376844843e-05\n",
      "normsq_dense10      47.88497845957744\n",
      "normsq_dense11      48.9206038311218\n",
      "normsq_dense1_1-0   1.0356253715443628\n",
      "normsq_dense1_1-0/0 0.02162735381448685\n",
      "sum_dense20         -0.3851421513708324\n",
      "sum_dense21         -0.3851787653076011\n",
      "sum_dense2_1-0      -3.66139367686813e-05\n",
      "sum_dense2_1-0/0    9.506603377054861e-05\n",
      "errex_by_norm_all   0.3491562496135361\n",
      "errex               4.354023705613849\n",
      "errth             15.351167529111551\n",
      "errth_100stp,     3.9825577003721535\n",
      "\n",
      "21/52: GD step=200  \n",
      "norm_thperp0        0.01507252272087335\n",
      "norm_thperp1        0.015074031024595537\n",
      "norm_thperp1-0      1.5083037221870488e-06\n",
      "norm_thperp1-0/0    0.00010006975939722802\n",
      "normsq_dense10      47.171182250616724\n",
      "normsq_dense11      48.16836962705655\n",
      "normsq_dense1_1-0   0.9971873764398254\n",
      "normsq_dense1_1-0/0 0.021139757980663882\n",
      "sum_dense20         -0.38130801506078393\n",
      "sum_dense21         -0.38134617246210745\n",
      "sum_dense2_1-0      -3.8157401323513085e-05\n",
      "sum_dense2_1-0/0    0.00010006975939761049\n",
      "errex_by_norm_all   0.35398886750196634\n",
      "errex               4.3813323233746635\n",
      "errth             15.35116747046348\n",
      "errth_100stp,     3.9825796013952064\n",
      "\n",
      "22/52: GD step=210  \n",
      "norm_thperp0        0.014922473949419906\n",
      "norm_thperp1        0.01492404190614057\n",
      "norm_thperp1-0      1.5679567206636907e-06\n",
      "norm_thperp1-0/0    0.00010507351032934075\n",
      "normsq_dense10      46.45567800352728\n",
      "normsq_dense11      47.4311970387011\n",
      "normsq_dense1_1-0   0.9755190351738179\n",
      "normsq_dense1_1-0/0 0.020998919337691054\n",
      "sum_dense20         -0.37751204803756\n",
      "sum_dense21         -0.3775517145536411\n",
      "sum_dense2_1-0      -3.966651608111604e-05\n",
      "sum_dense2_1-0/0    0.00010507351033514428\n",
      "errex_by_norm_all   0.35890364122230917\n",
      "errex               4.408866226993572\n",
      "errth             15.351167725810209\n",
      "errth_100stp,     3.982598966477955\n",
      "\n",
      "23/52: GD step=220  \n",
      "norm_thperp0        0.014773918931482822\n",
      "norm_thperp1        0.014775545204385038\n",
      "norm_thperp1-0      1.626272902215814e-06\n",
      "norm_thperp1-0/0    0.00011007728617965206\n",
      "normsq_dense10      45.772143164151636\n",
      "normsq_dense11      46.70893830870442\n",
      "normsq_dense1_1-0   0.9367951445527822\n",
      "normsq_dense1_1-0/0 0.020466490747290865\n",
      "sum_dense20         -0.3737538703213348\n",
      "sum_dense21         -0.37379501213307753\n",
      "sum_dense2_1-0      -4.114181174275089e-05\n",
      "sum_dense2_1-0/0    0.00011007728617600462\n",
      "errex_by_norm_all   0.36341304504835464\n",
      "errex               4.431093474723285\n",
      "errth             15.351168214009295\n",
      "errth_100stp,     3.982616207035341\n",
      "\n",
      "24/52: GD step=230  \n",
      "norm_thperp0        0.014626842796566744\n",
      "norm_thperp1        0.014628526069530802\n",
      "norm_thperp1-0      1.683272964058194e-06\n",
      "norm_thperp1-0/0    0.00011508108670268178\n",
      "normsq_dense10      45.09842812574574\n",
      "normsq_dense11      46.00142203946329\n",
      "normsq_dense1_1-0   0.9029939137175518\n",
      "normsq_dense1_1-0/0 0.020022735852340087\n",
      "sum_dense20         -0.37003310571502146\n",
      "sum_dense21         -0.3700756895269435\n",
      "sum_dense2_1-0      -4.258381192201455e-05\n",
      "sum_dense2_1-0/0    0.00011508108670365885\n",
      "errex_by_norm_all   0.36783717512117353\n",
      "errex               4.451492503661711\n",
      "errth             15.35116887390933\n",
      "errth_100stp,     3.982631653164528\n",
      "\n",
      "25/52: GD step=240  \n",
      "norm_thperp0        0.014481230822213647\n",
      "norm_thperp1        0.01448296979955389\n",
      "norm_thperp1-0      1.7389773402428504e-06\n",
      "norm_thperp1-0/0    0.00012008491278070967\n",
      "normsq_dense10      44.40721750690567\n",
      "normsq_dense11      45.308458885069555\n",
      "normsq_dense1_1-0   0.9012413781638884\n",
      "normsq_dense1_1-0/0 0.02029493016588437\n",
      "sum_dense20         -0.3663493817666237\n",
      "sum_dense21         -0.3663933748001811\n",
      "sum_dense2_1-0      -4.399303355739903e-05\n",
      "sum_dense2_1-0/0    0.00012008491278258525\n",
      "errex_by_norm_all   0.37279999359323474\n",
      "errex               4.47819853283379\n",
      "errth             15.35116965987523\n",
      "errth_100stp,     3.982645571796242\n",
      "\n",
      "26/52: GD step=250  \n",
      "norm_thperp0        0.014337068432529694\n",
      "norm_thperp1        0.014338861838696657\n",
      "norm_thperp1-0      1.7934061669628437e-06\n",
      "norm_thperp1-0/0    0.00012508876381546347\n",
      "normsq_dense10      43.753199728958855\n",
      "normsq_dense11      44.62984588386652\n",
      "normsq_dense1_1-0   0.8766461549076681\n",
      "normsq_dense1_1-0/0 0.02003616101995493\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sum_dense20         -0.36270232973195116\n",
      "sum_dense21         -0.3627476997180099\n",
      "sum_dense2_1-0      -4.5369986058751977e-05\n",
      "sum_dense2_1-0/0    0.00012508876381434294\n",
      "errex_by_norm_all   0.37777984790403196\n",
      "errex               4.505252952577606\n",
      "errth             15.351170536617577\n",
      "errth_100stp,     3.9826581804648282\n",
      "\n",
      "27/52: GD step=260  \n",
      "norm_thperp0        0.014194341196725612\n",
      "norm_thperp1        0.01419618777604386\n",
      "norm_thperp1-0      1.8465793182476792e-06\n",
      "norm_thperp1-0/0    0.00013009263992284846\n",
      "normsq_dense10      43.11110812259733\n",
      "normsq_dense11      43.96536978074112\n",
      "normsq_dense1_1-0   0.8542616581437912\n",
      "normsq_dense1_1-0/0 0.019815349114072463\n",
      "sum_dense20         -0.3590915845377034\n",
      "sum_dense21         -0.3591382997099113\n",
      "sum_dense2_1-0      -4.6715172207889566e-05\n",
      "sum_dense2_1-0/0    0.00013009263992647154\n",
      "errex_by_norm_all   0.3825165725341224\n",
      "errex               4.528457374496313\n",
      "errth             15.35117147797909\n",
      "errth_100stp,     3.9826696579177865\n",
      "\n",
      "28/52: GD step=270  \n",
      "norm_thperp0        0.014053034827672355\n",
      "norm_thperp1        0.014054933344066024\n",
      "norm_thperp1-0      1.8985163936693222e-06\n",
      "norm_thperp1-0/0    0.00013509654085044198\n",
      "normsq_dense10      42.4877839542974\n",
      "normsq_dense11      43.31480959442136\n",
      "normsq_dense1_1-0   0.8270256401239635\n",
      "normsq_dense1_1-0/0 0.019465021781638826\n",
      "sum_dense20         -0.3555167847449301\n",
      "sum_dense21         -0.3555648138327663\n",
      "sum_dense2_1-0      -4.8029087836187756e-05\n",
      "sum_dense2_1-0/0    0.00013509654085853309\n",
      "errex_by_norm_all   0.3870211089092672\n",
      "errex               4.548461022279574\n",
      "errth             15.351172465515813\n",
      "errth_100stp,     3.9826801521852744\n",
      "\n",
      "29/52: GD step=280  \n",
      "norm_thperp0        0.013913135180471005\n",
      "norm_thperp1        0.013915084417204964\n",
      "norm_thperp1-0      1.9492367339599137e-06\n",
      "norm_thperp1-0/0    0.0001401004668376927\n",
      "normsq_dense10      41.878917734037515\n",
      "normsq_dense11      42.677938618120216\n",
      "normsq_dense1_1-0   0.7990208840827009\n",
      "normsq_dense1_1-0/0 0.01907931072042219\n",
      "sum_dense20         -0.35197757251284845\n",
      "sum_dense21         -0.35202688473507493\n",
      "sum_dense2_1-0      -4.9312222226483016e-05\n",
      "sum_dense2_1-0/0    0.00014010046684063355\n",
      "errex_by_norm_all   0.3913163799919774\n",
      "errex               4.565228755124439\n",
      "errth             15.351173484444633\n",
      "errth_100stp,     3.9826897867000035\n",
      "\n",
      "30/52: GD step=290  \n",
      "norm_thperp0        0.013774628251036852\n",
      "norm_thperp1        0.013776627010449205\n",
      "norm_thperp1-0      1.998759412353765e-06\n",
      "norm_thperp1-0/0    0.00014510441776919195\n",
      "normsq_dense10      41.27858148540605\n",
      "normsq_dense11      42.05452599339175\n",
      "normsq_dense1_1-0   0.7759445079856988\n",
      "normsq_dense1_1-0/0 0.018797751280770932\n",
      "sum_dense20         -0.3484735935630243\n",
      "sum_dense21         -0.3485241586209251\n",
      "sum_dense2_1-0      -5.056505790079768e-05\n",
      "sum_dense2_1-0/0    0.00014510441776602673\n",
      "errex_by_norm_all   0.3956150780181493\n",
      "errex               4.581509179579611\n",
      "errth             15.351174524586874\n",
      "errth_100stp,     3.9826986653583143\n",
      "\n",
      "31/52: GD step=300  \n",
      "norm_thperp0        0.013637500174697429\n",
      "norm_thperp1        0.013639547277947191\n",
      "norm_thperp1-0      2.047103249762719e-06\n",
      "norm_thperp1-0/0    0.00015010839402670347\n",
      "normsq_dense10      40.67425801469711\n",
      "normsq_dense11      41.44433796147531\n",
      "normsq_dense1_1-0   0.7700799467782034\n",
      "normsq_dense1_1-0/0 0.01893285789011677\n",
      "sum_dense20         -0.34500449714390413\n",
      "sum_dense21         -0.34505628521490195\n",
      "sum_dense2_1-0      -5.1788070997815794e-05\n",
      "sum_dense2_1-0/0    0.00015010839402541056\n",
      "errex_by_norm_all   0.4003823360300015\n",
      "errex               4.602978148586969\n",
      "errth             15.351175578226668\n",
      "errth_100stp,     3.9827068761623687\n",
      "\n",
      "32/52: GD step=310  \n",
      "norm_thperp0        0.013501737224804826\n",
      "norm_thperp1        0.013503831511605939\n",
      "norm_thperp1-0      2.094286801113468e-06\n",
      "norm_thperp1-0/0    0.00015511239526021378\n",
      "normsq_dense10      40.085164678695094\n",
      "normsq_dense11      40.84713886998497\n",
      "normsq_dense1_1-0   0.7619741912898732\n",
      "normsq_dense1_1-0/0 0.01900888264767079\n",
      "sum_dense20         -0.34156993599571317\n",
      "sum_dense21         -0.3416229177266339\n",
      "sum_dense2_1-0      -5.298173092072034e-05\n",
      "sum_dense2_1-0/0    0.00015511239525888858\n",
      "errex_by_norm_all   0.4055713605848454\n",
      "errex               4.63008811189797\n",
      "errth             15.351176639106077\n",
      "errth_100stp,     3.982714494162513\n",
      "\n",
      "33/52: GD step=320  \n",
      "norm_thperp0        0.013367325811361517\n",
      "norm_thperp1        0.013369466139732455\n",
      "norm_thperp1-0      2.1403283709375137e-06\n",
      "norm_thperp1-0/0    0.00016011642127540187\n",
      "normsq_dense10      39.52241490395436\n",
      "normsq_dense11      40.26269199304525\n",
      "normsq_dense1_1-0   0.7402770890908883\n",
      "normsq_dense1_1-0/0 0.01873056317256618\n",
      "sum_dense20         -0.3381695663156856\n",
      "sum_dense21         -0.3382237128164287\n",
      "sum_dense2_1-0      -5.41465007430908e-05\n",
      "sum_dense2_1-0/0    0.00016011642127649163\n",
      "errex_by_norm_all   0.4102019003507989\n",
      "errex               4.650138228621145\n",
      "errth             15.35117770276207\n",
      "errth_100stp,     3.982721583911941\n",
      "\n",
      "34/52: GD step=330  \n",
      "norm_thperp0        0.013234252479660017\n",
      "norm_thperp1        0.013236437725683364\n",
      "norm_thperp1-0      2.185246023347562e-06\n",
      "norm_thperp1-0/0    0.00016512047255454055\n",
      "normsq_dense10      38.973188783078655\n",
      "normsq_dense11      39.69076020815874\n",
      "normsq_dense1_1-0   0.717571425080088\n",
      "normsq_dense1_1-0/0 0.018411924902373974\n",
      "sum_dense20         -0.3348030477236552\n",
      "sum_dense21         -0.3348583305611079\n",
      "sum_dense2_1-0      -5.528283745270812e-05\n",
      "sum_dense2_1-0/0    0.0001651204725541755\n",
      "errex_by_norm_all   0.41464581804949846\n",
      "errex               4.667254029438568\n",
      "errth             15.351178765747735\n",
      "errth_100stp,     3.9827282013130922\n",
      "\n",
      "35/52: GD step=340  \n",
      "norm_thperp0        0.013102503908936202\n",
      "norm_thperp1        0.013104732966503378\n",
      "norm_thperp1-0      2.229057567176146e-06\n",
      "norm_thperp1-0/0    0.00017012454891586628\n",
      "normsq_dense10      38.43245649271858\n",
      "normsq_dense11      39.13110656199558\n",
      "normsq_dense1_1-0   0.6986500692770008\n",
      "normsq_dense1_1-0/0 0.01817864724336232\n",
      "sum_dense20         -0.33147004322798157\n",
      "sum_dense21         -0.33152643441956586\n",
      "sum_dense2_1-0      -5.6391191584292955e-05\n",
      "sum_dense2_1-0/0    0.000170124548918913\n",
      "errex_by_norm_all   0.4190872032304856\n",
      "errex               4.683818398837999\n",
      "errth             15.35117982451189\n",
      "errth_100stp,     3.98273439503398\n",
      "\n",
      "36/52: GD step=350  \n",
      "norm_thperp0        0.01297206691103556\n",
      "norm_thperp1        0.012974338691599897\n",
      "norm_thperp1-0      2.271780564336995e-06\n",
      "norm_thperp1-0/0    0.00017512864988418708\n",
      "normsq_dense10      37.900714031589395\n",
      "normsq_dense11      38.58349474903286\n",
      "normsq_dense1_1-0   0.6827807174434639\n",
      "normsq_dense1_1-0/0 0.01801498295980338\n",
      "sum_dense20         -0.32817021919181766\n",
      "sum_dense21         -0.32822769119923567\n",
      "sum_dense2_1-0      -5.7472007418013504e-05\n",
      "sum_dense2_1-0/0    0.00017512864988038644\n",
      "errex_by_norm_all   0.423591845910792\n",
      "errex               4.700797347218304\n",
      "errth             15.351180876240651\n",
      "errth_100stp,     3.9827402076814242\n",
      "\n",
      "37/52: GD step=360  \n",
      "norm_thperp0        0.012842928429093356\n",
      "norm_thperp1        0.012845241861444905\n",
      "norm_thperp1-0      2.3134323515489746e-06\n",
      "norm_thperp1-0/0    0.00018013277612824718\n",
      "normsq_dense10      37.37926120116145\n",
      "normsq_dense11      38.04768952086685\n",
      "normsq_dense1_1-0   0.6684283197053986\n",
      "normsq_dense1_1-0/0 0.017882330956413586\n",
      "sum_dense20         -0.3249032452997054\n",
      "sum_dense21         -0.3249617710232555\n",
      "sum_dense2_1-0      -5.8525723550140185e-05\n",
      "sum_dense2_1-0/0    0.00018013277613202486\n",
      "errex_by_norm_all   0.428132909514997\n",
      "errex               4.717949391360131\n",
      "errth             15.351181919445663\n",
      "errth_100stp,     3.9827456767108136\n",
      "\n",
      "38/52: GD step=370  \n",
      "norm_thperp0        0.012715075536227546\n",
      "norm_thperp1        0.01271742956624242\n",
      "norm_thperp1-0      2.3540300148738175e-06\n",
      "norm_thperp1-0/0    0.0001851369272771334\n",
      "normsq_dense10      36.86774235558566\n",
      "normsq_dense11      37.52345703955485\n",
      "normsq_dense1_1-0   0.6557146839691868\n",
      "normsq_dense1_1-0/0 0.0177855936402312\n",
      "sum_dense20         -0.321668794524526\n",
      "sum_dense21         -0.32172834729674626\n",
      "sum_dense2_1-0      -5.955277222025046e-05\n",
      "sum_dense2_1-0/0    0.0001851369272803669\n",
      "errex_by_norm_all   0.43270159409999726\n",
      "errex               4.735226184822363\n",
      "errth             15.351182952491008\n",
      "errth_100stp,     3.982750835115552\n",
      "\n",
      "39/52: GD step=380  \n",
      "norm_thperp0        0.01258849543424459\n",
      "norm_thperp1        0.012590889024661647\n",
      "norm_thperp1-0      2.393590417057098e-06\n",
      "norm_thperp1-0/0    0.0001901411038006809\n",
      "normsq_dense10      36.35301820599048\n",
      "normsq_dense11      37.01056518510772\n",
      "normsq_dense1_1-0   0.6575469791172424\n",
      "normsq_dense1_1-0/0 0.018087823558179485\n",
      "sum_dense20         -0.31846654309474687\n",
      "sum_dense21         -0.3185270966747744\n",
      "sum_dense2_1-0      -6.055358002754474e-05\n",
      "sum_dense2_1-0/0    0.0001901411038004374\n",
      "errex_by_norm_all   0.43771140861486624\n",
      "errex               4.7577503672513926\n",
      "errth             15.35118397284402\n",
      "errth_100stp,     3.982755712119745\n",
      "\n",
      "40/52: GD step=390  \n",
      "norm_thperp0        0.012463175452358663\n",
      "norm_thperp1        0.01246560758253516\n",
      "norm_thperp1-0      2.4321301764981818e-06\n",
      "norm_thperp1-0/0    0.00019514530512670426\n",
      "normsq_dense10      35.8620270100141\n",
      "normsq_dense11      36.50878382491497\n",
      "normsq_dense1_1-0   0.6467568149008684\n",
      "normsq_dense1_1-0/0 0.018034586129787596\n",
      "sum_dense20         -0.3152961704620223\n",
      "sum_dense21         -0.3153576990294136\n",
      "sum_dense2_1-0      -6.152856739127799e-05\n",
      "sum_dense2_1-0/0    0.00019514530513046354\n",
      "errex_by_norm_all   0.4426377724348677\n",
      "errex               4.779260725131281\n",
      "errth             15.351184978262257\n",
      "errth_100stp,     3.9827603337150466\n",
      "\n",
      "41/52: GD step=400  \n",
      "norm_thperp0        0.012339103045922992\n",
      "norm_thperp1        0.012341572711621202\n",
      "norm_thperp1-0      2.469665698209833e-06\n",
      "norm_thperp1-0/0    0.00020014953185967956\n",
      "normsq_dense10      35.38924072404453\n",
      "normsq_dense11      36.017885051206335\n",
      "normsq_dense1_1-0   0.6286443271618083\n",
      "normsq_dense1_1-0/0 0.01776371332925174\n",
      "sum_dense20         -0.3121573592691038\n",
      "sum_dense21         -0.31221983741842596\n",
      "sum_dense2_1-0      -6.247814932214268e-05\n",
      "sum_dense2_1-0/0    0.0002001495318528809\n",
      "errex_by_norm_all   0.44715187553427393\n",
      "errex               4.795340493841862\n",
      "errth             15.351185967584552\n",
      "errth_100stp,     3.982764723154236\n",
      "\n",
      "42/52: GD step=410  \n",
      "norm_thperp0        0.012216265795174389\n",
      "norm_thperp1        0.012218772008323513\n",
      "norm_thperp1-0      2.5062131491244283e-06\n",
      "norm_thperp1-0/0    0.00020515378358208452\n",
      "normsq_dense10      34.92136356316729\n",
      "normsq_dense11      35.537643391428674\n",
      "normsq_dense1_1-0   0.6162798282613835\n",
      "normsq_dense1_1-0/0 0.017647645033866722\n",
      "sum_dense20         -0.3090497953180731\n",
      "sum_dense21         -0.30911319805289716\n",
      "sum_dense2_1-0      -6.340273482408065e-05\n",
      "sum_dense2_1-0/0    0.0002051537835798492\n",
      "errex_by_norm_all   0.4516800092563279\n",
      "errex               4.811086863355548\n",
      "errth             15.35118693938776\n",
      "errth_100stp,     3.982768901078049\n",
      "\n",
      "43/52: GD step=420  \n",
      "norm_thperp0        0.012094651403989843\n",
      "norm_thperp1        0.012097193192467878\n",
      "norm_thperp1-0      2.541788478034601e-06\n",
      "norm_thperp1-0/0    0.00021015806021462542\n",
      "normsq_dense10      34.45974499479658\n",
      "normsq_dense11      35.067835995511295\n",
      "normsq_dense1_1-0   0.6080910007147153\n",
      "normsq_dense1_1-0/0 0.017646416153298206\n",
      "sum_dense20         -0.30597316753888926\n",
      "sum_dense21         -0.3060374702662574\n",
      "sum_dense2_1-0      -6.430272736812626e-05\n",
      "sum_dense2_1-0/0    0.0002101580602160265\n",
      "errex_by_norm_all   0.4563228532815425\n",
      "errex               4.827936656926845\n",
      "errth             15.35118789280236\n",
      "errth_100stp,     3.982772886064188\n",
      "\n",
      "44/52: GD step=430  \n",
      "norm_thperp0        0.01197424769865567\n",
      "norm_thperp1        0.011976824106074242\n",
      "norm_thperp1-0      2.5764074185717634e-06\n",
      "norm_thperp1-0/0    0.0002151623620464284\n",
      "normsq_dense10      34.010286724970406\n",
      "normsq_dense11      34.608242803288675\n",
      "normsq_dense1_1-0   0.5979560783182691\n",
      "normsq_dense1_1-0/0 0.01758162414665116\n",
      "sum_dense20         -0.3029271679582517\n",
      "sum_dense21         -0.30299234648323736\n",
      "sum_dense2_1-0      -6.517852498566512e-05\n",
      "sum_dense2_1-0/0    0.0002151623620455455\n",
      "errex_by_norm_all   0.46096013025128096\n",
      "errex               4.844602378247002\n",
      "errth             15.35118882712341\n",
      "errth_100stp,     3.98277669490786\n",
      "\n",
      "45/52: GD step=440  \n",
      "norm_thperp0        0.011855042626649107\n",
      "norm_thperp1        0.01185765271213114\n",
      "norm_thperp1-0      2.6100854820330227e-06\n",
      "norm_thperp1-0/0    0.00022016668891310242\n",
      "normsq_dense10      33.572153731974645\n",
      "normsq_dense11      34.158646694774504\n",
      "normsq_dense1_1-0   0.5864929627998592\n",
      "normsq_dense1_1-0/0 0.017469625794107876\n",
      "sum_dense20         -0.2999114916687695\n",
      "sum_dense21         -0.29997752218885854\n",
      "sum_dense2_1-0      -6.60305200890221e-05\n",
      "sum_dense2_1-0/0    0.00022016668891750244\n",
      "errex_by_norm_all   0.465530942008094\n",
      "errex               4.860254488293853\n",
      "errth             15.351189740865562\n",
      "errth_100stp,     3.9827803426182076\n",
      "\n",
      "46/52: GD step=450  \n",
      "norm_thperp0        0.01173702425543148\n",
      "norm_thperp1        0.011739667093392168\n",
      "norm_thperp1-0      2.6428379606875663e-06\n",
      "norm_thperp1-0/0    0.0002251710402203995\n",
      "normsq_dense10      33.141377089242305\n",
      "normsq_dense11      33.71883362547423\n",
      "normsq_dense1_1-0   0.5774565362319279\n",
      "normsq_dense1_1-0/0 0.017424035660225188\n",
      "sum_dense20         -0.2969258367984442\n",
      "sum_dense21         -0.29699269589798316\n",
      "sum_dense2_1-0      -6.685909953896285e-05\n",
      "sum_dense2_1-0/0    0.0002251710402161715\n",
      "errex_by_norm_all   0.4701424096756296\n",
      "errex               4.8761243621111845\n",
      "errth             15.35119063363132\n",
      "errth_100stp,     3.982783842981271\n",
      "\n",
      "47/52: GD step=460  \n",
      "norm_thperp0        0.011620180771254434\n",
      "norm_thperp1        0.011622855451208258\n",
      "norm_thperp1-0      2.6746799538235333e-06\n",
      "norm_thperp1-0/0    0.00023017541692983432\n",
      "normsq_dense10      32.71414899068728\n",
      "normsq_dense11      33.28859274844936\n",
      "normsq_dense1_1-0   0.5744437577620829\n",
      "normsq_dense1_1-0/0 0.017559489562929163\n",
      "sum_dense20         -0.2939699044804467\n",
      "sum_dense21         -0.2940375691257766\n",
      "sum_dense2_1-0      -6.766464532992345e-05\n",
      "sum_dense2_1-0/0    0.00023017541693430097\n",
      "errex_by_norm_all   0.4749859245708229\n",
      "errex               4.8944975179955925\n",
      "errth             15.35119150559739\n",
      "errth_100stp,     3.9827872085479963\n",
      "\n",
      "48/52: GD step=470  \n",
      "norm_thperp0        0.01150450047797612\n",
      "norm_thperp1        0.011507206104312023\n",
      "norm_thperp1-0      2.7056263359036964e-06\n",
      "norm_thperp1-0/0    0.00023517981863560862\n",
      "normsq_dense10      32.28301244167677\n",
      "normsq_dense11      32.86771652439277\n",
      "normsq_dense1_1-0   0.5847040827160015\n",
      "normsq_dense1_1-0/0 0.018111819142415577\n",
      "sum_dense20         -0.29104339882320307\n",
      "sum_dense21         -0.29111184635695286\n",
      "sum_dense2_1-0      -6.844753374979362e-05\n",
      "sum_dense2_1-0/0    0.00023517981863375878\n",
      "errex_by_norm_all   0.48043723621419665\n",
      "errex               4.919905314826285\n",
      "errth             15.35119235613418\n",
      "errth_100stp,     3.982790450574309\n",
      "\n",
      "49/52: GD step=480  \n",
      "norm_thperp0        0.011389971795891952\n",
      "norm_thperp1        0.011392707487674882\n",
      "norm_thperp1-0      2.735691782929789e-06\n",
      "norm_thperp1-0/0    0.0002401842455761372\n",
      "normsq_dense10      31.877767438588773\n",
      "normsq_dense11      32.45600082053227\n",
      "normsq_dense1_1-0   0.5782333819434946\n",
      "normsq_dense1_1-0/0 0.018139080255775056\n",
      "sum_dense20         -0.2881460268807765\n",
      "sum_dense21         -0.288215235016859\n",
      "sum_dense2_1-0      -6.92081360824659e-05\n",
      "sum_dense2_1-0/0    0.00024018424557733535\n",
      "errex_by_norm_all   0.4854960481420394\n",
      "errex               4.941526732089756\n",
      "errth             15.351193185051757\n",
      "errth_100stp,     3.9827935794256146\n",
      "\n",
      "50/52: GD step=490  \n",
      "norm_thperp0        0.011276583260574189\n",
      "norm_thperp1        0.011279348151334809\n",
      "norm_thperp1-0      2.7648907606203643e-06\n",
      "norm_thperp1-0/0    0.00024518869738559266\n",
      "normsq_dense10      31.486676318570108\n",
      "normsq_dense11      32.053244998770694\n",
      "normsq_dense1_1-0   0.5665686802005858\n",
      "normsq_dense1_1-0/0 0.017993918267786072\n",
      "sum_dense20         -0.28527749862354\n",
      "sum_dense21         -0.2853474454418201\n",
      "sum_dense2_1-0      -6.994681828009774e-05\n",
      "sum_dense2_1-0/0    0.0002451886973826894\n",
      "errex_by_norm_all   0.49004911568841825\n",
      "errex               4.9563460021170735\n",
      "errth             15.351193991612876\n",
      "errth_100stp,     3.9827966044193968\n",
      "\n",
      "51/52: GD step=500  \n",
      "norm_thperp0        0.011164323521725137\n",
      "norm_thperp1        0.011167116759265856\n",
      "norm_thperp1-0      2.793237540718932e-06\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "norm_thperp1-0/0    0.000250193174291613\n",
      "normsq_dense10      31.095901594075713\n",
      "normsq_dense11      31.65925199311023\n",
      "normsq_dense1_1-0   0.563350399034519\n",
      "normsq_dense1_1-0/0 0.01811654816729439\n",
      "sum_dense20         -0.2824375269091499\n",
      "sum_dense21         -0.2825081908505458\n",
      "sum_dense2_1-0      -7.06639413958765e-05\n",
      "sum_dense2_1-0/0    0.00025019317428950075\n",
      "errex_by_norm_all   0.49469603478946284\n",
      "errex               4.971681416091924\n",
      "errth             15.351194776106382\n",
      "errth_100stp,     3.9827995340832647\n",
      "\n",
      "52/52: GD step=510  \n",
      "norm_thperp0        0.011053181342040546\n",
      "norm_thperp1        0.01105600208823286\n",
      "norm_thperp1-0      2.820746192315135e-06\n",
      "norm_thperp1-0/0    0.00025519767612845414\n",
      "normsq_dense10      30.707277062021898\n",
      "normsq_dense11      31.273828376127984\n",
      "normsq_dense1_1-0   0.5665513141060856\n",
      "normsq_dense1_1-0/0 0.018450066834704276\n",
      "sum_dense20         -0.27962582745379816\n",
      "sum_dense21         -0.27969718731515014\n",
      "sum_dense2_1-0      -7.13598613519828e-05\n",
      "sum_dense2_1-0/0    0.00025519767612944624\n",
      "errex_by_norm_all   0.49956378143058067\n",
      "errex               4.989536130245832\n",
      "errth             15.351195538589566\n",
      "errth_100stp,     3.9828023761495612\n",
      "b0-b1 -4.085620730620576e-13\n",
      "-a0 ex      0.0010005003484995967\n",
      "-a1 ex      0.0010000000247382891\n",
      "-a0 thGF    0.001\n",
      "-a0 thGFwXi 0.0010005\n"
     ]
    }
   ],
   "source": [
    "#================================================================\n",
    "# Calc Discretization Error etc.\n",
    "#================================================================\n",
    "#**************************************\n",
    "# Write header to csv file\n",
    "with open(csvpath, \"a\") as f:\n",
    "    f.write(\"GD step,norm_thperp0,norm_thperp1,norm_thperp1-0,norm_thperp1-0/0,normsq_dense10,normsq_dense11,normsq_dense1_1-0,normsq_dense1_1-0/0,sum_dense20,sum_dense21,sum_dense2_1-0,sum_dense2_1-0/0,errex_by_norm_all,errex,errth,errth_100stp\\n\")\n",
    "#**************************************\n",
    "# Initialize\n",
    "decay0 = [] # For decay rate\n",
    "decay1 = [] # For decay rate\n",
    "\n",
    "Hg_tot = np.float64(0.)\n",
    "Hg_tot_100stp = np.float64(0.) # <<<Most of gap betw. th. and ex. comes from steps after 100>>>\n",
    "errexvec_100stp = None # <<<Most of gap betw. th. and ex. comes from steps after 100>>>\n",
    "\n",
    "# Model loop\n",
    "for cnt, (it_step0, it_step1) in enumerate(zip(steps_intsec0, steps_intsec1)):\n",
    "    it_path_resume0 = dc_ckpt_paths0[it_step0]\n",
    "    it_path_resume1 = dc_ckpt_paths1[it_step1]\n",
    "    \n",
    "    # Restore weights\n",
    "    _ = ckpt.restore(it_path_resume0)\n",
    "    it_w0 = [tf.identity(v) for v in model.trainable_variables]\n",
    "    _ = ckpt.restore(it_path_resume1)\n",
    "    it_w1 = [tf.identity(v) for v in model.trainable_variables]\n",
    "    \n",
    "    \n",
    "    # 1. Calc Norm Diff\n",
    "    #========================================================================\n",
    "    errex = [tf.norm(v0-v1).numpy() for v0, v1 in zip(it_w0, it_w1)]\n",
    "    errex = np.sqrt(np.sum([v**2 for v in errex]))                              # <--- That's it!\n",
    "    errex_by_norm_all = errex / np.sqrt(np.sum([tf.norm(v)**2 for v in it_w0])) # <--- That's it!\n",
    "\n",
    "    \n",
    "    # 2. Calc Norm theta_perp\n",
    "    #========================================================================\n",
    "    thperp0, _ = calc_perp_para_translation(it_w0[2])\n",
    "    norm_thperp0 = tf.norm(thperp0).numpy()                                     # <--- That's it!\n",
    "    thperp1, _ = calc_perp_para_translation(it_w1[2]) \n",
    "    norm_thperp1 = tf.norm(thperp1).numpy()                                     # <--- That's it!\n",
    "\n",
    "    norm_thperpdiff10 = norm_thperp1 - norm_thperp0                             # <--- That's it!\n",
    "    norm_thperpdiff10by0 = norm_thperpdiff10 / norm_thperp0                     # <--- That's it!\n",
    "    \n",
    "       \n",
    "    # 3. Calc Noether Charges: Scale and Translation\n",
    "    #========================================================================\n",
    "    normsq_dense10 = tf.norm(it_w0[1]) ** 2\n",
    "    normsq_dense10 = normsq_dense10.numpy()                                     # <--- That's it!\n",
    "    sum_dense20 = tf.reduce_sum(it_w0[2]).numpy()                               # <--- That's it!\n",
    "\n",
    "    normsq_dense11 = tf.norm(it_w1[1]) ** 2\n",
    "    normsq_dense11 = normsq_dense11.numpy()                                     # <--- That's it!\n",
    "    sum_dense21 = tf.reduce_sum(it_w1[2]).numpy()                               # <--- That's it!\n",
    "    \n",
    "    normsq_dense1_diff10 = normsq_dense11 - normsq_dense10                      # <--- That's it!\n",
    "    sum_dense2_diff10 = sum_dense21 - sum_dense20                               # <--- That's it!\n",
    "    normsq_dense1_diff10divby0 = normsq_dense1_diff10 / normsq_dense10          # <--- That's it!\n",
    "    sum_dense2_diff10divby0 = sum_dense2_diff10 / sum_dense20                   # <--- That's it!\n",
    "\n",
    "    \n",
    "#     # 4. Calc Norm dense1 and Grad dense1\n",
    "#     #========================================================================\n",
    "#     norm_dense10 = tf.norm(it_w0[1])   \n",
    "#     grad10 = calc_grad(model, x, y, weights=it_w0, weight_decay=wd0)[1]\n",
    "#     ngrad10 = tf.norm(grad10)\n",
    "#     ngradhat10 = norm_dense10 * ngrad10                                         \n",
    "#     ngradhat10 = ngradhat10.numpy()                                             # <--- That's it!\n",
    "#     rsq_star0 = np.sqrt(np.float64(lr0)/(0.5 * np.float64(wd0) +\\\n",
    "#         np.float64(lr0) * np.float64(wd0)**2)) * ngradhat10                     # <--- That's it!\n",
    "    \n",
    "#     norm_dense11 = tf.norm(it_w1[1])\n",
    "#     grad11 = calc_grad(model, x, y, weights=it_w1, weight_decay=wd1)[1]\n",
    "#     ngrad11 = tf.norm(grad11)\n",
    "#     ngradhat11 = norm_dense11 * ngrad11\n",
    "#     ngradhat11 = ngradhat11.numpy()                                             # <--- That's it!\n",
    "#     rsq_star1 = np.sqrt(np.float64(lr0)/(0.5 * np.float64(wd1) +\\\n",
    "#         np.float64(lr0) * np.float64(wd1)**2)) * ngradhat11                     # <--- That's it!\n",
    "\n",
    "    \n",
    "#     # 5. Calc Cosine Angular Update (CAU)\n",
    "#     #========================================================================\n",
    "#     # Calc CAU\n",
    "#     if cnt == 0:\n",
    "#         tmp_dense10 = it_w0[1]\n",
    "#         tmp_dense11 = it_w1[1]\n",
    "#         cau0 = np.float64(1.)\n",
    "#         cau1 = np.float64(1.)\n",
    "#         lcau0 = - np.log(2. + 1e-50)\n",
    "#         lcau1 = - np.log(2. + 1e-50)\n",
    "#     else:\n",
    "#         # Calc CAU and log CAU\n",
    "#         cau0 = - calc_cossim(\n",
    "#             tf.reshape(tmp_dense10, [-1]), tf.reshape(it_w0[1], [-1])).numpy()  # <--- That's it!\n",
    "#         cau1 = - calc_cossim(\n",
    "#             tf.reshape(tmp_dense11, [-1]), tf.reshape(it_w1[1], [-1])).numpy()  # <--- That's it!\n",
    "#         lcau0 = - np.log(1. - cau0 + 1e-50)                                     # <--- That's it!\n",
    "#         lcau1 = - np.log(1. - cau1 + 1e-50)                                     # <--- That's it!\n",
    "\n",
    "#         # Update tmp_wX\n",
    "#         tmp_dense10 = it_w0[1]\n",
    "#         tmp_dense11 = it_w1[1]\n",
    "    \n",
    "#     # Theoretical prediciton\n",
    "#     cau_star = (1 - np.float64(10*lr0) * np.float64(wd0))/\\\n",
    "#         (1 - 0.5 * np.float64(10*lr0)**2 * np.float64(wd0)**2)                 # <--- That's it!\n",
    "#     lcau_star = - np.log(1 - cau_star)                                          # <--- That's it!\n",
    "    \n",
    "    \n",
    "    # 6.Calc Discretization Error\n",
    "    #========================================================================\n",
    "    # <<<Most of gap betw. th. and ex. comes from steps after 100>>> HEAD\n",
    "    if cnt == 10:\n",
    "        errexvec_100stp = [v0 - v1 for v0, v1 in zip(it_w0, it_w1)]\n",
    "    # <<<Most of gap betw. th. and ex. comes from steps after 100>>> TAIL\n",
    "        \n",
    "    Hg = calc_Hg(model, x, y, wd1, it_w1, eps)\n",
    "    if cnt == 0:\n",
    "        pass\n",
    "    else:\n",
    "        tmp_add = [sum(interp_Hg(v, w, 9)[:-1]) for v, w in zip(tmp_Hg, Hg)]\n",
    "        Hg_tot += tmp_add\n",
    "    tmp_Hg = [tf.identity(v) for v in Hg]\n",
    "        \n",
    "    if cnt == 0: \n",
    "        errth = 0.\n",
    "    else:\n",
    "        normHg_tot = tf.sqrt(tf.reduce_sum([tf.norm(v)**2 for v in Hg_tot])).numpy()\n",
    "        errth = (lr0 ** 2) * 0.5 * normHg_tot                                   # <--- That's it!\n",
    "\n",
    "    # <<<Most of gap betw. th. and ex. comes from steps after 100>>> HEAD\n",
    "    if cnt == 10: \n",
    "        errth_100stp = errex\n",
    "    elif cnt > 10: \n",
    "        Hg_tot_100stp += tmp_add\n",
    "        errth_100stp =[v + (lr0 ** 2) * 0.5 * w for v, w in zip(errexvec_100stp, Hg_tot_100stp)]\n",
    "        errth_100stp = tf.sqrt(tf.reduce_sum([tf.norm(v)**2 for v in errth_100stp])).numpy()\n",
    "    # <<<Most of gap betw. th. and ex. comes from steps after 100>>> TAIL\n",
    "    \n",
    "    \n",
    "    # 7. Verbose\n",
    "    #========================================================================\n",
    "    print(\"\\n{}/{}: GD step={}  \".format(cnt+1, len(steps_intsec0), it_step0))\n",
    "\n",
    "    print(\"norm_thperp0       \", norm_thperp0)\n",
    "    print(\"norm_thperp1       \", norm_thperp1)\n",
    "    print(\"norm_thperp1-0     \", norm_thperpdiff10)\n",
    "    print(\"norm_thperp1-0/0   \", norm_thperpdiff10by0)\n",
    "    \n",
    "    print(\"normsq_dense10     \", normsq_dense10)\n",
    "    print(\"normsq_dense11     \", normsq_dense11)\n",
    "    print(\"normsq_dense1_1-0  \", normsq_dense1_diff10)\n",
    "    print(\"normsq_dense1_1-0/0\", normsq_dense1_diff10divby0)\n",
    "    print(\"sum_dense20        \", sum_dense20)\n",
    "    print(\"sum_dense21        \", sum_dense21)\n",
    "    print(\"sum_dense2_1-0     \", sum_dense2_diff10)\n",
    "    print(\"sum_dense2_1-0/0   \", sum_dense2_diff10divby0)\n",
    "\n",
    "#     print(\"ngradhat10         \", ngradhat10)\n",
    "#     print(\"ngradhat11         \", ngradhat11)    \n",
    "#     print(\"rsq_star0          \", rsq_star0)\n",
    "#     print(\"rsq_star1          \", rsq_star1)\n",
    "    \n",
    "#     print(\"cau0               \", cau0)\n",
    "#     print(\"cau1               \", cau1)\n",
    "#     print(\"cau_star           \", cau_star)\n",
    "#     print(\"lcau0              \", lcau0)\n",
    "#     print(\"lcau1              \", lcau1)\n",
    "#     print(\"lcau_star          \", lcau_star)\n",
    "    \n",
    "    print(\"errex_by_norm_all  \", errex_by_norm_all)\n",
    "    print(\"errex              \", errex)\n",
    "    #print(\"normHg_tot    \", normHg_tot)\n",
    "    print(\"errth            \", errth)\n",
    "    if cnt >= 10: print(\"errth_100stp,    \", errth_100stp)\n",
    "    \n",
    "    decay0.append(norm_thperp0)\n",
    "    decay1.append(norm_thperp1)\n",
    "    \n",
    "    #**************************************\n",
    "    # Write results to csv file\n",
    "    if cnt < 10: errth_100stp = \"n/a\"\n",
    "    with open(csvpath, \"a\") as f:\n",
    "        f.write(\"{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\\n\".format(it_step0, norm_thperp0, norm_thperp1, norm_thperpdiff10, norm_thperpdiff10by0, normsq_dense10, normsq_dense11, normsq_dense1_diff10, normsq_dense1_diff10divby0, sum_dense20, sum_dense21, sum_dense2_diff10, sum_dense2_diff10divby0, errex_by_norm_all, errex, errth, errth_100stp))        \n",
    "    #**************************************\n",
    "\n",
    "# 8. Calc Decay Rate\n",
    "#==================================================================\n",
    "decay0_log = np.log(np.array(decay0) + 1e-50)\n",
    "decay1_log = np.log(np.array(decay1) + 1e-50)\n",
    "x_value = [i * 10 for i in range(len(decay0))]\n",
    "a0, b0 = np.polyfit(x_value, decay0_log, 1)\n",
    "a1, b1 = np.polyfit(x_value, decay1_log, 1)\n",
    "print(\"b0-b1\", b0-b1)\n",
    "print(\"-a0 ex     \", -a0)\n",
    "print(\"-a1 ex     \", -a1)\n",
    "print(\"-a0 thGF   \", lr0 * wd0, )\n",
    "print(\"-a0 thGFwXi\", lr0 * wd0 + lr0**2 * wd0**2 /2)  # theory GF and theory GFw\\xi\n",
    "\n",
    "#**************************************\n",
    "# Write results to csv file\n",
    "with open(csvpath, \"a\") as f:\n",
    "    f.write(\"b0-b1,-a0 ex,-a1 ex,-a0 thGF,-a0 thEoM\\n\")\n",
    "    f.write(\"{},{},{},{},{},\".format(b0-b1, -a0, -a1, lr0 * wd0, lr0 * wd0 + lr0**2 * wd0**2 /2))\n",
    "#**************************************"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "b0-b1 -4.085620730620576e-13\n",
      "-a0 ex      0.0010005003484995967\n",
      "-a1 ex      0.0010000000247382891\n",
      "-a0 thGF    0.001\n",
      "-a0 thGFwXi 0.0010005\n"
     ]
    }
   ],
   "source": [
    "# Decay rate\n",
    "decay0_log = np.log(np.array(decay0) + 1e-50)\n",
    "decay1_log = np.log(np.array(decay1) + 1e-50)\n",
    "x_value = [i * 10 for i in range(len(decay0))]\n",
    "a0, b0 = np.polyfit(x_value, decay0_log, 1)\n",
    "a1, b1 = np.polyfit(x_value, decay1_log, 1)\n",
    "\n",
    "print(\"b0-b1\", b0-b1)\n",
    "print(\"-a0 ex     \", -a0)\n",
    "print(\"-a1 ex     \", -a1)\n",
    "print(\"-a0 thGF   \", lr0 * wd0, )\n",
    "print(\"-a0 thGFwXi\", lr0 * wd0 + lr0**2 * wd0**2 /2)  # theory GF and theory GFw\\xi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Surprising Matching: Decay Rate of $||\\theta_\\perp||$\n",
    "- Models:\n",
    "    - LR0: 0.001 LR1: 1e-05 ind_fmg1: 0 wd: 0.01\n",
    "    - LR0: 0.001 LR1: 1e-05 ind_fmg1: 3 wd: 0.01\n",
    "    - Used up to 68-th steps\n",
    "\n",
    "```\n",
    "thGF        1.0e-05 (=lr0*wd0)\n",
    "thGFw\\xi    1.000005e-05 (= lr0*wd0 + (lr0*wd0)^2/2)\n",
    "exGF:       1.0000000247366986e-05\n",
    "exGFw\\xi:   1.0000050247119778e-05\n",
    "exGD:       1.000005047529353e-05\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Note\n",
    "Theoretical DE is defined as $|| \\mathbf{e}_{k} || = \\frac{\\eta^2}{2} || \\sum_{s=0}^{k-1} ( H (\\mathbf{\\theta}(s\\eta)) + \\lambda I) \\mathbf{g} (\\mathbf{\\theta}(s\\eta)) || + O(\\eta^3)$.\n",
    "\n",
    "R.H.S. is approximated to $||(H({\\bf \\theta}(t)) + \\lambda I) {\\bf g} ({\\bf \\theta}(t)) ||$ with $H({\\bf \\theta}(t)) + \\lambda I \\sim \\frac{{\\bf g} ({\\bf \\theta}(t) + \\epsilon {\\bf g} ({\\bf \\theta}(t))) - {\\bf g} ({\\bf \\theta}(t) - \\epsilon {\\bf g} ({\\bf \\theta}(t)))}{2\\epsilon}$ of model 0,\n",
    "\n",
    "where `eps` dependence is given by:\n",
    "```\n",
    "Model: `/data/t-miyagawa/eom/mnist/ckptlogs/Run20220216_try/_ModGradFalseLR0.001WD0.01_20220216_092320162/ckpt_step0-1`\n",
    "- `eps` vs. $||(H({\\bf \\theta}(t)) + \\lambda I) {\\bf g} ({\\bf \\theta}(t)) ||$\n",
    "- `eps` = 1e-1:  174.41075192330862\n",
    "- `eps` = 1e-2:  516.8432639239894\n",
    "- `eps` = 1e-3:  557.9337983681397\n",
    "- `eps` = 1e-4:  558.3848941907582\n",
    "- `eps` = 1e-5:  558.3894095286483\n",
    "- `eps` = 1e-6:  558.3894546807346\n",
    "- `eps` = 1e-7:  558.3894551335071\n",
    "- `eps` = 1e-8:  558.3894549019357\n",
    "- `eps` = 1e-9:  558.3894541373063\n",
    "- `eps` = 1e-10: 558.3894438393735\n",
    "- `eps` = 1e-12: 558.3906191660091\n",
    "- `eps` = 1e-15: 559.8027317450302\n",
    "- `eps` = 1e-18: 822.418941870446\n",
    "- `eps` = 1e-20: 7339.360218755874\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$\\nabla f(\\theta) = \\frac{1}{r} \\nabla f(\\hat{\\theta})$\n",
    "\n",
    "$r_{*}^2 = \\sqrt{\\frac{\\eta}{2\\lambda+\\eta\\lambda^2}} c_*$\n",
    "\n",
    "${\\rm CAU} := \\cos \\Delta_k \\xrightarrow{t\\rightarrow \\infty} \\cos \\Delta_* = \\frac{1-\\eta\\lambda}{1 - \\eta^2\\lambda^2 /2} + O(\\eta^3)$\n",
    "\n",
    "${\\rm LCAU} := - \\log (1 - \\cos \\Delta_k) (\\geq - \\log 2)$"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
