#!/usr/bin/env python
# coding: utf-8

# In[ ]:


import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader 
import numpy as np
from scipy import io
import matplotlib.pyplot as plt
import argparse
import os
import copy
import time

def cal_domain_grad(model, XTGrid, device):
    XTGrid = XTGrid.to(device)
    UVf = model(XTGrid)
    uf, vf = UVf[:, 0], UVf[:, 1]
    uf_t = torch.autograd.grad(outputs=uf.to(device), 
                               inputs=XTGrid, 
                               grad_outputs=torch.ones(uf.shape).to(device), 
                               create_graph = True,
                               allow_unused=True)[0][:,1]
    vf_t = torch.autograd.grad(outputs=vf.to(device), 
                               inputs=XTGrid, 
                               grad_outputs=torch.ones(uf.shape).to(device),
                               create_graph = True,
                               allow_unused=True)[0][:,1]
    uf_x = torch.autograd.grad(outputs=uf.to(device), 
                               inputs=XTGrid, 
                               grad_outputs=torch.ones(uf.shape).to(device),
                               create_graph = True,
                               allow_unused=True)[0][:,0]
    uf_xx = torch.autograd.grad(outputs=uf_x.to(device), 
                               inputs=XTGrid, 
                               grad_outputs=torch.ones(uf_x.shape).to(device),
                               create_graph = True,
                               allow_unused=True)[0][:,0]
    vf_x = torch.autograd.grad(outputs=vf.to(device), 
                               inputs=XTGrid, 
                               grad_outputs=torch.ones(vf.shape).to(device),
                               create_graph = True,
                               allow_unused=True)[0][:,0]
    vf_xx = torch.autograd.grad(outputs=vf_x.to(device), 
                                inputs=XTGrid, 
                                grad_outputs=torch.ones(vf_x.shape).to(device),
                                create_graph = True,
                                allow_unused=True)[0][:,0]

    loss = (0.5*uf_xx - vf_t + (uf**2 + vf**2)*uf)**2 + (0.5*vf_xx + uf_t + (uf**2 + vf**2)*vf)**2 
    loss =  loss.mean()

    mean_x, mean_t = torch.autograd.grad(outputs=loss.to(device), 
                               inputs=XTGrid, 
                               grad_outputs=torch.ones(loss.shape).to(device),
                               create_graph = True,
                               allow_unused=True)[0].T
    grad = torch.concatenate((mean_x.reshape(-1,1), mean_t.reshape(-1,1)), axis = 1)
    return grad

class LPINNSampler():
    def __init__(self, fixed_uniform, device, Nf, X_star, patience=10, min_delta=0.001, beta=0.2, tau=0.002, L=1):
        self.device = device
        self.cnt = 0
        self.beta = beta
        self.tau = tau
        self.L = L
        self.XTGrid = torch.tensor(copy.deepcopy(fixed_uniform), dtype=torch.float32, requires_grad=True).to(self.device)
        self.initial = torch.tensor(copy.deepcopy(fixed_uniform), dtype=torch.float32, requires_grad=True).to(self.device)
        self.Nf = Nf
        self.X_star = torch.tensor(X_star, dtype=torch.float32, requires_grad=True).to(self.device)
        self.traj = []
        self.criterion = nn.MSELoss()

    def update(self, phy_lf, model):
        x_data = self.XTGrid
        samples = x_data.clone().detach().requires_grad_(True)
        
        for t in range(1, self.L + 1):
            grad = phy_lf(model, samples, self.device)
            scaler = torch.sqrt(torch.sum((grad+1e-16)**2, axis = 1)).reshape(-1,1)
            grad = grad / scaler
            with torch.no_grad():
                samples = samples + self.tau * grad + self.beta*torch.sqrt(torch.tensor(2 * self.tau, device=self.device)) * torch.randn(samples.shape, device=self.device)
                samples[:, 0] = torch.clamp(samples[:, 0], min=-5, max=5)  # x-axis clamping
                samples[:, 1] = torch.clamp(samples[:, 1], min=0, max=np.pi/2)   # t-axis clamping
            samples = samples.clone().detach().requires_grad_(True)
        self.XTGrid = samples.detach()

