# -*- coding: utf-8 -*-
"""
Spyder Editor

This is a temporary script file.
"""

import torch
from torch import nn
import matplotlib.pyplot as plt
import torch.nn.functional as func
import numpy as np
import math
import h5py
import os
import time

_DT = 1. # 0.01, 0.1
NOISE = 0.0 # 0, 0.1 0.8
ROOT = str(_DT) + '_' + str(NOISE)
os.makedirs(ROOT, exist_ok = True)  

epsT = 0.2 #0.01 dataset dt (<= epsD)
epsN = 0.2 #0.01 network pred timestep
epsD = _DT #0.01 dataset time span
N = 1
device = 'cuda:0'
device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
points_validation = 200
dtp_validation = torch.tensor([_DT]) # torch.tensor([0.01])
plot_points = 10
lt = 0.5
dtp = torch.tensor([lt / plot_points])
hidden_dim = 64
n_sample = 1280
lt_Np = 200
plot_Np = 1000
dt_Np = torch.tensor([lt_Np / plot_Np])
l_r = 0.05  
def to_np(x):
    return x.detach().cpu().numpy()


def analyH(q, p):
    return 0.5*(p**2 + q**2)
class LinearBlock(nn.Module):
    def __init__(self, inchannel, outchannel):
        super(LinearBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Linear(inchannel, outchannel),
            #nn.Tanh(),
            nn.Sigmoid(),
            #nn.ReLU(inplace=True),
        )

    def forward(self, x):
        out = self.left(x)
        return out

class KTrained_baseline(nn.Module):
    def __init__(self, N, hidden_dim):
        super(KTrained_baseline, self).__init__()
        self.N = N
        self.cal_H = nn.Sequential(LinearBlock(2 * self.N, hidden_dim),
                                    LinearBlock(hidden_dim, hidden_dim),
                                    LinearBlock(hidden_dim, hidden_dim),
                                    LinearBlock(hidden_dim, hidden_dim),
                                    LinearBlock(hidden_dim, hidden_dim),
                                    nn.Linear(hidden_dim, 2*self.N))
        self.b = nn.Parameter(torch.zeros(1,1,2*self.N) , requires_grad=True).to(device)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.uniform_(-math.sqrt(6. / m.in_features), math.sqrt(6. / m.in_features))
    def forward(self, q, p):
        with torch.enable_grad():
            x = torch.cat([q, p], dim=2)
            x = x.requires_grad_(True)
            K = self.cal_H(x)+self.b
        return K[:, :, :self.N], K[:, :, self.N:self.N * 2]
    

class KTrained(nn.Module):
    def __init__(self, N, hidden_dim):
        super(KTrained, self).__init__()
        self.N = N
        self.cal_H = nn.Sequential(LinearBlock(2 * self.N, hidden_dim),
                                    LinearBlock(hidden_dim, hidden_dim),
                                    LinearBlock(hidden_dim, hidden_dim),
                                    LinearBlock(hidden_dim, hidden_dim),
                                    LinearBlock(hidden_dim, hidden_dim),
                                    nn.Linear(hidden_dim, 1))
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.uniform_(-math.sqrt(6. / m.in_features), math.sqrt(6. / m.in_features))

    def forward_train(self, q, p):
        with torch.enable_grad():
            x = torch.cat([q, p], dim=2)
            x = x.requires_grad_(True)
            K = self.cal_H(x.squeeze(1))
            dK = torch.autograd.grad(K.sum(), x, retain_graph=True, create_graph=True)[0]
        return dK[:, :, :self.N], dK[:, :, self.N:self.N * 2]

    def forward(self, q, p):
        with torch.enable_grad():
            x = torch.cat([q, p], dim=2)
            x = x.requires_grad_(True)
            K = self.cal_H(x.squeeze(1))
            dK = torch.autograd.grad(K.sum(), x, retain_graph=True, create_graph=False)[0]
        return dK[:, :, :self.N], dK[:, :, self.N:self.N * 2]

class KAnalysis(nn.Module):
    def __init__(self, ):
        super(KAnalysis, self).__init__()
    def forward(self,q,p):
        with torch.enable_grad():
            q = q.requires_grad_(True)
            p = p.requires_grad_(True)
            K = analyH(q, p)
            K = K.sum()
            dq = torch.autograd.grad(K, q, retain_graph=True, create_graph=False)[0]
            dp = torch.autograd.grad(K, p, retain_graph=True, create_graph=False)[0]
        return dq, dp
    
