from scipy.optimize import linear_sum_assignment
import lpips
import numpy as np
import torch as pt
import torch.nn as nn
import torch.nn.functional as ptnf


class MetricWrap:
    """"""

    def __init__(self, detach=False, **metrics):
        self.detach = detach
        self.metrics = metrics

    # @pt.compile
    def __call__(self, output: dict, batch: dict) -> dict:
        metrics = {}
        for key, value in self.metrics.items():
            # assert "map" in value
            kwds = {
                **{k: output[v] for k, v in value["map"]["output"].items()},
                **{k: batch[v] for k, v in value["map"]["batch"].items()},
            }
            if self.detach:
                kwds = {k: v.detach() for k, v in kwds.items()}  # cpu: slow
            if "transform" in value:
                kwds = value["transform"](**kwds)
            # assert "metric" in value
            metric = value["metric"](**kwds)
            if "weight" in value:
                metric = metric * value["weight"]
            metrics[key] = metric
        return metrics


class CrossEntropyLoss:
    """Based on ``nn.CrossEntropyLoss``. Tensors are in shape (b,c,..)."""

    def __init__(self, weight=None, reduce="mean"):
        self.weight = weight
        self.reduce = reduce

    def __call__(self, input, target):
        return ptnf.cross_entropy(input, target, self.weight, reduction=self.reduce)


class CrossEntropyLossGrouped(CrossEntropyLoss):
    """Assume input and target are both one hot, and in shape (b,c,...)."""

    def __init__(self, groups, weight=None, reduce="mean"):
        super().__init__(weight, reduce)
        self.groups = groups

    def __call__(self, input, target):
        """
        input: shape=(b,sum(self.groups),h,w); class indexes of onehot format (softmax)
        target: shape=(b,g,h,w); class indexes of tuple format
        """
        assert input.ndim == target.ndim
        start = 0
        loss = []
        for g, gc in enumerate(self.groups):  # or range(target.size(1)):
            end = start + gc
            ce = super().__call__(input[:, start:end], target[:, g])
            start = end
            loss.append(ce)
        # assert end == input.size(1)
        return sum(loss)


class Entropy:

    def __init__(self, dim=1, binariz=True, reduce="mean"):
        assert dim == 1
        assert reduce == "mean"
        self.dim = dim
        self.binariz = binariz

    def __call__(self, input):
        if self.binariz:
            idx = input.argmax(1, keepdim=True)
            input = pt.zeros_like(input).scatter_(1, idx, 1.0) + input - input.detach()
        x = input.mean((0, 2, 3))
        x = x.softmax(0)  # , x.dtype
        return __class__.entropy(x, 0)

    @staticmethod
    def entropy(prob, dim):
        return -(prob * prob.log2()).sum(dim)


class EntropyGrouped(Entropy):

    def __init__(self, groups, dim=1, binariz=True, reduce="mean"):
        super().__init__(dim, binariz, reduce)
        self.groups = groups

    def __call__(self, input):
        ent = []
        start = 0
        for g in self.groups:
            end = start + g
            array_g = input[:, start:end, :, :]
            start = end
            ent_g = super().__call__(array_g)
            ent.append(ent_g)  # remove the coefficient
            # ent.append(ent_g * (g / sum(self.groups)))
        assert end == input.size(1)
        return sum(ent)


class CosineSimilarity:

    def __init__(self, dim, pattern, reduce="mean"):
        self.dim = dim
        self.pattern = pattern
        assert reduce == "mean"

    def __call__(self, input):
        x = input / pt.norm(input, p=2, dim=self.dim, keepdim=True, dtype=input.dtype)
        prod = pt.einsum(self.pattern, x, x)
        return prod.mean()


class CodebookCosineSimilarity:

    def __call__(self, input):
        if isinstance(input, nn.ModuleList):
            simis = [
                __class__.calculate_codebook_cosine_similarity(_.weight) for _ in input
            ]
            loss = sum(simis) / len(input)
        else:
            raise "NotImplemented"
        return loss

    @staticmethod
    def calculate_codebook_cosine_similarity(x):
        x = x / pt.norm(x, p=2, dim=1, keepdim=True, dtype=x.dtype)
        return pt.einsum("nc,mc->nm", x, x).mean()


