import torch
from math import pi, sqrt
import numpy as np
torch.set_default_dtype(torch.float64)

from integrator.edtrk4 import EDTRK4

h = 0.02 # step size
nstart = 10
frame_per_start = 10
nplt = frame_per_start * nstart # collect data samples every nplt steps
tstart = 500 * h * nplt
ttrain = 1500 * h * nplt
tmax = 1700 * h * nplt # time interval [0, tmax]
N = 48
M = 16

seed = 123

train_data_name = f'train_data/N{N}_h{h:.3f}_coarse{nplt}_nstart{nstart}_train_{int(tstart)}_{int(ttrain)}_test{int(tmax)}.npz'
test_data_name = f'test_data/N{N}_h{h:.3f}_coarse{nplt}_nstart{nstart}_train_{int(tstart)}_{int(ttrain)}_test{int(tmax)}.npz'

L = 2 * pi / sqrt(0.085)

solver = EDTRK4(L, N, M, h)

train_data = []  # of shape nframe = T/frame_per_start, ngrid
test_data = []
train_times = []
test_times = []

def callback(solver: EDTRK4, train_arr: list, test_arr: list):  
    if solver.n_step % frame_per_start == 0:
        if solver.t > tstart + 1E-6 and solver.t < ttrain + 1E-6 :
            if len(train_arr) == 0:
                print(f'first entry into training set at time {solver.t}')
            train_arr.append(solver.u.clone())
            train_times.append(solver.t)
        if solver.t > ttrain + 1E-6:
            if len(test_arr) == 0:
                print(f'\nTraining set sampled, first entry into test set at time {solver.t}')
            test_arr.append(solver.u.clone())
            test_times.append(solver.t)
        print(f'time = {solver.t :.4f} / {tmax :.4f}', end = '\r')
    
xs = torch.arange(1, N+1)/N * L

ic = torch.cos(2 * pi / L * xs)
     
solver.set_ic(ic)   
    
solver.simulate(T = tmax, callback = callback, train_arr = train_data, test_arr = test_data)

print(f'\nlast entry into test set at time {test_times[-1]}')
print(f'train data size {len(train_data)}')
print(f'test data size {len(test_data)}')

data = torch.stack(train_data, dim = 1).cpu().numpy()
data = np.transpose(data.reshape(N, -1, nstart), (2, 1, 0))
times = np.transpose(np.array(train_times).reshape(-1, nstart), (1, 0))
np.savez(train_data_name, times = train_times, data = data, xs = xs, dt = h * nplt, N = N, M = M, L = L)

data = torch.stack(test_data, dim = 1).cpu().numpy()
data = np.transpose(data.reshape(N, -1, nstart), (2, 1, 0))
times = np.transpose(np.array(train_times).reshape(-1, nstart), (1, 0))
np.savez(test_data_name, times = test_times, data = data, xs = xs, dt = h * nplt, N = N, M = M, L = L)