"""Python convience wrapper around the C++ storage-free random projection code."""
from typing import Optional

import torch

from sfrp_torch.sfrp_cc_module_wrapper import sfrp_torch_cc


###############################################################################

_PROJECTION_TYPES = ('hypercubic_v1', 'hypercubic_v2')

# First is the "default" algorithm for that projection type.
_PROJECTION_TYPE_TO_ALGORITHMS = {
    'hypercubic_v1': ('alg3', 'alg2', 'alg1'),
    'hypercubic_v2': ('alg3',),
}

_CC_METHOD_NAME_MAP = {
    ('hypercubic_v1', 'alg1'): 'rand_proj_hypercubic_v1_alg1',
    ('hypercubic_v1', 'alg2'): 'rand_proj_hypercubic_v1_alg2',
    ('hypercubic_v1', 'alg3'): 'rand_proj_hypercubic_v1_alg3',
    #
    ('hypercubic_v2', 'alg3'): 'rand_proj_hypercubic_v2_alg3',
}

###############################################################################


def project(
    x: torch.Tensor,
    # d_projection must be provided if out is None.
    d_projection: Optional[int] = None,
    projection_type: Optional[str] = None,
    algorithm: Optional[str] = None,
    *,
    seed: int,
    out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    # x.shape = [n?, d_og]
    # out.shape = [n?, d_projection]

    if d_projection is None and out is None:
        raise ValueError('Either d_projection or out must be provided.')
    elif d_projection is not None and out is not None:
        if d_projection != out.shape[-1]:
            raise ValueError('Mismatch between the provided d_projection and the value inferred from out.')
    elif d_projection is None:
        d_projection = out.shape[-1]

    if x.dim() == 1:
        output_vector = True
        x = x[None, :]
        if out is not None:
            if out.dim() != 1:
                raise ValueError('The tensor was a vector but the provided out tensor was not a vector.')
            out = out[None, :]

    elif x.dim() == 2:
        output_vector = False
        if out is not None:
            if out.dim() != 2:
                raise ValueError('The tensor was a matrix but the provided out tensor was not a matrix.')
            if x.shape[0] != out.shape[0]:
                raise ValueError('Mismatch in first dimension of the arguments x and out.')
        
    else:
        raise ValueError('The tensor must either be a vector or matrix.')

    if projection_type is None:
        projection_type = _PROJECTION_TYPES[0]
    if projection_type not in _PROJECTION_TYPES:
        raise ValueError(f'Invalid projection_type: {projection_type}')

    if algorithm is None:
        algorithm = _PROJECTION_TYPE_TO_ALGORITHMS[projection_type][0]
    if algorithm not in _PROJECTION_TYPE_TO_ALGORITHMS[projection_type]:
        raise ValueError(f'Invalid algorithm for projection_type: {algorithm}')

    if out is None:
        out = torch.empty([x.shape[0], d_projection], dtype=x.dtype, device=x.device)

    method = getattr(sfrp_torch_cc, _CC_METHOD_NAME_MAP[(projection_type, algorithm)])
    method(x, seed, out)

    if output_vector:
        out = torch.squeeze(out, dim=0)

    return out