class UtilizLoss:
    """Maximize VAE codebook utilization."""

    def __init__(self, normaliz=True, binariz=True):
        self.normaliz = normaliz
        self.binariz = binariz

    def __call__(self, input):  # target=None
        """
        input: in shape (b,c,h,w)
        """
        if self.normaliz:
            input = input.softmax(1)  # , input.dtype
        if self.binariz:
            idx = input.argmax(1, keepdim=True)
            input = pt.zeros_like(input).scatter_(1, idx, 1.0) + input - input.detach()
        x = input.mean((0, 2, 3))
        if self.normaliz:  # coefficient of variance
            cv = x.std(0) * x.size(0)
        else:
            cv = x.std(0) / x.mean(0)
        return cv


class UtilizLossGrouped(UtilizLoss):

    def __init__(self, groups, normaliz=True, binariz=True):
        super().__init__(normaliz, binariz)
        self.groups = groups

    def __call__(self, input):
        d = input.size(1) // sum(self.groups)
        start = 0
        loss = []
        for g in self.groups:
            end = start + g * d
            array_g = input[:, start:end, :, :].unflatten(1, [g, d]).mean(2)
            cvm = super().__call__(array_g)
            start = end
            loss.append(cvm * (g / sum(self.groups)))
        assert end == input.size(1)
        return sum(loss)  # / len(self.groups)


class L1Loss:
    """Based on ``nn.L1Loss``."""

    def __init__(self, reduce="mean"):
        self.reduce = reduce

    def __call__(self, input, target=None):
        if target is None:
            target = pt.zeros_like(input)
        return ptnf.l1_loss(input, target, reduction=self.reduce)


class MSELoss:
    """Based on ``nn.MSELoss``."""

    def __init__(self, reduce="mean"):
        self.reduce = reduce

    def __call__(self, input, target):
        return ptnf.mse_loss(input, target, reduction=self.reduce)


class HuberLoss:
    """Based on ``nn.HuberLoss``."""

    def __init__(self, reduce="mean", delta=1.0):
        self.reduce = reduce
        self.delta = delta

    def __call__(self, input, target):
        return ptnf.huber_loss(input, target, reduction=self.reduce)


class KLDivLoss:
    """Based on ``nn.KLDivLoss``."""

    def __init__(self, reduce="mean", log_target=False):
        self.reduce = reduce
        self.log_target = log_target

    def __call__(self, input, target):
        return ptnf.kl_div(
            input, target, reduction=self.reduce, log_target=self.log_target
        )


class LPIPSLoss:
    """``lpips.LPIPS``"""

    def __init__(self, net="vgg", reduce="mean", stop_amp=True):
        assert reduce == "mean"
        self.lpips = lpips.LPIPS(pretrained=True, net=net, eval_mode=True)
        for p in self.lpips.parameters():
            p.requires_grad = False
        self.lpips.compile()
        # self.lpips = pt.quantization.quantize_dynamic(self.lpips)  # slow
        self.stop_amp = stop_amp

    @pt.autocast("cuda", pt.float, enabled=False)
    def __call__(self, input, target):
        self.lpips.to(input.device)  # to the same device, won't repeat once done
        return self.lpips(target, input.float()).mean()


