{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36256a71",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.nn.parameter import Parameter\n",
    "import torch.fft\n",
    "import matplotlib.pyplot as plt\n",
    "import operator\n",
    "from functools import reduce\n",
    "from functools import partial\n",
    "from timeit import default_timer\n",
    "from utilities3 import *\n",
    "\n",
    "torch.manual_seed(11)\n",
    "torch.cuda.manual_seed(11)  \n",
    "torch.cuda.manual_seed_all(11) \n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda:0\") \n",
    "    print(f\"Running on GPUs: {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "    print(\"Running on CPU\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9efe8f0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SpectralConv2d_fast(nn.Module):\n",
    "    def __init__(self, in_channels, out_channels, modes1, modes2):\n",
    "        super().__init__()\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.modes1 = modes1  # number of low-frequency modes to keep along dim1\n",
    "        self.modes2 = modes2  # number of low-frequency modes to keep along dim2\n",
    "\n",
    "        scale = 1 / (in_channels * out_channels)\n",
    "        self.weights1 = nn.Parameter(\n",
    "            scale * torch.randn(in_channels, out_channels, modes1, modes2, dtype=torch.cfloat)\n",
    "        )\n",
    "        self.weights2 = nn.Parameter(\n",
    "            scale * torch.randn(in_channels, out_channels, modes1, modes2, dtype=torch.cfloat)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        batchsize, _, height, width = x.shape\n",
    "\n",
    "\n",
    "        x_ft = torch.fft.rfft2(x, norm=\"ortho\")  # shape (b, in_ch, h, w//2+1), dtype=complex\n",
    "\n",
    "        out_ft = torch.zeros(\n",
    "            batchsize, self.out_channels, height, width // 2 + 1,\n",
    "            dtype=torch.cfloat, device=x.device)\n",
    "\n",
    "        out_ft[:, :, : self.modes1, : self.modes2] = torch.einsum(\n",
    "            \"bixy,ioxy->boxy\", \n",
    "            x_ft[:, :, : self.modes1, : self.modes2],\n",
    "            self.weights1)\n",
    "\n",
    "        out_ft[:, :, -self.modes1 :, : self.modes2] = torch.einsum(\n",
    "            \"bixy,ioxy->boxy\",\n",
    "            x_ft[:, :, -self.modes1 :, : self.modes2],\n",
    "            self.weights2)\n",
    "\n",
    "        x = torch.fft.irfft2(out_ft, s=(height, width), norm=\"ortho\")\n",
    "        return x\n",
    "\n",
    "class SimpleBlock2d(nn.Module):\n",
    "    def __init__(self, modes1, modes2, width):\n",
    "        super(SimpleBlock2d, self).__init__()\n",
    "\n",
    "        self.modes1 = modes1\n",
    "        self.modes2 = modes2\n",
    "        self.width = width\n",
    "        self.fc0 = nn.Linear(3, self.width)\n",
    "\n",
    "        self.conv0 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)\n",
    "        self.conv1 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)\n",
    "        self.conv2 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)\n",
    "        self.conv3 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)\n",
    "        self.w0 = nn.Conv1d(self.width, self.width, 1)\n",
    "        self.w1 = nn.Conv1d(self.width, self.width, 1)\n",
    "        self.w2 = nn.Conv1d(self.width, self.width, 1)\n",
    "        self.w3 = nn.Conv1d(self.width, self.width, 1)\n",
    "\n",
    "        self.fc1 = nn.Linear(self.width, 128)\n",
    "        self.fc2 = nn.Linear(128, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        batchsize = x.shape[0]\n",
    "        size_x, size_y = x.shape[1], x.shape[2]\n",
    "\n",
    "        x = self.fc0(x)\n",
    "        x = x.permute(0, 3, 1, 2)\n",
    "\n",
    "        x1 = self.conv0(x)\n",
    "        x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y)\n",
    "        x = F.gelu(x1 + x2)\n",
    "        \n",
    "        x1 = self.conv1(x)\n",
    "        x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y)\n",
    "        x = F.gelu(x1 + x2)\n",
    "        \n",
    "        x1 = self.conv2(x)\n",
    "        x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y)\n",
    "        x = F.gelu(x1 + x2)\n",
    "        \n",
    "        x1 = self.conv3(x)\n",
    "        x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y)\n",
    "        x = x1 + x2\n",
    "\n",
    "\n",
    "        x = x.permute(0, 2, 3, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.gelu(x)\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "\n",
    "class Net2d(nn.Module):\n",
    "    def __init__(self, modes, width):\n",
    "        super(Net2d, self).__init__()\n",
    "\n",
    "        self.conv1 = SimpleBlock2d(modes, modes,  width)\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        return x.squeeze()\n",
    "\n",
    "\n",
    "    def count_params(self):\n",
    "        c = 0\n",
    "        for p in self.parameters():\n",
    "            c += reduce(operator.mul, list(p.size()))\n",
    "\n",
    "        return c\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c12dc80",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "################################################################\n",
    "# configs\n",
    "################################################################\n",
    "TRAIN_PATH = 'piececonst_r421_N1024_smooth1.mat'\n",
    "TEST_PATH = 'piececonst_r421_N1024_smooth2.mat'\n",
    "\n",
    "ntrain = 1000\n",
    "ntest = 200\n",
    "\n",
    "batch_size = 20\n",
    "learning_rate = 0.001\n",
    "\n",
    "epochs = 500\n",
    "step_size = 100\n",
    "gamma = 0.5\n",
    "\n",
    "modes = 42\n",
    "width = 48\n",
    "\n",
    "r = 5\n",
    "h = int(((421 - 1)/r) + 1)\n",
    "s = h\n",
    "\n",
    "################################################################\n",
    "# load data and data normalization\n",
    "################################################################\n",
    "reader = MatReader(TRAIN_PATH)\n",
    "x_train = reader.read_field('coeff')[:ntrain,::r,::r][:,:s,:s]\n",
    "y_train = reader.read_field('sol')[:ntrain,::r,::r][:,:s,:s]\n",
    "\n",
    "reader.load_file(TEST_PATH)\n",
    "x_test = reader.read_field('coeff')[:ntest,::r,::r][:,:s,:s]\n",
    "y_test = reader.read_field('sol')[:ntest,::r,::r][:,:s,:s]\n",
    "\n",
    "x_normalizer = UnitGaussianNormalizer(x_train)\n",
    "x_train = x_normalizer.encode(x_train)\n",
    "x_test = x_normalizer.encode(x_test)\n",
    "\n",
    "y_normalizer = UnitGaussianNormalizer(y_train)\n",
    "y_train = y_normalizer.encode(y_train)\n",
    "\n",
    "grids = []\n",
    "grids.append(np.linspace(0, 1, s))\n",
    "grids.append(np.linspace(0, 1, s))\n",
    "grid = np.vstack([xx.ravel() for xx in np.meshgrid(*grids)]).T\n",
    "grid = grid.reshape(1,s,s,2)\n",
    "grid = torch.tensor(grid, dtype=torch.float)\n",
    "x_train = torch.cat([x_train.reshape(ntrain,s,s,1), grid.repeat(ntrain,1,1,1)], dim=3)\n",
    "x_test = torch.cat([x_test.reshape(ntest,s,s,1), grid.repeat(ntest,1,1,1)], dim=3)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12e4a6b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Net2d(modes, width).cuda()\n",
    "print(model.count_params())\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12a216aa",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "myloss = LpLoss(size_average=False)\n",
    "y_normalizer.cuda()\n",
    "t3 = default_timer()\n",
    "for ep in range(epochs):\n",
    "    model.train()\n",
    "    t1 = default_timer()\n",
    "    train_l2 = 0\n",
    "    for x, y in train_loader:\n",
    "        x, y = x.cuda(), y.cuda()\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        out = model(x)\n",
    "        out = y_normalizer.decode(out)\n",
    "        y = y_normalizer.decode(y)\n",
    "        loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1))\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        train_l2 += loss.item()\n",
    "\n",
    "    scheduler.step()\n",
    "    model.eval()\n",
    "    test_l2 = 0.0\n",
    "    with torch.no_grad():\n",
    "        for x, y in test_loader:\n",
    "            x, y = x.cuda(), y.cuda()\n",
    "\n",
    "            out = model(x)\n",
    "            out = y_normalizer.decode(model(x))\n",
    "\n",
    "            test_l2 += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item()\n",
    "\n",
    "    train_l2 /= ntrain\n",
    "    test_l2 /= ntest\n",
    "\n",
    "    t2 = default_timer()\n",
    "    print(f\"Epoch: {ep:.1f}, \"\n",
    "    f\"Time Elapsed: {t2-t1:.2f}, \"\n",
    "    f\"Train Loss: {train_l2:.4f}, \"\n",
    "    f\"Test Loss: {test_l2:.4f}\")\n",
    "    \n",
    "t4 = default_timer()\n",
    "print(\"total time = \", ((t4-t3)/500))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06bdf921",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d0a9fa4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aadcf707",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b35b2023",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
