{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e8d8e22",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-13T06:54:35.571835Z",
     "start_time": "2024-11-13T06:54:35.564889Z"
    }
   },
   "outputs": [],
   "source": [
    "# import modules\n",
    "from tqdm import tqdm_notebook\n",
    "import pickle as pkl\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# deep learning modules\n",
    "import torch\n",
    "\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Plot modules\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import axes3d\n",
    "import copy\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "126cc2c3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-13T06:54:35.700259Z",
     "start_time": "2024-11-13T06:54:35.695465Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "# Use Gpu\n",
    "\n",
    "device = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c091636",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-13T06:55:01.519707Z",
     "start_time": "2024-11-13T06:55:01.482455Z"
    }
   },
   "outputs": [],
   "source": [
    "data = torch.load('Helmholtz_neumann_51.pt')\n",
    "\n",
    "source = data[0,0]\n",
    "solution = data[0,1]\n",
    "\n",
    "source = source.view(-1,1).to(device)\n",
    "solution_whole = solution.view(-1,1).to(device)\n",
    "source.shape, solution.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d88191e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-13T06:55:01.953977Z",
     "start_time": "2024-11-13T06:55:01.947434Z"
    }
   },
   "outputs": [],
   "source": [
    "Nx, Ny = 50, 50\n",
    "hx, hy = 1/Nx, 1/Ny\n",
    "k=1\n",
    "x, y = torch.linspace(0,1,Nx+1), torch.linspace(0,1,Ny+1)\n",
    "X, Y = torch.meshgrid(x,y)\n",
    "\n",
    "grid = torch.stack([X,Y], axis=2).view(-1,2)\n",
    "grid = grid.to(device)\n",
    "grid_bc = grid.clone()[(grid[:,0]==0) | (grid[:,1]==0) | (grid[:,0]==1) | (grid[:,1]==1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a58588ff",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-13T06:55:03.305844Z",
     "start_time": "2024-11-13T06:55:03.127135Z"
    }
   },
   "outputs": [],
   "source": [
    "plt.scatter(grid[:,0].cpu().detach(), grid[:,1].cpu().detach())\n",
    "plt.scatter(grid_bc[:,0].cpu().detach(), grid_bc[:,1].cpu().detach())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b4b2111",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-13T06:55:06.068928Z",
     "start_time": "2024-11-13T06:55:06.061874Z"
    }
   },
   "outputs": [],
   "source": [
    "Ncut=5\n",
    "if Ncut != 0:\n",
    "    solution = solution_whole.reshape(Nx+1, Nx+1)[Ncut:-Ncut, Ncut:-Ncut].reshape(1,-1)\n",
    "else :\n",
    "    solution = solution_whole.clone().squeeze(-1)\n",
    "solution.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c495868",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-13T06:55:06.616603Z",
     "start_time": "2024-11-13T06:55:06.611373Z"
    }
   },
   "outputs": [],
   "source": [
    "grid = grid.requires_grad_(True)\n",
    "grid_bc = grid_bc.requires_grad_(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a833c645",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-13T06:55:06.994179Z",
     "start_time": "2024-11-13T06:55:06.991515Z"
    }
   },
   "outputs": [],
   "source": [
    "def calculate_derivative(y, x) :\n",
    "    return torch.autograd.grad(y, x, create_graph=True,\\\n",
    "                        grad_outputs=torch.ones(y.size()).to(device))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0eccdaf0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-13T06:55:07.672447Z",
     "start_time": "2024-11-13T06:55:07.661616Z"
    }
   },
   "outputs": [],
   "source": [
    "class net(nn.Module):\n",
    "    def __init__(self, hidden_dims, act) :                    # Hidden_dims : [h1, h2, h3, ..., hn]\n",
    "        super(net, self).__init__()\n",
    "        \n",
    "        self.layers = []\n",
    "        for i in range(len(hidden_dims)-1) :\n",
    "            self.layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1])) # hidden layers\n",
    "        self.layers = nn.ModuleList(self.layers)\n",
    "        \n",
    "        for layer in self.layers :                       # Weight initialization\n",
    "            nn.init.xavier_uniform_(layer.weight)        # Also known as Glorot initialization\n",
    "            \n",
    "        self.act = act#nn.Tanh()   #nn.ReLU() #                          # Nonlinear activation function\n",
    "        \n",
    "        \n",
    "    def forward(self, x) :\n",
    "        x = self.act(self.layers[0](x))\n",
    "        for layer in self.layers[1:-1] :\n",
    "            x = self.act(layer(x)) \n",
    "        x = self.layers[-1](x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f09a673e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-11-13T07:37:24.451378Z",
     "start_time": "2024-11-13T07:37:24.429337Z"
    }
   },
   "outputs": [],
   "source": [
    "def train(model_solution, model_source, optimizer, loss_f) :\n",
    "    model_source.train()\n",
    "    model_solution.train()\n",
    "    optimizer.zero_grad()\n",
    "    loss_ge_list, loss_bc_list = [], []\n",
    "    \n",
    "    source_output = model_source(grid)\n",
    "    solution_output = model_solution(grid)\n",
    "    solution_bc_output = model_solution(grid_bc)[:,0].view(-1,1)\n",
    "    \n",
    "    grad_solution = calculate_derivative(solution_output, grid)\n",
    "    grad_solution_bc = calculate_derivative(solution_bc_output, grid_bc)\n",
    "    \n",
    "    solution_x, solution_y = grad_solution[:,0].view(-1,1), grad_solution[:,1].view(-1,1)\n",
    "#     solution_xx = calculate_derivative(solution_x, grid)[:,0].view(-1,1)\n",
    "#     solution_yy = calculate_derivative(solution_y, grid)[:,1].view(-1,1)\n",
    "    \n",
    "    ge = calculate_derivative(solution_x, grid)[:,0].view(-1,1) +\\\n",
    "         calculate_derivative(solution_y, grid)[:,1].view(-1,1) +\\\n",
    "         (k**2)*solution_output -\\\n",
    "         source_output\n",
    "    \n",
    "    left = grid_bc[:,0]==0\n",
    "    right = grid_bc[:,0]==1\n",
    "    top = grid_bc[:,1]==1\n",
    "    bottom = grid_bc[:,1]==0\n",
    "    neumann = torch.cat([grad_solution_bc[:,0][left], grad_solution_bc[:,0][right],\\\n",
    "                         grad_solution_bc[:,1][top], grad_solution_bc[:,1][bottom]])\n",
    "    neumann_target = torch.cat([torch.zeros(Nx+1), torch.zeros(Nx+1),\\\n",
    "                                torch.zeros(Nx+1), torch.zeros(Nx+1)]).to(device)\n",
    "    \n",
    "    solution_output_cut=solution_output.view(Nx+1,Nx+1)[Ncut:-Ncut, Ncut:-Ncut].reshape(1,-1)\n",
    "    \n",
    "    loss_data = loss_f(solution_output_cut, solution)\n",
    "    loss_ge = loss_f(ge, torch.zeros_like(ge)) \n",
    "    loss_bc = loss_f(neumann, neumann_target)\n",
    "    \n",
    "    loss = 100*loss_data+loss_ge+loss_bc\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    return loss_ge.item(), loss_bc.item()\n",
    "\n",
    "def test(model_source, loss_f):\n",
    "    model_source.eval()\n",
    "    \n",
    "    source_output = model_source(grid)\n",
    "    return ((source-source_output).pow(2).mean().sqrt()\\\n",
    "                       /source.pow(2).mean().sqrt()).item()\n",
    "\n",
    "\n",
    "def plot(model_solution, model_source): \n",
    "    model_solution.eval()\n",
    "    model_source.eval()\n",
    "    \n",
    "    source_output = model_source(grid)\n",
    "    solution_output = model_solution(grid)\n",
    "    \n",
    "    figure = plt.figure(figsize=(15,4)) \n",
    "    ax = figure.add_subplot(1,4,1)\n",
    "    sc=ax.pcolor(source.cpu().detach().view(Nx+1, Ny+1))\n",
    "    plt.colorbar(sc)\n",
    "    ax.set_title('source true')\n",
    "    ax = figure.add_subplot(1,4,2)\n",
    "    sc=ax.pcolor(source_output.cpu().detach().view(Nx+1, Ny+1))\n",
    "    plt.colorbar(sc)\n",
    "    ax.set_title('source prediction')\n",
    "    ax = figure.add_subplot(1,4,3)\n",
    "    sc=ax.pcolor(solution_whole.cpu().detach().view(Nx+1, Ny+1))\n",
    "    plt.colorbar(sc)\n",
    "    ax.set_title('solution true ')\n",
    "    ax = figure.add_subplot(1,4,4)\n",
    "    sc=ax.pcolor(solution_output.cpu().detach().view(Nx+1, Ny+1))\n",
    "    plt.colorbar(sc)\n",
    "    ax.set_title('solution prediction ')\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01bd73c0",
   "metadata": {
    "ExecuteTime": {
     "start_time": "2024-11-13T07:37:24.606Z"
    }
   },
   "outputs": [],
   "source": [
    "model_solution = net(hidden_dims=[2,64,64,64,1], act=nn.Tanh()).to(device)\n",
    "model_source = net(hidden_dims=[2,64,64,64,1], act=nn.Tanh()).to(device)\n",
    "EPOCH=1000000\n",
    "optimizer=torch.optim.Adam([{'params': model_solution.parameters()},\\\n",
    "                            {'params': model_source.parameters()}], lr=1e-3)#, lr=1e-5\n",
    "\n",
    "for t in tqdm_notebook(range(EPOCH)) :\n",
    "    \n",
    "    loss_ge,loss_bc = train(model_solution, model_source, \\\n",
    "                                   optimizer=optimizer, loss_f=nn.MSELoss())\n",
    "    test_loss = test(model_source, loss_f=nn.MSELoss())\n",
    "    # Print Log\n",
    "    if t%100 == 0 :\n",
    "        print(\"%s/%s | loss_ge: %04.6f | loss_bc: %04.6f |  test loss: %04.6f\" \\\n",
    "              % (t, EPOCH, loss_ge, loss_bc, test_loss))\n",
    "        \n",
    "    if t%1000 == 0 :\n",
    "        plot(model_solution, model_source)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82f6bb1b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b89fb276",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33bd818b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
