import torch
import argparse
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import random
from torch import nn
from collections import OrderedDict
import os
seed = 30
from torch.func import vmap, jacrev, hessian
from itertools import cycle
from models.ns import NSNet, NSPotentialNet, NSRestrictedNet, NSRestrictedPotentialNet
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
gen = torch.Generator()
gen.manual_seed(seed)

font = {'size'   : 16}
import matplotlib
matplotlib.rc('font', **font)

t_max = 2.
x_max = 2.2
x_min = 0.
y_max = 0.41
y_min = 0.

parser = argparse.ArgumentParser()
parser.add_argument('--init_weight', default=1., type=float, help='Weight for the init loss')
parser.add_argument('--der_weight', default=1., type=float, help='Weight for the derivative loss')
parser.add_argument('--out_weight', default=1., type=float, help='Weight for the output loss')
parser.add_argument('--mom_weight', default=1., type=float, help='Weight for the momentum pde loss')
parser.add_argument('--div_weight', default=1., type=float, help='Weight for the divergence pde loss')
parser.add_argument('--bc_weight', default=1., type=float, help='Weight for the boundary condition loss')
parser.add_argument('--lr_init', default=1e-4, type=float, help='Starting learning rate')
parser.add_argument('--device', default='cuda:2', type=str, help='Device to use')
parser.add_argument('--name', default='base', type=str, help='Experiment name')
parser.add_argument('--interp', default='quintic', type=str, help='Experiment name')
parser.add_argument('--train_steps', default=100000, type=int, help='Number of training steps')
parser.add_argument('--epochs', default=2000, type=int, help='Number of epochs')
parser.add_argument('--mode', default=0, type=int, help='Mode: -1 for PINN learning, 0 for derivative learning, 1 for output learning')
parser.add_argument('--use_hessian', default=False, type=bool, help='Whether the hessian is used', action=argparse.BooleanOptionalAction)
parser.add_argument('--use_empirical', default=False, type=bool, help='Whether the hessian is used', action=argparse.BooleanOptionalAction)
parser.add_argument('--restricted', default=False, type=bool, help='Whether the hessian is used', action=argparse.BooleanOptionalAction)
parser.add_argument('--batch_size', default=512, type=int, help='Number of samples per step')
parser.add_argument('--layers', default=8, type=int, help='Number of layers in the network')
parser.add_argument('--units', default=128, type=int, help='Number of units per layer in the network')

args = parser.parse_args()
init_weight = args.init_weight
device = args.device
name = args.name
train_steps = args.train_steps
epochs = args.epochs
batch_size = args.batch_size
layers = args.layers
units = args.units
lr_init = args.lr_init
mode = args.mode
mom_weight = args.mom_weight
div_weight = args.div_weight
der_weight = args.der_weight
out_weight = args.out_weight
bc_weight = args.bc_weight
interp = args.interp
use_hessian = args.use_hessian
use_empirical = args.use_empirical
restricted = args.restricted
sys_weight = 1.


    
# Last model definitions
hidden_units=[units for _ in range(layers)]
activation = torch.nn.Tanh()

# Folder name
if use_empirical:
    if restricted:
        EXP_PATH = f'NS_empirical_restricted'
    else:
        EXP_PATH = f'NS_empirical'
else:
    if restricted:
        EXP_PATH = f'NS_true_restricted'
    else:
        EXP_PATH = f'NS_true'

restricted_x = 0.5


if restricted:
    x_min = 0.
    x_max = 1.7
    if name == 'base':
        model_class = NSRestrictedNet
    else:
        model_class = NSRestrictedPotentialNet
else:
    x_min = 0.
    x_max = 2.2
    if name == 'base':
        model_class = NSNet
    else:
        model_class = NSPotentialNet

model_1 = model_class(
    init_weight=init_weight,
    mom_weight=mom_weight,
    div_weight=div_weight,
    sys_weight=sys_weight,
    bc_weight=bc_weight,
    hidden_units=hidden_units,
    lr_init=lr_init,
    activation=activation,
    device=device,
    last_activation=False,
).to(device)


model_0 = model_class(
    init_weight=init_weight,
    mom_weight=mom_weight,
    sys_weight=sys_weight,
    div_weight=div_weight,
    bc_weight=bc_weight,
    hidden_units=hidden_units,
    lr_init=lr_init,
    activation=activation,
    device=device,
    last_activation=False,
).to(device)


model_sob = model_class(
    init_weight=init_weight,
    mom_weight=mom_weight,
    div_weight=div_weight,
    sys_weight=sys_weight,
    bc_weight=bc_weight,
    hidden_units=hidden_units,
    lr_init=lr_init,
    activation=activation,
    device=device,
    last_activation=False,
).to(device)

model_pinn = model_class(
    init_weight=init_weight,
    mom_weight=mom_weight,
    sys_weight=sys_weight,
    div_weight=div_weight,
    bc_weight=bc_weight,
    hidden_units=hidden_units,
    lr_init=lr_init,
    activation=activation,
    device=device,
    last_activation=False,
).to(device)


torch.cuda.empty_cache()
model_1.eval()
model_0.eval()
model_sob.eval()
model_pinn.eval()

