# https://github.com/Wuziyi616/SlotDiffusion/blob/main/slotdiffusion/img_based/models/eval_utils.py#L190

import numpy as np
import torch
import torch.nn.functional as F
import torchvision.ops as vops
from scipy.optimize import linear_sum_assignment

###########################################
# Image quality-related metrics
###########################################


def mse_metric(x, y):
    """x/y: [B, 3, H, W], np.ndarray in [0, 1]"""
    # people often sum over channel + spatial dimension in recon MSE
    return ((x - y)**2).sum(-1).sum(-1).sum(-1).mean()


def perceptual_dist(x, y, loss_fn):
    """x/y: [B, 3, H, W], torch.Tensor in [-1, 1]"""
    return loss_fn(x, y).mean().item()


###########################################
# Segmentation-related metrics
###########################################


def preproc_masks_overlap(gt_mask, pred_mask, inst_overlap_mask=None):
    """Pre-process masks to handle overlapping instances.

    From [DINOSAUR](https://arxiv.org/pdf/2209.14860.pdf), set the overlapping
        pixels to background, for both pred and gt masks. Only on COCO dataset.
    """
    if inst_overlap_mask is None:
        return gt_mask, pred_mask
    # all masks are of shape [H*W] or [H, W], dtype int
    assert torch.unique(inst_overlap_mask).shape[0] <= 2  # at most [0, 1]
    # set gt overlapping pixels to background
    gt_mask = gt_mask.clone()
    gt_mask[inst_overlap_mask == 1] = 0
    # set pred overlapping pixels to another class
    pred_mask = pred_mask.clone()
    pred_mask[inst_overlap_mask == 1] = pred_mask.max() + 1
    return gt_mask, pred_mask


def adjusted_rand_index(true_ids, pred_ids, ignore_background=False):
    """Computes the adjusted Rand index (ARI), a clustering similarity score.

    Code borrowed from https://github.com/google-research/slot-attention-video/blob/e8ab54620d0f1934b332ddc09f1dba7bc07ff601/savi/lib/metrics.py#L111

    Args:
        true_ids: An integer-valued array of shape
            [batch_size, seq_len, H, W]. The true cluster assignment encoded
            as integer ids.
        pred_ids: An integer-valued array of shape
            [batch_size, seq_len, H, W]. The predicted cluster assignment
            encoded as integer ids.
        ignore_background: Boolean, if True, then ignore all pixels where
            true_ids == 0 (default: False).

    Returns:
        ARI scores as a float32 array of shape [batch_size].
    """
    if len(true_ids.shape) == 3:
        true_ids = true_ids.unsqueeze(1)
    if len(pred_ids.shape) == 3:
        pred_ids = pred_ids.unsqueeze(1)

    true_oh = F.one_hot(true_ids).float()
    pred_oh = F.one_hot(pred_ids).float()

    if ignore_background:
        true_oh = true_oh[..., 1:]  # Remove the background row.

    N = torch.einsum("bthwc,bthwk->bck", true_oh, pred_oh)
    A = torch.sum(N, dim=-1)  # row-sum  (batch_size, c)
    B = torch.sum(N, dim=-2)  # col-sum  (batch_size, k)
    num_points = torch.sum(A, dim=1)

    rindex = torch.sum(N * (N - 1), dim=[1, 2])
    aindex = torch.sum(A * (A - 1), dim=1)
    bindex = torch.sum(B * (B - 1), dim=1)
    expected_rindex = aindex * bindex / torch.clamp(
        num_points * (num_points - 1), min=1)
    max_rindex = (aindex + bindex) / 2
    denominator = max_rindex - expected_rindex
    ari = (rindex - expected_rindex) / denominator

    # There are two cases for which the denominator can be zero:
    # 1. If both label_pred and label_true assign all pixels to a single cluster.
    #    (max_rindex == expected_rindex == rindex == num_points * (num_points-1))
    # 2. If both label_pred and label_true assign max 1 point to each cluster.
    #    (max_rindex == expected_rindex == rindex == 0)
    # In both cases, we want the ARI score to be 1.0:
    return torch.where(denominator != 0, ari, torch.tensor(1.).type_as(ari))


