{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "7ab9e907",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.ndimage import gaussian_filter1d\n",
    "from scipy.linalg import qr\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import random\n",
    "import scipy.ndimage\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "device = \"cuda:4\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab60bd7c",
   "metadata": {},
   "source": [
    "# visual function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "941da570",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visual_F_list(F_list):\n",
    "\n",
    "    K = len(F_list)\n",
    "    plt.figure(figsize=(5 * K, 3))\n",
    "    for idx, f in enumerate(F_list):\n",
    "        plt.subplot(1, K, idx + 1)\n",
    "        plt.imshow(f.cpu(), cmap='coolwarm', interpolation='nearest')\n",
    "        plt.title(f\"F_{idx + 1}\")\n",
    "        plt.colorbar()\n",
    "        plt.xlabel(\"Latent dim\")\n",
    "        plt.ylabel(\"Latent dim\")\n",
    "\n",
    "        p = f.shape[0]\n",
    "        ticks = list(range(p))\n",
    "        labels = list(range(1, p + 1))\n",
    "        plt.xticks(ticks, labels)\n",
    "        plt.yticks(ticks, labels)\n",
    "        \n",
    "    plt.suptitle(\"Sub-circuit dynamic matrices F_list\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def visual_C(C):\n",
    "\n",
    "    plt.figure(figsize=(10,3))\n",
    "    for k in range(C.shape[0]):\n",
    "        plt.plot(C[k].cpu().numpy(), label=f'c_{k+1}')\n",
    "    plt.title(\"Sub-circuit coefficients C\")\n",
    "    plt.xlabel(\"Time\")\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "\n",
    "def visual_A(data):\n",
    "\n",
    "    plt.figure(figsize=(6, 5))\n",
    "    plt.imshow(data.detach().cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')\n",
    "    plt.colorbar(label='Projection weight')\n",
    "    plt.title(\"Projection matrix A\")\n",
    "    plt.xlabel(\"Latent dimension (p)\")\n",
    "    plt.ylabel(\"Neuron #\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def visual_X(data):\n",
    "\n",
    "    X = data.detach().cpu().numpy()\n",
    "    plt.figure(figsize=(10,3))\n",
    "    for dim in range(X.shape[0]):\n",
    "        plt.plot(X[dim], label=f'x_{dim+1}')\n",
    "    plt.title(\"Latent dynamics X\")\n",
    "    plt.xlabel(\"Time\")\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "\n",
    "def visual_Y(data):\n",
    "    Y = data.detach().cpu().numpy()\n",
    "\n",
    "    plt.figure(figsize=(10,5))\n",
    "    plt.imshow(Y, aspect='auto', origin='lower', cmap='viridis')\n",
    "    plt.colorbar(label='Firing rate')\n",
    "    plt.title(\"Synthetic observations Y\")\n",
    "    plt.xlabel(\"Time\")\n",
    "    plt.ylabel(\"Neuron #\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5aa98352",
   "metadata": {},
   "source": [
    "# optimal"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cdb5384",
   "metadata": {},
   "source": [
    "## a and x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bfed7de",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_h(Y):\n",
    "\n",
    "    Y_norm = Y / (Y.norm(dim=1, keepdim=True) + 1e-8)\n",
    "    h = Y_norm @ Y_norm.T\n",
    "    h = (h + 1.0) / 2.0\n",
    "    return h\n",
    "\n",
    "\n",
    "def similarity_loss(a, h):\n",
    "    D = torch.diag(h.sum(dim=1))\n",
    "    L = D - h\n",
    "    loss = torch.trace(a.T @ L @ a)\n",
    "    return loss\n",
    "\n",
    "\n",
    "def update_a_and_x(Y, X, a, C, F_list, num_iter, lr_a, lr_x, lambda_sparse_A, lambda_sim_A, lambda_dyn_X, epoch):\n",
    "\n",
    "    N, T = Y.shape\n",
    "    p = X.shape[0]\n",
    "\n",
    "    h = compute_h(Y)\n",
    "\n",
    "    A_var = a.clone().detach().requires_grad_(True)\n",
    "    X_var = X.clone().detach().requires_grad_(False)\n",
    "\n",
    "    optimizer_a = torch.optim.Adam([A_var], lr=lr_a)\n",
    "\n",
    "    for i in range(num_iter):\n",
    "\n",
    "        Y_hat = A_var @ X_var\n",
    "        loss_mse = F.mse_loss(Y, Y_hat)\n",
    "\n",
    "        loss_sparse = lambda_sparse_A * torch.norm(A_var, p=1)\n",
    "\n",
    "        loss_sim = lambda_sim_A * similarity_loss(A_var, h)\n",
    "\n",
    "        loss_a = loss_mse + loss_sparse + loss_sim\n",
    "\n",
    "        if epoch % 10 == 0 and i == 0:\n",
    "            print(f\"[Epoch {epoch}] \"\n",
    "                  f\"loss_rec_y={loss_mse.item():.6f}, \"\n",
    "                  f\"loss_sparse_a={loss_sparse.item():.6f}, \"\n",
    "                  f\"loss_sim_a={loss_sim.item():.6f}\")\n",
    "\n",
    "        optimizer_a.zero_grad()\n",
    "        loss_a.backward()\n",
    "        optimizer_a.step()\n",
    "\n",
    "        with torch.no_grad():\n",
    "            A_var.clamp_(min=0.0, max=1.0)\n",
    "            col_norms = torch.norm(A_var, p=2, dim=0, keepdim=True)\n",
    "            col_norms = torch.clamp(col_norms, min=1e-8)\n",
    "            A_var.div_(col_norms)\n",
    "    \n",
    "    next_a = A_var.detach()\n",
    "\n",
    "\n",
    "    A_var = next_a.clone().detach().requires_grad_(False)\n",
    "    X_var = X.clone().detach().requires_grad_(True)\n",
    "    C_var = C.clone().detach().requires_grad_(False)\n",
    "    \n",
    "    optimizer_x = torch.optim.Adam([X_var], lr=lr_x)\n",
    "\n",
    "    F_all = torch.stack(F_list, dim=0).clone().detach().requires_grad_(False)\n",
    "\n",
    "    for i in range(num_iter):\n",
    "\n",
    "        Y_hat = A_var @ X_var\n",
    "        loss_rec_y = F.mse_loss(Y, Y_hat)\n",
    "\n",
    "        F_t_all = torch.einsum('kt,kij->tij', C_var, F_all)\n",
    "        X_t1_pred = torch.einsum('tij,jt->it', F_t_all[:-1, :, :], X[:, :-1])                   \n",
    "        loss_dyn_x = lambda_dyn_X * F.mse_loss(X_var[:, 1:], X_t1_pred)\n",
    "\n",
    "        loss_x = loss_rec_y + loss_dyn_x\n",
    "\n",
    "        if epoch % 10 == 0 and i == 0:\n",
    "            print(f\"[Epoch {epoch}] loss_rec_y={loss_rec_y.item():.6f}\")\n",
    "\n",
    "        optimizer_x.zero_grad()\n",
    "        loss_x.backward()\n",
    "        optimizer_x.step()\n",
    "\n",
    "    next_x = X_var.detach()\n",
    "\n",
    "    return next_a, next_x\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2863dad3",
   "metadata": {},
   "source": [
    "## c and f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea99bdef",
   "metadata": {},
   "outputs": [],
   "source": [
    "def decorrelation_loss(F_var: torch.Tensor, eps: float = 1e-8,):\n",
    "\n",
    "    K = F_var.shape[0]\n",
    "    flat = F_var.reshape(K, -1)                \n",
    "    flat = flat / (flat.norm(dim=1, keepdim=True).clamp_min(eps))\n",
    "\n",
    "    G = flat @ flat.t()\n",
    "\n",
    "    idx = torch.triu_indices(K, K, offset=1, device=F_var.device)\n",
    "    vals = (G[idx[0], idx[1]] ** 2)\n",
    "\n",
    "    return vals.sum()\n",
    "\n",
    "\n",
    "def update_c_and_f(X, C, F_list, num_iter, lr_c, lr_f, lambda_sparse_c, lambda_smooth_c, lambda_sparse_f, lambda_decor_f, epoch):\n",
    "    p, _ = X.shape\n",
    "    K, T = C.shape\n",
    "    \n",
    "    F_all = torch.stack(F_list, dim=0)\n",
    "\n",
    "    C_var = C.clone().detach().requires_grad_(True)\n",
    "    F_var = F_all.clone().detach().requires_grad_(False)\n",
    "\n",
    "    optimizer_c = torch.optim.Adam([C_var], lr=lr_c)\n",
    "\n",
    "    for i in range(num_iter):\n",
    "\n",
    "        # compute dynamics loss\n",
    "        F_t_all = torch.einsum('kt,kij->tij', C_var, F_var)\n",
    "        X_t1_pred = torch.einsum('tij,jt->it', F_t_all[:-1, :, :], X[:, :-1])\n",
    "\n",
    "        loss_dyn_c = F.mse_loss(X[:, 1:], X_t1_pred)\n",
    "\n",
    "        loss_sparse_c = lambda_sparse_c * torch.norm(C_var, p=1, dim=0).mean()\n",
    "\n",
    "        diff = C_var[:, 1:] - C_var[:, :-1]  # [K, T-1]\n",
    "        loss_smooth_c = lambda_smooth_c * diff.abs().mean()\n",
    "\n",
    "        loss_c = loss_dyn_c + loss_sparse_c + loss_smooth_c\n",
    "\n",
    "        optimizer_c.zero_grad()\n",
    "        loss_c.backward()\n",
    "        optimizer_c.step()\n",
    "\n",
    "        if epoch % 10 == 0 and i == 0:\n",
    "            print(f\"[Epoch {epoch}] loss_dyn_c={loss_dyn_c.item():.6f}, \"\n",
    "                  f\"loss_sparse_c={loss_sparse_c.item():.6f}, \"\n",
    "                  f\"loss_smooth_c={loss_smooth_c.item():.6f}\"\n",
    "                  )\n",
    "\n",
    "        with torch.no_grad():\n",
    "            C_var.clamp_(min=0.0, max=1.0)\n",
    "            col_sums = C_var.sum(dim=0, keepdim=True)  # [1, T]\n",
    "            col_sums = torch.clamp(col_sums, min=1e-8)\n",
    "            C_var.div_(col_sums)\n",
    "\n",
    "    C_opt = C_var.detach()\n",
    "\n",
    "\n",
    "    C_var = C_opt.clone().detach().requires_grad_(False)\n",
    "    F_var = F_all.clone().detach().requires_grad_(True)\n",
    "\n",
    "    optimizer_f = torch.optim.Adam([F_var], lr=lr_f)\n",
    "\n",
    "    for i in range(num_iter):\n",
    "\n",
    "        F_t_all = torch.einsum('kt,kij->tij', C_var, F_var)\n",
    "        X_t1_pred = torch.einsum('tij,jt->it', F_t_all[:-1, :, :], X[:, :-1])\n",
    "\n",
    "        loss_dyn_f = F.mse_loss(X[:, 1:], X_t1_pred)\n",
    "\n",
    "        loss_sparse_f = lambda_sparse_f * torch.norm(F_var, p=1)\n",
    "\n",
    "        loss_decor_f = lambda_decor_f * decorrelation_loss(F_var)\n",
    "\n",
    "        if epoch % 10 == 0 and i == 0:\n",
    "            print(f\"[Epoch {epoch}] loss_dyn_f={loss_dyn_f.item():.6f}, \"\n",
    "                  f\"loss_sparse_f={loss_sparse_f.item():.6f}, \"\n",
    "                  f\"loss_decor_f={loss_decor_f.item():.6f}\")\n",
    "\n",
    "        loss_f = loss_dyn_f + loss_sparse_f + loss_decor_f\n",
    "\n",
    "        optimizer_f.zero_grad()\n",
    "        loss_f.backward()\n",
    "        optimizer_f.step()\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for k in range(K):\n",
    "                eigvals = torch.linalg.eigvals(F_var[k])\n",
    "                max_abs_eig = eigvals.abs().max()\n",
    "                F_var[k] /= max_abs_eig\n",
    "\n",
    "    F_opt = [F_var[i].detach() for i in range(K)]\n",
    "\n",
    "    return C_opt, F_opt\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d4c3301",
   "metadata": {},
   "source": [
    "# main"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46ed014f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 0] loss_rec_y=1.184915, loss_sparse_a=0.160320, loss_sim_a=0.071039\n",
      "[Epoch 0] loss_rec_y=0.243432\n",
      "[Epoch 10] loss_rec_y=0.045665, loss_sparse_a=0.057226, loss_sim_a=0.012231\n",
      "[Epoch 10] loss_rec_y=0.042825\n",
      "[Epoch 20] loss_rec_y=0.016001, loss_sparse_a=0.055982, loss_sim_a=0.012611\n",
      "[Epoch 20] loss_rec_y=0.015095\n",
      "[Epoch 30] loss_rec_y=0.014100, loss_sparse_a=0.051011, loss_sim_a=0.016797\n",
      "[Epoch 30] loss_rec_y=0.014145\n",
      "[Epoch 40] loss_rec_y=0.014407, loss_sparse_a=0.043628, loss_sim_a=0.022128\n",
      "[Epoch 40] loss_rec_y=0.014570\n",
      "[Epoch 50] loss_rec_y=0.015362, loss_sparse_a=0.038035, loss_sim_a=0.025108\n",
      "[Epoch 50] loss_rec_y=0.015499\n",
      "[Epoch 60] loss_rec_y=0.015899, loss_sparse_a=0.033725, loss_sim_a=0.026966\n",
      "[Epoch 60] loss_rec_y=0.015801\n",
      "[Epoch 70] loss_rec_y=0.015524, loss_sparse_a=0.032215, loss_sim_a=0.027406\n",
      "[Epoch 70] loss_rec_y=0.015576\n",
      "[Epoch 80] loss_rec_y=0.015867, loss_sparse_a=0.031622, loss_sim_a=0.027772\n",
      "[Epoch 80] loss_rec_y=0.015878\n",
      "[Epoch 90] loss_rec_y=0.016094, loss_sparse_a=0.030918, loss_sim_a=0.028176\n",
      "[Epoch 90] loss_rec_y=0.016070\n",
      "[Epoch 100] loss_rec_y=0.016083, loss_sparse_a=0.030921, loss_sim_a=0.028174\n",
      "[Epoch 100] loss_rec_y=0.016077\n",
      "[Epoch 110] loss_rec_y=0.016082, loss_sparse_a=0.030917, loss_sim_a=0.028178\n",
      "[Epoch 110] loss_rec_y=0.016095\n",
      "[Epoch 120] loss_rec_y=0.016076, loss_sparse_a=0.030923, loss_sim_a=0.028172\n",
      "[Epoch 120] loss_rec_y=0.016079\n",
      "[Epoch 130] loss_rec_y=0.016096, loss_sparse_a=0.030917, loss_sim_a=0.028178\n",
      "[Epoch 130] loss_rec_y=0.016065\n",
      "[Epoch 140] loss_rec_y=0.016067, loss_sparse_a=0.030927, loss_sim_a=0.028168\n",
      "[Epoch 140] loss_rec_y=0.016080\n",
      "[Epoch 150] loss_rec_y=0.016100, loss_sparse_a=0.030916, loss_sim_a=0.028179\n",
      "[Epoch 150] loss_rec_y=0.016079\n",
      "[Epoch 160] loss_rec_y=0.016094, loss_sparse_a=0.030917, loss_sim_a=0.028177\n",
      "[Epoch 160] loss_rec_y=0.016059\n",
      "[Epoch 170] loss_rec_y=0.016096, loss_sparse_a=0.030917, loss_sim_a=0.028177\n",
      "[Epoch 170] loss_rec_y=0.016070\n",
      "[Epoch 180] loss_rec_y=0.016087, loss_sparse_a=0.030918, loss_sim_a=0.028177\n",
      "[Epoch 180] loss_rec_y=0.016075\n",
      "[Epoch 190] loss_rec_y=0.016081, loss_sparse_a=0.030920, loss_sim_a=0.028175\n",
      "[Epoch 190] loss_rec_y=0.016078\n",
      "[Epoch 0] loss_dyn_c=0.287023, loss_sparse_c=0.074759, loss_smooth_c=0.027236\n",
      "[Epoch 0] loss_dyn_f=0.013466, loss_sparse_f=0.010421, loss_decor_f=0.002814\n",
      "[Epoch 10] loss_dyn_c=0.003792, loss_sparse_c=0.050000, loss_smooth_c=0.003146\n",
      "[Epoch 10] loss_dyn_f=0.003775, loss_sparse_f=0.010184, loss_decor_f=0.002802\n",
      "[Epoch 20] loss_dyn_c=0.003728, loss_sparse_c=0.050000, loss_smooth_c=0.002766\n",
      "[Epoch 20] loss_dyn_f=0.003722, loss_sparse_f=0.010113, loss_decor_f=0.002858\n",
      "[Epoch 30] loss_dyn_c=0.003624, loss_sparse_c=0.050000, loss_smooth_c=0.002729\n",
      "[Epoch 30] loss_dyn_f=0.003624, loss_sparse_f=0.010069, loss_decor_f=0.002888\n",
      "[Epoch 40] loss_dyn_c=0.003571, loss_sparse_c=0.050000, loss_smooth_c=0.002677\n",
      "[Epoch 40] loss_dyn_f=0.003576, loss_sparse_f=0.010039, loss_decor_f=0.002905\n",
      "[Epoch 50] loss_dyn_c=0.003118, loss_sparse_c=0.050000, loss_smooth_c=0.002688\n",
      "[Epoch 50] loss_dyn_f=0.003117, loss_sparse_f=0.010153, loss_decor_f=0.002917\n",
      "[Epoch 60] loss_dyn_c=0.003130, loss_sparse_c=0.050000, loss_smooth_c=0.002690\n",
      "[Epoch 60] loss_dyn_f=0.003130, loss_sparse_f=0.010145, loss_decor_f=0.002916\n",
      "[Epoch 70] loss_dyn_c=0.003144, loss_sparse_c=0.050000, loss_smooth_c=0.002690\n",
      "[Epoch 70] loss_dyn_f=0.003145, loss_sparse_f=0.010137, loss_decor_f=0.002914\n",
      "[Epoch 80] loss_dyn_c=0.003155, loss_sparse_c=0.050000, loss_smooth_c=0.002692\n",
      "[Epoch 80] loss_dyn_f=0.003156, loss_sparse_f=0.010130, loss_decor_f=0.002913\n",
      "[Epoch 90] loss_dyn_c=0.003158, loss_sparse_c=0.050000, loss_smooth_c=0.002689\n",
      "[Epoch 90] loss_dyn_f=0.003158, loss_sparse_f=0.010129, loss_decor_f=0.002913\n",
      "[Epoch 100] loss_dyn_c=0.003160, loss_sparse_c=0.050000, loss_smooth_c=0.002690\n",
      "[Epoch 100] loss_dyn_f=0.003160, loss_sparse_f=0.010127, loss_decor_f=0.002913\n",
      "[Epoch 110] loss_dyn_c=0.003161, loss_sparse_c=0.050000, loss_smooth_c=0.002692\n",
      "[Epoch 110] loss_dyn_f=0.003160, loss_sparse_f=0.010126, loss_decor_f=0.002913\n",
      "[Epoch 120] loss_dyn_c=0.003159, loss_sparse_c=0.050000, loss_smooth_c=0.002688\n",
      "[Epoch 120] loss_dyn_f=0.003159, loss_sparse_f=0.010127, loss_decor_f=0.002913\n",
      "[Epoch 130] loss_dyn_c=0.003160, loss_sparse_c=0.050000, loss_smooth_c=0.002689\n",
      "[Epoch 130] loss_dyn_f=0.003160, loss_sparse_f=0.010128, loss_decor_f=0.002913\n",
      "[Epoch 140] loss_dyn_c=0.003158, loss_sparse_c=0.050000, loss_smooth_c=0.002691\n",
      "[Epoch 140] loss_dyn_f=0.003157, loss_sparse_f=0.010129, loss_decor_f=0.002913\n",
      "[Epoch 150] loss_dyn_c=0.003159, loss_sparse_c=0.050000, loss_smooth_c=0.002691\n",
      "[Epoch 150] loss_dyn_f=0.003159, loss_sparse_f=0.010127, loss_decor_f=0.002913\n",
      "[Epoch 160] loss_dyn_c=0.003161, loss_sparse_c=0.050000, loss_smooth_c=0.002687\n",
      "[Epoch 160] loss_dyn_f=0.003159, loss_sparse_f=0.010127, loss_decor_f=0.002913\n",
      "[Epoch 170] loss_dyn_c=0.003159, loss_sparse_c=0.050000, loss_smooth_c=0.002691\n",
      "[Epoch 170] loss_dyn_f=0.003159, loss_sparse_f=0.010128, loss_decor_f=0.002913\n",
      "[Epoch 180] loss_dyn_c=0.003161, loss_sparse_c=0.050000, loss_smooth_c=0.002691\n",
      "[Epoch 180] loss_dyn_f=0.003162, loss_sparse_f=0.010127, loss_decor_f=0.002913\n",
      "[Epoch 190] loss_dyn_c=0.003158, loss_sparse_c=0.050000, loss_smooth_c=0.002692\n",
      "[Epoch 190] loss_dyn_f=0.003158, loss_sparse_f=0.010127, loss_decor_f=0.002913\n",
      "[Epoch 0] loss_rec_y=0.016087, loss_sparse_a=0.000000, loss_sim_a=0.000000\n",
      "[Epoch 0] loss_rec_y=0.015789\n",
      "[Epoch 0] loss_dyn_c=0.003160, loss_sparse_c=0.050000, loss_smooth_c=0.002688\n",
      "[Epoch 0] loss_dyn_f=0.003153, loss_sparse_f=0.010126, loss_decor_f=0.002913\n",
      "[Epoch 10] loss_rec_y=0.013235, loss_sparse_a=0.000000, loss_sim_a=0.000000\n",
      "[Epoch 10] loss_rec_y=0.012968\n",
      "[Epoch 10] loss_dyn_c=0.003172, loss_sparse_c=0.050000, loss_smooth_c=0.003048\n",
      "[Epoch 10] loss_dyn_f=0.003175, loss_sparse_f=0.010137, loss_decor_f=0.002913\n",
      "[Epoch 20] loss_rec_y=0.010713, loss_sparse_a=0.000000, loss_sim_a=0.000000\n",
      "[Epoch 20] loss_rec_y=0.010480\n",
      "[Epoch 20] loss_dyn_c=0.003202, loss_sparse_c=0.050000, loss_smooth_c=0.003050\n",
      "[Epoch 20] loss_dyn_f=0.003205, loss_sparse_f=0.010144, loss_decor_f=0.002915\n",
      "[Epoch 30] loss_rec_y=0.008537, loss_sparse_a=0.000000, loss_sim_a=0.000000\n",
      "[Epoch 30] loss_rec_y=0.008339\n",
      "[Epoch 30] loss_dyn_c=0.003241, loss_sparse_c=0.050000, loss_smooth_c=0.003049\n",
      "[Epoch 30] loss_dyn_f=0.003241, loss_sparse_f=0.010157, loss_decor_f=0.002916\n",
      "[Epoch 40] loss_rec_y=0.006710, loss_sparse_a=0.000000, loss_sim_a=0.000000\n",
      "[Epoch 40] loss_rec_y=0.006546\n",
      "[Epoch 40] loss_dyn_c=0.003294, loss_sparse_c=0.050000, loss_smooth_c=0.003055\n",
      "[Epoch 40] loss_dyn_f=0.003293, loss_sparse_f=0.010172, loss_decor_f=0.002917\n"
     ]
    }
   ],
   "source": [
    "def main(Y, N, K, p, T, epoch_num, warmup_num):\n",
    "\n",
    "    torch.manual_seed(0)\n",
    "    F_list = [torch.eye(p).to(device) + 0.1 * torch.randn(p, p).to(device) for _ in range(K)]\n",
    "    C = torch.rand(K, T).to(device)\n",
    "    a = torch.rand(N, p, requires_grad=False, device=device)\n",
    "    X = torch.randn(p, T, requires_grad=False, device=device)\n",
    "    X[:, 0] = torch.tensor([1, -1, 1])\n",
    "\n",
    "    for epoch in range(warmup_num):\n",
    "        a, X = update_a_and_x(Y, X, a, C, F_list, num_iter=20, lr_a=1e-3, lr_x=1e-2, lambda_sparse_A=0.005, lambda_sim_A=0.001, lambda_dyn_X=0.0, epoch=epoch)\n",
    "\n",
    "    for epoch in range(warmup_num):\n",
    "        C, F_list = update_c_and_f(X, C, F_list, num_iter=20, lr_c=1e-2, lr_f=1e-3, lambda_sparse_c=0.05, lambda_smooth_c=0.08, lambda_sparse_f=0.001, lambda_decor_f=0.001, epoch=epoch)\n",
    "            \n",
    "\n",
    "    # --- joint optimization ---\n",
    "    for epoch in range(epoch_num):\n",
    "\n",
    "        a, X = update_a_and_x(Y, X, a, C, F_list, num_iter=2, lr_a=1e-3, lr_x=1e-2, lambda_sparse_A=0.0, lambda_sim_A=0.0, lambda_dyn_X=0.0001, epoch=epoch)\n",
    "\n",
    "        C, F_list = update_c_and_f(X, C, F_list, num_iter=2, lr_c=1e-2, lr_f=1e-3, lambda_sparse_c=0.05, lambda_smooth_c=0.08, lambda_sparse_f=0.001, lambda_decor_f=0.001, epoch=epoch)\n",
    "\n",
    "\n",
    "    data = {\n",
    "        'C': C,\n",
    "        'X': X,\n",
    "        'A': a,\n",
    "        'Y': Y,\n",
    "        'F_list': F_list\n",
    "    }\n",
    "    return data\n",
    "\n",
    "\n",
    "N = 21         \n",
    "p = 3           \n",
    "K = 3           \n",
    "T = 500     \n",
    "\n",
    "data_true = torch.load(\"./data/Three_Task_Synthetic_Data_train.pt\", weights_only=True)\n",
    "\n",
    "Y = data_true['Y']\n",
    "\n",
    "data_est = main(Y, N, p, K, T, epoch_num=50, warmup_num=200)\n",
    "\n",
    "# visual_F_list(data_est[\"F_list\"])\n",
    "# visual_C(data_est[\"C\"])\n",
    "# visual_A(data_est[\"A\"])\n",
    "# visual_X(data_est[\"X\"])\n",
    "# visual_Y(data_est[\"A\"] @ data_est[\"X\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "206e0fbc",
   "metadata": {},
   "source": [
    "# compute MSE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b7f2632",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss_Y': tensor(0.0052, device='cuda:4'), 'loss_X': tensor(0.0095, device='cuda:4'), 'loss_C': tensor(0.6071, device='cuda:4'), 'loss_F': tensor(0.0075, device='cuda:4'), 'loss_a': tensor(0.0041, device='cuda:4')}\n"
     ]
    }
   ],
   "source": [
    "def compute_losses(X, C, F_list, a, data_true):\n",
    "\n",
    "    true_X = data_true['X']\n",
    "    true_C = data_true['C']\n",
    "    true_F_list = data_true['F_list']\n",
    "    true_a = data_true['A']\n",
    "    true_Y = data_true['Y']\n",
    "\n",
    "    Y_hat = a @ X\n",
    "    loss_Y = torch.mean((Y_hat - true_Y) ** 2)\n",
    "\n",
    "    loss_X = torch.mean((X - true_X) ** 2)\n",
    "\n",
    "    loss_C = torch.mean((C - true_C) ** 2)\n",
    "\n",
    "    loss_F = 0.0\n",
    "    for f_est, f_true in zip(F_list, true_F_list):\n",
    "        loss_F += torch.mean((f_est - f_true) ** 2)\n",
    "\n",
    "    loss_a = torch.mean((a - true_a) ** 2)\n",
    "\n",
    "    return {\n",
    "        'loss_Y': loss_Y,\n",
    "        'loss_X': loss_X,\n",
    "        'loss_C': loss_C,\n",
    "        'loss_F': loss_F,\n",
    "        'loss_a': loss_a\n",
    "    }\n",
    "\n",
    "def save_data(re_F_list, re_A):\n",
    "    data = {\n",
    "        'F_list': re_F_list,\n",
    "        'A': re_A,\n",
    "    }\n",
    "    torch.save(data, \"./data/Three_Task_Synthetic_Data_A_F.pt\")\n",
    "\n",
    "\n",
    "re_F_list = data_est[\"F_list\"]\n",
    "re_C = data_est[\"C\"]\n",
    "re_A = data_est[\"A\"]\n",
    "re_X = data_est[\"X\"]\n",
    "re_Y = data_est[\"A\"] @ data_est[\"X\"]\n",
    "\n",
    "order = torch.tensor([1, 2, 0], device=re_A.device)\n",
    "re_A = re_A[:, order]\n",
    "re_X = re_X[order, :]\n",
    "\n",
    "save_data(re_F_list, re_A)\n",
    "\n",
    "results_mse = compute_losses(re_X, re_C, re_F_list, re_A, data_true)\n",
    "print(results_mse)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af49ad83",
   "metadata": {},
   "source": [
    "# compute_p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "2fac1e4f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'corr_A': 0.9402853807650996, 'corr_F_list': [0.9992513474937725, 0.9930902077186291, 0.9909606149523752], 'corr_C': -0.4778150230736332, 'corr_X': 0.9925176803867974, 'corr_Y': 0.9710753126681247}\n"
     ]
    }
   ],
   "source": [
    "def p(X, C, F_list, a, data_true):\n",
    "\n",
    "    def corrcoef(a, b):\n",
    "        a_flat = a.flatten()\n",
    "        b_flat = b.flatten()\n",
    "        return np.corrcoef(a_flat, b_flat)[0, 1]\n",
    "\n",
    "    results = {}\n",
    "\n",
    "    Y_hat = a @ X\n",
    "\n",
    "    A_true = data_true['A'].cpu().numpy()\n",
    "    A_est = a.cpu().numpy()\n",
    "    results['corr_A'] = corrcoef(A_true, A_est)\n",
    "\n",
    "    corr_F = []\n",
    "    F_true_list = data_true['F_list']\n",
    "    F_est_list = F_list\n",
    "    for f_true, f_est in zip(F_true_list, F_est_list):\n",
    "        f_true = f_true.cpu().numpy()\n",
    "        f_est = f_est.cpu().numpy()\n",
    "        corr = corrcoef(f_true, f_est)\n",
    "        corr_F.append(corr)\n",
    "    results['corr_F_list'] = corr_F\n",
    "\n",
    "    C_true = data_true['C'].cpu().numpy()\n",
    "    C_est = C.cpu().numpy()\n",
    "    results['corr_C'] = corrcoef(C_true, C_est)\n",
    "\n",
    "    X_true = data_true['X'].cpu().numpy()\n",
    "    X_est = X.cpu().numpy()\n",
    "    results['corr_X'] = corrcoef(X_true, X_est)\n",
    "\n",
    "    Y_true = data_true['Y'].cpu().numpy()\n",
    "    Y_est = Y_hat.cpu().numpy()\n",
    "    results['corr_Y'] = corrcoef(Y_true, Y_est)\n",
    "\n",
    "    return results\n",
    "\n",
    "results_p = p(re_X, re_C, re_F_list, re_A, data_true)\n",
    "print(results_p)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
