import os.path

import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable

from flow import Flow, ScoreMatch


def draw_grid(fuc, title, save_path, vmin=None, vmax=None, save_data=None):
    x_grid = torch.arange(-3, 3, 1e-2).to('cuda')
    t_grid = torch.arange(0, 1, 1e-2).to('cuda')
    density_grid = torch.zeros([len(t_grid), len(x_grid)]).to('cuda')
    for i, t in enumerate(t_grid):
        density_grid[i] = fuc(x_grid, t.unsqueeze(-1).repeat(len(x_grid)).unsqueeze(-1))
    if save_data is not None:
        torch.save(density_grid, save_data)
    # plt.figure(figsize=(10, 5), dpi=100)
    plt.figure()
    x_grid, t_grid, density_grid = x_grid.detach().cpu().numpy(), t_grid.detach().cpu().numpy(), density_grid.detach().cpu().numpy()
    plt.xlim(t_grid[0], t_grid[-1])
    plt.ylim(x_grid[0], x_grid[-1])
    plt.imshow(density_grid.T, extent=[t_grid[0], t_grid[-1], x_grid[0], x_grid[-1]], aspect='auto', origin='lower',
               cmap='viridis', vmax=vmax, vmin=vmin)
    # plt.title(title, fontsize=15)
    plt.xticks(fontsize=15)
    plt.xlabel(r'$t$', fontsize=15)
    plt.ylabel(r"$x_t$", fontsize=15)
    plt.yticks(fontsize=15)
    plt.rcParams['ytick.labelsize'] = 12
    plt.colorbar()
    plt.tight_layout()
    plt.savefig(save_path)
    plt.show()
    plt.close()

def draw_grid_with_plot(fuc, traj, title, save_path, vmin=None, vmax=None, save_data=None):
    fig = plt.figure()
    x_grid = torch.arange(-3, 3, 1e-2).to('cuda')
    t_grid = torch.arange(0, 1, 1e-2).to('cuda')
    density_grid = torch.zeros([len(t_grid), len(x_grid)]).to('cuda')
    for i, t in enumerate(t_grid):
        density_grid[i] = fuc(x_grid, t.unsqueeze(-1).repeat(len(x_grid)).unsqueeze(-1))
    if save_data is not None:
        torch.save(density_grid, save_data)
    # plt.figure(figsize=(10, 5), dpi=100)
    x_grid, t_grid, density_grid = x_grid.detach().cpu().numpy(), t_grid.detach().cpu().numpy(), density_grid.detach().cpu().numpy()
    plt.xlim(t_grid[0], t_grid[-1])
    plt.ylim(x_grid[0], x_grid[-1])
    cax = plt.imshow(density_grid.T, extent=[t_grid[0], t_grid[-1], x_grid[0], x_grid[-1]], aspect='auto', origin='lower',
               cmap='viridis', vmax=vmax, vmin=vmin)
    for j, func_value in enumerate(traj.T):
        color = 'white'
        # if j in [
        #     0,1, 2,3,
        #     20, 21, 22,23]:
        if j in [
            4,5,6,7,8,
            20, 21,
            22, 23,
            24, 25,
        ]:
            color = 'red'
        plt.plot(t_grid, func_value.detach().cpu().numpy(),
                 linestyle='--',
                 linewidth=1.,
                 # color=plt.cm.tab10(j % 9),
                 color=color,
                 )
    # plt.title(title, fontsize=15)
    plt.xticks(fontsize=15)
    plt.yticks(fontsize=15)
    plt.xlabel(r'$t$', fontsize=20)
    plt.ylabel(r"$x_t$", fontsize=20)
    plt.xticks([0,0.2,0.4,0.6,0.8], [1.0,0.8,0.6,0.4,0.2],fontsize=15)
    # plt.colorbar()
    cbar = fig.colorbar(cax)
    cbar.ax.tick_params(labelsize=15)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.show()
    plt.close()

