{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93cf2d0f-ed2f-4dc8-bd89-82ff44882d75",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "# from torchcfm.optimal_transport import OTPlanSampler\n",
    "\n",
    "from typing import List\n",
    "import time\n",
    "from torchdyn.core import NeuralODE\n",
    "\n",
    "from tqdm import tqdm\n",
    "from torch.distributions.multivariate_normal import MultivariateNormal"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc222e96-8d5d-4d49-a4e2-40e20eb4d27a",
   "metadata": {},
   "source": [
    "# 1. Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53ce556b-72a6-43d3-b2da-3342eeeb7b66",
   "metadata": {},
   "outputs": [],
   "source": [
    "dat_LFP = np.load('/hpc/home/gw74/latent_GLGP/stein/steinmetz_lfp.npz', allow_pickle=True)['dat']\n",
    "session = 0\n",
    "lfp_all = dat_LFP[session]\n",
    "lfp_all['lfp'].shape # regions-trial-time_bins\n",
    "\n",
    "# let's truncate...\n",
    "lfp_all_use = lfp_all['lfp'][:,:,50:(50 + 50):10]\n",
    "# lfp_all_use = lfp_all['lfp'][:,:,50:(50 + 50)]\n",
    "\n",
    "nRegion = lfp_all_use.shape[0]\n",
    "nTrial = lfp_all_use.shape[1]\n",
    "nT = lfp_all_use.shape[2]\n",
    "\n",
    "lfp_all_trans = np.zeros((nTrial, nT, nRegion))\n",
    "for tri in range(nTrial):\n",
    "    lfp_all_trans[tri,:,:] = lfp_all_use[:,tri,:].T\n",
    "lfp_all_trans = torch.from_numpy(lfp_all_trans).to(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fea0c1d4-d79f-4aac-adb6-ceae6d95e758",
   "metadata": {},
   "outputs": [],
   "source": [
    "area = 'VISp'\n",
    "region_id = lfp_all['brain_area_lfp'].index(area)\n",
    "mean_lfp = torch.mean(lfp_all_trans[:,:,region_id], 0)\n",
    "\n",
    "sd_tmp = torch.zeros(nT)\n",
    "for ss in range(nT):\n",
    "    sd_tmp[ss] =  torch.std(lfp_all_trans[ss,:,region_id])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "981014f6-211f-4b11-91cc-b401ebde7be4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(torch.arange(0,nT)*0.01, lfp_all_trans[:,:,region_id].T,c = 'grey', alpha = 0.2);\n",
    "plt.plot(torch.arange(0,nT)*0.01, mean_lfp,c = 'red');\n",
    "plt.plot(torch.arange(0,nT)*0.01, mean_lfp + 2*sd_tmp,c = 'red');\n",
    "plt.plot(torch.arange(0,nT)*0.01, mean_lfp - 2*sd_tmp,c = 'red');\n",
    "\n",
    "plt.xlabel('time')\n",
    "plt.ylabel('LFP')\n",
    "plt.title(area)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da0fcfbd-317e-4c59-9e00-08e1cbb9fb74",
   "metadata": {},
   "source": [
    "# 2. Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3a473d6-fb5c-4cb3-8c23-e44bb2ab0474",
   "metadata": {},
   "source": [
    "## 2.1 Common Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a2b90dc-27de-4390-8917-c7930af82d79",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(torch.nn.Module):\n",
    "    def __init__(self, dim, out_dim=None, w=64, time_varying=False): # 64\n",
    "        super().__init__()\n",
    "        self.time_varying = time_varying\n",
    "        if out_dim is None:\n",
    "            out_dim = dim\n",
    "        self.net = torch.nn.Sequential(\n",
    "            torch.nn.Linear(dim + (1 if time_varying else 0), w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, out_dim),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "763f2286-af19-4dd3-9967-92cdb45e9e42",
   "metadata": {},
   "outputs": [],
   "source": [
    "class torch_wrapper(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, model, x00 = None):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "        self.x00 = x00\n",
    "\n",
    "    def forward(self, t, x, *args, **kwargs):\n",
    "        if self.x00 is None:\n",
    "            x_expand = x\n",
    "        else:\n",
    "            x_expand = torch.cat([x, self.x00], -1)\n",
    "        return self.model(torch.cat([x_expand, t.repeat(x_expand.shape[0])[:, None]], 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31ea7b03-7682-42ff-9fed-fdf54c839210",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_traj(model, n_samp, nt_gen, seed, x_start = None, x00 = False):\n",
    "    \n",
    "    if x_start is None:\n",
    "        torch.manual_seed(seed)\n",
    "        x_start = torch.randn(n_samp, dim)\n",
    "    \n",
    "    if x00:\n",
    "        node = NeuralODE(torch_wrapper(model, x_start), solver=\"dopri5\",\n",
    "                     sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)\n",
    "    else:\n",
    "        node = NeuralODE(torch_wrapper(model), solver=\"dopri5\",\n",
    "                         sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        traj = node.trajectory(x_start, t_span=torch.linspace(0, 1, nt_gen))\n",
    "        \n",
    "    return traj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac1e078a-0342-4cd6-b868-2626fd7752fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def warmup_lr(step):\n",
    "    return min(step, warmup) / warmup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14fe8d0d-9637-4ae4-bcd6-4830238684f9",
   "metadata": {},
   "source": [
    "## 2.2 GP-I-CFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fe600ef-55cc-4fa3-8371-29a2c6d14610",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_r(ti, tj):\n",
    "    r = ti[...,None] - tj[...,None,:]\n",
    "    r[r == 0] = 1e-15\n",
    "    return r\n",
    "def k11(r, alpha, l):\n",
    "    return (alpha**2)*torch.exp(-0.5 * ((r/l)**2))\n",
    "def k12(r, alpha, l):\n",
    "    return (alpha**2/l**2)*r*torch.exp(-0.5*((r/l)**2))\n",
    "def k22(r, alpha, l):\n",
    "    return (alpha**2/l**4)*(l**2 - r**2)*torch.exp(-0.5*((r/l)**2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a32f13c8-f074-4af2-9aeb-0a15ce48a525",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cov_mat(ti, tj, alpha, l, sig2_diag = 1e-8):\n",
    "    r = calc_r(ti, tj)\n",
    "    nt = r.shape[0]\n",
    "    \n",
    "    Sig11 = k11(r, alpha, l) + torch.eye(nt)*sig2_diag\n",
    "    Sig12 = k12(r, alpha, l)\n",
    "    Sig21 = Sig12.T\n",
    "    Sig22 = k22(r, alpha, l) + torch.eye(nt)*sig2_diag\n",
    "    \n",
    "    block_row1 = torch.cat([Sig11, Sig12], dim=1)\n",
    "    block_row2 = torch.cat([Sig21, Sig22], dim=1)\n",
    "    Sig = torch.cat([block_row1, block_row2], dim=0)\n",
    "    Sig = (Sig + Sig.T)/2\n",
    "    \n",
    "    return Sig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce9ec9d8-0022-4c1f-ac57-59a6751bf40c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def samp_x_dx(t, alpha, l, x_obs, t_obs, sig2_diag=1e-8):\n",
    "    \n",
    "    nB, nt, dim = x_obs.shape[0], t.shape[0], x_obs.shape[2]\n",
    "    nt_obs = t_obs.shape[0]\n",
    "    \n",
    "    r_obs_x = calc_r(t_obs, t)\n",
    "    r_obs_obs = calc_r(t_obs, t_obs)\n",
    "    Sig_11 = cov_mat(t, t, alpha, l, sig2_diag)\n",
    "    \n",
    "    k_obs_x, k_obs_dx = k11(r_obs_x, alpha, l), k12(r_obs_x, alpha, l)\n",
    "    Sig_21 = torch.cat([k_obs_x, k_obs_dx], dim=1)\n",
    "    Sig_12 = Sig_21.T\n",
    "    \n",
    "    Sig_22 = k11(r_obs_obs, alpha, l) + torch.eye(nt_obs)*sig2_diag\n",
    "    Sig_22_inv = torch.linalg.inv(Sig_22)\n",
    "    Sig_cond = Sig_11 - Sig_12 @ Sig_22_inv @ Sig_21\n",
    "    \n",
    "    Sig_cond = (Sig_cond + Sig_cond.T)/2\n",
    "    if not bool((torch.linalg.eigvals(Sig_cond).real>=0).all()):\n",
    "        U, S, Vh = torch.linalg.svd(Sig_cond)\n",
    "        Sig_cond  = Vh.T @ torch.diag(S + 1e-6) @ Vh\n",
    "        Sig_cond = (Sig_cond + Sig_cond.T)/2\n",
    "    \n",
    "    mu_A = Sig_12 @ Sig_22_inv\n",
    "    mu_A_expand = mu_A.repeat(nB,1,1)\n",
    "    \n",
    "    x_samps = torch.zeros((nB, nt, dim))\n",
    "    dx_samps = torch.zeros((nB, nt, dim))\n",
    "    \n",
    "    \n",
    "    # Prepare the batch sampling\n",
    "    x_obs_batch = x_obs.reshape(nB, nt_obs, dim)\n",
    "    mu_new = torch.bmm(mu_A_expand, x_obs_batch).reshape(nB * dim, 2 * nt)\n",
    "\n",
    "    # Repeat Sig_cond for batch and dimensions\n",
    "    Sig_cond_expand = Sig_cond.unsqueeze(0).expand(nB * dim, -1, -1)\n",
    "\n",
    "    try:\n",
    "        # Perform batched sampling for all dimensions at once\n",
    "        dist = MultivariateNormal(loc=mu_new, covariance_matrix=Sig_cond_expand)\n",
    "        x_dx_samps_flat = dist.rsample().reshape(nB, 2 * nt, dim)\n",
    "    except RuntimeError:\n",
    "        print('Sampling failed; using numpy fallback.')\n",
    "        x_dx_samps_flat = torch.zeros((nB, 2 * nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "        for bb in range(nB):\n",
    "            for dd in range(dim):\n",
    "                mu_single = mu_new[bb * dim + dd].cpu().numpy()\n",
    "                cov_single = Sig_cond.cpu().numpy()\n",
    "                sample = np.random.multivariate_normal(mu_single, cov_single)\n",
    "                x_dx_samps_flat[bb, :, dd] = torch.from_numpy(sample)\n",
    "\n",
    "    # Separate x and dx samples\n",
    "    x_samps = x_dx_samps_flat[:, :nt, :]\n",
    "    dx_samps = x_dx_samps_flat[:, nt:, :]\n",
    "\n",
    "    return x_samps, dx_samps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c50ba51-e595-4e1d-b0f8-a66b28a99245",
   "metadata": {},
   "outputs": [],
   "source": [
    "def GP_FMv1(model, optimizer, x_data, alpha, l, nt_per, batch_size, t_obs, n_epochs, sig2_diag = 0, grad_clip = 1.0):\n",
    "    \n",
    "    N = x_data.shape[0]\n",
    "    dim = x_data.shape[2]\n",
    "    \n",
    "    nbatch = int(N/batch_size)\n",
    "    batch_idx = np.reshape(np.arange(0,N),[nbatch, batch_size])\n",
    "    \n",
    "    sched = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lr)\n",
    "    \n",
    "    losses: List[float] = []\n",
    "    model.train()\n",
    "    for k in tqdm(range(n_epochs)):\n",
    "        for bb in range(nbatch):\n",
    "            x_obs = x_data[batch_idx[bb,:],:,:]\n",
    "            \n",
    "            t_raw = torch.rand(nt_per*(len(t_obs) - 1))\n",
    "            int_rep = (t_obs[1:] - t_obs[:-1]).repeat_interleave(nt_per)\n",
    "            start_rep = t_obs[:-1].repeat_interleave(nt_per)\n",
    "            t_batch = int_rep*t_raw + start_rep\n",
    "            \n",
    "            try:\n",
    "                xt_batch, ut_batch = samp_x_dx(t_batch, alpha, l, x_obs, t_obs, sig2_diag)\n",
    "            except:\n",
    "                pass\n",
    "            \n",
    "            t = t_batch.repeat(1,batch_size).T\n",
    "            xt = torch.reshape(xt_batch, (-1,dim))\n",
    "            ut = torch.reshape(ut_batch, (-1,dim))\n",
    "            x00 = x_obs[:,0,:].repeat_interleave(len(t_raw), dim = 0)\n",
    "            \n",
    "            vt = model(torch.cat([xt, x00, t], dim=-1))\n",
    "            \n",
    "            loss = torch.mean((vt - ut) ** 2)\n",
    "\n",
    "            loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)\n",
    "            \n",
    "            optimizer.step()\n",
    "            sched.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # Logging\n",
    "            losses.append(loss.item())\n",
    "#             if (k-1) % int(n_epochs/5) == 0 and bb == (nbatch-1):\n",
    "#                 plt.plot(losses);\n",
    "#                 plt.show()\n",
    "            \n",
    "    return model, losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ac2ec62-03a9-4778-9e91-da23bbfeacf5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def GP_FMv2(model_list, optimizer_list,\n",
    "            x_data, alpha, l, nt_per, batch_size, t_obs, n_epochs, sig2_diag = 0, grad_clip = 1.0):\n",
    "    \n",
    "    N = x_data.shape[0]\n",
    "    dim = x_data.shape[2]\n",
    "    \n",
    "    nbatch = int(N/batch_size)\n",
    "    batch_idx = np.reshape(np.arange(0,N),[nbatch, batch_size])\n",
    "    \n",
    "    loss_list = [None]*dim\n",
    "    loss_val = torch.zeros(dim)\n",
    "    losses: List[float] = []\n",
    "    for ii in range(dim):\n",
    "        model_list[ii].train()\n",
    "        sched = torch.optim.lr_scheduler.LambdaLR(optimizer_list[ii], lr_lambda=warmup_lr)\n",
    "        \n",
    "    for k in tqdm(range(n_epochs)):\n",
    "\n",
    "        for bb in range(nbatch):\n",
    "            x_obs = x_data[batch_idx[bb,:],:,:]\n",
    "            \n",
    "            t_raw = torch.rand(nt_per*(len(t_obs) - 1))\n",
    "            int_rep = (t_obs[1:] - t_obs[:-1]).repeat_interleave(nt_per)\n",
    "            start_rep = t_obs[:-1].repeat_interleave(nt_per)\n",
    "            t_batch = int_rep*t_raw + start_rep\n",
    "            \n",
    "            try:\n",
    "                xt_batch, ut_batch = samp_x_dx(t_batch, alpha, l, x_obs, t_obs, sig2_diag)\n",
    "            except:\n",
    "                pass\n",
    "            \n",
    "            t = t_batch.repeat(1,batch_size).T\n",
    "            xt = torch.reshape(xt_batch, (-1,dim))\n",
    "            ut = torch.reshape(ut_batch, (-1,dim))\n",
    "            x00 = x_obs[:,0,:].repeat_interleave(len(t_raw), dim = 0)\n",
    "            \n",
    "            for ii in range(dim):\n",
    "                vt = model_list[ii](torch.cat([xt[:,ii:(ii+1)], x00[:,ii:(ii+1)], t], dim=-1))\n",
    "                loss_list[ii] = torch.mean((vt - ut[:,ii:(ii+1)]) ** 2)\n",
    "#                 loss_list[ii] = torch.mean((vt - ut[:,ii:(ii+1)]) ** 2) + torch.mean(vt**2)\n",
    "                loss_list[ii].backward()\n",
    "                torch.nn.utils.clip_grad_norm_(model_list[ii].parameters(), grad_clip)\n",
    "                \n",
    "                optimizer_list[ii].step()\n",
    "                sched.step()\n",
    "                \n",
    "                optimizer_list[ii].zero_grad()\n",
    "                \n",
    "                loss_val[ii] = loss_list[ii].item()\n",
    "                \n",
    "            # Logging\n",
    "            losses.append(torch.mean(loss_val))\n",
    "#             if (k-1) % int(n_epochs/5) == 0 and bb == (nbatch-1):\n",
    "#                 plt.plot(losses);\n",
    "#                 plt.show()\n",
    "            \n",
    "    return model_list, losses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13cab548-8ba1-461a-aab0-4c8b3127ea1d",
   "metadata": {},
   "source": [
    "## 2.3 I-CFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f54d808a-e6d1-4d83-869e-55584ea0dd28",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_conditional_pt(x0, x1, t, sigma):\n",
    "    t = t.reshape(-1, *([1] * (x0.dim() - 1)))\n",
    "    mu_t = t * x1 + (1 - t) * x0\n",
    "    epsilon = torch.randn_like(x0)\n",
    "    return mu_t + sigma * epsilon\n",
    "\n",
    "def compute_conditional_vector_field(x0, x1):\n",
    "    return x1 - x0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c632ebd-4df3-48e7-8197-56356e77789c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def I_FM(x1, model, optimizer, sigma = 1e-1, n_epochs = 10000, grad_clip = 1.0):\n",
    "    \n",
    "    losses: List[float] = []\n",
    "    \n",
    "    sched = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lr)\n",
    "    \n",
    "    model.train()\n",
    "    for k in tqdm(range(n_epochs)):\n",
    "        x0 = torch.randn_like(x1)\n",
    "\n",
    "        t = torch.rand(x0.shape[0]).type_as(x0)\n",
    "        xt = sample_conditional_pt(x0, x1, t, sigma=sigma)\n",
    "        ut = compute_conditional_vector_field(x0, x1)\n",
    "        vt = model(torch.cat([xt, t[:, None]], dim=-1))\n",
    "        loss = torch.mean((vt - ut) ** 2)\n",
    "\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)\n",
    "        \n",
    "        optimizer.step()\n",
    "        sched.step()\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # Logging\n",
    "        losses.append(loss.item())\n",
    "        \n",
    "    return model, losses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dac59344-0834-4ee3-bf5c-d7cfd3595dad",
   "metadata": {},
   "source": [
    "# 3. Training"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41e21111-c357-4e80-a534-e8018d565ae2",
   "metadata": {},
   "source": [
    "## 3.1 I-CFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6922f5fe-e2e4-4151-b68e-57db4b862116",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_t0 = lfp_all_trans[:,0,:]\n",
    "dim = x_t0.shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26e50533-000a-4cca-952a-3de93405815f",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epochs = 50000\n",
    "warmup = 5000\n",
    "\n",
    "model0 = MLP(dim = dim, out_dim = dim, time_varying=True)\n",
    "optimizer0 = torch.optim.Adam(model0.parameters(), lr=2e-4)\n",
    "# model0, losses0 = I_FM(x_t0, model0, optimizer0, sigma = 1e-2, n_epochs = n_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3421671-ad5b-4d36-a12b-b3140aa8468a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.save(model0.state_dict(), f'model0.pt')\n",
    "model0.load_state_dict(torch.load(f\"model0.pt\", weights_only=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c9422a7-179e-4ab6-9749-5aa5a0cb563e",
   "metadata": {},
   "source": [
    "## 3.2 GP-I-CFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6c45911-c782-44e2-8257-320580e44e41",
   "metadata": {},
   "outputs": [],
   "source": [
    "# test\n",
    "t = torch.rand(200)\n",
    "t_obs = torch.linspace(0,1,nT)\n",
    "x_obs = lfp_all_trans\n",
    "\n",
    "alpha = 2\n",
    "l = 2/nT\n",
    "\n",
    "x_samp, dx_samp =  samp_x_dx(t, alpha, l, x_obs, t_obs, 1e-2)\n",
    "batch = 0\n",
    "dim_check = 6\n",
    "plt.scatter(t, x_samp[batch,:,dim_check], s = 2)\n",
    "plt.scatter(t, dx_samp[batch,:,dim_check], s= 2)\n",
    "plt.scatter(t_obs, x_obs[batch,:,dim_check], s= 5)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b4584c7-f654-4c5f-b96e-6e7ee4d1e6c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 2\n",
    "l = 2/nT\n",
    "nt_per = 1\n",
    "batch_size = 107 \n",
    "t_obs = torch.linspace(0,1,nT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da49cb91-3e95-460e-9e54-d66897dd3361",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epochs = 20000\n",
    "warmup = 2000\n",
    "sig2_diag = 1e-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3390e8ea-8779-47ff-86b6-2e08738dae6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%capture output\n",
    "model_all = MLP(dim=dim*2, out_dim = dim, w=128, time_varying=True)\n",
    "optimizer = torch.optim.Adam(model_all.parameters(), lr=8e-5) #, weight_decay = 1e-4)\n",
    "model_all, losses_all = GP_FMv1(model_all, optimizer,\n",
    "                                lfp_all_trans, alpha, l, nt_per, batch_size,\n",
    "                                t_obs, n_epochs, sig2_diag = sig2_diag, grad_clip = 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66404125-70f3-4e3e-8516-df083797c716",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.save(model_all.state_dict(), f'model_all.pt')\n",
    "model_all.load_state_dict(torch.load(f\"model_all.pt\", weights_only=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "795d3eb4-8ef6-4bfc-b13d-0ca9f58a7d4a",
   "metadata": {},
   "source": [
    "# 4. Generate Samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1c1d06c-bfb1-47a5-9201-9ff9cff2c99a",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_samp = 10000\n",
    "nt_gen = 50\n",
    "seed = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38e3976b-bb89-4c98-bc90-b7a3e25081e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "traj0 = gen_traj(model0, n_samp, 2, seed)\n",
    "x_start = traj0[-1,:,:]\n",
    "# x_start = lfp_all_trans[:, 0, :]\n",
    "traj = gen_traj(model_all, n_samp, nt_gen, seed, x_start = x_start, x00 = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4eedd4b-de27-4854-a3eb-77b95892fa7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "over_100_idx = ((torch.max(torch.abs(traj), dim = 0).values).max(dim=1).values > torch.max(lfp_all_trans))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a05280fe-6665-44dd-8e8d-bf35862361ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "sum(~over_100_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "244d28e2-8c68-4348-9d1f-874750c5c146",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_all = torch.nonzero(~over_100_idx, as_tuple=False).flatten()\n",
    "idx_use = idx_all[0:1000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fff3052d-2ba4-449c-8141-d6513ac96b6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# over_100_idx = ((torch.max(torch.abs(traj), dim = 0).values).max(dim=1).values > 80)\n",
    "mean_traj = torch.mean(traj[:,idx_use,:], 1)\n",
    "plt.plot(mean_traj);\n",
    "plt.legend(lfp_all['brain_area_lfp'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "099d5f9c-cc14-43c4-a499-540eeeb374c3",
   "metadata": {},
   "source": [
    "# 5. Let's Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32414cd9-3316-4841-93f7-441b70ba8d71",
   "metadata": {},
   "outputs": [],
   "source": [
    "# raw data\n",
    "mean_data = []\n",
    "sd_data = []\n",
    "\n",
    "for ii in range(lfp_all_trans.shape[2]):\n",
    "    \n",
    "    mean_lfp = torch.mean(lfp_all_trans[:,:,ii], 0)\n",
    "    mean_data.append(mean_lfp)\n",
    "    \n",
    "    sd_tmp = torch.zeros(nT)\n",
    "    for ss in range(nT):\n",
    "        sd_tmp[ss] =  torch.std(lfp_all_trans[ss,:,ii])\n",
    "    \n",
    "    sd_data.append(sd_tmp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0989ab69-5add-4614-8dc2-a7b7bc274871",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_path = \"/hpc/home/gw74/diff_model/FM/LFP/plots\"\n",
    "plt.rcParams['svg.fonttype'] = 'none'\n",
    "plt.rcParams['text.usetex'] = False\n",
    "plt.rcParams.update({'font.size': 14})\n",
    "plt.rcParams['figure.figsize'] = [4, 2.5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fbfb549-e957-4f67-8be8-7b1ec2eca87c",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_data[ii]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "162230f8-40ab-417e-8cf8-bff0a24b2731",
   "metadata": {},
   "outputs": [],
   "source": [
    "for ii in range(7):\n",
    "    plt.plot(torch.arange(0,nT)*0.01, mean_data[ii])\n",
    "# plt.title('real');\n",
    "plt.savefig(plot_path + \"/1_real.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e033fb26-9286-466e-b958-11d0b3c1f5f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "over_100_idx = ((torch.max(torch.abs(traj), dim = 0).values).max(dim=1).values > 0.9*torch.max(lfp_all_trans))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d11a2ce2-5d01-47f9-a740-60909eb158d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_traj = torch.mean(traj[:,idx_use,:], 1)\n",
    "plt.plot(torch.linspace(0, 0.04, 50), mean_traj);\n",
    "plt.legend(lfp_all['brain_area_lfp'])\n",
    "plt.savefig(plot_path + \"/2_mean_gen.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f287caf5-cbd4-4b6b-b384-5f957b5d63d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "for ii in range(dim):\n",
    "    plt.plot(torch.linspace(0, 0.04, 50), traj[:,idx_use,ii], c = 'gray', alpha = 0.1);\n",
    "    plt.savefig(plot_path + f\"/region_{lfp_all['brain_area_lfp'][ii]}.svg\")\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
