from src.lib.metrics.numpy_metrics import (
    masked_coverage,
    masked_pi_width,
    masked_winkler_score,
)
from src.lib.utils.data_utils import find_close
from tsl.utils.casting import torch_to_numpy


def evaluate_multiquantile_predictions(
    q_hat, r_true, mask, alphas, target_qs, loader_name="eval"
):

    results = dict()
    for target_alpha in alphas:
        idx_low = find_close(target_alpha / 2, target_qs)
        idx_high = find_close(1 - target_alpha / 2, target_qs)
        pi_interval = [q_hat[idx_low], q_hat[idx_high]]

        results[f"{loader_name}_coverage_at_{(1 - target_alpha) * 100}"] = (
            masked_coverage(
                torch_to_numpy(pi_interval),
                torch_to_numpy(r_true),
                torch_to_numpy(mask),
            )
        )
        # compute pi width
        results[f"{loader_name}_pi_width_at_{(1 - target_alpha) * 100}"] = (
            masked_pi_width(
                torch_to_numpy(pi_interval),
                torch_to_numpy(r_true),
                torch_to_numpy(mask),
            )
        )
        # compute winkler score
        results[f"{loader_name}_winkler_at_{(1 - target_alpha) * 100}"] = (
            masked_winkler_score(
                torch_to_numpy(pi_interval),
                torch_to_numpy(r_true),
                torch_to_numpy(mask),
                alpha=target_alpha,
            )
        )
    return results
