# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats
import seaborn as sns
import pandas as pd
from sklearn.decomposition import PCA

import tqdm

import plotly.graph_objects as go
from sklearn.model_selection import train_test_split
import argparse
from mpl_toolkits.axes_grid1 import make_axes_locatable

from sklearn.neighbors import KernelDensity
from scipy.stats import gaussian_kde
import os

import sympy as sym
from sympy.utilities.lambdify import lambdify
from scipy.stats import norm

PRECISION = np.float32


class Model:
    """
    Stochastic model constants.
    """

    noise_sigma = 0.06
    sin_scale = 10
    drift_amplitude = 1.
    initial_noise_scale = 2e-2
    t0 = -1
    tn = 1

    t, T = sym.symbols('t T')

    mean_sympy = sym.integrate(sym.sin(sin_scale * t), (t, t0, T))
    mean_numpy = lambdify(T, mean_sympy, 'numpy')

    cov_sympy = sym.integrate(noise_sigma ** 2, (t, t0, T)) + initial_noise_scale ** 2
    cov_numpy = lambdify(T, cov_sympy, 'numpy')

    def __init__(self, n_particles, dim, varying_frequency=True, n_tsteps=2000, add_each=2, **kwargs):
        self.n_timesteps = n_tsteps

        self.DT = float(self.tn - self.t0) / self.n_timesteps

        self.dim = dim
        self.n_particles = n_particles

        y_ic = np.random.normal(0, Model.initial_noise_scale, (self.n_particles, self.dim))
        self.ys = [y_ic]
        self.ts = [self.t0]

        self.prev_y = self.ys[0]

        self.add_each = add_each
        self.varying_frequency = varying_frequency

        if self.varying_frequency:
            self.funs = [lambda _t:
                         Model.drift_amplitude * np.sin(
                             (i % self.dim) / self.dim * Model.sin_scale * np.pi
                             + (_t * Model.sin_scale))
                         for i in range(self.dim)]
        else:
            self.funs = [lambda _t:
                         Model.drift_amplitude * np.sin(_t * Model.sin_scale)
                         for i in range(self.dim)]

    def drift(self, y: float, _t: float) -> np.array:
        """
        Implement the Ornstein–Uhlenbeck mu.
        """
        drift = np.zeros(self.dim)

        for i in range(self.dim):
            drift[i] = self.funs[i](_t)
            # drift[i] = Model.drift_amplitude * np.sin(
            #     (i % self.dim) / self.dim * Model.sin_scale * np.pi
            #     + (_t * Model.sin_scale)
            # )

        return drift

    def sigma(self, _y: float, _t: float) -> float:
        """
        Implement the Ornstein–Uhlenbeck sigma.
        """
        return Model.noise_sigma

    def dW(self) -> float:
        """
        Sample a random number at each call.
        """
        return np.random.normal(loc=0.0, scale=np.sqrt(self.DT), size=(self.n_particles, self.dim))

    def run_simulation(self):
        """
        Return the result of one full simulation.
        """

        # TS = np.arange(self.t0, self.tn + self.DT, self.DT)

        for i in tqdm.tqdm(range(1, self.n_timesteps + 1)):
            t = self.t0 + (i - 1) * self.DT
            y = self.prev_y
            self.prev_y = y + self.drift(y, t) * self.DT + self.sigma(y, t) * self.dW()

            if i % self.add_each == 0:
                self.ys.append(self.prev_y)
                self.ts.append(t)

        return np.stack(self.ts, 0).astype(PRECISION), np.stack(self.ys, 0).astype(PRECISION)


