import torch
from torch.utils.data import TensorDataset
in_dim = 3
out_dim = 3
data_len = in_dim + out_dim
init_t = 0.
data_total = []
import numpy as np
seed = 30
import random
import os
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
gen = torch.Generator()
gen.manual_seed(seed)

# x definitions
x_min = 0.
x_max = 2.2
dx = 0.01
# y definition
y_min = 0.
y_max = 0.41
dy = 0.01
# t definition
t_min = 0.
t_max = 10.
dt = 0.01

# Initial time for the experiment
init_t = 8.0

# Restriction starting at x=...
restriction_x = 0.5

# Number of points for the empirical generation
N = int(1e6)

# Interpolation type
interp_type = 'quintic'

# Derivation type
der_type = 'forward'

# Spacing for the empirical derivatives
der_dx = 1e-4
der_dy = 1e-4
der_dt = 1e-4

# Initial values for loading the fiiles
t_now = t_min
x_now = x_min
y_now = y_min

# Path to save the data
EXP_PATH = 'NS_empirical'
REST_PATH = 'NS_empirical_restricted'
file_name = 'NS_out.txt'

# Generate the path
if not os.path.exists(EXP_PATH):
    os.mkdir(EXP_PATH)

# If the preprocessed data file exists already skip
if not os.path.isfile(f'{EXP_PATH}/data_original.npy'):
    print('Processing the file...')
    with open(file_name, 'rb') as f:
        # Read the lines
        lines = f.readlines()
        count = data_len
        data_point = []

        # For each line
        for line in lines:
            if count < data_len-2:
                dat = float(line.decode().replace('*^','e').replace('Indeterminate\r','nan'))
                data_point.append(dat)
            if count == data_len:
                data_point.append(t_now)
                data_point.append(x_now)
                data_point.append(y_now)
            
            count -= 1
            # When the count is 0, append the data point to the total data
            if count == 0:
                data_total.append(torch.tensor(data_point, dtype=torch.float64).reshape((1,-1)))
                data_point = []
                count = data_len
                y_now = np.round(y_now+dy, decimals=2)
                if y_now > y_max:
                    y_now = y_min 
                    x_now = np.round(x_now + dx, decimals=2)           
                if x_now > x_max:
                    x_now = x_min
                    t_now = np.round(t_now+dt, decimals=2)
    # Concatenate the data
    data_total = torch.cat(data_total)

    # Save the data with nans
    with open(f'{EXP_PATH}/data_original.npy', 'wb') as f:
        np.save(f, data_total.numpy())

    # Remove the nans
    data_nonan = data_total[~torch.any(data_total.isnan(),dim=1)]

    data_total = torch.nan_to_num(data_total, 0.)
    # Save data without the nans
    with open(f'{EXP_PATH}/data_clean.npy', 'wb') as f:
        np.save(f, data_nonan.numpy())


    print('Processing finished!')
    print(f'data shape: {data_total.shape}')


    # Save the coordinates of the points
    pts = data_total[:,:in_dim]
    with open(f'{EXP_PATH}/coordinates.npy', 'wb') as f:
        np.save(f, pts.numpy())

    # Save the function values
    out = data_total[:,in_dim:(in_dim+out_dim)]
    with open(f'{EXP_PATH}/solution.npy', 'wb') as f:
        np.save(f, out.numpy())

# If the files already exist, load them
if os.path.isfile(f'{EXP_PATH}/data_original.npy'):
    with open(f'{EXP_PATH}/solution.npy', 'rb') as f:
        out = torch.from_numpy(np.load(f))
        
    with open(f'{EXP_PATH}/coordinates.npy', 'rb') as f:
        pts = torch.from_numpy(np.load(f))
           
    with open(f'{EXP_PATH}/data_original.npy', 'rb') as f:
        data_total = torch.from_numpy(np.load(f))


# Find the index of the initial time
init_idx = torch.argwhere(data_total[:,0] == init_t)
non_init_idx = torch.argwhere(data_total[:,0] > init_t)
print(f'Final index for initial conditions: {init_idx.max()}')