def ARI_metric(x, y, inst_overlap_mask=None):
    """x/y: [B, H, W], both are seg_masks after argmax."""
    assert 'int' in str(x.dtype)
    assert 'int' in str(y.dtype)
    if inst_overlap_mask is not None:
        x, y = x.clone(), y.clone()
        for i in range(x.shape[0]):
            x[i], y[i] = preproc_masks_overlap(x[i], y[i],
                                               inst_overlap_mask[i])
    return adjusted_rand_index(x, y, ignore_background=False).mean().item()


def fARI_metric(x, y, inst_overlap_mask=None):
    """x/y: [B, H, W], both are seg_masks after argmax."""
    assert 'int' in str(x.dtype)
    assert 'int' in str(y.dtype)
    if inst_overlap_mask is not None:
        x, y = x.clone(), y.clone()
        for i in range(x.shape[0]):
            x[i], y[i] = preproc_masks_overlap(x[i], y[i],
                                               inst_overlap_mask[i])
    return adjusted_rand_index(x, y, ignore_background=True).mean().item()


def bbox_precision_recall(gt_pres_mask, gt_bbox, pred_bbox, ovthresh=0.5):
    """Compute the precision of predicted bounding boxes.

    Args:
        gt_pres_mask: A boolean tensor of shape [N]
        gt_bbox: A tensor of shape [N, 4]
        pred_bbox: A tensor of shape [M, 4]
    """
    gt_bbox, pred_bbox = gt_bbox.clone(), pred_bbox.clone()
    gt_bbox = gt_bbox[gt_pres_mask.bool()]
    pred_bbox = pred_bbox[pred_bbox[:, 0] >= 0.]
    N, M = gt_bbox.shape[0], pred_bbox.shape[0]
    assert gt_bbox.shape[1] == pred_bbox.shape[1] == 4
    # assert M >= N
    tp, fp = 0, 0
    bbox_used = [False] * pred_bbox.shape[0]
    bbox_ious = vops.box_iou(gt_bbox, pred_bbox)  # [N, M]

    # Find the best iou match for each ground truth bbox.
    for i in range(N):
        best_iou_idx = bbox_ious[i].argmax().item()
        best_iou = bbox_ious[i, best_iou_idx].item()
        if best_iou >= ovthresh and not bbox_used[best_iou_idx]:
            tp += 1
            bbox_used[best_iou_idx] = True
        else:
            fp += 1

    # compute precision and recall
    precision = tp / float(M)
    recall = tp / float(N)
    return precision, recall


def batch_bbox_precision_recall(gt_pres_mask, gt_bbox, pred_bbox):
    """Compute the precision of predicted bounding boxes over batch."""
    aps, ars = [], []
    for i in range(gt_pres_mask.shape[0]):
        ap, ar = bbox_precision_recall(gt_pres_mask[i], gt_bbox[i],
                                       pred_bbox[i])
        aps.append(ap)
        ars.append(ar)
    return np.mean(aps), np.mean(ars)


def hungarian_miou(gt_mask, pred_mask, ignore_background=True):
    """both mask: [H*W] after argmax, 0 is gt background index."""
    # in case GT only contains bg class
    if gt_mask.max().item() == 0 and ignore_background:
        return np.nan
    true_oh = F.one_hot(gt_mask).float()  # [HW, N]
    if ignore_background:
        true_oh = true_oh[..., 1:]  # only foreground
    pred_oh = F.one_hot(pred_mask).float()  # [HW, M]
    N, M = true_oh.shape[-1], pred_oh.shape[-1]
    # compute all pairwise IoU
    intersect = (true_oh[:, :, None] * pred_oh[:, None, :]).sum(0)  # [N, M]
    union = true_oh.sum(0)[:, None] + pred_oh.sum(0)[None] - intersect  # same
    iou = intersect / (union + 1e-8)  # [N, M]
    iou = iou.detach().cpu().numpy()
    # find the best match for each gt
    row_ind, col_ind = linear_sum_assignment(iou, maximize=True)
    # there are two possibilities here
    #   1. M >= N, just take the best match mean
    #   2. M < N, some objects are not detected, their iou is 0
    if M >= N:
        assert (row_ind == np.arange(N)).all()
        return iou[row_ind, col_ind].mean()
    return iou[row_ind, col_ind].sum() / float(N)


