"""
PyTorch log-likelihood modules for Hawkes process models.

This module implements various maximum likelihood estimation models for Hawkes processes:
* HawkesMLE - with constant α parameters
* HawkesFeatureMLE - past version with α parameterized by features via logistic link and 1 unique feature
* HawkesMultiFeatureMLE - extension supporting multiple features
* HawkesSlowLoopModelOgata - Ogata's recursive implementation for exponential Hawkes processes
"""
from __future__ import annotations
import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Union
from numpy.typing import NDArray
from openhawkes.utils import fill_tril


class HawkesMLE(nn.Module):
    """Constant-α exponential Hawkes log-likelihood model.
    
    This class implements maximum likelihood estimation for the exponential Hawkes process
    with constant alpha parameters.
    
    Attributes:
        nb_types: Number of event types in the process
        mu: Base intensity parameters (one per event type)
        alpha: Excitation parameters (nb_types x nb_types matrix)
        beta: Decay parameter(s), either shared or per-type
        beta_unique: Whether to use a single shared beta or one per event type
    """

    def __init__(self, nb_types: int, beta_unique: bool = True):
        """Initialize the Hawkes MLE model.
        
        Args:
            nb_types: Number of event types in the Hawkes process
            beta_unique: If True, use a single shared beta parameter; 
                        otherwise, use one beta per event type
        """
        super().__init__()
        self.nb_types = nb_types
        self.mu = nn.Parameter(torch.full((nb_types,), 0.1))
        self.alpha = nn.Parameter(torch.full((nb_types, nb_types), 0.1))
        self.beta_unique = beta_unique
        
        if self.beta_unique:
            self.beta = nn.Parameter(torch.tensor(1.0))
        else:
            self.beta = nn.Parameter(torch.ones(nb_types))

    def forward(
        self,
        event_times: torch.Tensor,  # (S,N)
        event_types: torch.Tensor,  # (S,N)
        mask: torch.Tensor,  # (S,N)
        t0: torch.Tensor,  # (S,1)
        t1: torch.Tensor,  # (S,1)
    ) -> torch.Tensor:
        """Compute log-likelihood for batched sequences.
        
        Args:
            event_times: Tensor of event timestamps, shape (batch_size, max_seq_len)
            event_types: Tensor of event type indices, shape (batch_size, max_seq_len)
            mask: Binary mask for valid events, shape (batch_size, max_seq_len)
            t0: Start time for each sequence, shape (batch_size, 1)
            t1: End time for each sequence, shape (batch_size, 1)
            
        Returns:
            log_likelihood: The computed log-likelihood for the batch
        """
        device = event_times.device
        S, N = event_times.shape
        D = self.nb_types
        mu, alpha, beta = self.mu, self.alpha, self.beta
        alpha = alpha.unsqueeze(0).repeat(S, 1, 1)
        t0 = t0.reshape(S, 1)
        t1 = t1.reshape(S, 1)
        
        # ------- (1) Log intensities computation -----------------------------------
        # Calculate all pairwise time differences within sequences
        dt = event_times[:, :, None] - event_times[:, None]
        dt = fill_tril(dt, -1) * mask.unsqueeze(2)  # Keep only valid j < i

        # Get alpha values based on event types
        row = event_types.unsqueeze(2).expand(S, N, N)  # Target event types
        col = event_types.unsqueeze(1).expand(S, N, N)  # Source event types
        batch = torch.arange(S, device=device).view(S, 1, 1).expand(-1, N, N)
        a_vals = alpha[batch, row, col]

        # Apply the exponential kernel with appropriate decay
        if self.beta_unique:
            kernel = beta * torch.exp(-beta * dt) * mask.unsqueeze(2)
        else:
            kernel = beta[row] * torch.exp(-beta[row] * dt) * mask.unsqueeze(2)
            
        # Compute the total excitation for each event
        excitation = fill_tril(a_vals * kernel, -1).sum(2)
        lam = mu[event_types] * mask + excitation
        log_int = (torch.log(lam.clamp_min(1e-15)) * mask).sum(1)

        # ------- (2) Compensator computation ---------------------------------------
        # Base compensator from background rates
        Tm = (t1 - t0).squeeze(-1)
        comp_base = Tm * mu.sum()

        # Additional compensator term from excitation effects
        diff_t = (t1 - event_times).clamp_min(0.0)
        
        if self.beta_unique:
            integ = (1 - torch.exp(-beta * diff_t)) * mask
        else:
            integ = (1 - torch.exp(-beta[event_types] * diff_t)) * mask

        s_idx = torch.arange(S, device=device).view(S, 1, 1)
        i_idx = torch.arange(D, device=device).view(1, D, 1)
        j_idx = event_types.unsqueeze(1)
        a_ij = alpha[s_idx, i_idx, j_idx]
        
        # Sum compensator contributions across all dimensions
        if self.beta_unique:
            comp = (a_ij * integ.unsqueeze(1)).sum(2).sum(1)
        else:
            comp = (a_ij * integ.unsqueeze(1)).sum(2).sum(1)
 
        return log_int - (comp + comp_base)