# Remove everything before the initial time
data_pde = data_total[non_init_idx.reshape((-1))]
data_init = data_total[init_idx.reshape((-1))]
data_pde[:,0] = data_pde[:,0] - init_t
data_init[:,0] = data_init[:,0] - init_t

# Get the initial vectors of values
init_pts = data_init[:,:in_dim]
init_out = data_init[:,in_dim:(in_dim+out_dim)]

# Remove the points that are in the cylinder
init_ok_idx = torch.argwhere(torch.sqrt((init_pts[:,1]-1/5.)**2 + (init_pts[:,2]-1/5.)**2) > 1/20.+dx)
init_pts = init_pts[init_ok_idx.reshape((-1))]
init_out = init_out[init_ok_idx.reshape((-1))]

print('Saving the initial values...')
init_data = TensorDataset(init_pts, init_out)
torch.save(init_data, f'{EXP_PATH}/init_data.pth')



print('Generation of the empirical dataset...')
t_max_data = t_max - init_t
# Generate the random coordinates, stay a bit far from the boundaries to calculate the derivatives
points_x = torch.distributions.Uniform(0.+der_dx,2.2-der_dx).sample((N,1))
points_y = torch.distributions.Uniform(0.+der_dx,0.41-der_dx).sample((N,1))
points_t = torch.distributions.Uniform(0.+der_dx,t_max-der_dt).sample((N,1))
pts_eval = torch.column_stack((points_t, points_x, points_y))


# Shifters for the derivatives
dx_vec = np.column_stack((np.zeros(points_x.shape), der_dx*np.ones(points_x.shape), np.zeros(points_x.shape)))
dy_vec = np.column_stack((np.zeros(points_x.shape), np.zeros(points_x.shape), der_dy*np.ones(points_x.shape)))
dt_vec = np.column_stack((der_dt*np.ones(points_x.shape), np.zeros(points_x.shape), np.zeros(points_x.shape)))
t_vec = np.arange(start=0., stop=t_max+dt, step=dt)
x_vec = np.arange(start=0., stop=2.2+dx, step=dx)
y_vec = np.arange(start=0., stop=0.41+dy, step=dy)

# Interpolation
from scipy.interpolate import RegularGridInterpolator
print('Generating the interpolators, using the following interpolation method:', interp_type)
u_interp = RegularGridInterpolator((t_vec,x_vec,y_vec), out[:,0].reshape((len(t_vec), len(x_vec), len(y_vec))).numpy(), method=interp_type)
v_interp = RegularGridInterpolator((t_vec,x_vec,y_vec), out[:,1].reshape((len(t_vec), len(x_vec), len(y_vec))).numpy(), method=interp_type)
p_interp = RegularGridInterpolator((t_vec,x_vec,y_vec), out[:,2].reshape((len(t_vec), len(x_vec), len(y_vec))).numpy(), method=interp_type)
print('Created interpolators!')


print('Calculating the derivatives, using the following derivative type:', der_type)
if der_type == 'centered':
    u_x = torch.tensor((u_interp(pts_eval + dx_vec) - u_interp(pts_eval-dx_vec))/(2*der_dx)).reshape((-1,1))
    u_y = torch.tensor((u_interp(pts_eval + dy_vec) - u_interp(pts_eval-dy_vec))/(2*der_dy)).reshape((-1,1))
    u_t = torch.tensor((u_interp(pts_eval + dt_vec) - u_interp(pts_eval-dt_vec))/(2*der_dt)).reshape((-1,1))
else:
    u_x = torch.tensor((u_interp(pts_eval + dx_vec) - u_interp(pts_eval))/(der_dx)).reshape((-1,1))
    u_y = torch.tensor((u_interp(pts_eval + dy_vec) - u_interp(pts_eval))/(der_dy)).reshape((-1,1))
    u_t = torch.tensor((u_interp(pts_eval + dt_vec) - u_interp(pts_eval))/(der_dt)).reshape((-1,1))
