from .dataPDE import *
import numpy as np
import os
from utils import to_pickle, from_pickle

M = 50
width = 10.
dt = 0.001
a = -6.
b = 1.

default_type = torch.get_default_dtype()
torch.set_default_dtype(torch.float64)
kdv1d = KdV1d(width=width, ndiv=M, a=a, b=b, device='cpu')
torch.set_default_dtype(default_type)


def sample_orbits(trials, n_steps, **kwargs):

    t_eval = np.arange(0, n_steps + 1) * dt
    x = width * np.arange(M) / M

    sech = lambda a: 1 / np.cosh(a)
    u_results = []
    for i in range(trials):
        print('generating KdV dataset,', i, '/', trials, end='\r')
        k1, k2 = np.random.uniform(0.5, 2.0, 2)
        d1 = np.random.uniform(0.2, 0.3, 1)
        d2 = d1 + np.random.uniform(0.2, 0.5, 1)
        x = width * np.arange(M) / M
        u0 = 0
        u0 += (-6. / a) * 2 * k1**2 * sech(k1 * (x - width * d1))**2
        u0 += (-6. / a) * 2 * k2**2 * sech(k2 * (x - width * d2))**2
        shift = np.random.randint(0, M)
        u0 = np.concatenate([u0[shift:], u0[:shift]], axis=-1)
        if np.random.randint(0, 2) == 1:
            u0 = u0[::-1].copy()
        u_result = kdv1d.dvdmint(u0, t_eval)
        u_result = u_result.reshape(-1, 1, M)
        u_results.append(u_result)
    u_results = np.stack(u_results, axis=0)

    data = {}
    dudt = kdv1d.dudt(u_results.reshape(-1, 1, M)).reshape(u_results.shape)
    energy = get_energies(u_results)
    data['u'] = u_results
    data['dudt'] = dudt
    data['energies'] = energy
    data['dt'] = dt
    data['dx'] = kdv1d.dx
    data['M'] = M
    data['t_eval'] = t_eval
    data['meta'] = {}
    data['meta']['n_steps'] = n_steps
    return data


def make_orbits_dataset(trials=100, test_trials=10, steps=500, test_steps=10000, **kwargs):
    data = sample_orbits(trials=trials, n_steps=steps, **kwargs)
    data_test = sample_orbits(trials=test_trials, n_steps=test_steps, **kwargs)
    data['test'] = data_test
    return data


def get_energies(state):
    ret = {}
    ret['energy'] = kdv1d.get_energy(state.reshape(-1, 1, M)).reshape(state.shape[: -2])
    return ret


def get_dataset(name, save_dir, **kwargs):
    '''Returns a KdV dataset. Also constructs
    the dataset if no saved version is available.'''
    torch.set_default_dtype(torch.float64)

    os.makedirs(f'{save_dir}/dataset') if not os.path.exists(f'{save_dir}/dataset') else None
    path = f'{save_dir}/dataset/{name}-dataset.pkl'

    try:
        data = from_pickle(path)
        print("Successfully loaded data from {}".format(path))
    except:
        np.random.seed(100)
        print("Had a problem loading data from {}. Rebuilding dataset...".format(path))
        data = make_orbits_dataset(**kwargs)
        to_pickle(data, path)
    return data

