{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pickle\n",
    "import torch\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_path = \"./models/NF.15/Visnav_VR_Expt/lateral/Neuroformer/predict_all_behavior/(state_history=6,_state=6,_stimulus=6,_behavior=6,_self_att=6,_modalities=(n_behavior=25))/25\"\n",
    "path = \"./models/NF.15/Visnav_VR_Expt/lateral/Neuroformer/predict_all_behavior/(state_history=6,_state=6,_stimulus=6,_behavior=6,_self_att=6,_modalities=(n_behavior=25))/25/inference/results_trial_sample-True_top_p-0.95_top_p_t-0.95_temp-1.0_temp_t-1.0_frame_end-0_true_past-False_get_dt-True_gpu-True_pred_dt-True.pkl\"\n",
    "\n",
    "with open(path, 'rb') as f:\n",
    "    results = pickle.load(f)\n",
    "\n",
    "print(results.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in results.keys():\n",
    "    if isinstance(results[key], list) and all(isinstance(i, (int, float)) for i in results[key]):\n",
    "        print(key)\n",
    "        results[key] = torch.tensor(results[key]).cpu().numpy()\n",
    "results['true'] = np.array([float(i.cpu()) for i in results['true']])\n",
    "# plot distribution of true and predicted values\n",
    "plt.figure(figsize=(20, 10))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.hist(results['true'], bins=100)\n",
    "plt.title('True values')\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.hist(results['ID'], bins=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(results['ID'])\n",
    "len(results['true'])\n",
    "\n",
    "print(f\"len ID: {len(results['ID'])}, len true: {len(results['true'])}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print the frequency and ID of top 10 most frequent predictions\n",
    "\n",
    "from collections import Counter\n",
    "\n",
    "true_topk = Counter(results['true']).most_common(20)\n",
    "pred_topk = Counter(results['ID']).most_common(20)\n",
    "\n",
    "# data frame with true and predicted\n",
    "df_topk = pd.DataFrame(true_topk, columns=['true', 'true_freq'])\n",
    "df_topk['pred'] = [i[0] for i in pred_topk]\n",
    "df_topk['pred_freq'] = [i[1] for i in pred_topk]\n",
    "df_topk\n",
    "\n",
    "# find common topk predictions\n",
    "common_topk = []\n",
    "for i in df_topk['true']:\n",
    "    if i in df_topk['pred'].values:\n",
    "        common_topk.append(i)\n",
    "\n",
    "print(f\"Common topk: {common_topk}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load_speed_predictions\n",
    "\n",
    "speed_path = os.path.join(ckpt_path, 'inference', 'behavior_preds_speed.csv')\n",
    "behavior_preds = pd.read_csv(speed_path)\n",
    "behavior_preds.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import pearsonr\n",
    "# from neuroformer.visualize import set_plot_params\n",
    "# from neuroformer.visualize import set_research_params\n",
    "model_name = \"Neuroformer\"\n",
    "\n",
    "\n",
    "x_true, y_true = behavior_preds['cum_interval'], behavior_preds['true']\n",
    "x_pred, y_pred = behavior_preds['cum_interval'], behavior_preds['behavior_speed_value']\n",
    "\n",
    "# pearson r\n",
    "r, p = pearsonr([float(y) for y in y_pred], [float(y) for y in y_true])\n",
    "\n",
    "# plot\n",
    "fig, ax = plt.subplots(figsize=(5, 5))\n",
    "ax.scatter(y_true, y_pred, s=100, c='k', alpha=0.5)\n",
    "\n",
    "# get the current axis limits after plotting your data\n",
    "xlims = ax.get_xlim()\n",
    "ylims = ax.get_ylim()\n",
    "s_f = 0.8\n",
    "# the line of perfect prediction should span the minimum to the maximum of the current x and y limits\n",
    "combined_limits = [min(xlims[0], ylims[0]) * s_f, max(xlims[1], ylims[1]) * s_f]\n",
    "ax.plot(combined_limits, combined_limits, 'k--', color='red')\n",
    "\n",
    "ax.set_xlabel('True speed', fontsize=20)\n",
    "ax.set_ylabel('Predicted speed', fontsize=20)\n",
    "ax.set_title(f'{model_name}, Regression', fontsize=20)\n",
    "# add pearson r to figure\n",
    "ax.text(0.05, 0.9, 'r = {:.2f}'.format(r), fontsize=20, transform=ax.transAxes)\n",
    "# add p to figure\n",
    "ax.text(0.05, 0.8, 'p < 0.001'.format(p), fontsize=20, transform=ax.transAxes)\n",
    "\n",
    "# axis limits = [-1.5, 1.5]\n",
    "# ax.set_xlim(axis_limits)\n",
    "# ax.set_ylim(axis_limits)\n",
    "# plt.savefig(os.path.join(save_path, 'regression_2.pdf'), dpi=300, bbox_inches='tight')\n",
    "\n",
    "\n",
    "# plot\n",
    "fig, ax = plt.subplots(figsize=(2.5, 2.5))\n",
    "ax.scatter(y_true, y_pred, c='k', alpha=0.5)\n",
    "\n",
    "# get the current axis limits after plotting your data\n",
    "xlims = ax.get_xlim()\n",
    "ylims = ax.get_ylim()\n",
    "s_f = 0.8\n",
    "# the line of perfect prediction should span the minimum to the maximum of the current x and y limits\n",
    "combined_limits = [min(xlims[0], ylims[0]) * s_f, max(xlims[1], ylims[1]) * s_f]\n",
    "ax.plot(combined_limits, combined_limits, 'k--', color='red')\n",
    "\n",
    "ax.set_xlabel('True speed',)\n",
    "ax.set_ylabel('Predicted speed',)\n",
    "ax.set_title(f'{model_name}, Regression',)\n",
    "# add pearson r to figure\n",
    "ax.text(0.05, 0.9, 'r = {:.2f}'.format(r), transform=ax.transAxes)\n",
    "# add p to figure\n",
    "ax.text(0.05, 0.8, 'p < 0.001'.format(p), transform=ax.transAxes)\n",
    "\n",
    "# axis limits = [-1.5, 1.5]\n",
    "# ax.set_xlim(axis_limits)\n",
    "# ax.set_ylim(axis_limits)\n",
    "# plt.savefig(os.path.join(save_path, 'regression_2.pdf'), dpi=300, bbox_inches='tight')\n",
    "\n",
    "\n",
    "# %%\n",
    "plt.figure(figsize=(5, 2.5))\n",
    "x = np.arange(len(behavior_preds))\n",
    "plt.title(f'Speed Predictions, {model_name} Regression vs. True')\n",
    "plt.plot(x, y_true, c='r', label='True')\n",
    "plt.plot(x, y_pred, c='b', label='Regression')\n",
    "plt.xlabel('Time (0.05s)')\n",
    "plt.ylabel('Speed (z-scored)')\n",
    "plt.legend(loc='upper left', framealpha=0.9)\n",
    "# plt.savefig(os.path.join(save_path, 'speed_preds.pdf'), bbox_inches='tight')\n",
    "\n",
    "\n",
    "plt.figure(figsize=(10, 5))\n",
    "x = np.arange(len(behavior_preds))\n",
    "plt.title(f'Speed Predictions, {model_name} Regression vs. True')\n",
    "plt.plot(x, y_true, c='r', label='True')\n",
    "plt.plot(x, y_pred, c='b', label='Regression')\n",
    "plt.xlabel('Time (0.05s)')\n",
    "plt.ylabel('Speed (z-scored)')\n",
    "plt.legend(loc='upper left', framealpha=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "neuroformer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