def RK2(q, p, dt, K_t, eps):
    n_steps = np.round((torch.abs(dt) / eps).max().item())
    h = dt / n_steps
    h = h.unsqueeze(1).unsqueeze(1)
    for i_step in range(int(n_steps)):
        dp, dq = K_t(q, p)
        q1 = q + 0.5 * dq * h
        p1 = p - 0.5 * dp * h
        dp, dq = K_t(q1, p1)
        q = q + dq * h
        p = p - dp * h
    return q, p

def Nonsep_SymInt(q, p, x, y, dt, K_t, eps):
    n_steps = np.round((torch.abs(dt) / eps).max().item())
    h = dt / n_steps
    h = h.unsqueeze(1).unsqueeze(1)
    w = 2000
    for i_step in range(int(n_steps)):
        x1, p1 = K_t(q, y)
        p = p - x1 * h * 0.5
        x = x + p1 * h * 0.5
        q1, y1 = K_t(x, p)
        q = q + y1 * h * 0.5
        y = y - q1 * h * 0.5
        q1 = 0.5 * (q - x)
        p1 = 0.5 * (p - y)
        x1 = torch.cos(2 * w * h) * q1 + torch.sin(2 * w * h) * p1
        y1 = -torch.sin(2 * w * h) * q1 + torch.cos(2 * w * h) * p1
        q1 = 0.5 * (q + x)
        p1 = 0.5 * (p + y)
        q = q1 + x1
        p = p1 + y1
        x = q1 - x1
        y = p1 - y1
        q1, y1 = K_t(x, p)
        q = q + y1 * h * 0.5
        y = y - q1 * h * 0.5
        x1, p1 = K_t(q, y)
        p = p - x1 * h * 0.5
        x = x + p1 * h * 0.5
    return q, p, x, y


def Gen_Data():
    datau = [[] for i in range(2)]
    datau_hnn = [[] for i in range(2)]
    datat = []
    for i in range(n_sample):
        t1_t2 = torch.tensor([epsD])
        q0 = 8.*(torch.rand(1, 1, N)-0.5)
        p0 = 8.*(torch.rand(1, 1, N)-0.5)
        datau_hnn[0].append(torch.cat([q0, p0], dim=1).unsqueeze(-1))
        x0 = q0
        y0 = p0
        datau[0].append(torch.cat([q0, p0, q0, p0], dim=1).unsqueeze(-1))
        f_true =  KAnalysis()
        f_true.eval()
        with torch.no_grad():
            q1, p1, x1, y1 = Nonsep_SymInt(q0, p0, x0, y0, t1_t2, f_true.forward, epsT)
            q1 = q1 + NOISE*(torch.rand(1, 1, N)-0.5) #noise
            p1 = p1 + NOISE*(torch.rand(1, 1, N)-0.5)
            dq0 = (q1 - q0)/t1_t2
            dp0 = (p1 - p0)/t1_t2
            datau_hnn[1].append(torch.cat([-dp0, dq0], dim=1).unsqueeze(-1))
            datau[1].append(torch.cat([q1, p1, q1, p1], dim=1).unsqueeze(-1))
            datat.append(t1_t2)
    data_root = ROOT #os.path.join(os.path.dirname(os.path.realpath(__file__)))
    datau = [torch.cat(datau[j]) for j in range(2)]
    datau = torch.cat(datau, dim = -1).float()
    datat = torch.tensor(datat).float()
    hf = h5py.File(os.path.join(data_root, "data.h5"), "w")
    hf.create_dataset('u', data=datau)
    hf.create_dataset('dt', data=datat)
    hf.close() 
    
    data_root_hnn = ROOT #os.path.join(os.path.dirname(os.path.realpath(__file__)))
    datau_hnn = [torch.cat(datau_hnn[j]) for j in range(2)]
    datau_hnn = torch.cat(datau_hnn, dim = -1).float()
    hf_hnn = h5py.File(os.path.join(data_root_hnn, "data_hnn.h5"), "w")
    hf_hnn.create_dataset('u', data=datau_hnn)
    hf_hnn.close()
    
    qT = torch.tensor([[[0.]]])
    pT = torch.tensor([[[-3.]]])
    xT = qT
    yT = pT
    qpxyT = [torch.cat([qT, pT, xT, yT], dim=1)]
    print('test data')
    with torch.no_grad():
        for i in range(plot_points):
            qT, pT, xT, yT = Nonsep_SymInt(qT, pT, xT, yT, dtp, f_true.forward, epsT)
            qpxyT.append(torch.cat([qT, pT, xT, yT], dim=1))
    torch.save(torch.cat(qpxyT), os.path.join(ROOT, 'test.dat'))
    qpxyT0 = []
    qpxyT1 = []
    print('validation data')
    with torch.no_grad():
        for i in range(points_validation):
            qT0 = 8.*(torch.rand(1, 1, N)-0.5)
            pT0 = 8.*(torch.rand(1, 1, N)-0.5)
            xT0 = qT0
            yT0 = pT0
            qT1, pT1, xT1, yT1 = Nonsep_SymInt(qT0, pT0, xT0, yT0, dtp_validation, f_true.forward, epsT)
            qpxyT0.append(torch.cat([qT0, pT0, xT0, yT0], dim=1))
            qpxyT1.append(torch.cat([qT1, pT1, xT1, yT1], dim=1))
    torch.save(torch.cat(qpxyT0), os.path.join(ROOT, 'validation0.dat'))
    torch.save(torch.cat(qpxyT1), os.path.join(ROOT, 'validation1.dat'))

