{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecb8ed7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.nn.parameter import Parameter\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import operator\n",
    "from functools import reduce\n",
    "from functools import partial\n",
    "from timeit import default_timer\n",
    "from utilities3 import *\n",
    "\n",
    "torch.manual_seed(11) # 0, 11, 17\n",
    "torch.cuda.manual_seed(11)  \n",
    "torch.cuda.manual_seed_all(11) \n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda:0\") \n",
    "    print(f\"Running on GPUs: {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "    print(\"Running on CPU\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7044d05",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --------------------------------------------------------------------\n",
    "# squeeze and excitation block\n",
    "# --------------------------------------------------------------------\n",
    "class SEBlock(nn.Module):\n",
    "    def __init__(self, ch: int, reduction: int = 1):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Conv2d(ch, ch // reduction, 1, bias=True)\n",
    "        self.fc2 = nn.Conv2d(ch // reduction, ch, 1, bias=True)\n",
    "\n",
    "    def forward(self, x):\n",
    "        s = x.mean((2, 3), keepdim=True)\n",
    "        s = F.gelu(self.fc1(s))\n",
    "        s = torch.sigmoid(self.fc2(s))\n",
    "        return x * s\n",
    "\n",
    "# --------------------------------------------------------------------\n",
    "# DC block : 3x3 and 3x3 D-Conv2d\n",
    "# --------------------------------------------------------------------\n",
    "class ResidualSEBlock2d(nn.Module):\n",
    "    def __init__(self, ch: int, dilation: int):\n",
    "        super().__init__()\n",
    "        self.conv = nn.Conv2d(\n",
    "            ch, ch, 3, padding=dilation, dilation=dilation, bias=True\n",
    "        )\n",
    "        \n",
    "        self.conv5 = nn.Conv2d(\n",
    "            ch, ch, 3, padding=dilation, dilation=dilation, bias=True\n",
    "        )\n",
    "\n",
    "        self.se   = SEBlock(ch)\n",
    "        self.act  = nn.GELU()\n",
    "\n",
    "    def forward(self, x):  # b,w,x,y\n",
    "        y = F.gelu( self.conv(x) )\n",
    "        y = self.conv5(y)\n",
    "        y = self.se(y)\n",
    "        return self.act(x + y)\n",
    "\n",
    "# --------------------------------------------------------------------\n",
    "# network body \n",
    "# --------------------------------------------------------------------\n",
    "class Net2d(nn.Module):\n",
    "    def __init__(self, width: int, H: int, W: int, num_blocks: int = 12):\n",
    "        super().__init__()\n",
    "        self.stem = nn.Linear(12, width)\n",
    "        dilations = [15, 25, 17, 13, 7, 5, 3, 1]\n",
    "        self.blocks = nn.Sequential(\n",
    "            *[ResidualSEBlock2d(width, d) for d in dilations]\n",
    "        )\n",
    "\n",
    "        self.head = nn.Sequential(\n",
    "            nn.Conv2d(width, 128, 1, bias=True),\n",
    "            nn.GELU(),\n",
    "            nn.Conv2d(128, 1, 1)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.stem(x)\n",
    "        x = x.permute(0, 3, 1, 2) # b,w,x,y\n",
    "        x = self.blocks(x)\n",
    "        x = self.head(x)\n",
    "        return x.permute(0,2,3,1).contiguous()\n",
    "\n",
    "    def count_params(self):\n",
    "        return sum(p.numel() for p in self.parameters() if p.requires_grad)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30eb9bc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAIN_PATH = 'NavierStokes_V1e-5_N1200_T20.mat'\n",
    "TEST_PATH = 'NavierStokes_V1e-5_N1200_T20.mat'\n",
    "\n",
    "ntrain = 1000\n",
    "ntest = 200\n",
    "\n",
    "# modes = 12\n",
    "width = 64\n",
    "\n",
    "batch_size = 20\n",
    "batch_size2 = batch_size\n",
    "\n",
    "epochs = 500\n",
    "learning_rate = 0.001\n",
    "scheduler_step = 100\n",
    "scheduler_gamma = 0.5\n",
    "\n",
    "runtime = np.zeros(2, )\n",
    "t1 = default_timer()\n",
    "\n",
    "sub = 1\n",
    "S = 64\n",
    "s=S\n",
    "T_in = 10\n",
    "T = 10\n",
    "step = 1\n",
    "\n",
    "################################################################\n",
    "# load data\n",
    "################################################################\n",
    "\n",
    "reader = MatReader(TRAIN_PATH)\n",
    "train_a = reader.read_field('u')[:ntrain,::sub,::sub,:T_in]\n",
    "train_u = reader.read_field('u')[:ntrain,::sub,::sub,T_in:T+T_in]\n",
    "\n",
    "reader = MatReader(TEST_PATH)\n",
    "test_a = reader.read_field('u')[-ntest:,::sub,::sub,:T_in]\n",
    "test_u = reader.read_field('u')[-ntest:,::sub,::sub,T_in:T+T_in]\n",
    "\n",
    "print(train_u.shape)\n",
    "print(test_u.shape)\n",
    "assert (S == train_u.shape[-2])\n",
    "assert (T == train_u.shape[-1])\n",
    "\n",
    "train_a = train_a.reshape(ntrain,S,S,T_in)\n",
    "test_a = test_a.reshape(ntest,S,S,T_in)\n",
    "\n",
    "gridx = torch.linspace(0, 1, S, dtype=torch.float).reshape(1, S, 1, 1).repeat(1, 1, S, 1)\n",
    "gridy = torch.linspace(0, 1, S, dtype=torch.float).reshape(1, 1, S, 1).repeat(1, S, 1, 1)\n",
    "\n",
    "gridx_train = gridx.repeat(ntrain, 1, 1, 1)\n",
    "gridy_train = gridy.repeat(ntrain, 1, 1, 1)\n",
    "gridx_test  = gridx.repeat(ntest,  1, 1, 1)\n",
    "gridy_test  = gridy.repeat(ntest,  1, 1, 1)\n",
    "\n",
    "train_a = torch.cat((train_a, gridx_train, gridy_train), dim=-1)\n",
    "test_a  = torch.cat((test_a,  gridx_test,  gridy_test),  dim=-1)\n",
    "\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "train_loader = DataLoader(TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True)\n",
    "test_loader  = DataLoader(TensorDataset(test_a,  test_u),  batch_size=batch_size, shuffle=False)\n",
    "\n",
    "t2 = default_timer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c8df685",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "model = Net2d(width, s,s).cuda()\n",
    "print(model.count_params())\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5cf2708",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "myloss = LpLoss(size_average=False)\n",
    "gridx = gridx.to(device)\n",
    "gridy = gridy.to(device)\n",
    "\n",
    "\n",
    "t3 = default_timer()\n",
    "for ep in range(epochs):\n",
    "    model.train()\n",
    "    t1 = default_timer()\n",
    "    train_l2_step = 0\n",
    "    train_l2_full = 0\n",
    "    for xx, yy in train_loader:\n",
    "        loss = 0\n",
    "        xx = xx.to(device)\n",
    "        yy = yy.to(device)\n",
    "\n",
    "        for t in range(0, T, step):\n",
    "            y = yy[..., t:t + step]\n",
    "            im = model(xx)\n",
    "            loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1))\n",
    "\n",
    "            if t == 0:\n",
    "                pred = im\n",
    "            else:\n",
    "                pred = torch.cat((pred, im), -1)\n",
    "\n",
    "            xx = torch.cat((xx[..., step:-2], im,\n",
    "                            gridx.repeat([batch_size, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1])), dim=-1)\n",
    "\n",
    "        train_l2_step += loss.item()\n",
    "        l2_full = myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1))\n",
    "        train_l2_full += l2_full.item()\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "\n",
    "    test_l2_step = 0\n",
    "    test_l2_full = 0\n",
    "    with torch.no_grad():\n",
    "        for xx, yy in test_loader:\n",
    "            loss = 0\n",
    "            xx = xx.to(device)\n",
    "            yy = yy.to(device)\n",
    "\n",
    "            for t in range(0, T, step):\n",
    "                y = yy[..., t:t + step]\n",
    "                im = model(xx)\n",
    "                loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1))\n",
    "\n",
    "                if t == 0:\n",
    "                    pred = im\n",
    "                else:\n",
    "                    pred = torch.cat((pred, im), -1)\n",
    "\n",
    "                xx = torch.cat((xx[..., step:-2], im,\n",
    "                                gridx.repeat([batch_size, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1])), dim=-1)\n",
    "\n",
    "\n",
    "            test_l2_step += loss.item()\n",
    "            test_l2_full += myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item()\n",
    "\n",
    "    t2 = default_timer()\n",
    "    scheduler.step()\n",
    "    print(ep, t2 - t1, train_l2_step / ntrain / (T / step), train_l2_full / ntrain, test_l2_step / ntest / (T / step),\n",
    "          test_l2_full / ntest)\n",
    "t4 = default_timer()\n",
    "print(\"total time = \", ((t4-t3)/epochs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bfeb4d2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
