"""Code for random projections."""
import dataclasses
from typing import Optional

import torch


###############################################################################


@dataclasses.dataclass(frozen=True, eq=True)
class RandomProjectionParams:
    """Parameters controlling a random projection."""
    d_projection: int

    projection_type: str
    algorithm: str

    seed: int

    # Only required for some sparse random projections.
    sparse_region_size: Optional[int] = None

    def to_json(self):
        return {
            'd_projection': int(self.d_projection),
            'projection_type': self.projection_type,
            'algorithm': self.algorithm,
            'seed': int(self.seed),
            'sparse_region_size': int(self.sparse_region_size) if self.sparse_region_size is not None else None,
        }

    @classmethod
    def from_json(cls, jason: dict) -> 'RandomProjectionParams':
        return cls(**jason)


###############################################################################


@dataclasses.dataclass
class RandomProjector:
    params: RandomProjectionParams

    def __post_init__(self):
        from sfrp_torch import sfrp
        self.sfrp = sfrp

    #######################################################

    def project(self, x: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> torch.Tensor:
        # x.shape = [n?, d_original]
        # ret/out.shape = [n?, d_projection]
        return self.sfrp.project(
            x,
            d_projection=self.params.d_projection,
            projection_type=self.params.projection_type,
            algorithm=self.params.algorithm,
            sparse_region_size=self.params.sparse_region_size,
            seed=self.params.seed,
            out=out,
        )

    def transposed_project(self, x: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor:
        # x.shape = [n?, d_projection]
        # out.shape = [n?, d_original]
        return self.sfrp.transposed_project(
            x,
            projection_type=self.params.projection_type,
            algorithm=self.params.algorithm,
            sparse_region_size=self.params.sparse_region_size,
            seed=self.params.seed,
            out=out,
        )

    #######################################################

    @classmethod
    def create(cls, **kwargs):
        return cls(**kwargs)