class Dataset_HNN(torch.utils.data.Dataset):
    def __init__(self, data_type):
        f = h5py.File(os.path.join(ROOT, 'data_hnn.h5'))
        self.u = f['u'][:]
        split = int(self.u.shape[0] * 0.9)
        if data_type == 'train':
            self.u = torch.from_numpy(self.u[:split]).to(device)
        else:
            self.u = torch.from_numpy(self.u[split:]).to(device)
        f.close()
    def __getitem__(self, index):
        return self.u[index]

    def __len__(self):
        return self.u.shape[0]
        
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_type):
        f = h5py.File(os.path.join(ROOT, 'data.h5'))
        self.u = f['u'][:]
        self.dt = f['dt'][:]
        split = int(self.u.shape[0] * 0.9)
        if data_type == 'train':
            self.u = torch.from_numpy(self.u[:split]).to(device)
            self.dt = torch.from_numpy(self.dt[:split]).to(device)
        else:
            self.u = torch.from_numpy(self.u[split:]).to(device)
            self.dt = torch.from_numpy(self.dt[split:]).to(device)
        f.close()

    def __getitem__(self, index):
        return self.u[index], self.dt[index]

    def __len__(self):
        return self.u.shape[0]



f_neur_hnn = KTrained(N, hidden_dim)
f_neur_hnn.to(device)
f_neur_nssnn = KTrained(N, hidden_dim)
f_neur_nssnn.to(device)
f_neur_nssnn_rk = KTrained(N, hidden_dim)
f_neur_nssnn_rk.to(device)
f_neur_baseline = KTrained_baseline(N, hidden_dim)
f_neur_baseline.to(device)
loss_func = func.l1_loss
#loss_func = func.mse_loss


