""""""
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import numpy as np


@dataclass
class RandomProjectionSketch:
    """"""

    input_dim: int
    sketch_dim: Optional[int] = None
    random_state: Optional[int] = None

    def __post_init__(self) -> None:
        if self.sketch_dim is None or self.sketch_dim >= self.input_dim:
            self.sketch_dim = self.input_dim
            self.R = np.eye(self.input_dim, dtype=np.float32)
            self.use_sketch = False
        else:
            rng = np.random.default_rng(self.random_state)
            self.R = rng.normal(
                loc=0.0,
                scale=1.0 / np.sqrt(self.sketch_dim),
                size=(self.input_dim, self.sketch_dim),
            ).astype(np.float32)
            self.use_sketch = True

    def sketch(self, M: np.ndarray) -> np.ndarray:
        """\nM: np.ndarray\n        """
        if M.ndim != 2 or M.shape[1] != self.input_dim:
            raise ValueError(
                f"M must have shape (K, {self.input_dim}), got {M.shape}"
            )
        return M @ self.R