#u_xx = torch.tensor((u_interp(pts_eval + 2*dx_vec) - 2*u_interp(pts_eval + dx_vec) + u_interp(pts_eval))/(der_dx**2)).reshape((-1,1))
#u_yy = torch.tensor((u_interp(pts_eval + 2*dy_vec) - 2*u_interp(pts_eval + dy_vec) + u_interp(pts_eval))/(der_dy**2)).reshape((-1,1))
#u_tt = torch.tensor((u_interp(pts_eval + 2*dt_vec) - 2*u_interp(pts_eval + dt_vec) + u_interp(pts_eval))/(der_dt**2)).reshape((-1,1))
print('U derivatives calculated!')
if der_type == 'centered':
    v_x = torch.tensor((v_interp(pts_eval + dx_vec) - v_interp(pts_eval-dx_vec))/(2*der_dx)).reshape((-1,1))
    v_y = torch.tensor((v_interp(pts_eval + dy_vec) - v_interp(pts_eval-dy_vec))/(2*der_dy)).reshape((-1,1))
    v_t = torch.tensor((v_interp(pts_eval + dt_vec) - v_interp(pts_eval-dt_vec))/(2*der_dt)).reshape((-1,1))
else:
    v_x = torch.tensor((v_interp(pts_eval + dx_vec) - v_interp(pts_eval))/(der_dx)).reshape((-1,1))
    v_y = torch.tensor((v_interp(pts_eval + dy_vec) - v_interp(pts_eval))/(der_dy)).reshape((-1,1))
    v_t = torch.tensor((v_interp(pts_eval + dt_vec) - v_interp(pts_eval))/(der_dt)).reshape((-1,1))
#v_xx = torch.tensor((v_interp(pts_eval + 2*dx_vec) - 2*v_interp(pts_eval + dx_vec) + v_interp(pts_eval))/(der_dx**2)).reshape((-1,1))
#v_yy = torch.tensor((v_interp(pts_eval + 2*dy_vec) - 2*v_interp(pts_eval + dy_vec) + v_interp(pts_eval))/(der_dy**2)).reshape((-1,1))
#v_tt = torch.tensor((v_interp(pts_eval + 2*dt_vec) - 2*v_interp(pts_eval + dt_vec) + v_interp(pts_eval))/(der_dt**2)).reshape((-1,1))
print('V derivatives calculated!')
if der_type == 'centered':
    p_x = torch.tensor((p_interp(pts_eval + dx_vec) - p_interp(pts_eval-dx_vec))/(2*der_dx)).reshape((-1,1))
    p_y = torch.tensor((p_interp(pts_eval + dy_vec) - p_interp(pts_eval-dy_vec))/(2*der_dy)).reshape((-1,1))
    p_t = torch.tensor((p_interp(pts_eval + dt_vec) - p_interp(pts_eval-dt_vec))/(2*der_dt)).reshape((-1,1))
else:
    p_x = torch.tensor((p_interp(pts_eval + dx_vec) - p_interp(pts_eval))/(der_dx)).reshape((-1,1))
    p_y = torch.tensor((p_interp(pts_eval + dy_vec) - p_interp(pts_eval))/(der_dy)).reshape((-1,1))
    p_t = torch.tensor((p_interp(pts_eval + dt_vec) - p_interp(pts_eval))/(der_dt)).reshape((-1,1))
#p_xx = torch.tensor((p_interp(pts_eval + 2*dx_vec) - 2*p_interp(pts_eval + dx_vec) + p_interp(pts_eval))/(der_dx**2)).reshape((-1,1))
#p_yy = torch.tensor((p_interp(pts_eval + 2*dy_vec) - 2*p_interp(pts_eval + dy_vec) + p_interp(pts_eval))/(der_dy**2)).reshape((-1,1))
#p_tt = torch.tensor((p_interp(pts_eval + 2*dt_vec) - 2*p_interp(pts_eval + dt_vec) + p_interp(pts_eval))/(der_dt**2)).reshape((-1,1))
print('P derivatives calculated!')