class HungarianMIoU:
    """Match ``c`` predicted intances with ``d`` ground-truth instances."""

    def __init__(self, num_pd=-1, num_gt=-1, fg=False):
        # self.num_pd = num_pd  # optional thanks to hungarian match
        # self.num_gt = num_gt  # optional thanks to hungarian match
        assert num_pd < 256 and num_gt < 256
        self.fg = fg

    @pt.no_grad()
    def __call__(self, input, target):
        """
        input: in shape (b,..), indexes not one-hot
        output: in shape (b,..), indexes not one-hot
        """
        idx_pd, idx_gt = __class__.segment_ensure(input, target)  # (b,n)
        oh_pd, oh_gt = __class__.index_to_onehot(
            idx_pd, idx_gt, -1, -1, self.fg
        )  # (b,n,c) (b,n,d)
        miou = __class__.hungarian_miou(oh_pd, oh_gt)
        return miou.mean()

    # @pt.no_grad()
    @staticmethod
    def segment_ensure(idx_pd, idx_gt, c=-1, d=-1):
        """
        idx_pd: shape=(b,..), dtype=int, indexes not onehot
        idx_gt: shape=(b,..), dtype=int, indexes not onehot
        """
        assert not idx_pd.is_floating_point()
        assert not idx_gt.is_floating_point()
        assert idx_pd.shape == idx_gt.shape
        assert idx_pd.ndim >= 2
        if c > 1:
            assert idx_pd.max() < c  # set(idx_pd.unique()) <= set(range(c))
        if d > 1:
            assert idx_gt.max() < d  # set(idx_gt.unique()) <= set(range(d))
        idx_pd = idx_pd.flatten(1)  # (b,n)
        idx_gt = idx_gt.flatten(1)  # (b,n)
        return idx_pd, idx_gt

    # @pt.no_grad()
    @staticmethod
    def index_to_onehot(idx_pd, idx_gt, c=-1, d=-1, fg=False):
        """
        idx_pd: shape=(b,n), dtype=int, indexes not onehot
        idx_gt: shape=(b,n), dtype=int, indexes not onehot
        c: number of prediction objects
        d: number of ground-truth objects; optional thanks to hungarian match
        """
        oh_pd = ptnf.one_hot(idx_pd.long(), c).bool()  # (b,n,c)
        oh_gt = ptnf.one_hot(idx_gt.long(), d).bool()  # (b,n,d)
        if fg:
            oh_gt = oh_gt[:, :, 1:]  # suppose 0 is background
        return oh_pd, oh_gt  # (b,n,c) (b,n,d)

    # @pt.no_grad()
    @staticmethod
    def hungarian_miou(oh_pd, oh_gt, eps=1e-8):
        """
        oh_pd: shape=(b,n,c), dtype=bool, one-hot not indexes
        oh_gt: shape=(b,n,d), dtype=bool, one-hot not indexes
        return: shape=(b,), dtype=float
        """
        b, n, c = oh_pd.shape
        b, n, d = oh_gt.shape

        oh_pd = oh_pd.bool()
        oh_gt = oh_gt.bool()  # TODO remove all-zero segments below

        intersect = (oh_gt[:, :, :, None] & oh_pd[:, :, None, :]).sum(1)  # (b,d,c)
        union = oh_gt.sum(1)[:, :, None] + oh_pd.sum(1)[:, None, :] - intersect
        iou = intersect.float() / (union.float() + eps)  # (b,d,c)

        iou2, ridx, cidx = [], [], []
        for _ in iou.detach().cpu().numpy():  # (d,c)
            # assign rows to cols  # TODO XXX transpose to (c,d)???
            row_ind, col_ind = linear_sum_assignment(_, maximize=True)  # (d,)
            ridx.append(row_ind)
            cidx.append(col_ind)
            iou2.append(_[row_ind, col_ind])

        # ridx = pt.as_tensor(ridx, dtype=pt.int32, device=device)  # (b,d)
        # cidx = pt.as_tensor(cidx, dtype=pt.int32, device=device)  # (b,d)
        # iou2 = iou[pt.arange(b)[:, None].repeat(1, d), ridx, cidx]  # (b,d)
        iou2 = pt.as_tensor(np.stack(iou2), dtype=iou.dtype, device=iou.device)

        # more predicted instances than ground-truth; use the best matches
        if c >= d:
            return iou2.mean(1)  # (b,)
        # less predicted instances than ground-truth; zero the missings
        return iou2.sum(1) / d


