## generate datasets for Burgers equation with RK4
## Bergers: Ut + U Ux = v Uxx

import torch
import numpy.random as rand
import numpy as np
from numpy import pi

from pde.ep import EP
from integrator.RK4 import RK4


torch.set_default_dtype(torch.float64)


dt = 0.002  # to generate exact reference solution
coarse = 150  # dt in training set is dt * coarse
n_step = 6000
n_train_path = 128
n_test_path = 32
n = 4

l0 = 2
g = 9.801
k_over_mg = 1

seed = 1234
train_ic_seeds = rand.randint(low = 0, high = 10000000 * n_train_path, size = n_train_path)
test_ic_seeds = rand.randint(low = 0, high = 10000000 * n_test_path, size = n_test_path)

train_fname = f'train_data/EP_dt{dt:.3f}_nstep{n_step}_coarse{coarse}_path{n_train_path}.npz'
test_fname = f'test_data/EP_dt{dt:.3f}_nstep{n_step}_coarse{coarse}_path{n_test_path}.npz'


pde = EP(g = g, k_over_mg = k_over_mg, l0 = l0)
solver = RK4(pde, device = 'cuda')

def callback(solver: RK4, arr: list):
    if solver.n_step % coarse == 0:
        print(f'time = {solver.t :.4f} / {dt * n_step :.4f}', end = '\r')
        arr.append(solver.x.clone())


def sample_ic(seed):
    if seed is not None:
        rand.seed(seed)
    r0 = rand.uniform(3/4, 5/4)
    rt = rand.normal(scale = 0.1)  # dr/dt = rt * l0
    phi0 = rand.uniform(-pi/8, pi/8)
    phit = rand.normal(scale = 0.1 * pi)
    return torch.tensor([phi0, r0 * l0, phit, rt * l0])
    
    
train_data = []
ic = torch.zeros((n, n_train_path))
for i in range(n_train_path):
    ic[:, i] = sample_ic(train_ic_seeds[i]) 
    
solver.set_ic(ic)   

solver.simulate(dt, n_step = n_step, callback = callback, arr = train_data)
data = np.array(torch.stack(train_data).permute(2, 0, 1).cpu())
np.savez(train_fname, data = data, dt = dt * coarse, g = g, l0 = l0, k_over_mg = k_over_mg)

print('\nTraining paths sampled.')


# test data
test_data = []
ic = torch.zeros((n, n_test_path))
for i in range(n_test_path):
    ic[:, i] = sample_ic(test_ic_seeds[i]) 
    
solver.set_ic(ic)   
solver.simulate(dt, n_step = n_step, callback = callback, arr = test_data)
data = np.array(torch.stack(test_data).permute(2, 0, 1).cpu())
np.savez(test_fname, data = data, dt = dt * coarse, g = g, l0 = l0, k_over_mg = k_over_mg)

print('\nTest paths sampled, done')


