{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "from models import *\n",
    "from data import *\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "\n",
    "# auto reload\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from manifold_flow.vector_transforms import create_vector_transform"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Coupling Transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from m_flow.experiments.architectures.vector_transforms import create_vector_transform\n",
    "from manifold_flow.transforms.projections import Projection, CompositeProjection\n",
    "from manifold_flow.transforms import *\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "LASA Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make a 5 x 6 grid of every lasa dataset plot\n",
    "fig, axs = plt.subplots(5, 6, figsize=(20, 20))\n",
    "for i, name in enumerate(data_set_names):\n",
    "    pos, vel = lasa_to_torch(name)\n",
    "    # plot pos\n",
    "    axs[i // 6, i % 6].plot(pos[0,0], pos[0,1], label=name)\n",
    "    axs[i // 6, i % 6].set_title(name)\n",
    "    # put a small x at the end\n",
    "    axs[i // 6, i % 6].plot(pos[0,0, 0], pos[0,1, 0], 'o', color='green')\n",
    "    axs[i // 6, i % 6].plot(pos[0,0, -1], pos[0,1, -1], 'x', color = 'red')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, loader, lr = 1e-3, epochs = 10, device = 'cpu'):\n",
    "    \n",
    "    def cuda(x):\n",
    "        return x.to(device)\n",
    "    losses = []\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
    "    model = cuda(model)\n",
    "    for e in range(epochs):\n",
    "        total_loss  =0\n",
    "\n",
    "        for i, (x, x_dot) in tqdm(enumerate(loader), total = len(loader)):\n",
    "            # compute loss\n",
    "\n",
    "            x = cuda(x.float())\n",
    "            x_dot = cuda(x_dot.float())\n",
    "            \n",
    "            \n",
    "            x_dot_pred = model(x)\n",
    "\n",
    "            \n",
    "            loss = torch.nn.MSELoss()(x_dot_pred, x_dot)\n",
    "            # backward pass\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            total_loss += loss.item()\n",
    "        print(\"Epoch: \", e, \"Loss: \", total_loss / len(loader))\n",
    "        losses.append(total_loss)\n",
    "\n",
    "    return losses"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data import *\n",
    "dataset_names = [ \"Sine\", \"JShape\", \"Line\", \"Angle\" ]\n",
    "# dataset_names = [\"BendedLine\"]\n",
    "pos, vel = lasa_to_torch_stacked(dataset_names)\n",
    "dataset = TrajData(pos, vel, num_demos=3)\n",
    "loader = torch.utils.data.DataLoader(dataset, batch_size=100, shuffle=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For Rosenbrock Data\n",
    "n=8 # or 16\n",
    "# dataset = get_rosenbrock(n)\n",
    "# loader = torch.utils.data.DataLoader(dataset, batch_size=100, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make the Model and Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "input_dim = len(dataset_names) * 2\n",
    "latent_dim = 2\n",
    "flow_steps = 2\n",
    "\n",
    "x_eq = dataset.pos_eq[0:1].float()\n",
    "rq_transform = create_vector_transform( input_dim,\n",
    "    flow_steps,\n",
    "    linear_transform_type=\"lu\",\n",
    "    base_transform_type=\"rq-coupling\",\n",
    "    hidden_features=30,\n",
    "    num_transform_blocks=2,\n",
    "    dropout_probability=0.0,\n",
    "    use_batch_norm=False,\n",
    "    num_bins=10,\n",
    "    tail_bound=10,\n",
    "    apply_unconditional_transform=False,\n",
    "    context_features=None)\n",
    "transform_c = CompositeProjection(rq_transform, input_dim, latent_dim)\n",
    "transform = transform_c # change this\n",
    "gncds = ELCD(d=latent_dim, x_eq = None, hidden_dim = 16)\n",
    "\n",
    "model_t = ELCD_Transform(x_eq, transform=transform, model=gncds)\n",
    "params_t = model_t.parameters()\n",
    "optimizer_t = torch.optim.Adam(params_t, lr=1e-3)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train(model_t, loader, lr=1e-3, epochs=1, device = 'cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get unshuffled data for visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dataset_det = TrajData(pos, vel, num_demos =3)\n",
    "traj_len = dataset_det.traj_len\n",
    "X_data = [x[0].float() for x in dataset_det][:traj_len]\n",
    "X_dot_data = [x[1].float() for x in dataset_det][:traj_len]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s = 0\n",
    "traj_len = dataset_det.traj_len\n",
    "plt.scatter([x[2 * s] for x in X_data], [x[1 + 2*s] for x in X_data])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot vector field and learned dx/dt on original trajectory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#model_f = lambda x, t: torch.inverse(A) @ model(A @ torch.Tensor(x))\n",
    "\n",
    "\n",
    "\n",
    "print(\"Converting model_t to cpu\")\n",
    "model_t_cpu = model_t.to('cpu')\n",
    "def model_t_f(x, t): \n",
    "\n",
    "    res = model_t_cpu( torch.Tensor(x).unsqueeze(0)).squeeze().detach().numpy()\n",
    "    optimizer_t.zero_grad()\n",
    "    return res\n",
    "\n",
    "()\n",
    "model_d = len(dataset_names) * 2\n",
    "n = 2\n",
    "x_c, y_c = 0, 0\n",
    "k=0\n",
    "print(\"Calculating Gradients\")\n",
    "\n",
    "# x_pred = [model_t(x).squeeze().detach().numpy() for x in X_data]\n",
    "\n",
    "x_pred = [model_t(x.unsqueeze(0)).squeeze().detach().numpy() for x in X_data]\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generate a trajectory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = forward_step_until_converge(model_t.forward_discrete, dataset.pos[0,:,0].unsqueeze(0), .02).detach().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = plt.subplots(1,1, figsize=(10,10))\n",
    "# ax = plot_2d_vector_field(model_t_f, x_c -n, x_c + n, y_c - n, y_c + n, model_d = model_d)\n",
    "\n",
    "every_other = 10\n",
    "s =1 # which shape to draw\n",
    "\n",
    " # which shape to draw\n",
    "ax.quiver([x[2 * s] for x in X_data[::every_other]],\n",
    "            [x[ 1 +  (2 * s)]  for x in X_data[::every_other]], \n",
    "          [x[2 * s] for x in x_pred][::every_other], \n",
    "          [x[1 + (2 * s)]  for x in x_pred][::every_other], \n",
    "          color='blue')\n",
    "\n",
    "ax.quiver([x[2 * s] for x in X_data[::every_other]],\n",
    "            [x[ 1 +  (2 * s)]  for x in X_data[::every_other]], \n",
    "          [x[2 * s] for x in x_pred][::every_other], \n",
    "          [x[1 + (2 * s)]  for x in x_pred][::every_other], \n",
    "          color='blue')\n",
    "plt.scatter(Y[:,0,s * 2], Y[:,0,1  +(2 * s)], cmap='viridis', c=np.arange(len(Y[:,0,0])))\n",
    "\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The above plot could be better. Try more epochs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv",
   "language": "python",
   "name": "myenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
