import torch
from torch.utils.data import TensorDataset
in_dim = 3
out_dim = 3
data_len = in_dim + out_dim + in_dim*out_dim
init_t = 0.
data_total = []
import numpy as np

seed = 30
import random
import os
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
gen = torch.Generator()
gen.manual_seed(seed)

# x definitions
x_min = 0.
x_max = 2.2
dx = 0.01
# y definition
y_min = 0.
y_max = 0.41
dy = 0.01
# t definition
t_min = 0.
t_max = 10.
dt = 0.05

# Initial time for the experiment
init_t = 8.0

# Restriction starting at x=...
restriction_x = 0.5

# Initial values for loading the fiiles
t_now = t_min
x_now = x_min
y_now = y_min

# Path to save the data
EXP_PATH = 'NS_true'
REST_PATH = 'NS_true_restricted'
file_name = 'NS_base.txt'

if not os.path.exists(EXP_PATH):
    os.mkdir(EXP_PATH)

# If the preprocessed data file exists already skip
if not os.path.isfile(f'{EXP_PATH}/data_original.npy'):
    print('Processing the file...')
    with open(file_name, 'rb') as f:
        # Read the lines
        lines = f.readlines()
        count = data_len
        data_point = []

        # For each line
        for line in lines:
            if count < data_len-2:
                dat = float(line.decode().replace('*^','e').replace('Indeterminate\r','nan'))
                data_point.append(dat)
            if count == data_len:
                data_point.append(t_now)
                data_point.append(x_now)
                data_point.append(y_now)
            
            count -= 1
            # When the count is 0, append the data point to the total data
            if count == 0:
                data_total.append(torch.tensor(data_point, dtype=torch.float64).reshape((1,-1)))
                data_point = []
                count = data_len
                y_now = np.round(y_now+dy, decimals=2)
                if y_now > y_max:
                    y_now = y_min 
                    x_now = np.round(x_now + dx, decimals=2)           
                if x_now > x_max:
                    x_now = x_min
                    t_now = np.round(t_now+dt, decimals=2)
    # Concatenate the data
    data_total = torch.cat(data_total)

    # Save the data with nans
    with open(f'{EXP_PATH}/data_original.npy', 'wb') as f:
        np.save(f, data_total.numpy())

    # Remove the nans
    data_clean = data_total[~torch.any(data_total.isnan(),dim=1)]

    data_total = torch.nan_to_num(data_total, 0.)
    # Save data without the nans
    with open(f'{EXP_PATH}/data_clean.npy', 'wb') as f:
        np.save(f, data_clean.numpy())


    print('Processing finished!')
    print(f'data shape: {data_total.shape}')

    # Save the coordinates of the points
    pts = data_clean[:,:in_dim]
    with open(f'{EXP_PATH}/coordinates.npy', 'wb') as f:
        np.save(f, pts.numpy())

    # Save the function values
    out = data_clean[:,in_dim:(in_dim+out_dim)]
    with open(f'{EXP_PATH}/solution.npy', 'wb') as f:
        np.save(f, out.numpy())
        
    Dout = data_clean[:,(in_dim+out_dim):]
    with open(f'{EXP_PATH}/derivatives.npy', 'wb') as f:
        np.save(f, out.numpy())

# If the files already exist, load them
if os.path.isfile(f'{EXP_PATH}/data_original.npy'):
    with open(f'{EXP_PATH}/derivatives.npy', 'rb') as f:
        Dout = torch.from_numpy(np.load(f))
        
    with open(f'{EXP_PATH}/solution.npy', 'rb') as f:
        out = torch.from_numpy(np.load(f))
        
    with open(f'{EXP_PATH}/coordinates.npy', 'rb') as f:
        pts = torch.from_numpy(np.load(f))
           
    with open(f'{EXP_PATH}/data_original.npy', 'rb') as f:
        data_total = torch.from_numpy(np.load(f))
        
    with open(f'{EXP_PATH}/data_clean.npy', 'rb') as f:
        data_clean = torch.from_numpy(np.load(f))

# Find the index of the initial time
init_idx = torch.argwhere(data_total[:,0] == init_t)
non_init_idx = torch.argwhere(data_total[:,0] > init_t)
print(f'Final index for initial conditions: {init_idx.max()}')

# Remove everything before the initial time
data_pde = data_total[non_init_idx.reshape((-1))]
data_init = data_total[init_idx.reshape((-1))]
data_pde[:,0] = data_pde[:,0] - init_t
data_init[:,0] = data_init[:,0] - init_t

