{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2440e3d8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:42:53.555307Z",
     "start_time": "2024-10-02T04:42:52.148777Z"
    }
   },
   "outputs": [],
   "source": [
    "# import modules\n",
    "from tqdm import tqdm_notebook\n",
    "import numpy as np\n",
    "\n",
    "# deep learning modules\n",
    "import torch\n",
    "\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Plot modules\n",
    "import matplotlib.pyplot as plt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ed7147d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:42:53.596658Z",
     "start_time": "2024-10-02T04:42:53.557405Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "# Use Gpu\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dba2beb6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:42:54.096736Z",
     "start_time": "2024-10-02T04:42:54.056820Z"
    }
   },
   "outputs": [],
   "source": [
    "data = torch.load('Helmholtz_neumann_51.pt')\n",
    "\n",
    "source = data[:,0]\n",
    "solution_whole = data[:,1]\n",
    "\n",
    "source.shape, solution_whole.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d88191e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:42:54.925958Z",
     "start_time": "2024-10-02T04:42:54.900574Z"
    }
   },
   "outputs": [],
   "source": [
    "Nt, Nx = 50,50\n",
    "ht, hx = 1/Nt, 1/Nx\n",
    "t, x = torch.linspace(0,1,Nt+1), torch.linspace(0,1,Nx+1)\n",
    "T, X = torch.meshgrid(t,x)\n",
    "\n",
    "grid = torch.stack([T,X], axis=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdd3a0f2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:42:55.410836Z",
     "start_time": "2024-10-02T04:42:55.398460Z"
    }
   },
   "outputs": [],
   "source": [
    "Ncut=5\n",
    "if Ncut != 0:\n",
    "    solution = solution_whole.reshape(-1, Nt+1, Nx+1)[:, Ncut:-Ncut, Ncut:-Ncut].reshape(2000,-1)\n",
    "else :\n",
    "    solution = solution_whole.clone().squeeze(-1)\n",
    "solution.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f117e3e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:42:56.070519Z",
     "start_time": "2024-10-02T04:42:55.910454Z"
    }
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 500\n",
    "\n",
    "n_test = 1000\n",
    "n_train = 50\n",
    "\n",
    "train_source = source[:n_train]\n",
    "test_source = source[n_test:]\n",
    "train_solution = solution[:n_train]\n",
    "test_solution = solution[n_test:]\n",
    "train_xy = grid.view(-1,2).to(device).requires_grad_(True)\n",
    "\n",
    "train_dataset = TensorDataset(train_solution, train_source)\n",
    "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "\n",
    "#test_dataset = TensorDataset(test_solution)\n",
    "solution_test_loader = DataLoader(test_solution, batch_size=BATCH_SIZE, shuffle=False)\n",
    "source_test_loader = DataLoader(test_source, batch_size=BATCH_SIZE, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f50dc80d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:42:57.411289Z",
     "start_time": "2024-10-02T04:42:57.406610Z"
    }
   },
   "outputs": [],
   "source": [
    "m =(Nx+1-2*Ncut)**2 # Sensors\n",
    "Q =(Nx+1)*(Nx+1) # Inverse discrete"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe92a213",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:42:57.809230Z",
     "start_time": "2024-10-02T04:42:57.804142Z"
    }
   },
   "outputs": [],
   "source": [
    "def calculate_derivative(y, x) :\n",
    "    return torch.autograd.grad(y, x, create_graph=True,\\\n",
    "                        grad_outputs=torch.ones(y.size()).to(device))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "601eccfc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:42:58.816091Z",
     "start_time": "2024-10-02T04:42:58.794935Z"
    }
   },
   "outputs": [],
   "source": [
    "hidden_dim = 128\n",
    "\n",
    "class branch_conv_net(nn.Module):\n",
    "    def __init__(self, hidden_dim) :                    # Hidden_dims : [h1, h2, h3, ..., hn]\n",
    "        super(branch_conv_net, self).__init__()\n",
    "        self.h = hidden_dim\n",
    "        self.conv1 = nn.Conv2d(1, 6, 3)\n",
    "        self.pool = nn.MaxPool2d(3, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 3)\n",
    "        self.conv3 = nn.Conv2d(16, 32, 3)\n",
    "        self.fc1 = nn.Linear(128, self.h)\n",
    "        self.fc2 = nn.Linear(self.h, self.h)\n",
    "\n",
    "    def forward(self, x) :\n",
    "        x = self.pool(F.gelu(self.conv1(x)))\n",
    "        x = self.pool(F.gelu(self.conv2(x)))\n",
    "        x = self.pool(F.gelu(self.conv3(x)))\n",
    "        x = torch.flatten(x, 1) # flatten all dimensions except batch\n",
    "        x = F.gelu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "    \n",
    "class trunk_net(nn.Module):\n",
    "    def __init__(self, hidden_dims, act) :                    # Hidden_dims : [h1, h2, h3, ..., hn]\n",
    "        super(trunk_net, self).__init__()\n",
    "        \n",
    "        self.layers = []\n",
    "        for i in range(len(hidden_dims)-1) :\n",
    "            self.layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1])) # hidden layers\n",
    "        self.layers = nn.ModuleList(self.layers)\n",
    "        \n",
    "        for layer in self.layers :                       # Weight initialization\n",
    "            nn.init.xavier_uniform_(layer.weight)        # Also known as Glorot initialization\n",
    "            \n",
    "        self.act = act#nn.Tanh()   #nn.ReLU() #                          # Nonlinear activation function\n",
    "        \n",
    "        \n",
    "    def forward(self, x) :\n",
    "        x = self.act(self.layers[0](x))\n",
    "        for layer in self.layers[1:-1] :\n",
    "            x = self.act(layer(x)) \n",
    "        x = self.layers[-1](x)\n",
    "        return x\n",
    "\n",
    "class deepOnet(nn.Module):\n",
    "    def __init__(self, branch_sol, branch_source, trunk):\n",
    "        super(deepOnet, self).__init__()\n",
    "        self.branch_sol = branch_sol\n",
    "        self.branch_source = branch_source\n",
    "        self.trunk = trunk\n",
    "        \n",
    "            \n",
    "    def forward(self, u, y):\n",
    "        branch_sol_out = self.branch_sol(u)\n",
    "        branch_source_out = self.branch_source(u)\n",
    "        trunk_out = self.trunk(y).view(-1, Q, self.branch_sol.h)\n",
    "\n",
    "        sol_out= torch.einsum('ab,acb->ac', branch_sol_out, trunk_out)#(branch_sol_out * trunk_out).sum(dim=1).unsqueeze(1)\n",
    "        source_out= torch.einsum('ab,acb->ac', branch_source_out, trunk_out)#(branch_source_out * trunk_out).sum(dim=1).unsqueeze(1)\n",
    "        return sol_out.view(-1,1), source_out.view(-1,1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f018e42",
   "metadata": {},
   "source": [
    "# s -> u -> s\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "880fb678",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:43:00.085497Z",
     "start_time": "2024-10-02T04:43:00.057360Z"
    }
   },
   "outputs": [],
   "source": [
    "def train(model, optimizer, loss_f) :\n",
    "    model.train()\n",
    "    loss_list, loss_ge_list, loss_data_list, loss_bc_list = [], [], [], []\n",
    "    \n",
    "    for sol, source in train_loader :\n",
    "        batch_size=len(sol)\n",
    "        # s : batch_size x m\n",
    "        optimizer.zero_grad()\n",
    "        sol = sol.to(device) \n",
    "        source = source.to(device)\n",
    "        \n",
    "        inv_xy = train_xy.repeat(batch_size,1) #batch_size, 1\n",
    "\n",
    "        sol_batch = sol.view(-1,1, Nx+1-2*Ncut, Nx+1-2*Ncut) # batch_size, 1, Nx+1-2*cut, Ny+1-2*cut\n",
    "        sol_output, source_output = model(sol_batch, inv_xy) # (batch_size, m), (batch_size, m)\n",
    "        \n",
    "        sol_output_grad = calculate_derivative(sol_output, inv_xy) \n",
    "        sol_output_x, sol_output_y = sol_output_grad[:,0].view(-1,1), sol_output_grad[:,1].view(-1,1)\n",
    "        \n",
    "        ge = calculate_derivative(sol_output_x, inv_xy)[:,0].view(-1,1) +\\\n",
    "             calculate_derivative(sol_output_y, inv_xy)[:,1].view(-1,1) +\\\n",
    "             sol_output -\\\n",
    "             source_output\n",
    "        \n",
    "        loss_ge = loss_f(ge, torch.zeros_like(ge))\n",
    "        \n",
    "        sol_pred_reshaped = sol_output.reshape(batch_size, Nx+1, Nx+1)\n",
    "        if Ncut!=0 :\n",
    "            sol_pred_interior = sol_pred_reshaped[:,Ncut:-Ncut, Ncut:-Ncut].reshape(-1,1)\n",
    "        else :\n",
    "            sol_pred_interior = sol_pred_reshaped.reshape(-1,1)\n",
    "        \n",
    "        left = inv_xy[:,0]==0\n",
    "        right = inv_xy[:,0]==1\n",
    "        top = inv_xy[:,1]==1\n",
    "        bottom = inv_xy[:,1]==0\n",
    "        grad_left, grad_right, grad_top, grad_bottom = sol_output_grad[:,0][left],\\\n",
    "                                                       sol_output_grad[:,0][right],\\\n",
    "                                                       sol_output_grad[:,1][top],\\\n",
    "                                                       sol_output_grad[:,1][bottom]\n",
    "        \n",
    "        neumann = torch.cat([grad_left, grad_right, grad_top, grad_bottom]).view(-1,1)\n",
    "        loss_data = 100*(loss_f(sol_pred_interior, sol.view(-1,1))+loss_f(source_output, source.view(-1,1))) #+ loss_f(source_output, f_output)\n",
    "        loss_bc = loss_f(neumann, torch.zeros_like(neumann))\n",
    "        loss = loss_ge+loss_data+loss_bc\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        loss_list.append(loss.item())\n",
    "        loss_ge_list.append(loss_ge.item())\n",
    "        loss_data_list.append(loss_data.item())\n",
    "        loss_bc_list.append(loss_bc.item())\n",
    "        \n",
    "    return np.mean(loss_list), np.mean(loss_ge_list), np.mean(loss_data_list), np.mean(loss_bc_list)\n",
    "\n",
    "\n",
    "def compute_test(model) :\n",
    "    loss_f = nn.MSELoss()\n",
    "    model.eval()\n",
    "    losses = []\n",
    "    for sol, source in zip(solution_test_loader, source_test_loader) :\n",
    "        batch_size = len(sol)\n",
    "        sol=sol.to(device)\n",
    "        source=source.to(device)\n",
    "        \n",
    "        inv_xy = train_xy.repeat(batch_size,1) \n",
    "        sol_test = sol.view(-1,1, Nx+1-2*Ncut, Nx+1-2*Ncut) \n",
    "        sol_output, source_output = model(sol_test, inv_xy) # (batch_size, m), (batch_size, m)\n",
    "        losses.append(loss_f(source_output, source.view(-1,1)).item())\n",
    "        \n",
    "    return np.mean(losses)\n",
    "\n",
    "def plot(model) :\n",
    "    model.eval()\n",
    "    rand = np.random.randint(len(test_source))\n",
    "    source, solution = test_source[rand].unsqueeze(0), test_solution[rand].unsqueeze(0)\n",
    "    \n",
    "    solution_plot = solution_whole[n_train+rand].unsqueeze(0)\n",
    "    inv_xy = train_xy.unsqueeze(0).view(-1,2) \n",
    "    solution_output, source_output = model(solution.view(-1,1, Nx+1-2*Ncut, Nx+1-2*Ncut).to(device) , inv_xy)\n",
    "    \n",
    "    figure = plt.figure(figsize=(15,4)) \n",
    "    ax = figure.add_subplot(1,4,1)\n",
    "    sc=ax.pcolor(source.cpu().detach().view(Nx+1, Nx+1), cmap='jet')\n",
    "    plt.colorbar(sc)\n",
    "    ax.set_title('source true')\n",
    "    ax = figure.add_subplot(1,4,2)\n",
    "    sc=ax.pcolor(source_output.cpu().detach().view(Nx+1, Nx+1), cmap='jet')\n",
    "    plt.colorbar(sc)\n",
    "    ax.set_title('source prediction')\n",
    "    ax = figure.add_subplot(1,4,3)\n",
    "    sc=ax.pcolor(solution_plot.cpu().detach().view(Nx+1, Nx+1), cmap='jet')\n",
    "    plt.colorbar(sc)\n",
    "    ax.set_title('solution true ')\n",
    "    ax = figure.add_subplot(1,4,4)\n",
    "    sc=ax.pcolor(solution_output.cpu().detach().view(Nx+1, Nx+1), cmap='jet')\n",
    "    plt.colorbar(sc)\n",
    "    ax.set_title('solution prediction ')\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fb3013a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:43:01.504783Z",
     "start_time": "2024-10-02T04:43:01.489739Z"
    }
   },
   "outputs": [],
   "source": [
    "branch_sol_model = branch_conv_net(hidden_dim).to(device)\n",
    "branch_source_model = branch_conv_net(hidden_dim).to(device)\n",
    "trunk_model = trunk_net(hidden_dims=[2,hidden_dim,hidden_dim,hidden_dim], act=nn.Tanh()).to(device)\n",
    "model = deepOnet(branch_sol_model, branch_source_model, trunk_model).to(device) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8189561e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-21T01:18:23.498495Z",
     "start_time": "2024-08-14T04:58:41.816382Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "EPOCH=1000000\n",
    "optimizer=torch.optim.Adam([{'params': model.parameters()}], lr=1e-3)#, lr=1e-5\n",
    "\n",
    "for t in tqdm_notebook(range(EPOCH)) :\n",
    "    \n",
    "    loss,loss_ge,loss_data,loss_bc = train(model,\\\n",
    "                                   optimizer=optimizer, loss_f=nn.MSELoss())\n",
    "    test_loss = compute_test(model)\n",
    "    # Print Log\n",
    "    if t%100 == 0 :\n",
    "        print(\"%s/%s | loss_ge: %04.6f | loss_data: %04.6f | loss_bc: %04.6f |  test loss: %04.6f\" \\\n",
    "              % (t, EPOCH, loss_ge, loss_data, loss_bc, test_loss))\n",
    "        \n",
    "    if t%10000 == 0:\n",
    "        plot(model)\n",
    "        torch.save([train_source, test_source, train_solution, test_solution, train_xy, \\\n",
    "            model], 'PI-DIONs_supervised_Ncut={}_Ntrain={}'.format(Ncut, n_train))        "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
