import numpy as np
import pytest
import torch

from mmseg.models.losses import Accuracy, reduce_loss, weight_reduce_loss


def test_utils():
    loss = torch.rand(1, 3, 4, 4)
    weight = torch.zeros(1, 3, 4, 4)
    weight[:, :, :2, :2] = 1

    # test reduce_loss()
    reduced = reduce_loss(loss, 'none')
    assert reduced is loss

    reduced = reduce_loss(loss, 'mean')
    np.testing.assert_almost_equal(reduced.numpy(), loss.mean())

    reduced = reduce_loss(loss, 'sum')
    np.testing.assert_almost_equal(reduced.numpy(), loss.sum())

    # test weight_reduce_loss()
    reduced = weight_reduce_loss(loss, weight=None, reduction='none')
    assert reduced is loss

    reduced = weight_reduce_loss(loss, weight=weight, reduction='mean')
    target = (loss * weight).mean()
    np.testing.assert_almost_equal(reduced.numpy(), target)

    reduced = weight_reduce_loss(loss, weight=weight, reduction='sum')
    np.testing.assert_almost_equal(reduced.numpy(), (loss * weight).sum())

    with pytest.raises(AssertionError):
        weight_wrong = weight[0, 0, ...]
        weight_reduce_loss(loss, weight=weight_wrong, reduction='mean')

    with pytest.raises(AssertionError):
        weight_wrong = weight[:, 0:2, ...]
        weight_reduce_loss(loss, weight=weight_wrong, reduction='mean')


def test_ce_loss():
    from mmseg.models import build_loss

    # use_mask and use_sigmoid cannot be true at the same time
    with pytest.raises(AssertionError):
        loss_cfg = dict(
            type='CrossEntropyLoss',
            use_mask=True,
            use_sigmoid=True,
            loss_weight=1.0)
        build_loss(loss_cfg)

    # test loss with class weights
    loss_cls_cfg = dict(
        type='CrossEntropyLoss',
        use_sigmoid=False,
        class_weight=[0.8, 0.2],
        loss_weight=1.0)
    loss_cls = build_loss(loss_cls_cfg)
    fake_pred = torch.Tensor([[100, -100]])
    fake_label = torch.Tensor([1]).long()
    assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))

    loss_cls_cfg = dict(
        type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
    loss_cls = build_loss(loss_cls_cfg)
    assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.))

    loss_cls_cfg = dict(
        type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)
    loss_cls = build_loss(loss_cls_cfg)
    assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.))

    fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
    fake_label = torch.ones(2, 8, 8).long()
    assert torch.allclose(
        loss_cls(fake_pred, fake_label), torch.tensor(0.9503), atol=1e-4)
    fake_label[:, 0, 0] = 255
    assert torch.allclose(
        loss_cls(fake_pred, fake_label, ignore_index=255),
        torch.tensor(0.9354),
        atol=1e-4)

    # TODO test use_mask


def test_accuracy():
    # test for empty pred
    pred = torch.empty(0, 4)
    label = torch.empty(0)
    accuracy = Accuracy(topk=1)
    acc = accuracy(pred, label)
    assert acc.item() == 0

    pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6],
                         [0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1],
                         [0.0, 0.0, 0.99, 0]])
    # test for top1
    true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
    accuracy = Accuracy(topk=1)
    acc = accuracy(pred, true_label)
    assert acc.item() == 100

    # test for top1 with score thresh=0.8
    true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
    accuracy = Accuracy(topk=1, thresh=0.8)
    acc = accuracy(pred, true_label)
    assert acc.item() == 40

    # test for top2
    accuracy = Accuracy(topk=2)
    label = torch.Tensor([3, 2, 0, 0, 2]).long()
    acc = accuracy(pred, label)
    assert acc.item() == 100

    # test for both top1 and top2
    accuracy = Accuracy(topk=(1, 2))
    true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
    acc = accuracy(pred, true_label)
    for a in acc:
        assert a.item() == 100

    # topk is larger than pred class number
    with pytest.raises(AssertionError):
        accuracy = Accuracy(topk=5)
        accuracy(pred, true_label)

    # wrong topk type
    with pytest.raises(AssertionError):
        accuracy = Accuracy(topk='wrong type')
        accuracy(pred, true_label)

    # label size is larger than required
    with pytest.raises(AssertionError):
        label = torch.Tensor([2, 3, 0, 1, 2, 0]).long()  # size mismatch
        accuracy = Accuracy()
        accuracy(pred, label)

    # wrong pred dimension
    with pytest.raises(AssertionError):
        accuracy = Accuracy()
        accuracy(pred[:, :, None], true_label)


def test_lovasz_loss():
    from mmseg.models import build_loss

    # loss_type should be 'binary' or 'multi_class'
    with pytest.raises(AssertionError):
        loss_cfg = dict(
            type='LovaszLoss',
            loss_type='Binary',
            reduction='none',
            loss_weight=1.0)
        build_loss(loss_cfg)

    # reduction should be 'none' when per_image is False.
    with pytest.raises(AssertionError):
        loss_cfg = dict(type='LovaszLoss', loss_type='multi_class')
        build_loss(loss_cfg)

    # test lovasz loss with loss_type = 'multi_class' and per_image = False
    loss_cfg = dict(type='LovaszLoss', reduction='none', loss_weight=1.0)
    lovasz_loss = build_loss(loss_cfg)
    logits = torch.rand(1, 3, 4, 4)
    labels = (torch.rand(1, 4, 4) * 2).long()
    lovasz_loss(logits, labels)

    # test lovasz loss with loss_type = 'multi_class' and per_image = True
    loss_cfg = dict(
        type='LovaszLoss',
        per_image=True,
        reduction='mean',
        class_weight=[1.0, 2.0, 3.0],
        loss_weight=1.0)
    lovasz_loss = build_loss(loss_cfg)
    logits = torch.rand(1, 3, 4, 4)
    labels = (torch.rand(1, 4, 4) * 2).long()
    lovasz_loss(logits, labels, ignore_index=None)

    # test lovasz loss with loss_type = 'binary' and per_image = False
    loss_cfg = dict(
        type='LovaszLoss',
        loss_type='binary',
        reduction='none',
        loss_weight=1.0)
    lovasz_loss = build_loss(loss_cfg)
    logits = torch.rand(2, 4, 4)
    labels = (torch.rand(2, 4, 4)).long()
    lovasz_loss(logits, labels)

    # test lovasz loss with loss_type = 'binary' and per_image = True
    loss_cfg = dict(
        type='LovaszLoss',
        loss_type='binary',
        per_image=True,
        reduction='mean',
        loss_weight=1.0)
    lovasz_loss = build_loss(loss_cfg)
    logits = torch.rand(2, 4, 4)
    labels = (torch.rand(2, 4, 4)).long()
    lovasz_loss(logits, labels, ignore_index=None)