# %%
import os
# %%
model_1.load_state_dict(torch.load(f'{EXP_PATH}/{name}/saved_models/NSnet_Output'))
model_0.load_state_dict(torch.load(f'{EXP_PATH}/{name}/saved_models/NSnet_Derivative'))
model_sob.load_state_dict(torch.load(f'{EXP_PATH}/{name}/saved_models/NSnet_Sobolev'))
model_pinn.load_state_dict(torch.load(f'{EXP_PATH}/{name}/saved_models/NSnet_PINN'))

# %%
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import TwoSlopeNorm



if not os.path.exists(f'{EXP_PATH}/{name}/plots'):
    os.mkdir(f'{EXP_PATH}/{name}/plots')

with open(f'NS_true/data_original.npy', 'rb') as f:
    pde_true = np.load(f)



dt = 0.01
dx = 0.01
t_vec = np.arange(0.,2.+dt,dt)
steps_print = 10
t_max = np.max(t_vec)
t_init = 8.

plot_downsample = 2

from matplotlib import gridspec
def plot_21(pred, true, X, Y, name, path):
    # Create the figure
    fig = plt.figure(figsize=(11, 12), layout='tight')  # Adjust figure size for better fit
    gs = gridspec.GridSpec(3, 2, width_ratios=[20, 1], height_ratios=[1, 1, 1])  # 3 rows, 2 columns
    
    vmin, vmax = (np.min(np.nan_to_num(true)), np.max(np.nan_to_num(true)))
    levels = np.linspace(vmin,vmax,100)
    
    
    # First plot (first row, left)
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.set_xlim((x_min-0.01,x_max + 0.01))
    ax1.set_ylim((y_min-0.01,y_max + 0.01))
    contour1 = ax1.contourf(X, Y, true, cmap='jet', vmin=vmin, vmax=vmax, levels=levels)
    for c in contour1.collections:
        c.set_rasterized(True)
    ax1.set_title(f'True {name}')
    
    # Second plot (second row, left)
    ax2 = fig.add_subplot(gs[1, 0])
    ax2.set_xlim((x_min-0.01,x_max + 0.01))
    ax2.set_ylim((y_min-0.01,y_max + 0.01))
    contour2 = ax2.contourf(X, Y, pred, cmap='jet', vmin=vmin, vmax=vmax, levels=levels)
    for c in contour2.collections:
        c.set_rasterized(True)
    ax2.set_title(f'Predicted {name}')

    # Shared colorbar for the first two plots
    cbar_ax = fig.add_subplot(gs[0:2, 1])  # Spans the first two rows
    fig.colorbar(contour1, cax=cbar_ax)

    # Third plot (third row, left)
    ax3 = fig.add_subplot(gs[2, 0])
    ax3.set_xlim((x_min-0.01,x_max + 0.01))
    ax3.set_ylim((y_min-0.01,y_max + 0.01))
    contour3 = ax3.contourf(X, Y, np.abs(true-pred), cmap='jet', levels=50)
    for c in contour3.collections:
        c.set_rasterized(True)
    ax3.set_title(f'{name} error')

    # Colorbar for the third plot
    cbar_ax2 = fig.add_subplot(gs[2, 1])  # Third row colorbar for the third plot
    fig.colorbar(contour3, cax=cbar_ax2)

    # Adjust layout to make it compressed
    plt.subplots_adjust(wspace=0.1, hspace=0.4)
    plt.savefig(path, format='pdf')
    plt.close()