class L_INFSampler():
    def __init__(self, device, Nf, step_size = 0.05 , n_iter = 20):
        self.device = device
        self.step_size = step_size
        self.n_iter = n_iter
        self.Nf = Nf

    def update(self, grad_f, model):
        x_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(-1, 1)
        t_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)
        x_t_new = torch.concatenate((x_new,t_new), axis = 1)
        self.XTGrid = x_t_new
            
        x_data = self.XTGrid
        samples = x_data.clone().detach().requires_grad_(True)
    
        for t in range(1, self.n_iter + 1):
            grad = grad_f(model, samples, self.device)
            with torch.no_grad():
                samples = samples + self.step_size * torch.sign(grad)
                samples[:, 0] = torch.clamp(samples[:, 0], min=-5, max=5)  # x축 클램핑
                samples[:, 1] = torch.clamp(samples[:, 1], min=0, max=np.pi/2)   # t축 클램핑
            samples = samples.clone().detach().requires_grad_(True)
        self.XTGrid = samples.detach()        

class R3Sampler(nn.Module):
    def __init__(self,Nf, fixed_uniform, X_star, device):
        super(R3Sampler, self).__init__()
        
        self.Nf = Nf
        self.device = device
        self.XTGrid = copy.deepcopy(fixed_uniform)
        self.XTGrid = torch.tensor(self.XTGrid, dtype = torch.float32, device=self.device, requires_grad=True) 
        self.X_star = X_star
    
    def update(self, loss):
        with torch.no_grad():
            mean_loss = loss.mean()
            mask = loss > mean_loss
            mask = mask.to('cpu')
            self.XTGrid = self.XTGrid[mask].detach()
            need_n_sample = self.Nf-self.XTGrid.shape[0]
            x_new = torch.zeros(need_n_sample, 1, dtype = torch.float32, device=self.device).uniform_(-5, 5)
            t_new = torch.zeros(need_n_sample, 1, dtype = torch.float32, device=self.device).uniform_(0, np.pi/2)
            x_t_new = torch.concatenate((x_new,t_new), axis = 1)
            self.XTGrid = torch.concatenate((self.XTGrid, x_t_new), axis = 0)
            self.XTGrid = torch.tensor(self.XTGrid, dtype = torch.float32, device=self.device, requires_grad=True)
    
