import os
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import pytest
from sklearn.metrics import pairwise_distances

from .._utils import pad_array, pad_arrays, pairwise_call, save_fig


def test_pad_array():
    a = np.array([1, 2])

    b = pad_array(a, (3,), fill_value=0)
    assert np.array_equal(b, [1, 2, 0])

    b = pad_array(a, a.shape, fill_value=0)
    assert np.array_equal(b, a)

    a = np.array([[1, 2], [3, 4]], dtype=float)

    b = pad_array(a, a.shape, fill_value=0)
    assert np.array_equal(b, a)

    b = pad_array(a, (3, 2), fill_value=np.nan)
    c = np.array([[1, 2], [3, 4], [None, None]], dtype=float)
    assert np.array_equal(b, c, equal_nan=True)


def test_pad_arrays():
    a1 = np.array([1, 2])
    a2 = np.array([1, 2, 3])
    L = [a1, a2]

    L2 = pad_arrays(L, shape=None)
    assert np.array_equal(L2[0].shape, L2[1].shape)


def test_save_fig():
    dir = tempfile.TemporaryDirectory()
    dirpath = dir.name

    fig = plt.figure()

    kwargs = {
        'arg3': 'arg3',
        'arg1': 0,
        'arg2': True,
    }

    filepath = save_fig(fig, dirpath)
    assert os.path.basename(filepath) == 'fig.pdf'
    filepath = save_fig(fig, dirpath, **kwargs)
    assert os.path.basename(filepath) == 'arg1=0:arg2=T:arg3=arg3.pdf'
    filepath = save_fig(fig, dirpath, order=['arg2'], **kwargs)
    assert os.path.basename(filepath) == 'arg2=T:arg1=0:arg3=arg3.pdf'
    filepath = save_fig(fig, dirpath, order=['arg2', 'arg3'], **kwargs)
    assert os.path.basename(filepath) == 'arg2=T:arg3=arg3:arg1=0.pdf'

    dir.cleanup()


@pytest.mark.parametrize('symmetric', [False, True])
@pytest.mark.parametrize('n_jobs', [1, 2])
def test_pairwise_call(symmetric, n_jobs):
    n = 5
    d = 2
    rs = np.random.RandomState(0)
    X = rs.uniform(size=(n, d))
    D1 = pairwise_distances(X, metric='euclidean')
    D2 = pairwise_call(X, lambda x, y: np.linalg.norm(x - y),
                       symmetric=symmetric, n_jobs=n_jobs)
    assert np.allclose(D1, D2)