# Get the initial vectors of values
init_pts = data_init[:,:in_dim]
init_out = data_init[:,in_dim:(in_dim+out_dim)]
init_Dout = data_init[:,(in_dim+out_dim):]

# Remove the points that are in the cylinder
init_ok_idx = torch.argwhere(torch.sqrt((init_pts[:,1]-1/5.)**2 + (init_pts[:,2]-1/5.)**2) > 1/20.+dx)
init_pts = init_pts[init_ok_idx.reshape((-1))]
init_out = init_out[init_ok_idx.reshape((-1))]
init_Dout = init_Dout[init_ok_idx.reshape((-1))]

print('Saving the initial values...')
init_data = TensorDataset(init_pts, init_out, init_Dout)
torch.save(init_data, f'{EXP_PATH}/init_data.pth')
print('Finished!')



# Get the vectors of values
pde_pts = data_pde[:,:in_dim]
pde_out = data_pde[:,in_dim:(in_dim+out_dim)]
pde_Dout = data_pde[:,(in_dim+out_dim):]

# Remove the points that are in the cylinder
pde_ok_idx = torch.argwhere(torch.sqrt((pde_pts[:,1]-1/5.)**2 + (pde_pts[:,2]-1/5.)**2) > 1/20.+dx)
pde_pts = pde_pts[pde_ok_idx.reshape((-1))]
pde_out = pde_out[pde_ok_idx.reshape((-1))]
pde_Dout = pde_Dout[pde_ok_idx.reshape((-1))]
pde_data = TensorDataset(pde_pts, pde_out, pde_Dout)

print('Saving the PDE values...')
pde_data = TensorDataset(pde_pts, pde_out, pde_Dout)
torch.save(pde_data, f'{EXP_PATH}/pde_data.pth')
print('Finished!')


print('Generating the restricted data...')
# Restriction of the initial conditions
restricted_init_indexes = torch.argwhere(init_pts[:,1] >= restriction_x)
restricted_init_pts = init_pts[restricted_init_indexes.reshape((-1))]
restricted_init_pts[:,1] = restricted_init_pts[:,1] - restriction_x
restricted_init_out = init_out[restricted_init_indexes.reshape((-1))]
restricted_init_Dout = init_Dout[restricted_init_indexes.reshape((-1))]
restricted_init_data = TensorDataset(restricted_init_pts, restricted_init_out, restricted_init_Dout)

# Restriction of the empirical data
restricted_indexes = torch.argwhere(pde_pts[:,1] > restriction_x)
restricted_pts_eval = pde_pts[restricted_indexes.reshape((-1))]
restricted_pts_eval[:,1] = restricted_pts_eval[:,1] - restriction_x
restricted_out_eval = pde_out[restricted_indexes.reshape((-1))]
restricted_Dout_eval = pde_Dout[restricted_indexes.reshape((-1))]
restricted_pde_data = TensorDataset(restricted_pts_eval, restricted_out_eval, restricted_Dout_eval)

# We need boundary conditions as well
# Here we select the standard boundary conditions
restricted_bc_indexes = torch.argwhere(torch.logical_and(torch.logical_or(torch.logical_or(torch.logical_or(pde_pts[:,2] == 0., pde_pts[:,2] == 0.41), pde_pts[:,1] == 2.2), pde_pts[:,1] == restriction_x), pde_pts[:,1] >= restriction_x))
restricted_bc_pts = pde_pts[restricted_bc_indexes.reshape((-1))]
#restricted_bc_indexes = torch.argwhere(pde_pts[:,1] == restriction_x)
#restricted_bc_pts = pde_pts[restricted_bc_indexes.reshape((-1))]
restricted_bc_pts[:,1] = restricted_bc_pts[:,1] - restriction_x
restricted_bc_out = pde_out[restricted_bc_indexes.reshape((-1))]
restricted_bc_Dout = pde_Dout[restricted_bc_indexes.reshape((-1))]
restricted_bc_data = TensorDataset(restricted_bc_pts, restricted_bc_out, restricted_bc_Dout)

print(restricted_bc_pts.shape)

if not os.path.exists(REST_PATH):
    os.mkdir(REST_PATH)

torch.save(restricted_pde_data, f'{REST_PATH}/pde_data.pth')
torch.save(restricted_init_data, f'{REST_PATH}/init_data.pth')
torch.save(restricted_bc_data, f'{REST_PATH}/bc_data.pth')

print('Finished!')