import pytest
from nesim.utils.dimensionality import EffectiveDimensionality
import torch

absolute_tolerance = 1e-5
functionality_tolerance = 1e-2
device = "cuda:0" if torch.cuda.is_available() else "cpu"

batch_sizes = [1, 2, 32, 64]

matrix_num_rows = [1024, 4096]

matrix_num_columns = [2, 16, 64, 1024]


@pytest.mark.parametrize("batch_size", batch_sizes)
@pytest.mark.parametrize("num_rows", matrix_num_rows)
@pytest.mark.parametrize("num_columns", matrix_num_columns)
def test_consistent_results_on_different_batch_sizes(
    batch_size: int, num_rows: int, num_columns: int
):
    matrix = torch.randn(num_rows, num_columns).to(device)
    e = EffectiveDimensionality(flatten=True, device=device, batch_size=batch_size)

    original_value = e.get_eigenvalues_torch_original(matrix)
    new_value = e.get_eigenvalues_torch_batched(matrix)

    assert torch.allclose(
        original_value, new_value, atol=absolute_tolerance
    ), f"Mismatch:\noriginal_value: {original_value}\nnew_value: {new_value}"


@pytest.mark.parametrize("batch_size", batch_sizes)
@pytest.mark.parametrize("num_rows", matrix_num_rows)
def test_functionality(batch_size: int, num_rows: int):
    column_data_1 = torch.randn(num_rows).to(
        device
    )  # .unsqueeze(-1) # shape: num_rows, 1 col
    column_data_2 = torch.randn(num_rows).to(
        device
    )  # .unsqueeze(-1) # shape: num_rows, 1 col

    matrix = torch.stack(
        [column_data_1, column_data_2, column_data_1, column_data_2], dim=-1
    )
    e = EffectiveDimensionality(flatten=True, device=device, batch_size=batch_size)

    eigen_values = e.get_eigenvalues_torch_original(matrix)
    effective_dim = e.get_effective_dim(eigen_values=eigen_values)
    ## dimensionality should be a value thats super close to 2
    assert (
        abs(effective_dim - 2.0) < functionality_tolerance
    ), f"Expected effective_dim to be a value very close to 2 (between: {2 - functionality_tolerance} to {2 + functionality_tolerance}), but got: {effective_dim}"
