import time
from pathlib import Path
import numpy as np
from sklearn.cluster import KMeans
import kmedoids
from tqdm import tqdm
from DPM_utils import *
from scipy.special import logsumexp, betaln
from scipy.stats import wishart,dirichlet,beta

def initialize_mixture_parameters(X, partial_labels, K, alpha=1.0, likelihood="gaussian"):
    """
    Initialise the parameters of the mixture model
    :param X: the data matix of shape [N,K]
    :param partial_labels: the partial labels ranging from {1,..,K'} K' <= K
    :param K: the number of clusters considering.
    :param alpha: the dirichelet distribution concetration prior by default equal to 1/K
    :return: params: dictionary of the parameters of the mixture model
    """
    d = X.shape[1]
    alpha = 1 / K
    epsilon_0 = alpha * np.ones((K, ))
    params: dict[str, np.ndarray] = {"epsilon_0": epsilon_0}
    if likelihood == "gaussian":
        kappa_0 = np.ones((K, ))
        nu_0 = d * np.ones((K, ))
        m_0 = np.zeros((K, d))
        L_0 = np.eye(d).reshape(1, d, d) * np.ones((K, 1, 1))
        params.update({
            "kappa_0": kappa_0,
            "nu_0": nu_0,
            "m_0": m_0,
            "L_0": L_0,
        })
    elif likelihood == "bernoulli":
        alpha_0 = np.ones((K, d))
        beta_0 = np.ones((K, d))
        params.update({"alpha_0": alpha_0, "beta_0": beta_0})
    else:
        raise ValueError(f"Unsupported likelihood '{likelihood}'")
    return params

def initialize_phi(N, K):
    """
    Initializing responsiblities or posterior class probabilities
    :param N: Number of instances
    :param K: Number of classes
    :return: R : responsibilities of shape [N, K]
    """
    print("[INFO] Random initialization of phi")
    phi = np.random.rand(N, K)
    phi = normalize(phi, axis=1)
    return phi

def initialise_phi_with_kmeans(X, K, likelihood="gaussian"):
    if likelihood == "bernoulli":
        print("[INFO] KMedoids initialization of phi")
        km = kmedoids.KMedoids(K, metric='hamming', method='fasterpam').fit(X)
        mu = X[km.medoid_indices_]
        diffs = np.not_equal(X[:, np.newaxis, :], mu[np.newaxis, :, :])
        distances = diffs.mean(axis=2)
        phi = np.exp(-distances)
    elif likelihood == "gaussian":
        print("[INFO] KMeans initialization of phi")
        mu = KMeans(K).fit(X).cluster_centers_
        phi = np.exp(-0.5 * LA.norm(X.reshape(X.shape[0], 1, X.shape[1]) - mu.reshape(1, K, X.shape[1]), 2, 2))
    return normalize(phi, 1), mu

