{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"!git clone https://github.com/zysophia/Doubly_Adaptive_MCMC.git\n%cd Doubly_Adaptive_MCMC/src","metadata":{"_uuid":"85ef5870-b084-4d67-b32a-fd8e3ae8a2b5","_cell_guid":"53547483-1938-4b6d-995d-73a0f3e48e86","collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2023-08-07T05:26:38.792968Z","iopub.execute_input":"2023-08-07T05:26:38.793375Z","iopub.status.idle":"2023-08-07T05:26:41.014341Z","shell.execute_reply.started":"2023-08-07T05:26:38.793334Z","shell.execute_reply":"2023-08-07T05:26:41.013062Z"},"trusted":true},"execution_count":2,"outputs":[{"name":"stdout","text":"Cloning into 'Doubly_Adaptive_MCMC'...\nremote: Enumerating objects: 92, done.\u001b[K\nremote: Counting objects: 100% (92/92), done.\u001b[K\nremote: Compressing objects: 100% (72/72), done.\u001b[K\nremote: Total 92 (delta 26), reused 81 (delta 19), pack-reused 0\u001b[K\nReceiving objects: 100% (92/92), 5.06 MiB | 24.90 MiB/s, done.\nResolving deltas: 100% (26/26), done.\n/kaggle/working/Doubly_Adaptive_MCMC/src\n","output_type":"stream"}]},{"cell_type":"code","source":"from gibbsChains import * \nfrom algorithms import * \nfrom meanEstimator import * \n\n\nimport torch\nfrom torch import nn\n#import matplotlib.pyplot as plt\nimport numpy as np \n#from sklearn.metrics import mean_squared_error\nimport time\nimport torch.autograd.functional as F\n#import statsmodels.api as sm\nfrom torch.autograd import Variable\n#import seaborn as sns\n\nimport pandas as pd","metadata":{"_uuid":"f5321d57-1f63-4b07-aeee-dbc2176eba7c","_cell_guid":"7432abbc-12da-417d-8429-5bc87eec0dba","collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2023-08-07T05:30:41.568492Z","iopub.execute_input":"2023-08-07T05:30:41.568875Z","iopub.status.idle":"2023-08-07T05:30:41.575207Z","shell.execute_reply.started":"2023-08-07T05:30:41.568844Z","shell.execute_reply":"2023-08-07T05:30:41.574253Z"},"trusted":true},"execution_count":17,"outputs":[]},{"cell_type":"code","source":"class CFG:\n    times= 1\n\n    d = 1\n\n    M= 10\n    h = 0.001\n    N = 200\n    \n    epoch_stage_1=1000\n    epoch_stage_2=100\n    \n    optim = ['SGD','Adam'][1]\n    lr = 0.001\n    seed = 44\n    hidden_size = 256\n    \n    \n    lambda1 =1\n    lambda2 =1\n    \n    \n    print_step = 100","metadata":{"execution":{"iopub.status.busy":"2023-08-07T06:00:43.268250Z","iopub.execute_input":"2023-08-07T06:00:43.268687Z","iopub.status.idle":"2023-08-07T06:00:43.279072Z","shell.execute_reply.started":"2023-08-07T06:00:43.268651Z","shell.execute_reply":"2023-08-07T06:00:43.277897Z"},"trusted":true},"execution_count":83,"outputs":[]},{"cell_type":"code","source":"true_w =  0.9226402\ntrue_v =  1.0834867\ntrue_q =  1.174333","metadata":{"execution":{"iopub.status.busy":"2023-08-07T05:52:03.570347Z","iopub.execute_input":"2023-08-07T05:52:03.570698Z","iopub.status.idle":"2023-08-07T05:52:03.575712Z","shell.execute_reply.started":"2023-08-07T05:52:03.570671Z","shell.execute_reply":"2023-08-07T05:52:03.574705Z"},"trusted":true},"execution_count":72,"outputs":[]},{"cell_type":"code","source":"logging.info(\"----- This is a new run -----\")\nn = 3\nchain = IsingChainLattice(n = n)\nlogging.info(\"Ising model, n = %d\", n)\n\nbmin = -0.02\nbmax = 0.0\neps = 0.1\ndelta = 0.25\nkappa = 0.1\ndist = 64\n\nchain.beta = bmin\ndp = [[]]\nfor i in range(n**2):\n    for j in range(len(dp)):\n        d = dp.pop(0)\n        dp.append(d+[0])\n        dp.append(d+[1])\nz = 0\nfor d in dp:\n    hh = chain.get_Hamiltonian(d)\n    z += np.exp(-bmin*hh)\nreal_z = z/2**(n**2)\nprint(\"real z:\", real_z)\nlogging.info(\"real value z = %.20f\", real_z)\n\n# TPA for parallel and super Gibbs\n\ntao_dict = {256: 1.260, 128: 1.372, 64:1.539, 32: 1.794, 16: 2.197, 8: 2.86, 4:4.0}\nHmax = chain.get_Hmax()\nHmin = chain.get_Hmin()\ngamma = 0.24\ntao = tao_dict[dist]\nm = tao/2/np.log(1+ gamma) * np.log(Hmax)\nk = int(m*dist)\nprint(\"k = \", k)\nchain.beta = bmin\nq = np.log(chain.get_upper_Q())\ntvd = kappa/ (k*q)\n#res = TPA_k_d(bmin, bmax, k, dist, chain, tvd)\n#schedule, TPAsteps = res[\"schedule\"], res[\"steps\"]\n\nschedule,TPAsteps = [-0.02, 0.0],0\n\ndef parallelGibbs1(schedule = None, TPAsteps = 0, bmin = 0, bmax = 1, gibbsChain = None, eps = 0.1, delta = 0.25, kappa = 0.2, d = 64, trace = True):\n    \n    print(\"running Parallel Gibbs...\")\n    print(f\"l = {len(schedule)}, trace = {trace}, e = {eps}, delta = {delta}, kappa = {kappa}, bmin = {bmin}, bmax = {bmax}, d = {d}\")\n\n    z = 1.0\n    w = 1.0\n    v = 1.0\n    sample_complexity = 0\n    Hmax = gibbsChain.get_Hmax()\n    Hmin = gibbsChain.get_Hmin()\n    sample_complexity += TPAsteps\n    l = len(schedule)\n    \n    # get mean-estimator params\n    epsprime = ((1+eps)**(1/l)-1) / ((1+eps)**(1/l)+1)\n    delprime = delta/2/l\n    \n    \n    \n    \n    for i in range(l-1):\n        print(i)\n        gap = schedule[i+1]-schedule[i]\n\n        gibbsChain.beta = schedule[i]\n        gibbsChain.set_startpoint()\n        func_f = lambda x: np.exp(-gap/2*gibbsChain.get_Hamiltonian(x))\n        a = np.exp(-gap/2*Hmax)\n        b = np.exp(-gap/2*Hmin)\n        \n        begin1 = time.time()\n        res1 = mean_estimator(gibbsChain, func_f, epsprime, delprime, a, b, use_trace=trace)\n        end1 = time.time()\n        \n        wi = res1[\"mean_value\"]\n        sample_complexity += res1[\"steps\"]\n        \n        print('wi',wi, res1[\"steps\"])\n        W=[wi, res1[\"steps\"],end1-begin1]\n        \n        gibbsChain.beta = schedule[i+1]\n        gibbsChain.set_startpoint()\n        func_g = lambda x: np.exp(gap/2*gibbsChain.get_Hamiltonian(x)) \n        a = np.exp(gap/2*Hmin)\n        b = np.exp(gap/2*Hmax)\n        \n        begin2 = time.time()\n        res2 = mean_estimator(gibbsChain, func_g, epsprime, delprime, a, b, use_trace=trace)\n        end2 = time.time()\n        \n        vi = res2[\"mean_value\"]\n        sample_complexity += res2[\"steps\"]\n        \n        print('vi',vi, res2[\"steps\"])\n        V=[vi,res2[\"steps\"],end2-begin2]\n        \n        w *= wi\n        v *= vi\n    print('v',v)\n    print('w',w)\n    z = v/w\n    \n    return z, sample_complexity,W,V,\n\nz, steps,W,V = parallelGibbs1(schedule = schedule, TPAsteps = TPAsteps, bmin = bmin, bmax = bmax, gibbsChain = chain, eps = eps, delta = delta, kappa= kappa, d = dist, trace = False)\nprint(\"parallelGibbs takes \", steps, \"steps, while z = \", z, \"TPA takes \", TPAsteps, \"steps\")","metadata":{"_uuid":"02fcdb69-54ce-4c34-abce-22161601452f","_cell_guid":"bd571579-c8da-4d29-99e4-6b609220d142","collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2023-08-07T06:00:48.766539Z","iopub.execute_input":"2023-08-07T06:00:48.766914Z","iopub.status.idle":"2023-08-07T06:00:55.637591Z","shell.execute_reply.started":"2023-08-07T06:00:48.766880Z","shell.execute_reply":"2023-08-07T06:00:55.636694Z"},"trusted":true},"execution_count":84,"outputs":[{"name":"stdout","text":"real z: 1.1743327316567358\nk =  587\nrunning Parallel Gibbs...\nl = 2, trace = False, e = 0.1, delta = 0.25, kappa = 0.1, bmin = -0.02, bmax = 0.0, d = 64\n0\n\nR =  0.1119544028286068 lambda =  0.9949431154076287\nT =  1\nmi= 13020 tv = 0.0002977991689405186 variance_upper 0.023423421003214998 bernstein 0.17117219231786907\nmi= 14322 tv = 0.0002998460056530724 variance_upper 0.021362948723785007 bernstein 0.1557084020636092\nmi= 15754 tv = 0.0002971584100386677 variance_upper 0.019477721195698506 bernstein 0.14163469494544745\nmi= 17330 tv = 0.0002976826035168887 variance_upper 0.01776827352966579 bernstein 0.12884142948045163\nmi= 19063 tv = 0.0002960443841007518 variance_upper 0.016208628323281973 bernstein 0.11720676344846032\nmi= 20969 tv = 0.0002930527378012635 variance_upper 0.014786466632059996 bernstein 0.10662485379064243\nmi= 23066 tv = 0.00029192336791363476 variance_upper 0.013495479256843166 bernstein 0.09700587661115068\nmi= 25372 tv = 0.00029283724854053413 variance_upper 0.012324854345887691 bernstein 0.08826740331217231\nmi= 27910 tv = 0.00029465635004231037 variance_upper 0.011260640614171044 bernstein 0.08031960311665082\nmi= 30700 tv = 0.0002950099147222977 variance_upper 0.010289917694920162 bernstein 0.07309337266416278\ndone, I = 10 i = 10 0.0238230365969691 0.05992810352914354 0.9340726323348647\nwi 0.9340726323348647 30700\n\nR =  0.12877821624045382 lambda =  0.9949431154076287\nT =  1\nmi= 13020 tv = 0.0004048129047104488 variance_upper 0.03101804260822404 bernstein 0.19692659349733205\nmi= 14322 tv = 0.0004065119035984197 variance_upper 0.02828868497861293 bernstein 0.1791351666530046\nmi= 15754 tv = 0.0004091281527747502 variance_upper 0.025807663778719094 bernstein 0.1629628846348488\nmi= 17330 tv = 0.00040988324196469193 variance_upper 0.023545022285281537 bernstein 0.14824604597586136\nmi= 19063 tv = 0.00040935631472606134 variance_upper 0.021484070612877063 bernstein 0.13486616400814389\nmi= 20969 tv = 0.0004095396555741252 variance_upper 0.01961016206233254 bernstein 0.12270347604689599\nmi= 23066 tv = 0.00040443982077061107 variance_upper 0.017893597906256345 bernstein 0.11162867934013151\nmi= 25372 tv = 0.00041303332374346806 variance_upper 0.016358456426575158 bernstein 0.1015935915009997\nmi= 27910 tv = 0.0004110133405448681 variance_upper 0.014940534361185333 bernstein 0.09243940949483753\nmi= 30700 tv = 0.00040876047718787974 variance_upper 0.013650038470968148 bernstein 0.08411977864821824\ndone, I = 10 i = 10 0.0238230365969691 0.05992810352914354 1.0744392752043948\nvi 1.0744392752043948 30700\nv 1.0744392752043948\nw 0.9340726323348647\nparallelGibbs takes  61400 steps, while z =  1.1502737988572271 TPA takes  0 steps\n","output_type":"stream"}]},{"cell_type":"code","source":"np.mean(cs2)","metadata":{"execution":{"iopub.status.busy":"2023-08-07T06:00:55.642512Z","iopub.execute_input":"2023-08-07T06:00:55.644642Z","iopub.status.idle":"2023-08-07T06:00:55.654040Z","shell.execute_reply.started":"2023-08-07T06:00:55.644606Z","shell.execute_reply":"2023-08-07T06:00:55.653237Z"},"trusted":true},"execution_count":85,"outputs":[{"execution_count":85,"output_type":"execute_result","data":{"text/plain":"1.0833329548164838"},"metadata":{}}]},{"cell_type":"code","source":"def get_values(T,types,n,beta):\n    chain = IsingChainLattice(n = n)\n    chain.beta= beta\n    uniform_mixing = chain.get_uniform_mixing()\n    chain.restart_and_sample(steps =uniform_mixing)\n    if types=='f':\n        func_f = lambda x: np.exp(-gap/2*chain.get_Hamiltonian(x))\n    if types=='g':\n        func_f = lambda x: np.exp(gap/2*chain.get_Hamiltonian(x)) \n    chainvals = [0 for _ in range(T)]\n    for t in range(T):\n        chain.step()\n        chainvals[t] = func_f(chain.current)\n    return chainvals\n\n\ngap = schedule[1]-schedule[0]\ncs1 = [get_values(CFG.M,'f',n,-0.02) for  i in range(CFG.N)]\ncs1 = np.array(cs1)\ncs2 = [get_values(CFG.M,'g',n,0) for  i in range(CFG.N)]\ncs2 = np.array(cs2)\n\n\ndef seed_torch(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.backends.cudnn.deterministic = True\nseed_torch(seed = CFG.seed)\n\ndef deltaW(N, m, h, seed,generator=None):\n    if generator is None:\n        generator = np.random.default_rng(seed=seed)\n    return generator.normal(0.0, np.sqrt(h), (N, m))\n\ndef get_wt(CFG,seed=1):\n    n,h,N,d = CFG.M,CFG.h,CFG.N,CFG.d\n    ts = np.array([i*h for i in range(n)])\n    w_t_increments = np.array([deltaW(N=n-1, m=d, h=h, generator=None,seed=int(i+seed)) for i in range(N)])\n    w_t = np.array([np.insert(np.cumsum(i.T),0,np.zeros(1)) for i in w_t_increments])\n    return w_t,ts,w_t_increments#.reshape(CFG.N,CFG.M-1)\n\ndef sinkhorn_loss(x, y, epsilon, n, niter):\n    C = cost_matrix(x, y)\n    mu = Variable(1. / n * torch.FloatTensor(n).fill_(1), requires_grad=False)\n    nu = Variable(1. / n * torch.FloatTensor(n).fill_(1), requires_grad=False)\n    rho = 1 \n    tau = -.8  \n    lam = rho / (rho + epsilon)  \n    thresh = 10**(-1) \n    def ave(u, u1):\n        return tau * u + (1 - tau) * u1\n    def M(u, v):\n        return (-C + u.unsqueeze(1) + v.unsqueeze(0)) / epsilon\n    def lse(A):\n        return torch.log(torch.exp(A).sum(1, keepdim=True) + 1e-6)\n    u, v, err = 0. * mu, 0. * nu, 0.\n    actual_nits = 0  \n    for i in range(niter):\n        u1 = u  \n        u = epsilon * (torch.log(mu) - lse(M(u, v)).squeeze()) + u\n        v = epsilon * (torch.log(nu) - lse(M(u, v).t()).squeeze()) + v\n        err = (u - u1).abs().sum()\n        actual_nits += 1\n        if (err < thresh).data.numpy():\n            break\n    U, V = u, v\n    pi = torch.exp(M(U, V))  \n    cost = torch.sum(pi * C)\n    return cost\n\ndef cost_matrix(x, y, p=2):\n    x_col = x.unsqueeze(1)\n    y_lin = y.unsqueeze(0)\n    c = torch.sum((torch.abs(x_col - y_lin)) ** p, 2)\n    return c   \n\n\nclass wasserstein_loss(nn.Module):\n    def __init__(self, **kwargs):\n        super(wasserstein_loss, self).__init__(**kwargs)\n    def forward(self,x,y):\n        return sinkhorn_loss(x.t(),y.t(),epsilon=1e-5,n=N,niter=1000)\n    \ndef gradients(u, x, order=1):  \n    if order == 1:  \n        return torch.autograd.grad(u, x,grad_outputs=torch.ones_like(u),  create_graph=True,only_inputs=True,)[0]  \n    else:  \n        return gradients(gradients(u, x), x,order=order-1)\n\nclass F_model(nn.Module):\n    def __init__(self):\n        super(F_model, self).__init__()\n        self.hidden_layer1 = torch.nn.Linear(CFG.d, CFG.hidden_size)\n        self.hidden_layer2 = torch.nn.Linear(CFG.hidden_size, CFG.hidden_size)\n        self.output_layer = torch.nn.Linear(CFG.hidden_size, CFG.d)\n        self.tanh = nn.Tanh()\n    def forward(self, x):\n        x = self.hidden_layer1(x)\n        x = self.hidden_layer2(x)\n        x =  self.tanh(x)\n        x = self.hidden_layer2(x)\n        x = self.hidden_layer2(x)\n        #x = self.hidden_layer2(x)\n        x =  self.tanh(x)\n        x = self.output_layer(x)\n        return x\n    \nclass G_model(nn.Module):\n    def __init__(self):\n        super(G_model, self).__init__()\n        self.hidden_layer1 = torch.nn.Linear(CFG.d, CFG.hidden_size)\n        self.hidden_layer2 = torch.nn.Linear(CFG.hidden_size, CFG.hidden_size)\n        self.output_layer = torch.nn.Linear(CFG.hidden_size, CFG.d*CFG.d)\n        self.tanh = nn.Tanh()\n    def forward(self, x):\n        x = self.hidden_layer1(x)\n        x = self.hidden_layer2(x)\n        x =  self.tanh(x)\n        x = self.hidden_layer2(x)\n        x = self.hidden_layer2(x)\n        #x = self.hidden_layer2(x)\n        x =  self.tanh(x)\n        x = self.output_layer(x).view(x.shape[0],CFG.d,CFG.d)\n        return x\n    \nclass X_0_model(nn.Module):\n    def __init__(self):\n        super(X_0_model, self).__init__()\n        self.hidden_layer1 = torch.nn.Linear(CFG.d, CFG.hidden_size)\n        self.hidden_layer2 = torch.nn.Linear(CFG.hidden_size, CFG.hidden_size)\n        self.output_layer = torch.nn.Linear(CFG.hidden_size, CFG.d)\n        self.tanh = nn.Tanh()\n    def forward(self, x):\n        x = self.hidden_layer1(x)\n        x = self.output_layer(x)\n        return x #torch.mean(x)#.view(CFG.N,1)\n    \n\nclass Pde_model(nn.Module):\n    def __init__(self):\n        super(Pde_model, self).__init__()\n        self.hidden_layer1 = torch.nn.Linear(CFG.d+1, CFG.hidden_size)\n        self.hidden_layer2 = torch.nn.Linear(CFG.hidden_size, CFG.hidden_size)\n        self.output_layer = torch.nn.Linear(CFG.hidden_size, 1)\n        self.tanh = nn.Tanh()\n        self.initialize_weights()\n    def forward(self,x):\n        x = self.hidden_layer1(x)\n        x =  self.tanh(x)\n        x = self.hidden_layer2(x)\n        x =  self.tanh(x)\n        x = self.output_layer(x)\n        return x\n    \n    def initialize_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Linear):\n                torch.nn.init.normal_(m.weight.data, 0, 0.01)\n                m.bias.data.zero_()\n                \nclass loss_function(nn.Module):\n    def __init__(self, **kwargs):\n        super(loss_function, self).__init__(**kwargs)\n        \n    def get_1d_wasserstein(self,u_values, v_values):\n        u_sorter = torch.argsort(u_values)\n        v_sorter = torch.argsort(v_values)\n        all_values =torch.cat((u_values, v_values))\n        all_values= torch.sort(all_values)[0]\n        deltas = torch.diff(all_values)\n        u_cdf_indices = torch.searchsorted(u_values[u_sorter],all_values[:-1],right=True)\n        v_cdf_indices = torch.searchsorted(v_values[v_sorter],all_values[:-1],right=True)\n        u_cdf = u_cdf_indices / len(u_values)\n        v_cdf = v_cdf_indices / len(v_values)\n        return torch.sum(torch.multiply(torch.abs(u_cdf - v_cdf), deltas))\n    \n    def get_wasserstein(self,x,y):\n        return sinkhorn_loss(x, y,epsilon=1e-2,n=CFG.N,niter=2000)\n    \n    def get_1d_wasserstein_sum(self,x,y):\n        for i in range(CFG.d):\n            if i == 0 :\n                sums = self.get_1d_wasserstein(x[:,i],y[:,i])\n            else:\n                sums = sums + self.get_1d_wasserstein(x[:,i],y[:,i])\n        return sums\n    \n    def get_out(self,t,yt,w_t_increments,par,model_list):\n        xt = model_list[2](par)\n        xt_l = []\n        for i in range(CFG.M):\n            if i == 0 :\n                xt_l.append(xt)\n                xt1 =xt\n            else:\n                xt1 =xt1+ model_list[0](xt)*CFG.h + (model_list[1](xt) @ w_t_increments[:,i-1,:].unsqueeze(-1)).view(CFG.N,CFG.d)\n                xt_l.append(xt1)\n\n        f= [model_list[0](xt_l[j])for j in range(CFG.M)]\n        l =[model_list[1](xt_l[j])for j in range(CFG.M)]\n        return xt_l,f,l\n\n    def forward(self,t,yt,w_t_increments,par,model_list):\n        xt = model_list[2](par)\n        xt_l = []\n        for i in range(CFG.M):\n            if i == 0 :\n                xt_l.append(xt)\n                xt1 =xt\n            else:\n                xt1 =xt1+ model_list[0](xt)*CFG.h + (model_list[1](xt) @ w_t_increments[:,i-1,:].unsqueeze(-1)).view(CFG.N,CFG.d)\n                xt_l.append(xt1)\n        wr =0 \n\n        for i in range(CFG.M):\n            wr = wr+self.get_1d_wasserstein(xt_l[i].view(-1),yt[:,i])\n        \n        return wr #self.get_wasserstein(xt_l[-1],yt)# + torch.mean(loss_sum)\n    \nclass loss_pde(nn.Module):\n    def __init__(self, **kwargs):\n        super(loss_pde, self).__init__(**kwargs)\n    def forward(self,t,xt_l,f,ll,model_list):\n        T = torch.cat([t[-1].view(-1) for K in range(CFG.N)],dim = 0).view(CFG.N,1)\n        XT = xt_l[-1]\n\n        uTx =torch.cat([T,XT],dim=-1)\n        uT = model_list[3](uTx)\n        loss1 = (uT - torch.sum(XT,dim=1))**2\n        \n        loss_sum =0 \n        for s in range(0,CFG.M-1):#range(CFG.M):\n            xi= xt_l[s]\n            ti = torch.cat([t[s].view(-1) for K in range(CFG.N)],dim = 0).view(CFG.N,1)\n            x =  torch.cat([ti,xi],dim=-1)\n            u = model_list[3](x)\n            \n            d1 = gradients(u,x)\n            \n            hessian=[]\n            for i in range(CFG.N):\n                P = torch.cat([gradients(j, x)[i] for j in d1[i]],dim=0).view(CFG.d+1, -1)\n                hessian.append(P[:,1:][1:])\n               \n            ut = d1[:,0]\n            \n            a = torch.sum(f[s] * d1[:,1:],dim=1)\n            uxy = torch.cat(hessian,dim=0).unsqueeze(0).view(CFG.N,CFG.d,CFG.d)\n            b= 1/2* torch.sum(torch.sum(uxy*ll[s],dim=1),dim=-1)\n            loss_sum = (ut + a+ b )**2\n        return torch.mean(loss_sum)+1*torch.mean(loss1)\n        \ndef get_optimizer(model):\n    if CFG.optim == 'SGD':\n        optimizer = torch.optim.SGD(model.parameters(), lr=CFG.lr)\n    if CFG.optim == 'Adam':\n        optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr, amsgrad = False)\n    return optimizer\n\ndef get_wt_nn(CFG,seed= 1):\n    n,h,N = CFG.M,CFG.h,CFG.N\n    ts = np.array([i*h for i in range(n)])\n    w_t_increments = np.array([deltaW(N=n-1, m=1, h=h, generator=None,seed=int(i+seed)) for i in range(N)])\n    w_t = np.array([np.insert(np.cumsum(i.T),0,np.zeros(1)) for i in w_t_increments])\n    return w_t,ts,w_t_increments.reshape(CFG.N,CFG.M-1)\n    \n    \ndef get_predict_sample(CFG,model_list):\n    par = np.array([1 for i in range(CFG.N*CFG.d)])\n    par = torch.tensor(par.astype(np.float32)).reshape(CFG.N,CFG.d).requires_grad_(True).cuda()\n    w_t,ts,w_t_increments = get_wt(CFG,seed=100000)\n    w_t_increments = torch.tensor(w_t_increments.astype(np.float32)).requires_grad_(False).cuda()\n    t = torch.tensor(ts.astype(np.float32)).reshape(CFG.M,1).requires_grad_(True).cuda()\n    t0 = torch.cat([t[0].view(-1) for K in range(CFG.N)],dim = 0).view(CFG.N,1).cuda()\n    \n    xt = model_list[2](par)\n    xt_l = []\n    for i in range(CFG.M):\n        if i == 0 :\n            xt_l.append(xt)\n            xt1 =xt\n        else:\n            xt1 =xt1+ model_list[0](xt)*CFG.h + (model_list[1](xt) @ w_t_increments[:,i-1,:].unsqueeze(-1)).view(CFG.N,CFG.d)\n            xt_l.append(xt1)\n    return xt_l\n\n\ndef get_nn_solution(CFG,yt):\n    model_list =[F_model().cuda() ,G_model().cuda(), X_0_model().cuda(),Pde_model().cuda()]\n    par = np.array([1 for i in range(CFG.N*CFG.d)])\n    par = torch.tensor(par.astype(np.float32)).reshape(CFG.N,CFG.d).requires_grad_(True).cuda()\n    yt = torch.tensor(yt.astype(np.float32)).requires_grad_(False).cuda()\n    loss_ito = loss_function()\n    loss_pdes = loss_pde()\n    optimizer_list = [get_optimizer(model_list[i]) for i in range(3)]+[torch.optim.Adam(model_list[3].parameters(), lr=CFG.lr*10, amsgrad = False)]\n    result= []\n\n    loss_ito_log = []\n\n    w_t,ts,w_t_increments = get_wt(CFG,seed=1)\n    w_t_increments = torch.tensor(w_t_increments.astype(np.float32)).requires_grad_(False).cuda()\n    t = torch.tensor(ts.astype(np.float32)).reshape(CFG.M,1).requires_grad_(True).cuda()\n    t0 = torch.cat([t[0].view(-1) for K in range(CFG.N)],dim = 0).view(CFG.N,1).cuda()\n    print('stage 1 begin')\n    b1 = time.time()\n    for epoch in range(CFG.epoch_stage_1):\n        for optimizer in optimizer_list[0:3]:\n            optimizer.zero_grad()\n            \n        loss_1 = loss_ito(t,yt,w_t_increments,par,model_list)\n        loss_1.backward()\n        for optimizer in optimizer_list[0:3]:\n            optimizer.step()\n        if epoch % CFG.print_step== 0: \n            print('Epoch:{}/{}\\t Ito_loss={:.4f}\\t '.format(CFG.epoch_stage_1,epoch ,loss_1,))\n    \n    e1 = time.time()\n    time1 =e1 - b1\n    xt_l,f,l = loss_ito.get_out(t,yt,w_t_increments,par,model_list)\n    xt_l = [torch.tensor(i.detach().cpu().numpy()).requires_grad_(True).cuda() for i in xt_l]\n    f = [torch.tensor(i.detach().cpu().numpy()).requires_grad_(False).cuda() for i in f]\n    l = [torch.tensor(i.detach().cpu().numpy()).requires_grad_(False).cuda() for i in l]\n    ll = [l[s] @ torch.transpose(l[s], 2, 1) for s in range(CFG.M)]\n    \n    print('stage 2 begin')\n    b2 = time.time()\n    #print('Ex',np.mean(np.sum(yt,axis=1)))\n    EX_LSIT=[]\n    for epoch in range(CFG.epoch_stage_2):\n        optimizer_list[-1].zero_grad()\n        loss_2 = loss_pdes(t,xt_l,f,ll,model_list)\n        loss_2.backward()\n        optimizer_list[-1].step()\n        \n        if epoch % 1== 0:\n            x =  torch.cat([t0,model_list[2](par)],dim=-1)\n            EX= model_list[3](x)\n            if epoch in [i for i in range(CFG.epoch_stage_2 - 10, CFG.epoch_stage_2)]:\n                EX_LSIT.append(EX[0].item())\n            \n            print('Epoch:{}/{}\\t PDE_loss={:.4f}\\t EX={:.4f}\\t'.format(CFG.epoch_stage_2,epoch ,loss_2,EX[0].item()))\n            \n\n    e2 = time.time()\n    time2 =e2 - b2\n    return result,model_list,np.mean(EX_LSIT),time1,time2  #,[,loss_ito_log,loss_sum_log]","metadata":{"_uuid":"82822e5a-936a-4499-a253-eaa6f3f38595","_cell_guid":"b90ca60e-db60-4609-b7d2-1f84da1fa6a4","collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2023-08-07T06:00:55.658459Z","iopub.execute_input":"2023-08-07T06:00:55.660632Z","iopub.status.idle":"2023-08-07T06:01:06.318592Z","shell.execute_reply.started":"2023-08-07T06:00:55.660599Z","shell.execute_reply":"2023-08-07T06:01:06.316674Z"},"trusted":true},"execution_count":86,"outputs":[]},{"cell_type":"code","source":"result,model_list,EX1,time1a,time2a=get_nn_solution(CFG,cs1)\npredict1 = get_predict_sample(CFG,model_list)\n\n\nresult,model_list,EX2,time1b,time2b=get_nn_solution(CFG,cs2)\npredict2 = get_predict_sample(CFG,model_list)","metadata":{"_uuid":"068c2c3b-3515-4021-86fa-9f56cfeb32c6","_cell_guid":"6e5f541d-0b0c-4ea7-b774-736f4cecd1f5","collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2023-08-07T06:01:06.324926Z","iopub.execute_input":"2023-08-07T06:01:06.327553Z","iopub.status.idle":"2023-08-07T06:09:35.524929Z","shell.execute_reply.started":"2023-08-07T06:01:06.327513Z","shell.execute_reply":"2023-08-07T06:09:35.523887Z"},"trusted":true},"execution_count":87,"outputs":[{"name":"stdout","text":"stage 1 begin\nEpoch:1000/0\t Ito_loss=10.9582\t \nEpoch:1000/100\t Ito_loss=0.1213\t \nEpoch:1000/200\t Ito_loss=0.1139\t \nEpoch:1000/300\t Ito_loss=0.1058\t \nEpoch:1000/400\t Ito_loss=0.1171\t \nEpoch:1000/500\t Ito_loss=0.1048\t \nEpoch:1000/600\t Ito_loss=0.1040\t \nEpoch:1000/700\t Ito_loss=0.1040\t \nEpoch:1000/800\t Ito_loss=0.0951\t \nEpoch:1000/900\t Ito_loss=0.0806\t \nstage 2 begin\nEpoch:100/0\t PDE_loss=0.8620\t EX=0.0988\t\nEpoch:100/1\t PDE_loss=0.6945\t EX=0.5807\t\nEpoch:100/2\t PDE_loss=0.2751\t EX=1.2814\t\nEpoch:100/3\t PDE_loss=0.6175\t EX=0.8440\t\nEpoch:100/4\t PDE_loss=0.1515\t EX=0.5240\t\nEpoch:100/5\t PDE_loss=0.1934\t EX=0.4412\t\nEpoch:100/6\t PDE_loss=0.2533\t EX=0.5212\t\nEpoch:100/7\t PDE_loss=0.1781\t EX=0.7294\t\nEpoch:100/8\t PDE_loss=0.0534\t EX=1.0438\t\nEpoch:100/9\t PDE_loss=0.0455\t EX=1.2251\t\nEpoch:100/10\t PDE_loss=0.1167\t EX=1.0981\t\nEpoch:100/11\t PDE_loss=0.0346\t EX=0.8940\t\nEpoch:100/12\t PDE_loss=0.0033\t EX=0.7573\t\nEpoch:100/13\t PDE_loss=0.0366\t EX=0.7065\t\nEpoch:100/14\t PDE_loss=0.0620\t EX=0.7265\t\nEpoch:100/15\t PDE_loss=0.0581\t EX=0.8014\t\nEpoch:100/16\t PDE_loss=0.0362\t EX=0.9096\t\nEpoch:100/17\t PDE_loss=0.0198\t EX=1.0132\t\nEpoch:100/18\t PDE_loss=0.0238\t EX=1.0640\t\nEpoch:100/19\t PDE_loss=0.0327\t EX=1.0409\t\nEpoch:100/20\t PDE_loss=0.0230\t EX=0.9695\t\nEpoch:100/21\t PDE_loss=0.0052\t EX=0.8932\t\nEpoch:100/22\t PDE_loss=0.0017\t EX=0.8428\t\nEpoch:100/23\t PDE_loss=0.0119\t EX=0.8292\t\nEpoch:100/24\t PDE_loss=0.0186\t EX=0.8502\t\nEpoch:100/25\t PDE_loss=0.0137\t EX=0.8971\t\nEpoch:100/26\t PDE_loss=0.0070\t EX=0.9539\t\nEpoch:100/27\t PDE_loss=0.0079\t EX=0.9953\t\nEpoch:100/28\t PDE_loss=0.0115\t EX=0.9980\t\nEpoch:100/29\t PDE_loss=0.0089\t EX=0.9634\t\nEpoch:100/30\t PDE_loss=0.0026\t EX=0.9175\t\nEpoch:100/31\t PDE_loss=0.0006\t EX=0.8869\t\nEpoch:100/32\t PDE_loss=0.0032\t EX=0.8805\t\nEpoch:100/33\t PDE_loss=0.0059\t EX=0.8908\t\nEpoch:100/34\t PDE_loss=0.0060\t EX=0.9078\t\nEpoch:100/35\t PDE_loss=0.0043\t EX=0.9276\t\nEpoch:100/36\t PDE_loss=0.0034\t EX=0.9494\t\nEpoch:100/37\t PDE_loss=0.0040\t EX=0.9661\t\nEpoch:100/38\t PDE_loss=0.0044\t EX=0.9656\t\nEpoch:100/39\t PDE_loss=0.0032\t EX=0.9459\t\nEpoch:100/40\t PDE_loss=0.0013\t EX=0.9207\t\nEpoch:100/41\t PDE_loss=0.0006\t EX=0.9052\t\nEpoch:100/42\t PDE_loss=0.0012\t EX=0.9030\t\nEpoch:100/43\t PDE_loss=0.0019\t EX=0.9079\t\nEpoch:100/44\t PDE_loss=0.0018\t EX=0.9152\t\nEpoch:100/45\t PDE_loss=0.0013\t EX=0.9260\t\nEpoch:100/46\t PDE_loss=0.0012\t EX=0.9398\t\nEpoch:100/47\t PDE_loss=0.0017\t EX=0.9492\t\nEpoch:100/48\t PDE_loss=0.0020\t EX=0.9459\t\nEpoch:100/49\t PDE_loss=0.0016\t EX=0.9327\t\nEpoch:100/50\t PDE_loss=0.0009\t EX=0.9205\t\nEpoch:100/51\t PDE_loss=0.0007\t EX=0.9157\t\nEpoch:100/52\t PDE_loss=0.0008\t EX=0.9160\t\nEpoch:100/53\t PDE_loss=0.0009\t EX=0.9178\t\nEpoch:100/54\t PDE_loss=0.0007\t EX=0.9225\t\nEpoch:100/55\t PDE_loss=0.0005\t EX=0.9318\t\nEpoch:100/56\t PDE_loss=0.0006\t EX=0.9408\t\nEpoch:100/57\t PDE_loss=0.0008\t EX=0.9422\t\nEpoch:100/58\t PDE_loss=0.0009\t EX=0.9359\t\nEpoch:100/59\t PDE_loss=0.0008\t EX=0.9287\t\nEpoch:100/60\t PDE_loss=0.0007\t EX=0.9246\t\nEpoch:100/61\t PDE_loss=0.0006\t EX=0.9217\t\nEpoch:100/62\t PDE_loss=0.0007\t EX=0.9196\t\nEpoch:100/63\t PDE_loss=0.0007\t EX=0.9215\t\nEpoch:100/64\t PDE_loss=0.0005\t EX=0.9282\t\nEpoch:100/65\t PDE_loss=0.0004\t EX=0.9346\t\nEpoch:100/66\t PDE_loss=0.0005\t EX=0.9359\t\nEpoch:100/67\t PDE_loss=0.0005\t EX=0.9336\t\nEpoch:100/68\t PDE_loss=0.0005\t EX=0.9310\t\nEpoch:100/69\t PDE_loss=0.0005\t EX=0.9281\t\nEpoch:100/70\t PDE_loss=0.0005\t EX=0.9241\t\nEpoch:100/71\t PDE_loss=0.0005\t EX=0.9216\t\nEpoch:100/72\t PDE_loss=0.0006\t EX=0.9234\t\nEpoch:100/73\t PDE_loss=0.0005\t EX=0.9279\t\nEpoch:100/74\t PDE_loss=0.0005\t EX=0.9310\t\nEpoch:100/75\t PDE_loss=0.0005\t EX=0.9319\t\nEpoch:100/76\t PDE_loss=0.0005\t EX=0.9322\t\nEpoch:100/77\t PDE_loss=0.0005\t EX=0.9315\t\nEpoch:100/78\t PDE_loss=0.0004\t EX=0.9287\t\nEpoch:100/79\t PDE_loss=0.0004\t EX=0.9253\t\nEpoch:100/80\t PDE_loss=0.0004\t EX=0.9246\t\nEpoch:100/81\t PDE_loss=0.0005\t EX=0.9264\t\nEpoch:100/82\t PDE_loss=0.0005\t EX=0.9281\t\nEpoch:100/83\t PDE_loss=0.0004\t EX=0.9294\t\nEpoch:100/84\t PDE_loss=0.0004\t EX=0.9310\t\nEpoch:100/85\t PDE_loss=0.0005\t EX=0.9318\t\nEpoch:100/86\t PDE_loss=0.0005\t EX=0.9302\t\nEpoch:100/87\t PDE_loss=0.0004\t EX=0.9277\t\nEpoch:100/88\t PDE_loss=0.0004\t EX=0.9266\t\nEpoch:100/89\t PDE_loss=0.0004\t EX=0.9267\t\nEpoch:100/90\t PDE_loss=0.0004\t EX=0.9270\t\nEpoch:100/91\t PDE_loss=0.0004\t EX=0.9279\t\nEpoch:100/92\t PDE_loss=0.0004\t EX=0.9296\t\nEpoch:100/93\t PDE_loss=0.0004\t EX=0.9306\t\nEpoch:100/94\t PDE_loss=0.0004\t EX=0.9299\t\nEpoch:100/95\t PDE_loss=0.0004\t EX=0.9287\t\nEpoch:100/96\t PDE_loss=0.0004\t EX=0.9280\t\nEpoch:100/97\t PDE_loss=0.0004\t EX=0.9273\t\nEpoch:100/98\t PDE_loss=0.0004\t EX=0.9269\t\nEpoch:100/99\t PDE_loss=0.0004\t EX=0.9276\t\nstage 1 begin\nEpoch:1000/0\t Ito_loss=9.1885\t \nEpoch:1000/100\t Ito_loss=0.0640\t \nEpoch:1000/200\t Ito_loss=0.0786\t \nEpoch:1000/300\t Ito_loss=0.0616\t \nEpoch:1000/400\t Ito_loss=0.0676\t \nEpoch:1000/500\t Ito_loss=0.0789\t \nEpoch:1000/600\t Ito_loss=0.0665\t \nEpoch:1000/700\t Ito_loss=0.0716\t \nEpoch:1000/800\t Ito_loss=0.0754\t \nEpoch:1000/900\t Ito_loss=0.0840\t \nstage 2 begin\nEpoch:100/0\t PDE_loss=1.1653\t EX=0.0901\t\nEpoch:100/1\t PDE_loss=0.9792\t EX=0.5643\t\nEpoch:100/2\t PDE_loss=0.2886\t EX=1.5795\t\nEpoch:100/3\t PDE_loss=0.2731\t EX=1.3307\t\nEpoch:100/4\t PDE_loss=0.0645\t EX=0.8436\t\nEpoch:100/5\t PDE_loss=0.0605\t EX=0.6718\t\nEpoch:100/6\t PDE_loss=0.1693\t EX=0.7982\t\nEpoch:100/7\t PDE_loss=0.0822\t EX=1.0826\t\nEpoch:100/8\t PDE_loss=0.0029\t EX=1.3491\t\nEpoch:100/9\t PDE_loss=0.0738\t EX=1.3568\t\nEpoch:100/10\t PDE_loss=0.0831\t EX=1.1584\t\nEpoch:100/11\t PDE_loss=0.0099\t EX=0.9598\t\nEpoch:100/12\t PDE_loss=0.0151\t EX=0.8608\t\nEpoch:100/13\t PDE_loss=0.0533\t EX=0.8607\t\nEpoch:100/14\t PDE_loss=0.0573\t EX=0.9352\t\nEpoch:100/15\t PDE_loss=0.0263\t EX=1.0591\t\nEpoch:100/16\t PDE_loss=0.0012\t EX=1.1913\t\nEpoch:100/17\t PDE_loss=0.0185\t EX=1.2606\t\nEpoch:100/18\t PDE_loss=0.0367\t EX=1.2205\t\nEpoch:100/19\t PDE_loss=0.0208\t EX=1.1022\t\nEpoch:100/20\t PDE_loss=0.0014\t EX=0.9846\t\nEpoch:100/21\t PDE_loss=0.0097\t EX=0.9354\t\nEpoch:100/22\t PDE_loss=0.0218\t EX=0.9612\t\nEpoch:100/23\t PDE_loss=0.0197\t EX=1.0201\t\nEpoch:100/24\t PDE_loss=0.0092\t EX=1.0798\t\nEpoch:100/25\t PDE_loss=0.0008\t EX=1.1399\t\nEpoch:100/26\t PDE_loss=0.0096\t EX=1.1869\t\nEpoch:100/27\t PDE_loss=0.0138\t EX=1.1716\t\nEpoch:100/28\t PDE_loss=0.0094\t EX=1.0907\t\nEpoch:100/29\t PDE_loss=0.0010\t EX=1.0167\t\nEpoch:100/30\t PDE_loss=0.0049\t EX=1.0016\t\nEpoch:100/31\t PDE_loss=0.0080\t EX=1.0216\t\nEpoch:100/32\t PDE_loss=0.0090\t EX=1.0403\t\nEpoch:100/33\t PDE_loss=0.0031\t EX=1.0691\t\nEpoch:100/34\t PDE_loss=0.0019\t EX=1.1223\t\nEpoch:100/35\t PDE_loss=0.0037\t EX=1.1559\t\nEpoch:100/36\t PDE_loss=0.0067\t EX=1.1268\t\nEpoch:100/37\t PDE_loss=0.0033\t EX=1.0722\t\nEpoch:100/38\t PDE_loss=0.0016\t EX=1.0468\t\nEpoch:100/39\t PDE_loss=0.0019\t EX=1.0445\t\nEpoch:100/40\t PDE_loss=0.0043\t EX=1.0402\t\nEpoch:100/41\t PDE_loss=0.0034\t EX=1.0469\t\nEpoch:100/42\t PDE_loss=0.0018\t EX=1.0838\t\nEpoch:100/43\t PDE_loss=0.0009\t EX=1.1220\t\nEpoch:100/44\t PDE_loss=0.0026\t EX=1.1201\t\nEpoch:100/45\t PDE_loss=0.0027\t EX=1.0921\t\nEpoch:100/46\t PDE_loss=0.0018\t EX=1.0743\t\nEpoch:100/47\t PDE_loss=0.0008\t EX=1.0654\t\nEpoch:100/48\t PDE_loss=0.0017\t EX=1.0514\t\nEpoch:100/49\t PDE_loss=0.0021\t EX=1.0481\t\nEpoch:100/50\t PDE_loss=0.0017\t EX=1.0713\t\nEpoch:100/51\t PDE_loss=0.0008\t EX=1.0996\t\nEpoch:100/52\t PDE_loss=0.0012\t EX=1.1040\t\nEpoch:100/53\t PDE_loss=0.0016\t EX=1.0929\t\nEpoch:100/54\t PDE_loss=0.0015\t EX=1.0857\t\nEpoch:100/55\t PDE_loss=0.0008\t EX=1.0765\t\nEpoch:100/56\t PDE_loss=0.0010\t EX=1.0605\t\nEpoch:100/57\t PDE_loss=0.0012\t EX=1.0559\t\nEpoch:100/58\t PDE_loss=0.0013\t EX=1.0717\t\nEpoch:100/59\t PDE_loss=0.0008\t EX=1.0886\t\nEpoch:100/60\t PDE_loss=0.0008\t EX=1.0914\t\nEpoch:100/61\t PDE_loss=0.0010\t EX=1.0897\t\nEpoch:100/62\t PDE_loss=0.0011\t EX=1.0887\t\nEpoch:100/63\t PDE_loss=0.0008\t EX=1.0792\t\nEpoch:100/64\t PDE_loss=0.0008\t EX=1.0654\t\nEpoch:100/65\t PDE_loss=0.0009\t EX=1.0645\t\nEpoch:100/66\t PDE_loss=0.0010\t EX=1.0755\t\nEpoch:100/67\t PDE_loss=0.0008\t EX=1.0827\t\nEpoch:100/68\t PDE_loss=0.0008\t EX=1.0847\t\nEpoch:100/69\t PDE_loss=0.0009\t EX=1.0883\t\nEpoch:100/70\t PDE_loss=0.0009\t EX=1.0873\t\nEpoch:100/71\t PDE_loss=0.0008\t EX=1.0769\t\nEpoch:100/72\t PDE_loss=0.0007\t EX=1.0692\t\nEpoch:100/73\t PDE_loss=0.0008\t EX=1.0722\t\nEpoch:100/74\t PDE_loss=0.0008\t EX=1.0772\t\nEpoch:100/75\t PDE_loss=0.0008\t EX=1.0792\t\nEpoch:100/76\t PDE_loss=0.0007\t EX=1.0837\t\nEpoch:100/77\t PDE_loss=0.0008\t EX=1.0874\t\nEpoch:100/78\t PDE_loss=0.0008\t EX=1.0824\t\nEpoch:100/79\t PDE_loss=0.0007\t EX=1.0750\t\nEpoch:100/80\t PDE_loss=0.0008\t EX=1.0742\t\nEpoch:100/81\t PDE_loss=0.0008\t EX=1.0759\t\nEpoch:100/82\t PDE_loss=0.0008\t EX=1.0764\t\nEpoch:100/83\t PDE_loss=0.0007\t EX=1.0802\t\nEpoch:100/84\t PDE_loss=0.0007\t EX=1.0849\t\nEpoch:100/85\t PDE_loss=0.0008\t EX=1.0832\t\nEpoch:100/86\t PDE_loss=0.0008\t EX=1.0783\t\nEpoch:100/87\t PDE_loss=0.0007\t EX=1.0770\t\nEpoch:100/88\t PDE_loss=0.0007\t EX=1.0765\t\nEpoch:100/89\t PDE_loss=0.0008\t EX=1.0756\t\nEpoch:100/90\t PDE_loss=0.0007\t EX=1.0786\t\nEpoch:100/91\t PDE_loss=0.0007\t EX=1.0826\t\nEpoch:100/92\t PDE_loss=0.0007\t EX=1.0820\t\nEpoch:100/93\t PDE_loss=0.0007\t EX=1.0797\t\nEpoch:100/94\t PDE_loss=0.0007\t EX=1.0789\t\nEpoch:100/95\t PDE_loss=0.0007\t EX=1.0772\t\nEpoch:100/96\t PDE_loss=0.0007\t EX=1.0759\t\nEpoch:100/97\t PDE_loss=0.0007\t EX=1.0784\t\nEpoch:100/98\t PDE_loss=0.0007\t EX=1.0810\t\nEpoch:100/99\t PDE_loss=0.0007\t EX=1.0807\t\n","output_type":"stream"}]},{"cell_type":"code","source":"df = pd.DataFrame(columns=['Method','n','wi','vi','q','true_wi','true_vi','true_q'\n                          ,'error_wi','error_vi','error_q','wi sample ponints','vi sample ponints','wi time','vi time']\n                 )","metadata":{"_uuid":"4a5ba455-917b-4279-bf0f-5f5a03d0251d","_cell_guid":"2bc0898e-4033-48a1-ad6a-c41c36d68280","collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2023-08-07T06:09:38.635797Z","iopub.execute_input":"2023-08-07T06:09:38.636151Z","iopub.status.idle":"2023-08-07T06:09:38.643044Z","shell.execute_reply.started":"2023-08-07T06:09:38.636120Z","shell.execute_reply":"2023-08-07T06:09:38.642080Z"},"trusted":true},"execution_count":90,"outputs":[]},{"cell_type":"code","source":"m1 = ['MCMC-C',n,np.round(W[0],decimals=7),np.round(V[0],decimals=7),np.round(z,decimals=7)\n      ,true_w,true_v,true_q,0,0,0,W[1],V[1],np.round(W[2],5),np.round(V[2],5)\n     ] \n\nm2 = ['MCMC-R',n,np.round(torch.mean(torch.cat(predict1,dim=1)).item(),7),np.round(torch.mean(torch.cat(predict2,dim=1)).item(),7),\n      np.round(torch.mean(torch.cat(predict2,dim=1)).item()/torch.mean(torch.cat(predict1,dim=1)).item(),7)\n      ,true_w,true_v,true_q,0,0,0,cs1.shape[0]*cs1.shape[1],cs2.shape[0]*cs2.shape[1],np.round(time1a,3),np.round(time1b,3)\n      ]\n\nm3= ['MCMC-T',n,np.round(EX1,7),np.round(EX2,7),np.round(EX2/EX1,7),\n      true_w,true_v,true_q,0,0,0,cs1.shape[0]*cs1.shape[1],cs2.shape[0]*cs2.shape[1],np.round(time2a,3),np.round(time2b,3)\n     ]","metadata":{"_uuid":"3bace2df-4aab-4523-8728-005b48895ea0","_cell_guid":"1ae878e3-f820-4898-b5a0-d7b9b3a7d1f8","collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2023-08-07T06:09:39.128911Z","iopub.execute_input":"2023-08-07T06:09:39.130061Z","iopub.status.idle":"2023-08-07T06:09:39.141609Z","shell.execute_reply.started":"2023-08-07T06:09:39.130019Z","shell.execute_reply":"2023-08-07T06:09:39.140607Z"},"trusted":true},"execution_count":91,"outputs":[]},{"cell_type":"code","source":"df.loc[0] = m1\ndf.loc[1] = m2\ndf.loc[2] = m3","metadata":{"_uuid":"74c5a275-41c1-409d-a4a8-0f55efd237d0","_cell_guid":"0344ed3b-ca81-4120-a26d-db3a4ed8b207","collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2023-08-07T06:09:39.695541Z","iopub.execute_input":"2023-08-07T06:09:39.695880Z","iopub.status.idle":"2023-08-07T06:09:39.711407Z","shell.execute_reply.started":"2023-08-07T06:09:39.695850Z","shell.execute_reply":"2023-08-07T06:09:39.710441Z"},"trusted":true},"execution_count":92,"outputs":[]},{"cell_type":"code","source":"df['error_wi'] = (df['true_wi'] - df['wi'])**2\ndf['error_vi'] = (df['true_vi'] - df['vi'])**2\ndf['error_q'] = (df['true_q'] - df['q'])**2","metadata":{"_uuid":"55947ed4-5559-4818-a284-3e457e9f7c73","_cell_guid":"5e264bdd-ca6c-41af-9686-48f6655e523c","collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2023-08-07T06:09:40.056772Z","iopub.execute_input":"2023-08-07T06:09:40.057082Z","iopub.status.idle":"2023-08-07T06:09:40.065428Z","shell.execute_reply.started":"2023-08-07T06:09:40.057054Z","shell.execute_reply":"2023-08-07T06:09:40.064340Z"},"trusted":true},"execution_count":93,"outputs":[]},{"cell_type":"code","source":"df","metadata":{"_uuid":"2cdf71ee-f2e4-4670-b050-a75924f31113","_cell_guid":"99cf2de1-4d28-4bce-acb0-8cd7c3746a79","collapsed":false,"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2023-08-07T06:09:41.001891Z","iopub.execute_input":"2023-08-07T06:09:41.002281Z","iopub.status.idle":"2023-08-07T06:09:41.019682Z","shell.execute_reply.started":"2023-08-07T06:09:41.002249Z","shell.execute_reply":"2023-08-07T06:09:41.018449Z"},"trusted":true},"execution_count":94,"outputs":[{"execution_count":94,"output_type":"execute_result","data":{"text/plain":"   Method  n        wi        vi         q  true_wi   true_vi    true_q  \\\n0  MCMC-C  3  0.934073  1.074439  1.150274  0.92264  1.083487  1.174333   \n1  MCMC-R  3  0.926999  1.077446  1.162295  0.92264  1.083487  1.174333   \n2  MCMC-T  3  0.928355  1.079516  1.162827  0.92264  1.083487  1.174333   \n\n   error_wi  error_vi   error_q  wi sample ponints  vi sample ponints  \\\n0  0.000131  0.000082  0.000579              30700              30700   \n1  0.000019  0.000036  0.000145               2000               2000   \n2  0.000033  0.000016  0.000132               2000               2000   \n\n     wi time    vi time  \n0    3.40208    3.43238  \n1   20.16300   19.54400  \n2  236.10900  233.22500  ","text/html":"<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>Method</th>\n      <th>n</th>\n      <th>wi</th>\n      <th>vi</th>\n      <th>q</th>\n      <th>true_wi</th>\n      <th>true_vi</th>\n      <th>true_q</th>\n      <th>error_wi</th>\n      <th>error_vi</th>\n      <th>error_q</th>\n      <th>wi sample ponints</th>\n      <th>vi sample ponints</th>\n      <th>wi time</th>\n      <th>vi time</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>MCMC-C</td>\n      <td>3</td>\n      <td>0.934073</td>\n      <td>1.074439</td>\n      <td>1.150274</td>\n      <td>0.92264</td>\n      <td>1.083487</td>\n      <td>1.174333</td>\n      <td>0.000131</td>\n      <td>0.000082</td>\n      <td>0.000579</td>\n      <td>30700</td>\n      <td>30700</td>\n      <td>3.40208</td>\n      <td>3.43238</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>MCMC-R</td>\n      <td>3</td>\n      <td>0.926999</td>\n      <td>1.077446</td>\n      <td>1.162295</td>\n      <td>0.92264</td>\n      <td>1.083487</td>\n      <td>1.174333</td>\n      <td>0.000019</td>\n      <td>0.000036</td>\n      <td>0.000145</td>\n      <td>2000</td>\n      <td>2000</td>\n      <td>20.16300</td>\n      <td>19.54400</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>MCMC-T</td>\n      <td>3</td>\n      <td>0.928355</td>\n      <td>1.079516</td>\n      <td>1.162827</td>\n      <td>0.92264</td>\n      <td>1.083487</td>\n      <td>1.174333</td>\n      <td>0.000033</td>\n      <td>0.000016</td>\n      <td>0.000132</td>\n      <td>2000</td>\n      <td>2000</td>\n      <td>236.10900</td>\n      <td>233.22500</td>\n    </tr>\n  </tbody>\n</table>\n</div>"},"metadata":{}}]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]}]}