def generate_data(config: dict) -> pd.DataFrame:
    np.random.seed(1234)
    root_path = f'{config["root_path"]}/data_images'

    dim = config["dim"]
    model = Model(**config)

    ts, ys = model.run_simulation()
    ts_full = np.repeat(ts[..., np.newaxis], model.n_particles)

    y_min, y_max = -.7, .7
    # assert y_min < np.min(ys) and y_max > np.max(ys), f"{y_min}!<{np.min(ys)}, {y_max}!>{np.max(ys)}"

    n_function_samples = 5_000
    if dim == 1 and config["kde"]:
        return gen_kde_dataset_1d(config, model, model.n_particles, root_path, ts, y_max, y_min, ys, n_function_samples)
    if dim == 2 and config["kde"]:
        scotts_rule_bandwith = 1.06 * Model.initial_noise_scale * model.n_particles ** (-1 / 5)
        kdes = [KernelDensity(kernel="gaussian", bandwidth=scotts_rule_bandwith).fit(ys[i, ...])
                for i in tqdm.tqdm(range(ys.shape[0]))]

        grid_dim = 100
        xx = np.linspace(y_min, y_max, grid_dim)
        yy = np.linspace(y_min, y_max, grid_dim)
        XY_grid = np.stack(np.meshgrid(xx, yy), -1)
        XY = XY_grid.reshape(-1, dim)
        t_grid = np.repeat(ts, grid_dim ** dim).reshape((ts.shape[0], grid_dim, grid_dim))
        X_grid = np.tile(XY[..., 0], ts.shape[0]).reshape((ts.shape[0], grid_dim, grid_dim))
        Y_grid = np.tile(XY[..., 1], ts.shape[0]).reshape((ts.shape[0], grid_dim, grid_dim))
        XYT_grid = np.stack([t_grid, X_grid, Y_grid], axis=-1)

        preds = np.stack([np.exp(kde.score_samples(XY)) for kde in tqdm.tqdm(kdes)], 0)
        preds = preds.reshape((-1, xx.shape[0], yy.shape[0]))

        plot_2d(XYT_grid, preds, model, ts_full, ys, root_path)
        plot_nd(dim=dim, centroid_mesh_txy=XYT_grid, mass_on_grid=preds, model=model, ts_full=ts_full, ys=ys,
                root_path=root_path)
        # dim, model,ts_full, ys, root_path,
        # mass_on_grid, centroid_mesh_txy
        data = []
        for t, kde in zip(ts, kdes):
            X = np.random.uniform(y_min, y_max, n_function_samples * dim).reshape(-1, dim)
            T = np.repeat(t, X.shape[0]).reshape(-1, 1)
            pred_density = np.exp(kde.score_samples(X)).reshape(-1, 1)

            all = np.concatenate([T, X, pred_density], -1)
            data.append(all)
        data = np.concatenate(data, 0)
        df = pd.DataFrame(data, columns=["time", "x", "y", "density"])
        return df

    if dim == 1:
        n_bins_per_dim = 30
    elif dim == 2:
        n_bins_per_dim = 30
    else:
        n_bins_per_dim = 30

    df_particle_positions = pd.DataFrame({f"dim_{i}": np.squeeze(ys[..., i]).reshape(-1) for i in range(dim)})
    df_particle_positions["t"] = ts_full.reshape(-1)

    bin_edges = np.linspace(y_min, y_max, n_bins_per_dim + 1).astype(PRECISION)
    bin_width = (y_max - y_min) / n_bins_per_dim
    bin_volume = bin_width ** dim

    hist, edges = np.histogramdd(ys[0, ...], bins=[bin_edges] * dim, density=True)

    edge_means = [np.mean(np.vstack([edge[0:-1], edge[1:]]), axis=0) for edge in edges]

    centroid_mesh_txy = create_domain_mesh(dim, edge_means, n_bins_per_dim, ts)
    # plt.plot(np.exp(kde.score_samples(np.linspace(-.6, .6, 500).reshape(-1,1))))
    # plt.show()

    mass_on_grid = np.stack([np.histogramdd(ys[i, ...], bins=[bin_edges] * dim, density=True)[0].astype(PRECISION)
                             for i in tqdm.tqdm(range(ys.shape[0]))], 0)
    mass_on_grid /= np.sum(mass_on_grid)
    density_on_grid = mass_on_grid / bin_volume

    proportion_nonzero = config["proportion_nonzero"]
    proportion_zero = 1 - proportion_nonzero

    # combine indices
    idx_gt_0 = np.where(mass_on_grid > 0.)
    idx_0 = np.where(mass_on_grid == 0.)
    tmp = np.random.choice(idx_0[0].shape[0], int(idx_gt_0[0].shape[0] / proportion_nonzero * proportion_zero),
                           replace=False)
    idx_combined = tuple(np.concatenate([idx_gt_0[i], idx_0[i][tmp]]) for i in range(len(idx_0)))

    # test and train split on index level
    idx_train_, idx_test_ = train_test_split(np.arange(len(idx_combined[0])), test_size=0.9, random_state=42)
    idx_val_, idx_test_ = train_test_split(np.arange(len(idx_combined[0])), test_size=0.9, random_state=43)

    dfs = []
    for idx_idx, subset in zip([idx_train_, idx_val_, idx_test_],
                               ["train", "val", "test"]):
        # index that references train, val or test data in the big mesh
        cur_data_idx = tuple((idx_[idx_idx] for idx_ in idx_combined))
        df_particle_positions = (pd.DataFrame(centroid_mesh_txy[cur_data_idx],
                                              columns=["time"] + [f"x_{i}" for i in range(1, dim + 1)])
                                 .assign(density=lambda _: density_on_grid[cur_data_idx],
                                         subset=lambda df_: pd.Categorical([subset] * len(df_),
                                                                           ["train", "val", "test"]))
                                 )
        dfs.append(df_particle_positions)

    data_df = pd.concat(dfs)

    print("plotting..")
    if config["create_plots"]:
        if dim == 1:
            plot_1d(model, root_path, mass_on_grid, centroid_mesh_txy)
        if dim == 2:
            plot_2d(centroid_mesh_txy, mass_on_grid, model, ts_full, ys,
                    root_path)
        if dim > 1:
            plot_nd(dim, model, ts_full, ys,
                    root_path,
                    mass_on_grid, centroid_mesh_txy)
    return data_df