def draw_grid_line(fuc, title, save_path):
    x_grid = torch.arange(-5, 5, 1e-2).to('cuda')
    t_grid = torch.arange(0, 0.99, 1e-2).to('cuda')
    density_grid = torch.zeros([len(t_grid), len(x_grid)]).to('cuda')
    for i, t in enumerate(t_grid):
        density_grid[i] = fuc(x_grid, t.unsqueeze(-1).repeat(len(x_grid)).unsqueeze(-1))
    plt.figure()
    x_grid, t_grid, density_grid = x_grid.cpu().numpy(), t_grid.cpu().numpy(), density_grid.cpu().numpy()
    # plt.xlim(t_grid[0], t_grid[-1])
    # plt.ylim(x_grid[0], x_grid[-1])
    # plt.imshow(density_grid.T, extent=[t_grid[0], t_grid[-1], x_grid[0], x_grid[-1]], aspect='auto', origin='lower',
    #            cmap='viridis', vmax=vmax, vmin=vmin)
    plt.plot(t_grid, density_grid.mean(-1) )
    plt.title(title)
    plt.grid(True)
    plt.savefig(save_path)
    plt.show()
    plt.close()

def draw_line(fucs, title, save_path, labels=None, yrange=(-5,5)):
    t_values = torch.linspace(0, 1, 100).to('cuda').unsqueeze(-1)
    # t_values = torch.linspace(0, 1, 1000).to('cuda').unsqueeze(-1)
    plt.figure()
    if len(fucs) == 2:
        line_style = ['-', '--']
    else:
        line_style = ('-', ) * 99

    for i, fuc in enumerate(fucs):
        func_values = fuc(t_values).cpu().numpy()
        if labels is not None:
            for j, func_value in enumerate(func_values.T):
                plt.plot(t_values.cpu().numpy(), func_value, label=labels[i],
                         linestyle=line_style[i],
                         color=plt.cm.tab10(j % 9))
        else:
            for j, func_value in enumerate(func_values.T):
                plt.plot(t_values.cpu().numpy(), func_value, linestyle=line_style[i],
                         color=plt.cm.tab10(j % 9 ))
    plt.xlabel('t')
    plt.ylim(*yrange)
    plt.ylabel(title)
    # plt.grid(True)
    if labels is not None:
        plt.legend()
    plt.savefig(save_path)
    plt.show()
    plt.close()


def draw_x_line(fucs, title, save_path, ylim=(-0.5, 0.5),  labels=None, device='cuda', mode=None, vis=True):
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))
    x_values = torch.linspace(-3, 3, 1000).to(device)
    # x_values = torch.linspace(-0.5, 0.5, 40).to(device).unsqueeze(-1)
    plt.figure()
    if len(fucs) == 2:
        line_style = ['-', '-']
    else:
        line_style = ('-', ) * 99

    all_fuc_values = []
    for i, fuc in enumerate(fucs):
        func_values = fuc(x_values).detach().cpu().numpy()
        all_fuc_values.append(func_values)
        if labels is not None:
            plt.plot(x_values.detach().cpu().numpy(), func_values, label=labels[i], linestyle=line_style[i])
        else:
            plt.plot(x_values.detach().cpu().numpy(), func_values, linestyle=line_style[i])
    plt.xlabel('x_t')
    if ylim is not None:
        plt.ylim(*ylim)
    plt.ylabel(title)
    plt.grid(True)
    if len(fucs) > 1:
        plt.legend()
    if mode == 'MSE':
        plt.title(f'Mean Square Error {np.trapz(np.abs(func_values), x_values.detach().cpu().numpy()):.6f}, Variance {np.std(func_values):.6f}')
    if mode == 'MAE':
        # plt.title(f'Mean Absolute Error {np.trapz(np.abs(func_values), x_values.detach().cpu().numpy()):.6f}, Variance {np.std(func_values):.6f}')
        plt.title(f'Mean Absolute Error {np.sum(np.abs(func_values)):.6f}, Variance {np.std(func_values):.6f}')
    if mode == 'om':
        # plt.title(f'Osillation Measure: {np.trapz(np.abs(func_values), x_values.detach().cpu().numpy()):.6f}')
        plt.title(f'Prediction Osillation Measure: {np.trapz(np.abs(all_fuc_values[0]), x_values.detach().cpu().numpy()):.6f} '
                  f'GT Osillation Measure: {np.trapz(np.abs(all_fuc_values[1]), x_values.detach().cpu().numpy()):.6f}')
    if mode == 'tv':
        # plt.title(f'Prediction Total Variation: {np.trapz(np.abs(all_fuc_values[0]), x_values.detach().cpu().numpy()):.6f} '
        #           f'GT Total Variation: {np.trapz(np.abs(all_fuc_values[1]), x_values.detach().cpu().numpy()):.6f}')
        plt.title(f'Prediction Total Variation: {np.sum(np.abs(all_fuc_values[0])):.6f} '
                  f'GT Total Variation: {np.sum(np.abs(all_fuc_values[1])):.6f}')

    plt.xticks(fontsize='15')
    plt.yticks(fontsize='15')
    # plt.legend(fontsize='15')
    plt.xlabel(r'$x_t$', fontsize='20')
    plt.ylabel(r'$\nabla \log(x_t)$', fontsize='20')
    # plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
    plt.tight_layout()
    if vis:
        plt.savefig(save_path, dpi=300)
        plt.show()
    plt.close()
    return all_fuc_values


