"""
Simulation of exponential-kernel multivariate Hawkes processes
using Ogata's thinning algorithm.
"""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Optional, Tuple, Union


@dataclass
class Events:
    """Container for simulation results from ExpHawkes.simulate()."""
    times: List[float]
    types: List[int]
    dims: int


class ExpHawkes:
    """
    Multivariate Hawkes process with exponential kernel.

    Parameters
    ----------
    mu : array-like, shape (d,)
        Baseline intensities for each dimension
    alpha : array-like, shape (d, d)
        Excitation matrix representing event-triggered intensity jumps
    beta : float or array-like, shape (d, d)
        Decay parameter(s) for the exponential kernel
    seed : int, optional
        Random seed for reproducibility
    """

    def __init__(
        self,
        mu: Union[List[float], np.ndarray],
        alpha: Union[List[List[float]], np.ndarray],
        beta: Union[float, List[float], np.ndarray],
        seed: Optional[int] = None,
    ) -> None:
        """Initialize the Hawkes process with given parameters."""
        # Convert inputs to numpy arrays
        self.mu = np.asarray(mu, dtype=float)
        self.alpha = np.asarray(alpha, dtype=float)
        
        # Handle beta parameter: ensure it's a (d, d) matrix
        beta_array = np.asarray(beta, dtype=float)
        if beta_array.ndim == 0 or beta_array.size == 1:  # scalar
            self.beta = np.tile(beta_array, (len(self.mu), len(self.mu)))
        elif beta_array.ndim == 1:  # vector
            self.beta = np.tile(beta_array.reshape(-1, 1), (1, len(self.mu)))
        else:  # matrix
            self.beta = beta_array
        
        # Process dimensions
        self.dims = len(self.mu)
        
        # Initialize random number generator
        self._rng = np.random.default_rng(seed)
        
        # Initialize state variables
        self._current_time = 0.0
        self._residual_matrix = np.zeros((self.dims, self.dims))
        self.times = []
        self.types = []

    def simulate(
        self,
        *,
        max_time: Optional[float] = None,
        max_events: Optional[int] = None,
        plot: bool = False,
    ) -> Events:
        """
        Simulate the Hawkes process.
        
        Parameters
        ----------
        max_time : float, optional
            Maximum simulation time
        max_events : int, optional
            Maximum number of events to simulate
        plot : bool, default=False
            Whether to plot the intensity and events
            
        Returns
        -------
        Events
            Container with simulation results
            
        Notes
        -----
        Exactly one of max_time or max_events must be provided.
        """
        # Reset simulation state
        self._current_time = 0.0
        self._residual_matrix = np.zeros((self.dims, self.dims))
        self.times = []
        self.types = []
        
        # Validate parameters
        if (max_time is None) == (max_events is None):
            raise ValueError("Provide exactly one of max_time or max_events.")
            
        # Run simulation
        if max_time is not None:
            self._simulate_until_time(max_time)
        else:
            self._simulate_until_events(max_events)

        # Plot if requested
        if plot:
            self._plot_intensity_and_events()

        return Events(self.times, self.types, self.dims)

    def _simulate_until_events(self, max_events: int) -> None:
        """
        Simulate the process until reaching the specified number of events.
        
        Parameters
        ----------
        max_events : int
            Maximum number of events to simulate
        """
        while len(self.times) < max_events:
            self._simulate_next_step()

    def _simulate_until_time(self, horizon: float) -> None:
        """
        Simulate the process until reaching the specified time horizon.
        
        Parameters
        ----------
        horizon : float
            Time horizon for simulation
        """
        while True:
            self._simulate_next_step()
            if self._current_time > horizon:
                # Last proposal exceeded horizon - discard it if needed
                if self.times and self.times[-1] > horizon:
                    self.times.pop()
                    self.types.pop()
                break

    def _simulate_next_step(self) -> float:
        """
        Perform one step of Ogata's thinning algorithm.
        
        Returns
        -------
        float
            Time increment proposed (whether accepted or not)
        """
        # Calculate current intensity
        current_intensity = self.mu + self._residual_matrix.sum(axis=1)
        total_intensity = current_intensity.sum()
        
        if total_intensity <= 0:
            raise RuntimeError("Total intensity vanished → process died.")

        # Propose next jump time
        delta_t = self._rng.exponential(1 / total_intensity)
        proposed_time = self._current_time + delta_t

        # Decay memory based on time passed
        decay_factors = np.exp(-self.beta * delta_t)
        decayed_residuals = self._residual_matrix * decay_factors
        
        # Calculate new intensity after decay
        new_intensity = self.mu + decayed_residuals.sum(axis=1)
        new_total_intensity = new_intensity.sum()

        # Acceptance test for thinning algorithm
        if self._rng.random() < new_total_intensity / total_intensity:  # accept
            self._current_time = proposed_time
            self._residual_matrix = decayed_residuals
            
            # Choose which dimension triggered
            dim = self._rng.choice(self.dims, p=new_intensity / new_total_intensity)
            
            # Update residuals with new event
            self._residual_matrix[:, dim] += self.beta[:, dim] * self.alpha[:, dim]
            
            # Record the event
            self.times.append(self._current_time)
            self.types.append(dim)
        else:  # reject - only advance time and decay
            self._current_time = proposed_time
            self._residual_matrix = decayed_residuals
            
        return delta_t

    def _calculate_intensity(self, time_grid: np.ndarray) -> np.ndarray:
        """
        Calculate intensity function on a time grid.
        
        Parameters
        ----------
        time_grid : array-like
            Times at which to evaluate the intensity
            
        Returns
        -------
        np.ndarray, shape (len(time_grid), dims)
            Intensity values for each time and dimension
        """
        time_steps = len(time_grid)
        intensity = np.zeros((time_steps, self.dims))
        
        # For each time point in the grid
        for k, eval_time in enumerate(time_grid):
            # Start with baseline intensity
            intensity[k] = self.mu.copy()
            
            # Add contribution from each past event
            for event_time, event_type in zip(self.times, self.types):
                if event_time >= eval_time:
                    break
                    
                # Calculate contribution using exponential kernel
                time_diff = eval_time - event_time
                intensity[k] += (
                    self.beta[:, event_type] * 
                    self.alpha[:, event_type] * 
                    np.exp(-self.beta[:, event_type] * time_diff)
                )
                
        return intensity

