import numpy as np
import pytest
import torch
from PIL import Image

from pytorch_fid import fid_score, inception


@pytest.fixture
def device():
    return torch.device('cpu')


def test_calculate_fid_given_statistics(mocker, tmp_path, device):
    dim = 2048
    m1, m2 = np.zeros((dim,)), np.ones((dim,))
    sigma = np.eye(dim)

    def dummy_statistics(path, model, batch_size, dims, device):
        if path.endswith('1'):
            return m1, sigma
        elif path.endswith('2'):
            return m2, sigma
        else:
            raise ValueError

    mocker.patch('pytorch_fid.fid_score.compute_statistics_of_path',
                 side_effect=dummy_statistics)

    dir_names = ['1', '2']
    paths = []
    for name in dir_names:
        path = tmp_path / name
        path.mkdir()
        paths.append(str(path))

    fid_value = fid_score.calculate_fid_given_paths(paths,
                                                    batch_size=dim,
                                                    device=device,
                                                    dims=dim)

    # Given equal covariance, FID is just the squared norm of difference
    assert fid_value == np.sum((m1 - m2)**2)


def test_compute_statistics_of_path(mocker, tmp_path, device):
    model = mocker.MagicMock(inception.InceptionV3)()
    model.side_effect = lambda inp: [inp.mean(dim=(2, 3), keepdim=True)]

    size = (4, 4, 3)
    arrays = [np.zeros(size), np.ones(size) * 0.5, np.ones(size)]
    images = [(arr * 255).astype(np.uint8) for arr in arrays]

    paths = []
    for idx, image in enumerate(images):
        paths.append(str(tmp_path / '{}.png'.format(idx)))
        Image.fromarray(image, mode='RGB').save(paths[-1])

    stats = fid_score.compute_statistics_of_path(str(tmp_path), model,
                                                 batch_size=len(images),
                                                 dims=3,
                                                 device=device)

    assert np.allclose(stats[0], np.ones((3,)) * 0.5, atol=1e-3)
    assert np.allclose(stats[1], np.ones((3, 3)) * 0.25)


def test_compute_statistics_of_path_from_file(mocker, tmp_path, device):
    model = mocker.MagicMock(inception.InceptionV3)()

    mu = np.random.randn(5)
    sigma = np.random.randn(5, 5)

    path = tmp_path / 'stats.npz'
    with path.open('wb') as f:
        np.savez(f, mu=mu, sigma=sigma)

    stats = fid_score.compute_statistics_of_path(str(path), model,
                                                 batch_size=1,
                                                 dims=5,
                                                 device=device)

    assert np.allclose(stats[0], mu)
    assert np.allclose(stats[1], sigma)


def test_image_types(tmp_path):
    in_arr = np.ones((24, 24, 3), dtype=np.uint8) * 255
    in_image = Image.fromarray(in_arr, mode='RGB')

    paths = []
    for ext in fid_score.IMAGE_EXTENSIONS:
        paths.append(str(tmp_path / 'img.{}'.format(ext)))
        in_image.save(paths[-1])

    dataset = fid_score.ImagePathDataset(paths)

    for img in dataset:
        assert np.allclose(np.array(img), in_arr)
