from typing import Union
import numpy as np
import torch

try:
    from . import dissonance_cuda

    def dissonance_gpu(belief):
        """
        Args:
            belief: (N, K)

        Returns:
            ans_dissonance: (N, )
        """
        device = belief.device
        ans_dissonance = torch.cuda.FloatTensor(belief.shape[0]).zero_()
        # ans_dissonance = torch.zeros(belief.shape[0], device='cuda', dtype=torch.float32)
        num_classes = belief.shape[1]    #
        dissonance_cuda.dissonance_gpu(belief.cuda().contiguous(), ans_dissonance, num_classes)
        return ans_dissonance.to(device)

except ImportError as e:
    print(e)
    print("cannot import name 'dissonance_cuda' from 'edl.ops.dissonance'")


def dissonance_cpu(belief):
    def rmbf(bj: Union[float, int], bk: Union[float, int]) -> Union[float, int]:
        """
        relative mass balance function: Bal(bj, bk)
        :return:
        """
        if min(bj, bk) == 0:
            return 0
        else:
            return 1 - abs(bj - bk) / (bj + bk)

    def diss(w: list) -> float:
        """

        :param w:
        :return:
        """
        dissonance = 0
        K = len(w)  # class number
        for k in range(K):
            bk = w[k]
            if bk == 1:
                continue

            numerator = 0
            denominator = sum(w) - bk + 1e-9

            for j in range(K):
                bj = w[j]

                if k == j:
                    continue
                numerator += bj * rmbf(bj, bk)
            numerator *= bk

            dissonance += numerator / denominator
        return dissonance
    return np.array([0 for b in belief])

if __name__ == '__main__':
    pass