import torch

from .statevector import _validateStatevector

from qtorch import CTYPE,RTYPE

def _validateDensityMatrix(num_qubits:int, data:torch.Tensor, **kwargs)->None:
    '''
    Runs a series of checks to make sure that the passed tensor is a
    valid density matrix

    Arguments
    ---------
    num_qubit: int
        The number of qubits
    data: torch.Tensor
        The data to validate
    
    Keyword Arguments
    -----------------
    atol: float
        absolute tolerance to be passed to torch.isclose(),
        defaults to 1e-8
    rtol: float
        relative tolerance to be passed to torch.isclose(),
        defaults to 1e-5
    
    Raises
    ------
    ValueError
        - data is not a square matrix
        - the size of data is not 2**num_qbits
        - the trace of data is not 1.0
        - data is not Hermitian
        - data has any negative eigenvalues (not positive semi-definite)
    '''
    if data.shape[0] != data.shape[1]:
        raise ValueError('Density Matrix data must a square matrix!')
    if data.shape[0] != 2**num_qubits:
        raise ValueError(f'Expected square matrix of size 2^{num_qubits}!')
    if not torch.isclose(torch.trace(data).abs(),torch.tensor(1.0,dtype=RTYPE)):
        raise ValueError('Density Matrix must have unit trace!')
    if not torch.isclose(data.t(), data.conj()).all():
        raise ValueError('Density Matrix must be Hermitian!')
    
    eigvals = torch.linalg.eigvalsh(data)
    atol = kwargs['atol'] if 'atol' in kwargs else 1e-8
    rtol = kwargs['rtol'] if 'rtol' in kwargs else 1e-5
    if not (eigvals >= 0).all():
        # Make sure any negative eigenvalues are close to zero
        if not torch.isclose(eigvals[torch.where(eigvals<0)], 
                             torch.tensor(0.0,dtype=RTYPE), atol=atol, rtol=rtol
                             ).all():
            raise ValueError(f'Density Matrix must be positive semi-definite!')

def _validateEnsemble(num_qubits:int, psis:torch.Tensor, 
                      ensemble_probabilities:torch.Tensor, **kwargs)->None:
    '''
    Runs a series of checks to validate that the ensemble can be turned
    into a density matrix

    Arguments
    ---------
    num_qubits: int
        The number of qubits
    psis: torch.Tensor
        Collection of the n state vectors in the ensemble
        Shape: [n, 2**num_qubits] 
    ensemble_probabilities: torch.Tensor
        The probabilities of each of the state vectors in the ensemble
        Shape: [n]
    
    Keyword Arguments (for validation)
    ----------------------------------
    atol: float
        absolute tolerance to be passed to torch.isclose(),
        defaults to 1e-8
    rtol: float
        relative tolerance to be passed to torch.isclose(),
        defaults to 1e-5
    
    Raises
    ------
    ValueError
        - if any probability is negative
        - if the sum of the probabilities in the ensemble is not 1.0
    '''
    if psis.shape[0] != ensemble_probabilities.shape[0]:
        raise ValueError(f'The number of state vectors in `psis` '
                         f'({psis.shape[0]}) does not match the number of '
                         'ensemble probabilities '
                         f'({ensemble_probabilities.shape[0]})')
    if ((ensemble_probabilities < 0.0).any() or 
        (ensemble_probabilities > 1.0).any()):
        raise ValueError(f'All probabilities must be in the range [0.0, 1.0]')
    atol = kwargs['atol'] if 'atol' in kwargs else 1e-8
    rtol = kwargs['rtol'] if 'rtol' in kwargs else 1e-5
    for i, psi in enumerate(psis):
        try:
            _validateStatevector(num_qubits, psi, atol=atol, rtol=rtol)
        except ValueError as e:
            e.add_note(f'Statevector index {i}')
            raise
    if not torch.isclose(torch.sum(ensemble_probabilities), 
                         torch.tensor(1.0,device=ensemble_probabilities.device),
                         atol=atol, rtol=rtol):
        raise ValueError('The total probability of the ensemble is not 1.0')

def getPureState(psi:torch.Tensor)->torch.Tensor:
    '''
    Generates the density matrix for a pure state defined by |psi>

    Arguments
    ---------
    psi: Statevector
        The statevector object defining the pure state
    
    Returns:
        The density matrix data |psi><psi|
    '''
    return torch.outer(psi, psi.conj())

def DensityMatrix(num_qubits:int, data:torch.Tensor|None=None, 
                  ensemble_probabilities:torch.Tensor|None=None, 
                  **kwargs)->torch.Tensor:
    '''
    Validates and returns a density matrix from either a pure state's 
    statevector, matrix data or an ensemble of statevectors. If no data is 
    passed, the density matrix |0...0><0...0| is returned.

    Arguments
    ---------
    num_qubits: int
        The number of qubits that the density matrix describes
    data: torch.Tensor, optional
        - If `data` is a 1-dimensional tensor, it represents a pure state.
        - If `data` is a 2-dimensional tensor and `ensemble_probabilities` is 
        None, it contains the Z-basis representation of the density matrix.
        - If `data` is a 2-dimensional tensor and `ensemble_probabilities` is 
        not None, it contains a list of the statevectors in the ensemble.
    ensemble_probabilities: torch.Tensor, optional
        The probabilities of the statevectors in the ensemble.
    
    Keyword Arguments (for validation)
    ----------------------------------
    atol: float
        absolute tolerance to be passed to torch.isclose(),
        defaults to 1e-8
    rtol: float
        relative tolerance to be passed to torch.isclose(),
        defaults to 1e-5
    
    Returns
    -------
    torch.Tensor:
        The Z-basis density matrix representation of the passed data.
    
    Raises
    ------
    TypeError
        - if `data` is not a tensor or not None
        - if `ensemble_probabilities` is not a tensor or not None
    ValueError
        - if `data` is not a 1 or 2-dimensional tensor
    '''
    if data is None:
        rho = torch.zeros((2**num_qubits,2**num_qubits),dtype=CTYPE)
        rho[0,0] = 1.0
    elif isinstance(data,torch.Tensor):
        if data.dim() == 2:
            if ensemble_probabilities is None:
                _validateDensityMatrix(num_qubits, data, **kwargs)
                rho = data
            elif isinstance(ensemble_probabilities,torch.Tensor):
                _validateEnsemble(num_qubits, data, ensemble_probabilities, 
                                  **kwargs)
                rho = torch.sum(ensemble_probabilities * 
                                torch.vmap(torch.outer)(data,data.conj()), 
                                dim=0)
            else:
                raise TypeError('`ensemble_probabilities` must be None or a '
                                'tensor')
        elif data.dim() == 1:
            _validateStatevector(num_qubits, data,**kwargs)
            rho = getPureState(data)
        else:
            raise ValueError('`data` must be a 1 or 2 dimensional tensor')
        if rho.dtype != CTYPE:
            rho = rho.to(CTYPE)
    else:
        raise TypeError('`data` must be None or a tensor')
    return rho