"""
Metrics module for SPOTLIGHT anomaly detection
"""

import numpy as np


def compute_accuracy(detected_anomalies, true_anomalies, tolerance=1):
    """
    Compute accuracy of anomaly detection.

    Args:
        detected_anomalies: List of detected anomaly timestamps
        true_anomalies: List of true anomaly timestamps
        tolerance: Time tolerance for matching (default 1)

    Returns:
        accuracy: Fraction of true anomalies that were detected within tolerance
    """
    if len(true_anomalies) == 0:
        return 1.0 if len(detected_anomalies) == 0 else 0.0

    detected_set = set(detected_anomalies)
    matches = 0

    for true_anomaly in true_anomalies:
        # Check if any detected anomaly is within tolerance
        for t in range(max(0, true_anomaly - tolerance), true_anomaly + tolerance + 1):
            if t in detected_set:
                matches += 1
                break

    return matches / len(true_anomalies)


def compute_precision_recall_f1(
    detected_anomalies, true_anomalies, total_timestamps, tolerance=1
):
    """
    Compute precision, recall, and F1 score for anomaly detection.

    Args:
        detected_anomalies: List of detected anomaly timestamps
        true_anomalies: List of true anomaly timestamps
        total_timestamps: Total number of timestamps in the sequence
        tolerance: Time tolerance for matching

    Returns:
        precision, recall, f1: Performance metrics
    """
    if len(detected_anomalies) == 0:
        precision = 1.0 if len(true_anomalies) == 0 else 0.0
        recall = 0.0
        f1 = 0.0
        return precision, recall, f1

    if len(true_anomalies) == 0:
        precision = 0.0
        recall = 1.0
        f1 = 0.0
        return precision, recall, f1

    detected_set = set(detected_anomalies)
    true_set = set(true_anomalies)

    # True positives: detected anomalies that match true anomalies within tolerance
    tp = 0
    for detected in detected_anomalies:
        for true_anomaly in true_anomalies:
            if abs(detected - true_anomaly) <= tolerance:
                tp += 1
                break

    # False positives: detected anomalies that don't match any true anomaly
    fp = len(detected_anomalies) - tp

    # False negatives: true anomalies that weren't detected
    fn = 0
    for true_anomaly in true_anomalies:
        found = False
        for detected in detected_anomalies:
            if abs(detected - true_anomaly) <= tolerance:
                found = True
                break
        if not found:
            fn += 1

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = (
        2 * precision * recall / (precision + recall)
        if (precision + recall) > 0
        else 0.0
    )

    return precision, recall, f1


def compute_detection_delay(detected_anomalies, true_anomalies, tolerance=5):
    """
    Compute average detection delay.

    Args:
        detected_anomalies: List of detected anomaly timestamps
        true_anomalies: List of true anomaly timestamps
        tolerance: Maximum delay considered valid

    Returns:
        avg_delay: Average detection delay, or None if no detections
    """
    delays = []

    for true_anomaly in true_anomalies:
        # Find the earliest detection within tolerance
        min_delay = float("inf")
        for detected in detected_anomalies:
            if detected >= true_anomaly and detected <= true_anomaly + tolerance:
                delay = detected - true_anomaly
                min_delay = min(min_delay, delay)

        if min_delay != float("inf"):
            delays.append(min_delay)

    return np.mean(delays) if delays else None
