{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sbi_mcmc.utils.experiment_utils import *\n",
    "from sbi_mcmc.utils.utils import read_from_file, save_to_file\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_as_chunks(test_dataset, save_dir, chunk_size=5000):\n",
    "    \"\"\"\n",
    "    Splits test_dataset into chunks and saves each chunk as a separate file.\n",
    "\n",
    "    If test_dataset is a dictionary, it will chunk each value across keys consistently.\n",
    "    \"\"\"\n",
    "\n",
    "    def get_length(obj):\n",
    "        if isinstance(obj, dict):\n",
    "            return len(next(iter(obj.values())))\n",
    "        return len(obj)\n",
    "\n",
    "    def get_chunk(obj, start, end):\n",
    "        if isinstance(obj, dict):\n",
    "            return {k: v[start:end] for k, v in obj.items()}\n",
    "        return obj[start:end]\n",
    "\n",
    "    num_chunks = (\n",
    "        get_length(test_dataset) + chunk_size - 1\n",
    "    ) // chunk_size  # Ceiling division\n",
    "\n",
    "    for i in range(num_chunks):\n",
    "        start = i * chunk_size\n",
    "        end = min((i + 1) * chunk_size, get_length(test_dataset))\n",
    "        chunk = get_chunk(test_dataset, start, end)\n",
    "        file_path = save_dir / f\"test_dataset_chunk_{i + 1}.pkl\"\n",
    "        save_to_file(chunk, file_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prepare the test datasets for different tasks."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GEV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pymc as pm\n",
    "import pymc_extras.distributions as pmx\n",
    "from sbi_mcmc.tasks import PyMCTask\n",
    "\n",
    "\n",
    "# GEV with wider priors\n",
    "class GeneralizedExtremeValueWide(PyMCTask):\n",
    "    def __init__(self):\n",
    "        var_names = [\"mu\", \"sigma\", \"xi\"]\n",
    "        task_name = \"GEV\"\n",
    "        super().__init__(\n",
    "            var_names=var_names,\n",
    "            task_name=task_name,\n",
    "        )\n",
    "\n",
    "    def setup_pymc_model(self, observation=None) -> pm.Model:\n",
    "        if observation is None:\n",
    "            observation = self.get_observation()\n",
    "        p = 1 / 10\n",
    "        with pm.Model() as pymc_model:\n",
    "            # Priors\n",
    "            mu = pm.Normal(\"mu\", mu=3.8, sigma=0.4)\n",
    "            sigma = pm.HalfNormal(\"sigma\", sigma=0.6)\n",
    "            xi = pm.TruncatedNormal(\n",
    "                \"xi\", mu=0, sigma=0.4, lower=-1.2, upper=1.2\n",
    "            )\n",
    "\n",
    "            # Estimation\n",
    "            gev = pmx.GenExtreme(\n",
    "                \"gev\",\n",
    "                mu=mu,\n",
    "                sigma=sigma,\n",
    "                xi=xi,\n",
    "                observed=observation,\n",
    "            )\n",
    "            # Return level\n",
    "            z_p = pm.Deterministic(\n",
    "                \"z_p\", mu - sigma / xi * (1 - (-np.log(1 - p)) ** (-xi))\n",
    "            )\n",
    "        return pymc_model\n",
    "\n",
    "    def get_observation(self):\n",
    "        # fmt: off\n",
    "        data = np.array([4.03, 3.83, 3.65, 3.88, 4.01, 4.08, 4.18, 3.80,\n",
    "                        4.36, 3.96, 3.98, 4.69, 3.85, 3.96, 3.85, 3.93,\n",
    "                        3.75, 3.63, 3.57, 4.25, 3.97, 4.05, 4.24, 4.22,\n",
    "                        3.73, 4.37, 4.06, 3.71, 3.96, 4.06, 4.55, 3.79,\n",
    "                        3.89, 4.11, 3.85, 3.86, 3.86, 4.21, 4.01, 4.11,\n",
    "                        4.24, 3.96, 4.21, 3.74, 3.85, 3.88, 3.66, 4.11,\n",
    "                        3.71, 4.18, 3.90, 3.78, 3.91, 3.72, 4.00, 3.66,\n",
    "                        3.62, 4.33, 4.55, 3.75, 4.08, 3.90, 3.88, 3.94,\n",
    "                        4.33])\n",
    "        # fmt: on\n",
    "        return data\n",
    "\n",
    "\n",
    "task = GeneralizedExtremeValueWide()\n",
    "paths = get_paths(task)\n",
    "test_dataset = task.sample(1000)\n",
    "\n",
    "save_to_file(\n",
    "    test_dataset,\n",
    "    paths[\"dataset_dir\"] / \"test_dataset_chunk_1.pkl\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bernoulli GLM "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = BernoulliGLMTask()\n",
    "paths = get_paths(task)\n",
    "\n",
    "test_dataset = task.sample(10000)\n",
    "\n",
    "save_to_file(\n",
    "    test_dataset,\n",
    "    paths[\"dataset_dir\"] / \"test_dataset.pkl\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset = read_from_file(paths[\"dataset_dir\"] / \"test_dataset.pkl\")\n",
    "save_as_chunks(test_dataset, paths[\"dataset_dir\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Psychometric Task\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from brainbox.behavior.training import (\n",
    "    compute_performance,\n",
    "    compute_psychometric,\n",
    "    get_signed_contrast,\n",
    ")\n",
    "from brainbox.io.one import SessionLoader\n",
    "from one.api import ONE\n",
    "\n",
    "ONE.setup(base_url=\"https://openalyx.internationalbrainlab.org\", silent=True)\n",
    "one = ONE(password=\"international\")  # noqa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sbi_mcmc.tasks import PsychometricTask\n",
    "\n",
    "task = PsychometricTask(overdispersion=True)\n",
    "paths = get_paths(task)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We provided the csv directly so downloading and pre-processing is not necessary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Download the data\n",
    "# eid = \"4ecb5d24-f5cc-402c-be28-9d0f7cb14b3a\"\n",
    "# trials = one.load_object(eid, \"trials\", collection=\"alf\")\n",
    "# eids = one.search(task=\"biasedChoiceWorld\")\n",
    "# print(len(eids))\n",
    "# errors = []\n",
    "# for eid in eids:\n",
    "#     try:\n",
    "#         sl = SessionLoader(eid=eid, one=one)\n",
    "#         sl.load_trials()\n",
    "#     except Exception as e:\n",
    "#         errors.append((eid, e))\n",
    "# print(errors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Create a DataFrame\n",
    "# eids = one.search(task=\"biasedChoiceWorld\")\n",
    "# data = []\n",
    "# for i in tqdm(range(len(eids))):\n",
    "#     eid = eids[i]\n",
    "#     try:\n",
    "#         sl = SessionLoader(eid=eid, one=one)\n",
    "#         sl.load_trials()\n",
    "#     except Exception as e:\n",
    "#         print(e)\n",
    "#         continue\n",
    "#     trials = sl.trials\n",
    "#     for block in [0.5, 0.2, 0.8]:\n",
    "#         signed_contrast = get_signed_contrast(trials)\n",
    "#         trials[\"signed_contrast\"] = signed_contrast\n",
    "#         performance, contrasts, n_contrasts = compute_performance(\n",
    "#             trials,\n",
    "#             signed_contrast=signed_contrast,\n",
    "#             block=block,\n",
    "#             prob_right=True,\n",
    "#         )\n",
    "#         if np.isnan(n_contrasts).any():\n",
    "#             continue\n",
    "#         data.append(\n",
    "#             {\n",
    "#                 \"eid\": eid,\n",
    "#                 \"block\": block,\n",
    "#                 \"performance\": performance.tolist(),\n",
    "#                 \"contrasts\": contrasts.tolist(),\n",
    "#                 \"n_contrasts\": n_contrasts.tolist(),\n",
    "#             }\n",
    "#         )\n",
    "# # Convert to DataFrame\n",
    "# df = pd.DataFrame(data)\n",
    "\n",
    "# # Save as CSV\n",
    "# csv_path = paths[\"dataset_dir\"] / \"psychometric_data.csv\"\n",
    "# df.to_csv(csv_path, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "csv_path = task.task_info_dir / \"psychometric_data.csv\"\n",
    "# Read the csv file\n",
    "df = pd.read_csv(\n",
    "    csv_path,\n",
    "    converters={\n",
    "        \"performance\": ast.literal_eval,\n",
    "        \"n_contrasts\": ast.literal_eval,\n",
    "        \"contrasts\": ast.literal_eval,\n",
    "    },\n",
    ")\n",
    "df[\"trial_count\"] = df[\"n_contrasts\"].apply(sum)\n",
    "# Convert lists to numpy arrays and multiply elementwise\n",
    "df[\"right_count\"] = df.apply(\n",
    "    lambda row: [\n",
    "        round(p * n)\n",
    "        for p, n in zip(row[\"performance\"], row[\"n_contrasts\"], strict=True)\n",
    "    ],\n",
    "    axis=1,\n",
    ")\n",
    "df = df[[\"contrasts\", \"n_contrasts\", \"right_count\"]]\n",
    "\n",
    "# Convert lists to tuples to make them hashable and count occurrences\n",
    "unique_contrasts = df[\"contrasts\"].apply(tuple)\n",
    "unique_counts = unique_contrasts.value_counts()\n",
    "print(unique_counts)\n",
    "\n",
    "# For simplicity, only keep rows with contrasts=(-100.0, -25.0, -12.5, -6.25, 0.0, 6.25, 12.5, 25.0, 100.0)\n",
    "df = df[\n",
    "    df[\"contrasts\"].apply(tuple)\n",
    "    == (-100.0, -25.0, -12.5, -6.25, 0.0, 6.25, 12.5, 25.0, 100.0)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Stack the columns of df and convert to numpy array of shape (N_dataset, N_contrasts, 3)\n",
    "# Each dataset contain [contrast, n_trials, n_right] for each contrast level\n",
    "test_dataset = np.stack(\n",
    "    [\n",
    "        df[\"contrasts\"].tolist(),\n",
    "        df[\"n_contrasts\"].tolist(),\n",
    "        df[\"right_count\"].tolist(),\n",
    "    ],\n",
    "    axis=-1,\n",
    ")\n",
    "test_dataset[..., 0] /= 100  # scale the contrast to be between -1 and 1\n",
    "save_to_file(test_dataset, paths[\"dataset_dir\"] / \"test_dataset.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sbi_mcmc.utils.utils import read_from_file\n",
    "\n",
    "test_dataset = read_from_file(paths[\"dataset_dir\"] / \"test_dataset.pkl\")\n",
    "print(test_dataset.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_as_chunks(test_dataset, paths[\"dataset_dir\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DDM task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sbi_mcmc.tasks.ddm import CustomDDM\n",
    "\n",
    "task = CustomDDM(dt=0.0001)\n",
    "paths = get_paths(task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(task.task_info_dir / \"stan_data.csv\").rename(\n",
    "    columns={\"id\": \"subject\"}\n",
    ")\n",
    "ids = data[\"subject\"].unique()\n",
    "print(len(ids))\n",
    "\n",
    "\n",
    "def process_data_block(data, condition_val=0):\n",
    "    \"\"\"Process a block of data (congruent or incongruent) into required format.\n",
    "\n",
    "    Args:\n",
    "        data: DataFrame with rt and response columns\n",
    "        condition_val: 0 for congruent, 1 for incongruent condition\n",
    "\n",
    "    Returns:\n",
    "        Processed numpy array with shape (60, 4) containing:\n",
    "        [rt, missing flag, condition type, response]\n",
    "    \"\"\"\n",
    "    missing = np.zeros(60)\n",
    "    condition_type = np.ones(60) * condition_val\n",
    "    data = data[[\"rt\", \"response\"]].values\n",
    "    if data.shape[0] < 60:\n",
    "        missing[data.shape[0] :] = 1\n",
    "        data = np.concatenate(\n",
    "            [data, np.zeros((60 - data.shape[0], 2))], axis=0\n",
    "        )\n",
    "\n",
    "    return np.stack([data[:, 0], missing, condition_type], axis=1)\n",
    "\n",
    "\n",
    "test_dataset = []\n",
    "num_missing_trials = {}\n",
    "for subject_id in tqdm(ids[:]):\n",
    "    data_1p = data[(data[\"subject\"] == subject_id) & (data[\"rt\"] > 0)]\n",
    "    # Replace response 0 with -1\n",
    "    data_1p.loc[:, \"response\"] = data_1p[\"response\"].replace(0, -1)\n",
    "    data_1p.loc[:, \"rt\"] = data_1p[\"rt\"] * data_1p[\"response\"]\n",
    "    data_c = data_1p[data_1p[\"block\"] == 1]  # congruent\n",
    "    data_i = data_1p[data_1p[\"block\"] == 0]  # incongruent\n",
    "\n",
    "    # Process congruent and incongruent blocks\n",
    "    data_c = process_data_block(data_c, condition_val=0)\n",
    "    data_i = process_data_block(data_i, condition_val=1)\n",
    "\n",
    "    # Concatenate the data\n",
    "    observation = np.concatenate([data_c, data_i], axis=0)\n",
    "    if np.sum(observation[:, 1]) > 0:\n",
    "        num_missing_trials[subject_id] = np.sum(observation[:, 1])\n",
    "\n",
    "    # code stimulus types, picture == 1\n",
    "    stimulus_type = np.concatenate(\n",
    "        (\n",
    "            np.zeros(30),\n",
    "            np.ones(30),  # condition 1: congruent\n",
    "            np.zeros(30),\n",
    "            np.ones(30),\n",
    "        )\n",
    "    )  # condition 2: incongruent\n",
    "    observation = np.concatenate(\n",
    "        [observation, stimulus_type[:, None]], axis=-1\n",
    "    )\n",
    "    test_dataset.append(observation)\n",
    "\n",
    "# print(f\"Number of missing trials: \\n{num_missing_trials}\")\n",
    "print(f\"{len(num_missing_trials)} subjects with missing trials\")\n",
    "print(f\"max missing trials: {max(num_missing_trials.values())}\")\n",
    "test_dataset = np.stack(test_dataset, axis=0)\n",
    "print(test_dataset.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the test dataset\n",
    "save_to_file(test_dataset, paths[\"dataset_dir\"] / \"test_dataset.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sbi_mcmc.utils.utils import read_from_file\n",
    "\n",
    "test_dataset = read_from_file(paths[\"dataset_dir\"] / \"test_dataset.pkl\")\n",
    "save_as_chunks(test_dataset, paths[\"dataset_dir\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below was used for generating additional chunks of data for the DDM task. The data was provided by the original authors. We used `test_dataset_chunk_3` and `test_dataset_chunk_4`, together with the `test_dataset_chunk_1` above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import os\n",
    "# import pickle\n",
    "\n",
    "# import numpy as np\n",
    "\n",
    "# file_path = os.path.expanduser(\"~/Downloads/emp_data.p\")\n",
    "# with open(file_path, \"rb\") as f:\n",
    "#     result = pickle.load(f)\n",
    "# # Create a mask to identify people who meet the condition\n",
    "# valid_people_mask = np.all(\n",
    "#     (np.abs(result[\"data_array\"][..., 0]) <= 20)\n",
    "#     & (np.abs(result[\"data_array\"][..., 0]) > 0),\n",
    "#     axis=1,\n",
    "# )\n",
    "\n",
    "# # Filter the data_array to keep only valid people\n",
    "# filtered_data_array = result[\"data_array\"][valid_people_mask]\n",
    "\n",
    "# reordered_data_array = np.zeros_like(filtered_data_array)\n",
    "\n",
    "# for i in range(filtered_data_array.shape[0]):  # For each person\n",
    "#     # Get indices that would sort the values in dimension 1\n",
    "#     sort_indices = np.argsort(filtered_data_array[i, :, 1])\n",
    "\n",
    "#     # Reorder the data for this person\n",
    "#     reordered_data_array[i] = filtered_data_array[i, sort_indices]\n",
    "\n",
    "# # Verify that 1s come before 0s for each person\n",
    "# assert np.all(reordered_data_array[:, :60, 1] == 0)\n",
    "# assert np.all(reordered_data_array[:, 60:, 1] == 1)\n",
    "\n",
    "# missing = np.zeros_like(reordered_data_array)[..., 0]\n",
    "# stimulus_type = np.concatenate(\n",
    "#     (\n",
    "#         np.zeros(30),\n",
    "#         np.ones(30),  # condition 1: congruent\n",
    "#         np.zeros(30),\n",
    "#         np.ones(30),\n",
    "#     )\n",
    "# )  # condition 2: incongruent\n",
    "# stimulus_type = np.tile(stimulus_type, (reordered_data_array.shape[0], 1))\n",
    "\n",
    "# data = np.stack(\n",
    "#     [\n",
    "#         reordered_data_array[..., 0],\n",
    "#         missing,\n",
    "#         reordered_data_array[..., 1],\n",
    "#         stimulus_type,\n",
    "#     ],\n",
    "#     axis=-1,\n",
    "# )\n",
    "\n",
    "\n",
    "# save_as_chunks(\n",
    "#     data[: 22 * 5000],\n",
    "#     Path(\"ddm_data\"),\n",
    "# )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "abw_review",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
