{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "99e8c612",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecb8ed7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.nn.parameter import Parameter\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import operator\n",
    "from functools import reduce\n",
    "from functools import partial\n",
    "from timeit import default_timer\n",
    "from utilities3 import *\n",
    "\n",
    "torch.manual_seed(11) # 0, 11, 17\n",
    "torch.cuda.manual_seed(11)  \n",
    "torch.cuda.manual_seed_all(11) \n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda:0\") \n",
    "    print(f\"Running on GPUs: {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "    print(\"Running on CPU\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d58c329b",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7044d05",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --------------------------------------------------------------------\n",
    "# squeeze and excitation block\n",
    "# --------------------------------------------------------------------\n",
    "class SEBlock(nn.Module):\n",
    "    def __init__(self, ch: int, reduction: int = 1):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Conv2d(ch, ch // reduction, 1, bias=True)\n",
    "        self.fc2 = nn.Conv2d(ch // reduction, ch, 1, bias=True)\n",
    "\n",
    "    def forward(self, x):\n",
    "        s = x.mean((2, 3), keepdim=True)\n",
    "        s = F.gelu(self.fc1(s))\n",
    "        s = torch.sigmoid(self.fc2(s))\n",
    "        return x * s\n",
    "\n",
    "# --------------------------------------------------------------------\n",
    "# DC block : 3x3 and 5x5 D-Conv2d\n",
    "# --------------------------------------------------------------------\n",
    "class ResidualSEBlock2d(nn.Module):\n",
    "    def __init__(self, ch: int, dilation: int):\n",
    "        super().__init__()\n",
    "        self.conv = nn.Conv2d(ch, ch, 3, padding=dilation, dilation=dilation, bias=True)\n",
    "        self.conv5 = nn.Conv2d(ch, ch, 5, padding= 2 *dilation, dilation=dilation, bias=False)\n",
    "        self.se   = SEBlock(ch)\n",
    "        self.act  = nn.GELU()\n",
    "\n",
    "    def forward(self, x):\n",
    "        y = self.conv(x)\n",
    "        y = F.gelu(y)\n",
    "        y = self.conv5(y)\n",
    "        y = self.se(y)\n",
    "        return self.act(x + y)\n",
    "\n",
    "# --------------------------------------------------------------------\n",
    "# network body \n",
    "# --------------------------------------------------------------------\n",
    "class Net2d(nn.Module):\n",
    "    def __init__(self, width: int, H: int, W: int, num_blocks: int = 6):\n",
    "        super().__init__()\n",
    "        self.stem = nn.Conv2d(3, width, 1, bias=True)\n",
    "        dilations =[1, 3, 5, 9, 13, 19]\n",
    "        self.blocks = nn.Sequential(\n",
    "            *[ResidualSEBlock2d(width, d) for d in dilations])\n",
    "\n",
    "        self.head = nn.Sequential(\n",
    "            nn.Conv2d(width, 128, 1, bias=True),\n",
    "            nn.GELU(),\n",
    "            nn.Conv2d(128, 1, 1))\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.permute(0, 3, 1, 2)   # (B,3,H,W)\n",
    "        x = self.stem(x)\n",
    "        x = self.blocks(x)\n",
    "        x = self.head(x)\n",
    "        return x.squeeze(1)\n",
    "\n",
    "    def count_params(self):\n",
    "        return sum(p.numel() for p in self.parameters() if p.requires_grad)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26bfadd5",
   "metadata": {},
   "source": [
    "# Read data and create train and test loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30eb9bc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "################################################################\n",
    "# configs\n",
    "################################################################\n",
    "TRAIN_PATH = 'piececonst_r421_N1024_smooth1.mat'\n",
    "TEST_PATH = 'piececonst_r421_N1024_smooth2.mat'\n",
    "\n",
    "ntrain = 1000\n",
    "ntest = 200\n",
    "\n",
    "batch_size = 20\n",
    "learning_rate = 0.001\n",
    "\n",
    "epochs = 500\n",
    "step_size = 100\n",
    "gamma = 0.5\n",
    "\n",
    "\n",
    "width = 48\n",
    "\n",
    "r = 5\n",
    "h = int(((421 - 1)/r) + 1)\n",
    "s = h\n",
    "\n",
    "################################################################\n",
    "# load data and data normalization\n",
    "################################################################\n",
    "reader = MatReader(TRAIN_PATH)\n",
    "x_train = reader.read_field('coeff')[:ntrain,::r,::r][:,:s,:s]\n",
    "y_train = reader.read_field('sol')[:ntrain,::r,::r][:,:s,:s]\n",
    "\n",
    "reader.load_file(TEST_PATH)\n",
    "x_test = reader.read_field('coeff')[:ntest,::r,::r][:,:s,:s]\n",
    "y_test = reader.read_field('sol')[:ntest,::r,::r][:,:s,:s]\n",
    "\n",
    "x_normalizer = UnitGaussianNormalizer(x_train)\n",
    "x_train = x_normalizer.encode(x_train)\n",
    "x_test = x_normalizer.encode(x_test)\n",
    "\n",
    "y_normalizer = UnitGaussianNormalizer(y_train)\n",
    "y_train = y_normalizer.encode(y_train)\n",
    "\n",
    "grids = []\n",
    "grids.append(np.linspace(0, 1, s))\n",
    "grids.append(np.linspace(0, 1, s))\n",
    "grid = np.vstack([xx.ravel() for xx in np.meshgrid(*grids)]).T\n",
    "grid = grid.reshape(1,s,s,2)\n",
    "grid = torch.tensor(grid, dtype=torch.float)\n",
    "x_train = torch.cat([x_train.reshape(ntrain,s,s,1), grid.repeat(ntrain,1,1,1)], dim=3)\n",
    "x_test = torch.cat([x_test.reshape(ntest,s,s,1), grid.repeat(ntest,1,1,1)], dim=3)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c8df685",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "model = Net2d(width, s,s).cuda()\n",
    "print(model.count_params())\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f132b68",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "myloss = LpLoss(size_average=False)\n",
    "y_normalizer.cuda()\n",
    "t3 = default_timer()\n",
    "for ep in range(epochs):\n",
    "    model.train()\n",
    "    t1 = default_timer()\n",
    "    train_l2 = 0\n",
    "    for x, y in train_loader:\n",
    "        x, y = x.cuda(), y.cuda()\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        out = y_normalizer.decode(model(x))\n",
    "        y = y_normalizer.decode(y)\n",
    "    \n",
    "        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))\n",
    "        l2.backward()\n",
    "        optimizer.step()\n",
    "        train_l2 += l2.item()\n",
    "        \n",
    "    scheduler.step()\n",
    "    model.eval()\n",
    "    test_l2 = 0.0\n",
    "    with torch.no_grad():\n",
    "        for x, y in test_loader:\n",
    "            x, y = x.cuda(), y.cuda()\n",
    "            out = y_normalizer.decode(model(x))\n",
    "            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()\n",
    "            \n",
    "    train_l2 /= ntrain\n",
    "    test_l2 /= ntest\n",
    "\n",
    "    t2 = default_timer()\n",
    "    print(f\"Epoch: {ep:.1f}, \"\n",
    "    f\"Time Elapsed: {t2-t1:.2f}, \"\n",
    "    f\"Train Loss: {train_l2:.4f}, \"\n",
    "    f\"Test Loss: {test_l2:.4f}\")\n",
    "    \n",
    "t4 = default_timer()\n",
    "print(\"total time = \", ((t4-t3)/500))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff2356a9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60753579",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
