{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7347dd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from train.py\n",
    "import numpy as np\n",
    "from omegaconf import OmegaConf\n",
    "import pathlib\n",
    "\n",
    "from torch import optim\n",
    "\n",
    "from models import *\n",
    "from tasks import *\n",
    "import config\n",
    "import utils\n",
    "\n",
    "import torch\n",
    "\n",
    "config_arg = \"spatial_navigation/noisy_unbiased\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c0105b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from config.py\n",
    "#    \"\"\"Loads the experiment configuration for model training.\"\"\"\n",
    "#    cli_config = OmegaConf.from_cli()\n",
    "\n",
    "train_config_path = pathlib.Path(\"configs\", \"train\")\n",
    "base_config = OmegaConf.load(train_config_path / \"base.yml\")\n",
    "task_config = OmegaConf.load(train_config_path / config_arg.split(\"/\")[0] / \"base.yml\")\n",
    "expt_config = OmegaConf.load(train_config_path / f\"{config_arg}.yml\")\n",
    "expt_config.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "args = OmegaConf.merge(base_config, task_config, expt_config)\n",
    "args[\"config\"] = config_arg\n",
    "args[\"seed\"] = 0\n",
    "args[\"rnn\"][\"bias\"] = True\n",
    "args[\"trainer\"][\"n_epochs\"] = 10_000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2074eeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from train.py\n",
    "task = spatial_navigation.SpatialNavigation(\n",
    "    box_width=args.task.box_width,\n",
    "    box_height=args.task.box_height,\n",
    "    border_region=args.task.border_region,\n",
    "    border_slow_factor=args.task.border_slow_factor,\n",
    "    init_pos=args.task.init_pos,\n",
    "    biased=args.task.biased,\n",
    "    drift_const=args.task.drift_const,\n",
    "    anchor_point=np.array(args.task.anchor_point),\n",
    "    dt=args.task.dt,\n",
    "    mu=args.task.mu,\n",
    "    sigma=args.task.sigma,\n",
    "    b=args.task.b,\n",
    "    use_place_cells=args.task.use_place_cells,\n",
    "    place_cells_num=args.task.place_cells_num,\n",
    "    place_cells_sigma=args.task.place_cells_sigma,\n",
    "    place_cells_dog=args.task.place_cells_dog,\n",
    "    place_cells_surround_scale=args.task.place_cells_surround_scale,\n",
    "    sequence_length=args.task.sequence_length,\n",
    "    batch_size=args.task.batch_size,\n",
    "    device=args.device,\n",
    ")\n",
    "\n",
    "model = rnn.RNN(\n",
    "    task=task,\n",
    "    n_in=args.rnn.n_in,\n",
    "    n_rec=args.rnn.n_rec,\n",
    "    n_out=args.rnn.n_out,\n",
    "    n_init=args.rnn.n_init,\n",
    "    sigma_in=np.sqrt(args.rnn.sigma2_in),\n",
    "    sigma_rec=np.sqrt(args.rnn.sigma2_rec),\n",
    "    sigma_out=np.sqrt(args.rnn.sigma2_out),\n",
    "    dt=args.rnn.dt,\n",
    "    tau=args.rnn.tau,\n",
    "    bias=args.rnn.bias,\n",
    "    activation_fn=args.rnn.activation_fn,\n",
    "    device=args.device,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60a7877c",
   "metadata": {},
   "outputs": [],
   "source": [
    "args.rnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58755948",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from trainer.py __init\n",
    "train_data = task.get_generator()\n",
    "test_data = task.get_test_batch()\n",
    "unmask_every = 6\n",
    "lr = 0.001\n",
    "weight_decay = 0\n",
    "compute_all_metrics = True\n",
    "test_freq = 100\n",
    "save_freq = 100\n",
    "path = \"results\"\n",
    "device = \"cuda\"\n",
    "\n",
    "model.set_device(device)  # self.device)\n",
    "optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
    "\n",
    "\n",
    "# from trainer.py train\n",
    "n_epochs = args.trainer.n_epochs\n",
    "start_epoch = 0\n",
    "\n",
    "epoch = start_epoch\n",
    "train_metrics = dict()  # self.train_metrics = dict()\n",
    "test_metrics = dict()  # self.test_metrics = dict()\n",
    "aux = None\n",
    "unmask_levels = (\n",
    "    torch.arange(unmask_every) + 1\n",
    ")  # self.unmask_every) + 1 #[1], [1,2], [1,2,3], etc. masking code.\n",
    "\n",
    "for batch in train_data:  # self.train_data:\n",
    "    # each unmasking level gets the same amount of training time\n",
    "    # unmask_level = unmask_levels[min(self.unmask_every - 1, \\\n",
    "    #                             int(epoch // (n_epochs / self.unmask_every)))] # masking code\n",
    "    unmask_level = unmask_levels[\n",
    "        min(unmask_every - 1, int(epoch // (n_epochs / unmask_every)))\n",
    "    ]  # masking code\n",
    "    epoch += 1\n",
    "    if epoch - start_epoch > n_epochs:\n",
    "        break\n",
    "\n",
    "    model.train()  # self.model.train()\n",
    "    optimizer.zero_grad()  # self.optimizer.zero_grad()\n",
    "\n",
    "    data = batch[\"data\"].to(device)  # self.device)\n",
    "    data[:, torch.arange(data.shape[1]) % unmask_level != 0] = 0  # masking code\n",
    "    init_state = batch[\"init_state\"].to(device)  # self.device)\n",
    "    targets = batch[\"targets\"].to(device)  # self.device)\n",
    "\n",
    "    # _, outputs = self.model(data, init_state=init_state)\n",
    "    _, outputs = model(data, init_state=init_state)\n",
    "\n",
    "    if compute_all_metrics:  # self.compute_all_metrics:\n",
    "        aux = batch\n",
    "\n",
    "    # train_loss, train_metric = self.model.task.compute_metrics(outputs, targets, aux)\n",
    "    train_loss, train_metric = model.task.compute_metrics(outputs, targets, aux)\n",
    "    train_loss.backward()\n",
    "    optimizer.step()  # self.optimizer.step()\n",
    "    self.train_metrics[epoch] = (\n",
    "        train_metric.copy()\n",
    "    )  # self.train_metrics[epoch] = train_metric.copy()\n",
    "\n",
    "    print(f\"Epoch {epoch} (train):\")\n",
    "    for k, v in train_metric.items():\n",
    "        print(f\"  - {k} = {v}.\")\n",
    "\n",
    "    # if epoch % self.save_freq == 0:\n",
    "    #    model_path = pathlib.Path(self.path).joinpath(f\"model_{epoch}.pt\")\n",
    "    #    torch.save(self.model, model_path)\n",
    "    #    print(f\"Model saved at epoch {epoch}.\")\n",
    "\n",
    "    if epoch % test_freq == 0:  # self.test_freq == 0:\n",
    "        with torch.no_grad():\n",
    "            model.eval()  # self.model.eval()\n",
    "\n",
    "            # data = self.test_data[\"data\"].to(self.device)\n",
    "            # init_state = self.test_data[\"init_state\"].to(self.device)\n",
    "            # targets = self.test_data[\"targets\"].to(self.device)\n",
    "            data = test_data[\"data\"].to(device)\n",
    "            init_state = test_data[\"init_state\"].to(device)\n",
    "            targets = test_data[\"targets\"].to(device)\n",
    "\n",
    "            if compute_all_metrics:  # self.compute_all_metrics:\n",
    "                aux = test_data  # self.test_data\n",
    "\n",
    "            # _, outputs = self.model(data, init_state=init_state)\n",
    "            _, outputs = model(data, init_state=init_state)\n",
    "\n",
    "            # _, test_metric = self.model.task.compute_metrics(outputs, targets, aux)\n",
    "            # self.test_metrics[epoch] = test_metric.copy()\n",
    "            _, test_metric = model.task.compute_metrics(outputs, targets, aux)\n",
    "            test_metrics[epoch] = test_metric.copy()\n",
    "\n",
    "            print(f\"Epoch {epoch} (test):\")\n",
    "            for k, v in test_metric.items():\n",
    "                print(f\"  - {k} = {v}.\")\n",
    "\n",
    "# train_metrics_path = pathlib.Path(self.path).joinpath(\"train_metrics.json\")\n",
    "# json.dump(self.train_metrics, open(train_metrics_path, \"w\"), indent=4)\n",
    "\n",
    "# test_metrics_path = pathlib.Path(self.path).joinpath(\"test_metrics.json\")\n",
    "# json.dump(self.test_metrics, open(test_metrics_path, \"w\"), indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e4d6633",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_metrics"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