class RADSampler(nn.Module):
    def __init__(self, k, c, device, Nf):
        super(RADSampler, self).__init__()
        self.device = device
        self.Nf = Nf
        self.k = k
        self.c = c
        x_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(-5, 5)
        t_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, np.pi/2)
        x_t_new = torch.concatenate((x_new,t_new), axis = 1)
        self.XTGrid = torch.tensor(x_t_new, dtype = torch.float32, device=self.device, requires_grad=True)  
        
    def update(self, model):
        x_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(-5, 5)
        t_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, np.pi/2)
        x_t_new = torch.concatenate((x_new,t_new), axis = 1)
        XTGrid = torch.tensor(x_t_new, dtype = torch.float32, device=self.device, requires_grad=True)
        XTGrid = XTGrid.to(self.device)
        UVf = model.forward(XTGrid)
        uf, vf = UVf[:, 0], UVf[:, 1]
        uf_t = torch.autograd.grad(outputs=uf.to(self.device), 
                                   inputs=XTGrid, 
                                   grad_outputs=torch.ones(uf.shape).to(self.device), 
                                   create_graph = True,
                                   allow_unused=True)[0][:,1]
        vf_t = torch.autograd.grad(outputs=vf.to(self.device), 
                                   inputs=XTGrid, 
                                   grad_outputs=torch.ones(uf.shape).to(self.device),
                                   create_graph = True,
                                   allow_unused=True)[0][:,1]
        uf_x = torch.autograd.grad(outputs=uf.to(self.device), 
                                   inputs=XTGrid, 
                                   grad_outputs=torch.ones(uf.shape).to(self.device),
                                   create_graph = True,
                                   allow_unused=True)[0][:,0]
        uf_xx = torch.autograd.grad(outputs=uf_x.to(self.device), 
                                   inputs=XTGrid, 
                                   grad_outputs=torch.ones(uf_x.shape).to(self.device),
                                   create_graph = True,
                                   allow_unused=True)[0][:,0]
        vf_x = torch.autograd.grad(outputs=vf.to(self.device), 
                                   inputs=XTGrid, 
                                   grad_outputs=torch.ones(vf.shape).to(self.device),
                                   create_graph = True,
                                   allow_unused=True)[0][:,0]
        vf_xx = torch.autograd.grad(outputs=vf_x.to(self.device), 
                                    inputs=XTGrid, 
                                    grad_outputs=torch.ones(vf_x.shape).to(self.device),
                                    create_graph = True,
                                    allow_unused=True)[0][:,0]


        err = torch.abs((0.5*uf_xx - vf_t + (uf**2 + vf**2)*uf)**2 + (0.5*vf_xx + uf_t + (uf**2 + vf**2)*vf))
        err = (err**self.k)/((err**self.k).mean())+self.c
        err_norm = err/(err.sum())

        indice = torch.multinomial(err_norm, self.Nf, replacement = True)
        self.XTGrid = XTGrid[indice]

    