def draw_t_line(fucs, title, save_path, ylim=(-0.5, 0.5),  labels=None, device='cuda'):
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))
    t_values = torch.linspace(0.0, 0.4, 1000).to(device).unsqueeze(-1)
    # x_values = torch.linspace(-0.5, 0.5, 40).to(device).unsqueeze(-1)
    plt.figure()
    # if len(fucs) == 2:
    #     line_style = ['-', '--']
    # else:
    #     line_style = ('-', ) * 99
    line_style = ('-', ) * 99

    for i, fuc in enumerate(fucs):
        func_values = fuc(t_values).detach().cpu().numpy()
        if labels is not None:
            plt.plot(t_values.cpu().numpy(), func_values, label=labels[i], linestyle=line_style[i])
        else:
            plt.plot(t_values.cpu().numpy(), func_values, linestyle=line_style[i])
    plt.xlabel('t')
    if ylim is not None:
        plt.ylim(*ylim)
    plt.ylabel(title)
    plt.grid(True)
    plt.legend()
    plt.savefig(save_path)
    plt.show()
    plt.close()

def draw_conditional_trajectory(flow, z0, z1, save_path):
    from mog_util import schedule
    assert isinstance(flow, Flow)
    t = torch.linspace(0, 1, 100).to('cuda')
    num = 20
    x_0 = z0[:num]
    x_1 = z1[:num]
    alpha_t, beta_t = schedule(flow.t_trans(t), flow.trajectory)
    x_evolution = alpha_t * x_1 + beta_t * x_0
    plt.figure()
    # plt.figure(figsize=(10, 5), dpi=100)
    for i in x_evolution:
        plt.plot(t.cpu().numpy(), i.cpu().numpy())
    plt.title('conditional trajectory')
    plt.xlabel('time')
    plt.ylabel('value')
    plt.savefig(save_path)
    plt.show()

def draw_forward(flow, z0, z1, save_path):
    # assert isinstance(flow, ScoreMatch)
    t = torch.linspace(1e-2, 1-1e-2, 100).to('cuda')
    num = 20
    x_1 = z1[:num]
    # x_drift = z1[:num]
    # diffusion_cum = 0
    x_evolution = []
    drifts = []
    diffusions = []


    plt.figure()
    # plt.figure(figsize=(10, 5), dpi=100)
    for i in t:
        drift, diffusion = flow.sde.sde(x_1, 1-i)
        x_1 = x_1 + drift * 0.01 + diffusion * 0.1 * torch.randn_like(x_1)
        x_evolution.append(x_1)

        drift, diffusion = flow.sde.marginal_prob(z1[:num], (1-i).unsqueeze(-1), return_coef=True)
        drifts.append(drift)
        diffusions.append(diffusion)

    x_evolution = torch.stack(x_evolution).squeeze()
    plt.plot(t.cpu().numpy(), x_evolution.cpu().numpy())
    plt.title('forward trajectory')
    plt.xlabel('time')
    plt.ylabel('value')
    plt.savefig(save_path + ' trajectory.png')
    plt.show()

    drifts = torch.stack(drifts).squeeze()
    plt.plot(t.cpu().numpy(), drifts.cpu().numpy())
    plt.title('forward drift')
    plt.xlabel('time')
    plt.ylabel('value')
    plt.savefig(save_path + ' dift')
    plt.show()

    diffusions = torch.stack(diffusions).squeeze()
    plt.plot(t.cpu().numpy(), diffusions.cpu().numpy())
    plt.title('forward diffusion')
    plt.xlabel('time')
    plt.ylabel('value')
    plt.savefig(save_path + ' diffusion')
    plt.show()

