import numpy as np

import tensorly as tl
from ..robust_decomposition import robust_pca
from ...testing import assert_array_equal, assert_, assert_array_almost_equal


def test_RPCA():
    """Test for RPCA"""
    tol = 1e-5

    sample = np.array([[1.0, 2, 3, 4], [2, 4, 6, 8]])
    clean = np.vstack([sample[None, ...]] * 100)
    noise_probability = 0.05
    rng = tl.check_random_state(12345)
    noise = rng.choice(
        [0.0, 100.0, -100.0],
        size=clean.shape,
        replace=True,
        p=[1 - noise_probability, noise_probability / 2, noise_probability / 2],
    )
    tensor = tl.tensor(clean + noise)
    corrupted_clean = np.copy(clean)
    corrupted_noise = np.copy(noise)
    clean = tl.tensor(clean)
    noise = tl.tensor(noise)
    clean_pred, noise_pred = robust_pca(
        tensor,
        mask=None,
        reg_E=0.4,
        mu_max=10e12,
        learning_rate=1.2,
        n_iter_max=200,
        tol=tol,
        verbose=True,
    )
    # check recovery
    assert_array_almost_equal(tensor, clean_pred + noise_pred, decimal=tol)
    # check low rank recovery
    assert_array_almost_equal(clean, clean_pred, decimal=1)
    # Check for sparsity of the gross error
    # assert tl.sum(noise_pred > 0.01) == tl.sum(noise > 0.01)
    assert_array_equal((noise_pred > 0.01), (noise > 0.01))
    # check sparse gross error recovery
    assert_array_almost_equal(noise, noise_pred, decimal=1)

    ############################
    # Test with missing values #
    ############################
    # Add some corruption (missing values, replaced by ones)
    mask = rng.choice([0, 1], clean.shape, replace=True, p=[0.05, 0.95])
    corrupted_clean[mask == 0] = 1
    tensor = tl.tensor(corrupted_clean + corrupted_noise)
    corrupted_noise = tl.tensor(corrupted_noise)
    corrupted_clean = tl.tensor(corrupted_clean)
    mask = tl.tensor(mask, dtype=tl.float64)
    # Decompose the tensor
    clean_pred, noise_pred = robust_pca(
        tensor,
        mask=mask,
        reg_E=0.4,
        mu_max=10e12,
        learning_rate=1.2,
        n_iter_max=200,
        tol=tol,
        verbose=True,
    )
    # check recovery
    assert_array_almost_equal(tensor, clean_pred + noise_pred, decimal=tol)
    # check low rank recovery
    assert_array_almost_equal(corrupted_clean * mask, clean_pred * mask, decimal=1)
    # check sparse gross error recovery
    assert_array_almost_equal(noise * mask, noise_pred * mask, decimal=1)

    # Check for recovery of the corrupted/missing part
    mask = 1 - mask
    error = tl.norm((clean * mask - clean_pred * mask), 2) / tl.norm(clean * mask, 2)
    assert_(error <= 10e-3)
