{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "976fbfea",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0,'..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4b164f6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "from train import train\n",
    "import priors\n",
    "import encoders\n",
    "import positional_encodings\n",
    "import utils\n",
    "import bar_distribution\n",
    "\n",
    "\n",
    "from samlib.utils import chunker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29d423b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "mykwargs = \\\n",
    "{\n",
    " 'bptt': 5*5+1,\n",
    "'nlayers': 6,\n",
    " 'dropout': 0.0, 'steps_per_epoch': 100,\n",
    " 'batch_size': 100}\n",
    "mnist_jobs_5shot_pi_prior_search = [\n",
    "    pretrain_and_eval( {'num_features': 28 * 28, 'fuse_x_y': False, 'num_outputs': 5,\n",
    "                                            'translations': False, 'jonas_style': True}, priors.stroke.DataLoader, Losses.ce, enc, emsize=emsize, nhead=nhead, warmup_epochs=warmup_epochs, nhid=nhid, y_encoder_generator=encoders.get_Canonical(5), lr=lr, epochs=epochs, single_eval_pos_gen=mykwargs['bptt']-1,\n",
    "                  extra_prior_kwargs_dict={'num_features': 28*28, 'fuse_x_y': False, 'num_outputs':5, 'only_train_for_last_idx': True,\n",
    "                                          'min_max_strokes': (1,max_strokes), 'min_max_len': (min_len, max_len), 'min_max_width': (min_width, max_width), 'max_offset': max_offset, 'max_target_offset': max_target_offset},\n",
    "                  **mykwargs)\n",
    "    for max_strokes, min_len, max_len, min_width, max_width, max_offset, max_target_offset in random_hypers\n",
    "    for enc in [encoders.Linear] for emsize in [1024] for nhead in [4] for nhid in [emsize*2] for warmup_epochs in [5] for lr in [.00001] for epochs in [128,1024] for _ in range(1)]\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "deb93d1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "@torch.inference_mode()\n",
    "def get_acc(finetuned_model, eval_pos, device='cpu', steps=100, train_mode=False, **mykwargs):\n",
    "    finetuned_model.to(device)\n",
    "    finetuned_model.eval()\n",
    "\n",
    "    t_dl = priors.omniglot.DataLoader(steps, batch_size=1000, seq_len=mykwargs['bptt'], train=train_mode,\n",
    "                                      **mykwargs['extra_prior_kwargs_dict'])\n",
    "\n",
    "    ps = []\n",
    "    ys = []\n",
    "    for x, y in tqdm(t_dl):\n",
    "        p = finetuned_model(tuple(e.to(device) for e in x), single_eval_pos=eval_pos)\n",
    "        ps.append(p)\n",
    "        ys.append(y)\n",
    "\n",
    "    ps = torch.cat(ps, 1)\n",
    "    ys = torch.cat(ys, 1)\n",
    "\n",
    "    def acc(ps, ys):\n",
    "        return (ps.argmax(-1) == ys.to(ps.device)).float().mean()\n",
    "\n",
    "    a = acc(ps[eval_pos], ys[eval_pos]).cpu()\n",
    "    print(a.item())\n",
    "    return a\n",
    "\n",
    "\n",
    "def train_and_eval(*args, **kwargs):\n",
    "    r = train(*args, **kwargs)\n",
    "    model = r[-1]\n",
    "    acc = get_acc(model, -1, device='cuda:0', **kwargs).cpu()\n",
    "    model.to('cpu')\n",
    "    return [acc]\n",
    "\n",
    "def pretrain_and_eval(extra_prior_kwargs_dict_eval,*args, **kwargs):\n",
    "    r = train(*args, **kwargs)\n",
    "    model = r[-1]\n",
    "    kwargs['extra_prior_kwargs_dict'] = extra_prior_kwargs_dict_eval\n",
    "    acc = get_acc(model, -1, device='cuda:0', **kwargs).cpu()\n",
    "    model.to('cpu')\n",
    "    return r, acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "706ecbb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "emsize = 1024\n",
    "# mnist_jobs_5shot_pi[20].result()[-1].state_dict()\n",
    "mykwargs = \\\n",
    "    {'bptt': 5 * 5 + 1,\n",
    "     'nlayers': 6,\n",
    "     'nhead': 4, 'emsize': emsize,\n",
    "     'encoder_generator': encoders.Linear, 'nhid': emsize * 2}\n",
    "results = train_and_eval(priors.omniglot.DataLoader, Losses.ce, y_encoder_generator=encoders.get_Canonical(5),\n",
    "                   load_weights_from_this_state_dict=mnist_jobs_5shot_pi_prior_search[67][0][-1].state_dict(), epochs=32, lr=.00001, dropout=dropout,\n",
    "                   single_eval_pos_gen=mykwargs['bptt'] - 1,\n",
    "                   extra_prior_kwargs_dict={'num_features': 28 * 28, 'fuse_x_y': False, 'num_outputs': 5,\n",
    "                                            'translations': True, 'jonas_style': True},\n",
    "                   batch_size=100, steps_per_epoch=200, **mykwargs)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "611554b2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
