import sys
sys.path.append("../")
import os
import math
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d

from other import set_random_seed, make_folder

# The simulation code here was adopted from:
# https://matplotlib.org/stable/gallery/mplot3d/lorenz_attractor.html

def derivative(x, y, z, s, r, b, t, scale):
    # sample coefficients
    s = np.random.normal(s, scale)
    r = np.random.normal(r, scale)
    r2 = np.random.normal(1, 0)
    r3 = np.random.normal(1, 0)
    b2 = np.random.normal(1, 0)
    b = np.random.normal(b, scale)

    # derivative
    x_dot = s * (y - x)
    y_dot = r * x - r2 * y - r3 * x * z
    z_dot = b2 * x * y - b * z
    
    return np.array([x_dot, y_dot, z_dot]), np.array([s, r, b])

def simulation(init_conds, steps, dt, s, r, b, scale):
    # Need one more for the initial values
    x = np.zeros([steps, len(init_conds)])
    dx = np.zeros([steps, len(init_conds)])
    ts = np.zeros([steps,])
    params = np.zeros([steps, 3])

    # Set initial values
    x[0] = init_conds

    # Step through "time", calculating the partial derivatives at the current
    # point and using them to estimate the next point
    for i in range(steps):
        x_dot, curr_params = derivative(x[i][0], x[i][1], x[i][2], s, r, b, i * dt, scale)
        dx[i] = x_dot
        ts[i] = i * dt
        params[i] = curr_params
        if i == steps - 1:
            break
        x[i + 1] = x[i] + x_dot * dt
    return x, dx, ts, params

def pipeline(folder, init_conds=(0., 1., 1.05), steps=10000, dt=1e-2, s=10, r=28, b=8.0/3, scale=10.0, train=True, end=''):
    x_train, x_dot_train_measured, ts, params = simulation(init_conds, steps, dt, s, r, b, scale)
    count = 1
    while np.any(np.isnan(x_train)) or np.any(np.isinf(x_train)):
        print(count)
        x_train, x_dot_train_measured, ts, params = simulation(init_conds, steps, dt, s, r, b, scale)
        count += 1
    print(np.any(np.isnan(x_train)))
    print(np.any(np.isinf(x_train)))

    make_folder(folder)
    if folder[-1] != "/":
        folder += "/"
    if train:
        np.save(folder + "x_train" + end, x_train)
        np.save(folder + "x_dot" + end, x_dot_train_measured)
        np.save(folder + "x_ts" + end, ts)
        np.save(folder + "x_params" + end, params)
    else:
        np.save(folder + "x_test" + end, x_train)
        np.save(folder + "x_dot_test" + end, x_dot_train_measured)
        np.save(folder + "x_ts_test" + end, ts)
        np.save(folder + "x_params_test" + end, params)


def main():
    for i in range(10):
        set_random_seed(2990 + i)
        x0 = (0. + np.random.randn(), 1. + np.random.randn(), 1.05 + np.random.randn(),)
        pipeline("../data/lorenz_rmse/scale-1.0", init_conds=x0, scale=1.0, end=str(i))

    for i in range(10):
        set_random_seed(3990 + i)
        x0 = (0. + np.random.randn(), 1. + np.random.randn(), 1.05 + np.random.randn(),)
        pipeline("../data/lorenz_rmse/scale-5.0", init_conds=x0, scale=5.0, end=str(i))

    for i in range(10):
        set_random_seed(4990 + i)
        x0 = (0. + np.random.randn(), 1. + np.random.randn(), 1.05 + np.random.randn(),)
        pipeline("../data/lorenz_rmse/scale-10.0", init_conds=x0, scale=10.0, end=str(i))


if __name__ == "__main__":
    main()