class HawkesFeatureMLE(nn.Module):
    """Hawkes process model where α depends on features via logistic link.
    
    This class implements a Hawkes process where the excitation matrix α is
    computed as a function of features through a logistic link function
    
    Attributes:
        nb_types: Number of event types in the process
        beta_unique: Whether to use a single shared beta or one per event type
        mu: Base intensity parameters
        gamma: Constant term in the logistic link
        theta1: Coefficient for source node features
        theta2: Coefficient for target node features
        beta: Decay parameter(s)
    """

    def __init__(self, nb_types: int, beta_unique: bool = True):
        """Initialize the feature-based Hawkes MLE model.
        
        Args:
            nb_types: Number of event types in the Hawkes process
            beta_unique: If True, use a single shared beta parameter; 
                        otherwise, use one beta per event type
        """
        super().__init__()
        self.nb_types = nb_types
        self.beta_unique = beta_unique

        # Initialize model parameters
        self.mu = nn.Parameter(torch.rand(nb_types) * 0.1 + 0.01)
        self.gamma = nn.Parameter(torch.zeros(nb_types, nb_types))
        self.theta1 = nn.Parameter(torch.zeros(nb_types, nb_types))
        self.theta2 = nn.Parameter(torch.zeros(nb_types, nb_types))

        # Initialize decay parameter(s)
        self.beta = (
            nn.Parameter(torch.tensor(1.0))
            if beta_unique
            else nn.Parameter(torch.ones(nb_types))
        )

    def forward(
        self,
        event_times: torch.Tensor,
        event_types: torch.Tensor,
        mask: torch.Tensor,
        t0: torch.Tensor,
        t1: torch.Tensor,
        features: torch.Tensor,  # (S, D)
    ) -> torch.Tensor:
        """Compute log-likelihood for batched sequences with features.
        
        Args:
            event_times: Tensor of event timestamps, shape (batch_size, max_seq_len)
            event_types: Tensor of event type indices, shape (batch_size, max_seq_len)
            mask: Binary mask for valid events, shape (batch_size, max_seq_len)
            t0: Start time for each sequence, shape (batch_size, 1)
            t1: End time for each sequence, shape (batch_size, 1)
            features: Node features, shape (batch_size, nb_types)
            
        Returns:
            log_likelihood: The computed log-likelihood for the batch
        """
        device = event_times.device
        S, N = event_times.shape
        D = self.nb_types
        t0 = t0.reshape(S, 1)
        t1 = t1.reshape(S, 1)
        
        # ---------- Compute α(features) ----------------------------------------
        # Calculate alpha matrix from features through logistic link function
        alpha = torch.zeros((S, D, D), device=device)
        for s in range(S):
            x = features[s]  # (D,)
            z = self.gamma + self.theta1 * x.view(-1, 1) + self.theta2 * x.view(1, -1)
            alpha[s] = torch.sigmoid(z)

        mu = self.mu
        beta = self.beta

        # ---------- (1) Calculate log intensities --------------------------------
        # Time differences between all events
        dt = event_times[:, :, None] - event_times[:, None]
        dt = fill_tril(dt, -1) * mask.unsqueeze(2)
        
        # Apply exponential kernel
        if self.beta_unique:
            kernel = torch.exp(-beta * dt) * mask.unsqueeze(2)
        else:
            kernel = torch.exp(-beta[event_types].unsqueeze(2) * dt) * mask.unsqueeze(2)

        # Get alpha values based on event types
        row = event_types.unsqueeze(2).expand(S, N, N)
        col = event_types.unsqueeze(1).expand(S, N, N)
        batch = torch.arange(S, device=device).view(S, 1, 1).expand(-1, N, N)
        a_vals = alpha[batch, row, col]

        # Total excitation and intensity
        excitation = fill_tril(a_vals * kernel, -1).sum(2)
        lam = mu[event_types] * mask + excitation
        log_int = (torch.log(lam.clamp_min(1e-15)) * mask).sum(1)

        # ---------- (2) Calculate compensator ------------------------------------
        # Base compensator from background rates
        Tm = (t1 - t0).squeeze(-1)
        comp_base = Tm * mu.sum()

        # Additional compensator term from excitation effects
        diff_t = (t1 - event_times).clamp_min(0.0)
        if self.beta_unique:
            integ = (1 - torch.exp(-beta * diff_t)) * mask
        else:
            integ = (1 - torch.exp(-beta[event_types] * diff_t)) * mask

        s_idx = torch.arange(S, device=device).view(S, 1, 1)
        i_idx = torch.arange(D, device=device).view(1, D, 1)
        j_idx = event_types.unsqueeze(1)
        a_ij = alpha[s_idx, i_idx, j_idx]
        
        # Sum compensator contributions
        if self.beta_unique:
            comp = (a_ij / beta * integ.unsqueeze(1)).sum(2).sum(1)
        else:
            comp = (a_ij / beta[i_idx] * integ.unsqueeze(1)).sum(2).sum(1)

        return log_int - (comp + comp_base)


