import einops
import torch
import torch.nn.functional as F

# original: https://github.com/google-research/google-research/blob/e77c5d6a164b060c5d21417c264c2ead505302f6/invariant_slot_attention/lib/metrics.py#L112
# slightly modified from https://github.com/zpbao/Discovery_Obj_Move/blob/main/utils.py

def preprocess_and_batch_aris(true_masks, soft_pred_masks):
    # flatten masks along H and W dimensions and reshape
    # s is number of slots; o is number of objects
    true_masks_flat = einops.rearrange(true_masks, 'b o h w -> b (h w) o')
    soft_masks_flat = einops.rearrange(soft_pred_masks, 'b s h w -> b (h w) s')
    # make soft masks hard
    idcs_of_max = torch.argmax(soft_masks_flat, dim=-1, keepdim=True)
    hard_masks_flat = torch.zeros_like(soft_masks_flat, device=soft_masks_flat.device).scatter_(-1, idcs_of_max, 1)
    # compute scores
    batch_fg_ari = adjusted_rand_index(true_masks_flat, hard_masks_flat, ignore_background=True)
    batch_ari = adjusted_rand_index(true_masks_flat, hard_masks_flat, ignore_background=False)

    return batch_fg_ari, batch_ari
    

def adjusted_rand_index(true_mask_oh, pred_mask_oh, ignore_background=True):
    batch_size, n_points, n_true_groups = true_mask_oh.shape
    n_pred_groups = pred_mask_oh.shape[-1]
    assert not (n_points <= n_true_groups and n_points <= n_pred_groups), ("adjusted_rand_index requires n_groups < n_points. We don't handle the special cases that can occur when you have one cluster per datapoint.")

    if ignore_background:
        true_mask_oh = true_mask_oh[..., 1:]
    
    nij = torch.einsum('bji,bjk->bki', pred_mask_oh, true_mask_oh)
    a = torch.sum(nij, dim=1)
    b = torch.sum(nij, dim=2)

    rindex = torch.sum(nij * (nij - 1), dim=[1, 2])
    aindex = torch.sum(a * (a - 1), dim=1)
    bindex = torch.sum(b * (b - 1), dim=1)
    expected_rindex = aindex * bindex / (n_points*(n_points-1))
    max_rindex = (aindex + bindex) / 2
    denominator = max_rindex - expected_rindex
    ari = (rindex - expected_rindex) / denominator

    # There are two cases for which the denominator can be zero:
    # 1. If both label_pred and label_true assign all pixels to a single cluster.
    #    (max_rindex == expected_rindex == rindex == num_points * (num_points-1))
    # 2. If both label_pred and label_true assign max 1 point to each cluster.
    #    (max_rindex == expected_rindex == rindex == 0)
    # In both cases, we want the ARI score to be 1.0:
    ari = torch.where(denominator == 0, 1.0, ari)
    return ari