import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import seaborn as sns
import random

import pysindy as ps

def fourth_order_diff(x, dt):
    dx = np.zeros([x.shape[0], x.shape[1]])
    dx[0] = (-11.0 / 6) * x[0] + 3 * x[1] - 1.5 * x[2] + x[3] / 3
    dx[1] = (-11.0 / 6) * x[1] + 3 * x[2] - 1.5 * x[3] + x[4] / 3
    dx[2:-2] = (-1.0 / 12) * x[4:] + (2.0 / 3) * x[3:-1] - (2.0 / 3) * x[1:-3] + (1.0 / 12) * x[:-4]
    dx[-2] = (11.0 / 6) * x[-2] - 3.0 * x[-3] + 1.5 * x[-4] - x[-5] / 3.0
    dx[-1] = (11.0 / 6) * x[-1] - 3.0 * x[-2] + 1.5 * x[-3] - x[-4] / 3.0
    return dx / dt

def sample_trajectory(x0, coefs, library, timesteps, dt, batch_size):
    coefs = np.transpose(coefs, (0, 2, 1))
    xs = []
    curr = np.array([x0 for i in range(batch_size)])
    for i in range(timesteps):
        curr_lib = library.transform(curr).reshape(10, 1, 9)
        coef_idx = np.random.randint(0, len(coefs), batch_size)
        curr_coefs = coefs[coef_idx]
        dx = np.matmul(curr_lib, curr_coefs).squeeze(1)
        curr = curr + dx * dt
        xs.append(curr)
    xs = np.array(xs)
    return np.transpose(xs, (1, 0, 2))

def sample_trajectory2(x0, coefs, library, timesteps, dt, batch_size):
    coefs = np.transpose(coefs, (0, 2, 1))
    coefs_mean, coefs_std = coefs.mean(0), coefs.std(0)
    coefs_mean = np.array([coefs_mean for _ in range(batch_size)])
    coefs_std = np.array([coefs_std for _ in range(batch_size)])
    xs = []
    curr = np.array([x0 for _ in range(batch_size)])
    for i in range(timesteps):
        curr_lib = library.transform(curr).reshape(10, 1, 9)
        noise = np.random.normal(0, 1, (batch_size, coefs.shape[1], coefs.shape[2]))
        curr_coefs = coefs_mean + coefs_std * noise
        dx = np.matmul(curr_lib, curr_coefs).squeeze(1)
        curr = curr + dx * dt
        xs.append(curr)
    xs = np.array(xs)
    return np.transpose(xs, (1, 0, 2))

def plot_samples(xs, samples, num_samples=4, dpi=300, figsize=None, filename=None):
    sns.set()

    # https://dawes.wordpress.com/2014/06/27/publication-ready-3d-figures-from-matplotlib/
    # fig = plt.figure(figsize=(batch_size + 1, 3.5), dpi=300)
    if figsize is not None:
        fig = plt.figure(figsize=figsize, dpi=dpi)
    else:
        fig = plt.figure(dpi=dpi)
    fig.tight_layout()
    ct = 0
    for i in range(num_samples):
        ax = fig.add_subplot(1, num_samples, ct + 1, projection='3d')
        if i == 0:
            ax.plot(xs[:, 0], xs[:, 1], xs[:,2], color='red')
        else:
            ax.plot(samples[i][:, 0], samples[i][:, 1], samples[i][:,2], color='blue')
        ct += 1

        ax.grid(False)
        color_tuple = (1.0, 1.0, 1.0, 0.0)

        ax.xaxis.set_pane_color(color_tuple)
        ax.yaxis.set_pane_color(color_tuple)
        ax.zaxis.set_pane_color(color_tuple)
        ax.xaxis.line.set_color(color_tuple)
        ax.yaxis.line.set_color(color_tuple)
        ax.zaxis.line.set_color(color_tuple)

        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])

    plt.subplots_adjust(wspace=0)
    
    if filename is not None:
        plt.savefig(filename)
        plt.close()
    plt.show()
    plt.close()

