"""
Utility helpers – masking, batching, regularisers.
"""
from __future__ import annotations
import numpy as np
import torch
from typing import List, Tuple


# ---------------------------------------------------------------------
# Tensor masks
# ---------------------------------------------------------------------


def fill_tril(x: torch.Tensor, diagonal: int = 0) -> torch.Tensor:
    """Zero the strictly upper triangular part (so t_j < t_k)."""
    return torch.tril(x, diagonal=diagonal)


# ---------------------------------------------------------------------
# Batching of event sequences
# ---------------------------------------------------------------------


def event_batching(
    time_lists: List[List[float]], 
    type_lists: List[List[int]]
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Pad to the longest sequence.

    Parameters
    ----------
    time_lists : List[List[float]]
        List of time sequences
    type_lists : List[List[int]]
        List of type sequences

    Returns
    -------
    times : np.ndarray
        Padded time sequences of shape (S, N), dtype float32
    types : np.ndarray
        Padded type sequences of shape (S, N), dtype int32
    mask : np.ndarray
        Mask indicating real events (1) vs padding (0) of shape (S, N), dtype int32
    """
    batch = len(time_lists)
    max_len = max(len(seq) for seq in time_lists) if time_lists else 0
    
    t_out = np.zeros((batch, max_len), dtype=np.float32)
    y_out = np.zeros((batch, max_len), dtype=np.int32)
    m_out = np.zeros((batch, max_len), dtype=np.int32)

    for i, (t_seq, y_seq) in enumerate(zip(time_lists, type_lists)):
        seq_len = min(len(t_seq), max_len)
        t_out[i, :seq_len] = np.asarray(t_seq, np.float32)[:seq_len]
        y_out[i, :seq_len] = np.asarray(y_seq, np.int32)[:seq_len]
        m_out[i, :seq_len] = 1
        
    return t_out, y_out, m_out


def split_into_batches(seq: List, n_batches: int) -> List[List]:
    """
    Evenly chunk a sequence into n_batches sub-lists.
    
    Parameters
    ----------
    seq : List
        The sequence to split
    n_batches : int
        Number of batches to create
        
    Returns
    -------
    List[List]
        The sequence split into n_batches
    """
    n = len(seq)
    if n_batches <= 0 or n_batches >= n:
        return [[x] for x in seq]
    
    size = n // n_batches
    return [seq[i * size:(i + 1) * size] for i in range(n_batches - 1)] + [
        seq[(n_batches - 1) * size:]
    ]


# ---------------------------------------------------------------------
# Regularisation
# ---------------------------------------------------------------------


def l1_regularisation(params: List[torch.Tensor], lam: float) -> torch.Tensor:
    """
    Apply L1 regularization to a list of parameters.
    
    Parameters
    ----------
    params : List[torch.Tensor]
        List of parameters to regularize
    lam : float
        Regularization strength
        
    Returns
    -------
    torch.Tensor
        L1 regularization term
    """
    device = params[0].device
    return lam * sum(torch.abs(p).sum() for p in params).to(device)


def log_barrier(params: List[torch.Tensor], t: float) -> torch.Tensor:
    """
    Apply log barrier regularization.
    
    Parameters
    ----------
    params : List[torch.Tensor]
        List of parameters
    t : float
        Barrier parameter
        
    Returns
    -------
    torch.Tensor
        Log barrier term
    """
    device = params[0].device
    return sum((-torch.log(p + 1e-7)).sum() / t for p in params).to(device)


def sigmoid(z):
    """
    Sigmoid activation function.
    
    Parameters
    ----------
    z : float or np.ndarray
        Input value(s)
        
    Returns
    -------
    float or np.ndarray
        Sigmoid of input
    """
    return 1 / (1 + np.exp(-z))


def check_alpha_conditions(alpha_mat: np.ndarray) -> bool:
    """
    Check if alpha matrix meets specific conditions.
    
    Parameters
    ----------
    alpha_mat : np.ndarray
        Alpha matrix to check
        
    Returns
    -------
    bool
        True if conditions are met, False otherwise
    """
    near_zero = np.sum(alpha_mat < 1e-3)

    if near_zero != 4:
        return False

    mask_non_zero = (alpha_mat >= 1e-3)
    if np.any(alpha_mat[mask_non_zero] > 0.33) or np.any(alpha_mat[mask_non_zero] < 0.01):
        return False
        
    return True


def spectral_radius_is_lt_one(alpha_mat: np.ndarray, *, eps: float = 1e-12) -> bool:
    """
    Check if the spectral radius ρ(alpha_mat) is < 1.

    Parameters
    ----------
    alpha_mat : np.ndarray
        Square matrix (real or complex)
    eps : float, optional
        Small numerical safety margin (1-eps). Default is 1e-12.

    Returns
    -------
    bool
        True if ρ(alpha_mat) < 1 (sub-critical),
        False otherwise (≥ 1)
    """
    # Eigenvalues (numpy handles small to medium-sized matrices well)
    rho = np.max(np.abs(np.linalg.eigvals(alpha_mat)))
    return rho < (1.0 - eps)


def generate_alphas_with_features(rng: np.random.Generator, gamma, theta1, theta2):
    """
    Generate a valid dxd alpha matrix.
    
    Parameters
    ----------
    rng : np.random.Generator
        Random number generator
    gamma, theta1, theta2 : np.ndarray
        Parameter matrices of size dxd
        
    Returns
    -------
    tuple
        (alpha matrix, x vector)
    """
    d = len(gamma)
    
    while True:
        x = rng.uniform(-6, 6, size=d)
    
        alpha = np.zeros((d, d))
        for i in range(d):
            for j in range(d):
                z_ij = gamma[i, j] + theta1[i, j] * x[i] + theta2[i, j] * x[j]
                alpha[i, j] = sigmoid(z_ij)

        if spectral_radius_is_lt_one(alpha):
            return alpha, x