def gen_kde_dataset_1d(config, model, n_particles, root_path, ts, y_max, y_min, ys, n_function_samples):
    xx = np.linspace(y_min, y_max, 300)

    scotts_rule_bandwith = 1.06 * Model.initial_noise_scale * n_particles ** (-1 / 5)
    kdes = [KernelDensity(kernel="gaussian", bandwidth=scotts_rule_bandwith).fit(ys[i, ...])
            for i in tqdm.tqdm(range(ys.shape[0]))]
    preds = np.stack([np.exp(kde.score_samples(xx.reshape(-1, 1))) for kde in tqdm.tqdm(kdes)], 0)

    print("plotting..")
    plot_1d(model, root_path, preds, np.stack(np.meshgrid(ts, xx), -1))
    # plt.imshow(preds, origin='lower', aspect="auto", extent=[y_min, y_max, config["t0"], config["tn"]],
    #            cmap="mako")
    # plt.savefig(f"{root_path}/FokkerPlanck_{model.dim}D_hist_kde.png")
    plt.close()

    print("creating data..")
    data = []
    for t, kde in tqdm.tqdm(zip(ts, kdes), total=ts.shape[0]):
        X = np.random.uniform(y_min, y_max, n_function_samples).reshape(-1, 1)
        T = np.repeat(t, X.shape[0])
        pred_density = np.exp(kde.score_samples(X))

        all = np.stack([np.squeeze(T), np.squeeze(X), np.squeeze(pred_density)], -1)
        data.append(all)
    data = np.concatenate(data, 0)
    df = pd.DataFrame(data, columns=["time", "x", "density"])
    return df


def create_domain_mesh(dim, edge_means, n_bins_per_dim, ts):
    centroid_mesh = np.meshgrid(*edge_means)
    centroid_mesh_xy = np.stack(centroid_mesh, -1).astype(PRECISION)  # .reshape(-1, len(centroid_mesh))
    del centroid_mesh
    centroid_mesh_xy = np.repeat(centroid_mesh_xy[np.newaxis, ...], len(ts), 0).astype(PRECISION)
    t_mesh = np.tile(ts.astype(PRECISION).reshape((-1,) + (1,) * dim), (1,) + (n_bins_per_dim,) * dim)[..., np.newaxis]
    centroid_mesh_txy = np.concatenate([t_mesh.astype(PRECISION), centroid_mesh_xy.astype(PRECISION)], -1)
    return centroid_mesh_txy


def add_colorbar(im, ax, fig):
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    return fig.colorbar(im, cax=cax, orientation='vertical')