# 1
np.random.seed(862023)
random.seed(862023)
dt = 0.01
x_train = np.load('../data/lorenz/scale-1.0/x_train.npy')
x_dot = fourth_order_diff(x_train, dt)
x_test = np.load('../data/lorenz/scale-1.0/x_test_0.npy')
x0 = x_test[0]
feature_names = ['x', 'y', 'z']
# Instantiate and fit the SINDy model 
library = ps.PolynomialLibrary(degree=2, include_bias=False)
optimizer = ps.SR3(
    threshold=0.5, thresholder="l0", max_iter=1000, normalize_columns=False, tol=1e-1
)
model = ps.SINDy(feature_names=feature_names, feature_library=library, optimizer=optimizer)
model.fit(x_train, x_dot=x_dot, t=dt, ensemble=True, quiet=True, n_models=500)
ensemble_coefs = np.array(model.coef_list)
model.print()
model.coef_list = np.mean(ensemble_coefs, 0)
the_std = ensemble_coefs.std(0)
print(the_std)
print(np.max(the_std))
np.random.seed(862023)
random.seed(862023)
samples = sample_trajectory(x0, ensemble_coefs, library, 10000, dt, 10)
plot_samples(x_test, samples, 5, 300, (20, 20), "../results/esindy_lorenz_gen_1")

# 5
np.random.seed(862023)
random.seed(862023)
dt = 0.01
x_train = np.load('../data/lorenz/scale-5.0/x_train.npy')
x_dot = fourth_order_diff(x_train, dt)
x_test = np.load('../data/lorenz/scale-5.0/x_test_0.npy')
x0 = x_test[0]
feature_names = ['x', 'y', 'z']
# Instantiate and fit the SINDy model 
library = ps.PolynomialLibrary(degree=2, include_bias=False)
optimizer = ps.SR3(
    threshold=0.5, thresholder="l0", max_iter=1000, normalize_columns=False, tol=1e-1
)
model = ps.SINDy(feature_names=feature_names, feature_library=library, optimizer=optimizer)
model.fit(x_train, x_dot=x_dot, t=dt, ensemble=True, quiet=True, n_models=500)
ensemble_coefs = np.array(model.coef_list)
model.print()
model.coef_list = np.mean(ensemble_coefs, 0)
the_std = ensemble_coefs.std(0)
print(the_std)
print(np.max(the_std))

# 10
np.random.seed(862023)
random.seed(862023)
samples = sample_trajectory(x0, ensemble_coefs, library, 10000, dt, 10)
plot_samples(x_test, samples, 5, 300, (20, 20), "../results/esindy_lorenz_gen_5")
np.random.seed(862023)
random.seed(862023)
dt = 0.01
x_train = np.load('../data/lorenz/scale-10.0/x_train.npy')
x_dot = fourth_order_diff(x_train, dt)
x_test = np.load('../data/lorenz/scale-10.0/x_test_0.npy')
x0 = x_test[0]
feature_names = ['x', 'y', 'z']
# Instantiate and fit the SINDy model 
library = ps.PolynomialLibrary(degree=2, include_bias=False)
optimizer = ps.SR3(
    threshold=0.9, thresholder="l0", max_iter=1000, normalize_columns=False, tol=1e-1
)
model = ps.SINDy(feature_names=feature_names, feature_library=library, optimizer=optimizer)
model.fit(x_train, x_dot=x_dot, t=dt, ensemble=True, quiet=True, n_models=500)
ensemble_coefs = np.array(model.coef_list)
model.print()
model.coef_list = np.mean(ensemble_coefs, 0)
the_std = ensemble_coefs.std(0)
print(the_std)
print(np.max(the_std))
np.random.seed(862023)
random.seed(862023)
samples = sample_trajectory(x0, ensemble_coefs, library, 10000, dt, 10)
plot_samples(x_test, samples, 5, 300, (20, 20), "../results/esindy_lorenz_gen_10")