import sys
sys.path.append("../")
import os
import math
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
from scipy.stats import halfnorm

from other import set_random_seed, make_folder

"""
Generates Lotka-Volterra data with drift-diffusion representation.
"""

# The simulation code here was adopted from:
# https://matplotlib.org/stable/gallery/mplot3d/lorenz_attractor.html

def calc_drift(x, y, a, b, c, d):
    x_dot = a * x - b * x * y
    y_dot = c * x * y - d * y
    return np.array([x_dot, y_dot])

def calc_diff(x, y):
    #g11 = (0.050 * x ** 2)
    #g22 = (0.057 * y ** 2)
    g11 = (0.066 * x ** 2) + (-0.064 * x * y) + (0.019 * x * (y ** 2))
    g22 = (-0.037 * y) + (0.071 * (y ** 2))
    return np.array([g11, g22])

def derivative(x, y, a, b, c, d, t, dt):
    # derivative
    drift = calc_drift(x, y, a, b, c, d)
    diff = calc_diff(x, y)
    return drift * dt + diff * np.random.randn(2) * np.sqrt(dt)

def simulation(init_conds, steps, dt, a, b, c, d):
    # Need one more for the initial values
    x = np.zeros([steps, len(init_conds)])
    # Set initial values
    x[0] = init_conds

    # Step through "time", calculating the partial derivatives at the current
    # point and using them to estimate the next point
    for i in range(steps):
        x_dot = derivative(x[i][0], x[i][1], a, b, c, d, i * dt, dt)
        if i == steps - 1:
            break
        x[i + 1] = x[i] + x_dot
    return x

def pipeline(folder, init_conds=(2.1, 1.0), steps=10000, dt=1e-2, a=0.937, b=0.995, c=1.090, d=1.114):
    x_train = simulation(init_conds, steps, dt, a, b, c, d)

    make_folder(folder)
    if folder[-1] != "/":
        folder += "/"
    np.save(folder + "x_train", x_train)

    plt.plot(x_train)
    plt.savefig("../results/gen_ss.png")

    """
    F1
    (0.937 ± 0.361)x + (-0.995 ± 0.228)xy
    F2
    (-1.114 ± 0.107)y + (1.090 ± 0.055)xy
    G11
    (0.050 ± 0.016)x^2
    G22
    (0.057 ± 0.013)y^2
    G12
    0
    """

    """
    G11
    (0.066 ± 0.016)x^2 + (-0.064 ± 0.015)xy + (0.019 ± 0.003)xy^2
    G22
    (-0.037 ± 0.010)y + (0.071 ± 0.003)y^2
    """
    

def main():
    set_random_seed(123)
    pipeline("../data/gen_ss/")
    


if __name__ == "__main__":
    main()