{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-01-31T08:22:28.608474Z",
     "start_time": "2020-01-31T08:22:27.561591Z"
    }
   },
   "outputs": [],
   "source": [
    "# import modules\n",
    "from tqdm import tqdm_notebook\n",
    "import pickle as pkl\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# deep learning modules\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Plot modules\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-01-31T08:22:28.628241Z",
     "start_time": "2020-01-31T08:22:28.610444Z"
    }
   },
   "outputs": [],
   "source": [
    "# Use Gpu\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beta, q = 1,1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-01-31T08:22:28.758289Z",
     "start_time": "2020-01-31T08:22:28.751782Z"
    }
   },
   "outputs": [],
   "source": [
    "# Make grid points\n",
    "\n",
    "tmin, tmax = 0, 1\n",
    "xmin, xmax = 0, 1\n",
    "vmin, vmax = -5, 5\n",
    "\n",
    "Nt, Nx, Nv = 31, 31, 31\n",
    "sxy = np.mgrid[tmin:tmax:31j, xmin:xmax:31j, vmin:vmax:31j].reshape(3, -1).T\n",
    "df_real = pd.DataFrame(sxy, columns=[\"t\", 'x', 'v'])\n",
    "\n",
    "X_train = df_real.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-01-31T08:22:30.844840Z",
     "start_time": "2020-01-31T08:22:30.807697Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# generate initial data\n",
    "\n",
    "X_ini = df_real.copy().loc[df_real['t']==0]\n",
    "\n",
    "def initial_generate(row) :\n",
    "    t,x,v=row\n",
    "    return np.exp(-v**2)/(1.77245)#(np.sin(np.pi*x)/np.sqrt(2*np.pi)) * np.exp(-((v**2)/2)) * (1/0.63662)\n",
    "\n",
    "def initial_der_generate(row) :\n",
    "    t,x,v=row\n",
    "    return -2*v*np.exp(-v**2)/(1.77245)#np.pi* (np.cos(np.pi*x)/np.sqrt(2*np.pi)) * np.exp(-((v**2)/2)) * (1/0.63662)\n",
    "    \n",
    "f_ini = X_ini.apply(initial_generate, axis=1) \n",
    "f_der_ini = X_ini.apply(initial_der_generate, axis=1) \n",
    "\n",
    "X_ini = X_ini.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-01-31T08:22:31.373430Z",
     "start_time": "2020-01-31T08:22:30.993080Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15,5))\n",
    "x = np.linspace(xmin, xmax, Nx) # discretization of space\n",
    "v = np.linspace(vmin, vmax, Nv) # discretization of space\n",
    "X, V = np.meshgrid(x, v)\n",
    "\n",
    "ax = fig.add_subplot(1,2,1, projection='3d')\n",
    "surf = ax.plot_surface(X, V, f_ini.values.reshape(Nx, Nv).T, alpha=0.7, label='f')\n",
    "surf._facecolors2d=surf._facecolors3d\n",
    "surf._edgecolors2d=surf._edgecolors3d\n",
    "\n",
    "ax.zaxis._axinfo['juggled'] = (1,2,2)\n",
    "ax.set_title('Initial Condition u_0, v_0')\n",
    "ax.set_xlabel('x')\n",
    "ax.set_ylabel('v')\n",
    "ax.set_zlabel('initial')\n",
    "ax.locator_params(axis='f', nbins=6)\n",
    "ax.locator_params(axis='x', nbins=5)\n",
    "ax.locator_params(axis='v', nbins=5)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_mass = f_ini.mean()*10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_ter = df_real.copy().loc[df_real['t']==tmax]\n",
    "\n",
    "def equilibrium(row):\n",
    "    t,x,v = row\n",
    "    return (1/((2*np.pi*(q/beta))**(0.5))) * np.exp(-0.5*(beta/q)*v**2)\n",
    "\n",
    "maxwellian = X_ter.apply(equilibrium, axis=1)*f_ini.mean()*(xmax-xmin)*(vmax-vmin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15,5))\n",
    "x = np.linspace(xmin, xmax, Nx) # discretization of space\n",
    "v = np.linspace(vmin, vmax, Nv) # discretization of space\n",
    "X, V = np.meshgrid(x, v)\n",
    "\n",
    "ax = fig.add_subplot(1,2,1, projection='3d')\n",
    "surf = ax.plot_surface(X, V, maxwellian.values.reshape(Nx, Nv).T, alpha=0.7, label='f')\n",
    "surf._facecolors2d=surf._facecolors3d\n",
    "surf._edgecolors2d=surf._edgecolors3d\n",
    "\n",
    "ax.zaxis._axinfo['juggled'] = (1,2,2)\n",
    "ax.set_title('Initial Condition u_0, v_0')\n",
    "ax.set_xlabel('x')\n",
    "ax.set_ylabel('v')\n",
    "ax.set_zlabel('initial')\n",
    "ax.locator_params(axis='f', nbins=6)\n",
    "ax.locator_params(axis='x', nbins=5)\n",
    "ax.locator_params(axis='v', nbins=5)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(f_ini.values[-31:], label='initial')\n",
    "plt.plot(maxwellian.values[-31:], label='maxwellian')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_bdry_left = df_real[df_real['x']==xmin]\n",
    "X_bdry_right = df_real[df_real['x']==xmax]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert initial, boundary points into gpu tensor\n",
    "\n",
    "X_ini = Variable(torch.FloatTensor(X_ini).to(device), requires_grad=True)\n",
    "f_ini = torch.FloatTensor(f_ini).to(device).view(-1,1)\n",
    "f_der_ini = torch.FloatTensor(f_der_ini).to(device).view(-1,1)\n",
    "X_bdry_left = Variable(torch.FloatTensor(X_bdry_left.values).to(device), requires_grad=True)\n",
    "X_bdry_right = Variable(torch.FloatTensor(X_bdry_right.values).to(device), requires_grad=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "solution = pd.read_pickle('numerical_solution_c=100_exp_v2.pkl')[1:]\n",
    "\n",
    "xmin, xmax = 0, 1\n",
    "umin, umax = -0.995, 0.995\n",
    "c=100\n",
    "\n",
    "x = np.linspace(xmin, xmax, solution[0].shape[1])\n",
    "u = np.linspace(umin, umax, solution[0].shape[0])\n",
    "x_mesh, u_mesh = np.meshgrid(x, u)\n",
    "\n",
    "time = np.linspace(tmin, tmax, Nt)\n",
    "\n",
    "smin, smax = 0, tmax\n",
    "ymin, ymax = 0, 1\n",
    "vmin, vmax = -5, 5\n",
    "\n",
    "Ns, Ny, Nv = 31, 31, 31\n",
    "syv = np.mgrid[smin:smax:31j, ymin:ymax:31j, vmin:vmax:31j].reshape(3, -1).T\n",
    "df = pd.DataFrame(syv, columns=[\"s\", 'y', 'v'])\n",
    "\n",
    "transformed = []\n",
    "def transform(row) :\n",
    "    s, y, v = row\n",
    "    v0 = v*np.exp(beta*s)\n",
    "    transformed.append([s, (y-v*(np.exp(beta*s)-1)/beta)%1, v0/np.sqrt(c**2+v0**2)])\n",
    "\n",
    "df.apply(transform, axis=1)\n",
    "df_transformed = pd.DataFrame(transformed, columns=[\"s\", 'x', 'u'])\n",
    "\n",
    "from scipy import interpolate\n",
    "def eval_f(row, s, f) :\n",
    "    t, x, v = row\n",
    "    return (f(x, v)*np.exp(beta*s))[0]\n",
    "\n",
    "f = interpolate.interp2d(x, u, solution[-1], kind='cubic')\n",
    "df_plot = df_transformed.loc[df_transformed['s']==smax]\n",
    "y_new, v_new = df_plot['x'], df_plot['u']\n",
    "out_term = df_plot.apply(eval_f, s=smax, f=f, axis=1)\n",
    "\n",
    "f = interpolate.interp2d(x, u, solution[int(len(solution)/2)], kind='cubic')\n",
    "df_plot = df_transformed.loc[df_transformed['s']==smax/2]\n",
    "y_new, v_new = df_plot['x'], df_plot['u']\n",
    "out_mid = df_plot.apply(eval_f, s=smax/2, f=f, axis=1)\n",
    "\n",
    "f = interpolate.interp2d(x, u, solution[0], kind='cubic')\n",
    "df_plot = df_transformed.loc[df_transformed['s']==smin]\n",
    "y_new, v_new = df_plot['x'], df_plot['u']\n",
    "out_init = df_plot.apply(eval_f, s=smin, f=f, axis=1)\n",
    "\n",
    "x_init = df_real[df_real['t']==smin].values\n",
    "x_mid = df_real[df_real['t']==smax/2].values\n",
    "x_term = df_real[df_real['t']==smax].values\n",
    "\n",
    "compare_x = torch.FloatTensor(np.concatenate([x_init, x_mid, x_term])).to(device)\n",
    "compare_y = np.concatenate([out_init.values.reshape(-1,1), out_mid.values.reshape(-1,1), out_term.values.reshape(-1,1)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import interpolate\n",
    "f = interpolate.interp2d(x, u, solution[-1], kind='cubic')\n",
    "\n",
    "df_plot = df_transformed.loc[df_transformed['s']==smax]\n",
    "y_new, v_new = df_plot['x'].unique(), df_plot['u'].unique()\n",
    "out_term = f(y_new, v_new)*np.exp(beta*smax)\n",
    "\n",
    "fig = plt.figure(figsize=(15,5))\n",
    "k = np.linspace(ymin, ymax, len(y_new)) # discretization of space\n",
    "l = np.linspace(vmin, vmax, len(v_new)) # discretization of space\n",
    "K,L = np.meshgrid(k, l)\n",
    "\n",
    "ax = fig.add_subplot(1,2,1, projection='3d')\n",
    "surf = ax.plot_surface(K, L, out_term, alpha=0.7, label='f')\n",
    "surf._facecolors2d=surf._facecolors3d\n",
    "surf._edgecolors2d=surf._edgecolors3d\n",
    "\n",
    "ax.zaxis._axinfo['juggled'] = (1,2,2)\n",
    "ax.set_title('Initial Condition u_0, u_0')\n",
    "ax.set_xlabel('y')\n",
    "ax.set_ylabel('v')\n",
    "ax.set_zlabel('initial')\n",
    "ax.locator_params(axis='f', nbins=6)\n",
    "ax.locator_params(axis='x', nbins=5)\n",
    "ax.locator_params(axis='u', nbins=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Neural Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCH = 1000000\n",
    "BATCH_SIZE = 45000\n",
    "# Make dataloader\n",
    "data_train = TensorDataset(torch.FloatTensor(X_train))\n",
    "train_loader = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-01-31T08:22:33.806480Z",
     "start_time": "2020-01-31T08:22:33.800574Z"
    }
   },
   "outputs": [],
   "source": [
    "# Neural network, Weight Sharing\n",
    "\n",
    "class f_Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(f_Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(3, 256)\n",
    "        self.fc2 = nn.Linear(256, 256)\n",
    "        self.fc3 = nn.Linear(256, 256)\n",
    "        self.fc4 = nn.Linear(256, 256)\n",
    "        self.fc5 = nn.Linear(256, 1)\n",
    "        self.act1 = nn.Tanh()\n",
    "        self.act2 = nn.Softplus()\n",
    "    def forward(self, x):\n",
    "        x = self.act1(self.fc1(x))\n",
    "        x = self.act1(self.fc2(x))\n",
    "        x = self.act1(self.fc3(x))\n",
    "        x = self.act1(self.fc4(x))\n",
    "        x = self.fc5(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-01-31T08:22:38.737662Z",
     "start_time": "2020-01-31T08:22:38.729613Z"
    }
   },
   "outputs": [],
   "source": [
    "def calculate_derivative(y, x) :\n",
    "    return torch.autograd.grad(y, x, create_graph=True,\\\n",
    "                        grad_outputs=torch.ones(y.size()).to(device))[0]\n",
    "\n",
    "def calculate_all_partial(y, x) :\n",
    "    del_f = calculate_derivative(y, x)\n",
    "    \n",
    "    f_t, f_x, f_v = del_f[:,0], del_f[:,1], del_f[:,2]\n",
    "    f_vv = calculate_derivative(f_v, x)[:, 2]\n",
    "    \n",
    "    return f_t, f_x, f_v, f_vv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mass_conservation(output) :\n",
    "    split = torch.stack(torch.split(output, Nx*Nv)).mean(dim=1)*10\n",
    "    return split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def max_L2(compare) :\n",
    "    L2s = np.mean(np.split((compare.cpu().detach().numpy() - compare_y)**2, 3), axis=1)\n",
    "    return max(L2s)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-01-31T08:22:38.915588Z",
     "start_time": "2020-01-31T08:22:38.908622Z"
    }
   },
   "outputs": [],
   "source": [
    "# Training function\n",
    "\n",
    "def train(f_model, trainloader, optimizer, loss_f) :\n",
    "    f_model.train()\n",
    "    loss_list, loss_list1, loss_list2, loss_list3, loss_list4, err_list = [], [], [], [], [], []\n",
    "    \n",
    "    for i, (data,) in enumerate(trainloader) :\n",
    "        optimizer.zero_grad()\n",
    "        X_v = Variable(data, requires_grad=True).to(device)\n",
    "        output = f_model(X_v)\n",
    "        \n",
    "        output_bdry_l, output_bdry_r = f_model(X_bdry_left), f_model(X_bdry_right)\n",
    "        output_bdry_l_der, output_bdry_r_der = calculate_derivative(output_bdry_l, X_bdry_left).reshape(-1), \\\n",
    "                                                calculate_derivative(output_bdry_r, X_bdry_right).reshape(-1)\n",
    "        output_ini = f_model(X_ini)\n",
    "        output_der_ini = calculate_derivative(output_ini, X_ini)[:,2].view(-1,1)\n",
    "        f_t, f_x, f_v, f_vv = calculate_all_partial(output, X_v)\n",
    "        mass = mass_conservation(output)\n",
    "        \n",
    "        GE = f_t + X_v[:,2]*(f_x - beta*f_v) - beta*output.view(-1) - q*f_vv\n",
    "        del_GE = calculate_derivative(GE, X_v)\n",
    "        GE_x, GE_v = del_GE[:,1], del_GE[:,2]\n",
    "        \n",
    "        loss1 = loss_f(GE, torch.zeros_like(f_t)) + loss_f(GE_x, torch.zeros_like(GE_x)) \\\n",
    "                + loss_f(GE_v, torch.zeros_like(GE_v))\n",
    "        loss2 = loss_f(output_ini, f_ini) + loss_f(output_der_ini, f_der_ini)\n",
    "        loss3 = loss_f(output_bdry_l - output_bdry_r, torch.zeros_like(output_bdry_l)) +\\\n",
    "                loss_f(output_bdry_l_der - output_bdry_r_der, torch.zeros_like(output_bdry_l_der))\n",
    "\n",
    "        loss4 = loss_f(mass, torch.ones_like(mass)*initial_mass)\n",
    "        loss = tmax*(xmax-xmin)*(vmax-vmin)*loss1+(xmax-xmin)*(vmax-vmin)*loss2+tmax*(vmax-vmin)*loss3 +loss4\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        out_compare = f_model(compare_x)\n",
    "        err = max_L2(out_compare)\n",
    "        \n",
    "        loss_list.append(loss.item())\n",
    "        loss_list1.append(loss1.item())\n",
    "        loss_list2.append(loss2.item())\n",
    "        loss_list3.append(loss3.item())\n",
    "        loss_list4.append(loss4.item())\n",
    "        err_list.append(err.item())\n",
    "    return np.mean(loss_list), np.mean(loss_list1), np.mean(loss_list2), np.mean(loss_list3), np.mean(loss_list4), \\\n",
    "            np.mean(err_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-01-31T08:22:39.725732Z",
     "start_time": "2020-01-31T08:22:39.713029Z"
    }
   },
   "outputs": [],
   "source": [
    "data = [torch.FloatTensor(df_real.loc[df_real['t']==i].values).to(device) \\\n",
    "        for i in [df_real['t'].min(), df_real['t'].median(), df_real['t'].max()]]\n",
    "\n",
    "def plot_results(f_model, show) :\n",
    "    times = [df_real['t'].min(), df_real['t'].median(), df_real['t'].max()]\n",
    "    fig = plt.figure(figsize=(20,5))\n",
    "\n",
    "    x = np.linspace(xmin, xmax, Nx) # discretization of space\n",
    "    v = np.linspace(vmin, vmax, Nv) # discretization of space\n",
    "    X, V = np.meshgrid(x, v)\n",
    "    for i in range(1,4) :\n",
    "        prediction = (f_model(data[i-1])).cpu().detach().numpy().reshape(Nx, Nv)\n",
    "    \n",
    "        ax = fig.add_subplot(1,3, i, projection='3d')\n",
    "        surf = ax.plot_surface(X, V, prediction.T, alpha=0.7)\n",
    "        surf._facecolors2d=surf._facecolors3d\n",
    "        surf._edgecolors2d=surf._edgecolors3d\n",
    "\n",
    "        ax.zaxis._axinfo['juggled'] = (1,2,2)\n",
    "        ax.set_title('Neural Network Solution at t={}'.format(times[i-1]))\n",
    "        ax.set_xlabel('x')\n",
    "        ax.set_ylabel('v')\n",
    "        ax.set_zlabel('f')\n",
    "        ax.locator_params(axis='f', nbins=6)\n",
    "        ax.locator_params(axis='x', nbins=5)\n",
    "        ax.locator_params(axis='v', nbins=5)\n",
    "    \n",
    "    if show :\n",
    "        plt.show()\n",
    "    else :\n",
    "        return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-01-31T08:22:41.568063Z",
     "start_time": "2020-01-31T08:22:41.557360Z"
    }
   },
   "outputs": [],
   "source": [
    "# initialization\n",
    "\n",
    "f_model = f_Net()      \n",
    "f_model = f_model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-02-01T01:33:11.155490Z",
     "start_time": "2020-01-31T08:22:45.105245Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "optimizer=torch.optim.Adam([{'params': f_model.parameters()}], lr=1e-4)\n",
    "total_loss = []\n",
    "\n",
    "for t in tqdm_notebook(range(1, EPOCH+1)) :\n",
    "    \n",
    "    loss, loss1, loss2, loss3, loss4, err = train(f_model, trainloader=train_loader, \\\n",
    "                                  optimizer=optimizer, loss_f=nn.MSELoss())\n",
    "    # Print Log\n",
    "    if t%100 == 0 :\n",
    "        print(\"%s/%s | loss: %04.4f | loss1: %06.6f | loss2: %06.6f | loss3 : %06.6f | loss4 : %06.6f | err : %06.6f \" % \\\n",
    "              (t, EPOCH, loss, loss1, loss2, loss3, loss4, err))\n",
    "         \n",
    "    if t%5000==0 :\n",
    "        plot_results(f_model, show=True)\n",
    "        \n",
    "    total_loss.append(loss)\n",
    "            \n",
    "    # Save Modelview(-1)\n",
    "    if t % 10000 == 0:\n",
    "        torch.save([f_model, total_loss],\n",
    "                   'Fokker-Planck_GE_H1_IC_H1_exp_ini.pt')\n",
    "    if err < 1e-4 :\n",
    "        break\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "continued=False\n",
    "\n",
    "if continued :\n",
    "    errs_list, epochs, times = pd.read_pickle('Fokker_Planck_H1_exp_ini_fix_bdry')\n",
    "\n",
    "else :\n",
    "    epochs, times, errs_list = [] , [], []\n",
    "\n",
    "for i in tqdm_notebook(range(100-len(epochs))) :\n",
    "    f_model =  f_Net().to(device)\n",
    "    total_loss = []\n",
    "    loss1s, loss2s, loss3s, errs = [], [], [], []\n",
    "    optimizer=torch.optim.Adam([{'params': f_model.parameters()}], lr=1e-4)\n",
    "\n",
    "    EPOCH=2000\n",
    "    a = time.time()\n",
    "    saved = False\n",
    "    for t in tqdm_notebook(range(1, EPOCH+1)) :\n",
    "\n",
    "        loss, loss1, loss2, loss3, loss4, err = train(f_model, trainloader=train_loader, \\\n",
    "                                      optimizer=optimizer, loss_f=nn.MSELoss())\n",
    "        # Print Log\n",
    "        if (err<1e-4) & (saved==False):\n",
    "            print('In EPOCH {}, total loss is under 1e-4, training finished'.format(t))\n",
    "            epochs.append(t)\n",
    "            b = time.time()\n",
    "            times.append(b-a)\n",
    "            saved=True\n",
    "            \n",
    "        errs.append(err)\n",
    "    errs_list.append(errs)\n",
    "    with open('Fokker_Planck_H1_exp_ini_fix_bdry', 'wb') as f:\n",
    "        pkl.dump([errs_list, epochs, times], f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pydata",
   "language": "python",
   "name": "pydata"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