def plot_results(model: NSNet, title_mode:str, title:str):
    with open(f'{EXP_PATH}/{name}/plots{title_mode}/testdata.npy', 'rb') as f:
        loss_combination = np.load(f)
    
    step_list = loss_combination[:,0]
    mom_loss = loss_combination[:,1]
    div_loss = loss_combination[:,2]
    out_loss = loss_combination[:,3]
    der_loss = loss_combination[:,4]
    init_loss = loss_combination[:,5]
    bc_loss = loss_combination[:,6]
    tot_loss = loss_combination[:,7]
    time_list = loss_combination[:,8]
    
    with open(f'{EXP_PATH}/{name}/errors.txt', 'a') as f:
        print(f'Plotting results for model {title_mode}', file=f)
        print(f'Mom loss: {np.mean(np.sqrt(mom_loss[-10:]))}', file=f)
        print(f'Div loss: {np.mean(np.sqrt(div_loss[-10:]))}', file=f)
        print(f'Out loss: {np.mean(np.sqrt(out_loss[-10:]))}', file=f)
        print(f'Der loss: {np.mean(np.sqrt(der_loss[-10:]))}', file=f)
        print(f'Init loss: {np.mean(np.sqrt(init_loss[-10:]))}', file=f)
        print(f'Bc loss: {np.mean(np.sqrt(bc_loss[-10:]))}', file=f)
        print(f'Time: {np.mean(time_list[-10:])}', file=f)
    
    
    print(f'Plotting results for model {title_mode}')
    # For each time step, plot the results
    for t_ind in [0,5,10]:#range(steps_print):
        
        t = t_max*t_ind/steps_print
        print(f'Plotting {t}')
        out_indexes = np.argwhere(np.float32(pde_true[:,0]) == t_init+t).reshape((-1))
        curr_pde = pde_true[out_indexes]
        #if restricted:
        #    rest_indexes = np.argwhere(curr_pde[:,1] >= restricted_x).reshape((-1))
        #    curr_pde = curr_pde[rest_indexes]
        #    curr_pde[:,1] = curr_pde[:,1] - restricted_x
        
        points_x = np.unique(curr_pde[:,1])
        points_y = np.unique(curr_pde[:,2])
        len_x = len(points_x)
        len_y = len(points_y)
        #X = np.tile(points_x,(len(points_y),1))
        #Y = np.tile(points_y,(len(points_y),1)).T

        #X,Y = np.meshgrid(points_x,points_y)
        X = curr_pde[:,1].reshape((len_x,len_y)).T
        Y = curr_pde[:,2].reshape((len_x,len_y)).T
        T = t*np.ones_like(X.reshape((-1)))
        pts = np.vstack([T,X.reshape(-1),Y.reshape(-1)]).T
        #pts = pde_true[out_indexes,:3]
                
        pred_plot = model.forward(torch.tensor(pts).to(device).float()).detach().cpu().numpy()
        xvel_pred_plot = pred_plot[:,0].reshape(X.shape)
        yvel_pred_plot = pred_plot[:,1].reshape(X.shape)
        pres_pred_plot = pred_plot[:,2].reshape(X.shape)
        
        
        xvel_true_plot = curr_pde[:,3].reshape(X.T.shape).T
        yvel_true_plot = curr_pde[:,4].reshape(X.T.shape).T
        pres_true_plot = curr_pde[:,5].reshape(X.T.shape).T
        
        ders_true = curr_pde[:,6:]
        dudy_true = curr_pde[:,8].reshape(X.T.shape).T 
        dvdx_true = curr_pde[:,10].reshape(X.T.shape).T
        
        
        plot_21(pres_pred_plot, pres_true_plot, X, Y, 'Pressure', f'{EXP_PATH}/{name}/plots{title_mode}/pressure_results{t}.pdf')
        plot_21(xvel_pred_plot, xvel_true_plot, X, Y, 'X-Velocity', f'{EXP_PATH}/{name}/plots{title_mode}/xvel_results{t}.pdf')
        plot_21(yvel_pred_plot, yvel_true_plot, X, Y, 'Y-Velocity', f'{EXP_PATH}/{name}/plots{title_mode}/yvel_results{t}.pdf')
        der_plot = vmap(jacrev(model.forward_single))(torch.tensor(pts).to(device).float()).detach().cpu().numpy().reshape(-1,3,3)
        vorticity_pred_plot = (der_plot[:,1,1] - der_plot[:,0,2]).reshape(X.shape)
        
        vorticity_true_plot = dvdx_true - dudy_true
        vorticity_true_plot = vorticity_true_plot.reshape(X.shape)
        
        plot_21(vorticity_pred_plot, vorticity_true_plot, X, Y, 'Vorticity', f'{EXP_PATH}/{name}/plots{title_mode}/vorticity_results{t}.pdf')
        
        hess = vmap(hessian(model.forward_single))(torch.tensor(pts).to(device).float()).detach().cpu().reshape(-1,3,3,3)
        lapl_u = torch.diagonal(hess[:,:2,1:,1:], dim1=2, dim2=3).sum(dim=2)
        lapl_norm = torch.norm(lapl_u, dim=1)
        lapl_plot = lapl_norm.reshape(X.shape)
        plt.figure()
        cont = plt.contourf(X,Y,lapl_plot,10,cmap='jet')
        for c in cont.collections:
            c.set_rasterized(True)
        plt.colorbar()
        plt.axis('scaled')
        plt.savefig(f'{EXP_PATH}/{name}/plots{title_mode}/laplacian_results{t}.pdf', format='pdf')
        plt.close()
        
        

if restricted:
    pde_true = pde_true[np.argwhere(pde_true[:,1] >= restricted_x).reshape((-1))]
    pde_true[:,1] = pde_true[:,1] - restricted_x


from plotting.ns_plotting import plot_errors
pde_true = np.nan_to_num(pde_true)

all_time_indexes = np.argwhere(np.float32(pde_true[:,0]) >= t_init).reshape((-1))
all_data = pde_true[all_time_indexes]
all_data[:,0] = all_data[:,0] - t_init
pts = all_data[:,:3]
batch_size = 512
n_batches = len(pts)//batch_size

#pts = pts[:n_batches*batch_size]
error_snapshots_1 = []
error_snapshots_0 = []
error_snapshots_sob = []
error_snapshots_pinn = []

pred_plot_1 = np.zeros((len(pts),3))
pred_plot_0 = np.zeros((len(pts),3))
pred_plot_sob = np.zeros((len(pts),3))
pred_plot_pinn = np.zeros((len(pts),3))

pred_der_1 = np.zeros((len(pts),3,3))
pred_der_0 = np.zeros((len(pts),3,3))
pred_der_sob = np.zeros((len(pts),3,3))
pred_der_pinn = np.zeros((len(pts),3,3))

