{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "33716b29",
   "metadata": {},
   "source": [
    "# $u(x,T) \\rightarrow f(x)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2440e3d8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:44:13.091328Z",
     "start_time": "2024-10-02T04:44:11.664780Z"
    }
   },
   "outputs": [],
   "source": [
    "# import modules\n",
    "from tqdm import tqdm_notebook\n",
    "import numpy as np\n",
    "\n",
    "# deep learning modules\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Plot modules\n",
    "import matplotlib.pyplot as plt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ed7147d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:44:13.129210Z",
     "start_time": "2024-10-02T04:44:13.092890Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "# Use Gpu\n",
    "\n",
    "device = torch.device(\"cuda:3\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d88191e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:44:13.169790Z",
     "start_time": "2024-10-02T04:44:13.130805Z"
    }
   },
   "outputs": [],
   "source": [
    "Nt, Nx = 30,30\n",
    "ht, hx = 1/Nt, 1/Nx\n",
    "\n",
    "t, x = torch.linspace(0,1,Nt+1), torch.linspace(0,1,Nx+1)\n",
    "T, X = torch.meshgrid(t,x)\n",
    "\n",
    "grid = torch.stack([T,X], axis=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06025678",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:44:13.192205Z",
     "start_time": "2024-10-02T04:44:13.172966Z"
    }
   },
   "outputs": [],
   "source": [
    "data_solution, data_source = torch.load('RD_random_init_homogeneous_Dirichlet_data_31.pt')\n",
    "data_solution, data_sourec = data_solution[:2000], data_source[:2000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ad8bdd3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:44:13.200348Z",
     "start_time": "2024-10-02T04:44:13.194073Z"
    }
   },
   "outputs": [],
   "source": [
    "data_solution.size(), data_source.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "943b554b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:44:13.215030Z",
     "start_time": "2024-10-02T04:44:13.201743Z"
    }
   },
   "outputs": [],
   "source": [
    "mask = torch.ones(31,31)\n",
    "mask[1:-1, 1:-1] = 0\n",
    "mask = mask.bool()\n",
    "data_solution[:,mask].size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f117e3e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:44:13.362621Z",
     "start_time": "2024-10-02T04:44:13.217544Z"
    }
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE=1000\n",
    "n_train = 50\n",
    "n_test = 1000\n",
    "\n",
    "train_source = data_source[:n_train]\n",
    "test_source = data_source[n_test:]\n",
    "train_solution = data_solution[:, [0,-1]].view(-1, 2*(Nt+1))[:n_train]\n",
    "test_solution = data_solution[:, [0,-1]].view(-1, 2*(Nt+1))[n_test:]\n",
    "\n",
    "train_label = data_solution[:,mask][:n_train]\n",
    "test_label = data_solution[:,mask][n_test:]\n",
    "\n",
    "train_tx = grid.view(-1,2).to(device).requires_grad_(True)\n",
    "train_x = x.view(-1,1).to(device)\n",
    "\n",
    "train_dataset = TensorDataset(train_solution, train_label, train_source)\n",
    "test_dataset = TensorDataset(test_solution, test_label)\n",
    "\n",
    "solution_train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "source_train_loader = DataLoader(train_source, batch_size=BATCH_SIZE, shuffle=True)\n",
    "\n",
    "solution_test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
    "source_test_loader = DataLoader(test_source, batch_size=BATCH_SIZE, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "413b8b61",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:44:13.368108Z",
     "start_time": "2024-10-02T04:44:13.364567Z"
    }
   },
   "outputs": [],
   "source": [
    "train_solution.size(), train_source.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3b231d6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:44:13.464978Z",
     "start_time": "2024-10-02T04:44:13.369199Z"
    }
   },
   "outputs": [],
   "source": [
    "plt.plot(train_solution[-1].cpu())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e56a3bb8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:44:13.530576Z",
     "start_time": "2024-10-02T04:44:13.466507Z"
    }
   },
   "outputs": [],
   "source": [
    "plt.plot(train_label[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f50dc80d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:44:13.533982Z",
     "start_time": "2024-10-02T04:44:13.531517Z"
    }
   },
   "outputs": [],
   "source": [
    "m = train_solution.size()[1] # Sensors\n",
    "Q = train_source.size()[1] # Inverse discrete"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe92a213",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:44:13.549460Z",
     "start_time": "2024-10-02T04:44:13.536537Z"
    }
   },
   "outputs": [],
   "source": [
    "def calculate_derivative(y, x) :\n",
    "    return torch.autograd.grad(y, x, create_graph=True,\\\n",
    "                        grad_outputs=torch.ones(y.size()).to(device))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "601eccfc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:44:13.569932Z",
     "start_time": "2024-10-02T04:44:13.552070Z"
    }
   },
   "outputs": [],
   "source": [
    "class branch_net(nn.Module):\n",
    "    def __init__(self, hidden_dims) :                    # Hidden_dims : [h1, h2, h3, ..., hn]\n",
    "        super(branch_net, self).__init__()\n",
    "        \n",
    "        self.layers = []\n",
    "        for i in range(len(hidden_dims)-1) :\n",
    "            self.layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1])) # hidden layers\n",
    "        self.layers = nn.ModuleList(self.layers)\n",
    "        \n",
    "        for layer in self.layers :                       # Weight initialization\n",
    "            nn.init.xavier_uniform_(layer.weight)        # Also known as Glorot initialization\n",
    "            \n",
    "        self.act = nn.ReLU()                             # Nonlinear activation function\n",
    "    \n",
    "    def forward(self, x) :\n",
    "        x = self.act(self.layers[0](x))\n",
    "        for layer in self.layers[1:-1] :\n",
    "            x = self.act(layer(x)) \n",
    "        x = self.layers[-1](x)\n",
    "        return x\n",
    "\n",
    "class trunk_net(nn.Module):\n",
    "    def __init__(self, hidden_dims, act) :                    # Hidden_dims : [h1, h2, h3, ..., hn]\n",
    "        super(trunk_net, self).__init__()\n",
    "        \n",
    "        self.layers = []\n",
    "        for i in range(len(hidden_dims)-1) :\n",
    "            self.layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1])) # hidden layers\n",
    "        self.layers = nn.ModuleList(self.layers)\n",
    "        \n",
    "        for layer in self.layers :                       # Weight initialization\n",
    "            nn.init.xavier_uniform_(layer.weight)        # Also known as Glorot initialization\n",
    "            \n",
    "        self.act = act#nn.Tanh()   #nn.ReLU() #                          # Nonlinear activation function\n",
    "    \n",
    "    def forward(self, x) :\n",
    "        x = self.act(self.layers[0](x))\n",
    "        for layer in self.layers[1:-1] :\n",
    "            x = self.act(layer(x)) \n",
    "        x = self.layers[-1](x)\n",
    "        return x\n",
    "\n",
    "class deepOnet(nn.Module):\n",
    "    def __init__(self, branch, trunk):\n",
    "        super(deepOnet, self).__init__()\n",
    "        self.branch = branch\n",
    "        self.trunk = trunk\n",
    "        \n",
    "    def forward(self, u, y):\n",
    "        batch_size = len(u)\n",
    "        branch_out = self.branch(u)\n",
    "        trunk_out = self.trunk(y).view(batch_size, -1, 32)\n",
    "        out = torch.einsum('ab,acb->ac', branch_out, trunk_out)\n",
    "        return out.view(-1,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "880fb678",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:44:13.588849Z",
     "start_time": "2024-10-02T04:44:13.571712Z"
    }
   },
   "outputs": [],
   "source": [
    "k, D = 0.01, 0.01\n",
    "\n",
    "def train(model_inv, model_for, optimizer, loss_f) :\n",
    "    model_inv.train()\n",
    "    model_for.train()\n",
    "    loss_list, loss_ge_list, loss_data_list = [], [], []\n",
    "    \n",
    "    for sol, label, source in solution_train_loader :\n",
    "        # s : batch_size, m\n",
    "        optimizer.zero_grad()\n",
    "        batch_size = len(sol)\n",
    "        sol = sol.to(device) # batch_size, m\n",
    "        source = source.to(device)\n",
    "        label = label.to(device)\n",
    "        y_res = train_x.repeat(batch_size,1) # batch_size*Q, 1\n",
    "        source_output = model_inv(sol, y_res) # batch_size*Q, 1\n",
    "\n",
    "        tx_res = train_tx.repeat(batch_size,1) #batch_size*Q^2, 2\n",
    "        sol_output = model_for(sol, tx_res) #batch_size*Q^2, 1\n",
    "\n",
    "        sol_output_grad = calculate_derivative(sol_output, tx_res) #batch_size*m, 1\n",
    "        sol_output_t = sol_output_grad[:,0].view(-1,1)\n",
    "        sol_output_xx = calculate_derivative(sol_output_grad[:,1], tx_res)[:,1].view(-1,1)\n",
    "        \n",
    "        ge = sol_output_t - D*sol_output_xx - k*sol_output.pow(2) -\\\n",
    "                source_output.view(batch_size, Q).repeat(1, Nt+1).view(-1,1)\n",
    "\n",
    "        loss_ge = loss_f(ge, torch.zeros_like(ge))\n",
    "        sol_pred = sol_output.reshape(batch_size, Nt+1, Nx+1)[:, mask]\n",
    "        \n",
    "        loss_data = 100*(loss_f(sol_pred, label)+loss_f(source_output, source.view(-1,1)))\n",
    "        loss = loss_ge+loss_data\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        loss_list.append(loss.item())\n",
    "        loss_ge_list.append(loss_ge.item())\n",
    "        loss_data_list.append(loss_data.item())\n",
    "        \n",
    "    return np.mean(loss_list), np.mean(loss_ge_list), np.mean(loss_data_list)\n",
    "\n",
    "\n",
    "def compute_test(model_inv, model_for) :\n",
    "    loss_f = nn.MSELoss()\n",
    "    model_inv.eval()\n",
    "    model_for.eval()\n",
    "    losses = []\n",
    "    for (sol, label), source in zip(solution_test_loader, source_test_loader) :\n",
    "        batch_size = len(sol)\n",
    "        sol=sol.to(device)\n",
    "        source=source.to(device)\n",
    "        \n",
    "        x_test = train_x.repeat(batch_size,1) \n",
    "        source_output = model_inv(sol, x_test) \n",
    "        losses.append(loss_f(source_output, source.view(-1,1)).item())\n",
    "        \n",
    "    return np.mean(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fb3013a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:44:13.612396Z",
     "start_time": "2024-10-02T04:44:13.590216Z"
    }
   },
   "outputs": [],
   "source": [
    "branch_source_model = branch_net(hidden_dims=[m,32,32,32]).to(device)\n",
    "trunk_source_model = trunk_net(hidden_dims=[1,32,32,32], act=nn.ReLU()).to(device)\n",
    "\n",
    "branch_sol_model = branch_net(hidden_dims=[m,32,32,32]).to(device)\n",
    "trunk_sol_model = trunk_net(hidden_dims=[2,32,32,32], act=nn.Tanh()).to(device)\n",
    "\n",
    "model_inv = deepOnet(branch_source_model, trunk_source_model).to(device) \n",
    "model_for = deepOnet(branch_sol_model, trunk_sol_model).to(device) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8189561e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-02T04:44:15.744310Z",
     "start_time": "2024-10-02T04:44:13.614515Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    " EPOCH=1000000\n",
    "optimizer=torch.optim.Adam([{'params': model_inv.parameters()},\\\n",
    "                            {'params': model_for.parameters()}])#, lr=1e-5\n",
    "\n",
    "for t in tqdm_notebook(range(EPOCH)) :\n",
    "    \n",
    "    loss,loss_ge,loss_data = train(model_inv, model_for,\\\n",
    "                                   optimizer=optimizer, loss_f=nn.MSELoss())\n",
    "    test_loss = compute_test(model_inv, model_for)\n",
    "    # Print Log\n",
    "    if t%10 == 0 :\n",
    "        print(\"%s/%s | loss_ge: %04.6f | loss_data: %04.6f | test loss: %04.6f\" \\\n",
    "              % (t, EPOCH, loss_ge, loss_data, test_loss))\n",
    "        torch.save([model_inv, model_for], 'PI-DIONs_supervised_Ntrain={}'.format(n_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cb41ed3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
