"""Some utilities for testing."""
from typing import Optional

import torch

from sfrp_torch import sfrp


# - consistency (i.e. same result when applied to same matrix)
# - covariant with permutation of rows of the matrix
# - approximately preserves distances


###############################################################################


def materialize_projection_matrix(
    d_original: int,
    d_projection: int,
    projection_type: Optional[str] = None,
    algorithm: Optional[str] = None,
    # Required for only some sparse projections. Otherwise gets ignored.
    sparse_region_size: Optional[int] = None,
    *,
    seed: int,
    device: torch.device,
    dtype: torch.dtype = torch.float32
) -> torch.Tensor:
    # NOTE: Current implementation can probably pretty easily OOM. In the future,
    # we can compute parts of the matrix at a time to get around this.
    id_matrix = torch.eye(d_original, dtype=dtype, device=device)
    return sfrp.project(
        id_matrix,
        d_projection=d_projection,
        projection_type=projection_type,
        algorithm=algorithm,
        sparse_region_size=sparse_region_size,
        seed=seed,
    )


def materialize_transposed_projection_matrix(
    d_original: int,
    d_projection: int,
    projection_type: Optional[str] = None,
    algorithm: Optional[str] = None,
    # Required for only some sparse projections. Otherwise gets ignored.
    sparse_region_size: Optional[int] = None,
    *,
    seed: int,
    device: torch.device,
    dtype: torch.dtype = torch.float32
) -> torch.Tensor:
    # NOTE: Current implementation can probably pretty easily OOM. In the future,
    # we can compute parts of the matrix at a time to get around this.
    id_matrix = torch.eye(d_projection, dtype=dtype, device=device)
    return sfrp.transposed_project(
        id_matrix,
        d_original=d_original,
        projection_type=projection_type,
        algorithm=algorithm,
        sparse_region_size=sparse_region_size,
        seed=seed,
    )
