{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50, 3, 32, 32)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = pickle.load(open('data/val.pkl', 'rb'))\n",
    "dx = dy = data['dx']\n",
    "dt = data['dt']\n",
    "coriolis_param = data['coriolis_param']\n",
    "data = data['data'][0]\n",
    "depth = data['depth']\n",
    "gravity = data['gravity']\n",
    "data = data['data']\n",
    "data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((50, 32, 32), (50, 32, 32), (50, 32, 32))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "h, u, v = data.transpose(1, 0, 2, 3)\n",
    "h = h + depth\n",
    "u, v = 0.1 * u, 0.1 * v\n",
    "h.shape, u.shape, v.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(96.69035386477402, 0.0920699070003061, 0.07996052298958502)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.abs(h).max(), np.abs(u).max(), np.abs(v).max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((50, 30, 30), (50, 30, 30))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v_avg = 0.25 * (v[:, 1:-1, 1:-1] + v[:, :-2, 1:-1] + v[:, 1:-1, 2:] + v[:, :-2, 2:])\n",
    "u_avg = 0.25 * (u[:, 1:-1, 1:-1] + u[:, 1:-1, :-2] + u[:, 2:, 1:-1] + u[:, 2:, :-2])\n",
    "v_avg.shape, u_avg.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((49, 30, 30), (49, 30, 30), (49, 30, 30))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dudt = np.diff(u, axis=0)[:, 1:-1, 1:-1] / dt\n",
    "dvdt = np.diff(v, axis=0)[:, 1:-1, 1:-1] / dt\n",
    "dhdt = np.diff(h, axis=0)[:, 1:-1, 1:-1] / dt\n",
    "dudt.shape, dvdt.shape, dhdt.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((50, 30, 30), (50, 30, 30), (50, 30, 30), (50, 30, 30))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dhdx = (h[:, 1:-1, 2:] - h[:, 1:-1, 1:-1]) / dx\n",
    "dhdy = (h[:, 2:, 1:-1] - h[:, 1:-1, 1:-1]) / dy\n",
    "dudx = (u[:, 1:-1, 1:-1] - u[:, 1:-1, :-2]) / dx\n",
    "dvdy = (v[:, 1:-1, 1:-1] - v[:, :-2, 1:-1]) / dy\n",
    "dhdx.shape, dhdy.shape, dudx.shape, dvdy.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6.522153693858113e-20, (49, 30, 30))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss1 = dudt - (coriolis_param * v_avg - gravity * dhdx)[:-1]\n",
    "np.abs(loss1).max(), loss1.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6.437450399132683e-20, (49, 30, 30))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss2 = dvdt + coriolis_param * u_avg[1:] + gravity * dhdy[:-1]\n",
    "np.abs(loss2).max(), loss2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2.2073678605447067e-17, (49, 30, 30))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss3 = dhdt + depth * (dudx + dvdy)[1:]\n",
    "np.abs(loss3).max(), loss3.shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