def train(model):
    model_path = os.path.join(ROOT, model)
    if (model == "model_baseline.pt"):
        optimizer = torch.optim.Adam(f_neur_baseline.parameters(), lr=l_r)
    elif (model == "model_hnn.pt"): 
        optimizer = torch.optim.Adam(f_neur_hnn.parameters(), lr=l_r)
    elif (model == "model_nssnn.pt"):
        optimizer = torch.optim.Adam(f_neur_nssnn.parameters(), lr=l_r)
    elif (model == "model_nssnn_rk.pt"):
        optimizer = torch.optim.Adam(f_neur_nssnn_rk.parameters(), lr=l_r)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
    if (model == "model_hnn.pt"):
        train_data_loader = torch.utils.data.DataLoader(Dataset_HNN('train'), batch_size=512, shuffle=True)
        test_data_loader = torch.utils.data.DataLoader(Dataset_HNN('test'), batch_size=512, shuffle=True)
    else:
        train_data_loader = torch.utils.data.DataLoader(Dataset('train'), batch_size=512, shuffle=True)
        test_data_loader = torch.utils.data.DataLoader(Dataset('test'), batch_size=512, shuffle=True)
    lowest_test_loss = 99999
    for i in range(401):
        train_loss = 0
        train_sample = 0
        test_loss = 0
        test_sample = 0
        if (model == "model_baseline.pt"):
            f_neur_baseline.train()
        elif (model == "model_hnn.pt"): 
            f_neur_hnn.train()
        elif (model == "model_nssnn.pt"):
            f_neur_nssnn.train()
        elif (model == "model_nssnn_rk.pt"):
            f_neur_nssnn_rk.train()
        for batch_index, data_batch in enumerate(train_data_loader):
            optimizer.zero_grad()
            if (model == "model_hnn.pt"):
                qpxyT = data_batch
                q0T = qpxyT[:, 0:1, :, 0]
                p0T = qpxyT[:, 1:2, :, 0]
                dq0T = qpxyT[:, 0:1, :, 1]
                dp0T = qpxyT[:, 1:2, :, 1]
                dq0N,dp0N = f_neur_hnn.forward_train(q0T,p0T)
                loss = loss_func(dq0T, dq0N)+loss_func(dp0T, dp0N)
            else:
                qpxyT, dt = data_batch
                qpxyN = []
                q0T = qpxyT[:, 0:1, :, 0]
                p0T = qpxyT[:, 1:2, :, 0]
                if (model == "model_nssnn.pt"):
                    x0T = qpxyT[:, 2:3, :, 0]
                    y0T = qpxyT[:, 3:4, :, 0]
                    q1N, p1N, x1N, y1N = Nonsep_SymInt(q0T, p0T, x0T, y0T, dt, f_neur_nssnn.forward_train, epsN)
                    qpxyN.append(torch.cat([q1N, p1N, x1N, y1N], dim=1).unsqueeze(-1))      
                elif (model == "model_nssnn_rk.pt"):
                    q1N, p1N = RK2(q0T, p0T, dt, f_neur_nssnn_rk.forward_train, epsN)
                    qpxyN.append(torch.cat([q1N, p1N], dim=1).unsqueeze(-1))      
                elif (model == "model_baseline.pt"):
                    q1N, p1N = RK2(q0T, p0T, dt, f_neur_baseline.forward, epsN)
                    qpxyN.append(torch.cat([q1N, p1N, q1N, p1N], dim=1).unsqueeze(-1))
                qpxyN = torch.cat(qpxyN,dim = -1)
                if (model == "model_nssnn_rk.pt"):
                    loss = loss_func(qpxyT[:,0:2,:,1], qpxyN[:,:,:,0])
                else :
                    loss = loss_func(qpxyT[:,:,:,1], qpxyN[:,:,:,0])
            train_loss += loss.detach().cpu().item()
            train_sample += 1
            loss.backward()
            optimizer.step()
        scheduler.step()
        if (model == "model_baseline.pt"):
            f_neur_baseline.eval()
        elif (model == "model_hnn.pt"): 
            f_neur_hnn.eval()
        elif (model == "model_nssnn.pt"):
            f_neur_nssnn.eval()
        elif (model == "model_nssnn_rk.pt"):
            f_neur_nssnn_rk.eval()
        with torch.no_grad():
            for batch_index, data_batch in enumerate(test_data_loader):
                if (model == "model_hnn.pt"):
                    qpxyT = data_batch
                    q0T = qpxyT[:, 0:1, :, 0]
                    p0T = qpxyT[:, 1:2, :, 0]
                    dq0T = qpxyT[:, 0:1, :, 1]
                    dp0T = qpxyT[:, 1:2, :, 1]
                    dq0N,dp0N = f_neur_hnn.forward(q0T,p0T)
                    loss = loss_func(dq0T, dq0N)+loss_func(dp0T, dp0N)
                else:
                    qpxyT, dt = data_batch
                    qpxyN = []
                    q0T = qpxyT[:, 0:1, :, 0]
                    p0T = qpxyT[:, 1:2, :, 0]
                    if (model == "model_nssnn.pt"):
                        x0T = qpxyT[:, 2:3, :, 0]
                        y0T = qpxyT[:, 3:4, :, 0]
                        q1N, p1N, x1N, y1N = Nonsep_SymInt(q0T, p0T, x0T, y0T, dt, f_neur_nssnn.forward, epsN)
                        qpxyN.append(torch.cat([q1N, p1N, x1N, y1N], dim=1).unsqueeze(-1))      
                    elif (model == "model_nssnn_rk.pt"):
                        q1N, p1N = RK2(q0T, p0T, dt, f_neur_nssnn_rk.forward, epsN)
                        qpxyN.append(torch.cat([q1N, p1N], dim=1).unsqueeze(-1)) 
                    elif (model == "model_baseline.pt"):
                        q1N, p1N = RK2(q0T, p0T, dt, f_neur_baseline.forward, epsN)
                        qpxyN.append(torch.cat([q1N, p1N, q1N, p1N], dim=1).unsqueeze(-1))         
                    qpxyN = torch.cat(qpxyN,dim = -1)
                    if (model == "model_nssnn_rk.pt"):
                        loss = loss_func(qpxyT[:,0:2,:,1], qpxyN[:,:,:,0])
                    else :
                        loss = loss_func(qpxyT[:,:,:,1], qpxyN[:,:,:,0])
                test_loss += loss.detach().cpu().item()
                test_sample += 1
            
        if i % 100 == 0:
            print(model)
            print(i, train_loss / train_sample, test_loss / test_sample)
            with open(model+'loss.dat','w') as f:
                f.write(str(train_loss / train_sample))
                f.write('\t\t')
                f.write(str(test_loss / test_sample))
        if lowest_test_loss > test_loss / test_sample:
            if (model == "model_nssnn.pt"):
                torch.save(f_neur_nssnn.state_dict(), model_path)
            elif (model == "model_nssnn_rk.pt"):
                torch.save(f_neur_nssnn_rk.state_dict(), model_path)
            elif (model == "model_hnn.pt"):    
                torch.save(f_neur_hnn.state_dict(), model_path)
            elif (model == "model_baseline.pt"):
                torch.save(f_neur_baseline.state_dict(), model_path)
            lowest_test_loss = test_loss / test_sample

