{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import os\n",
    "\n",
    "import sys\n",
    "import glob\n",
    "from pathlib import Path, PurePath\n",
    "path = Path.cwd()\n",
    "parent_path = path.parents[1]\n",
    "sys.path.append(str(PurePath(parent_path, 'neuroformer')))\n",
    "sys.path.append('neuroformer')\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import pearsonr\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "\n",
    "from neuroformer.model_neuroformer import load_model_and_tokenizer\n",
    "from neuroformer.utils import get_attr\n",
    "from neuroformer.utils import (set_seed, running_jupyter, \n",
    "                                 all_device, recursive_print,\n",
    "                                 create_modalities_dict)\n",
    "from neuroformer.datasets import load_visnav, load_V1AL\n",
    "\n",
    "parent_path = os.path.dirname(os.path.dirname(os.getcwd())) + \"/\"\n",
    "\n",
    "# set up logging\n",
    "import logging\n",
    "logging.basicConfig(\n",
    "        format=\"%(asctime)s - %(levelname)s - %(name)s -   %(message)s\",\n",
    "        datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
    "        level=logging.INFO,\n",
    ")\n",
    "\n",
    "from neuroformer.default_args import DefaultArgs, parse_args\n",
    "\n",
    "if running_jupyter(): # or __name__ == \"__main__\":\n",
    "    print(\"Running in Jupyter\")\n",
    "    args = DefaultArgs()\n",
    "    # args.dataset = \"medial\"\n",
    "    # args.ckpt_path = \"./models/NF.15/Visnav_VR_Expt/medial/Neuroformer/predict_all_behavior/(state_history=6,_state=6,_stimulus=6,_behavior=6,_self_att=6,_modalities=(n_behavior=25))/25\"\n",
    "    \n",
    "    args.dataset = \"lateral\"\n",
    "    args.ckpt_path = \"./models/NF.15/Visnav_VR_Expt/lateral/Neuroformer/all_input_contrastive/(state_history=6,_state=6,_stimulus=6,_behavior=6,_self_att=6,_modalities=(n_behavior=25))/25/\"\n",
    "    args.predict_modes = ['speed', 'phi', 'th']\n",
    "else:\n",
    "    print(\"Running in terminal\")\n",
    "    args = parse_args()\n",
    "\n",
    "\n",
    "# SET SEED - VERY IMPORTANT\n",
    "set_seed(args.seed)\n",
    "\n",
    "print(f\"CONTRASTIUVEEEEEEE {args.contrastive}\")\n",
    "print(f\"VISUAL: {args.visual}\")\n",
    "print(f\"PAST_STATE: {args.past_state}\")\n",
    "\n",
    "\n",
    "config, tokenizer, model = load_model_and_tokenizer(args.ckpt_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" \n",
    "\n",
    "-- DATA --\n",
    "neuroformer/data/OneCombo3_V1AL/\n",
    "df = response\n",
    "video_stack = stimulus\n",
    "DOWNLOAD DATA URL = https://drive.google.com/drive/folders/1jNvA4f-epdpRmeG9s2E-2Sfo-pwYbjeY?usp=sharing\n",
    "\n",
    "\n",
    "\"\"\"\n",
    "print(f\"DATASET: {args.dataset}\")\n",
    "if args.dataset in [\"lateral\", \"medial\"]:\n",
    "    data, intervals, train_intervals, \\\n",
    "    test_intervals, finetune_intervals, \\\n",
    "    callback = load_visnav(args.dataset, config, \n",
    "                           selection=config.selection if hasattr(config, \"selection\") else None)\n",
    "elif args.dataset == \"V1AL\":\n",
    "    data, intervals, train_intervals, \\\n",
    "    test_intervals, finetune_intervals, \\\n",
    "    callback = load_V1AL(config)\n",
    "\n",
    "spikes = data['spikes']\n",
    "stimulus = data['stimulus']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "window = config.window.curr\n",
    "window_prev = config.window.prev\n",
    "dt = config.resolution.dt\n",
    "\n",
    "\n",
    "# -------- #\n",
    "\n",
    "spikes_dict = {\n",
    "    \"ID\": data['spikes'],\n",
    "    \"Frames\": data['stimulus'],\n",
    "    \"Interval\": intervals,\n",
    "    \"dt\": config.resolution.dt,\n",
    "    \"id_block_size\": config.block_size.id,\n",
    "    \"prev_id_block_size\": config.block_size.prev_id,\n",
    "    \"frame_block_size\": config.block_size.frame,\n",
    "    \"window\": config.window.curr,\n",
    "    \"window_prev\": config.window.prev,\n",
    "    \"frame_window\": config.window.frame,\n",
    "}\n",
    "\n",
    "\"\"\" structure:\n",
    "{\n",
    "    type_of_modality:\n",
    "        {name of modality: {'data':data, 'dt': dt, 'predict': True/False},\n",
    "        ...\n",
    "        }\n",
    "    ...\n",
    "}\n",
    "\"\"\"\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config.window.frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from neuroformer.data_utils import NFDataloader\n",
    "\n",
    "modalities = create_modalities_dict(data, config.modalities)\n",
    "frames = {'feats': stimulus, 'callback': callback, 'window': config.window.frame, 'dt': config.resolution.dt}\n",
    "\n",
    "train_dataset = NFDataloader(spikes_dict, tokenizer, config, dataset=args.dataset, \n",
    "                             frames=frames, intervals=train_intervals, modalities=modalities)\n",
    "test_dataset = NFDataloader(spikes_dict, tokenizer, config, dataset=args.dataset, \n",
    "                            frames=frames, intervals=test_intervals, modalities=modalities)\n",
    "finetune_dataset = NFDataloader(spikes_dict, tokenizer, config, dataset=args.dataset, \n",
    "                                frames=frames, intervals=finetune_intervals, modalities=modalities)\n",
    "\n",
    "    \n",
    "# print(f'train: {len(train_dataset)}, test: {len(test_dataset)}')\n",
    "iterable = iter(train_dataset)\n",
    "x, y = next(iterable)\n",
    "recursive_print(x)\n",
    "\n",
    "# Update the config\n",
    "config.id_vocab_size = tokenizer.ID_vocab_size\n",
    "\n",
    "# Create a DataLoader\n",
    "loader = DataLoader(test_dataset, batch_size=2, shuffle=True, num_workers=0)\n",
    "iterable = iter(loader)\n",
    "x, y = next(iterable)\n",
    "recursive_print(y)\n",
    "preds, features, loss = model(x, y)\n",
    "\n",
    "# Set training parameters\n",
    "MAX_EPOCHS = 300\n",
    "BATCH_SIZE = 32 * 5\n",
    "SHUFFLE = True\n",
    "\n",
    "if config.gru_only:\n",
    "    model_name = \"GRU\"\n",
    "elif config.mlp_only:\n",
    "    model_name = \"MLP\"\n",
    "elif config.gru2_only:\n",
    "    model_name = \"GRU_2.0\"\n",
    "else:\n",
    "    model_name = \"Neuroformer\"\n",
    "\n",
    "CKPT_PATH = args.ckpt_path\n",
    "\n",
    "# Define the parameters\n",
    "sample = True\n",
    "top_p = 0.95\n",
    "top_p_t = 0.95\n",
    "temp = 1.\n",
    "temp_t = 1.\n",
    "frame_end = 0\n",
    "true_past = args.true_past\n",
    "get_dt = True\n",
    "gpu = True\n",
    "pred_dt = True\n",
    "\n",
    "# # Run the prediction function\n",
    "# results_trial = generate_spikes(model, test_dataset, window, \n",
    "#                                 window_prev, tokenizer, \n",
    "#                                 sample=sample, top_p=top_p, top_p_t=top_p_t, \n",
    "#                                 temp=temp, temp_t=temp_t, frame_end=frame_end, \n",
    "#                                 true_past=true_past,\n",
    "#                                 get_dt=get_dt, gpu=gpu, pred_dt=pred_dt,\n",
    "#                                 plot_probs=False)\n",
    "\n",
    "# # Create a filename string with the parameters\n",
    "# filename = f\"results_trial_sample-{sample}_top_p-{top_p}_top_p_t-{top_p_t}_temp-{temp}_temp_t-{temp_t}_frame_end-{frame_end}_true_past-{true_past}_get_dt-{get_dt}_gpu-{gpu}_pred_dt-{pred_dt}.pkl\"\n",
    "\n",
    "# # Save the results in a pickle file\n",
    "# save_inference_path = os.path.join(CKPT_PATH, \"inference\")\n",
    "# if not os.path.exists(save_inference_path):\n",
    "#     os.makedirs(save_inference_path)\n",
    "\n",
    "# print(f\"Saving inference results in {os.path.join(save_inference_path, filename)}\")\n",
    "\n",
    "# with open(os.path.join(save_inference_path, filename), \"wb\") as f:\n",
    "#     pickle.dump(results_trial, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# predict other modality\n",
    "from neuroformer.simulation import decode_modality\n",
    "# model.load_state_dict(torch.load(os.path.join(CKPT_PATH, f\"_epoch_speed.pt\"), map_location=torch.device('cpu')))\n",
    "model.load_state_dict(torch.load(os.path.join(CKPT_PATH, f\"model.pt\"), map_location=torch.device('cpu')))\n",
    "args.predict_modes = ['speed', 'phi', 'th']\n",
    "behavior_preds = {}\n",
    "if args.predict_modes is not None:\n",
    "    block_type = 'behavior'\n",
    "    block_config = get_attr(config.modalities, block_type).variables\n",
    "    for mode in args.predict_modes:\n",
    "        mode_config = get_attr(block_config, mode)\n",
    "        behavior_preds[mode] = decode_modality(model, test_dataset, modality=mode, \n",
    "                                          block_type=block_type, objective=get_attr(mode_config, 'objective'))\n",
    "        filename = f\"behavior_preds_{mode}.csv\"\n",
    "        save_inference_path = os.path.join(CKPT_PATH, \"inference\")\n",
    "        if not os.path.exists(save_inference_path):\n",
    "            os.makedirs(save_inference_path)\n",
    "        print(f\"Saving inference results in {os.path.join(save_inference_path, filename)}\")\n",
    "        behavior_preds.to_csv(os.path.join(save_inference_path, filename))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from neuroformer.visualize import plot_regression\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(17.5, 5), nrows=1, ncols=len(args.predict_modes))\n",
    "plt.suptitle(f'Visnav {args.dataset} Multitask Decoding - Speed + Eye Gaze (phi, th)', fontsize=20, y=1.01)\n",
    "colors = ['limegreen', 'royalblue', 'darkblue']  # Define your colors here\n",
    "\n",
    "for n, mode in enumerate(args.predict_modes):\n",
    "    behavior_preds_mode = behavior_preds[mode]\n",
    "    x_true, y_true = behavior_preds_mode['cum_interval'][:200], behavior_preds_mode['true'][:200]  # Limit to 200 examples\n",
    "    x_pred, y_pred = behavior_preds_mode['cum_interval'][:200], behavior_preds_mode[f'behavior_{mode}_value'][:200]  # Limit to 200 examples\n",
    "    r, p = pearsonr([float(y) for y in y_pred], [float(y) for y in y_true])\n",
    "    axis = ax[n]\n",
    "    plot_regression(y_true, y_pred, mode, model_name, r, p, ax=axis, color=colors[n], save_path=args.ckpt_path)  # Use color"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "neuroformer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
