import torch

from .bit_flip import BitFlipChannel
from .phase_flip import PhaseFlipChannel

def BitPhaseFlipChannel(rho:torch.Tensor, num_qubits:int, prob_x:float, 
                        prob_z:float|None=None, **kwargs)->torch.Tensor:
    '''
    Apply the bit flip channel followed by the phase flip channel to every qubit
    in the quantum state.
    
    Arguments
    ---------
    rho: torch.Tensor
        The quantum state to apply the channel to
    num_qubits: int
        The number of qubits in the system
    prob_x: float
        The probability of each bit flip
    prob_z: float, optional
        The probability of each phase flip, if not provided, prob_z = prob_x
    
    Keyword Arguments (for validation)
    ----------------------------------
    atol: float
        absolute tolerance to be passed to torch.isclose(),
        defaults to 1e-8
    rtol: float
        relative tolerance to be passed to torch.isclose(),
        defaults to 1e-5
    
    Returns
    -------
    torch.Tensor:
        The resulting quantum state
    '''
    if prob_z is None:
        prob_z = prob_x
    return PhaseFlipChannel(BitFlipChannel(rho, num_qubits, prob_x, **kwargs), 
                            num_qubits, prob_z, **kwargs)