def draw_hist(data, names, save_path, title, range=(-3, 3)):
    # data = torch.stack(data)
    if save_path is not None and not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))
    plt.figure()
    if range is None:
        range = (min([torch.min(_data).cpu().numpy() for _data in data]),
                 max([torch.max(_data).cpu().numpy() for _data in data]))
    for i, _data in enumerate(data):
        plt.hist(_data.cpu().numpy(), bins=100, density=True, alpha=0.3, label=names[i],
                 cumulative=False, range=range)
    # plt.hist(data.cpu().numpy(), bins=100, density=True, alpha=0.3, label=names,
    #          cumulative=False, range=(-3, 3))
    # plt.title(title)
    plt.xticks(fontsize=20, rotation=45)
    plt.ylabel('Density', fontsize=20)
    plt.xlabel('Feature', fontsize=20)
    plt.yticks(fontsize=20)
    plt.legend(fontsize=20)
    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path)
    plt.show()
    plt.close()

def line(data, title, save_path):
    t_values = torch.linspace(0, 1, 100).to('cuda').unsqueeze(-1)
    func_values = data.cpu().numpy()
    plt.figure()
    plt.plot(t_values.cpu().numpy(), func_values)
    plt.xlabel('t')
    plt.ylabel(title)
    plt.grid(True)
    plt.savefig(save_path)
    plt.show()
    plt.close()