class HMRF_GMM:
    """
    The semi-supervised Dirichelet process Gaussian mixture model
    """
    def __init__(
        self,
        X,
        K,
        partial_labels,
        alpha=1.0,
        epsilon=1e-9,
        lambda_=1.,
        init="kmeans",
        weight_prior="Dirichelet distribution",
        eta=1,
        mode='complete',
        likelihood="bernoulli",
    ):
        """
        init function
        :param X: data numpy array of shape [N, d]
        :param K: Trunction level for the approximate betas
        :param partial_labels: array of partial labels where unlabeled data is set to None labeled data is in {1,..., T}
        :param alpha: prior on the Dirichlet dist
        :param epsilon: for stability
        :param lambda_: the hyperparamater for the HMRF
        :param init: initialization Kmeans or random
        :param weight_prior: weight prior Dirichelet distribution or Dirichelet Proceszs
        :param eta: Beta prior
        :param mode: Faster if we consider a tree HMRF on the hidden labels but complete is recommended
        """
        if mode == "tree":
            self.dict_N, self.mask = construct_neighborhood_tree(partial_labels)
        else:
            self.dict_N, self.mask = construct_neighborhood_complete_graph(partial_labels)
        self.N = X.shape[0]
        self.K = K
        self.d = X.shape[1]
        self.eps = epsilon
        self.likelihood = likelihood.lower()
        if self.likelihood not in {"gaussian", "bernoulli"}:
            raise ValueError("likelihood must be either 'gaussian' or 'bernoulli'")
        self.params_0 = initialize_mixture_parameters(
            X,
            partial_labels,
            self.K,
            alpha=alpha,
            likelihood=self.likelihood,
        )
        if init == "kmeans":
            self.phi, mu_0 = initialise_phi_with_kmeans(X, K, likelihood=self.likelihood)
            if self.likelihood == "gaussian":
                self.params_0["m_0"] = mu_0
        else:
            self.phi = initialize_phi(self.N, K)

        self.X = np.asarray(X, dtype=float)
        self.lambda_ = lambda_
        self.weight_prior = weight_prior
        self.eta = eta / K
        self.S = 64
        self.to = 1024
        self.delta = 0.5
        self.tuples_ml = get_tuples(self.dict_N)

    def compute_N(self, phi):
        """
        Function computing the sum over all instances of the responsibilities --> N_k = sum_n phi_{nk}
        :param phi: of shape [N, K]
        :return: N: of shape [1, K]
        """
        return np.sum(phi, 0)

    def compute_nu(self, N): # for GMM
        """
        The function computing the seconnd variational parameter of the wishart distribution
        :param N: of shape [1, K]
        :return: nu: the second variational parameter of the wishart distributionof shape [K, ]
        """
        nu = self.params_0["nu_0"] + N + 1
        return nu

    def compute_gamma_1(self, N):
        """
        The function computing the first variational parameter of the beta distribution
        :param N: of shape [1, K]
        :return: gamma_1: the first variational parameter of the beta distribution of shape [K, ]
        """
        gamma_1 = 1 + N
        return gamma_1

    def compute_gamma_2(self, N):
        """
        The function computing the second variational parameter of the beta distribution
        :param N: of shape [1, K]
        :return: gamma_2: the second variational parameter of the beta distribution of shape [K, ]
        """
        gamma_2 = self.eta + cumsum_ex(N[::-1])[::-1]
        return gamma_2

    def compute_epsilon(self, N):
        """
        The function computing the first variational parameter of the Dirichelet distribution (if Dirichelet dist prior is considered)
         :param N: of shape [1, K]
        :return: epsilon: the first variational parameter of the Dirichelet distribution of shape [K, ]
        """
        epsilon = self.params_0["epsilon_0"] + N
        return epsilon

    def compute_kappa(self, N): # for GMM
        """
        The function computing the scalar of the variance for the approximating distribution of the means
        :param N: of shape [1, K]
        :return: kappa: of shape [1, K]
        """
        kappa = self.params_0["kappa_0"] + N
        return kappa

    def compute_m(self, phi, N): # for GMM
        """
        The function computing the scalar of the mean for the approximating distribution of the means
        :param phi: of shape [N, K]
        :param N: of shape [1, K]
        :return: m: of shape [K, d]
        """
        m = (self.params_0["kappa_0"].reshape(self.K, 1) * self.params_0["m_0"]
             + np.sum(np.reshape(phi, [self.N, self.K, 1]) * np.reshape(self.X, [self.N, 1, self.d]), axis=0) )\
            / (self.params_0["kappa_0"] + N ).reshape(self.K, 1)
        return m

    def compute_L(self, phi, m, nu): # for GMM
        """
        Function compution the wishart mean matrix of the approximating distribution of the convariance matrices
        :param phi: of shape [N, K]
        :param m: of shape [K, d]
        :param nu: of shape [1, K]
        :return:
        """
        m_0 = np.reshape(self.params_0["m_0"], [self.K, self.d, 1])
        m =  np.reshape(m, [self.K, self.d, 1])
        L_inv = self.params_0["L_0"] + self.params_0["kappa_0"].reshape(self.K, 1, 1) * np.matmul(m - m_0, np.reshape(m - m_0,(self.K, 1, self.d))) \
                + np.sum(np.reshape(phi, [self.N, self.K, 1, 1]) * np.matmul(self.X.reshape(self.N, 1, self.d, 1) - m.reshape(1, self.K, self.d, 1),
                                                                             self.X.reshape(self.N, 1, 1, self.d) - m.reshape(1, self.K, 1, self.d) ) , axis=0)
        return np.linalg.inv(L_inv)

    def compute_bern_param_ab(self, phi):
        """
        Function computing the bernoulli parameters alpha and beta
        :param phi: of shape [N, K]
        :return:
        """
        ones = np.dot(phi.T, self.X)
        zeros = np.dot(phi.T, 1 - self.X)
        alpha_post = self.params_0["alpha_0"] + ones
        beta_post = self.params_0["beta_0"] + zeros
        return alpha_post, beta_post

    def compute_phi_bernoulli(self, alpha, beta_param, epsilon, V, phi_t, gamma_1, gamma_2):
        """Function computing the variational parameters phi for the multivariate bernoulli case
        :param alpha: of shape [K, d]
        :param beta_param: of shape [K, d]
        :param epsilon: of shape [1, K]
        :param V: of shape [K, K]
        :param phi_t: of shape [N, K]
        :param gamma_1: of shape [1, K]
        :param gamma_2: of shape [1, K]
        :return:
        """
        if self.weight_prior == "Dirichelet distribution":
            val = digamma(epsilon) - digamma(np.sum(epsilon))
        else :
            val = digamma(gamma_1) - digamma(gamma_1 + gamma_2) + cumsum_ex(digamma(gamma_2) - digamma(gamma_1 + gamma_2))
        log_phi = np.zeros((self.N, self.K))
        for n in range(self.N):
            if self.mask[n] == 0:
                log_phi[n, :] = log_phi[n, :] + val
            else:
                for j in self.dict_N[n]:
                    log_phi[n, :] = log_phi[n, :] - self.lambda_ * np.sum(phi_t[j, :].reshape(1, self.K) * V, axis=1)
            log_phi[n, :] = log_phi[n, :] + np.sum((digamma(alpha) - digamma(alpha + beta_param)).reshape(self.K, self.d) * self.X[n, :].reshape(1, self.d), axis=1) \
                             + np.sum((digamma(beta_param) - digamma(alpha + beta_param)).reshape(self.K, self.d) * (1 - self.X[n, :].reshape(1, self.d)), axis=1)

        log_phi = log_phi - logsumexp(log_phi,axis=1)[:,np.newaxis]
        phi = np.exp(log_phi)
        return normalize(phi, axis=1)


    def compute_phi(self, m, L, epsilon, V, nu, phi_t, kappa, gamma_1, gamma_2):
        """
        Function computing the variational parameters phi
        :param m: of shape [K, d]
        :param L: of shape [K, d, d]
        :param epsilon: of shape [1, K]
        :param V: of shape [K, K]
        :param nu: of shape [1, K]
        :param phi_t: of shape [N, K]
        :param kappa: of shape [1, K]
        :param gamma_1: of shape [1, K]
        :param gamma_2: of shape [1, K]
        :return:
        """
        if self.weight_prior == "Dirichelet distribution":
            val = digamma(epsilon) - digamma(np.sum(epsilon))
        else :
            val = digamma(gamma_1) - digamma(gamma_1 + gamma_2) + cumsum_ex(digamma(gamma_2) - digamma(gamma_1 + gamma_2))

        #print("val", val)

        log_phi = np.zeros((self.N, self.K))

        for n in range(self.N):
            if self.mask[n] == 0:
                log_phi[n, :] = log_phi[n, :] + val
                #print("digmma epsilon ", digamma(epsilon) - digamma(np.sum(epsilon)))
            else:
                for j in self.dict_N[n]:
                    log_phi[n, :] = log_phi[n, :] - self.lambda_ * np.sum(phi_t[j, :].reshape(1, self.K) * V, axis=1)
            log_phi[n, :] = log_phi[n, :] - 0.5 * self.d/kappa - 0.5 * nu * np.trace(np.matmul(L, np.matmul((self.X[n,:].reshape(1, self.d) - m).reshape(self.K, self.d, 1),
                                                                                       (self.X[n,:].reshape(1, self.d) - m).reshape(self.K, 1, self.d) )), axis1=1 , axis2=2) \
                            + 0.5 * (LA.slogdet(L)[1] + multivar_digamma(nu, self.d)) - 0.5 * self.d * np.log(np.pi)

        log_phi = log_phi - logsumexp(log_phi,axis=1)[:,np.newaxis]
        phi = np.exp(log_phi)
        #print(np.sum(log_phi, axis=1).min(0))
        return normalize(phi, axis=1)


    def compute_potentials_bernoulli_(self, alpha, beta_param):
        """Symmetrized KL divergence between two multivariate Bernoulli means.
        :param alpha: of shape [K, d]
        :param beta_param: of shape [K, d]
        :return:
        """
        V = np.zeros((self.K, self.K))
        for k in range(self.K - 1):
            for l in range(k + 1, self.K):
                kl_kl = (
                    betaln(alpha[l, :], beta_param[l, :]) - betaln(alpha[k, :], beta_param[k, :])
                    + (alpha[k, :] - alpha[l, :]) * (digamma(alpha[k, :]) - digamma(alpha[k, :] + beta_param[k, :]))
                    + (beta_param[k, :] - beta_param[l, :]) * (digamma(beta_param[k, :]) - digamma(alpha[k, :] + beta_param[k, :]))
                )
                kl_lk = (
                    betaln(alpha[k, :], beta_param[k, :]) - betaln(alpha[l, :], beta_param[l, :])
                    + (alpha[l, :] - alpha[k, :]) * (digamma(alpha[l, :]) - digamma(alpha[l, :] + beta_param[l, :]))
                    + (beta_param[l, :] - beta_param[k, :]) * (digamma(beta_param[l, :]) - digamma(alpha[l, :] + beta_param[l, :]))
                )
                V[k, l] = np.sum(kl_kl + kl_lk)
        return V + V.T


    def compute_potentials(self, m, L, nu, kappa):
        """
        Function computing the pairwise potential matrices
        :param m: of shape [K, d]
        :param L: of shape [K, d, d]
        :param nu: of shape [1, K]
        :param kappa: of shape [1, K]
        :return:
        """
        V = np.zeros((self.K, self.K))
        L_inv = LA.inv(L)
        for k in range(self.K - 1):
            for l in range(k + 1, self.K):
                V[k, l] = self.d *(kappa[k]/(self.eps + kappa[l]) - 1) + np.log(kappa[k] / (self.eps + kappa[l])) + nu[k] * \
                            np.trace(np.matmul(L[k,:,:], np.matmul((m[k, :] - m[l, :]).reshape(self.d,1), (m[k, :] - m[l, :]).reshape(1, self.d) ))) \
                            + self.d *(kappa[l]/(self.eps + kappa[k]) - 1) + np.log(kappa[l] / (self.eps + kappa[k])) + nu[l] * \
                            np.trace(np.matmul(L[l,:,:], np.matmul((m[l, :] - m[k, :]).reshape(self.d,1), (m[l, :] - m[k, :]).reshape(1, self.d) ))) \
                            + 0.5 * (nu[k] - nu[l]) * ( np.log(self.eps + LA.det(L[k,:,:])) - np.log(self.eps + LA.det(L[l,:,:]))
                                                       + multivar_digamma(nu[k], self.d) - multivar_digamma(nu[l], self.d)) \
                            + 0.5 * nu[k] * (np.trace(np.matmul(L_inv[l, :,:], L[k,:,:])) - self.d) \
                            + 0.5 * nu[l] * (np.trace(np.matmul(L_inv[k, :,:], L[l,:,:])) - self.d)
        #print("V",  V)
        return V + V.T

    def Inference(self, max_iter=1000, rel_tol=1e-9, debug=False):
        """
        The gradient ascent algorithm using the fixed point equations
        :param max_iter: Number of max iterations
        :param debug: debug if True
        :return: L: List the evidence lower bound at each iteration
        """
        loss = []
        stop_criterion = False
        progress = tqdm(total=max_iter, desc="HMRF-DPGMM", unit="iter", leave=False)

        try:
            for i in range(max_iter):
                if debug:
                    progress.write(f"[DEBUG] Iteration  {i + 1}")
                update_t0 = time.perf_counter()

                # Fixed point equations for gradient ascent
                N = self.compute_N(self.phi)
                if self.likelihood == "gaussian":
                    nu = self.compute_nu(N)
                    kappa = self.compute_kappa(N)
                    m = self.compute_m(self.phi, N)
                    L = self.compute_L(self.phi, m, nu)
                    V = self.compute_potentials(m, L, nu, kappa)
                else:
                    alpha, beta_param = self.compute_bern_param_ab(self.phi)
                    V = self.compute_potentials_bernoulli_(alpha, beta_param)

                if self.weight_prior == "Dirichelet distribution":
                    epsilon = self.compute_epsilon(N)
                    gamma_1 = None
                    gamma_2 = None
                else:
                    epsilon = None
                    gamma_1 = self.compute_gamma_1(N)
                    gamma_2 = self.compute_gamma_2(N)

                phi_t = np.copy(self.phi)
                if self.likelihood == "gaussian":
                    self.phi = self.compute_phi(m, L, epsilon, V, nu, phi_t, kappa, gamma_1, gamma_2)
                else:
                    self.phi = self.compute_phi_bernoulli(alpha, beta_param, epsilon, V, phi_t, gamma_1, gamma_2)

                update_time = time.perf_counter() - update_t0
                progress.set_postfix_str(f"upd {update_time:.2f}s", refresh=False)
                elbo_t0 = time.perf_counter()

                # Compute evidence lower bound
                if self.likelihood == "gaussian":
                    l, log_likelihood_term, hmrf_term = self.compute_elbo(self.phi,nu,kappa,epsilon,m,L,V,N,gamma_1,gamma_2)
                else:
                    l, log_likelihood_term, hmrf_term = self.compute_elbo_bernoulli(self.phi,alpha,beta_param,epsilon,V,N,gamma_1,gamma_2)
                elbo_time = time.perf_counter() - elbo_t0
                progress.set_postfix_str(
                    f"elbo {elbo_time:.2f}s",
                    refresh=False,
                )

                loss.append(l)
                loss_change = np.abs((loss[-1] - loss[-2]) / loss[-2]) if i > 0 else np.inf
                delta_str = f"{loss_change:.2e}" if np.isfinite(loss_change) else "inf"
                progress.set_postfix_str(
                    f"upd {update_time:.2f}s | elbo {elbo_time:.2f}s | Δ {delta_str}",
                    refresh=False,
                )
                progress.update(1)

                if debug:
                    progress.write(f"- elbo is  {l}")
                    progress.write(f"- log_likelihood_term is  {log_likelihood_term}")
                    progress.write(f"- hmrf_term is  {hmrf_term}")
                    if i > 0:
                        progress.write(f"- relative elbo change is  {loss_change}")
                    progress.write(
                        "[TIMING] updates {:.3f}s |  elbo {:.3f}s".format(
                            update_time,
                            elbo_time,
                        )
                    )

                if len(loss) > 2:
                    stop_criterion = loss_change < rel_tol
                if stop_criterion:
                    progress.write(f"[INFO] Converged at iteration  {i}")
                    break
                if i == max_iter - 1:
                    progress.write("[INFO] Reached maximum iterations without convergence")
        finally:
            progress.close()
        return loss


    def compute_elbo_bernoulli(self, phi, alpha, beta_param, epsilon, V, N, gamma_1, gamma_2): 
        """
        Function compute the evidence lower bound as defined for HRMF-DPCMM From the variational parameters.
        :param phi: of shape [N,K]
        :param alpha: of shape [K, d]
        :param beta_param: of shape [K, d]
        :param epsilon: of shape [1, K]
        :param V: of shape [K, K]
        :param N: of shape [1, K]
        :param gamma_1: of shape [1, K]
        :param gamma_2: of shape [1, K]
        :return:
        """
        hmrf_term = 0
        log_likelihood_term = 0

        if self.weight_prior == "Dirichelet distribution":
            val = digamma(epsilon) - digamma(np.sum(epsilon))
        else:
            val = digamma(gamma_1) - digamma(gamma_1 + gamma_2) + cumsum_ex(
                digamma(gamma_2) - digamma(gamma_1 + gamma_2))

        entropy_z = 0.0
        for n in range(self.N):
            logp_n = (
                np.sum(
                    (digamma(alpha) - digamma(alpha + beta_param)).reshape(self.K, self.d) * self.X[n, :].reshape(1, self.d),
                    axis=1,
                )
                + np.sum(
                    (digamma(beta_param) - digamma(alpha + beta_param)).reshape(self.K, self.d) * (1 - self.X[n, :].reshape(1, self.d)),
                    axis=1,
                )
            )
            log_likelihood_term += np.dot(phi[n, :], logp_n)
            if self.mask[n] == 0:
                log_likelihood_term += np.dot(phi[n, :], val)
            entropy_z -= np.dot(phi[n, :], np.log(phi[n, :] + self.eps))

        log_likelihood_term += entropy_z

        for tuple in self.tuples_ml:
            hmrf_term += - self.lambda_ * np.sum(np.sum(phi[tuple[0],:].reshape(self.K, 1) * phi[tuple[1],:].reshape(1,self.K) * V))

        alpha_0 = self.params_0["alpha_0"]
        beta_0 = self.params_0["beta_0"]
        beta_kl = (
            betaln(alpha_0, beta_0) - betaln(alpha, beta_param)
            + (alpha - alpha_0) * (digamma(alpha) - digamma(alpha + beta_param))
            + (beta_param - beta_0) * (digamma(beta_param) - digamma(alpha + beta_param))
        )
        log_likelihood_term -= np.sum(beta_kl)

        if self.weight_prior != "Dirichelet distribution":
            for k in range(self.K):
                log_likelihood_term += beta(gamma_1[k], gamma_2[k]).entropy()

        if self.weight_prior == "Dirichelet distribution":
            log_likelihood_term += dirichlet(epsilon).entropy()

        elbo = log_likelihood_term + hmrf_term
        return elbo/self.N, log_likelihood_term, hmrf_term


    def compute_elbo(self, phi, nu, kappa, epsilon, m, L, V, N, gamma_1, gamma_2):
        """
        Function compute the evidence lower bound as defined for HRMF-DPCMM From the variational parameters.
        :param phi: of shape [N,K]
        :param nu: of shape [1, K]
        :param kappa: of shape [1, K]
        :param epsilon: of shape [1, K]
        :param m: of shape [N, d]
        :param L: of shape [N, d, d]
        :param V: of shape [K, K]
        :param N: of shape [1, K]
        :param gamma_1: of shape [1, K]
        :param gamma_2: of shape [1, K]
        :return:
        """
        hmrf_term = 0
        log_likelihood_term = 0

        if self.weight_prior == "Dirichelet distribution":
            val = digamma(epsilon) - digamma(np.sum(epsilon))
        else:
            val = digamma(gamma_1) - digamma(gamma_1 + gamma_2) + cumsum_ex(
                digamma(gamma_2) - digamma(gamma_1 + gamma_2))

        for n in range(self.N):
            for k in range(self.K):
                log_likelihood_term += - 0.5 * phi[n, k] * nu[k] * np.trace(np.matmul(L[k,:,:], np.matmul(self.X[n,:].reshape(self.d, 1) - m[k].reshape(self.d, 1),
                                                                                        self.X[n,:].reshape(1, self.d) - m[k].reshape(1, self.d) )) ) - 0.5 * self.d * N[k] / kappa[k]
                if self.mask[n] == 0:
                    log_likelihood_term += phi[n,k] * val[k]

        for k in range(self.K):
            log_likelihood_term += 0.5 * (nu[k] - self.d + N[k]) * (multivar_digamma(nu[k], self.d) + np.log(self.eps + LA.det(L[k, :, :]))) - 0.5 * self.d * nu[k]

        for tuple in self.tuples_ml:
            hmrf_term += - self.lambda_ * np.sum(np.sum(phi[tuple[0],:].reshape(self.K, 1) * phi[tuple[1],:].reshape(1,self.K) * V))

        for k in range(self.K):
            log_likelihood_term += wishart(nu[k],L[k,:,:]).entropy() + 0.5*(multivar_digamma(nu[k], self.d) + np.log(self.eps + LA.det(L[k, :, :]))) + 0.5*self.d*np.log(self.eps + kappa[k]) - np.sum(phi[:,k] * np.log(phi[:,k] + self.eps))
            if self.weight_prior != "Dirichelet distribution":
                log_likelihood_term += beta(gamma_1[k], gamma_2[k]).entropy()

        if self.weight_prior == "Dirichelet distribution":
            log_likelihood_term += dirichlet(epsilon).entropy()

        elbo = log_likelihood_term + hmrf_term
        return elbo/self.N, log_likelihood_term, hmrf_term

    def infer_clusters(self):
        """
        Function returning the clustering assignments for each data sample
        :return: y_pred: array of shape [N, ]
        """
        return np.argmax(self.phi, axis=1)

    def plot_elbo(
        self,
        elbos: list[float] | np.ndarray | None = None,
        ax: plt.Axes | None = None,
        save_path: str | Path | None = None,
        show: bool | None = None,
        title: str | None = None,
        tick_interval: int = 20,
    ) -> plt.Axes:
        """Plot ELBO values per iteration with x-ticks every ``tick_interval`` steps."""

        if elbos is None:
            return ax
        elbo_values = np.asarray(elbos, dtype=float)
        if elbo_values.ndim != 1:
            raise ValueError("elbos must be a 1-D sequence of scalars")
        if elbo_values.size != 0:
            iterations = np.arange(1, elbo_values.size + 1)
            created_fig = False
            if ax is None:
                fig, ax = plt.subplots(figsize=(7, 4))
                created_fig = True
            else:
                fig = ax.figure

            ax.plot(iterations, elbo_values, linewidth=1.6, color="tab:blue")
            ax.set_xlabel("Iteration")
            ax.set_ylabel("ELBO")
            ax.set_title(title if title is not None else "Stochastic ELBO")
            ax.grid(True, linestyle="--", alpha=0.3)

            if tick_interval <= 0:
                raise ValueError("tick_interval must be positive")
            max_tick = ((iterations[-1] - 1) // tick_interval + 1) * tick_interval
            ticks = np.arange(tick_interval, max_tick + 1, tick_interval)
            ticks = np.concatenate(([0], ticks)) if ticks.size else np.array([0])
            ax.set_xticks(ticks)

            if save_path is not None:
                fig.savefig(Path(save_path), dpi=250, bbox_inches="tight")

            should_show = show if show is not None else created_fig
            if should_show:
                plt.show()
        return ax