def mean_best_overlap(gt_mask, pred_mask):
    """both mask: [H*W] after argmax, 0 is gt background index.

    From [DINOSAUR](https://arxiv.org/pdf/2209.14860.pdf), ignore background.
    """
    # in case GT only contains bg class
    if gt_mask.max().item() == 0:
        return np.nan
    true_oh = F.one_hot(gt_mask).float()  # [HW, N]
    # quote from email exchange with the authors:
    # > the ground truth mask corresponding to all background pixels is not
    # > used for matching
    true_oh = true_oh[..., 1:]  # only foreground
    pred_oh = F.one_hot(pred_mask).float()  # [HW, M]
    # compute all pairwise IoU
    intersect = (true_oh[:, :, None] * pred_oh[:, None, :]).sum(0)  # [N, M]
    union = true_oh.sum(0)[:, None] + pred_oh.sum(0)[None] - intersect  # same
    iou = intersect / (union + 1e-8)  # [N, M]
    iou = iou.detach().cpu().numpy()
    # find the best pred_mask for each gt_mask
    # each pred_mask can be used multiple times
    best_iou = iou.max(1)
    return best_iou.mean()


def miou_metric(gt_mask, pred_mask, inst_overlap_mask=None):
    """both mask: [B, H, W], both are seg_masks after argmax."""
    assert 'int' in str(gt_mask.dtype)
    assert 'int' in str(pred_mask.dtype)
    if inst_overlap_mask is None:
        inst_overlap_mask = [None] * gt_mask.shape[0]
    else:
        inst_overlap_mask = inst_overlap_mask.flatten(1, 2)
    gt_mask, pred_mask = gt_mask.flatten(1, 2), pred_mask.flatten(1, 2)
    ious = []
    for i in range(gt_mask.shape[0]):
        gt_mask_, pred_mask_ = preproc_masks_overlap(gt_mask[i], pred_mask[i],
                                                     inst_overlap_mask[i])
        iou = hungarian_miou(gt_mask_, pred_mask_, ignore_background=False)
        ious.append(iou)
    if all(np.isnan(iou) for iou in ious):
        print('WARNING: all miou in a batch are nan')
        return np.nan
    return np.nanmean(ious)


def fmiou_metric(gt_mask, pred_mask, inst_overlap_mask=None):
    """both mask: [B, H, W], both are seg_masks after argmax."""
    assert 'int' in str(gt_mask.dtype)
    assert 'int' in str(pred_mask.dtype)
    if inst_overlap_mask is None:
        inst_overlap_mask = [None] * gt_mask.shape[0]
    else:
        inst_overlap_mask = inst_overlap_mask.flatten(1, 2)
    gt_mask, pred_mask = gt_mask.flatten(1, 2), pred_mask.flatten(1, 2)
    ious = []
    for i in range(gt_mask.shape[0]):
        gt_mask_, pred_mask_ = preproc_masks_overlap(gt_mask[i], pred_mask[i],
                                                     inst_overlap_mask[i])
        iou = hungarian_miou(gt_mask_, pred_mask_, ignore_background=True)
        ious.append(iou)
    if all(np.isnan(iou) for iou in ious):
        print('WARNING: all miou in a batch are nan')
        return np.nan
    return np.nanmean(ious)


def mbo_metric(gt_mask, pred_mask, inst_overlap_mask=None):
    """both mask: [B, H, W], both are seg_masks after argmax."""
    assert 'int' in str(gt_mask.dtype)
    assert 'int' in str(pred_mask.dtype)
    if inst_overlap_mask is None:
        inst_overlap_mask = [None] * gt_mask.shape[0]
    else:
        inst_overlap_mask = inst_overlap_mask.flatten(1, 2)
    gt_mask, pred_mask = gt_mask.flatten(1, 2), pred_mask.flatten(1, 2)
    mbos = []
    for i in range(gt_mask.shape[0]):
        gt_mask_, pred_mask_ = preproc_masks_overlap(gt_mask[i], pred_mask[i],
                                                     inst_overlap_mask[i])
        mbo = mean_best_overlap(gt_mask_, pred_mask_)
        mbos.append(mbo)
    if all(np.isnan(mbo) for mbo in mbos):
        print('WARNING: all mbo in a batch are nan')
        return np.nan
    return np.nanmean(mbos)
