"""
fit_hawkes.py
=============

High-level helper to train Hawkes-process models (constant-α, feature-dependent,
slow-loop, or EM).  The helper now supports two ways of feeding the
observations:

1.  Flat lists  (times, types, nb_batches) – legacy interface.
2.  One list per match / trajectory (seq_times, seq_types) – the helper
    handles padding automatically.  Optional per-trajectory features can be
    supplied via seq_feats.

Both interfaces are fully backward-compatible with earlier code.
"""

from __future__ import annotations

from typing import Optional, Sequence, List

import numpy as np
import torch
import torch.optim as optim

from openhawkes.models import (
    HawkesMLE,
    HawkesMultiFeatureMLE,
    ModelEM,      
)
from openhawkes.utils import (
    event_batching,
    split_into_batches,
    l1_regularisation,
    log_barrier,
)


def fit_hawkes(
    nb_types: int,
    *,
    # ------------------------------------------------------------ data
    # one list per match / trajectory
    seq_times: Optional[Sequence[Sequence[float]]] = None,
    seq_types: Optional[Sequence[Sequence[int]]] = None,
    seq_feats: Optional[np.ndarray] = None,   # shape (S,D)

    # legacy interface: flat lists + manual batching
    times: Optional[Sequence[float]] = None,
    types: Optional[Sequence[int]] = None,
    nb_batches: int = 1,
    
    # ------------------------------------------------------------ hyp-opt
    max_epochs: int = 10_000,
    lr: float = 5e-3,
    lambda_l1: float = 0.0,
    beta_unique: bool = True,
    nb_of_features: int = 1,
    model_type: str = "vectorized",  # "vectorized" | "slow_loop" | "em"
    
    # -------- adaptive-LR / early stop
    epsilon: float = 1e-5,
    start_plateau: int = 5,
    epsilon_plateau: float = 0.02,
    adjust_lr: float = 0.8,
    
    # -------- positivity barrier
    log_b: bool = False,
    factor_b: float = 0.5,
    
    # ------------------------------------------------------------ system
    device: Optional[torch.device] = None
) -> torch.nn.Module:
    """
    Fit a multi-dimensional Hawkes process.

    Parameters
    ----------
    nb_types
        Number of dimensions D.
    seq_times, seq_types
        list (size S) of lists of timestamps / type-indices.
        Each inner list is one match / trajectory; lengths may differ.
    seq_feats
        Optional `(S, D_feat)` array of exogenous features – one row per
        trajectory.  Ignored if `features` (legacy) is provided.
    times, types, nb_batches
        Legacy flat-list interface kept for backward compatibility.
    model_type
        "vectorized" (default), 
        "slow_loop_ogata", or "em" (expectation–maximisation).
    beta_unique
        Fit a scalar β (True) or one β per dimension (False).
    lambda_l1
        l1 penalty on the branching matrix.
    max_epochs, lr ...
        Standard optimiser hyper-parameters.

    Returns
    -------
    torch.nn.Module
        Trained model with attributes mu, alpha and
        beta (names follow FeatHakes conventions).
    """
    # ---------------------------------------------------------------- device
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ---------------------------------------------------------------- EM path
    if model_type.lower() == "em":
        if seq_times or seq_types or seq_feats is not None:
            raise ValueError(
                "EM model only supports the legacy flat-list interface for now."
            )
        if times is None or types is None:
            raise ValueError("Provide `times` and `types` for EM fitting.")

        em = ModelEM(
            beta_unique=beta_unique,
            nb_types=nb_types,
            beta_init=1.0,
            maxiter=max_epochs,
        )
        em.fit(np.asarray(times, float), np.asarray(types, int))

        # wrap into nn.Module so downstream code can rely on .mu, etc.
        class EMWrapper(torch.nn.Module):
            pass

        mdl = EMWrapper()
        mdl.mu = torch.tensor(em.mu_, dtype=torch.float32, device=device)
        mdl.alpha = torch.tensor(em.alpha_, dtype=torch.float32, device=device)
        mdl.beta = torch.tensor(em.beta, dtype=torch.float32, device=device)
        return mdl

    # ===================================================================== #
    # 1. Build padded tensors                                                #
    # ===================================================================== #
    if seq_times is not None and seq_types is not None:
        # -- list-of-lists interface
        t_batches: List[List[float]] = [list(seq) for seq in seq_times]
        y_batches: List[List[int]] = [list(seq) for seq in seq_types]
        if times is not None or types is not None:
            print(
                "Ignoring legacy times / types because "
                "seq_times / seq_types were supplied."
            )
    else:
        # -- legacy flat lists
        if times is None or types is None:
            raise ValueError(
                "Supply either (seq_times, seq_types) or (times, types)."
            )
        t_batches = split_into_batches(list(times), nb_batches)
        y_batches = split_into_batches(list(types), nb_batches)

    # pad to (S, N)
    batch_t, batch_y, batch_m = event_batching(t_batches, y_batches)
    event_times = torch.tensor(batch_t, dtype=torch.float32, device=device)
    event_types = torch.tensor(batch_y, dtype=torch.long, device=device)
    mask = torch.tensor(batch_m, dtype=torch.float32, device=device)

    min_t = torch.tensor(
        [[min(seq) if seq else 0.0 for seq in t_batches]],
        dtype=torch.float32,
        device=device,
    )
    max_t = torch.tensor(
        [[max(seq) if seq else 0.0 for seq in t_batches]],
        dtype=torch.float32,
        device=device,
    )

    # ===================================================================== #
    # 2. Select model                                                        #
    # ===================================================================== #
    if seq_feats is not None:
        features = seq_feats.copy()
    else:
        features = None

    if model_type.lower() == "slow_loop_ogata":
        if features is not None:
            raise ValueError("Slow-loop models do not support exogenous features.")
        from openhawkes.models import HawkesSlowLoopModelOgata
        model = HawkesSlowLoopModelOgata(nb_types).to(device)

    else:  # "vectorized" path
        if features is None:
            model = HawkesMLE(nb_types=nb_types, beta_unique=beta_unique).to(device)
            feats_tensor = None
        else:
            feats_tensor = torch.tensor(features, dtype=torch.float32, device=device)
            model = HawkesMultiFeatureMLE(
                nb_types=nb_types,
                nb_of_features=nb_of_features,
                beta_unique=beta_unique,
            ).to(device)

    # ===================================================================== #
    # 3. Optimiser and training loop                                         #
    # ===================================================================== #
    opt = optim.Adam(model.parameters(), lr=lr)
    prev_loss, plateau_cnt, stop_conv = None, 0, None

    for epoch in range(max_epochs):
        opt.zero_grad()

        # ---------- forward
        if model_type.lower().startswith("slow_loop"):
            loglik = model(event_times, event_types, mask, min_t, max_t)
        elif feats_tensor is None:
            loglik = model(event_times, event_types, mask, min_t, max_t)
        else:
            loglik = model(event_times, event_types, mask, min_t, max_t, feats_tensor)

        # ---------- loss & regularisers
        loss = -loglik.sum()
        if lambda_l1 > 0:
            if features is None:
                loss += l1_regularisation([model.alpha], lambda_l1)
            else:
                loss += l1_regularisation([model.theta1, model.theta2], lambda_l1)
        if log_b:
            loss += log_barrier([model.mu, model.beta], factor_b)
        loss.backward()
        opt.step()

        # ---------- positivity
        with torch.no_grad():
            model.mu.clamp_(min=1e-12)
            model.beta.clamp_(min=1e-12)
            if hasattr(model, "alpha"):
                model.alpha.clamp_(min=1e-12)

        # ---------- plateau LR-scheduler
        if prev_loss is not None:
            if abs(prev_loss - loss.item()) < epsilon_plateau:
                plateau_cnt += 1
            else:
                plateau_cnt = 0
            if plateau_cnt >= start_plateau:
                for g in opt.param_groups:
                    g["lr"] *= adjust_lr
                plateau_cnt = 0
        prev_loss = loss.item()

        # ---------- simple early stop
        if stop_conv is not None and abs(stop_conv - loss.item()) < epsilon:
            print("Early stop: loss has converged.")
            break
        stop_conv = loss.item()

        if epoch % 100 == 0:
            print(f"[{epoch:4d}] loss = {loss.item():.4f}")

    return model