{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.integrate import solve_ivp, odeint\n",
    "import deepdish as dd\n",
    "import pandas as pd\n",
    "import cv2\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.animation as animation\n",
    "import h5py\n",
    "\n",
    "def pendulum(t, thetas, g, l):\n",
    "    theta, dot_theta = thetas # y0, y1\n",
    "    dots = (dot_theta, -(g/l)*np.sin(theta))\n",
    "    return  dots # y0_dot, y1_dot\n",
    "\n",
    "def get_trajectory(timesteps, dt, angle, omega0, length, g):\n",
    "    theta0 = np.radians(angle) # initial anglee\n",
    "\n",
    "    tmin = 0.0\n",
    "    tmax = timesteps*dt\n",
    "    ts = np.linspace(tmin, tmax, timesteps)\n",
    "    sol = solve_ivp(pendulum, [tmin, tmax], [theta0, omega0], t_eval = ts, args=(g,length))\n",
    "\n",
    "    # save the x, y coordinated of the pendulum\n",
    "    xy = np.zeros_like(sol.y)\n",
    "    xy[0] = length*np.sin(sol.y[0])\n",
    "    xy[1] = length*np.cos(sol.y[0])\n",
    "\n",
    "    cartesian = xy.T\n",
    "    phase_space = sol.y.T\n",
    "    \n",
    "    labels =  {'initial_angle': angle, \n",
    "               'initial_velocity': omega0, \n",
    "               'gravity': g, \n",
    "               'length': length}\n",
    "    \n",
    "    return cartesian, phase_space, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.integrate import solve_ivp, odeint\n",
    "import deepdish as dd\n",
    "import pandas as pd\n",
    "import cv2\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.animation as animation\n",
    "import h5py\n",
    "\n",
    "def pendulum(t, thetas, g, l):\n",
    "    theta, dot_theta = thetas # y0, y1\n",
    "    dots = (dot_theta, -(g/l)*np.sin(theta))\n",
    "    return  dots # y0_dot, y1_dot\n",
    "\n",
    "def get_trajectory(timesteps, dt, angle, omega0, length, g):\n",
    "    theta0 = np.radians(angle) # initial anglee\n",
    "\n",
    "    tmin = 0.0\n",
    "    tmax = timesteps*dt\n",
    "    ts = np.linspace(tmin, tmax, timesteps)\n",
    "    sol = solve_ivp(pendulum, [tmin, tmax], [theta0, omega0], t_eval = ts, args=(g,length))\n",
    "\n",
    "    # save the x, y coordinated of the pendulum\n",
    "    xy = np.zeros_like(sol.y)\n",
    "    xy[0] = length*np.sin(sol.y[0])\n",
    "    xy[1] = length*np.cos(sol.y[0])\n",
    "\n",
    "    cartesian = xy.T\n",
    "    phase_space = sol.y.T\n",
    "    \n",
    "    labels =  {'initial_angle': angle, \n",
    "               'initial_velocity': omega0, \n",
    "               'gravity': g, \n",
    "               'length': length}\n",
    "    \n",
    "    return cartesian, phase_space, labels\n",
    "\n",
    "def coords_to_pixels(world_size, x, y, res):\n",
    "        \"\"\"Maps coordinates from world space to pixel space\n",
    "        Args: x,y (float): x,y coordinate of the world space.\n",
    "            res (int): Image resolution in pixels (images are square).\n",
    "        Returns: (int, int): Tuple of coordinates in pixel space.\n",
    "        \"\"\"\n",
    "        pix_x = np.rint(res*(x + world_size)/(2*world_size))\n",
    "        pix_y = np.rint(res*(y + world_size)/(2*world_size))\n",
    "        return (int(pix_x), int(pix_y))\n",
    "\n",
    "\n",
    "def draw(q, resolution, world_size, ball_size, ball_color, filter_size, resize=None):\n",
    "    \"\"\"Returns array of the environment evolution\n",
    "       vid (np.ndarray): Rendered rollout as a sequence of images\n",
    "    \"\"\"    \n",
    "    background_color = 0.0\n",
    "    \n",
    "    rollout_length = len(q)\n",
    "    \n",
    "    if resize is None:\n",
    "        resize = resolution\n",
    "\n",
    "    vid = np.zeros((rollout_length, resize, resize), dtype=np.float32)\n",
    "    for t in range(rollout_length):\n",
    "        temp = np.zeros((resolution, resolution))\n",
    "        pixels = coords_to_pixels(world_size, q[t,0], q[t,1], resolution)\n",
    "        temp = cv2.circle(temp, pixels, ball_size, ball_color, thickness=-1)\n",
    "        temp = cv2.blur(temp, (filter_size, filter_size))\n",
    "        temp = cv2.blur(temp, (filter_size, filter_size))\n",
    "        if resize is not None:\n",
    "            temp = cv2.resize(temp, (resize, resize))\n",
    "        vid[t] = temp\n",
    "\n",
    "    vid += background_color\n",
    "    vid[vid > 1.] = 1.\n",
    "    vid[vid < 0.] = 0.\n",
    "    \n",
    "    return vid\n",
    "\n",
    "def create_dataset(fname, dt, timesteps, final_resolution, angle_range, velocity_range, g_range, length_range, g_exclude = None, length_exclude = None):\n",
    "    world_size = 2.5\n",
    "    resolution=512\n",
    "    ball_size=int(resolution/8)\n",
    "    filter_size=int(ball_size/8)\n",
    "    ball_color=1\n",
    "    params = {'dt': dt,\n",
    "        'timesteps': timesteps,\n",
    "        'world_size': world_size,\n",
    "        'initial_resolution': resolution,\n",
    "        'final_resolution': final_resolution, \n",
    "        'filter_size': filter_size, \n",
    "        'ball_size': ball_size,\n",
    "        'color': ball_color,\n",
    "        'angle_range': angle_range,\n",
    "        'velocity_range': velocity_range,\n",
    "        'g_range': g_range,\n",
    "        'length_range': length_range\n",
    "    }\n",
    "    num_datapoints = angle_range[2] * velocity_range[2] * g_range[2] * length_range[2]\n",
    "\n",
    "    fg = h5py.File(fname, 'w')\n",
    "\n",
    "    for k in fg.keys():\n",
    "        del fg[k]\n",
    "    for k in fg.attrs.keys():\n",
    "        del fg.attrs[k]\n",
    "\n",
    "    # Set global attributes\n",
    "    for k,v in params.items():\n",
    "        fg.attrs[k] = v\n",
    "\n",
    "    fg.create_dataset('frames',(num_datapoints, timesteps, final_resolution, final_resolution), \n",
    "                     chunks=(1, timesteps, final_resolution, final_resolution), compression='gzip')\n",
    "    fg.create_dataset('phase_space',(num_datapoints, timesteps, 2), \n",
    "                     chunks=(1, timesteps, 2), compression='gzip')\n",
    "\n",
    "    fg.create_group('labels')\n",
    "    fg.create_dataset('labels/pendulum_length', (num_datapoints,1), chunks=True)\n",
    "    fg.create_dataset('labels/g', (num_datapoints,1), chunks=True)\n",
    "    fg.create_dataset('labels/initial_angle', (num_datapoints,1), chunks=True)\n",
    "    fg.create_dataset('labels/initial_velocity', (num_datapoints,1), chunks=True)\n",
    "    fg.keys()\n",
    "\n",
    "    i=0\n",
    "    print(f'Creating {num_datapoints} datapoints')\n",
    "    \n",
    "    angle_list = np.random.uniform(*angle_range[:-1],num_datapoints)\n",
    "    velocity_list = np.random.uniform(*velocity_range[:-1],num_datapoints)\n",
    "    g_list = np.random.uniform(*g_range[:-1],num_datapoints)\n",
    "    length_list = np.random.uniform(*length_range[:-1],num_datapoints)\n",
    "\n",
    "    if g_exclude != None or length_exclude != None:\n",
    "        for idx in range(num_datapoints):\n",
    "            while g_list[idx] >= g_exclude[0] and g_list[idx] <= g_exclude[1] and length_list[idx] >= length_exclude[0] and length_list[idx] <= length_exclude[1]:\n",
    "                g_list[idx] = np.random.uniform(*g_range[:-1],1)[0]\n",
    "                length_list[idx] = np.random.uniform(*length_range[:-1],1)[0]\n",
    "\n",
    "    for angle,omega0,g,pendulum_length in zip(angle_list,velocity_list,g_list,length_list):        \n",
    "        fg['labels']['pendulum_length'][i]  = pendulum_length\n",
    "        fg['labels']['g'][i]                = g\n",
    "        fg['labels']['initial_angle'][i]    = angle\n",
    "        fg['labels']['initial_velocity'][i] = omega0\n",
    "\n",
    "        i+=1\n",
    "        if (i % 100) == 0:\n",
    "            print(f'{i:04d}/{num_datapoints} finished')\n",
    "    \n",
    "    fg.close()\n",
    "\n",
    "final_resolution = 64\n",
    "dt = 0.01\n",
    "timesteps = 1000\n",
    "\n",
    "n_div = 6\n",
    "angle_range = (30, 170, n_div)\n",
    "velocity_range = (-2, 2, n_div)\n",
    "g_range = (8, 12, n_div)\n",
    "length_range = (1.2, 1.4, n_div)\n",
    "\n",
    "g_exclude = (8, 12)\n",
    "length_exclude = (1.2, 1.4)\n",
    "\n",
    "num_datapoints = angle_range[2] * velocity_range[2] * g_range[2] * length_range[2]\n",
    "\n",
    "ang_str = '-'.join([str(a) for a in angle_range[:2]])\n",
    "vel_str = '-'.join([f'{a:.2f}' for a in velocity_range[:2]])\n",
    "g_str = '-'.join([f'{a:.2f}' for a in g_range[:2]])\n",
    "len_str = '-'.join([f'{a:.2f}' for a in length_range[:2]])\n",
    "\n",
    "fname = f'../data/pixel_pendulum_n_{num_datapoints}_steps_{timesteps}_dt_{dt:.2f}'+\\\n",
    "        f'_angle_{ang_str}_vel_{vel_str}_len_{len_str}_g_{g_str}_val.hd5'\n",
    "\n",
    "create_dataset(fname, dt, timesteps, final_resolution, angle_range, velocity_range, g_range, length_range)\n",
    "print(fname)\n",
    "\n",
    "n_div = 6\n",
    "angle_range = (30, 170, n_div)\n",
    "velocity_range = (-2, 2, n_div)\n",
    "g_range = (8, 12, n_div)\n",
    "length_range = (1.2, 1.4, n_div)\n",
    "\n",
    "num_datapoints = angle_range[2] * velocity_range[2] * g_range[2] * length_range[2]\n",
    "\n",
    "ang_str = '-'.join([str(a) for a in angle_range[:2]])\n",
    "vel_str = '-'.join([f'{a:.2f}' for a in velocity_range[:2]])\n",
    "g_str = '-'.join([f'{a:.2f}' for a in g_range[:2]])\n",
    "len_str = '-'.join([f'{a:.2f}' for a in length_range[:2]])\n",
    "\n",
    "fname = f'../data/pixel_pendulum_n_{num_datapoints}_steps_{timesteps}_dt_{dt:.2f}'+\\\n",
    "        f'_angle_{ang_str}_vel_{vel_str}_len_{len_str}_g_{g_str}_test.hd5'\n",
    "\n",
    "create_dataset(fname, dt, timesteps, final_resolution, angle_range, velocity_range, g_range, length_range)\n",
    "print(fname)\n",
    "\n",
    "\n",
    "n_div = 6\n",
    "angle_range = (30, 170, n_div)\n",
    "velocity_range = (-2, 2, n_div)\n",
    "g_range = (12, 12.5, n_div)\n",
    "length_range = (1.4, 1.45, n_div)\n",
    "\n",
    "num_datapoints = angle_range[2] * velocity_range[2] * g_range[2] * length_range[2]\n",
    "\n",
    "ang_str = '-'.join([str(a) for a in angle_range[:2]])\n",
    "vel_str = '-'.join([f'{a:.2f}' for a in velocity_range[:2]])\n",
    "g_str = '-'.join([f'{a:.2f}' for a in g_range[:2]])\n",
    "len_str = '-'.join([f'{a:.2f}' for a in length_range[:2]])\n",
    "\n",
    "fname = f'../data/pixel_pendulum_n_{num_datapoints}_steps_{timesteps}_dt_{dt:.2f}'+\\\n",
    "        f'_angle_{ang_str}_vel_{vel_str}_len_{len_str}_g_{g_str}.hd5'\n",
    "\n",
    "\n",
    "g_range = (8, 12.5, n_div)\n",
    "length_range = (1.2, 1.45, n_div)\n",
    "create_dataset(fname, dt, timesteps, final_resolution, angle_range, velocity_range, g_range, length_range,g_exclude=g_exclude,length_exclude=length_exclude)\n",
    "print(fname)\n",
    "\n",
    "\n",
    "\n",
    "g_exclude = (8, 12.5)\n",
    "length_exclude = (1.2, 1.45)\n",
    "n_div = 6\n",
    "angle_range = (30, 170, n_div)\n",
    "velocity_range = (-2, 2, n_div)\n",
    "g_range = (12.5, 13.0, n_div)\n",
    "length_range = (1.45, 1.50, n_div)\n",
    "\n",
    "num_datapoints = angle_range[2] * velocity_range[2] * g_range[2] * length_range[2]\n",
    "\n",
    "ang_str = '-'.join([str(a) for a in angle_range[:2]])\n",
    "vel_str = '-'.join([f'{a:.2f}' for a in velocity_range[:2]])\n",
    "g_str = '-'.join([f'{a:.2f}' for a in g_range[:2]])\n",
    "len_str = '-'.join([f'{a:.2f}' for a in length_range[:2]])\n",
    "\n",
    "fname = f'../data/pixel_pendulum_n_{num_datapoints}_steps_{timesteps}_dt_{dt:.2f}'+\\\n",
    "        f'_angle_{ang_str}_vel_{vel_str}_len_{len_str}_g_{g_str}.hd5'\n",
    "\n",
    "g_range = (8, 13.0, n_div)\n",
    "length_range = (1.2, 1.50, n_div)\n",
    "create_dataset(fname, dt, timesteps, final_resolution, angle_range, velocity_range, g_range, length_range,g_exclude=g_exclude,length_exclude=length_exclude)\n",
    "print(fname)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating 10000 datapoints\n",
      "0100/10000 finished\n",
      "0200/10000 finished\n",
      "0300/10000 finished\n",
      "0400/10000 finished\n",
      "0500/10000 finished\n",
      "0600/10000 finished\n",
      "0700/10000 finished\n",
      "0800/10000 finished\n",
      "0900/10000 finished\n",
      "1000/10000 finished\n",
      "1100/10000 finished\n",
      "1200/10000 finished\n",
      "1300/10000 finished\n",
      "1400/10000 finished\n",
      "1500/10000 finished\n",
      "1600/10000 finished\n",
      "1700/10000 finished\n",
      "1800/10000 finished\n",
      "1900/10000 finished\n",
      "2000/10000 finished\n",
      "2100/10000 finished\n",
      "2200/10000 finished\n",
      "2300/10000 finished\n",
      "2400/10000 finished\n",
      "2500/10000 finished\n",
      "2600/10000 finished\n",
      "2700/10000 finished\n",
      "2800/10000 finished\n"
     ]
    }
   ],
   "source": [
    "n_div = 10\n",
    "angle_range = (30, 170, n_div)\n",
    "velocity_range = (-2, 2, n_div)\n",
    "g_range = (8, 12, n_div)\n",
    "length_range = (1.2, 1.4, n_div)\n",
    "\n",
    "num_datapoints = angle_range[2] * velocity_range[2] * g_range[2] * length_range[2]\n",
    "\n",
    "ang_str = '-'.join([str(a) for a in angle_range[:2]])\n",
    "vel_str = '-'.join([f'{a:.2f}' for a in velocity_range[:2]])\n",
    "g_str = '-'.join([f'{a:.2f}' for a in g_range[:2]])\n",
    "len_str = '-'.join([f'{a:.2f}' for a in length_range[:2]])\n",
    "\n",
    "fname = f'../data/pixel_pendulum_n_{num_datapoints}_steps_{timesteps}_dt_{dt:.2f}'+\\\n",
    "        f'_angle_{ang_str}_vel_{vel_str}_len_{len_str}_g_{g_str}.hd5'\n",
    "\n",
    "create_dataset(fname, dt, timesteps, final_resolution, angle_range, velocity_range, g_range, length_range)\n",
    "print(fname)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Validation set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_div = 6\n",
    "angle_range = (30, 170, n_div)\n",
    "velocity_range = (-2, 2, n_div)\n",
    "g_range = (12, 12.5, n_div)\n",
    "length_range = (1.4, 1.45, n_div)\n",
    "\n",
    "num_datapoints = angle_range[2] * velocity_range[2] * g_range[2] * length_range[2]\n",
    "\n",
    "ang_str = '-'.join([str(a) for a in angle_range[:2]])\n",
    "vel_str = '-'.join([f'{a:.2f}' for a in velocity_range[:2]])\n",
    "g_str = '-'.join([f'{a:.2f}' for a in g_range[:2]])\n",
    "len_str = '-'.join([f'{a:.2f}' for a in length_range[:2]])\n",
    "\n",
    "fname = f'../data/pixel_pendulum_n_{num_datapoints}_steps_{timesteps}_dt_{dt:.2f}'+\\\n",
    "        f'_angle_{ang_str}_vel_{vel_str}_len_{len_str}_g_{g_str}.hd5'\n",
    "\n",
    "create_dataset(fname, dt, timesteps, final_resolution, angle_range, velocity_range, g_range, length_range)\n",
    "print(fname)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_div = 6\n",
    "angle_range = (30, 170, n_div)\n",
    "velocity_range = (-2, 2, n_div)\n",
    "g_range = (12.5, 13.0, n_div)\n",
    "length_range = (1.45, 1.50, n_div)\n",
    "\n",
    "num_datapoints = angle_range[2] * velocity_range[2] * g_range[2] * length_range[2]\n",
    "\n",
    "ang_str = '-'.join([str(a) for a in angle_range[:2]])\n",
    "vel_str = '-'.join([f'{a:.2f}' for a in velocity_range[:2]])\n",
    "g_str = '-'.join([f'{a:.2f}' for a in g_range[:2]])\n",
    "len_str = '-'.join([f'{a:.2f}' for a in length_range[:2]])\n",
    "\n",
    "fname = f'../data/pixel_pendulum_n_{num_datapoints}_steps_{timesteps}_dt_{dt:.2f}'+\\\n",
    "        f'_angle_{ang_str}_vel_{vel_str}_len_{len_str}_g_{g_str}.hd5'\n",
    "\n",
    "create_dataset(fname, dt, timesteps, final_resolution, angle_range, velocity_range, g_range, length_range)\n",
    "print(fname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fg = h5py.File(fname, 'r')\n",
    "vani=fg['frames'][128]\n",
    "fig = plt.figure()\n",
    "ims = []\n",
    "for i in range(len(vani)):\n",
    "    ims.append([plt.imshow(vani[i], cmap='gray', vmin=0, vmax=1, aspect='equal', animated=True)])\n",
    "ani = animation.ArtistAnimation(fig, ims, interval=10, blit=True, repeat_delay=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(vani[10], cmap='gray', vmin=0, vmax=1, aspect='equal')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
