import torch

def getZDistribution(quantumstate:torch.Tensor)->torch.Tensor:
    '''
    Returns the Z-basis measurement probabilities of a quantum state

    Arguments
    ---------
    quantumstate: torch.Tensor
        The Z-basis representation of a statevector or density matrix to measure
    
    Returns
    -------
    torch.Tensor:
        Z-basis measurement probabilities
    
    Raises
    ------
    ValueError
        -  `quantumstate` is not a 1 or 2-dimensional tensor 
    '''
    if quantumstate.dim() == 1:
        return torch.abs(quantumstate)**2
    elif quantumstate.dim() == 2:
        return torch.diagonal(quantumstate).real
    else:
        raise ValueError('`quantumstate` must be a 1 or 2 dimensional tensor')

def getExpectation(quantumstate:torch.Tensor, observable:torch.Tensor
                   )->torch.Tensor:
    '''
    Returns the expectation value of the observable on the quantum state

    Arguments
    ---------
    quantumstate: torch.Tensor
        The Z-basis representation of a statevector or density matrix to measure
    observable: torch.Tensor
        The Z-basis representation of the observable (Hermitian matrix), or if
        the observable is diagonalizable in the Z-basis, a 1-dimensional tensor
        containing the Z-basis eigenvalues.
    
    Returns
    -------
    torch.Tensor:
        The expectation value of the observable on the quantum state
    
    Raises
    ------
    ValueError
        - `quantumstate` is not a 1 or 2-dimensional tensor
        - `observable` is not a 1 or 2-dimensional tensor
    '''
    if quantumstate.dim() == 1:
        if observable.dim() == 1:
            return torch.dot(observable, quantumstate.abs()**2).real
        elif observable.dim() == 2:
            return torch.vdot(quantumstate, observable @ quantumstate).real
        else:
            raise ValueError('`observable` must be a 1 or 2 dimensional tensor')
    elif quantumstate.dim() == 2:
        if observable.dim() == 1:
            return torch.dot(quantumstate.diag().real, observable)
        elif observable.dim() == 2:
            return (quantumstate @ observable).trace().real
        else:
            raise ValueError('`observable` must be a 1 or 2 dimensional tensor')
    else:
        raise ValueError('`quantumstate` must be a 1 or 2 dimensional tensor')
