{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2440e3d8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T05:23:44.819883Z",
     "start_time": "2024-09-24T05:23:43.173686Z"
    }
   },
   "outputs": [],
   "source": [
    "# import modules\n",
    "from tqdm import tqdm_notebook\n",
    "import numpy as np\n",
    "\n",
    "# deep learning modules\n",
    "import torch\n",
    "\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Plot modules\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ed7147d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T05:23:50.191296Z",
     "start_time": "2024-09-24T05:23:50.078573Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "# Use Gpu\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dba2beb6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T05:23:51.030522Z",
     "start_time": "2024-09-24T05:23:50.970497Z"
    }
   },
   "outputs": [],
   "source": [
    "data = torch.load('Darcy_consistent_f_{}.pt'.format(50))\n",
    "\n",
    "source = F.interpolate(data[:2000,0,:,].view(-1,1,50,50), scale_factor=[0.6,0.6]).view(-1,900)\n",
    "solution = F.interpolate(data[:2000,1,:,].view(-1,1,50,50), scale_factor=[0.6,0.6]).view(-1,900)\n",
    "\n",
    "source.shape, solution.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d88191e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T05:23:52.314561Z",
     "start_time": "2024-09-24T05:23:52.306722Z"
    }
   },
   "outputs": [],
   "source": [
    "Nt, Nx = 29, 29\n",
    "ht, hx = 1/Nt, 1/Nx\n",
    "t, x = torch.linspace(0,1,Nt+1), torch.linspace(0,1,Nx+1)\n",
    "T, X = torch.meshgrid(t,x)\n",
    "\n",
    "grid = torch.stack([T,X], axis=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5704fd9f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T05:23:52.466048Z",
     "start_time": "2024-09-24T05:23:52.459721Z"
    }
   },
   "outputs": [],
   "source": [
    "grid.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f117e3e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T05:24:31.092476Z",
     "start_time": "2024-09-24T05:24:30.942728Z"
    }
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE=1000\n",
    "n_train = 50\n",
    "n_test = 1000\n",
    "\n",
    "train_source = source[:n_train]\n",
    "test_source = source[n_test:]\n",
    "train_solution = solution[:n_train]\n",
    "test_solution = solution[n_test:]\n",
    "train_xy = grid.view(-1,2).to(device).requires_grad_(True)\n",
    "\n",
    "train_dataset = TensorDataset(train_solution, train_source)\n",
    "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "\n",
    "solution_test_loader = DataLoader(test_solution, batch_size=BATCH_SIZE, shuffle=False)\n",
    "source_test_loader = DataLoader(test_source, batch_size=BATCH_SIZE, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f50dc80d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T05:24:31.096780Z",
     "start_time": "2024-09-24T05:24:31.094253Z"
    }
   },
   "outputs": [],
   "source": [
    "m =(Nx+1)*(Nx+1)#(Nx+1-2*Ncut)**2 # Sensors\n",
    "Q =(Nx+1)*(Nx+1) # Inverse discrete"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe92a213",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T05:24:31.455393Z",
     "start_time": "2024-09-24T05:24:31.450214Z"
    }
   },
   "outputs": [],
   "source": [
    "def calculate_derivative(y, x) :\n",
    "    return torch.autograd.grad(y, x, create_graph=True,\\\n",
    "                        grad_outputs=torch.ones(y.size()).to(device))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "601eccfc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T05:24:31.751923Z",
     "start_time": "2024-09-24T05:24:31.731360Z"
    }
   },
   "outputs": [],
   "source": [
    "hidden_dim =128\n",
    "class branch_conv_net(nn.Module):\n",
    "    def __init__(self) :                    # Hidden_dims : [h1, h2, h3, ..., hn]\n",
    "        super(branch_conv_net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 6, 2)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 2)\n",
    "        self.conv3 = nn.Conv2d(16, 32, 2)\n",
    "        self.fc1 = nn.Linear(128, 128)\n",
    "        self.fc2 = nn.Linear(128, hidden_dim)\n",
    "\n",
    "    def forward(self, x) :\n",
    "        x = self.pool(F.gelu(self.conv1(x)))\n",
    "        x = self.pool(F.gelu(self.conv2(x)))\n",
    "        x = self.pool(F.gelu(self.conv3(x)))\n",
    "        x = torch.flatten(x, 1) # flatten all dimensions except batch\n",
    "        x = F.gelu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "    \n",
    "class trunk_net(nn.Module):\n",
    "    def __init__(self, hidden_dims, act) :                    # Hidden_dims : [h1, h2, h3, ..., hn]\n",
    "        super(trunk_net, self).__init__()\n",
    "        \n",
    "        self.layers = []\n",
    "        for i in range(len(hidden_dims)-1) :\n",
    "            self.layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1])) # hidden layers\n",
    "        self.layers = nn.ModuleList(self.layers)\n",
    "        \n",
    "        for layer in self.layers :                       # Weight initialization\n",
    "            nn.init.xavier_uniform_(layer.weight)        # Also known as Glorot initialization\n",
    "            \n",
    "        self.act = act#nn.Tanh()   #nn.ReLU() #                          # Nonlinear activation function\n",
    "        \n",
    "        \n",
    "    def forward(self, x) :\n",
    "        x = self.act(self.layers[0](x))\n",
    "        for layer in self.layers[1:-1] :\n",
    "            x = self.act(layer(x)) \n",
    "        x = self.layers[-1](x)\n",
    "        return x\n",
    "\n",
    "\n",
    "\n",
    "class deepOnet(nn.Module):\n",
    "    def __init__(self, branch_sol, branch_source, trunk, positive=False):\n",
    "        super(deepOnet, self).__init__()\n",
    "        self.branch_sol = branch_sol\n",
    "        self.branch_source = branch_source\n",
    "        self.trunk = trunk\n",
    "        \n",
    "        self.positive=positive\n",
    "            \n",
    "    def forward(self, u, y):\n",
    "        branch_sol_out = self.branch_sol(u)\n",
    "        branch_source_out = self.branch_source(u)\n",
    "        trunk_out = self.trunk(y).view(-1, m, hidden_dim)\n",
    "\n",
    "        sol_out= torch.einsum('ab,acb->ac', branch_sol_out, trunk_out)#(branch_sol_out * trunk_out).sum(dim=1).unsqueeze(1)\n",
    "        source_out= torch.einsum('ab,acb->ac',branch_source_out, trunk_out)#(branch_source_out * trunk_out).sum(dim=1).unsqueeze(1)\n",
    "        if self.positive:\n",
    "            source_out = F.sigmoid(source_out)\n",
    "        return sol_out.view(-1,1), source_out.view(-1,1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f018e42",
   "metadata": {},
   "source": [
    "# s -> u -> s\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "880fb678",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T05:25:20.432328Z",
     "start_time": "2024-09-24T05:25:20.406339Z"
    }
   },
   "outputs": [],
   "source": [
    "def train(model, optimizer, loss_f) :\n",
    "    model.train()\n",
    "    loss_list, loss_ge_list, loss_data_list, loss_bc_list = [], [], [], []\n",
    "    \n",
    "    for sol, source in train_loader :\n",
    "        batch_size=len(sol)\n",
    "        optimizer.zero_grad()\n",
    "        sol = sol.to(device)  #batch_size, m\n",
    "        source = source.to(device)\n",
    "        \n",
    "        inv_xy = train_xy.unsqueeze(0).repeat(batch_size, 1, 1).view(-1,2) # m,2 \n",
    "        sol_batch = sol.view(-1,1, Nx+1, Nx+1) # batch_size, 1, Nx, Nx\n",
    "        sol_output, source_output = model(sol_batch, inv_xy) # (batch_size, m), (batch_size, m)\n",
    "        \n",
    "        sol_output_grad = calculate_derivative(sol_output, inv_xy)\n",
    "        sol_x, sol_y = sol_output_grad[:,0].view(-1,1), sol_output_grad[:,1].view(-1,1)\n",
    "        \n",
    "        x, y = inv_xy[:,0].view(-1,1), inv_xy[:,1].view(-1,1)\n",
    "        ge = calculate_derivative(source_output*sol_x, inv_xy)[:,0].view(-1,1) +\\\n",
    "             calculate_derivative(source_output*sol_y, inv_xy)[:,1].view(-1,1) +\\\n",
    "             +100*x*(1-x)*y*(1-y)#+ torch.ones_like(sol_output)\n",
    "        loss_ge = loss_f(ge, torch.zeros_like(ge))\n",
    "        \n",
    "        loss_data = 1000*(loss_f(sol_output, sol.view(-1,1)) + loss_f(source_output, source.view(-1,1)))\n",
    "        loss_bc = torch.zeros_like(loss_data)#loss_f(sol_pred_bc, torch.zeros_like(sol_pred_bc))#torch.zeros_like(loss_data)#\n",
    "        loss = loss_ge+loss_data+loss_bc\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        loss_list.append(loss.item())\n",
    "        loss_ge_list.append(loss_ge.item())\n",
    "        loss_data_list.append(loss_data.item())\n",
    "        loss_bc_list.append(loss_bc.item())\n",
    "        \n",
    "    return np.mean(loss_list), np.mean(loss_ge_list), np.mean(loss_data_list), np.mean(loss_bc_list)\n",
    "\n",
    "\n",
    "def compute_test(model) :\n",
    "    loss_f = nn.MSELoss()\n",
    "    model.eval()\n",
    "    losses = []\n",
    "    for sol, source in zip(solution_test_loader, source_test_loader) :\n",
    "        batch_size = len(sol)\n",
    "        sol=sol.to(device)\n",
    "        source=source.to(device)\n",
    "        \n",
    "        inv_xy = train_xy.unsqueeze(0).repeat(batch_size, 1, 1).view(-1,2) #Q*n_test x 1\n",
    "        sol_test = sol.view(-1,1, Nx+1, Nx+1) # n_test*Q x m\n",
    "        sol_output, source_output = model(sol_test, inv_xy) # n_test x Q\n",
    "        losses.append(loss_f(source_output, source.view(-1,1)).item())\n",
    "        \n",
    "    return np.mean(losses)\n",
    "                   \n",
    "def plot(model) :\n",
    "    model.eval()\n",
    "    rand = np.random.randint(len(test_source))\n",
    "    source, solution = test_source[rand].unsqueeze(0), test_solution[rand].unsqueeze(0)\n",
    "    inv_xy = train_xy.unsqueeze(0).view(-1,2) \n",
    "    solution_output, source_output = model(solution.view(-1,1, Nx+1, Nx+1).to(device) , inv_xy)\n",
    "    \n",
    "    figure = plt.figure(figsize=(15,4)) \n",
    "    ax = figure.add_subplot(1,4,1)\n",
    "    sc=ax.pcolor(source.cpu().detach().view(Nx+1, Nx+1))\n",
    "    plt.colorbar(sc)\n",
    "    ax.set_title('source true')\n",
    "    ax = figure.add_subplot(1,4,2)\n",
    "    sc=ax.pcolor(source_output.cpu().detach().view(Nx+1, Nx+1))\n",
    "    plt.colorbar(sc)\n",
    "    ax.set_title('source prediction')\n",
    "    ax = figure.add_subplot(1,4,3)\n",
    "    sc=ax.pcolor(solution.cpu().detach().view(Nx+1, Nx+1))\n",
    "    plt.colorbar(sc)\n",
    "    ax.set_title('solution true ')\n",
    "    ax = figure.add_subplot(1,4,4)\n",
    "    sc=ax.pcolor(solution_output.cpu().detach().view(Nx+1, Nx+1))\n",
    "    plt.colorbar(sc)\n",
    "    ax.set_title('solution prediction ')\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fb3013a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T05:31:29.295328Z",
     "start_time": "2024-09-24T05:31:29.278491Z"
    }
   },
   "outputs": [],
   "source": [
    "branch_sol_model = branch_conv_net().to(device)\n",
    "branch_source_model = branch_conv_net().to(device)\n",
    "trunk_model = trunk_net(hidden_dims=[2,128,128,128], act=nn.Tanh()).to(device)\n",
    "model = deepOnet(branch_sol_model, branch_source_model, trunk_model, positive=True).to(device) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8189561e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T20:33:49.518608Z",
     "start_time": "2024-09-24T05:31:30.101511Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "EPOCH=1000000\n",
    "\n",
    "optimizer=torch.optim.Adam([{'params': model.parameters()}], lr=1e-3)#, lr=1e-5\n",
    "\n",
    "for t in tqdm_notebook(range(EPOCH)) :\n",
    "    \n",
    "    loss,loss_ge,loss_data,loss_bc = train(model,\\\n",
    "                                   optimizer=optimizer, loss_f=nn.MSELoss())\n",
    "    test_loss = compute_test(model)\n",
    "    # Print Log\n",
    "    if t%100 == 0 :\n",
    "        print(\"%s/%s | loss_ge: %04.6f | loss_data: %04.6f | loss_bc: %04.6f |  test loss: %04.6f\" \\\n",
    "              % (t, EPOCH, loss_ge, loss_data, loss_bc, test_loss))\n",
    "        \n",
    "    if t%10000 == 0:\n",
    "        torch.save([train_source, test_source, train_solution, test_solution, train_xy, \\\n",
    "            model], 'PI-DIONs_supervised_Ntrain={}.pt'.format(n_train))        \n",
    "        plot(model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
