## generate datasets for Burgers equation with RK4
## Bergers: Ut + U Ux = v Uxx

import torch
import numpy.random as rand
import numpy as np

from pde.burgers import Burgers
from integrator.RK4 import RK4

from sample_ic import sample_trig_ic


torch.set_default_dtype(torch.float64)


L = 0
R = 1
n = 96  # does not include boundary point
nu = 0.01

dt = 0.0001  # to generate exact reference solution
coarse = 100  # dt in training set is dt * coarse
n_step = 10000
n_train_path = 16
n_test_path = 4

seed = 123
train_ic_seeds = rand.randint(low = 0, high = 10000000, size = n_train_path)
test_ic_seeds = rand.randint(low = 0, high = 10000000, size = n_test_path)

suffix = 'gauss_debug'
train_fname = f'train_data/experiment/n{n}_nstep{n_step:d}_dt{dt:.5f}_nu{nu:.3f}_coarse{coarse}_path{n_train_path}_{suffix}.npz'
test_fname = f'test_data/experiment/n{n}_nstep{n_step:d}_dt{dt:.5f}_nu{nu:.3f}_coarse{coarse}_path{n_test_path}_{suffix}.npz'


#! CHECK THIS
xs = torch.linspace(L, R, n+1)
xs = xs[:-1]


pde = Burgers(xs, nu = nu)
solver = RK4(pde, device = 'cuda')

def callback(solver: RK4, arr: list):
    if solver.n_step % coarse == 0:
        print(f'time = {solver.t :.4f} / {dt * n_step :.4f}', end = '\r')
        arr.append(solver.x.clone())


train_data = []
ic = torch.zeros((n, n_train_path))
for i in range(n_train_path):
    ic[:, i] = sample_trig_ic(L, R, xs, train_ic_seeds[i])
     
solver.set_ic(ic)   
    
solver.simulate(dt, n_step = n_step, callback = callback, arr = train_data)
data = np.array(torch.stack(train_data).permute(2, 0, 1).cpu())
np.savez(train_fname, data = data, dt = dt * coarse, xs = xs, nu = nu)

print('\nTraining paths sampled.')


# test data
test_data = []
ic = torch.zeros((n, n_test_path))
for i in range(n_test_path):
    ic[:, i] = sample_trig_ic(L, R, xs, test_ic_seeds[i]) 
    
solver.set_ic(ic)   
solver.simulate(dt, n_step = n_step, callback = callback, arr = test_data)

data = np.array(torch.stack(test_data).permute(2, 0, 1).cpu())
np.savez(test_fname, data = data, dt = dt * coarse, xs = xs, nu = nu)

print('\nTest paths sampled, done')

