import torch

def fidelity(a:torch.Tensor, b:torch.Tensor)->torch.Tensor:
    '''
    Returns the fidelity between the quantum states a and b
    
    Arguments
    ---------
    a: torch.Tensor
        A Z-basis representation of a statevector or density matrix 
    b: torch.Tensor
        A Z-basis representation of a statevector or density matrix 

    Returns
    -------
    torch.Tensor:
        Fidelity between a and b
    
    Raises
    ------
    ValueError:
        - when a or b are not 1 or 2-dimensional tensors
    '''
    
    if a.dim() == 1:
        if b.dim() == 1:
            return _fidelity_sv_sv(a, b)
        elif b.dim() == 2:
            return _fidelity_sv_dm(a, b)
        else:
            raise ValueError('Quantum state tensor must either be 1 or 2 '
                             'dimensional!')
    elif a.dim() == 2:
        if b.dim() == 1:
            return _fidelity_sv_dm(b, a)
        elif b.dim() == 2:
            return _fidelity_dm_dm(a, b)
        else:
            raise ValueError('Quantum state tensor must either be 1 or 2 '
                             'dimensional!')
    else:
        raise ValueError('Quantum state tensor must either be 1 or 2 '
                         'dimensional!')

def _fidelity_sv_sv(psi:torch.Tensor,phi:torch.Tensor)->torch.Tensor:
    '''Fidelity between state vectors psi and phi'''
    return torch.abs(torch.vdot(psi,phi))**2

def _fidelity_sv_dm(psi:torch.Tensor,rho:torch.Tensor)->torch.Tensor:
    '''Fidelity between statevector psi and density matrix rho'''
    return torch.vdot(psi, rho @ psi).real

def _fidelity_dm_dm(rho:torch.Tensor, sigma:torch.Tensor)->torch.Tensor:
    '''Fidelity between density matrices rho and sigma'''
    def sqrt_matrix(M:torch.Tensor)->torch.Tensor:
        '''Calculate the matrix square root of desnity matrix M'''
        eigvals, eigvecs = torch.linalg.eigh(M)
        # Asserts positive semi-definiteness of densitry matrix
        eigvals = torch.sqrt(torch.clamp(eigvals,0.0)) 
        return (eigvecs * eigvals.unsqueeze(0)) @ eigvecs.T.conj()
    
    sqrt_rho = sqrt_matrix(rho)
    
    M = sqrt_rho @ sigma @ sqrt_rho
    eigvals = torch.clamp(torch.linalg.eigvalsh(M),0.0)
    return torch.sqrt(eigvals).sum()**2
