{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experimental data\n",
    "\n",
    "*Roitman, JD and Shadlen, MN (2002), “Response of neurons in the lateral intraparietal area during a combined visual discrimination reaction time task”, Journal of Neuroscience, Vol. 22(21), 9475-9489.*\n",
    "\n",
    "\n",
    "#### additional information:\n",
    "- 'coh'      coherence of trial (multiplied by 10 - ie. 32 is a coherence of 3.2%)\n",
    "- 'correct'  whether the subject was correct (1 - correct, 0 - error)\n",
    "- rt should be in ms (between 5 and 1762, but most in [200,1200])\n",
    "- in Shinn et al. animal N (coded as 0 below) is used for all main figures\n",
    "- per monkey and coherence level: trial number n>500 for animal 0, n>400 for animal 1.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "ddm_data = pd.read_csv(\"../../../data/applications/roitman_data_clean.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>rt</th>\n",
       "      <th>coherence</th>\n",
       "      <th>decision</th>\n",
       "      <th>animal</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>464.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>318.0</td>\n",
       "      <td>64.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>531.0</td>\n",
       "      <td>128.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>567.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>398.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6143</th>\n",
       "      <td>743.0</td>\n",
       "      <td>64.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6144</th>\n",
       "      <td>704.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6145</th>\n",
       "      <td>490.0</td>\n",
       "      <td>512.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6146</th>\n",
       "      <td>558.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6147</th>\n",
       "      <td>690.0</td>\n",
       "      <td>128.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>6148 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         rt  coherence  decision  animal\n",
       "0     464.0      256.0       1.0     0.0\n",
       "1     318.0       64.0       1.0     0.0\n",
       "2     531.0      128.0       1.0     0.0\n",
       "3     567.0        0.0       1.0     0.0\n",
       "4     398.0        0.0       1.0     0.0\n",
       "...     ...        ...       ...     ...\n",
       "6143  743.0       64.0       1.0     1.0\n",
       "6144  704.0        0.0       1.0     1.0\n",
       "6145  490.0      512.0       1.0     1.0\n",
       "6146  558.0      256.0       1.0     1.0\n",
       "6147  690.0      128.0       1.0     1.0\n",
       "\n",
       "[6148 rows x 4 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ddm_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generative model \n",
    "Use the *pyddm* toolbox:\n",
    "\n",
    "https://pyddm.readthedocs.io/en/stable/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: pyddm in c:\\users\\zina\\anaconda3\\envs\\labproject\\lib\\site-packages (0.7.0)\n",
      "Requirement already satisfied: numpy>=1.9.2 in c:\\users\\zina\\anaconda3\\envs\\labproject\\lib\\site-packages (from pyddm) (1.26.3)\n",
      "Requirement already satisfied: scipy>=0.16 in c:\\users\\zina\\anaconda3\\envs\\labproject\\lib\\site-packages (from pyddm) (1.12.0)\n",
      "Requirement already satisfied: matplotlib in c:\\users\\zina\\anaconda3\\envs\\labproject\\lib\\site-packages (from pyddm) (3.8.0)\n",
      "Requirement already satisfied: paranoid-scientist>=0.2.1 in c:\\users\\zina\\anaconda3\\envs\\labproject\\lib\\site-packages (from pyddm) (0.2.2)\n",
      "Requirement already satisfied: contourpy>=1.0.1 in c:\\users\\zina\\anaconda3\\envs\\labproject\\lib\\site-packages (from matplotlib->pyddm) (1.2.0)\n",
      "Requirement already satisfied: cycler>=0.10 in c:\\users\\zina\\anaconda3\\envs\\labproject\\lib\\site-packages (from matplotlib->pyddm) (0.11.0)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in c:\\users\\zina\\anaconda3\\envs\\labproject\\lib\\site-packages (from matplotlib->pyddm) (4.25.0)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in c:\\users\\zina\\anaconda3\\envs\\labproject\\lib\\site-packages (from matplotlib->pyddm) (1.4.4)\n",
      "Requirement already satisfied: packaging>=20.0 in c:\\users\\zina\\anaconda3\\envs\\labproject\\lib\\site-packages (from matplotlib->pyddm) (23.1)\n",
      "Requirement already satisfied: pillow>=6.2.0 in c:\\users\\zina\\anaconda3\\envs\\labproject\\lib\\site-packages (from matplotlib->pyddm) (10.0.1)\n",
      "Requirement already satisfied: pyparsing>=2.3.1 in c:\\users\\zina\\anaconda3\\envs\\labproject\\lib\\site-packages (from matplotlib->pyddm) (3.0.9)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\zina\\anaconda3\\envs\\labproject\\lib\\site-packages (from matplotlib->pyddm) (2.8.2)\n",
      "Requirement already satisfied: importlib-resources>=3.2.0 in c:\\users\\zina\\anaconda3\\envs\\labproject\\lib\\site-packages (from matplotlib->pyddm) (6.1.1)\n",
      "Requirement already satisfied: zipp>=3.1.0 in c:\\users\\zina\\anaconda3\\envs\\labproject\\lib\\site-packages (from importlib-resources>=3.2.0->matplotlib->pyddm) (3.17.0)\n",
      "Requirement already satisfied: six>=1.5 in c:\\users\\zina\\anaconda3\\envs\\labproject\\lib\\site-packages (from python-dateutil>=2.7->matplotlib->pyddm) (1.16.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install pyddm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load pyddm \n",
    "from pyddm import Model\n",
    "from pyddm.models import DriftConstant, NoiseConstant, BoundConstant, OverlayNonDecision, ICPointSourceCenter\n",
    "from pyddm.functions import fit_adjust_model\n",
    "\n",
    "from pyddm import Fittable\n",
    "from pyddm.models import LossRobustBIC\n",
    "from pyddm.functions import fit_adjust_model\n",
    "\n",
    "from pyddm.models import (\n",
    "    BoundCollapsingExponential,\n",
    "    BoundConstant,\n",
    "    DriftLinear,\n",
    "    DriftConstant,\n",
    "    ICPointSourceCenter,\n",
    "    NoiseConstant,\n",
    "    OverlayNonDecision,\n",
    ")\n",
    "\n",
    "from roitman_utils import filter_roitman_data \n",
    "# this filters the data and puts it into the format we need for pyddm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = filter_roitman_data(ddm_data, \n",
    "                            coherence=128, \n",
    "                            animal=0 , \n",
    "                            n_trial=\"all\", \n",
    "                            attach_model_mask=False,\n",
    "                            partition=None,\n",
    "                            data_mode='pyddm')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DDM 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Info: Params [  2.36556079 -15.36818391   0.72002534   0.99788879   0.15029011] gave 97.07400176299583\n"
     ]
    }
   ],
   "source": [
    "model_fit = Model(name='Simple model (fitted)',\n",
    "                        drift= DriftLinear(drift=Fittable(minval=0, maxval=5),t=0, x=Fittable(minval=-20, maxval=10)),\n",
    "                        noise=NoiseConstant(noise=1),\n",
    "                        bound=BoundCollapsingExponential(B=Fittable(minval=0.5, maxval=4), tau=Fittable(minval=0.1, maxval=4)),\n",
    "                        overlay=OverlayNonDecision(nondectime=Fittable(minval=0.1, maxval=0.4)),\n",
    "                        IC = ICPointSourceCenter(),\n",
    "                        dx=.001, dt=.01, T_dur=2)\n",
    "\n",
    "# fit model\n",
    "fit_adjust_model(data, model_fit,\n",
    "                fitting_method=\"differential_evolution\",\n",
    "                lossfunction=LossRobustBIC, verbose=False)\n",
    "\n",
    "sol = model_fit.solve()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate data \n",
    "generated_data = sol.resample(k=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_data_corr = torch.tensor(generated_data.choice_upper, dtype=torch.float32)\n",
    "generated_data_err = torch.tensor(generated_data.choice_lower, dtype=torch.float32)\n",
    "generated_data = torch.cat([generated_data_corr, generated_data_err]).unsqueeze(-1)\n",
    "real_data_corr = torch.tensor(data.choice_upper, dtype=torch.float32)\n",
    "real_data_err = torch.tensor(data.choice_lower, dtype=torch.float32)\n",
    "real_data = torch.cat([real_data_corr, real_data_err]).unsqueeze(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(generated_data, \"../../../data/applications/ddm/generated_data.pt\")\n",
    "torch.save(real_data, \"../../../data/applications/ddm/real_data.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DDM 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Info: Params [1.82345925 1.0614431  0.17550922] gave 530.9352220973424\n"
     ]
    }
   ],
   "source": [
    "model_fit = Model(name='Simple model (fitted)',\n",
    "                        drift= DriftConstant(drift=Fittable(minval=0, maxval=5)),\n",
    "                        noise=NoiseConstant(noise=1),\n",
    "                        bound=BoundConstant(B=Fittable(minval=0.5, maxval=5)),\n",
    "                        overlay=OverlayNonDecision(nondectime=Fittable(minval=0.1, maxval=0.4)),\n",
    "                        IC = ICPointSourceCenter(),\n",
    "                        dx=.001, dt=.01, T_dur=2)\n",
    "\n",
    "# fit model\n",
    "fit_adjust_model(data, model_fit,\n",
    "                fitting_method=\"differential_evolution\",\n",
    "                lossfunction=LossRobustBIC, verbose=False)\n",
    "\n",
    "sol = model_fit.solve()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate data \n",
    "generated_data = sol.resample(k=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_data_corr = torch.tensor(generated_data.choice_upper, dtype=torch.float32)\n",
    "generated_data_err = torch.tensor(generated_data.choice_lower, dtype=torch.float32)\n",
    "generated_data = torch.cat([generated_data_corr, generated_data_err]).unsqueeze(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(generated_data, \"../../../data/applications/ddm/generated_data2.pt\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