class mBO:
    """Match ``c`` predicted intances with ``d`` ground-truth instances."""

    def __init__(self, num_pd=-1, num_gt=-1, fg=False):
        # self.num_pd = num_pd
        # self.num_gt = num_gt
        assert num_pd < 256 and num_gt < 256
        self.fg = fg

    def __call__(self, input, target):
        """
        input: in shape (b,..), indexes not one-hot
        output: in shape (b,..), indexes not one-hot
        """
        idx_pd, idx_gt = HungarianMIoU.segment_ensure(input, target)  # (b,n)
        oh_pd, oh_gt = HungarianMIoU.index_to_onehot(
            idx_pd, idx_gt, -1, -1, self.fg
        )  # (b,n,c) (b,n,d)
        bo = __class__.mean_best_overlap(oh_pd, oh_gt)  # (b,)
        return bo.mean()

    # @pt.no_grad()
    @staticmethod
    def mean_best_overlap(oh_pd, oh_gt, eps=1e-8):
        """
        oh_pd: shape=(b,n,c), dtype=bool, one-hot not indexes
        oh_gt: shape=(b,n,d), dtype=bool, one-hot not indexes
        return: shape=(b,), dtype=float
        """
        oh_pd = oh_pd.bool()
        oh_gt = oh_gt.bool()  # TODO remove all-zero segments below

        intersect = (oh_gt[:, :, :, None] & oh_pd[:, :, None, :]).sum(1)  # (b,d,c)
        union = oh_gt.sum(1)[:, :, None] + oh_pd.sum(1)[:, None, :] - intersect
        iou = intersect.float() / (union.float() + eps)  # (b,d,c)

        iou2 = iou.max(2)[0]  # (b,d)  # find the best prediction for every ground-truth
        return iou2.mean(1)  # (b,)


class ARI:
    """
    https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html
    Note: The larger num_pd and num_gt, the "better" ari will be, especially arifg.
    """

    def __init__(self, num_pd=-1, num_gt=-1, fg=False):
        # self.num_pd = num_pd
        # self.num_gt = num_gt
        assert num_pd < 256 and num_gt < 256
        self.fg = fg

    @pt.no_grad()
    def __call__(self, input, target):
        """
        input: in shape (b,..), indexes not one-hot
        output: in shape (b,..), indexes not one-hot
        """
        idx_pd, idx_gt = HungarianMIoU.segment_ensure(input, target)  # (b,n)
        oh_pd, oh_gt = HungarianMIoU.index_to_onehot(
            idx_pd, idx_gt, -1, -1, self.fg
        )  # (b,n,c) (b,n,d)
        acc = __class__.adjusted_rand_score(oh_pd, oh_gt)
        return acc.mean()

    # @pt.no_grad()
    @staticmethod
    def adjusted_rand_score(oh_pd, oh_gt, eps=1e-8):
        """
        oh_pd: shape=(b,n,c), dtype=bool, one-hot not indexes
        oh_gt: shape=(b,n,d), dtype=bool, one-hot not indexes
        return: shape=(b,), dtype=float
        """
        oh_pd = oh_pd.bool()
        oh_gt = oh_gt.bool()  # TODO remove all-zero segments below

        # the following two impls: 1:4@cpu; 3:0.3@gpu
        # einsum("bnc,bnd->bcd", oh_gt.double(), oh_pr.double())
        N = (oh_gt[:, :, :, None] & oh_pd[:, :, None, :]).sum(1)  # (b,d,c)
        A = N.int().sum(2)  # (b,d), long
        B = N.int().sum(1)  # (b,c)
        num = A.int().sum(1)  # (b,)

        idx_r = (N * (N - 1)).sum([1, 2])  # (b,)
        idx_a = (A * (A - 1)).sum(1)  # (b,)
        idx_b = (B * (B - 1)).sum(1)  # (b,)
        idx_n = (num * (num - 1)).clip(1)  # (b,)

        idx_r_exp = (idx_a * idx_b).float() / idx_n.float()
        idx_r_max = (idx_a + idx_b).float() / 2.0
        denominat = idx_r_max - idx_r_exp
        ari = (idx_r.float() - idx_r_exp) / (denominat + eps)

        # the denominator can be zero:
        # - both pd & gt idxs assign all pixels to one cluster
        # - both pd & gt idxs assign max 1 point to each cluster
        ari.masked_fill_(denominat == 0, 1)
        return ari
