{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9a756c8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "44ea0149",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import torch\n",
    "import numpy as np\n",
    "from types import SimpleNamespace\n",
    "from scripts import launch_pretraining, launch_finetuning\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "from IPython.display import clear_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "907572e2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cpu')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = torch.device('cpu')\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a0d9c8fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_config(data_seed, train_seed, f):\n",
    "    seed = data_seed * 100 + train_seed\n",
    "    exp_name = f'f_all={f}'\n",
    "\n",
    "    return SimpleNamespace(\n",
    "        data_seed = data_seed,\n",
    "        data_protocol = [\n",
    "            {'feature_type': f, 'ids': (r, r + 1), 'margin': 0.1, 'noise': 0.0}\n",
    "            for r in range(0, 32, 2)\n",
    "        ],\n",
    "        multiview_probs = [1.0] * 16,\n",
    "        num_features = 32,\n",
    "        train_samples = 512,\n",
    "        test_samples = 2000,\n",
    "        batch_size = 16,\n",
    "        num_hidden = 32,\n",
    "        num_layers = 3,\n",
    "        activation = 'relu',\n",
    "        last_layer_norm = 10,\n",
    "        riemann_opt = False,\n",
    "        init_dirichlet = None,\n",
    "        pt_iters = 40000,\n",
    "        ft_iters = 20000,\n",
    "        ckpt_iters = 100,\n",
    "        log_iters = 5,\n",
    "        pt_seed = train_seed,\n",
    "        ft_seed = train_seed,\n",
    "        init_point_seed = train_seed,\n",
    "        savedir = f'experiments-final/{exp_name}/PT-FCN-seed={seed}',\n",
    "        ft_savedir = f'experiments-final/{exp_name}/FT-FCN-seed={seed}',\n",
    "        lrs = np.logspace(-4.5, -2.25, 10).tolist()[:-1] + \\\n",
    "              np.logspace(-2.25, -1.25, 9).tolist()[:-1] + \\\n",
    "              np.logspace(-1.25, 0, 6).tolist()\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "de6aeb2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "for f in ['tick']:\n",
    "    for data_seed in range(1, 11):\n",
    "        for train_seed in range(1, 6):\n",
    "            print(f'{f}, data seed: {data_seed}, #{train_seed}')\n",
    "            config = get_config(\n",
    "                data_seed, train_seed, f\n",
    "            )\n",
    "            launch_pretraining(config, device)\n",
    "            launch_finetuning(config, device, num_ft_lr=10)\n",
    "\n",
    "        clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "423e3c31",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
