Gmm estimator
rex.gmm_estimator.GMMEstimator
¤
__init__(data: jax.typing.ArrayLike, name: str = 'GMM', threshold: float = 1e-07, verbose: bool = True)
¤
Gaussian Mixture Model Estimator.
Parameters:
-
data(ArrayLike) –1D array of delay data.
-
name(str, default:'GMM') –Name of the model.
-
threshold(float, default:1e-07) –Threshold for determining if the data is deterministic.
-
verbose(bool, default:True) –Whether to print progress.
fit(num_steps: int = 100, num_components: int = 2, step_size: float = 0.05, seed: int = 0)
¤
Fit the model to the data.
Parameters:
-
num_steps(int, default:100) –Number of steps to train the model.
-
num_components(int, default:2) –Number of components in the mixture model.
-
step_size(float, default:0.05) –Step size for the optimizer.
-
seed(int, default:0) –Random seed.
get_dist(percentile: float = 0.99) -> base.StaticDist
¤
Get the distribution.
Parameters:
-
percentile(float, default:0.99) –A percentile to prune the number of components that do not contribute much.
Returns:
-
StaticDist–base.StaticDist: The distribution object.
plot_hist(ax: plt.Axes = None, edgecolor: str = None, facecolor: str = None, bins: int = 100, xmin: float = None, xmax: float = None, num_points: int = 1000, plot_dist: bool = True) -> plt.Axes
¤
Plot the histogram of the data and the fitted distribution.
Parameters:
-
ax(Axes, default:None) –Axes to plot on.
-
edgecolor(str, default:None) –Edge color of the histogram.
-
facecolor(str, default:None) –Face color of the histogram.
-
bins(int, default:100) –Number of bins for the histogram.
-
xmin(float, default:None) –Minimum x value for the histogram. Can be used to avoid outliers.
-
xmax(float, default:None) –Maximum x value for the histogram. Can be used to avoid outliers.
-
num_points(int, default:1000) –Number of points to plot the distribution.
-
plot_dist(bool, default:True) –Whether to plot the fitted distribution.
Returns:
-
Axes–The axes with the plot.
plot_loss(ax: plt.Axes = None, edgecolor: str = None) -> plt.Axes
¤
Plot the loss function.
Parameters:
-
ax(Axes, default:None) –Axes to plot on.
-
edgecolor(str, default:None) –Edge color of the plot.
Returns:
-
Axes–plt.Axes: The axes with the plot.
plot_normalized_weights(ax: plt.Axes = None, edgecolor: str = None) -> plt.Axes
¤
Plot the normalized weights.
Parameters:
-
ax(Axes, default:None) –Axes to plot on.
-
edgecolor(str, default:None) –Edge color of the plot.
Returns:
-
Axes–The axes with the plot.
animate_training(num_frames: int = 30, fig: plt.Figure = None, ax: plt.Axes = None, edgecolor: str = None, facecolor: str = None, bins: int = 40, xmin: float = None, xmax: float = None, num_points: int = 1000) -> matplotlib.animation.FuncAnimation
¤
Animate the training process.
Parameters:
-
num_frames(int, default:30) –Number of frames to animate.
-
fig(Figure, default:None) –Figure to plot on.
-
ax(Axes, default:None) –Axes to plot on.
-
edgecolor(str, default:None) –Edge color of the histogram.
-
facecolor(str, default:None) –Face color of the histogram.
-
bins(int, default:40) –Number of bins for the histogram.
-
xmin(float, default:None) –Minimum x value for the histogram. Can be used to avoid outliers.
-
xmax(float, default:None) –Maximum x value for the histogram. Can be used to avoid outliers.
-
num_points(int, default:1000) –Number of points to plot the distribution.
Returns:
-
FuncAnimation–matplotlib.animation.FuncAnimation: The animation object.