def validation(model): 
    model_path = os.path.join(ROOT, model)
    if (model == "model_baseline.pt"):
        f_neur_baseline.load_state_dict(torch.load(model_path, map_location=lambda storage, location: storage))
        f_neur_baseline.to(device)
        f_neur_baseline.eval()
    elif (model == "model_hnn.pt"): 
        f_neur_hnn.load_state_dict(torch.load(model_path, map_location=lambda storage, location: storage))
        f_neur_hnn.to(device)
        f_neur_hnn.eval()
    elif (model == "model_hnn_tao.pt"): 
        f_neur_hnn.load_state_dict(torch.load(os.path.join(ROOT, "model_hnn.pt"), map_location=lambda storage, location: storage))
        f_neur_hnn.to(device)
        f_neur_hnn.eval()
    elif (model == "model_nssnn.pt"):
        f_neur_nssnn.load_state_dict(torch.load(model_path, map_location=lambda storage, location: storage))
        f_neur_nssnn.to(device)
        f_neur_nssnn.eval()
    elif (model == "model_nssnn_rk.pt"):
        f_neur_nssnn_rk.load_state_dict(torch.load(model_path, map_location=lambda storage, location: storage))
        f_neur_nssnn_rk.to(device)
        f_neur_nssnn_rk.eval()
        
    qpxyT0 = torch.load(os.path.join(ROOT, 'validation0.dat'))
    qpxyT1 = torch.load(os.path.join(ROOT, 'validation1.dat'))
    Geo = 0.
    Err = 0.
    with torch.no_grad():
        for i in range(points_validation):
            qT0 = qpxyT0[i:i+1,0:1,:].to(device)
            pT0 = qpxyT0[i:i+1,1:2,:].to(device)
            xT0 = qpxyT0[i:i+1,2:3,:].to(device)
            yT0 = qpxyT0[i:i+1,3:4,:].to(device)
            qT1 = qpxyT1[i:i+1,0:1,:].to(device)
            pT1 = qpxyT1[i:i+1,1:2,:].to(device)
            if (model == "model_nssnn.pt"):
                qN1, pN1, xN1, yN1 = Nonsep_SymInt(qT0, pT0, xT0, yT0, dtp_validation.to(device), f_neur_nssnn.forward, epsN)
            elif (model == "model_nssnn_rk.pt"):
                qN1, pN1 = RK2(qT0, pT0, dtp_validation.to(device), f_neur_nssnn_rk.forward, epsN)
            elif (model == "model_hnn.pt"):
                qN1, pN1 = RK2(qT0, pT0, dtp_validation.to(device), f_neur_hnn.forward, epsN)
            elif (model == "model_hnn_tao.pt"):
                qN1, pN1, xN1, yN1 = Nonsep_SymInt(qT0, pT0, xT0, yT0, dtp_validation.to(device), f_neur_hnn.forward, epsN)
            elif (model == "model_baseline.pt"):
                qN1, pN1 = RK2(qT0, pT0, dtp_validation.to(device), f_neur_baseline.forward, epsN)
            Geotemp = torch.abs((analyH(qN1, pN1) - analyH(qT0, pT0))/torch.sqrt(analyH(qT0, pT0)**2+0.0000001)).sum().detach().cpu().item()
            Errtemp = (torch.abs(qN1-qT1)+torch.abs(pN1-pT1)).sum().detach().cpu().item()
            Geo = Geo + Geotemp
            Err = Err + Errtemp
    print(model, Geo/points_validation, Err/points_validation)
    with open(os.path.join(ROOT, model+'Hamiltonian.dat'),'w') as f:
        f.write(str(Geo/points_validation))
        f.write('\t\t')
        f.write(str(Err/points_validation))
        