for i in range(n_batches):
    #print(f'Batch {i}/{n_batches}')
    batch = torch.tensor(pts[i*batch_size:(i+1)*batch_size]).to(device).float()
    pred_plot_1[i*batch_size:(i+1)*batch_size] = model_1.forward(batch).detach().cpu().numpy()
    pred_plot_0[i*batch_size:(i+1)*batch_size] = model_0.forward(batch).detach().cpu().numpy()
    pred_plot_sob[i*batch_size:(i+1)*batch_size] = model_sob.forward(batch).detach().cpu().numpy()
    pred_plot_pinn[i*batch_size:(i+1)*batch_size] = model_pinn.forward(batch).detach().cpu().numpy()
    
    pred_der_1[i*batch_size:(i+1)*batch_size] = vmap(jacrev(model_1.forward_single))(batch).detach().cpu().numpy().reshape(-1,3,3)
    pred_der_0[i*batch_size:(i+1)*batch_size] = vmap(jacrev(model_0.forward_single))(batch).detach().cpu().numpy().reshape(-1,3,3)
    pred_der_sob[i*batch_size:(i+1)*batch_size] = vmap(jacrev(model_sob.forward_single))(batch).detach().cpu().numpy().reshape(-1,3,3)
    pred_der_pinn[i*batch_size:(i+1)*batch_size] = vmap(jacrev(model_pinn.forward_single))(batch).detach().cpu().numpy().reshape(-1,3,3)

pred_plot_1[n_batches*batch_size:] = model_1.forward(torch.tensor(pts[n_batches*batch_size:]).to(device).float()).detach().cpu().numpy()
pred_plot_0[n_batches*batch_size:] = model_0.forward(torch.tensor(pts[n_batches*batch_size:]).to(device).float()).detach().cpu().numpy()
pred_plot_sob[n_batches*batch_size:] = model_sob.forward(torch.tensor(pts[n_batches*batch_size:]).to(device).float()).detach().cpu().numpy()
pred_plot_pinn[n_batches*batch_size:] = model_pinn.forward(torch.tensor(pts[n_batches*batch_size:]).to(device).float()).detach().cpu().numpy()

pred_der_1[n_batches*batch_size:] = vmap(jacrev(model_1.forward_single))(torch.tensor(pts[n_batches*batch_size:]).to(device).float()).detach().cpu().numpy().reshape(-1,3,3)
pred_der_0[n_batches*batch_size:] = vmap(jacrev(model_0.forward_single))(torch.tensor(pts[n_batches*batch_size:]).to(device).float()).detach().cpu().numpy().reshape(-1,3,3)
pred_der_sob[n_batches*batch_size:] = vmap(jacrev(model_sob.forward_single))(torch.tensor(pts[n_batches*batch_size:]).to(device).float()).detach().cpu().numpy().reshape(-1,3,3)
pred_der_pinn[n_batches*batch_size:] = vmap(jacrev(model_pinn.forward_single))(torch.tensor(pts[n_batches*batch_size:]).to(device).float()).detach().cpu().numpy().reshape(-1,3,3)

error_1 = np.linalg.norm(pred_plot_1 - all_data[:,3:6], axis=1)
error_0 = np.linalg.norm(pred_plot_0 - all_data[:,3:6], axis=1)
error_sob = np.linalg.norm(pred_plot_sob - all_data[:,3:6], axis=1)
error_pinn = np.linalg.norm(pred_plot_pinn - all_data[:,3:6], axis=1)

der_error_0 = np.linalg.norm(pred_der_0 - all_data[:,6:].reshape((-1,3,3)), axis=(1,2))
der_error_1 = np.linalg.norm(pred_der_1 - all_data[:,6:].reshape((-1,3,3)), axis=(1,2))
der_error_sob = np.linalg.norm(pred_der_sob - all_data[:,6:].reshape((-1,3,3)), axis=(1,2))
der_error_pinn = np.linalg.norm(pred_der_pinn - all_data[:,6:].reshape((-1,3,3)), axis=(1,2))


mean_1 = np.mean(error_1)
std_1 = np.std(error_1)
mean_0 = np.mean(error_0)
std_0 = np.std(error_0)
mean_sob = np.mean(error_sob)
std_sob = np.std(error_sob)
mean_pinn = np.mean(error_pinn)
std_pinn = np.std(error_pinn)

mean_der_1 = np.mean(der_error_0)
std_der_1 = np.std(der_error_0)
mean_der_0 = np.mean(der_error_1)
std_der_0 = np.std(der_error_1)
mean_der_sob = np.mean(der_error_sob)
std_der_sob = np.std(der_error_sob)
mean_der_pinn = np.mean(der_error_pinn)
std_der_pinn = np.std(der_error_pinn)


print(f'Error 1: {mean_1}, std: {std_1}')
print(f'Error 0: {mean_0}, std: {std_0}')
print(f'Error sob: {mean_sob}, std: {std_sob}')
print(f'Error pinn: {mean_pinn}, std: {std_pinn}')

with open(f'{EXP_PATH}/{name}/errors.txt', 'w') as f:
    print(f'Error 1: {mean_1}, std: {std_1}', file=f)
    print(f'L2 loss 1: {np.sqrt(np.mean(error_1**2))}', file=f)
    print(f'Derivative error 1: {mean_der_1}, std: {std_der_1}', file=f)
    print(f'Derivative L2 loss 1: {np.sqrt(np.mean(der_error_1**2))}\n', file=f)
    
    print(f'Error 0: {mean_0}, std: {std_0}', file=f)
    print(f'L2 loss 0: {np.sqrt(np.mean(error_0**2))}', file=f)
    print(f'Derivative error 0: {mean_der_0}, std: {std_der_0}', file=f)
    print(f'Derivative L2 loss 0: {np.sqrt(np.mean(der_error_0**2))}\n', file=f)
    
    print(f'Error sob: {mean_sob}, std: {std_sob}', file=f)
    print(f'L2 loss sob: {np.sqrt(np.mean(error_sob**2))}', file=f)
    print(f'Derivative error sob: {mean_der_sob}, std: {std_der_sob}', file=f)
    print(f'Derivative L2 loss sob: {np.sqrt(np.mean(der_error_sob**2))}\n', file=f)
    
    print(f'Error pinn: {mean_pinn}, std: {std_pinn}', file=f)
    print(f'L2 loss pinn: {np.sqrt(np.mean(error_pinn**2))}', file=f)
    print(f'Derivative error pinn: {mean_der_pinn}, std: {std_der_pinn}', file=f)
    print(f'Derivative L2 loss pinn: {np.sqrt(np.mean(der_error_pinn**2))}\n', file=f)
    

