import numpy as np
import sglang as sgl
from sglang.lang.interpreter import ProgramState


def aggregate_score(scores, method="last"):
    if method == "last":
        return [score[-1] for score in scores]
    elif method == "mean":
        return [np.mean(score) for score in scores]
    elif method == "min":
        return [np.min(score) for score in scores]
    else:
        raise ValueError(f"Unknown aggregate method: {method}")


def add_messages_to_state(s: ProgramState, messages):
    for message in messages:
        if message["role"] == "system":
            s += sgl.system(message["content"])
        elif message["role"] == "user":
            s += sgl.user(message["content"])
        elif message["role"] == "assistant":
            s += sgl.assistant(message["content"])
        else:
            raise ValueError(f"Unknown message role: {message['role']}")


def _softmax(x: np.ndarray) -> np.ndarray:
    """Compute softmax of a 1-D numpy array with numerical stability.

    Parameters
    ----------
    x : np.ndarray
        Input array.

    Returns
    -------
    np.ndarray
        Softmax probabilities of the same shape as *x*.
    """
    # Shift by max for numerical stability
    shiftx = x - np.max(x)
    exps = np.exp(shiftx)
    return exps / np.sum(exps)


def compute_softmax_divergence(
    p_scores,
    q_scores,
    method: str = "kl",
    eps: float = 1e-12,
):
    """Compute divergence between two score lists based on their softmax distributions.

    The function first transforms both *p_scores* and *q_scores* into probability
    distributions via softmax and then computes a divergence metric between the
    resulting distributions.

    Parameters
    ----------
    p_scores : Sequence[float]
        The first list/array of (unnormalised) scores.
    q_scores : Sequence[float]
        The second list/array of (unnormalised) scores.
    method : {"kl", "js"}, optional
        Divergence metric to use. ``"kl"`` stands for the Kullback–Leibler
        divergence *D\_KL(P || Q)*, whereas ``"js"`` computes the symmetric
        Jensen–Shannon divergence. Default is ``"kl"``.
    eps : float, optional
        Small value added for numerical stability when computing ratios or
        logarithms. Default is ``1e-12``.

    Returns
    -------
    float
        The divergence value.
    """
    p = _softmax(np.asarray(p_scores, dtype=np.float64))
    q = _softmax(np.asarray(q_scores, dtype=np.float64))

    if method.lower() == "kl":
        return float(np.sum(p * np.log((p + eps) / (q + eps))))
    elif method.lower() == "js":
        m = 0.5 * (p + q)
        kl_pm = np.sum(p * np.log((p + eps) / (m + eps)))
        kl_qm = np.sum(q * np.log((q + eps) / (m + eps)))
        return float(0.5 * (kl_pm + kl_qm))
    else:
        raise ValueError(f"Unknown divergence method: {method}. Supported methods are 'kl' and 'js'.")


def compute_top1_mismatch(
    p_scores,
    q_scores,
) -> float:
    """Return ``1.0`` if the index of the maximum score in *p_scores*
    differs from that in *q_scores*, otherwise ``0.0``.

    This metric is useful for quickly checking whether the two
    distributions agree on the single best-ranked item.

    Parameters
    ----------
    p_scores : Sequence[float]
        First score list.
    q_scores : Sequence[float]
        Second score list.

    Returns
    -------
    float
        ``1.0`` if the argmax indices differ, else ``0.0``.
    """
    if len(p_scores) == 0 or len(q_scores) == 0:
        raise ValueError("Score lists must be non-empty.")
    if len(p_scores) != len(q_scores):
        raise ValueError("Score lists must have the same length.")

    return float(int(np.argmax(p_scores) != np.argmax(q_scores)))


def compute_pairwise_inversion_ratio(
    p_scores,
    q_scores,
) -> float:
    """Compute the fraction of pairwise order disagreements
    (a.k.a. Kendall tau distance normalised to [0, 1]) between
    *p_scores* and *q_scores*.

    A pair *(i, j)* is considered a *reverse pair* (or discordant)
    if the ordering of *i* and *j* differs between the two
    score lists when ranked in descending order.

    Parameters
    ----------
    p_scores : Sequence[float]
        First score list.
    q_scores : Sequence[float]
        Second score list.

    Returns
    -------
    float
        The ratio of discordant pairs over the total number of
        possible pairs. A value of ``0`` indicates identical
        rankings, whereas ``1`` means completely reversed order.
    """
    if len(p_scores) == 0 or len(q_scores) == 0:
        raise ValueError("Score lists must be non-empty.")
    if len(p_scores) != len(q_scores):
        raise ValueError("Score lists must have the same length.")

    n = len(p_scores)
    if n < 2:
        return 0.0  # No pairs to compare.

    # Obtain ranking positions (lower value = higher rank)
    p_order = np.argsort(-np.asarray(p_scores))
    q_order = np.argsort(-np.asarray(q_scores))

    rank_p = np.empty(n, dtype=int)
    rank_q = np.empty(n, dtype=int)
    rank_p[p_order] = np.arange(n)
    rank_q[q_order] = np.arange(n)

    discordant = 0
    # Count discordant pairs
    for i in range(n):
        for j in range(i + 1, n):
            if (rank_p[i] - rank_p[j]) * (rank_q[i] - rank_q[j]) < 0:
                discordant += 1

    total_pairs = n * (n - 1) // 2
    return discordant / total_pairs


def compute_divergence(
    p_scores,
    q_scores,
    method: str = "softmax_divergence",
    **kwargs,
):
    """Unified interface to quantify ranking divergence between two score lists.

    Parameters
    ----------
    p_scores : Sequence[float]
        First list/array of scores.
    q_scores : Sequence[float]
        Second list/array of scores.
    method : {"softmax_divergence", "top1_mismatch", "disorder_pairs"}, optional
        Which metric to use:

        * ``"softmax_divergence"`` – Call :pyfunc:`compute_softmax_divergence`. Extra
          keyword arguments are forwarded (e.g., ``divergence='js'`` or ``eps=1e-9``).
        * ``"top1_mismatch"`` – Call :pyfunc:`compute_top1_mismatch`.
        * ``"disorder_pairs"`` – Call :pyfunc:`compute_pairwise_inversion_ratio`.

    **kwargs
        Additional keyword arguments forwarded to the underlying metric function
        (currently only relevant for ``"softmax_divergence"``).

    Returns
    -------
    float
        Divergence score according to the selected *method*.
    """
    method = method.lower()
    if method == "softmax_divergence":
        div_type = kwargs.pop("divergence", "kl")
        eps = kwargs.pop("eps", 1e-12)
        return compute_softmax_divergence(p_scores, q_scores, method=div_type, eps=eps)
    elif method == "top1_mismatch":
        return compute_top1_mismatch(p_scores, q_scores)
    elif method == "disorder_pairs":
        return compute_pairwise_inversion_ratio(p_scores, q_scores)
    else:
        raise ValueError(
            "Unknown divergence method: {method}. Supported values are 'softmax_divergence', 'top1_mismatch', and 'disorder_pairs'."
        )