def Gen_Data_plot():
    qT = torch.tensor([[[0.]]])
    pT = torch.tensor([[[-3.]]])
    xT = qT
    yT = pT
    qpxyT = [torch.cat([qT, pT, xT, yT], dim=1)]
    f_true =  KAnalysis()
    f_true.eval()
    with torch.no_grad():
        for i in range(plot_Np):
            qT, pT, xT, yT = Nonsep_SymInt(qT, pT, xT, yT, dt_Np, f_true.forward, epsT)
            qpxyT.append(torch.cat([qT, pT, xT, yT], dim=1))
    torch.save(torch.cat(qpxyT), os.path.join(ROOT, 'test_plot.dat'))
    
def validation_plot(model):
    model_path = os.path.join(ROOT, model)
    if (model == "model_baseline.pt"):
        f_neur_baseline.load_state_dict(torch.load(model_path, map_location=lambda storage, location: storage))
        f_neur_baseline.to(device)
        f_neur_baseline.eval()
    elif (model == "model_hnn.pt"): 
        f_neur_hnn.load_state_dict(torch.load(model_path, map_location=lambda storage, location: storage))
        f_neur_hnn.to(device)
        f_neur_hnn.eval()
    elif (model == "model_hnn_tao.pt"): 
        f_neur_hnn.load_state_dict(torch.load(os.path.join(ROOT, "model_hnn.pt"), map_location=lambda storage, location: storage))
        f_neur_hnn.to(device)
        f_neur_hnn.eval()
    elif (model == "model_nssnn.pt"):
        f_neur_nssnn.load_state_dict(torch.load(model_path, map_location=lambda storage, location: storage))
        f_neur_nssnn.to(device)
        f_neur_nssnn.eval()
    elif (model == "model_nssnn_rk.pt"):
        f_neur_nssnn_rk.load_state_dict(torch.load(model_path, map_location=lambda storage, location: storage))
        f_neur_nssnn_rk.to(device)
        f_neur_nssnn_rk.eval()
    qpxyT = torch.load(os.path.join(ROOT, 'test_plot.dat'))
    qN = qpxyT[0:1,0:1,0:1].to(device)
    pN = qpxyT[0:1,1:2,0:1].to(device)
    xN = qpxyT[0:1,2:3,0:1].to(device)
    yN = qpxyT[0:1,3:4,0:1].to(device)
    qpxyN = [torch.cat([qN, pN, xN, yN], dim=1).detach().cpu()]
    
    with torch.no_grad():
        for i in range(plot_Np):
            if (model == "model_nssnn.pt"):
                qN, pN, xN, yN = Nonsep_SymInt(qN, pN, xN, yN, dt_Np.to(device), f_neur_nssnn.forward, epsN)
            elif (model == "model_nssnn_rk.pt"):
                qN, pN = RK2(qN, pN, dt_Np.to(device), f_neur_nssnn_rk.forward, epsN)
            elif (model == "model_hnn.pt"):
                qN, pN = RK2(qN, pN, dt_Np.to(device), f_neur_hnn.forward, epsN)
            elif (model == "model_hnn_tao.pt"):
                qN, pN, xN, yN = Nonsep_SymInt(qN, pN, xN, yN, dt_Np.to(device), f_neur_hnn.forward, epsN)
            elif (model == "model_baseline.pt"):
                qN, pN = RK2(qN, pN, dt_Np.to(device), f_neur_baseline.forward, epsN)        
            qpxyN.append(torch.cat([qN, pN, qN, pN], dim=1).detach().cpu())
    qpxyN = torch.cat(qpxyN)
    qpxyN = to_np(qpxyN)
    qpxyT = to_np(qpxyT)
    dpq = np.abs((qpxyN[:,0,0]**2+1)*(qpxyN[:,1,0]**2+1)-(qpxyT[:,0,0]**2+1)*(qpxyT[:,1,0]**2+1))
    plt.clf()
    plt.plot(qpxyN[:, 0, 0], qpxyN[:, 1, 0], c="b", linewidth = 1)
    N_true = 31
    plt.plot(qpxyT[:, 0, 0], qpxyT[:, 1, 0], c="r", linewidth = 1)
    plt.xlim(-4, 4)
    plt.ylim(-4, 4)
    plt.yticks(np.linspace(-4, 4, 5),size=24, color='black')
    plt.xticks(np.linspace(-4, 4, 5),size=24, color='black')
    plt.draw()
    plt.show()
    plt.pause(1.)
    plt.savefig(os.path.join(ROOT, '_'+model+'_out.png'))
    with torch.no_grad():
        with open(os.path.join(ROOT, model+'qp.dat'),'w') as f:
            for i in range(plot_Np):
                f.write(str(qpxyN[i, 0, 0]))
                f.write('\t\t')
                f.write(str(qpxyN[i, 1, 0]))
                f.write('\t\t')
                f.write(str(qpxyT[i, 0, 0]))
                f.write('\t\t')
                f.write(str(qpxyT[i, 1, 0]))
                f.write('\n')
        with open(os.path.join(ROOT, model+'dqp.dat'),'w') as f:
            for i in range(plot_Np):
                f.write(str(i*lt_Np/plot_Np))
                f.write('\t\t')
                f.write(str(dpq[i]))
                f.write('\n')
