import torch

from qtorch.config import QTORCH_CONFIG
from qtorch.quantumstate.statevector import _validateStatevector
from qtorch.quantumstate.densitymatrix import _validateDensityMatrix
from qtorch.unitaries import pauli_x

def BitFlipChannel(qs:torch.Tensor, num_qubits:int, prob:float, **kwargs
                   )->torch.Tensor:
    '''
    Apply the bit flip channel to every qubit in the quantum state.
    
    Arguments
    ---------
    qs: torch.Tensor
        The quantum state to apply the channel to
    num_qubits: int
        The number of qubits in the system
    prob: float
        The probability of each bit flip
    
    Keyword Arguments (for validation)
    ----------------------------------
    atol: float
        absolute tolerance to be passed to torch.isclose(),
        defaults to 1e-8
    rtol: float
        relative tolerance to be passed to torch.isclose(),
        defaults to 1e-5
    
    Returns
    -------
    torch.Tensor:
        The resulting quantum state
    '''
    if not QTORCH_CONFIG['skipValidation']:
        if qs.dim() ==1:
            _validateStatevector(num_qubits, qs, **kwargs)
        elif qs.dim() == 2:
            _validateDensityMatrix(num_qubits, qs, **kwargs)
        else:
            raise ValueError('Quantum state must be a 1 or 2-dimensional '
                             'tensor.')
        if prob < 0.0 or prob > 1.0:
            raise ValueError(f'Invalid probability: {prob} supplied for '
                             'bit flip channel.')
    if qs.dim() == 1:
        psi = qs.clone()
        rands = torch.rand(num_qubits)
        for i in torch.where(rands < prob)[0]:
            psi = pauli_x(psi, num_qubits, i.item())
        return psi
    elif qs.dim() == 2:
        rho = qs.clone()
        for i in range(num_qubits):
            rho = prob*pauli_x(rho, num_qubits, i) + (1.0-prob)*rho
        return rho
    else:
        raise ValueError('Quantum state must be a 1 or 2-dimensional tensor.')