"""
Plot helpers – kept separate to avoid heavy matplotlib import unless needed.
"""
from __future__ import annotations

import numpy as np
import matplotlib.pyplot as plt
from .simulation import ExpHawkes


def plot_intensity(process: ExpHawkes, *, steps: int = 1000) -> None:
    """
    Plot intensity functions for all dimensions with colors.
    
    Args:
        process: An ExpHawkes process that has been simulated
        steps: Number of grid points for plotting
        
    Raises:
        ValueError: If the process has not been simulated
    """
    if not process.times:
        raise ValueError("Process not yet simulated.")
        
    time_grid = np.linspace(0, process.times[-1], steps)
    lambda_values = process._calculate_intensity(time_grid)
    
    for dim in range(process.d):
        plt.plot(time_grid, lambda_values[:, dim], label=f"λ_{dim}")
        
    plt.xlabel("Time")
    plt.ylabel("Intensity")
    plt.legend()
    plt.show()


def plot_jumps(process: ExpHawkes) -> None:
    """
    Plot jump events for each dimension on separate subplots.
    
    Args:
        process: An ExpHawkes process that has been simulated
    """
    fig, axes = plt.subplots(process.dims, 1, sharex=True, figsize=(8, 2 * process.dims))
    
    # Handle single dimension case where axes is not array-like
    if process.d == 1:
        axes = [axes]
        
    for dim in range(process.d):
        # Filter times for current dimension
        jumps = [time for time, event_type in zip(process.times, process.types) if event_type == dim]
        
        # Plot vertical lines for each jump
        for jump_time in jumps:
            axes[dim].axvline(jump_time, color="k")
            
        axes[dim].set_ylabel(f"Dim {dim}")
        axes[dim].set_ylim(0, 1)
        axes[dim].set_yticks([])
        
    axes[-1].set_xlabel("Time")
    plt.tight_layout()
    plt.show()


def plot_intensity_and_jumps_same_yscale(process: ExpHawkes, *, steps: int = 1000) -> None:
    """
    Plot intensity and jumps on stacked plots with the same y-scale for intensities.
    
    Args:
        process: An ExpHawkes process that has been simulated
        steps: Number of grid points for plotting
        
    Raises:
        ValueError: If the process has not been simulated
    """
    if not process.times:
        raise ValueError("Process not yet simulated.")

    time_grid = np.linspace(0, process.times[-1], steps)
    lambda_values = process._calculate_intensity(time_grid)
    y_max = lambda_values.max() * 1.1

    fig, axes = plt.subplots(process.dims * 2, 1, sharex=True, figsize=(8, 3 * process.dims))
    
    # Handle single dimension case
    if process.dims == 1:
        axes = [axes[0], axes[1]]  # ensure proper indexing for the single dimension case

    for dim in range(process.dims):
        # Plot intensity
        ax_intensity = axes[2 * dim]
        ax_intensity.plot(time_grid, lambda_values[:, dim])
        ax_intensity.set_ylim(0, y_max)
        ax_intensity.set_ylabel("Intensity")
        ax_intensity.set_title(f"Dimension {dim}")
        
        # Plot jumps
        ax_jumps = axes[2 * dim + 1]
        jumps = [time for time, event_type in zip(process.times, process.types) if event_type == dim]
        
        for jump_time in jumps:
            ax_jumps.axvline(jump_time, color="k")
            
        ax_jumps.set_ylim(0, 1)
        ax_jumps.set_yticks([])
        ax_jumps.set_ylabel("Jumps")

    axes[-1].set_xlabel("Time")
    plt.tight_layout()
    plt.show()