import sys
sys.path.append("../")
import os
import math
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
from scipy.stats import halfnorm

from other import set_random_seed, make_folder

"""
Generates Lotka-Volterra data with drift-diffusion representation.
"""

# The simulation code here was adopted from:
# https://matplotlib.org/stable/gallery/mplot3d/lorenz_attractor.html

def calc_drift(x, y, a, b, c, d):
    x_dot = a * x - b * x * y
    y_dot = c * x * y - d * y
    return np.array([x_dot, y_dot])

def calc_diff(x, y):
    return np.array([[0.25, -0.09], [-0.09, 0.25]]).dot(np.array([x, y]))

def derivative(x, y, a, b, c, d, t, dt):
    # derivative
    drift = calc_drift(x, y, a, b, c, d)
    diff = calc_diff(x, y)
    return drift * dt + diff * np.random.randn(2) * np.sqrt(dt)

def simulation(init_conds, steps, dt, a, b, c, d):
    # Need one more for the initial values
    x = np.zeros([steps, len(init_conds)])

    # Set initial values
    x[0] = init_conds

    # Step through "time", calculating the partial derivatives at the current
    # point and using them to estimate the next point
    for i in range(steps):
        x_dot = derivative(x[i][0], x[i][1], a, b, c, d, i * dt, dt)
        if i == steps - 1:
            break
        x[i + 1] = x[i] + x_dot
    return x

def pipeline(folder, end='', init_conds=(4, 2), steps=10000, dt=1e-2, a=1.0, b=1.0, c=1.0, d=1.0):
    x_train = simulation(init_conds, steps, dt, a, b, c, d)

    #plt.plot(x_train)
    #plt.savefig("../results/lotkavolterra_train.png")

    make_folder(folder)
    if folder[-1] != "/":
        folder += "/"
    np.save(folder + "x_" + str(end), x_train)

def main():
    set_random_seed(1000)
    pipeline("../data/lotkavolterra/", end='train')
    set_random_seed(65)
    pipeline("../data/lotkavolterra/", end='test', init_conds=(2.1, 1.0))
    


if __name__ == "__main__":
    main()