if __name__ == '__main__':
    Gen_Data()
    print("gen_data finished")
    start = time.time()
    train("model_baseline.pt")
    end = time.time()
    time_baseline = end - start
    print("model_baseline",time_baseline)
    start = time.time()
    train("model_hnn.pt")
    end = time.time()
    time_hnn = end - start
    print("model_hnn",time_hnn)
    start = time.time()
    train("model_nssnn.pt")
    end = time.time()
    time_nssnn = end - start
    print("model_nssnn",time_nssnn)
    start = time.time()
    train("model_nssnn_rk.pt")
    end = time.time()
    time_nssnn_rk = end - start
    print("model_nssnn_rk",time_nssnn_rk)
    with open(os.path.join(ROOT, 'trainingtime.dat'),'w') as f:
        f.write("time_baseline")
        f.write('\t\t')
        f.write(str(time_baseline))
        f.write('\n')
        f.write("time_hnn")
        f.write('\t\t')
        f.write(str(time_hnn))
        f.write('\n')
        f.write("time_nssnn")
        f.write('\t\t')
        f.write(str(time_nssnn))
        f.write('\n')
        f.write("time_nssnn_rk")
        f.write('\t\t')
        f.write(str(time_nssnn_rk))
        f.write('\n')
    print("training * 4 finished")
    validation("model_baseline.pt")
    validation("model_hnn.pt")
    validation("model_hnn_tao.pt")
    validation("model_nssnn.pt")
    validation("model_nssnn_rk.pt")
    print("validation * 5 finished")
    Gen_Data_plot()
    validation_plot("model_baseline.pt")
    print("pred baseline finished")
    validation_plot("model_hnn.pt")
    print("pred hnn finished")
    validation_plot("model_hnn_tao.pt")
    print("pred hnn_tao finished")
    validation_plot("model_nssnn.pt")
    print("pred nssnn finished")
    validation_plot("model_nssnn_rk.pt")
    print("pred nssnn_rk finished")
    
    
    
    