def hist_3d(data, samples, bins=200, range_min=-3, range_max=3, mayavi=None):
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    # #  numpy 
    # z1 = data.cpu().numpy()
    # sampled_data = samples.cpu().numpy()
    #
    # #  2D 
    # H1, xedges1, yedges1 = np.histogram2d(z1[:, 0], z1[:, 1],
    #                                       bins=bins,
    #                                       range=[[range_min, range_max],
    #                                              [range_min, range_max]])
    # H2, xedges2, yedges2 = np.histogram2d(sampled_data[:, 0], sampled_data[:, 1],
    #                                       bins=bins,
    #                                       range=[[range_min, range_max],
    #                                              [range_min, range_max]])
    #
    # # 
    # xcenters1 = (xedges1[:-1] + xedges1[1:]) / 2
    # ycenters1 = (yedges1[:-1] + yedges1[1:]) / 2
    # X1, Y1 = np.meshgrid(xcenters1, ycenters1)
    #
    # xcenters2 = (xedges2[:-1] + xedges2[1:]) / 2
    # ycenters2 = (yedges2[:-1] + yedges2[1:]) / 2
    # X2, Y2 = np.meshgrid(xcenters2, ycenters2)
    #
    # # ============== : 3D  (plot_surface) ==============
    # # fig = plt.figure(figsize=(12, 5))
    # #
    # # ax1 = fig.add_subplot(121, projection='3d')
    # # ax1.plot_surface(X1, Y1, H1, cmap='Blues', edgecolor='none', alpha=0.8)
    # # ax1.set_title('Real Data 3D Histogram')
    # # ax1.set_xlabel('X')
    # # ax1.set_ylabel('Y')
    # # ax1.set_zlabel('Count')
    # #
    # # ax2 = fig.add_subplot(122, projection='3d')
    # # ax2.plot_surface(X2, Y2, H2, cmap='Oranges', edgecolor='none', alpha=0.8)
    # # ax2.set_title('Generated Data 3D Histogram')
    # # ax2.set_xlabel('X')
    # # ax2.set_ylabel('Y')
    # # ax2.set_zlabel('Count')
    # #
    # # plt.tight_layout()
    # # plt.show()
    #
    # # ============== : 3D  (bar3d) ==============
    # # 
    # #  offset 
    #
    # fig = plt.figure(figsize=(8, 6))
    # ax = fig.add_subplot(111, projection='3d')
    #
    # # 3D bar  X, Y, H 
    # #  x,y  (dx, dy)
    # dx = (xedges1[1] - xedges1[0]) * 0.9
    # dy = (yedges1[1] - yedges1[0]) * 0.9
    # mask = (H1 > 0).ravel()
    # mask2 = (H2 > 0).ravel()
    #
    # #  3D bar
    # #  offset 
    # #  (x,y) 
    # ax.bar3d(
    #     X2.ravel()[mask2],
    #     Y2.ravel()[mask2],
    #     np.zeros_like(H2).ravel()[mask2],
    #     dx,
    #     dy,
    #     H2.ravel()[mask2],
    #     color='orange', alpha=0.3,
    #     label='Generated Data'
    # )
    #
    # #  3D bar
    # ax.bar3d(
    #     X1.ravel()[mask],  # x 
    #     Y1.ravel()[mask],  # y 
    #     np.zeros_like(H1).ravel()[mask],  # bar  z 
    #     dx,  # bar 
    #     dy,  # bar 
    #     H1.ravel()[mask],  # bar 
    #     color='blue', alpha=0.3,
    #     label='Real Data'
    # )
    #
    #
    # ax.set_xlabel('X')
    # ax.set_ylabel('Y')
    # ax.set_zlabel('Count')
    # ax.set_title("Real vs Generated Data (3D Histogram)")
    # # ax.legend()
    # plt.tight_layout()
    # plt.show()
    # return

    #  2D 
    # from mayavi import mlab
    #
    # z_real = data.detach().cpu().numpy() #if isinstance(z_real, torch.Tensor) else z_real
    # H, xedges, yedges = np.histogram2d(
    #     z_real[:, 0], z_real[:, 1],
    #     bins=bins, range=[[range_min, range_max], [range_min, range_max]]
    # )
    #
    # # 
    # x_centers = 0.5 * (xedges[:-1] + xedges[1:])
    # y_centers = 0.5 * (yedges[:-1] + yedges[1:])
    # X, Y = np.meshgrid(x_centers, y_centers)
    # Xr = X.ravel()
    # Yr = Y.ravel()
    # Zr = H.ravel()
    #
    # #  points3d 
    # # size=Zr 
    # # scale_mode='none'  size 
    # pts = mlab.points3d(
    #     Xr, Yr, np.zeros_like(Zr),
    #     Zr,  # 
    #     scale_mode='scalar',  # Zr
    #     scale_factor=0.1,  # /
    #     colormap='blue-red'
    # )
    #
    # # “” transform
    # # Mayavi  bar3d ( barchart )
    # #  transform 
    # #  mlab.quiver3d / glyph3d 
    # # 
    #
    # #  0~Zr  points3d  glyph pipeline
    #
    # mlab.colorbar(pts, title='Count')
    # mlab.axes()
    # mlab.show()

    #  numpy
    z_real = data.detach().cpu().numpy() #if isinstance(z_real, torch.Tensor) else z_real
    z_fake = samples.detach().cpu().numpy() #if isinstance(z_fake, torch.Tensor) else z_fake

    #  2D hist 
    H_real, xedges, yedges = np.histogram2d(
        z_real[:, 0], z_real[:, 1],
        bins=bins,
        range=[[range_min, range_max],
               [range_min, range_max]]
    )
    H_fake, _, _ = np.histogram2d(
        z_fake[:, 0], z_fake[:, 1],
        bins=bins,
        range=[[range_min, range_max],
               [range_min, range_max]]
    )
    H_fake[H_fake > 750] = 750
    H_real[H_real > 100] = 100

    #  (X, Y)
    x_centers = 0.5 * (xedges[:-1] + xedges[1:])
    y_centers = 0.5 * (yedges[:-1] + yedges[1:])
    X, Y = np.meshgrid(x_centers, y_centers)

    #  Plotly Figure
    fig = go.Figure()

    # 
    fig.add_trace(go.Surface(
        x=X,
        y=Y,
        z=H_real,
        colorscale='Blues',
        # colorscale=[[0, 'blue'], [1, 'blue']],  # 
        opacity=0.5,  # 
        name='SDE Sampled Data',
        showscale=False,
        hoverinfo='x+y+z'
    ))

    # 
    fig.add_trace(go.Surface(
        x=X,
        y=Y,
        z=H_fake,
        colorscale='Oranges',
        # colorscale=[[0, 'orange'], [1, 'orange']],  # 
        opacity=0.5,  # 
        name='ODE Sampled Data',
        showscale=False,
        hoverinfo='x+y+z'
    ))

    # 
    fig.update_layout(
        title="3D Histogram (Surface) of Real vs Generated Data",
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Count',
            aspectratio=dict(x=1, y=1, z=0.7)
        ),
        width=800,
        height=600,
        legend=dict(
            x=0.8,
            y=0.9,
            bgcolor='rgba(255, 255, 255, 0)',
            bordercolor='rgba(0,0,0,0)'
        )
    )

    # Plotly  Surface  name 
    #  Scatter3d 
    fig.add_trace(go.Scatter3d(
        x=[None], y=[None], z=[None],
        mode='markers',
        marker=dict(
            size=10,
            color='blue'
        ),
        name='SDE Sampled Data'
    ))
    fig.add_trace(go.Scatter3d(
        x=[None], y=[None], z=[None],
        mode='markers',
        marker=dict(
            size=10,
            color='orange'
        ),
        name='ODE Sampled Data'
    ))

    fig.show()