print(f'Derivative error 1: {mean_der_1}, std: {std_der_1}')
print(f'Derivative error 0: {mean_der_0}, std: {std_der_0}')
print(f'Derivative error sob: {mean_der_sob}, std: {std_der_sob}')
print(f'Derivative error pinn: {mean_der_pinn}, std: {std_der_pinn}')


for t in np.unique(all_data[:, 0]):
    t_indexes = np.argwhere(all_data[:, 0] == t).reshape((-1))
    pred_1 = pred_plot_1[t_indexes]
    pred_0 = pred_plot_0[t_indexes]
    pred_sob = pred_plot_sob[t_indexes]
    pred_pinn = pred_plot_pinn[t_indexes]
    true_vals = all_data[t_indexes, 3:6]
    
    error_1 = np.linalg.norm(pred_1 - true_vals, axis=1)
    error_0 = np.linalg.norm(pred_0 - true_vals, axis=1)
    error_sob = np.linalg.norm(pred_sob - true_vals, axis=1)
    error_pinn = np.linalg.norm(pred_pinn - true_vals, axis=1)
    
    error_snapshots_1.append(error_1)
    error_snapshots_0.append(error_0)
    error_snapshots_sob.append(error_sob)
    error_snapshots_pinn.append(error_pinn)
    
 # Get the unique points
out_indexes = np.argwhere(np.float32(all_data[:,0]) == 0.).reshape((-1))
curr_pde = pde_true[out_indexes]
#if restricted:
#    rest_indexes = np.argwhere(curr_pde[:,1] >= restricted_x).reshape((-1))
#    curr_pde = curr_pde[rest_indexes]
#    curr_pde[:,1] = curr_pde[:,1] - restricted_x

points_x = np.unique(curr_pde[:,1])
points_y = np.unique(curr_pde[:,2])
len_x = len(points_x)
len_y = len(points_y)
#X = np.tile(points_x,(len(points_y),1))
#Y = np.tile(points_y,(len(points_y),1)).T

#X,Y = np.meshgrid(points_x,points_y)
X = curr_pde[:,1].reshape((len_x,len_y))
Y = curr_pde[:,2].reshape((len_x,len_y))
T = t*np.ones_like(X.reshape((-1)))
#pts = np.vstack([T,X.reshape(-1),Y.reshape(-1)])


mean_error_1 = np.mean(np.array(error_snapshots_1), axis=0).reshape((X.shape[0], X.shape[1]))
mean_error_0 = np.mean(np.array(error_snapshots_0), axis=0).reshape((X.shape[0], X.shape[1]))
mean_error_sob = np.mean(np.array(error_snapshots_sob), axis=0).reshape((X.shape[0], X.shape[1]))
mean_error_pinn = np.mean(np.array(error_snapshots_pinn), axis=0).reshape((X.shape[0], X.shape[1]))

from plotting.ns_plotting import ns_errorplot, ns_compareplot

mean_error_1[mean_error_1 > 1.5] = 1.5
mean_error_pinn[mean_error_pinn > 1.5] = 1.5

ns_errorplot(to_plot=np.array([mean_error_0, mean_error_1, mean_error_sob, mean_error_pinn]), model_names=['DERL', 'OUTL', 'SOB', 'PINN'], path=f'{EXP_PATH}/{name}/plots/', t=0., X=X, Y=Y, name='error_tavg', curr_pde=curr_pde)
ns_compareplot(to_plot=np.array([mean_error_1-mean_error_0, mean_error_sob-mean_error_0, mean_error_pinn-mean_error_0]), model_names=['OUTL', 'SOB', 'PINN'], path=f'{EXP_PATH}/{name}/plots/', t=0., X=X, Y=Y, name='error_tavg_compare', curr_pde=curr_pde)


ns_errorplot(to_plot=np.array([mean_error_0, mean_error_1, mean_error_sob]), model_names=['DERL', 'OUTL', 'SOB'], path=f'{EXP_PATH}/{name}/plots/', t=0., X=X, Y=Y, name='error_tavg_nopinn', curr_pde=curr_pde)
ns_compareplot(to_plot=np.array([mean_error_1-mean_error_0, mean_error_sob-mean_error_0]), model_names=['OUTL', 'SOB'], path=f'{EXP_PATH}/{name}/plots/', t=0., X=X, Y=Y, name='error_tavg_compare_nopinn', curr_pde=curr_pde)