class PINN(nn.Module):
    def __init__(self,k , c , t, X_star, u_star, v_star, exact_u, exact_h, space_domain, time_domain, Layers, N0, Nb, Nf, Nt, beta, tau, L,
                 Activation = nn.Tanh(), model_name = "PINN.model", device = 'cpu', display_freq=1000, sampling_method = 'fixed'):
        super(PINN, self).__init__()
        
        LBs = [space_domain[0], time_domain[0]]
        UBs = [space_domain[1], time_domain[1]]
        
        self.LBs = torch.tensor(LBs, dtype=torch.float32).to(device)
        self.UBs = torch.tensor(UBs, dtype=torch.float32).to(device)
        
        self.Layers = Layers
        self.in_dim  = Layers[0]
        self.out_dim = Layers[-1]
        self.Activation = Activation
        
        self.device = device
        
        x_init = torch.zeros(Nf, 1, dtype = torch.float32, device=self.device).uniform_(-5, 5)
        t_init = torch.zeros(Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, np.pi/2)
        x_t_init = torch.concatenate((x_init,t_init), axis = 1)
        self.fixed_uniform = torch.tensor(x_t_init, dtype = torch.float32, device=self.device, requires_grad=True)
        
        self.N0 = N0
        self.Nb = Nb
        self.Nf = Nf
        self.Nt = Nt
        self.beta = beta
        self.tau = tau
        self.L = L
        
        self.t = t
        self.X_star = X_star
        self.u_star = u_star
        self.exact_u = exact_u
        self.v_star = v_star
        self.exact_h = exact_h
        self.XT0, self.u0, self.v0   = self.initial_condition(self.LBs[0], self.UBs[0])
        
        self.XT0 = self.XT0.to(device)
        self.u0 = self.u0.to(device)
        self.v0 = self.v0.to(device)
        self.XTbL, self.XTbU = self.boundary_condition( self.LBs[0], self.UBs[0])
        
        self.XTbL = self.XTbL.to(device) 
        self.XTbU = self.XTbU.to(device)
        
        self._nn = self.build_model()
        self._nn.to(self.device)
        self.Loss = torch.nn.MSELoss(reduction='mean')
        
        self.model_name = model_name
        self.display_freq = display_freq
        
        self.k = k
        self.c = c
        self.method = sampling_method
        self.r3_sampler = R3Sampler(self.Nf, self.fixed_uniform, self.X_star, device=self.device)
        self.lpinn_sampler = LPINNSampler(fixed_uniform=self.fixed_uniform, device=self.device, Nf=self.Nf, beta=self.beta, tau=self.tau, L = self.L, X_star=self.X_star)
        self.l_inf_sampler = L_INFSampler(device=self.device, Nf=self.Nf)
        self.rad_sampler = RADSampler(k=self.k, c=self.c, device = self.device, Nf=self.Nf)
        
        try:
            os.mkdir("../models/"+self.method)
        except:
            print('Folder exists!!')
    
    def build_model(self):
        seq = nn.Sequential()
        for layer_idx in range(len(self.Layers)-1):
            this_module = nn.Linear(self.Layers[layer_idx], self.Layers[layer_idx+1])
            nn.init.xavier_normal_(this_module.weight)
            seq.add_module("Linear" + str(layer_idx), this_module)
            if not layer_idx == len(self.Layers)-2:
                seq.add_module("Activation" + str(layer_idx), self.Activation)
        return seq    
    
    def forward(self, x):
        x = x.to(self.device)
        x = x.reshape((-1,self.in_dim))
        return torch.reshape(self._nn.forward(x), (-1, self.out_dim))
    
    def initial_condition(self,LB, UB):
        x = torch.tensor([])

        if (type(LB) != type(x)):
            LB = torch.tensor(LB).cpu()
        else:
            LB = LB.cpu()
        if (type(UB) != type(x)):
            UB = torch.tensor(UB).cpu()
        else:
            UB = UB.cpu()

        indices = (self.X_star[:,0] >= LB) & (self.X_star[:,0] < UB) & (self.X_star[:,1] == 0.)
        XT0 = self.X_star[indices]
        u0 = self.u_star[indices]
        v0 = self.v_star[indices] 

        return XT0, u0, v0

    def boundary_condition(self, LB, UB):
        x = torch.tensor([])
        
        if (type(LB) != type(x)):
            LB = torch.tensor(LB).cpu()
        else:
            LB = LB.cpu()
        if (type(UB) != type(x)):
            UB = torch.tensor(UB).cpu()
        else:
            UB = UB.cpu()
        
        tb =  torch.tensor(np.linspace(0, 1, self.t.shape[0], endpoint=False), dtype = torch.float32)
        XTL = torch.cat(( LB*torch.ones((self.t.shape[0],1)), tb.reshape(-1,1)), dim = 1)
        XTL.requires_grad_()
        XTU = torch.cat(( UB*torch.ones((self.t.shape[0],1)), tb.reshape(-1,1)), dim = 1)
        XTU.requires_grad_()
        
        return  XTL, XTU
    
    def IC_loss(self):
        XT0 = self.XT0
        u0  = self.u0
        v0 = self.v0
        UV0_pred = self.forward(XT0)
        u0_pred = UV0_pred[:,0].reshape(-1)
        v0_pred = UV0_pred[:,1].reshape(-1)
        return self.Loss(u0_pred, u0) + self.Loss(v0_pred, v0)

    def BC_loss(self):
        UVb_L, UVb_U = self.forward(self.XTbL), self.forward(self.XTbU)
        ub_l, vb_l = UVb_L[:, 0], UVb_L[:, 1]
        ub_u, vb_u = UVb_U[:, 0], UVb_U[:, 1]
        ub_l_x = torch.autograd.grad(outputs=ub_l.to(self.device), 
                                     inputs=self.XTbL, 
                                     grad_outputs=torch.ones(ub_l.shape).to(self.device), 
                                     create_graph = True,
                                     allow_unused=True)[0][:,0]
    
        vb_l_x = torch.autograd.grad(outputs=vb_l.to(self.device), 
                                     inputs=self.XTbL, 
                                     grad_outputs=torch.ones(vb_l.shape).to(self.device),
                                     create_graph = True,
                                     allow_unused=True)[0][:,0]
    
        ub_u_x = torch.autograd.grad(outputs=ub_u.to(self.device), 
                                     inputs=self.XTbU, 
                                     grad_outputs=torch.ones(ub_u.shape).to(self.device), 
                                     create_graph = True,
                                     allow_unused=True)[0][:,0]
    
        vb_u_x = torch.autograd.grad(outputs=vb_u.to(self.device), 
                                     inputs=self.XTbU, 
                                     grad_outputs=torch.ones(vb_u.shape).to(self.device), 
                                     create_graph = True,
                                     allow_unused=True)[0][:,0]                    
        return self.Loss(ub_l, ub_u) + self.Loss(vb_l, vb_u) + \
               self.Loss(ub_l_x, ub_u_x) + self.Loss(vb_l_x, vb_u_x)
    
    def classical_sampling(self):
        method = self.method
        if method =='fixed':
            XTGrid = torch.tensor(self.fixed_uniform, dtype = torch.float32, device=self.device, requires_grad=True) 
            
        elif method =='rar':
            x_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(-5, 5)
            t_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, np.pi/2)
            x_t_new = torch.concatenate((x_new,t_new), axis = 1)
            XTGrid = torch.tensor(x_t_new, dtype = torch.float32, device=self.device, requires_grad=True) 
        
        return XTGrid
        
    def physics_loss(self, XTGrid):
        XTGrid = XTGrid.to(self.device)
        UVf = self.forward(XTGrid)
        uf, vf = UVf[:, 0], UVf[:, 1]
        uf_t = torch.autograd.grad(outputs=uf.to(self.device), 
                                   inputs=XTGrid, 
                                   grad_outputs=torch.ones(uf.shape).to(self.device), 
                                   create_graph = True,
                                   allow_unused=True)[0][:,1]
        vf_t = torch.autograd.grad(outputs=vf.to(self.device), 
                                   inputs=XTGrid, 
                                   grad_outputs=torch.ones(uf.shape).to(self.device),
                                   create_graph = True,
                                   allow_unused=True)[0][:,1]
        uf_x = torch.autograd.grad(outputs=uf.to(self.device), 
                                   inputs=XTGrid, 
                                   grad_outputs=torch.ones(uf.shape).to(self.device),
                                   create_graph = True,
                                   allow_unused=True)[0][:,0]
        uf_xx = torch.autograd.grad(outputs=uf_x.to(self.device), 
                                   inputs=XTGrid, 
                                   grad_outputs=torch.ones(uf_x.shape).to(self.device),
                                   create_graph = True,
                                   allow_unused=True)[0][:,0]
        vf_x = torch.autograd.grad(outputs=vf.to(self.device), 
                                   inputs=XTGrid, 
                                   grad_outputs=torch.ones(vf.shape).to(self.device),
                                   create_graph = True,
                                   allow_unused=True)[0][:,0]
        vf_xx = torch.autograd.grad(outputs=vf_x.to(self.device), 
                                    inputs=XTGrid, 
                                    grad_outputs=torch.ones(vf_x.shape).to(self.device),
                                    create_graph = True,
                                    allow_unused=True)[0][:,0]
        
        loss = (0.5*uf_xx - vf_t + (uf**2 + vf**2)*uf)**2 + (0.5*vf_xx + uf_t + (uf**2 + vf**2)*vf)**2
        
        return loss

    def train(self, n_iters, weights=(1.0,1.0,1.0)):
        params = list(self.parameters())
        optimizer = optim.Adam(params, lr=1e-3)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 5000, gamma=0.9, last_epoch=-1)
        min_loss = 999999.0
        training_losses = [-10]*n_iters
        rel_error = [-10]*(1+n_iters//1000)
        start = time.time()
        for it in range(n_iters):
            total_ICloss = torch.tensor(0.0, dtype = torch.float32, device=self.device, requires_grad = True)
            total_BCloss = torch.tensor(0.0, dtype = torch.float32, device=self.device, requires_grad = True)
            total_physics_loss = torch.tensor(0.0, dtype = torch.float32, device=self.device, requires_grad = True)
            
            total_ICloss = total_ICloss + self.IC_loss()
            total_BCloss = total_BCloss + self.BC_loss()
            
            if self.method =='r3':
                if it == 0:
                    XTGrid = self.r3_sampler.XTGrid
                    XTGrid = torch.tensor(XTGrid, dtype = torch.float32, device=self.device, requires_grad=True) 
                else:
                    with torch.no_grad():
                        self.r3_sampler.update(loss)
                        XTGrid = self.r3_sampler.XTGrid
                        XTGrid = torch.tensor(XTGrid, dtype = torch.float32, device=self.device, requires_grad=True) 
                        
            elif self.method == 'lpinn':
                if it == 0:
                    XTGrid = self.lpinn_sampler.XTGrid
                    XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)
                else:
                    if self.lpinn_sampler.cnt % 1 == 0:
                        self.lpinn_sampler.update(cal_domain_grad, self._nn)
                    XTGrid = self.lpinn_sampler.XTGrid
                    XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)
                self.lpinn_sampler.cnt += 1
            
            elif self.method == 'l_inf':
                self.l_inf_sampler.update(cal_domain_grad, self._nn)
                XTGrid = self.l_inf_sampler.XTGrid
                XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)                
            
            elif self.method == 'rad':
                self.rad_sampler.update(self._nn)
                XTGrid = self.rad_sampler.XTGrid
                XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)
            
            else:
                XTGrid = self.classical_sampling()
                
            optimizer.zero_grad()    
            loss = self.physics_loss(XTGrid)
            mean_loss = loss.mean()
            
            total_physics_loss = total_physics_loss + mean_loss
            total_loss = weights[0]*total_ICloss + weights[1]*total_BCloss\
                        + weights[2]*total_physics_loss 
            
            total_loss.backward()
            optimizer.step()
            scheduler.step()
            if total_loss < min_loss:
                torch.save(self, self.model_name)
                min_loss = float(total_loss)
                    
            training_losses[it] = float(total_loss)
            
            if (it+1) % self.display_freq == 0:
                with torch.no_grad():
                    outputs = self.forward(self.X_star)
                    outputs_u  = outputs[:,0]
                    outputs_v  = outputs[:,1]
                    outputs_h = torch.sqrt(outputs_u**2 + outputs_v**2)
                    outputs = outputs_h.reshape(201,256)
                    re = np.linalg.norm(self.exact_h.cpu().T-outputs.cpu().detach()) / np.linalg.norm(self.exact_h.cpu().detach().T)
                    rel_error[int((it+1)/1000)] = float(re*100)
                print("Iteration Number = {}".format(it+1))
                print("\tIC Loss = {}".format(float(total_ICloss)))
                print("\tBC Loss = {}".format(float(total_BCloss)))
                print("\tPhysics Loss = {}".format(float(total_physics_loss)))
                print("\tTraining Loss = {}".format(float(total_loss)))
                print("\tRelative L2 error (test) = {}".format(float(re*100)))

                torch.save(XTGrid, "../models/"+self.method +'/'+str(self.Nf)+"_grid_"+str(it+1)+'_'+ str(self.beta)+'_'+ str(self.L)+'_'+str(i))
                torch.save(exact_u.cpu().T-outputs.cpu().detach(), "../models/"+self.method +'/'+str(self.Nf)+"_error_"+str(it+1)+'_'+ str(self.beta)+'_'+ str(self.L)+'_'+str(i))
                
        end = time.time()
        print(end-start)
        return training_losses, rel_error

