{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecb8ed7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.nn.parameter import Parameter\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import operator\n",
    "from functools import reduce\n",
    "from functools import partial\n",
    "from timeit import default_timer\n",
    "from utilities3 import *\n",
    "\n",
    "torch.manual_seed(11) # 0, 11, 17\n",
    "torch.cuda.manual_seed(11)  \n",
    "torch.cuda.manual_seed_all(11) \n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda:0\") \n",
    "    print(f\"Running on GPUs: {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "    print(\"Running on CPU\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7044d05",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --------------------------------------------------------------------\n",
    "# squeeze and excitation block\n",
    "# --------------------------------------------------------------------\n",
    "class SEBlock(nn.Module):\n",
    "    def __init__(self, ch: int, reduction: int = 1):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Conv2d(ch, ch // reduction, 1, bias=True)\n",
    "        self.fc2 = nn.Conv2d(ch // reduction, ch, 1, bias=True)\n",
    "\n",
    "    def forward(self, x):\n",
    "        s = x.mean((2, 3), keepdim=True)\n",
    "        s = F.gelu(self.fc1(s))\n",
    "        s = torch.sigmoid(self.fc2(s))\n",
    "        return x * s\n",
    "\n",
    "# --------------------------------------------------------------------\n",
    "# DC block : 3x3 and 5x5 D-Conv2d\n",
    "# --------------------------------------------------------------------\n",
    "class ResidualSEBlock2d(nn.Module):\n",
    "    def __init__(self, ch: int, dilation):         \n",
    "        super().__init__()\n",
    "        if isinstance(dilation, int):\n",
    "            dilation = (dilation, dilation)      \n",
    "        pad3 = dilation                            \n",
    "        pad5 = (dilation[0] * 2, dilation[1] * 2)  \n",
    "        self.conv  = nn.Conv2d(ch, ch, 3,\n",
    "                               padding=pad3,\n",
    "                               dilation=dilation,\n",
    "                               bias= True)\n",
    "        self.conv5 = nn.Conv2d(ch, ch, 5,\n",
    "                               padding=pad5,\n",
    "                               dilation=dilation,\n",
    "                               bias=False)\n",
    "\n",
    "        self.se  = SEBlock(ch)\n",
    "        self.act = nn.GELU()\n",
    "\n",
    "    def forward(self, x): \n",
    "        y = F.gelu(self.conv(x))\n",
    "        y = self.conv5(y)\n",
    "        y = self.se(y)\n",
    "        return self.act(x + y)\n",
    "\n",
    "# --------------------------------------------------------------------\n",
    "# network body\n",
    "# --------------------------------------------------------------------\n",
    "class Net2d(nn.Module):\n",
    "    def __init__(self, width: int, H: int, W: int):\n",
    "        super().__init__()\n",
    "        self.stem = nn.Conv2d(4, width, 1, bias=True)\n",
    "\n",
    "        dilations_y = [1, 2, 8, 12, 6, 2, 1]     \n",
    "        dilations_x = [16, 56, 42, 36,  32, 24, 1]    \n",
    "\n",
    "        dilations   = list(zip(dilations_x, dilations_y))\n",
    "\n",
    "        self.blocks = nn.Sequential(\n",
    "            *[ResidualSEBlock2d(width, dxy) for dxy in dilations]\n",
    "        )\n",
    "\n",
    "        self.head = nn.Sequential(\n",
    "            nn.Conv2d(width, 128, 1),\n",
    "            nn.GELU(),\n",
    "            nn.Conv2d(128, 1, 1)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.permute(0, 3, 1, 2)      # (B, 3, H, W)\n",
    "        x = self.stem(x)\n",
    "        x = self.blocks(x)\n",
    "        x = self.head(x)\n",
    "        return x.squeeze(1)\n",
    "\n",
    "    def count_params(self):\n",
    "        return sum(p.numel() for p in self.parameters() if p.requires_grad)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30eb9bc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "################################################################\n",
    "# configs\n",
    "################################################################\n",
    "INPUT_X = 'NACA_Cylinder_X.npy'\n",
    "INPUT_Y = 'NACA_Cylinder_Y.npy'\n",
    "OUTPUT_Sigma = 'NACA_Cylinder_Q.npy'\n",
    "\n",
    "ntrain = 1000\n",
    "ntest = 200\n",
    "\n",
    "batch_size = 20\n",
    "learning_rate = 0.001\n",
    "\n",
    "epochs = 500\n",
    "step_size = 100\n",
    "gamma = 0.5\n",
    "\n",
    "\n",
    "width = 64\n",
    "\n",
    "r1=1\n",
    "r2=1\n",
    "s1 = int(((221 - 1) / r1) + 1)\n",
    "s2 = int(((51 - 1) / r2) + 1)\n",
    "\n",
    "\n",
    "################################################################\n",
    "# load data and data normalization\n",
    "################################################################\n",
    "\n",
    "\n",
    "inputX = np.load(INPUT_X)\n",
    "inputX = torch.tensor(inputX, dtype=torch.float)\n",
    "inputY = np.load(INPUT_Y)\n",
    "inputY = torch.tensor(inputY, dtype=torch.float)\n",
    "input = torch.stack([inputX, inputY], dim=-1)\n",
    "\n",
    "output = np.load(OUTPUT_Sigma)[:, 4]\n",
    "output = torch.tensor(output, dtype=torch.float)\n",
    "print(input.shape, output.shape)\n",
    "\n",
    "x_train = input[:ntrain, ::r1, ::r2][:, :s1, :s2]\n",
    "y_train = output[:ntrain, ::r1, ::r2][:, :s1, :s2]\n",
    "x_test = input[ntrain:ntrain + ntest, ::r1, ::r2][:, :s1, :s2]\n",
    "y_test = output[ntrain:ntrain + ntest, ::r1, ::r2][:, :s1, :s2]\n",
    "print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)\n",
    "\n",
    "\n",
    "x_normalizer = UnitGaussianNormalizer(x_train)\n",
    "x_train = x_normalizer.encode(x_train)\n",
    "x_test = x_normalizer.encode(x_test)\n",
    "\n",
    "y_normalizer = UnitGaussianNormalizer(y_train)\n",
    "y_train = y_normalizer.encode(y_train)\n",
    "\n",
    "xs = np.linspace(0, 1, s1)\n",
    "ys = np.linspace(0, 1, s2)\n",
    "xx, yy = np.meshgrid(xs, ys, indexing='xy')   \n",
    "grid = np.stack([xx, yy], axis=-1)               \n",
    "grid = grid.transpose(1, 0, 2)                    \n",
    "grid = grid[None, ...]                          \n",
    "grid_t = torch.tensor(grid, dtype=torch.float32) \n",
    "\n",
    "grid_train = grid_t.repeat(ntrain, 1, 1, 1)  \n",
    "grid_test  = grid_t.repeat(ntest,  1, 1, 1) \n",
    "\n",
    "\n",
    "x_train_aug = torch.cat([x_train, grid_train], dim=-1)\n",
    "x_test_aug  = torch.cat([x_test,  grid_test],  dim=-1) \n",
    "\n",
    "print(x_train_aug.shape, x_test_aug.shape)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    torch.utils.data.TensorDataset(x_train_aug, y_train),\n",
    "    batch_size=batch_size, shuffle=True)\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    torch.utils.data.TensorDataset(x_test_aug, y_test),\n",
    "    batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c8df685",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "model = Net2d(width, s1,s2).cuda()\n",
    "\n",
    "print(model.count_params())\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) \n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f132b68",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "myloss = LpLoss(size_average=False)\n",
    "y_normalizer.cuda()\n",
    "t3 = default_timer()\n",
    "for ep in range(epochs):\n",
    "    model.train()\n",
    "    t1 = default_timer()\n",
    "    train_l2 = 0\n",
    "    for x, y in train_loader:\n",
    "        x, y = x.cuda(), y.cuda()\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        out = y_normalizer.decode(model(x))\n",
    "        y = y_normalizer.decode(y)\n",
    "    \n",
    "        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))\n",
    "        l2.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "\n",
    "        train_l2 += l2.item()\n",
    "\n",
    "    scheduler.step()\n",
    "\n",
    "    model.eval()\n",
    "    test_l2 = 0.0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for x, y in test_loader:\n",
    "            x, y = x.cuda(), y.cuda()\n",
    "\n",
    "            out = y_normalizer.decode(model(x))\n",
    "            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()\n",
    "\n",
    "\n",
    "    train_l2 /= ntrain\n",
    "    test_l2 /= ntest\n",
    "\n",
    "    t2 = default_timer()\n",
    "    print(f\"Epoch: {ep:.1f}, \"\n",
    "    f\"Time Elapsed: {t2-t1:.2f}, \"\n",
    "    f\"Train Loss: {train_l2:.4f}, \"\n",
    "    f\"Test Loss: {test_l2:.4f}\"\n",
    ")\n",
    "    \n",
    "t4 = default_timer()\n",
    "print(\"total time = \", ((t4-t3)/500))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbeeeff4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