mean_error_1 = np.mean(np.array(error_snapshots_1), axis=0).reshape((X.shape[0], X.shape[1]))
mean_error_0 = np.mean(np.array(error_snapshots_0), axis=0).reshape((X.shape[0], X.shape[1]))
mean_error_sob = np.mean(np.array(error_snapshots_sob), axis=0).reshape((X.shape[0], X.shape[1]))
mean_error_pinn = np.mean(np.array(error_snapshots_pinn), axis=0).reshape((X.shape[0], X.shape[1]))

batch_size = 512
moms_1 = []
moms_0 = []
moms_sob = []
moms_pinn = []

divs_1 = []
divs_0 = []
divs_sob = []
divs_pinn = []

for i in range(n_batches):
    #print(f'Batch {i}/{n_batches}')
    batch = torch.tensor(pts[i*batch_size:(i+1)*batch_size]).to(device).float()
    mom_1, div_1 = model_1.get_consistencies(batch)
    mom_0, div_0 = model_0.get_consistencies(batch)
    mom_sob, div_sob = model_sob.get_consistencies(batch)
    mom_pinn, div_pinn = model_pinn.get_consistencies(batch)
    moms_1.append(mom_1.detach().cpu().numpy())
    moms_0.append(mom_0.detach().cpu().numpy())
    moms_sob.append(mom_sob.detach().cpu().numpy())
    moms_pinn.append(mom_pinn.detach().cpu().numpy())
    divs_1.append(div_1.detach().cpu().numpy())
    divs_0.append(div_0.detach().cpu().numpy())
    divs_sob.append(div_sob.detach().cpu().numpy())
    divs_pinn.append(div_pinn.detach().cpu().numpy())
    
mom_1, div_1 = model_1.get_consistencies(torch.tensor(pts[n_batches*batch_size:]).to(device).float())
mom_0, div_0 = model_0.get_consistencies(torch.tensor(pts[n_batches*batch_size:]).to(device).float())
mom_sob, div_sob = model_sob.get_consistencies(torch.tensor(pts[n_batches*batch_size:]).to(device).float())
mom_pinn, div_pinn = model_pinn.get_consistencies(torch.tensor(pts[n_batches*batch_size:]).to(device).float())
moms_1.append(mom_1.detach().cpu().numpy())
moms_0.append(mom_0.detach().cpu().numpy())
moms_sob.append(mom_sob.detach().cpu().numpy())
moms_pinn.append(mom_pinn.detach().cpu().numpy())
divs_1.append(div_1.detach().cpu().numpy())
divs_0.append(div_0.detach().cpu().numpy())
divs_sob.append(div_sob.detach().cpu().numpy())
divs_pinn.append(div_pinn.detach().cpu().numpy())

mom_1 = np.concatenate(moms_1)
mom_0 = np.concatenate(moms_0)
mom_sob = np.concatenate(moms_sob)
mom_pinn = np.concatenate(moms_pinn)
div_1 = np.concatenate(divs_1)
div_0 = np.concatenate(divs_0)
div_sob = np.concatenate(divs_sob)
div_pinn = np.concatenate(divs_pinn)

with open(f'{EXP_PATH}/{name}/consistencies.txt', 'w') as f:
    print(f'Momentum 1: {np.mean(mom_1)}, pm {np.std(mom_1)}', file=f)
    print(f'Momentum 1 L2 loss: {np.sqrt(np.mean(mom_1**2))}', file=f)
    print(f'Divergence 1: {np.mean(div_1)}, pm {np.std(div_1)}', file=f)
    print(f'Divergence 1 L2 loss: {np.sqrt(np.mean(div_1**2))}\n', file=f)
    
    
    print(f'Momentum 0: {np.mean(mom_0)}, pm {np.std(mom_0)}, div: {np.mean(div_0)}, pm {np.std(div_0)}', file=f)
    print(f'Momentum 0 L2 loss: {np.sqrt(np.mean(mom_0**2))}', file=f)
    print(f'Divergence 0: {np.mean(div_0)}, pm {np.std(div_0)}', file=f)
    print(f'Divergence 0 L2 loss: {np.sqrt(np.mean(div_0**2))}\n', file=f)
    
    print(f'Momentum sob: {np.mean(mom_sob)}, pm {np.std(mom_sob)}, div: {np.mean(div_sob)}, pm {np.std(div_sob)}', file=f)
    print(f'Momentum sob L2 loss: {np.sqrt(np.mean(mom_sob**2))}', file=f)
    print(f'Divergence sob: {np.mean(div_sob)}, pm {np.std(div_sob)}', file=f)
    print(f'Divergence sob L2 loss: {np.sqrt(np.mean(div_sob**2))}\n', file=f)
    
    print(f'Momentum pinn: {np.mean(mom_pinn)}, pm {np.std(mom_pinn)}, div: {np.mean(div_pinn)}, pm {np.std(div_pinn)}', file=f)
    print(f'Momentum pinn L2 loss: {np.sqrt(np.mean(mom_pinn**2))}', file=f)
    print(f'Divergence pinn: {np.mean(div_pinn)}, pm {np.std(div_pinn)}', file=f)
    print(f'Divergence pinn L2 loss: {np.sqrt(np.mean(div_pinn**2))}\n', file=f)