if __name__ == "__main__":
        parser = argparse.ArgumentParser()
        parser.add_argument('--nodes', type=int, default = 128, help='The number of nodes per hidden layer in the neural network')
        parser.add_argument('--N0', type=int, default = 100, help='The number of points to use on the initial condition')
        parser.add_argument('--Nb', type=int, default = 100, help='The number of points to use on the boundary condition')
        parser.add_argument('--Nf', type=int, default = 1000, help='The number of collocation points to use')
        parser.add_argument('--Nt', type=int, default = 1000, help='The number of points to use to calculate the MSE loss')
        parser.add_argument('--epochs', type=int, default = 200000, help='The number of epochs to train the neural network')
        parser.add_argument('--model-name', type=str, default='PINN_model', help='File name to save the model')
        parser.add_argument('--display-freq', type=int, default=1000, help='How often to display loss information')
        parser.add_argument('--layers', type=int, default = 4, help='The number of hidden layers in the neural network')
        parser.add_argument('--beta', type=int, default=0.2, help='High residual concentration parameter')
        parser.add_argument('--tau', type=int, default=0.002, help='Langevin step size')
        parser.add_argument('--L', type=int, default=1, help='The number of Langevin iteration')
        parser.add_argument('--method', type=str, default='lpinn', help='sampling method')
        parser.add_argument('-f')
        args = parser.parse_args()

        data = io.loadmat("../data/NLS.mat")
        t = torch.tensor(data['tt'], dtype = torch.float32).reshape(-1)
        x = torch.tensor(data['x'], dtype = torch.float32).reshape(-1)
        exact = data['uu']
        exact_u = torch.tensor(np.real(exact), dtype = torch.float32)
        exact_v = torch.tensor(np.imag(exact), dtype = torch.float32)
        exact_h = torch.sqrt(exact_u**2 + exact_v**2)
        X, T = np.meshgrid(x,t)
        X_star = torch.tensor(np.hstack((X.flatten()[:,None], T.flatten()[:,None])), dtype = torch.float32)
        u_star = torch.flatten(torch.transpose(exact_u,0,1))
        v_star = torch.flatten(torch.transpose(exact_v,0,1))
        h_star = torch.flatten(torch.transpose(exact_h,0,1))
        
        if not os.path.exists("../models/"):
            os.mkdir("../models/")

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        NHiddenLayers = args.layers
        
        boundaries = [-5, 5]
        t_domain = [0., np.pi/2]
        
        Layers = [2] + [args.nodes]*NHiddenLayers + [2]
        Activation = nn.Tanh()

        k = 1
        c = 1
        
        repeat = [0,1,2,3,4]
        sampling_me = 'rad'
        
        for i in repeat:
            pinn = PINN(  k = k,
                          c = c,
                          t= t,
                          X_star = X_star,
                          u_star = u_star,
                          v_star = v_star,
                          exact_u = exact_u,
                          exact_h = exact_h,
                          space_domain = boundaries,
                          time_domain = t_domain,
                          Layers = Layers,
                          N0 = args.N0,
                          Nb = args.Nb,
                          Nf = args.Nf,
                          Nt = args.Nt,
                          beta = args.beta,
                          tau = args.tau,
                          L = args.L,
                          Activation = Activation,
                          device = device,
                          model_name = "../models/" + args.model_name + ".model_"+sampling_me+'_'+str(args.layers)+'_'+str(args.Nf)+'_'+ str(args.beta)+'_'+ str(args.L)+'_' +str(i),
                          display_freq = args.display_freq, sampling_method = sampling_me )

            Losses_train, Losses_rel_l2 = pinn.train(args.epochs, weights = (100, 1, 1)) # initial, boundary, residual

            torch.save(Losses_train, "../models/" + args.model_name + ".loss_"+sampling_me+'_'+str(args.layers)+'_'+str(args.Nf)+'_'+ str(args.beta)+'_'+ str(args.L) + '_'+str(i))
            torch.save(Losses_rel_l2, "../models/" + args.model_name + ".rel_l2_"+sampling_me+'_'+str(args.layers)+'_'+str(args.Nf)+'_'+ str(args.beta)+'_'+ str(args.L)+ '_'+ str(i))

