import torch

from qtorch import CTYPE

def _validateStatevector(num_qubits:int, psi:torch.Tensor, **kwargs)->None:
    '''
    Runs a series of checks to ensure that the statevector psi is valid.

    Arguments
    ---------
    num_qubit: int
        The number of qubits
    psi: torch.Tensor
        The Statevector object to validate
    
    Keyword Arguments
    -----------------
    atol: float
        absolute tolerance to be passed to torch.isclose(),
        defaults to 1e-8
    rtol: float
        relative tolerance to be passed to torch.isclose(),
        defaults to 1e-5
    
    Raises
    ------
    ValueError
        - The number of qubits of psi does not match the num_qbits
        - psi is not normalized
    '''
    if psi.dim() != 1:
            raise ValueError('`psi` should be a one-dimensional tensor.')
    if psi.shape[0] != 2**num_qubits:
        raise ValueError(f'Expected statevector length {2**num_qubits} for '
                         f'{num_qubits} qubits, received ({psi.shape[0]})')
    atol = kwargs['atol'] if 'atol' in kwargs else 1e-8
    rtol = kwargs['rtol'] if 'rtol' in kwargs else 1e-5
    if not torch.isclose(psi.norm(), torch.tensor(1.0, device=psi.device),
                         atol=atol, rtol=rtol):
        raise ValueError('The Statevector is not normalized!')

def Statevector(num_qubits:int, data:torch.Tensor|None=None, **kwargs
                )->torch.Tensor:
    '''
    Validates the passed data as an n-qubit statevector and returns it, if no 
    data is passed, the |0...0> statevector is returned.

    Arguments
    ---------
    num_qubits: int
        The number of qubits that the statevector describes
    data: torch.Tensor, optional
        The statevector data to validate
    
    Keyword Arguments (for validation)
    ----------------------------------
    atol: float
        absolute tolerance to be passed to torch.isclose(),
        defaults to 1e-8
    rtol: float
        relative tolerance to be passed to torch.isclose(),
        defaults to 1e-5
    
    Returns
    -------
    torch.Tensor:
        The Z-basis statevector representation of the passed data.
    '''
    if data is None:
        data = torch.zeros(2**num_qubits,dtype=CTYPE)
        data[0] = 1.0
    else:
        if not isinstance(data,torch.Tensor):
            data = torch.tensor(data,dtype=CTYPE)
        else:
            if data.dtype != CTYPE:
                data = data.to(dtype=CTYPE)
        _validateStatevector(num_qubits, data,**kwargs)
    return data