class HawkesMultiFeatureMLE(nn.Module):
    """Hawkes process model with multiple features per dimension.
    
    This extension of HawkesFeatureMLE supports multiple features per dimension
    for more flexible modeling of the excitation matrix.
    
    Attributes:
        nb_types: Number of event types in the process
        nb_of_features: Number of features per dimension
        beta_unique: Whether to use a single shared beta or one per event type
        mu: Base intensity parameters
        gamma: Constant term in the logistic link
        theta1: Coefficients for source node features (nb_of_features, nb_types, nb_types)
        theta2: Coefficients for target node features (nb_of_features, nb_types, nb_types)
        beta: Decay parameter(s)
    """

    def __init__(self, nb_types: int, nb_of_features: int, beta_unique: bool = True):
        """Initialize the multi-feature Hawkes MLE model.
        
        Args:
            nb_types: Number of event types in the Hawkes process
            nb_of_features: Number of features per dimension
            beta_unique: If True, use a single shared beta parameter; 
                        otherwise, use one beta per event type
        """
        super().__init__()
        self.nb_types = nb_types
        self.nb_of_features = nb_of_features
        self.beta_unique = beta_unique

        # Initialize model parameters
        self.mu = nn.Parameter(torch.rand(nb_types) * 0.1 + 0.01)
        self.gamma = nn.Parameter(torch.zeros(nb_types, nb_types))
        # Parameters for multiple features
        self.theta1 = nn.Parameter(torch.zeros(nb_of_features, nb_types, nb_types))
        self.theta2 = nn.Parameter(torch.zeros(nb_of_features, nb_types, nb_types))

        # Initialize decay parameter(s)
        self.beta = (
            nn.Parameter(torch.tensor(1.0))
            if beta_unique
            else nn.Parameter(torch.ones(nb_types))
        )

    def forward(
        self,
        event_times: torch.Tensor,
        event_types: torch.Tensor,
        mask: torch.Tensor,
        t0: torch.Tensor,
        t1: torch.Tensor,
        features: torch.Tensor,  # (S, D) or (S, K, D)
    ) -> torch.Tensor:
        """Compute log-likelihood for batched sequences with multiple features.
        
        Args:
            event_times: Tensor of event timestamps, shape (batch_size, max_seq_len)
            event_types: Tensor of event type indices, shape (batch_size, max_seq_len)
            mask: Binary mask for valid events, shape (batch_size, max_seq_len)
            t0: Start time for each sequence, shape (batch_size, 1)
            t1: End time for each sequence, shape (batch_size, 1)
            features: Node features, shape (batch_size, nb_types) or (batch_size, nb_of_features, nb_types)
            
        Returns:
            log_likelihood: The computed log-likelihood for the batch
        """
        device = event_times.device
        S, N = event_times.shape
        D = self.nb_types
        K = self.nb_of_features
        t0 = t0.reshape(S, 1)
        t1 = t1.reshape(S, 1)
        
        # Adapt feature input format - support both (S, D) and (S, K, D)
        if len(features.shape) == 2:  # If features is (S, D)
            # Transform to (S, 1, D) then duplicate to (S, K, D)
            features = features.unsqueeze(1).expand(-1, K, -1)
        
        # ---------- Compute α(features) with multiple feature dimensions --------
        # Contribution from source features via B1: sum_k features[s,k,i] * B1[k,i,j]
        contrib1 = torch.einsum('ski,kij->sij', features, self.theta1)
        # Contribution from target features via B2: sum_k features[s,k,j] * B2[k,i,j]
        contrib2 = torch.einsum('skj,kij->sij', features, self.theta2)

        # Add constant term and apply sigmoid activation
        z = self.gamma.unsqueeze(0) + contrib1 + contrib2  # (S, D, D)
        alpha = torch.sigmoid(z)  # (S, D, D)

        mu = self.mu
        beta = self.beta

        # ---------- (1) Calculate log intensities --------------------------------
        # Time differences between all events
        dt = event_times[:, :, None] - event_times[:, None]
        dt = fill_tril(dt, -1) * mask.unsqueeze(2)
        
        # Apply exponential kernel
        if self.beta_unique:
            kernel = torch.exp(-beta * dt) * mask.unsqueeze(2)
        else:
            kernel = torch.exp(-beta[event_types].unsqueeze(2) * dt) * mask.unsqueeze(2)

        # Get alpha values based on event types
        row = event_types.unsqueeze(2).expand(S, N, N)
        col = event_types.unsqueeze(1).expand(S, N, N)
        batch = torch.arange(S, device=device).view(S, 1, 1).expand(-1, N, N)
        a_vals = alpha[batch, row, col]

        # Total excitation and intensity
        excitation = fill_tril(a_vals * kernel, -1).sum(2)
        lam = mu[event_types] * mask + excitation
        log_int = (torch.log(lam.clamp_min(1e-15)) * mask).sum(1)

        # ---------- (2) Calculate compensator ------------------------------------
        # Base compensator from background rates
        Tm = (t1 - t0).squeeze(-1)
        comp_base = Tm * mu.sum()

        # Additional compensator term from excitation effects
        diff_t = (t1 - event_times).clamp_min(0.0)
        if self.beta_unique:
            integ = (1 - torch.exp(-beta * diff_t)) * mask
        else:
            integ = (1 - torch.exp(-beta[event_types] * diff_t)) * mask

        s_idx = torch.arange(S, device=device).view(S, 1, 1)
        i_idx = torch.arange(D, device=device).view(1, D, 1)
        j_idx = event_types.unsqueeze(1)
        a_ij = alpha[s_idx, i_idx, j_idx]
        
        # Sum compensator contributions
        if self.beta_unique:
            comp = (a_ij / beta * integ.unsqueeze(1)).sum(2).sum(1)
        else:
            comp = (a_ij / beta[i_idx] * integ.unsqueeze(1)).sum(2).sum(1)

        return log_int - (comp + comp_base)


