from typing import List, Tuple, Callable, Optional

import torch


def constant_bias_fn(inputs: torch.Tensor) -> float:
    return torch.sum(inputs).item() / inputs.shape[0]


def back_bias_fn(inputs: torch.Tensor) -> float:
    n = inputs.shape[0]
    res = torch.dot(inputs, torch.arange(1, n + 1, dtype=inputs.dtype, device=inputs.device)).item()
    res /= (n * (n + 1)) // 2  # sum of numbers 1, ..., n
    return res


def front_bias_fn(inputs: torch.Tensor) -> float:
    n = inputs.shape[0]
    res = torch.dot(inputs, torch.arange(n, 0, -1, dtype=inputs.dtype, device=inputs.device)).item()
    res /= (n * (n + 1)) // 2  # sum of numbers 1, ..., n
    return res


def middle_bias_fn(inputs: torch.Tensor) -> float:
    n = inputs.shape[0]
    result = torch.empty_like(inputs)
    middle, remainder = divmod(n, 2)
    middle2 = middle + remainder
    torch.arange(1, middle + 1, out=result[:middle], dtype=result.dtype, device=result.device)
    torch.arange(middle2, 0, -1, out=result[-middle2:], dtype=result.dtype, device=result.device)
    result = torch.dot(inputs, result).item()
    result /= (middle * (middle + 1) + middle2 * (middle2 + 1)) // 2
    return result


def inverse_proportional_cardinality_fn(cardinality: int, gt_length: int) -> float:
    return 1 / max(1, cardinality)


def improved_cardinality_fn(cardinality: int, gt_length: int):
    return ((gt_length - 1) / gt_length) ** (cardinality - 1)


def compute_window_indices(binary_labels: torch.Tensor) -> List[Tuple[int, int]]:
    """
    Compute a list of indices where (ground truth and predicted) anomalies begin and end

    :param binary_labels:
    :return:
    """
    boundaries = torch.empty_like(binary_labels)
    boundaries[0] = 0
    boundaries[1:] = binary_labels[:-1]
    boundaries *= -1
    boundaries += binary_labels
    # boundaries will be 1 where a window starts and -1 at the end of a window

    indices = torch.nonzero(boundaries, as_tuple=True)[0].tolist()
    if len(indices) % 2 != 0:
        # Add the last index as the end of a window if appropriate
        indices.append(binary_labels.shape[0])
    indices = [(indices[i], indices[i + 1]) for i in range(0, len(indices), 2)]

    return indices


def compute_overlap(preds: torch.Tensor, pred_indices: List[Tuple[int, int]],
                    gt_indices: List[Tuple[int, int]], alpha: float,
                    bias_fn: Callable, cardinality_fn: Callable,
                    use_window_weight: bool = False) -> float:
    n_gt_windows = len(gt_indices)
    n_pred_windows = len(pred_indices)
    total_score = 0.0
    total_gt_points = 0

    i = j = 0
    while i < n_gt_windows and j < n_pred_windows:
        gt_start, gt_end = gt_indices[i]
        window_length = gt_end - gt_start
        total_gt_points += window_length
        i += 1

        cardinality = 0
        while j < n_pred_windows and pred_indices[j][1] <= gt_start:
            j += 1
        while j < n_pred_windows and pred_indices[j][0] < gt_end:
            j += 1
            cardinality += 1

        if cardinality == 0:
            # cardinality == 0 means no overlap at all, hence no contribution
            continue

        # The last predicted window that overlaps our current window could also overlap the next window.
        # Therefore, we must consider it again in the next loop iteration.
        j -= 1

        cardinality_multiplier = cardinality_fn(cardinality, window_length)

        prediction_inside_ground_truth = preds[gt_start:gt_end]
        # We calculate omega directly in the bias function, because this can greatly improve running time
        # for the constant bias, for example.
        omega = bias_fn(prediction_inside_ground_truth)

        # Either weight evenly across all windows or based on window length
        weight = window_length if use_window_weight else 1

        # Existence reward (if cardinality > 0 then this is certainly 1)
        total_score += alpha * weight
        # Overlap reward
        total_score += (1 - alpha) * cardinality_multiplier * omega * weight

    denom = total_gt_points if use_window_weight else n_gt_windows

    return total_score / denom


def ts_precision_and_recall(anomalies: torch.Tensor, predictions: torch.Tensor, alpha: float = 0,
                            recall_bias_fn: Callable[[torch.Tensor], float] = constant_bias_fn,
                            recall_cardinality_fn: Callable[[int], float] = inverse_proportional_cardinality_fn,
                            precision_bias_fn: Optional[Callable] = None,
                            precision_cardinality_fn: Optional[Callable] = None,
                            anomaly_ranges: Optional[List[Tuple[int, int]]] = None,
                            prediction_ranges: Optional[List[Tuple[int, int]]] = None,
                            weighted_precision: bool = False) -> Tuple[float, float]:
    """Computes time series precision and recall.

    :param anomalies: Tensor of shape (length,) containing the true labels.
    :type anomalies: torch.Tensor
    :param predictions: Tensor of shape (length,) containing the predicted labels.
    :type predictions: torch.Tensor
    :param alpha: Weight for existence term in recall.
    :type alpha: float
    :param recall_bias_fn: Function that computes the bias term for a batch of segments.
    :type recall_bias_fn: Callable[[torch.Tensor], torch.Tensor]
    :param recall_cardinality_fn: Function that compute the cardinality for a batch of segments.
    :type recall_cardinality_fn: Union[Callable[[int], float], str]
    :param precision_bias_fn: Function that computes the bias term for a batch of segments.
        If None, this will be the same as recall_bias_function.
    :type precision_bias_fn: Callable[[torch.Tensor], torch.Tensor]
    :param precision_cardinality_fn: Function that compute the cardinality for a batch of segments.
        If None, this will be the same as recall_cardinality_function.
    :type precision_cardinality_fn: Callable[[int], float]
    :param weighted_precision: If True, the precision score of a predicted window will be weighted with the
        length of the window in the final score. Otherwise, each window will have the same weight.
    :return: The time-series precision and recall for the given labels in the form (precision, recall).
    :rtype: Tuple[float, float]
    """
    has_anomalies = torch.any(anomalies > 0).item()
    has_predictions = torch.any(predictions > 0).item()

    # Catch special cases which would cause a division by zero
    if not has_predictions and not has_anomalies:
        # In this case, the classifier is perfect, so it makes sense to set precision and recall to 1
        return 1, 1
    elif not has_predictions or not has_anomalies:
        return 0, 0

    # Set precision functions to the same as recall functions if they are not given
    if precision_bias_fn is None:
        precision_bias_fn = recall_bias_fn
    if precision_cardinality_fn is None:
        precision_cardinality_fn = recall_cardinality_fn

    if anomaly_ranges is None:
        anomaly_ranges = compute_window_indices(anomalies)
    if prediction_ranges is None:
        prediction_ranges = compute_window_indices(predictions)

    recall = compute_overlap(predictions, prediction_ranges, anomaly_ranges, alpha, recall_bias_fn,
                             recall_cardinality_fn)
    precision = compute_overlap(anomalies, anomaly_ranges, prediction_ranges, 0, precision_bias_fn,
                                precision_cardinality_fn, use_window_weight=weighted_precision)

    return precision, recall
