{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52b37b5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from pathlib import Path\n",
    "import sys\n",
    "sys.path.append(\"../../\")\n",
    "import json\n",
    "import os\n",
    "\n",
    "from src.kooporch import full_to_reduced_system, reduced_to_full_system"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f75a7b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Utility functions to convert between numpy and torch\n",
    "\n",
    "def from_numpy_to_torch(T, dtype =torch.cfloat, device = \"cpu\"):\n",
    "    return [torch.from_numpy(x).to(dtype=dtype, device=device) for x in T]\n",
    "\n",
    "def from_torch_to_numpy(T):\n",
    "    return [x.detach().cpu().numpy() for x in T]\n",
    "\n",
    "def primal_predict(T,X_init,n_samples,sampfreq,real_valued=True):\n",
    "    D,R,L = T\n",
    "    t = np.arange(n_samples)/sampfreq\n",
    "    Z = np.exp(D[:,None] * t[None,:])\n",
    "    A =Z *(L.conj().T @ X_init.reshape(-1,1))\n",
    "    pred = (R @ A).T\n",
    "    if real_valued:\n",
    "        pred = pred.real\n",
    "    return pred\n",
    "\n",
    "def linear_primal_predict(T1,T2,ratio,X_init,n_samples,sampfreq):\n",
    "    D1,R1,L1 = T1\n",
    "    D2,R2,L2 = T2\n",
    "    T1 = R1 @ (np.exp(D1[:,None]/sampfreq) * L1.conj().T)\n",
    "    T2 = R2 @ (np.exp(D2[:,None]/sampfreq) * L2.conj().T)\n",
    "    T = (1-ratio) * T1 + ratio * T2\n",
    "    x_current = X_init\n",
    "    traj = [x_current]\n",
    "    for t in range(n_samples-1):\n",
    "        x_current = T @ x_current\n",
    "        traj.append(x_current)\n",
    "    traj = np.array(traj)\n",
    "    return traj.real"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62d6a124",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_ID = 2 \n",
    "\n",
    "B_DTYPE = torch.cfloat\n",
    "D_DTYPE = torch.float32\n",
    "\n",
    "# loading the data\n",
    "result_path = Path(f\"results/exp_{EXP_ID}\")\n",
    "data_1 = np.load(result_path / \"T1.npz\")\n",
    "D1,L1,R1,signal_1 = data_1['D'], data_1['L'], data_1['R'], data_1['signal']\n",
    "T1 = (D1,R1,L1)\n",
    "T1 = full_to_reduced_system(T1)\n",
    "data_2 = np.load(result_path / \"T2.npz\")\n",
    "D2,L2,R2,signal_2 = data_2['D'], data_2['L'], data_2['R'], data_2['signal']\n",
    "T2 = (D2,R2,L2)\n",
    "T2 = full_to_reduced_system(T2)\n",
    "traj_init = np.load(result_path / \"signal_init.npy\")\n",
    "\n",
    "\n",
    "# loading the settings\n",
    "with open(result_path / \"settings.json\", 'r') as f:\n",
    "    settings = json.load(f)\n",
    "n_points = settings[\"n_points\"]\n",
    "ratios = np.linspace(0,1,n_points+2)\n",
    "\n",
    "# loading the barycenters\n",
    "base_lst = {}\n",
    "for filename in os.listdir(result_path / \"base\"):\n",
    "    if filename.endswith(\".npz\"):\n",
    "        i = int(filename.split(\"_\")[-1].split(\".\")[0])\n",
    "    data = np.load(result_path / \"base\" / filename)\n",
    "    D, R, L = data['D'], data['R'], data['L']\n",
    "    base_lst[i] = (D,R,L)\n",
    "\n",
    "hs_lst = {}\n",
    "for filename in os.listdir(result_path / \"hs\"):\n",
    "    if filename.endswith(\".npz\"):\n",
    "        i = int(filename.split(\"_\")[-1].split(\".\")[0])\n",
    "    data = np.load(result_path / \"hs\" / filename)\n",
    "    D, R, L = data['D'], data['R'], data['L']\n",
    "    hs_lst[i] = (D,R,L)\n",
    "\n",
    "ot_lst = {}\n",
    "for filename in os.listdir(result_path / \"ot\"):\n",
    "    if filename.endswith(\".npz\"):\n",
    "        i = int(filename.split(\"_\")[-1].split(\".\")[0])\n",
    "    data = np.load(result_path / \"ot\" / filename)\n",
    "    D, R, L = data['D'], data['R'], data['L']\n",
    "    ot_lst[i] = (D,R,L)\n",
    "\n",
    "\n",
    "## Focus on frequency and decay \n",
    "freq_s = T1[0].imag/(2*np.pi)\n",
    "decay_s = T1[0].real\n",
    "\n",
    "freq_t = T2[0].imag/(2*np.pi)\n",
    "decay_t = T2[0].real"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "967dc563",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams[\"font.family\"] = \"Helvetica\"\n",
    "plt.rcParams[\"xtick.labelsize\"] = 8\n",
    "plt.rcParams[\"ytick.labelsize\"] = 8\n",
    "plt.rcParams[\"axes.labelsize\"] = 11\n",
    "plt.rcParams[\"legend.fontsize\"] = 11\n",
    "\n",
    "fig,ax = plt.subplots(1,1,figsize=(4,4))\n",
    "ax.scatter(decay_s,freq_s,label=r\"$\\mathbf{T}^{(0)}$\",color=\"C2\",marker=\"X\",s=100)\n",
    "\n",
    "for i, ratio in enumerate(ratios[1:-1]):\n",
    "\n",
    "    freq_o = ot_lst[i+1][0].imag/(2*np.pi)\n",
    "    decay_o = ot_lst[i+1][0].real\n",
    "    ax.scatter(decay_o,freq_o,marker=\"v\",color=\"C0\",alpha=ratio, label=r\"$\\mathbf{T}^{(%.1f)}_{OT}$\"%ratio)\n",
    "\n",
    "ax.scatter(decay_s,freq_s,color=\"C2\",label=\" \",alpha=0)\n",
    "ax.scatter(decay_s,freq_s,color=\"C2\",label=\" \",alpha=0)\n",
    "\n",
    "for i, ratio in enumerate(ratios[1:-1]):\n",
    "    freq_h = hs_lst[i+1][0].imag/(2*np.pi)\n",
    "    decay_h = hs_lst[i+1][0].real\n",
    "    ax.scatter(decay_h,freq_h,marker=\"^\",alpha=ratio,color=\"C1\", label=r\"$\\mathbf{T}^{(%.1f)}_{HS}$\"%ratio)\n",
    "\n",
    "ax.scatter(decay_t,freq_t,label=r\"$\\mathbf{T}^{(1)}$\",color=\"C3\",marker=\"X\",s=100)\n",
    "ax.set_xlabel(r\"Decay (Hz)\")\n",
    "ax.set_ylabel(r\"Frequency (Hz)\")\n",
    "ax.legend(ncol=2,columnspacing=-0,handletextpad=-0.4, bbox_to_anchor=(1.2, 1.02))\n",
    "plt.savefig(result_path/\"eigenvalues_interpolation.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6b6744d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams[\"font.family\"] = \"Helvetica\"\n",
    "plt.rcParams[\"xtick.labelsize\"] = 10\n",
    "plt.rcParams[\"ytick.labelsize\"] = 10\n",
    "plt.rcParams[\"axes.labelsize\"] = 14\n",
    "plt.rcParams[\"legend.fontsize\"] = 14\n",
    "plt.rcParams['axes.titlesize'] = 14\n",
    "\n",
    "fig = plt.figure(figsize=(13, 6))\n",
    "ax_ot = fig.add_subplot(1, 3, 3, projection='3d')\n",
    "ax_hs = fig.add_subplot(1, 3, 2, projection='3d')\n",
    "ax_linear = fig.add_subplot(1, 3, 1, projection='3d')\n",
    "\n",
    "max_length = 4000\n",
    "init_time = np.arange(traj_init.shape[0])/settings[\"sampfreq\"]\n",
    "pred_time = np.arange(traj_init.shape[0],max_length + traj_init.shape[0])/settings[\"sampfreq\"]\n",
    "\n",
    "facecolors = plt.colormaps['turbo'](np.linspace(0, 1, len(ratios)))\n",
    "\n",
    "\n",
    "# OT interpolation predictions\n",
    "target_pred = primal_predict(reduced_to_full_system(T2), traj_init, settings[\"n_samples\"], settings[\"sampfreq\"])[:max_length,-1]\n",
    "ax_ot.plot(pred_time,np.full_like(pred_time,1),target_pred,color=facecolors[-1],alpha=0.7,label=r\"$\\mathbf{T}^{(1)}$ $(\\gamma= 1.0)$\")\n",
    "ax_ot.plot(init_time,np.full_like(init_time,1),traj_init[:,0], color=\"k\",alpha=0.5)\n",
    "for idx in [8,6,4,2]:\n",
    "    pred = primal_predict(reduced_to_full_system(ot_lst[idx]), traj_init, settings[\"n_samples\"], settings[\"sampfreq\"])[:max_length, -1]\n",
    "    ax_ot.plot(pred_time, np.full_like(pred_time, ratios[idx]), pred, alpha=0.7, color=facecolors[idx], label=r\"$\\gamma= $\"+f\"{ratios[idx]:.1f}\")\n",
    "    ax_ot.plot(init_time,np.full_like(init_time,ratios[idx]),traj_init[:,0], color=\"k\",alpha=0.5)\n",
    "source_pred = primal_predict(reduced_to_full_system(T1), traj_init, settings[\"n_samples\"], settings[\"sampfreq\"])[:max_length,-1]\n",
    "ax_ot.plot(pred_time,np.full_like(pred_time,0),source_pred,color=facecolors[0], alpha=0.7,label=r\"$\\mathbf{T}^{(0)}$ $(\\gamma= 0.0)$\")\n",
    "ax_ot.plot(init_time,np.full_like(init_time,0),traj_init[:,0], color=\"k\",alpha=0.5)\n",
    "ax_ot.set_title('(c) SGOT',y=0.85, pad=0)\n",
    "ax_ot.set_xlabel('Time (s)')\n",
    "ax_ot.tick_params(axis='x', which='major', pad=-3)\n",
    "ax_ot.set_ylabel(r'Ratio $\\gamma$')\n",
    "ax_ot.set_box_aspect([3, 3, 1])  # aspect ratio is 2:1:1\n",
    "ax_ot.view_init(elev=20, azim=-80, roll=0)\n",
    "\n",
    "\n",
    "# HS interpolation predictions\n",
    "target_pred = primal_predict(reduced_to_full_system(T2), traj_init, settings[\"n_samples\"], settings[\"sampfreq\"])[:max_length,-1]\n",
    "ax_hs.plot(pred_time,np.full_like(pred_time,1),target_pred,color=facecolors[-1],alpha=0.7,label=r\"$\\mathbf{T}^{(0)}$\")\n",
    "ax_hs.plot(init_time,np.full_like(init_time,1),traj_init[:,0], color=\"k\",label=\"Initial signal\",alpha=0.5)\n",
    "for idx in [8,6,4,2]:\n",
    "    pred = primal_predict(reduced_to_full_system(hs_lst[idx]), traj_init, settings[\"n_samples\"], settings[\"sampfreq\"])[:max_length, -1]\n",
    "    ax_hs.plot(pred_time, np.full_like(pred_time, ratios[idx]), pred, alpha=0.7, color=facecolors[idx])\n",
    "    ax_hs.plot(init_time,np.full_like(init_time,ratios[idx]),traj_init[:,0], color=\"k\",label=\"Initial signal\",alpha=0.5)\n",
    "source_pred = primal_predict(reduced_to_full_system(T1), traj_init, settings[\"n_samples\"], settings[\"sampfreq\"])[:max_length,-1]\n",
    "ax_hs.plot(pred_time,np.full_like(pred_time,0),source_pred,color=facecolors[0], alpha=0.7,label=r\"$\\mathbf{T}^{(0)}$\")\n",
    "ax_hs.plot(init_time,np.full_like(init_time,0),traj_init[:,0], color=\"k\",label=\"Initial signal\",alpha=0.5)\n",
    "ax_hs.set_title('(b) Constrained Hilbert-Schmidt',y=0.85, pad=0)\n",
    "ax_hs.set_xlabel('Time (s)')\n",
    "ax_hs.tick_params(axis='x', which='major', pad=-3)\n",
    "ax_hs.set_ylabel(r'Ratio $\\gamma$')\n",
    "ax_hs.set_box_aspect([3, 3, 1])  # aspect ratio is 2:1:1\n",
    "ax_hs.view_init(elev=20, azim=-80, roll=0)\n",
    "\n",
    "# Linear interpolation predictions\n",
    "target_pred = primal_predict(reduced_to_full_system(T2), traj_init, settings[\"n_samples\"], settings[\"sampfreq\"])[:max_length,-1]\n",
    "ax_linear.plot(pred_time,np.full_like(pred_time,1),target_pred,color=facecolors[-1],alpha=0.7,label=r\"$\\mathbf{T}^{(0)}$\")\n",
    "ax_linear.plot(init_time,np.full_like(init_time,1),traj_init[:,0], color=\"k\",label=\"Initial signal\",alpha=0.5)\n",
    "for idx in [2,4,6,8]:\n",
    "    pred = linear_primal_predict(reduced_to_full_system(T1),reduced_to_full_system(T2),ratios[idx],traj_init,settings[\"n_samples\"], settings[\"sampfreq\"])[:max_length, -1,0]\n",
    "    ax_linear.plot(pred_time, np.full_like(pred_time, ratios[idx]), pred, alpha=0.7, color=facecolors[idx])\n",
    "    ax_linear.plot(init_time,np.full_like(init_time,ratios[idx]),traj_init[:,0], color=\"k\",label=\"Initial signal\",alpha=0.5)\n",
    "source_pred = primal_predict(reduced_to_full_system(T1), traj_init, settings[\"n_samples\"], settings[\"sampfreq\"])[:max_length,-1].squeeze()\n",
    "ax_linear.plot(pred_time,np.full_like(pred_time,0),source_pred,color=facecolors[0], alpha=0.7,label=r\"$\\mathbf{T}^{(0)}$\")\n",
    "ax_linear.plot(init_time,np.full_like(init_time,0),traj_init[:,0], color=\"k\",label=\"Initial signal\",alpha=0.5)\n",
    "ax_linear.set_title('(a) Hilbert-Schmidt',y =0.85,pad=0)\n",
    "ax_linear.set_xlabel('Time (s)')\n",
    "ax_linear.tick_params(axis='x', which='major', pad=-3)\n",
    "ax_linear.set_ylabel(r'Ratio $\\gamma$')\n",
    "ax_linear.set_box_aspect([3, 3, 1])  # aspect ratio is 2:1:1\n",
    "ax_linear.view_init(elev=20, azim=-80, roll=0)\n",
    "\n",
    "plt.subplots_adjust(left=0, right=1, top=1, bottom=0,hspace=0.1, wspace=0.1)\n",
    "handles, labels = ax_ot.get_legend_handles_labels()\n",
    "fig.legend(handles[::-1], labels[::-1], loc='upper center', bbox_to_anchor=(0.5, 0.25), ncol=len(labels),frameon=False,columnspacing=1.0,handletextpad=0.2,)\n",
    "fig.tight_layout()\n",
    "plt.savefig(result_path/\"interpolation.pdf\", bbox_inches='tight', pad_inches=0.1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2372e5ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "ot_losses = []\n",
    "ot_time = []\n",
    "for filename in os.listdir(result_path / \"ot\"):\n",
    "    ot_losses.append(np.load(result_path / \"ot\" / filename)[\"losses\"])\n",
    "    ot_time.append(np.load(result_path / \"ot\" / filename)[\"duration\"])\n",
    "ot_time = np.array(ot_time)\n",
    "ot_lens = np.array([len(x)*10 for x in ot_losses])\n",
    "ot_gd_time  = np.mean(ot_time / ot_lens)\n",
    "\n",
    "hs_losses = []\n",
    "hs_time = []\n",
    "for filename in os.listdir(result_path / \"hs\"):\n",
    "    hs_losses.append(np.load(result_path / \"hs\" / filename)[\"losses\"])\n",
    "    hs_time.append(np.load(result_path / \"hs\" / filename)[\"duration\"])\n",
    "hs_time = np.array(hs_time)\n",
    "hs_lens = np.array([len(x) for x in hs_losses])\n",
    "hs_gd_time  = np.mean(hs_time / hs_lens)\n",
    "\n",
    "\n",
    "\n",
    "fig = plt.figure(figsize=(4, 4))\n",
    "for i, loss in enumerate(ot_losses):\n",
    "    plt.plot(np.arange(len(loss))*10,loss/np.max(loss), color='C0',alpha=0.5, label= \"SGOT\" if i==0 else \"\")\n",
    "\n",
    "for i, loss in enumerate(hs_losses):\n",
    "    plt.plot(np.arange(len(loss))*1,loss/np.max(loss), color='C1', alpha=0.5, label= \"Hilbert-Schmidt\" if i==0 else \"\")\n",
    "plt.xlabel(\"Gradient descent steps\")\n",
    "plt.ylabel(\"Normalized loss\")\n",
    "plt.legend()\n",
    "fig.tight_layout()\n",
    "plt.savefig(result_path/\"convergence.pdf\", bbox_inches='tight', pad_inches=0.1)\n",
    "\n",
    "print(f\"SGOT average time per gradient step: {ot_gd_time*1000:.4f} seconds\")\n",
    "print(f\"Hilbert-Schmidt average time per gradient step: {hs_gd_time*1000:.4f} seconds\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "KooPOT",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