print('Calculating the solution values...')
out_eval = torch.column_stack((torch.tensor(u_interp(pts_eval)).reshape((-1,1)), torch.tensor(v_interp(pts_eval)).reshape((-1,1)), torch.tensor(p_interp(pts_eval)).reshape((-1,1))))
print('Solution values calculated!')
# Stack the derivatives
Dout_eval = torch.column_stack((u_t, u_x, u_y, v_t, v_x, v_y, p_t, p_x, p_y)).reshape((-1,3,3))

# Remove the points that are in the cylinder
nan_idx = torch.argwhere(torch.sqrt((pts_eval[:,1]-1/5.)**2 + (pts_eval[:,2]-1/5.)**2) > 1/20.+dx)
# Remove the values in the cylnder
pts_eval = pts_eval[nan_idx.reshape((-1))]
out_eval = out_eval[nan_idx.reshape((-1))]
Dout_eval = Dout_eval[nan_idx.reshape((-1))]
#Hout_eval = Hout_eval[nan_idx.reshape((-1))]

non_init_idx = torch.argwhere(pts_eval[:,0] > init_t)
pts_eval = pts_eval[non_init_idx.reshape((-1))]
out_eval = out_eval[non_init_idx.reshape((-1))]
Dout_eval = Dout_eval[non_init_idx.reshape((-1))]

pts_eval[:,0] = pts_eval[:,0] - init_t

print(pts_eval[:,0])

print('Saving the data...')

print(f'pts_eval.shape: {pts_eval.shape}')
print(f'out_eval.shape: {out_eval.shape}')
print(f'Dout_eval.shape: {Dout_eval.shape}')
#print(f'Hout_eval.shape: {Hout_eval.shape}')


pde_data = TensorDataset(pts_eval, out_eval, Dout_eval)
torch.save(pde_data, f'{EXP_PATH}/pde_data_{interp_type}.pth')

print('Generating the restricted data...')
# Restriction of the initial conditions
restricted_init_indexes = torch.argwhere(init_pts[:,1] >= restriction_x)
restricted_init_pts = init_pts[restricted_init_indexes.reshape((-1))]
restricted_init_pts[:,1] = restricted_init_pts[:,1] - restriction_x
restricted_init_out = init_out[restricted_init_indexes.reshape((-1))]
restricted_init_data = TensorDataset(restricted_init_pts, restricted_init_out)

# Restriction of the empirical data
restricted_indexes = torch.argwhere(pts_eval[:,1] > restriction_x)
restricted_pts_eval = pts_eval[restricted_indexes.reshape((-1))]
restricted_pts_eval[:,1] = restricted_pts_eval[:,1] - restriction_x
restricted_out_eval = out_eval[restricted_indexes.reshape((-1))]
restricted_Dout_eval = Dout_eval[restricted_indexes.reshape((-1))]
restricted_pde_data = TensorDataset(restricted_pts_eval, restricted_out_eval, restricted_Dout_eval)

pde_pts = data_pde[:,:in_dim]
pde_out = data_pde[:,in_dim:(in_dim+out_dim)]
pde_Dout = data_pde[:,(in_dim+out_dim):]

# We need boundary conditions as well
restricted_bc_indexes = torch.argwhere(pde_pts[:,1] == restriction_x)
restricted_bc_pts = pde_pts[restricted_bc_indexes.reshape((-1))]
restricted_bc_pts[:,1] = restricted_bc_pts[:,1] - restriction_x
restricted_bc_out = pde_out[restricted_bc_indexes.reshape((-1))]
restricted_bc_data = TensorDataset(restricted_bc_pts, restricted_bc_out)

if not os.path.exists(REST_PATH):
    os.mkdir(REST_PATH)

torch.save(restricted_pde_data, f'{REST_PATH}/pde_data_{interp_type}.pth')
torch.save(restricted_init_data, f'{REST_PATH}/init_data.pth')
torch.save(restricted_bc_data, f'{REST_PATH}/bc_data.pth')

print('Finished!')