class HawkesSlowLoopModelOgata(nn.Module):
    """Recursive implementation of exponential Hawkes log-likelihood based on Ogata's method.
    
    This implementation updates a state matrix recursively through the sequence,
    which can be more stable and efficient for longer sequences.
    
    Attributes:
        nb_types: Number of event types in the process
        mu: Base intensity parameters 
        alpha: Excitation parameters matrix
        beta: Decay parameter
    """

    def __init__(self, nb_types: int):
        """Initialize the Ogata-style recursive Hawkes model.
        
        Args:
            nb_types: Number of event types in the Hawkes process
        """
        super().__init__()
        self.nb_types = nb_types
        self.mu = nn.Parameter(torch.full((nb_types,), 0.1))
        self.alpha = nn.Parameter(torch.full((nb_types, nb_types), 0.1))
        self.beta = nn.Parameter(torch.tensor(1.0))

    def forward(self, event_times, event_types, input_mask, t0, t1):
        """Compute log-likelihood using Ogata's recursive method.
        
        Args:
            event_times: Tensor of event timestamps, shape (batch_size, max_seq_len)
            event_types: Tensor of event type indices, shape (batch_size, max_seq_len)
            input_mask: Binary mask for valid events, shape (batch_size, max_seq_len)
            t0: Start time for each sequence, shape (batch_size, 1)
            t1: End time for each sequence, shape (batch_size, 1)
            
        Returns:
            log_likelihood: The computed log-likelihood for the batch
        """
        S, N = event_times.shape
        D = self.nb_types
        mu = self.mu
        alpha = self.alpha
        beta = self.beta.clamp(min=1e-6)

        T0 = t0.reshape(S)
        T1 = t1.reshape(S)
        loglik = torch.zeros(S, device=event_times.device)

        for s in range(S):
            # Extract and sort valid events
            t_seq = event_times[s][input_mask[s] > 0.5]
            d_seq = event_types[s][input_mask[s] > 0.5].long()
            order = torch.argsort(t_seq)
            t_seq = t_seq[order]
            d_seq = d_seq[order]

            # Initialize state matrix R(t) tracking excitation effects
            R = torch.zeros(D, D, device=event_times.device)  
            ll = 0.0
            comp = 0.0
            t_prev = T0[s]

            # Process events sequentially
            for t_k, i_k in zip(t_seq, d_seq):
                # Time since previous event
                Dt = t_k - t_prev
                decay = torch.exp(-beta * Dt)

                # Update state matrix R
                R_prev = R.clone()                # R(t_prev⁺)
                R = R_prev * decay                # R just before t_k

                # Compute intensity for current event
                lam_vec = mu + R.sum(dim=1)       # Sum across columns
                ll += torch.log(lam_vec[i_k] + 1e-15)

                # Compute compensator for interval [t_prev, t_k]
                comp += mu.sum() * Dt
                comp += (R_prev.sum() / (beta + 1e-15)) * (1 - decay)

                # Update R after event at t_k of type i_k
                for i in range(D):
                    R[i, i_k] += alpha[i, i_k]   # Each type i excited by event of type i_k

                t_prev = t_k

            # Process final interval [t_last, T1]
            Dt_end = T1[s] - t_prev
            decay_end = torch.exp(-beta * Dt_end)
            comp += mu.sum() * Dt_end
            comp += (R.sum() / (beta + 1e-15)) * (1 - decay_end)

            loglik[s] = ll - comp

        return loglik.sum()


