"""Python convience wrapper around the C++ storage-free random projection code."""
from typing import Optional

import torch


###############################################################################

_DN_PROJECTION_TYPES = ('bernoulli_v2', 'bernoulli_v3_xorwow',)
_SP_PROJECTION_TYPES = ('sp_bernoulli_v1_xorwow', 'sp_bernoulli_v2_xorwow',)
_PROJECTION_TYPES = (
    *_DN_PROJECTION_TYPES,
    *_SP_PROJECTION_TYPES,
)

# First is the "default" algorithm for that projection type.
_PROJECTION_TYPE_TO_ALGORITHMS = {
    'bernoulli_v2': ('alg3',),
    'bernoulli_v3_xorwow': ('alg3',),
    'sp_bernoulli_v1_xorwow': ('alg3',),
    'sp_bernoulli_v2_xorwow': ('alg1',),
}

_CC_METHOD_NAME_MAP = {
    ('bernoulli_v2', 'alg3'): 'project_dn_bernoulli_v2_alg3',
    #
    ('bernoulli_v3_xorwow', 'alg3'): 'project_dn_bernoulli_v3_xorwow_alg3',
    #
    ('sp_bernoulli_v1_xorwow', 'alg3'): 'project_sp_bernoulli_v1_xorwow_alg3',
    #
    ('sp_bernoulli_v2_xorwow', 'alg1'): 'project_sp_bernoulli_v2_xorwow_alg1',
}

_CC_TRANSPOSED_METHOD_NAME_MAP = {
    ('bernoulli_v3_xorwow', 'alg3'): 'transposed_project_dn_bernoulli_v3_xorwow_alg1',
    ('sp_bernoulli_v2_xorwow', 'alg1'): 'transposed_project_sp_bernoulli_v2_xorwow_alg1',
}

###############################################################################


def project(
    x: torch.Tensor,
    # d_projection must be provided if out is None.
    d_projection: Optional[int] = None,
    projection_type: Optional[str] = None,
    algorithm: Optional[str] = None,
    # Required for only some sparse projections. Otherwise gets ignored.
    sparse_region_size: Optional[int] = None,
    *,
    seed: int,
    out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    # x.shape = [n?, d_original]
    # out.shape = [n?, d_projection]

    from build import sfrp_torch as cc_sfrp_torch

    if d_projection is None and out is None:
        raise ValueError('Either d_projection or out must be provided.')
    elif d_projection is not None and out is not None:
        if d_projection != out.shape[-1]:
            raise ValueError('Mismatch between the provided d_projection and the value inferred from out.')
    elif d_projection is None:
        d_projection = out.shape[-1]

    if x.dim() == 1:
        output_vector = True
        x = x[None, :]
        if out is not None:
            if out.dim() != 1:
                raise ValueError('The tensor was a vector but the provided out tensor was not a vector.')
            out = out[None, :]

    elif x.dim() == 2:
        output_vector = False
        if out is not None:
            if out.dim() != 2:
                raise ValueError('The tensor was a matrix but the provided out tensor was not a matrix.')
            if x.shape[0] != out.shape[0]:
                raise ValueError('Mismatch in first dimension of the arguments x and out.')
        
    else:
        raise ValueError('The tensor must either be a vector or matrix.')

    if projection_type is None:
        projection_type = _PROJECTION_TYPES[0]
    if projection_type not in _PROJECTION_TYPES:
        raise ValueError(f'Invalid projection_type: {projection_type}')

    if algorithm is None:
        algorithm = _PROJECTION_TYPE_TO_ALGORITHMS[projection_type][0]
    if algorithm not in _PROJECTION_TYPE_TO_ALGORITHMS[projection_type]:
        raise ValueError(f'Invalid algorithm for projection_type: {algorithm}')

    if out is None:
        out = torch.empty([x.shape[0], d_projection], dtype=x.dtype, device=x.device)

    method = getattr(cc_sfrp_torch, _CC_METHOD_NAME_MAP[(projection_type, algorithm)])
    
    if projection_type in _SP_PROJECTION_TYPES:
        if sparse_region_size is None:
            raise ValueError('A non-None sparse_region_size must be passed when using this projection type.')
        method(x, seed, sparse_region_size, out)
    else:
        method(x, seed, out)

    if output_vector:
        out = torch.squeeze(out, dim=0)

    return out


###############################################################################


def transposed_project(
    x: torch.Tensor,
    # d_original must be provided if out is None.
    d_original: Optional[int] = None,
    projection_type: Optional[str] = None,
    algorithm: Optional[str] = None,
    # Required for only some sparse projections. Otherwise gets ignored.
    sparse_region_size: Optional[int] = None,
    *,
    seed: int,
    out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    # x.shape = [n?, d_projection]
    # out.shape = [n?, d_original]

    from build import sfrp_torch as cc_sfrp_torch
    
    if d_original is None and out is None:
        raise ValueError('Either d_original or out must be provided.')
    elif d_original is not None and out is not None:
        if d_original != out.shape[-1]:
            raise ValueError('Mismatch between the provided d_original and the value inferred from out.')
    elif d_original is None:
        d_original = out.shape[-1]

    if x.dim() == 1:
        output_vector = True
        x = x[None, :]
        if out is not None:
            if out.dim() != 1:
                raise ValueError('The tensor was a vector but the provided out tensor was not a vector.')
            out = out[None, :]

    elif x.dim() == 2:
        output_vector = False
        if out is not None:
            if out.dim() != 2:
                raise ValueError('The tensor was a matrix but the provided out tensor was not a matrix.')
            if x.shape[0] != out.shape[0]:
                raise ValueError('Mismatch in first dimension of the arguments x and out.')
        
    else:
        raise ValueError('The tensor must either be a vector or matrix.')

    if projection_type is None:
        projection_type = _PROJECTION_TYPES[0]
    if projection_type not in _PROJECTION_TYPES:
        raise ValueError(f'Invalid projection_type: {projection_type}')

    if algorithm is None:
        algorithm = _PROJECTION_TYPE_TO_ALGORITHMS[projection_type][0]
    if algorithm not in _PROJECTION_TYPE_TO_ALGORITHMS[projection_type]:
        raise ValueError(f'Invalid algorithm for projection_type: {algorithm}')

    if out is None:
        out = torch.empty([x.shape[0], d_original], dtype=x.dtype, device=x.device)

    method = getattr(cc_sfrp_torch, _CC_TRANSPOSED_METHOD_NAME_MAP[(projection_type, algorithm)], None)
    if method is None:
        raise ValueError('No transposed projection method currently exists for the provided projection_type and algorithm.')
    
    if projection_type in _SP_PROJECTION_TYPES:
        if sparse_region_size is None:
            raise ValueError('A non-None sparse_region_size must be passed when using this projection type.')
        method(x, seed, sparse_region_size, out)
    else:
        method(x, seed, out)

    if output_vector:
        out = torch.squeeze(out, dim=0)

    return out