# Now repeat the averaging over time but with consistentcies
mom_1_tavg = []
mom_0_tavg = []
mom_sob_tavg = []
mom_pinn_tavg = []
print(mom_1.shape)
div_1_tavg = []
div_0_tavg = []
div_sob_tavg = []
div_pinn_tavg = []
for t in np.unique(pts[:, 0]):
    t_indexes = np.argwhere(pts[:, 0] == t).reshape((-1))
    mom_1_tavg.append(mom_1[t_indexes])
    mom_0_tavg.append(mom_0[t_indexes])
    mom_sob_tavg.append(mom_sob[t_indexes])
    mom_pinn_tavg.append(mom_pinn[t_indexes])
    
    div_1_tavg.append(div_1[t_indexes])
    div_0_tavg.append(div_0[t_indexes])
    div_sob_tavg.append(div_sob[t_indexes])
    div_pinn_tavg.append(div_pinn[t_indexes])

mom_1_tavg = np.array(mom_1_tavg)
mom_0_tavg = np.array(mom_0_tavg)
mom_sob_tavg = np.array(mom_sob_tavg)
mom_pinn_tavg = np.array(mom_pinn_tavg)

div_1_tavg = np.array(div_1_tavg)
div_0_tavg = np.array(div_0_tavg)
div_sob_tavg = np.array(div_sob_tavg)
div_pinn_tavg = np.array(div_pinn_tavg)

mom_1_tavg[mom_1_tavg > 1.5] = 1.5 
div_1_tavg[div_1_tavg > 1.5] = 1.5
mom_pinn_tavg[mom_pinn_tavg > 1.5] = 1.5
div_pinn_tavg[div_pinn_tavg > 1.5] = 1.5

mom_1_tavg = np.mean(mom_1_tavg, axis=0).reshape((X.shape[0], X.shape[1]))
mom_0_tavg = np.mean(mom_0_tavg, axis=0).reshape((X.shape[0], X.shape[1]))
mom_sob_tavg = np.mean(mom_sob_tavg, axis=0).reshape((X.shape[0], X.shape[1]))
mom_pinn_tavg = np.mean(mom_pinn_tavg, axis=0).reshape((X.shape[0], X.shape[1]))

div_1_tavg = np.mean(div_1_tavg, axis=0).reshape((X.shape[0], X.shape[1]))
div_0_tavg = np.mean(div_0_tavg, axis=0).reshape((X.shape[0], X.shape[1]))
div_sob_tavg = np.mean(div_sob_tavg, axis=0).reshape((X.shape[0], X.shape[1]))
div_pinn_tavg = np.mean(div_pinn_tavg, axis=0).reshape((X.shape[0], X.shape[1]))

div_pinn_tavg[div_pinn_tavg > 1.5] = 1.5
mom_pinn_tavg[mom_pinn_tavg > 1.5] = 1.5

ns_errorplot(to_plot=np.array([mom_0_tavg, mom_1_tavg, mom_sob_tavg, mom_pinn_tavg]), model_names=['DERL', 'OUTL', 'SOB', 'PINN'], path=f'{EXP_PATH}/{name}/plots/', t=0., X=X, Y=Y, name='mom_tavg', curr_pde=curr_pde)
ns_errorplot(to_plot=np.array([div_0_tavg, div_1_tavg, div_sob_tavg, div_pinn_tavg]), model_names=['DERL', 'OUTL', 'SOB', 'PINN'], path=f'{EXP_PATH}/{name}/plots/', t=0., X=X, Y=Y, name='div_tavg', curr_pde=curr_pde)
ns_errorplot(to_plot=np.array([mom_0_tavg, mom_1_tavg, mom_sob_tavg]), model_names=['DERL', 'OUTL', 'SOB'], path=f'{EXP_PATH}/{name}/plots/', t=0., X=X, Y=Y, name='mom_tavg_nopinn', curr_pde=curr_pde)
ns_errorplot(to_plot=np.array([div_0_tavg, div_1_tavg, div_sob_tavg]), model_names=['DERL', 'OUTL', 'SOB'], path=f'{EXP_PATH}/{name}/plots/', t=0., X=X, Y=Y, name='div_tavg_nopinn', curr_pde=curr_pde)



ns_compareplot(to_plot=np.array([mom_1_tavg-mom_0_tavg, mom_sob_tavg-mom_0_tavg, mom_pinn_tavg-mom_0_tavg]), model_names=['OUTL', 'SOB', 'PINN'], path=f'{EXP_PATH}/{name}/plots/', t=0., X=X, Y=Y, name='mom_tavg_compare', curr_pde=curr_pde)
ns_compareplot(to_plot=np.array([div_1_tavg-div_0_tavg, div_sob_tavg-div_0_tavg, div_pinn_tavg-div_0_tavg]), model_names=['OUTL', 'SOB', 'PINN'], path=f'{EXP_PATH}/{name}/plots/', t=0., X=X, Y=Y, name='div_tavg_compare', curr_pde=curr_pde)
    
ns_compareplot(to_plot=np.array([mom_1_tavg-mom_0_tavg, mom_sob_tavg-mom_0_tavg]), model_names=['OUTL', 'SOB'], path=f'{EXP_PATH}/{name}/plots/', t=0., X=X, Y=Y, name='mom_tavg_compare_nopinn', curr_pde=curr_pde)
ns_compareplot(to_plot=np.array([div_1_tavg-div_0_tavg, div_sob_tavg-div_0_tavg]), model_names=['OUTL', 'SOB'], path=f'{EXP_PATH}/{name}/plots/', t=0., X=X, Y=Y, name='div_tavg_compare_nopinn', curr_pde=curr_pde)
    

