import os
import os.path as osp
import tempfile
from unittest.mock import MagicMock

import matplotlib.pyplot as plt
import mmcv
import numpy as np
import pytest

from openmixup.models.utils.visualization import color_val_matplotlib, imshow_infos, ImshowInfosContextManager

def test_color():
    assert color_val_matplotlib(mmcv.Color.blue) == (0., 0., 1.)
    assert color_val_matplotlib('green') == (0., 1., 0.)
    assert color_val_matplotlib((1, 2, 3)) == (3 / 255, 2 / 255, 1 / 255)
    assert color_val_matplotlib(100) == (100 / 255, 100 / 255, 100 / 255)
    assert color_val_matplotlib(np.zeros(3, dtype=int)) == (0., 0., 0.)
    # forbid white color
    with pytest.raises(TypeError):
        color_val_matplotlib([255, 255, 255])
    # forbid float
    with pytest.raises(TypeError):
        color_val_matplotlib(1.0)
    # overflowed
    with pytest.raises(AssertionError):
        color_val_matplotlib((0, 0, 500))


def test_imshow_infos():
    tmp_dir = osp.join(tempfile.gettempdir(), 'image_infos')
    tmp_filename = osp.join(tmp_dir, 'image.jpg')

    image = np.ones((10, 10, 3), np.uint8)
    result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
    out_image = imshow_infos(
        image, result, out_file=tmp_filename, show=False)
    assert osp.isfile(tmp_filename)
    assert image.shape == out_image.shape
    assert not np.allclose(image, out_image)
    os.remove(tmp_filename)

    # test grayscale images
    image = np.ones((10, 10), np.uint8)
    result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
    out_image = imshow_infos(
        image, result, out_file=tmp_filename, show=False)
    assert osp.isfile(tmp_filename)
    assert image.shape == out_image.shape[:2]
    os.remove(tmp_filename)


def test_figure_context_manager():
    # test show multiple images with the same figure.
    images = [
        np.random.randint(0, 255, (100, 100, 3), np.uint8) for _ in range(5)
    ]
    result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}

    with ImshowInfosContextManager() as manager:
        fig_show = manager.fig_show
        fig_save = manager.fig_save

        # Test time out
        fig_show.canvas.start_event_loop = MagicMock()
        fig_show.canvas.end_event_loop = MagicMock()
        for image in images:
            ret, out_image = manager.put_img_infos(image, result, show=True)
            assert ret == 0
            assert image.shape == out_image.shape
            assert not np.allclose(image, out_image)
            assert fig_show is manager.fig_show
            assert fig_save is manager.fig_save

        # Test continue key
        fig_show.canvas.start_event_loop = (
            lambda _: fig_show.canvas.key_press_event(' '))
        for image in images:
            ret, out_image = manager.put_img_infos(image, result, show=True)
            assert ret == 0
            assert image.shape == out_image.shape
            assert not np.allclose(image, out_image)
            assert fig_show is manager.fig_show
            assert fig_save is manager.fig_save

        # Test close figure manually
        fig_show = manager.fig_show

        def destroy(*_, **__):
            fig_show.canvas.close_event()
            plt.close(fig_show)

        fig_show.canvas.start_event_loop = destroy
        ret, out_image = manager.put_img_infos(images[0], result, show=True)
        assert ret == 1
        assert image.shape == out_image.shape
        assert not np.allclose(image, out_image)
        assert fig_save is manager.fig_save
