import numpy as np

from utils import solve_VdP, solve_lorenz_system, solve_rossler_system

# setting a random seed
rng = np.random.default_rng(1)


# normalize Lorenz data:
def normalize_dataset(dataset, max_value=3):
    max_data_value = np.max(np.abs(dataset))
    scaling_factor = max_value / max_data_value
    normalized_data = dataset * scaling_factor
    return normalized_data, scaling_factor


if __name__ == '__main__':
    # Van der Pol (simple ODE)
    t_eval_VdP_train = np.arange(0, 20.1, 0.1)
    train_dataset_VdP = solve_VdP(rng.random(size=(50, 2)) * 6 - 3, t_eval_VdP_train)
    t_eval_VdP_valid = np.arange(0, 50.1, 0.1)
    valid_dataset_VdP = solve_VdP(rng.random(size=(50, 2)) * 6 - 3, t_eval_VdP_valid)
    test_dataset_VdP = solve_VdP(rng.random(size=(50, 2)) * 6 - 3, t_eval_VdP_valid)

    np.save('VdP_train.npy', np.transpose(train_dataset_VdP, axes=[2, 1, 0]))
    np.save('VdP_valid.npy', np.transpose(valid_dataset_VdP, axes=[2, 1, 0]))
    np.save('VdP_test.npy', np.transpose(test_dataset_VdP, axes=[2, 1, 0]))

    # Lorenz (chaotic system)
    t_eval_L_train = np.arange(0, 5.01, 0.01)
    train_dataset_L = solve_lorenz_system(
        np.concatenate([rng.random(size=(50, 2)) * 40 - 20, rng.random(size=(50, 1)) * 50], axis=1), t_eval_L_train)
    t_eval_L_valid = np.arange(0, 50.01, 0.01)
    valid_dataset_L = solve_lorenz_system(
        np.concatenate([rng.random(size=(50, 2)) * 40 - 20, rng.random(size=(50, 1)) * 50], axis=1), t_eval_L_valid)
    test_dataset_L = solve_lorenz_system(
        np.concatenate([rng.random(size=(50, 2)) * 40 - 20, rng.random(size=(50, 1)) * 50], axis=1), t_eval_L_valid)

    np.save('Lorenz_train.npy', np.transpose(train_dataset_L, axes=[2, 1, 0]))
    np.save('Lorenz_valid.npy', np.transpose(valid_dataset_L, axes=[2, 1, 0]))
    np.save('Lorenz_test.npy', np.transpose(test_dataset_L, axes=[2, 1, 0]))

    np.save('Lorenz_normalized_train.npy', normalize_dataset(np.transpose(train_dataset_L, axes=[2, 1, 0]))[0])
    np.save('Lorenz_normalized_valid.npy', normalize_dataset(np.transpose(valid_dataset_L, axes=[2, 1, 0]))[0])
    np.save('Lorenz_normalized_test.npy', normalize_dataset(np.transpose(test_dataset_L, axes=[2, 1, 0]))[0])

    # Rössler (chaotic system)
    t_eval_R_train = np.arange(0, 10.01, 0.01)
    train_dataset_R = solve_rossler_system(
        np.concatenate([rng.random(size=(50, 2)) * 40 - 20, rng.random(size=(50, 1)) * 40], axis=1), t_eval_R_train)
    t_eval_R_valid = np.arange(0, 200.01, 0.01)
    valid_dataset_R = solve_rossler_system(
        np.concatenate([rng.random(size=(50, 2)) * 40 - 20, rng.random(size=(50, 1)) * 40], axis=1), t_eval_R_valid)
    test_dataset_R = solve_rossler_system(
        np.concatenate([rng.random(size=(50, 2)) * 40 - 20, rng.random(size=(50, 1)) * 40], axis=1), t_eval_R_valid)

    np.save('Rossler_train.npy', np.transpose(train_dataset_R, axes=[2, 1, 0]))
    np.save('Rossler_valid.npy', np.transpose(valid_dataset_R, axes=[2, 1, 0]))
    np.save('Rossler_test.npy', np.transpose(test_dataset_R, axes=[2, 1, 0]))

    np.save('Rossler_normalized_train.npy', normalize_dataset(np.transpose(train_dataset_R, axes=[2, 1, 0]))[0])
    np.save('Rossler_normalized_valid.npy', normalize_dataset(np.transpose(valid_dataset_R, axes=[2, 1, 0]))[0])
    np.save('Rossler_normalized_test.npy', normalize_dataset(np.transpose(test_dataset_R, axes=[2, 1, 0]))[0])

    print('Datasets were successfully generated.')
