import torch
from torch import Tensor
import torch.nn.functional as F

from criterion.dice_loss import multiclass_dice_coeff


def dice_score(masks_pred, true_masks):
    true_masks = true_masks.reshape([-1, masks_pred.size(2), masks_pred.size(3)])
    # true_masks = torch.where(true_masks>=masks_pred.size(1), masks_pred.size(1)-1, true_masks)
    true_masks = F.one_hot(true_masks, masks_pred.size(1)).permute(0, 3, 1, 2).float()
    
    masks_pred = F.one_hot(masks_pred.argmax(dim=1), masks_pred.size(1)).permute(0, 3, 1, 2).float()

    # compute the Dice score, ignoring background
    d_s = multiclass_dice_coeff(masks_pred[:, 1:, ...], true_masks[:, 1:, ...], reduce_batch_first=False)

    return d_s * true_masks.shape[0], true_masks.shape[0]