def density_evolution_3d(traj, bins=500, range_min=-3, range_max=3,
                         save_path=None):
    # import matplotlib
    # matplotlib.use('Qt5Agg')
    from matplotlib import cm  # 
    from mpl_toolkits.mplot3d import Axes3D  # 
    """
     3D 

    :
    - traj: PyTorch  (101, 50000, 1)
    - bins: 
    - range_min: X 
    - range_max: X 
    - t_step:  Y 
    """
    #  traj  (101, 50000)

    if traj.dim() == 3 and traj.size(2) == 1:
        traj = traj.squeeze(2)  #  (101, 50000)
    elif traj.dim() != 2:
        raise ValueError("traj  (101, 50000, 1)  (101, 50000) ")

    traj_np = traj.detach().cpu().numpy()  #  NumPy 

    #  x 
    bin_edges = np.linspace(range_min, range_max, bins + 1)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    bin_width = bin_edges[1] - bin_edges[0]

    # 
    t_values = np.arange(traj_np.shape[0]) * 1 / traj_np.shape[0]  #  t  0  t_step

    # 
    X, Y = np.meshgrid(bin_centers, t_values)

    #  Z
    Z = np.zeros_like(X)

    # 
    for i in range(traj_np.shape[0]):
        counts, _ = np.histogram(traj_np[i], bins=bin_edges, density=True)
        # Z[i, :] = counts / (50000 * bin_width)
        Z[i, :] = counts

    #  3D 
    fig = plt.figure(figsize=(15, 10))
    ax = fig.add_subplot(111, projection='3d')

    # 
    xpos, ypos = X.ravel(), Y.ravel()
    zpos = np.zeros_like(xpos)

    # 
    dx = bin_width * 1  # 
    dy = 1/traj_np.shape[0] * 1

    # 
    dz = Z.ravel()

    # # 
    max_count = dz.max()
    colors = cm.cool(dz / max_count)  #  viridis 
    # 
    # from matplotlib.colors import Normalize
    # norm = Normalize(vmin=dz.min(), vmax=dz.max())
    # cmap = cm.viridis  #  colormap
    # colors = cmap(norm(dz))

    mappable = ScalarMappable(cmap=cm.cool)
    mappable.set_clim(0, max_count)
    # mappable.set_array(t_values)
    cbar = fig.colorbar(mappable, ax=ax, shrink=0.5, aspect=15, pad=0.1)
    # cbar.set_label(r'Value ($x_t$)', fontsize=25)
    # cbar.set_label(r'Value ($x_t$)', fontsize=25)
    cbar.ax.tick_params(labelsize=30)  # 

    # 
    ax.bar3d(xpos, ypos, zpos, dx, dy, dz, color=colors, zsort='average', alpha=0.7)

    # 
    ax.set_xlabel(r'Value ($x_t$)', fontsize=30, labelpad=25)
    ax.set_ylabel(r'Time ($t$)', fontsize=30, labelpad=25)
    # ax.set_zlabel('Count', fontsize=25, labelpad=10)
    ax.set_zlabel('Density', fontsize=30, labelpad=25)

    # 
    # ax.set_title('3D Histogram: Distribution Over Time', fontsize=15, pad=20)
    # 
    ax.view_init(elev=30, azim=-45)
    ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
    ax.set_yticklabels([1.0, 0.8, 0.6, 0.4, 0.2, .0], fontsize=30)
    ax.tick_params(axis='x', labelsize=25, pad=10)
    ax.tick_params(axis='y', labelsize=25, pad=10)
    ax.tick_params(axis='z', labelsize=25, pad=10)
    plt.tight_layout()
    # 
    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.show()