def plot_1d(model, root_path, density_grid, txy_grid):
    t_min, x_min = txy_grid.min(0).min(0).tolist()
    t_max, x_max = txy_grid.max(0).max(0).tolist()

    tt = txy_grid[0, :, 0]
    xx = txy_grid[:, 0, 1]
    mean = Model.mean_numpy(tt)  # (1 / 10) * (np.cos(10) - np.cos(10 * tt))
    var = Model.cov_numpy(tt)  # (tt - t_min) * Model.noise_sigma ** 2 + initial_noise_scale ** 2

    log_pdfs = np.zeros((tt.shape[0], xx.shape[0]))
    for i, (mu, sigma2) in enumerate(zip(mean, var)):
        log_pdfs[i, :] = norm.logpdf(xx, mu, np.sqrt(sigma2))
    log_pdfs = np.exp(log_pdfs)
    im_args = dict(vmin=density_grid.min(), vmax=density_grid.max(), origin='lower',
                   cmap="mako", aspect="auto", extent=[t_min, t_max, x_min, x_max])

    fig, axs = plt.subplots(3, 1, sharex="all", sharey="all", figsize=(15, 15))
    axs = axs.flatten()
    im = axs[0].imshow(density_grid.T, **im_args)
    add_colorbar(im, axs[0], fig)

    im = axs[1].imshow(log_pdfs.T, **im_args)
    add_colorbar(im, axs[1], fig)

    im = axs[2].imshow(np.abs(density_grid.T - log_pdfs.T), cmap="Reds", aspect="auto",
                       extent=[t_min, t_max, x_min, x_max])
    add_colorbar(im, axs[2], fig)

    for ax, title in zip(axs, ["Particle Simulation", "Analytical Solution", "Absolute Error"]):
        ax.set_title(title)
        ax.set_xlabel("t")
        ax.set_ylabel("x")

    plt.tight_layout()
    plt.savefig(f"{root_path}/FokkerPlanck_{model.dim}D_hist.png")


def plot_2d(centroid_mesh_txy, mass_on_grid, model, ts_full, ys, root_path):
    idx = np.random.choice(np.prod(ts_full.shape), 10_000)
    plot_inputs = dict(
        xs=ts_full.reshape(-1)[idx],
        ys=ys[..., 0].reshape(-1)[idx],
        zs=ys[..., 1].reshape(-1)[idx])
    fig = plt.figure(figsize=plt.figaspect(0.5))
    ax = fig.add_subplot(1, 2, 1, projection='3d')
    ax.scatter(**plot_inputs, marker='+', alpha=0.1)
    ax.view_init(elev=10., azim=80.)
    ax = fig.add_subplot(1, 2, 2, projection='3d')
    ax.scatter(**plot_inputs, marker='+', alpha=0.1)
    ax.view_init(elev=70., azim=-80.)
    for ax in plt.gcf().get_axes():
        ax.set_xlabel('time')
        ax.set_ylabel('X')
        ax.set_zlabel('Y')
    plt.savefig(f"FokkerPlanck_{model.dim}D_scatter.png")
    plt.close()
    sample_each = 2
    fig = go.Figure(data=go.Volume(
        x=centroid_mesh_txy[::sample_each, ..., 0].flatten(),
        y=centroid_mesh_txy[::sample_each, ..., 1].flatten(),
        z=centroid_mesh_txy[::sample_each, ..., 2].flatten(),
        value=mass_on_grid[::sample_each, ...].flatten(),
        isomin=np.percentile(mass_on_grid.flatten(), 10),
        isomax=np.percentile(mass_on_grid.flatten(), 99.1),
        opacity=0.1,  # needs to be small to see through all surfaces
        surface_count=21,  # needs to be a large number for good volume rendering
        opacityscale="extremes"
    ))
    fig.update_layout(
        title='21 isosurfaces from 10th to 99th percentile.',
        autosize=False,
        width=500,
        height=500,
        margin=dict(l=65, r=50, b=65, t=90),
        scene=dict(
            xaxis_title='Time',
            yaxis_title='X',
            zaxis_title='Y',
        ),
    )
    fig.write_html(f"{root_path}/FokkerPlanck_{model.dim}D_hist_iso.html",
                   include_plotlyjs="cdn", animation_opts=dict(transition=dict(easing="cubic")))


