import numpy as np
from scipy.integrate import odeint
from scipy.fftpack import diff as psdiff

np.random.seed(2022)
def kdv_exact(x, c):
    """Profile of the exact solution to the KdV for a single soliton on the real line."""
    u = 0.5*c*np.cos(0.5*np.sqrt(c)*x)
    return u

def kdv(u, t, L):
    """Differential equations for the KdV equation, discretized in x."""
    # Compute the x derivatives using the pseudo-spectral method.
    ux = psdiff(u, period=L)
    uxxx = psdiff(u, period=L, order=3)

    # Compute du/dt.    
    dudt = -6*u*ux - uxxx

    return dudt

def kdv_solution(u0, t, L):
    """Use odeint to solve the KdV equation on a periodic domain.
    
    `u0` is initial condition, `t` is the array of time values at which
    the solution is to be computed, and `L` is the length of the periodic
    domain."""

    sol = odeint(kdv, u0, t, args=(L,), mxstep=5000)
    return sol

L = 50.0 * 16
N = 1024
dx = L / (N - 1.0)
x = np.linspace(0, (1-1.0/N)*L, N)

# Set the initial conditions.
# Not exact for two solitons on a periodic domain, but close enough...
sols = []
a = []
for i in range(10000):
    c1 = np.random.randn(1) * 0.1 
    c2 = np.random.randn(1) * 0.2
    u0 = kdv_exact(x-0.33*L, (0.75+c1)*0.5) + kdv_exact(x-0.65*L, (0.4+c2)*0.5)

    # Set the time sample grid.
    T = 1
    t = np.linspace(0, T, 100)


    sol = kdv_solution(u0, t, L)
    if np.isnan(sol).any() == False:
        sols.append(sol)
        a.append(u0)
    if len(sols) >= 1200:
        break
sols = np.stack(sols, axis=0)
a = np.stack(a, axis=0)
sols = sols.transpose(0,-1,1)

data = {'a': a, 'u': sols}
path = '/usr/commondata/public/Neural_Dynamics/CTmixer/dataset/kdv_equation/dataset_sr1024.pkl'
import pickle 
with open(path, "wb") as f:
    pickle.dump(data, f, protocol = 4)