class ModelEM:
    """EM algorithm for exponential Hawkes processes. Inspired by Ogata's method and code from https://github.com/stmorse/hawkes
    
    This class implements the Expectation-Maximization algorithm for fitting
    exponential Hawkes processes with kernel α_{ij} ω_j e^{-ω_j t}.
    
    Attributes:
        D: Number of event types
        beta_unique: Whether to use a single shared beta or one per event type
        maxiter: Maximum number of EM iterations
        epsilon: Convergence threshold for log-likelihood
        regularize: Whether to apply regularization during fitting
        beta: Decay parameter(s), either shared or per-type
        mu_: Fitted base intensities
        alpha_: Fitted excitation matrix
    """

    def __init__(
        self,
        nb_types: int,
        beta_init: Union[float, NDArray[np.floating]] = 1.0,
        maxiter: int = 10000,
        epsilon: float = 1e-5,
        regularize: bool = False,
        beta_unique: bool = True,
    ) -> None:
        """Initialize the EM algorithm for Hawkes process 
        
        Args:
            nb_types: Number of event types
            beta_init: Initial value(s) for the decay parameter beta
            maxiter: Maximum number of EM iterations
            epsilon: Convergence threshold for log-likelihood change
            regularize: Whether to apply regularization during fitting
            beta_unique: If True, use a single shared beta parameter; 
                        otherwise, use one beta per event type
        """
        self.D = nb_types
        self.beta_unique = beta_unique
        self.maxiter = maxiter
        self.epsilon = epsilon
        self.regularize = regularize

        # Initialize beta as scalar or vector
        if beta_unique:
            self.beta: float = float(beta_init)
        else:  # Vector (D,)
            self.beta: NDArray[np.floating] = np.full(self.D, float(beta_init))

        # Initialize model parameters (to be fitted)
        self.mu_: Optional[NDArray[np.floating]] = None  # Base intensities (D,)
        self.alpha_: Optional[NDArray[np.floating]] = None  # Excitation matrix (D,D)

    def fit(
        self,
        times: NDArray[np.floating],
        types: NDArray[np.integer]
    ) -> "ModelEM":
        """Fit the model using EM algorithm.
        
        Parameters:
            times: Array of event timestamps, shape (N,), assumed sorted
            types: Array of event type indices, shape (N,), values in {0,...,D-1}
            
        Returns:
            self: The fitted model instance (for method chaining)
        """
        # Pre-processing and initialization
        seq = np.column_stack([times, types]).astype(float)
        N = seq.shape[0]
        D = self.D
        Tm = times[-1]  # Observation horizon
        u = types.astype(int)  # Event types

        # Count events by type for initialization
        counts: NDArray[np.floating] = np.bincount(u, minlength=D).astype(float)

        # Initialize parameters
        mu = counts / Tm  # Initial baseline rates
        alpha = np.full((D, D), 0.1, float)  # Initial excitation matrix
        beta = self.beta  # Decay parameter (scalar or vector)

        # Pre-compute time differences matrix
        diffs = np.subtract.outer(times, times)  # Dt = t_i - t_j
        tri_i, tri_j = np.triu_indices(N)  # Indices for j ≥ i
        diffs[tri_i, tri_j] = 0.0  # Keep only j < i

        # EM algorithm loop
        old_LL = -np.inf
        for it in range(self.maxiter):
            # ========= E-step: compute responsibilities =========
            # Compute kernel values
            if self.beta_unique:
                K = beta * np.exp(-beta * diffs)  # Scalar beta
            else:
                K = beta[u][None, :] * np.exp(-beta[u][None, :] * diffs)  # Vector beta

            # Get alpha values for each pair of events
            Aij = alpha[u[:, None], u[None, :]]  # α_{u_i,u_j}
            Aij[tri_i, tri_j] = 0.0  # Zero out upper triangle (j ≥ i)
            ag = Aij * K  # α_{u_i,u_j} * e^{-ω*Dt}

            # Compute total intensity for each event
            rates = mu[u] + ag.sum(axis=1)  # λ(t_i)

            # Compute responsibilities
            p_ij = ag / rates[:, None]  # Parent responsibilities p_ij
            p_ij[tri_i, tri_j] = 0.0  # Ensure upper triangle remains zero
            p_ii = mu[u] / rates  # Background responsibilities

            # ========= M-step: update parameters =========
            # Update μ (background rates)
            for d in range(D):
                mu[d] = p_ii[u == d].sum() / Tm

            # Update α (excitation matrix)
            numer = np.zeros_like(alpha)
            for i in range(D):
                for j in range(D):
                    mask_ij = (u == i)[:, None] & (u == j)[None, :]
                    mask_ij[tri_i, tri_j] = False  # Keep only j < i
                    numer[i, j] = p_ij[mask_ij].sum()

            # Compute denominator for alpha update
            if self.beta_unique:
                # For scalar beta
                G = np.array([
                    np.sum(1 - np.exp(-beta * (Tm - times[u == j])))
                    for j in range(D)
                ])  # (D,)
                # Update alpha: α̂_ij = N_{ij} / G_j
                alpha = numer / G[None, :]
            else:
                # For vector beta
                denom = np.zeros(D)
                for j in range(D):
                    idx = np.where(u == j)[0]
                    denom[j] = np.sum((1 - np.exp(-beta[j] * (Tm - times[idx]))))
                alpha = numer / denom[None, :]

            # Update ω (decay parameter)
            sum_p = p_ij.sum(axis=None)                      # Σ p_ij
            sum_dt = (p_ij * diffs).sum(axis=None)           # Σ p_ij Dt

            if self.beta_unique:
                # For scalar beta
                if sum_dt > 0:
                    beta = sum_p / sum_dt
            else:
                # For vector beta
                sum_p_j = np.zeros(D)
                sum_dt_j = np.zeros(D)
                for j in range(D):
                    col_j = (u == j)
                    sum_p_j[j] = p_ij[:, col_j].sum()
                    sum_dt_j[j] = (p_ij[:, col_j] * diffs[:, col_j]).sum()
                mask = sum_dt_j > 0
                beta[mask] = sum_p_j[mask] / sum_dt_j[mask]

            # ========= Convergence test =========
            term1 = np.log(rates).sum()
            term2 = Tm * mu.sum()
            term3 = alpha.sum()
            new_LL = (term1 - term2 - term3) / N

            if abs(new_LL - old_LL) < self.epsilon:
                break
            old_LL = new_LL

        # Save the final parameter estimates
        self.mu_ = mu
        self.alpha_ = alpha
        self.beta = beta
        return self