def plot_nd(dim, model, ts_full, ys, root_path,
            mass_on_grid, centroid_mesh_txy):
    pca = PCA(n_components=2)
    projected = pca.fit_transform(ys.reshape(-1, dim))
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    idx = np.random.choice(np.prod(ts_full.shape), 10_000)
    plot_inputs = dict(
        xs=ts_full.reshape(-1)[idx],
        ys=projected[..., 0].reshape(-1)[idx],
        zs=projected[..., 1].reshape(-1)[idx],
        c=np.linalg.norm(ys, axis=-1).reshape(-1)[idx])
    ax.scatter(**plot_inputs, marker='+', alpha=0.1)
    plt.savefig(f"{root_path}/FokkerPlanck_{model.dim}D_scatter_PCA.png")
    plt.close()

    num_tri_entries = dim * (dim - 1) / 2
    fig, axs = plt.subplots(dim + 1, dim + 1, figsize=(10, 10))
    # vmin, vmax = mass_on_grid.sum().max(), mass_on_grid.sum().min()
    for j in range(dim + 1):
        for i in range(0, j):
            sum_over_idx = tuple(dim for dim in range(len(mass_on_grid.shape)) if dim != i and dim != j)
            cur_grid = mass_on_grid.sum(axis=sum_over_idx)
            im = axs[i, j].imshow(cur_grid, cmap="mako")

            divider = make_axes_locatable(axs[i, j])
            cax = divider.append_axes('right', size='5%', pad=0.05)
            fig.colorbar(im, cax=cax, orientation='vertical')

        for i in range(j, dim + 1):
            # axs[j, i].axis('off')

            axs[i, j].xaxis.set_visible(False)
            # make spines (the box) invisible
            plt.setp(axs[i, j].spines.values(), visible=False)
            # remove ticks and labels for the left axis
            axs[i, j].tick_params(left=False, labelleft=False)
            # remove background patch (only needed for non-white background)
            axs[i, j].patch.set_visible(False)

        sum_over_idx = tuple(dim for dim in range(len(mass_on_grid.shape)) if dim != i)
        x_axis = centroid_mesh_txy.mean(axis=sum_over_idx)[:, i]
        # axs[i, i].hist(mass_on_grid.sum(axis=sum_over_idx), bins=10)

    for j in range(dim + 1):
        row_title = "time" if j == 0 else f"dim_{j}"
        axs[j, 0].set_ylabel(row_title, size="large")
    for i in range(dim + 1):
        col_title = "time" if i == 0 else f"dim_{i}"
        axs[0, i].set_title(col_title)

    # fig, axs= plt.subplots(*[int(np.round(np.sqrt(dim)))]*2, figsize=(10,10))
    # axs = axs.flatten()
    # sns.pairplot(df_particle_positions.sample(50_000),
    #              kind="hist", corner=True,
    #              plot_kws=dict(bins=n_bins_per_dim),
    #              diag_kws=dict(bins=n_bins_per_dim))
    plt.tight_layout()
    plt.savefig(f"{root_path}/FokkerPlanck_{model.dim}D_pairplot.png")


def main(dim: int, root_path="."):
    os.makedirs(f"{root_path}/data_images/", exist_ok=True)
    os.makedirs(f"{root_path}/data/", exist_ok=True)

    config = dict(
        n_particles=50_000,
        n_timesteps=4000,
        add_each=20,

        dim=dim,

        proportion_nonzero=0.5,
        varying_frequency=True,

        create_plots=True,
        root_path=root_path,
        kde=True
    )
    df = generate_data(config)
    print("writing snappy..")
    kde = "_kde" if config["kde"] else ""
    varying_frequency = "_multi_freq" if config["varying_frequency"] else ""
    df.to_parquet(f"./data/FokkerPlanck_{dim}D{kde}{varying_frequency}.snappy")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('--dim', type=int, default=1, help='Dimension of the data.')
    parser.add_argument("--silent", action="store_true")
    args = parser.parse_args()
    print(args)

    main(args.dim)