print('Plotting results comparisons')
for t_ind in [0,5,10]:# range(steps_print):
    #plot_errors([model_1, model_0, model_sob, model_pinn], ['Vanilla', 'Derivative', 'Sobolev', 'PINN'], f'{EXP_PATH}/{name}/plots', t_max*t_ind/steps_print, pde_true, apx='')
    
    t = t_max*t_ind/steps_print
    print(f'Plotting {t}')
    # Indexes for the current time step
    out_indexes = np.argwhere(np.float32(all_data[:,0]) == t).reshape((-1))
    curr_pde = all_data[out_indexes]
    
    plot_errors([model_0, model_1, model_sob, model_pinn], ['DERL', 'OUTL', 'SOB', 'PINN'], f'{EXP_PATH}/{name}/plots', t_max*t_ind/steps_print, curr_pde=curr_pde, apx='')
    plot_errors([model_0, model_1, model_sob], ['DERL', 'OUTL', 'SOB'], f'{EXP_PATH}/{name}/plots', t_max*t_ind/steps_print, curr_pde=curr_pde, apx='_nopinn')
  
plot_results(model_1, 'Output', 'OUTL')
plot_results(model_0, 'Derivative', 'DERL')
plot_results(model_pinn, 'PINN', 'PINN')
plot_results(model_sob, 'Sobolev', 'SOB')

N = 20

def plot_loss_curves(to_plot, step_list, names, path, title, colors):
    plt.figure()
    for i in range(len(to_plot)):
        plot_y = np.convolve(to_plot[i], np.ones((N,))/N, mode='valid')
        plt.plot(step_list[:-(N-1)], plot_y, label=names[i], color=colors[i])
    plt.legend()
    plt.yscale('log')
    plt.title(title)
    plt.savefig(path, format='pdf')
    plt.close()
  
  
with open(f'{EXP_PATH}/{name}/plotsDerivative/testdata.npy', 'rb') as f:
    derivative_losses = np.load(f)

with open(f'{EXP_PATH}/{name}/plotsOutput/testdata.npy', 'rb') as f:
    output_losses = np.load(f)

with open(f'{EXP_PATH}/{name}/plotsSobolev/testdata.npy', 'rb') as f:
    sobolev_losses = np.load(f)
    
with open(f'{EXP_PATH}/{name}/plotsPINN/testdata.npy', 'rb') as f:
    pinn_losses = np.load(f)
    
step_list = derivative_losses[:,0]

plot_loss_curves([derivative_losses[:,1], output_losses[:,1], sobolev_losses[:,1], pinn_losses[:,1]], step_list, ['Derivative', 'Output', 'Sobolev', 'PINN'], f'{EXP_PATH}/{name}/plots/losses_mom.pdf', title='Momentum loss', colors=['blue', 'red', 'purple', 'green'])
plot_loss_curves([derivative_losses[:,2], output_losses[:,2], sobolev_losses[:,2], pinn_losses[:,2]], step_list, ['Derivative', 'Output', 'Sobolev', 'PINN'], f'{EXP_PATH}/{name}/plots/losses_div.pdf', title='Divergence loss', colors=['blue', 'red', 'purple', 'green'])
plot_loss_curves([derivative_losses[:,3], output_losses[:,3], sobolev_losses[:,3], pinn_losses[:,3]], step_list, ['Derivative', 'Output', 'Sobolev', 'PINN'], f'{EXP_PATH}/{name}/plots/losses_out.pdf', title='Output loss', colors=['blue', 'red', 'purple', 'green'])
plot_loss_curves([derivative_losses[:,4], output_losses[:,4], sobolev_losses[:,4], pinn_losses[:,4]], step_list, ['Derivative', 'Output', 'Sobolev', 'PINN'], f'{EXP_PATH}/{name}/plots/losses_der.pdf', title='Derivative loss', colors=['blue', 'red', 'purple', 'green'])
plot_loss_curves([derivative_losses[:,5], output_losses[:,5], sobolev_losses[:,5], pinn_losses[:,5]], step_list, ['Derivative', 'Output', 'Sobolev', 'PINN'], f'{EXP_PATH}/{name}/plots/losses_init.pdf', title='Init loss', colors=['blue', 'red', 'purple', 'green'])
plot_loss_curves([derivative_losses[:,6], output_losses[:,6], sobolev_losses[:,6], pinn_losses[:,6]], step_list, ['Derivative', 'Output', 'Sobolev', 'PINN'], f'{EXP_PATH}/{name}/plots/losses_bc.pdf', title='BC loss', colors=['blue', 'red', 'purple', 'green'])

    
model_pinn = model_class(
    init_weight=init_weight,
    mom_weight=mom_weight,
    sys_weight=sys_weight,
    div_weight=div_weight,
    bc_weight=bc_weight,
    hidden_units=hidden_units,
    lr_init=lr_init,
    activation=activation,
    device=device,
    last_activation=False,
).to(device)


torch.cuda.empty_cache()
model_pinn.eval() 

model_pinn.load_state_dict(torch.load(f'{EXP_PATH}/{name}/saved_models/NSnet_PINN'))
plot_results(model_pinn, 'PINN', 'PINN')
