{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "7ab9e907",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.ndimage import gaussian_filter1d\n",
    "from scipy.linalg import qr\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import random\n",
    "import scipy.ndimage\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "device = \"cuda:4\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab60bd7c",
   "metadata": {},
   "source": [
    "# visual function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "941da570",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visual_F_list(F_list):\n",
    "\n",
    "    K = len(F_list)\n",
    "    plt.figure(figsize=(5 * K, 3))\n",
    "    for idx, f in enumerate(F_list):\n",
    "        plt.subplot(1, K, idx + 1)\n",
    "        plt.imshow(f.cpu(), cmap='coolwarm', interpolation='nearest')\n",
    "        plt.title(f\"F_{idx + 1}\")\n",
    "        plt.colorbar()\n",
    "        plt.xlabel(\"Latent dim\")\n",
    "        plt.ylabel(\"Latent dim\")\n",
    "\n",
    "        p = f.shape[0]\n",
    "        ticks = list(range(p))\n",
    "        labels = list(range(1, p + 1))\n",
    "        plt.xticks(ticks, labels)\n",
    "        plt.yticks(ticks, labels)\n",
    "        \n",
    "    plt.suptitle(\"Sub-circuit dynamic matrices F_list\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def visual_C(C):\n",
    "\n",
    "    plt.figure(figsize=(10,3))\n",
    "    for k in range(C.shape[0]):\n",
    "        plt.plot(C[k].detach().cpu().numpy(), label=f'c_{k+1}')\n",
    "    plt.title(\"Sub-circuit coefficients C\")\n",
    "    plt.xlabel(\"Time\")\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "\n",
    "def visual_A(data):\n",
    "\n",
    "    plt.figure(figsize=(6, 5))\n",
    "    plt.imshow(data.detach().cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')\n",
    "    plt.colorbar(label='Projection weight')\n",
    "    plt.title(\"Projection matrix A\")\n",
    "    plt.xlabel(\"Latent dimension (p)\")\n",
    "    plt.ylabel(\"Neuron #\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def visual_X(data):\n",
    "\n",
    "    X = data.detach().cpu().numpy()\n",
    "    plt.figure(figsize=(10,3))\n",
    "    for dim in range(X.shape[0]):\n",
    "        plt.plot(X[dim], label=f'x_{dim+1}')\n",
    "    plt.title(\"Latent dynamics X\")\n",
    "    plt.xlabel(\"Time\")\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "\n",
    "def visual_Y(data):\n",
    "    Y = data.detach().cpu().numpy()\n",
    "    plt.figure(figsize=(10,5))\n",
    "    plt.imshow(Y, aspect='auto', origin='lower', cmap='viridis')\n",
    "    plt.colorbar(label='Firing rate')\n",
    "    plt.title(\"Synthetic observations Y\")\n",
    "    plt.xlabel(\"Time\")\n",
    "    plt.ylabel(\"Neuron #\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5aa98352",
   "metadata": {},
   "source": [
    "# optimal"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cdb5384",
   "metadata": {},
   "source": [
    "## a and x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bfed7de",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_h(Y):\n",
    "\n",
    "    Y_norm = Y / (Y.norm(dim=1, keepdim=True) + 1e-8)\n",
    "    h = Y_norm @ Y_norm.T\n",
    "    h = (h + 1.0) / 2.0\n",
    "    return h\n",
    "\n",
    "\n",
    "def similarity_loss(a, h):\n",
    "    D = torch.diag(h.sum(dim=1))\n",
    "    L = D - h\n",
    "    loss = torch.trace(a.T @ L @ a)\n",
    "    return loss\n",
    "\n",
    "\n",
    "def update_x(Y, X, a, C, F_list, num_iter, lr_x, lambda_dyn_X, epoch):\n",
    "\n",
    "    A_var = a.clone().detach().requires_grad_(False)\n",
    "    X_var = X.clone().detach().requires_grad_(True)\n",
    "    C_var = C.clone().detach().requires_grad_(False)\n",
    "    \n",
    "    optimizer_x = torch.optim.Adam([X_var], lr=lr_x)\n",
    "\n",
    "    F_all = torch.stack(F_list, dim=0).clone().detach().requires_grad_(False)\n",
    "\n",
    "    for i in range(num_iter):\n",
    "\n",
    "        Y_hat = A_var @ X_var\n",
    "        loss_rec_y = F.mse_loss(Y, Y_hat)\n",
    "\n",
    "        F_t_all = torch.einsum('kt,kij->tij', C_var, F_all)\n",
    "        X_t1_pred = torch.einsum('tij,jt->it', F_t_all[:-1, :, :], X[:, :-1])                   \n",
    "        loss_dyn_x = lambda_dyn_X * F.mse_loss(X_var[:, 1:], X_t1_pred)\n",
    "\n",
    "\n",
    "        loss_x = loss_rec_y + loss_dyn_x\n",
    "\n",
    "        if epoch % 10 == 0 and i == 0:\n",
    "            print(f\"[Epoch {epoch}] loss_rec_y={loss_rec_y.item():.6f}\")\n",
    "\n",
    "\n",
    "        optimizer_x.zero_grad()\n",
    "        loss_x.backward()\n",
    "        optimizer_x.step()\n",
    "\n",
    "    next_x = X_var.detach()\n",
    "\n",
    "    return next_x\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2863dad3",
   "metadata": {},
   "source": [
    "## c and f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea99bdef",
   "metadata": {},
   "outputs": [],
   "source": [
    "def decorrelation_loss(F_var: torch.Tensor, eps: float = 1e-8):\n",
    "\n",
    "    K = F_var.shape[0]\n",
    "    flat = F_var.reshape(K, -1) \n",
    "    flat = flat / (flat.norm(dim=1, keepdim=True).clamp_min(eps))\n",
    "\n",
    "    G = flat @ flat.t()\n",
    "\n",
    "    idx = torch.triu_indices(K, K, offset=1, device=F_var.device)\n",
    "    vals = (G[idx[0], idx[1]] ** 2)\n",
    "\n",
    "    return vals.sum()\n",
    "\n",
    "\n",
    "def update_c(X, C, F_list, num_iter, lr_c, lambda_sparse_c, lambda_smooth_c, epoch):\n",
    "\n",
    "    F_all = torch.stack(F_list, dim=0)\n",
    "\n",
    "    C_var = C.clone().detach().requires_grad_(True)\n",
    "    F_var = F_all.clone().detach().requires_grad_(False)\n",
    "\n",
    "    optimizer_c = torch.optim.Adam([C_var], lr=lr_c)\n",
    "\n",
    "    for i in range(num_iter):\n",
    "\n",
    "        F_t_all = torch.einsum('kt,kij->tij', C_var, F_var)\n",
    "        X_t1_pred = torch.einsum('tij,jt->it', F_t_all[:-1, :, :], X[:, :-1])\n",
    "\n",
    "        loss_dyn_c = F.mse_loss(X[:, 1:], X_t1_pred)\n",
    "\n",
    "        loss_sparse_c = lambda_sparse_c * torch.norm(C_var, p=1, dim=0).mean()\n",
    "\n",
    "        diff = C_var[:, 1:] - C_var[:, :-1]\n",
    "        loss_smooth_c = lambda_smooth_c * diff.abs().mean()\n",
    "\n",
    "        loss_c = loss_dyn_c + loss_sparse_c + loss_smooth_c\n",
    "\n",
    "        optimizer_c.zero_grad()\n",
    "        loss_c.backward()\n",
    "        optimizer_c.step()\n",
    "\n",
    "\n",
    "        if epoch % 10 == 0 and i == 0:\n",
    "            print(f\"[Epoch {epoch}] loss_dyn_c={loss_dyn_c.item():.6f}, \"\n",
    "                  f\"loss_sparse_c={loss_sparse_c.item():.6f}, \"\n",
    "                  f\"loss_smooth_c={loss_smooth_c.item():.6f}\"\n",
    "                  )\n",
    "\n",
    "        with torch.no_grad():\n",
    "            C_var.clamp_(min=0.0, max=1.0)\n",
    "            col_sums = C_var.sum(dim=0, keepdim=True)\n",
    "            col_sums = torch.clamp(col_sums, min=1e-8)\n",
    "            C_var.div_(col_sums)\n",
    "\n",
    "    C_var = C_var.detach()\n",
    "    \n",
    "    return C_var"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d4c3301",
   "metadata": {},
   "source": [
    "# main"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46ed014f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 0] loss_rec_y=0.239189\n",
      "[Epoch 10] loss_rec_y=0.014262\n",
      "[Epoch 0] loss_dyn_c=0.332270, loss_sparse_c=0.074576, loss_smooth_c=0.027128\n",
      "[Epoch 10] loss_dyn_c=0.010707, loss_sparse_c=0.050000, loss_smooth_c=0.014710\n",
      "[Epoch 0] loss_rec_y=0.006482\n",
      "[Epoch 0] loss_dyn_c=0.010657, loss_sparse_c=0.050000, loss_smooth_c=0.014869\n",
      "[Epoch 10] loss_rec_y=0.006482\n",
      "[Epoch 10] loss_dyn_c=0.010611, loss_sparse_c=0.050000, loss_smooth_c=0.014957\n",
      "[Epoch 20] loss_rec_y=0.006482\n",
      "[Epoch 20] loss_dyn_c=0.010597, loss_sparse_c=0.050000, loss_smooth_c=0.014923\n",
      "[Epoch 30] loss_rec_y=0.006482\n",
      "[Epoch 30] loss_dyn_c=0.010589, loss_sparse_c=0.050000, loss_smooth_c=0.014873\n",
      "[Epoch 40] loss_rec_y=0.006482\n",
      "[Epoch 40] loss_dyn_c=0.010577, loss_sparse_c=0.050000, loss_smooth_c=0.014813\n"
     ]
    }
   ],
   "source": [
    "def main(F_list, X, C, Y, a, epoch_num, warmup_num):\n",
    "\n",
    "    for epoch in range(warmup_num):\n",
    "\n",
    "        X = update_x(Y, X, a, C, F_list, num_iter=20, lr_x=1e-2, lambda_dyn_X=0.0, epoch=epoch)\n",
    "\n",
    "    for epoch in range(warmup_num):\n",
    "\n",
    "        C = update_c(X, C, F_list, num_iter=20, lr_c=1e-2, lambda_sparse_c=0.05, lambda_smooth_c=0.08, epoch=epoch)\n",
    "            \n",
    "\n",
    "    # --- joint optimization ---\n",
    "    for epoch in range(epoch_num):\n",
    "\n",
    "        X = update_x(Y, X, a, C, F_list, num_iter=2, lr_x=1e-2, lambda_dyn_X=0.0001, epoch=epoch)\n",
    "\n",
    "        C = update_c(X, C, F_list, num_iter=2, lr_c=1e-2, lambda_sparse_c=0.05, lambda_smooth_c=0.08, epoch=epoch)\n",
    "\n",
    "\n",
    "    data = {\n",
    "        'C': C,\n",
    "        'X': X,\n",
    "        'A': a,\n",
    "        'Y': Y,\n",
    "        'F_list': F_list\n",
    "    }\n",
    "    return data\n",
    "\n",
    "\n",
    "N = 21\n",
    "p = 3\n",
    "K = 3\n",
    "T = 500\n",
    "\n",
    "data_true = torch.load(\"./data/Three_Task_Synthetic_Data_test.pt\", weights_only=True)\n",
    "data_a_f = torch.load(\"./data/Three_Task_Synthetic_Data_A_F.pt\", weights_only=True)\n",
    "\n",
    "Y = data_true['Y']\n",
    "\n",
    "F_list = data_a_f[\"F_list\"]\n",
    "a = data_a_f[\"A\"]\n",
    "\n",
    "torch.manual_seed(0)\n",
    "C = torch.rand(K, T).to(device)\n",
    "X = torch.randn(p, T, requires_grad=False, device=device)\n",
    "X[:, 0] = torch.tensor([1, -1, 1])\n",
    "\n",
    "# ==== Run main ====\n",
    "data_est = main(F_list, X, C, Y, a, epoch_num=50, warmup_num=20)\n",
    "\n",
    "# visual_F_list(data_est[\"F_list\"])\n",
    "# visual_C(data_est[\"C\"])\n",
    "# visual_A(data_est[\"A\"])\n",
    "# visual_X(data_est[\"X\"])\n",
    "# visual_Y(data_est[\"A\"] @ data_est[\"X\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "206e0fbc",
   "metadata": {},
   "source": [
    "# compute MSE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b7f2632",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss_Y': tensor(0.0065, device='cuda:4'), 'loss_X': tensor(0.0094, device='cuda:4'), 'loss_C': tensor(0.2890, device='cuda:4'), 'loss_F': tensor(0.0075, device='cuda:4'), 'loss_a': tensor(0.0041, device='cuda:4')}\n"
     ]
    }
   ],
   "source": [
    "def compute_losses(X, C, F_list, a, data_true):\n",
    "\n",
    "    true_X = data_true['X']\n",
    "    true_C = data_true['C']\n",
    "    true_F_list = data_true['F_list']\n",
    "    true_a = data_true['A']\n",
    "    true_Y = data_true['Y']\n",
    "\n",
    "    Y_hat = a @ X\n",
    "    loss_Y = torch.mean((Y_hat - true_Y) ** 2)\n",
    "\n",
    "    loss_X = torch.mean((X - true_X) ** 2)\n",
    "\n",
    "    loss_C = torch.mean((C - true_C) ** 2)\n",
    "\n",
    "    loss_F = 0.0\n",
    "    for f_est, f_true in zip(F_list, true_F_list):\n",
    "        loss_F += torch.mean((f_est - f_true) ** 2)\n",
    "\n",
    "    loss_a = torch.mean((a - true_a) ** 2)\n",
    "\n",
    "    return {\n",
    "        'loss_Y': loss_Y,\n",
    "        'loss_X': loss_X,\n",
    "        'loss_C': loss_C,\n",
    "        'loss_F': loss_F,\n",
    "        'loss_a': loss_a\n",
    "    }\n",
    "\n",
    "re_F_list = data_est[\"F_list\"]\n",
    "re_C = data_est[\"C\"]\n",
    "re_A = data_est[\"A\"]\n",
    "re_X = data_est[\"X\"]\n",
    "re_Y = data_est[\"A\"] @ data_est[\"X\"]\n",
    "\n",
    "results_mse = compute_losses(re_X, re_C, re_F_list, re_A, data_true)\n",
    "print(results_mse)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af49ad83",
   "metadata": {},
   "source": [
    "# compute_p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fac1e4f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'corr_A': 0.9402853807650996, 'corr_F_list': [0.9992513474937725, 0.9930902077186291, 0.9909606149523752], 'corr_C': 0.2649108390692795, 'corr_X': 0.9934662950658272, 'corr_Y': 0.9669815197890067}\n"
     ]
    }
   ],
   "source": [
    "def p(X, C, F_list, a, data_true):\n",
    "\n",
    "    def corrcoef(a, b):\n",
    "        a_flat = a.flatten()\n",
    "        b_flat = b.flatten()\n",
    "        return np.corrcoef(a_flat, b_flat)[0, 1]\n",
    "\n",
    "    results = {}\n",
    "\n",
    "    Y_hat = a @ X\n",
    "\n",
    "    A_true = data_true['A'].cpu().numpy()\n",
    "    A_est = a.cpu().numpy()\n",
    "    results['corr_A'] = corrcoef(A_true, A_est)\n",
    "\n",
    "    corr_F = []\n",
    "    F_true_list = data_true['F_list']\n",
    "    F_est_list = F_list\n",
    "    for f_true, f_est in zip(F_true_list, F_est_list):\n",
    "        f_true = f_true.cpu().numpy()\n",
    "        f_est = f_est.cpu().numpy()\n",
    "        corr = corrcoef(f_true, f_est)\n",
    "        corr_F.append(corr)\n",
    "    results['corr_F_list'] = corr_F\n",
    "\n",
    "    C_true = data_true['C'].cpu().numpy()\n",
    "    C_est = C.cpu().numpy()\n",
    "    results['corr_C'] = corrcoef(C_true, C_est)\n",
    "\n",
    "    X_true = data_true['X'].cpu().numpy()\n",
    "    X_est = X.cpu().numpy()\n",
    "    results['corr_X'] = corrcoef(X_true, X_est)\n",
    "\n",
    "    Y_true = data_true['Y'].cpu().numpy()\n",
    "    Y_est = Y_hat.cpu().numpy()\n",
    "    results['corr_Y'] = corrcoef(Y_true, Y_est)\n",
    "\n",
    "    return results\n",
    "\n",
    "results_p = p(re_X, re_C, re_F_list, re_A, data_true)\n",
    "print(results_p)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
