import numpy as np
import torch


def get_avg_f1(y_pred_lst, y_lst):
    test_results = list()
    for i in range(y_pred_lst.shape[1]):
        y = y_lst[:, i]
        y_hat = y_pred_lst[:, i]
        test_results.append(
            f1_max(
                torch.tensor(y_hat).unsqueeze(dim=1),
                torch.tensor(y).unsqueeze(dim=1),
            ).numpy()
        )
    return np.mean(np.array(test_results))


def f1_max(pred, target):
    """
    Code from https://torchdrug.ai/docs/_modules/torchdrug/metrics/metric.html#f1_max
    """
    order = pred.argsort(descending=True, dim=1)
    target = target.gather(1, order)
    precision = target.cumsum(1) / torch.ones_like(target).cumsum(1)
    recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10)
    is_start = torch.zeros_like(target).bool()
    is_start[:, 0] = 1
    is_start = torch.scatter(is_start, 1, order, is_start)

    all_order = pred.flatten().argsort(descending=True)
    order = (
        order
        + torch.arange(order.shape[0], device=order.device).unsqueeze(1)
        * order.shape[1]
    )
    order = order.flatten()
    inv_order = torch.zeros_like(order)
    inv_order[order] = torch.arange(order.shape[0], device=order.device)
    is_start = is_start.flatten()[all_order]
    all_order = inv_order[all_order]
    precision = precision.flatten()
    recall = recall.flatten()
    all_precision = precision[all_order] - torch.where(
        is_start, torch.zeros_like(precision), precision[all_order - 1]
    )
    all_precision = all_precision.cumsum(0) / is_start.cumsum(0)
    all_recall = recall[all_order] - torch.where(
        is_start, torch.zeros_like(recall), recall[all_order - 1]
    )
    all_recall = all_recall.cumsum(0) / pred.shape[0]
    all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10)
    return all_f1.max()
