{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7636673f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys, os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c28d3ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob, os\n",
    "import mediapy as media\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from load_model_from_ckpt import load_model, get_sampler, init_samples\n",
    "from datasets import get_dataset, data_transform, inverse_data_transform\n",
    "from runners.ncsn_runner import conditioning_fn\n",
    "\n",
    "from os.path import expanduser\n",
    "home = expanduser(\"~\")\n",
    "\n",
    "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d956de9-6fba-4824-bfdb-3d36b4178836",
   "metadata": {},
   "source": [
    "# Set directories to download model, data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08f0476c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# SET THESE!!!\n",
    "GDRIVE_URL = \"https://drive.google.com/drive/folders/1bM6wqU_kymoljz5uYQRCYNup_8adBfLH\" # smmnist_big_5c5_unetm_b2\n",
    "EXP_PATH = os.path.join(home, \"/path/to/dir/mcvd-pytorch\")\n",
    "DATA_PATH = os.path.join(home, \"/path/to/dir/mcvd-pytorch\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eefab24a",
   "metadata": {},
   "source": [
    "# Download experiment (model checkpoint, config, etc.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eef31cee",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# GDRIVE_URL = GDRIVE_URL.removesuffix(\"?usp=sharing\")\n",
    "# !gdown --fuzzy {GDRIVE_URL} -O {EXP_PATH}/ --folder"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7497708",
   "metadata": {},
   "source": [
    "# Load model checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bed12d90",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_path = glob.glob(os.path.join(EXP_PATH, \"bouncing_balls_div/logs/checkpoint_*.pt\"))[0]\n",
    "scorenet, config = load_model(ckpt_path, device)\n",
    "print(config)\n",
    "sampler = get_sampler(config)\n",
    "\n",
    "config.data.dataset = \"BOUNCING_BALLS\"\n",
    "dataset, test_dataset = get_dataset(DATA_PATH, config, video_frames_pred=config.data.num_frames)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc8c5945",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(device)\n",
    "print(ckpt_path)\n",
    "print(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "717769fc",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed043f4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset, test_dataset = get_dataset(DATA_PATH, config, video_frames_pred=config.data.num_frames)\n",
    "print(dataset.__len__(), test_dataset.__len__())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "277c205a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataloader = DataLoader(dataset, batch_size=config.training.batch_size, shuffle=True,\n",
    "#                         num_workers=config.data.num_workers)\n",
    "# train_iter = iter(dataloader)\n",
    "# x, y = next(train_iter)\n",
    "\n",
    "test_loader = DataLoader(test_dataset, batch_size=5, shuffle=False,\n",
    "                         num_workers=config.data.num_workers, drop_last=True)\n",
    "\n",
    "test_iter = iter(test_loader)\n",
    "print(test_dataset.__getitem__(50)[0].shape)\n",
    "test_x, test_y = next(test_iter)\n",
    "print(test_x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cebee4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_x = data_transform(config, test_x)\n",
    "real, cond, cond_mask = conditioning_fn(config, test_x, num_frames_pred=config.data.num_frames,\n",
    "                                        prob_mask_cond=getattr(config.data, 'prob_mask_cond', 0.0),\n",
    "                                        prob_mask_future=getattr(config.data, 'prob_mask_future', 0.0))\n",
    "\n",
    "print(real.shape, cond.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4b6e45b",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 1\n",
    "# media.show_images(torch.cat([cond[i].permute(0, 2, 3, 1), real[i].permute(0, 2, 3, 1)]))\n",
    "media.show_images(torch.cat([cond[i], real[i]]))\n",
    "media.show_images(cond[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d37ea28",
   "metadata": {},
   "source": [
    "# Load initial samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55bde034",
   "metadata": {},
   "outputs": [],
   "source": [
    "init = init_samples(len(real), config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9253a5ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(init.shape)\n",
    "media.show_images(init[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47b62c71",
   "metadata": {},
   "source": [
    "# Predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d2a372b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "! module load ninja\n",
    "import time\n",
    "startTime = time.time()\n",
    "pred = sampler(init, scorenet, cond=cond, cond_mask=cond_mask, subsample=100, verbose=True)\n",
    "\n",
    "from models.projection import Projection\n",
    "# 3.7 mars, 24.8 Jupiter\n",
    "project_height = Projection(cond=pred[0][0].unsqueeze(0), acceleration=3.7)\n",
    "\n",
    "\n",
    "\n",
    "# print(pred[0].shape, x.shape)\n",
    "\n",
    "\n",
    "# media.show_images(x)\n",
    "\n",
    "batch_list = []\n",
    "for j in range(5):\n",
    "    # for i in range(5):\n",
    "        x = project_height.apply(pred[j][0:]).squeeze(1)\n",
    "        # print(x.shape)\n",
    "        # media.show_images(x)        \n",
    "        batch_list.append(x)\n",
    "        print(j)\n",
    "        # save_images(x[j].unsqueeze(0), f'./cop-m-{j+1}.png')\n",
    "\n",
    "cop = torch.stack(batch_list)\n",
    "executionTime = (time.time() - startTime)\n",
    "print('Execution time in seconds: ' + str(executionTime))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "881fc304",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(pred.shape)\n",
    "i = 1\n",
    "media.show_images(torch.cat([cond[i], real[i]]))\n",
    "# for i in range(5):\n",
    "media.show_images(torch.cat([cond[i], pred[i]]))\n",
    "print(pred.shape, pred.max(), pred.min())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7937cd9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 4\n",
    "media.show_images(torch.cat([cond[i], real[i]]))\n",
    "media.show_images(real[2])\n",
    "media.show_images(torch.cat([cond[i], pred[i]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0f85ca1-3100-4d14-beeb-0b103ce4fcb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "\n",
    "def save_images(images, path, **kwargs):\n",
    "    grid = torchvision.utils.make_grid(images, **kwargs)\n",
    "    ndarr = grid.permute(1, 2, 0).to('cpu').numpy()\n",
    "    ndarr = (ndarr * 255).astype(np.uint8)\n",
    "\n",
    "    print(ndarr.shape)\n",
    "    im = Image.fromarray(ndarr)\n",
    "    im.save(path)\n",
    "\n",
    "print(pred[0][i].unsqueeze(0).shape)\n",
    "for i in range(5):\n",
    "    save_images(pred[0][i].unsqueeze(0), f'./co-earth-{i+1}.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f3ba975-e72a-4823-b5f0-f96b805c88d4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from models.projection import Projection\n",
    "# 3.7 mars, 24.8 Jupiter\n",
    "project_height = Projection(cond=pred[0][0].unsqueeze(0), acceleration=3.7)\n",
    "\n",
    "\n",
    "\n",
    "# print(pred[0].shape, x.shape)\n",
    "\n",
    "\n",
    "# media.show_images(x)\n",
    "\n",
    "batch_list = []\n",
    "for j in range(5):\n",
    "    # for i in range(5):\n",
    "        x = project_height.apply(pred[j][0:]).squeeze(1)\n",
    "        # print(x.shape)\n",
    "        # media.show_images(x)        \n",
    "        batch_list.append(x)\n",
    "        print(j)\n",
    "        save_images(x[j].unsqueeze(0), f'./cop-m-{j+1}.png')\n",
    "\n",
    "cop = torch.stack(batch_list)\n",
    "        \n",
    "# x = project_height.apply(torch.cat([cond[i], pred[i]])).squeeze(1)\n",
    "# print(pred[0].shape, x.shape)\n",
    "\n",
    "media.show_images(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d97883f-248c-4b3d-a76e-35511d0ab522",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torchvision.transforms import ToTensor\n",
    "import random\n",
    "from PIL import Image, ImageDraw\n",
    "\n",
    "\n",
    "\n",
    "def create_background_with_gray_lines(frame_size=(64, 64), num_lines=50):\n",
    "    \"\"\"Create a background with random gray lines.\"\"\"\n",
    "    background = Image.new(\"L\", frame_size, \"white\")\n",
    "    draw = ImageDraw.Draw(background)\n",
    "\n",
    "    for idx in range(int(num_lines/2)):\n",
    "\n",
    "        const_point = int((frame_size[0] / (num_lines/2)) * idx) +1\n",
    "\n",
    "        gray_shade = 127\n",
    "        draw.line([(const_point, 0), (const_point, 63)], fill=gray_shade)\n",
    "        draw.line([(0, const_point), (63, const_point)], fill=gray_shade)\n",
    "\n",
    "    return background\n",
    "\n",
    "def create_bw_frame(background, ball_y, ball_x, ball_radius=5):\n",
    "    \"\"\"Create a single black and white frame with the ball at the specified y position.\"\"\"\n",
    "    frame = background.copy()  # Copy the background with gray lines\n",
    "    draw = ImageDraw.Draw(frame)\n",
    "    x_position = ball_x\n",
    "    top_left = (x_position - ball_radius, ball_y - ball_radius)\n",
    "    bottom_right = (x_position + ball_radius, ball_y + ball_radius)\n",
    "    draw.ellipse([top_left, bottom_right], fill=\"black\")\n",
    "    return frame\n",
    "\n",
    "def generate_bw_animation_frames(num_frames=10, frame_size=(64, 64)):\n",
    "    \"\"\"Generate a series of black and white frames for the animation.\"\"\"\n",
    "    frames = []\n",
    "    # ball_x = random.randrange(8, 56, 1)\n",
    "    ball_x = 32\n",
    "    # max_height = random.randrange(48, 56, 1)\n",
    "    max_height = 56\n",
    "    background = create_background_with_gray_lines(frame_size)\n",
    "\n",
    "    for i in range(num_frames):\n",
    "        t = i / (num_frames - 1)\n",
    "        ball_y = (64 - max_height) + int(position_change(9.8, i) / 2)\n",
    "        frame = create_bw_frame(background, ball_y, ball_x, ball_radius=5)\n",
    "        frames.append(frame)\n",
    "    return frames\n",
    "\n",
    "def position_change(acceleration, time):\n",
    "        # change_in_position = initial_velocity * time + 0.5 * acceleration * time^2\n",
    "        change_in_position = 0.5 * acceleration * time ** 2\n",
    "\n",
    "        return change_in_position\n",
    "\n",
    "# Generate black and white frames\n",
    "cond = generate_bw_animation_frames()\n",
    "cond = torch.stack([ToTensor()(frame) for frame in cond], dim=0).squeeze(1)[:5]\n",
    "\n",
    "print(cond.shape)\n",
    "print(pred.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (Score-SDE)",
   "language": "python",
   "name": "score-sde"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
