from __future__ import annotations

from typing import List, Tuple, Dict
import numpy as np

Interval = Tuple[float, float]


def merge_intervals(intervals: List[Interval], eps: float = 1e-12) -> List[Interval]:
    """
    Merge overlapping (or touching, within eps) intervals and return a sorted list.
    """
    iv = sorted((float(a), float(b)) for a, b in intervals if float(b) > float(a))
    out: List[List[float]] = []
    for a, b in iv:
        if (not out) or (a > out[-1][1] + eps):
            out.append([a, b])
        else:
            out[-1][1] = max(out[-1][1], b)
    return [(a, b) for a, b in out]


def interval_length(intervals: List[Interval]) -> float:
    """
    Total length of a list of intervals.
    """
    return float(sum(float(b) - float(a) for a, b in intervals))


def intersect_length(A: List[Interval], B: List[Interval], eps: float = 1e-12) -> float:
    """
    Compute the total intersection length of two interval sets.
    """
    A = merge_intervals(A, eps)
    B = merge_intervals(B, eps)

    i = j = 0
    tot = 0.0
    while i < len(A) and j < len(B):
        a1, b1 = A[i]
        a2, b2 = B[j]
        L = max(a1, a2)
        R = min(b1, b2)
        if R > L:
            tot += (R - L)

        if b1 < b2 - eps:
            i += 1
        elif b2 < b1 - eps:
            j += 1
        else:
            i += 1
            j += 1
    return float(tot)


def hat_intervals_from_selected_basis(
    sel_idx: np.ndarray,
    projection_num: int,
    degree_projection: int,
    eps: float = 1e-12,
) -> List[Interval]:
    """
    Recover support intervals on [0, 1] from selected B-spline basis indices.

    The knot construction must match the one used in training:
        knots = [0,...,0] (degree times) + linspace(0,1,inner) + [1,...,1] (degree times)
        inner = projection_num - degree_projection + 1

    For selected basis j, support is [knots[j], knots[j + degree + 1]].
    """
    projection_num = int(projection_num)
    degree_projection = int(degree_projection)

    if projection_num <= 0:
        raise ValueError("projection_num must be positive.")
    if degree_projection < 0:
        raise ValueError("degree_projection must be nonnegative.")

    inner = projection_num - degree_projection + 1
    if inner <= 1:
        # Not enough knots to form valid supports
        return []

    knots = np.concatenate(
        (
            np.zeros(degree_projection, dtype=float),
            np.linspace(0.0, 1.0, inner, dtype=float),
            np.ones(degree_projection, dtype=float),
        )
    )

    sel_idx = np.asarray(sel_idx, dtype=int).reshape(-1)
    m = int(len(knots))

    iv: List[Interval] = []
    for j in map(int, sel_idx):
        r = j + degree_projection + 1
        if 0 <= j < m and 0 <= r < m:
            a = float(knots[j])
            b = float(knots[r])
            if b > a + eps:
                iv.append((a, b))

    return merge_intervals(iv, eps=eps)


def interval_metrics(true_iv: List[Interval], hat_iv: List[Interval], domain: Interval, eps: float = 1e-12) -> Dict[str, float]:
    """
    Continuous-interval metrics on a 1D domain.

    Returns:
      TPR (Recall), FPR, Precision, F1, Accuracy, and lengths.
    """
    tmin, tmax = map(float, domain)
    if not (tmax > tmin):
        raise ValueError(f"Invalid domain: {domain}")

    def _clip(S: List[Interval]) -> List[Interval]:
        out: List[Interval] = []
        for a, b in S:
            aa = max(tmin, float(a))
            bb = min(tmax, float(b))
            if bb > aa + eps:
                out.append((aa, bb))
        return out

    T = merge_intervals(_clip(true_iv), eps=eps)
    H = merge_intervals(_clip(hat_iv), eps=eps)

    lenT = interval_length(T)
    lenH = interval_length(H)
    L = float(tmax - tmin)

    TP = intersect_length(T, H, eps=eps)
    union_len = interval_length(merge_intervals(T + H, eps=eps))

    FP = max(0.0, lenH - TP)
    FN = max(0.0, lenT - TP)
    TN = max(0.0, L - union_len)

    precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
    recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0
    fpr = FP / (FP + TN) if (FP + TN) > 0 else 0.0
    acc = (TP + TN) / L if L > 0 else 0.0
    f1 = 0.0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall)
    symdiff = FP + FN

    return {
        "TPR": float(recall),
        "Recall": float(recall),
        "FPR": float(fpr),
        "Precision": float(precision),
        "F1": float(f1),
        "Accuracy": float(acc),
        "len_true": float(lenT),
        "len_hat": float(lenH),
        "len_domain": float(L),
        "SymDiff_len": float(symdiff),
    }
