import torch
from numba import cuda
from qtorch.config import QTORCH_CONFIG

from .utils import validateInput

def _cnot_on_statevector(psi:torch.Tensor, total_qubits:int, target_qubit:int, 
                         control_qubit:int)->torch.Tensor:
    '''
    Applies the CNOT gate to |psi>

    Arguments
    ---------
    psi: torch.Tensor
        The statevector to apply CNOT to
    total_qubits: int
        The number of qubits that the statevector describes
    target_qubit: int
        The index of the target qubit
    control_qubit: int
        The index of the control qubit
    
    Returns
    -------
    torch.Tensor:
        The statevector after applying CNOT to |psi>
    '''
    phi = psi.clone()
    I = torch.arange(2**total_qubits,device=psi.device)
    active_mask = (I & (1 << control_qubit)) != 0
    swap_indices = I ^ (1 << target_qubit)
    phi[active_mask] = psi[swap_indices][active_mask]
    return phi

def _cnot_on_densitymatrix(rho:torch.Tensor, total_qubits:int, target_qubit:int,
                           control_qubit:int)->torch.Tensor:
    '''
    Applies the CNOT gate to the density matrix rho

    Arguments
    ---------
    rho: torch.Tensor
        The density matrix to apply CNOT to
    total_qubits: int
        The number of qubits that the density matrix describes
    target_qubit: int
        The index of the target qubit
    control_qubit: int
        The index of the control qubit
    
    Returns
    -------
    torch.Tensor:
        The density matrix after applying CNOT to rho
    '''
    I = torch.arange(2**total_qubits, device=rho.device)
    active_mask = I & (1 << control_qubit)
    active_indices = torch.where(active_mask)[0]
    passive_indices = torch.where(torch.logical_not(active_mask))[0]
    swap_indices = active_indices ^ (1 << target_qubit)

    rho1 = torch.empty_like(rho)
    
    rho1[passive_indices[:,None], 
         passive_indices[None,:]] = rho[passive_indices[:,None], 
                                        passive_indices[None,:]]
    
    rho1[passive_indices[:,None], 
         active_indices[None,:]] = rho[passive_indices[:,None], 
                                       swap_indices[None,:]]
    
    rho1[active_indices[:,None], 
         passive_indices[None,:]] = rho[swap_indices[:,None], 
                                        passive_indices[None,:]]
    
    rho1[active_indices[:,None], 
         active_indices[None,:]] = rho[swap_indices[:,None], 
                                       swap_indices[None,:]]

    return rho1

def cnot(qs:torch.Tensor, total_qubits:int, target_qubit:int, control_qubit:int
         )->torch.Tensor:
    '''
    Applies the CNOT gate to the passed quantum state

    Arguments
    ---------
    qs: torch.Tensor
        The quantum state to apply CNOT to
    total_qubits: int
        The number of qubits that the quantum state describes
    target_qubit: int
        The index of the target qubit
    control_qubit: int
        The index of the control qubit
    
    Returns
    -------
    torch.Tensor:
        The quantum state after applying the X gate
    
    Raises
    ------
    NotImplementedError
        - If `qs` is not a 1 or 2-dimensional tensor
    '''
    if not QTORCH_CONFIG['skipValidation']:
        validateInput(qs, total_qubits, target_qubit, control_qubit)
    if qs.dim() == 1:
        return _cnot_on_statevector(qs, total_qubits, target_qubit, 
                                    control_qubit)
        
    elif qs.dim() == 2:
        return _cnot_on_densitymatrix(qs,total_qubits,target_qubit, 
                                      control_qubit)
    else:
        raise NotImplementedError()

@cuda.jit(device=True)
def cnot_statevector_kernel(out:torch.Tensor, psi:torch.Tensor, T:int, C:int, 
                            i:int)->None:
    '''
    Device function to calculate the i-th term of applying CNOT on a statevector

    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    psi: torch.Tensor
        the statevector to apply the operation on
    T: int
        `1 << target_qubit`
    C: int
        `1 << control_qubit`
    i: int
        The index of the amplitude to be calculated
    '''
    new_i = i ^ (int(i & C != 0) * T)
    out[2*i] = psi[2*new_i]
    out[2*i+1] = psi[2*new_i+1]

@cuda.jit(device=True)
def cnot_density_matrix_kernel(out:torch.Tensor, rho:torch.Tensor, T:int, C:int, 
                               i:int, j:int):
    '''
    Device function to calculate the (i, j)-th term of applying CNOT on a 
    density matrix

    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    rho: torch.Tensor
        the density matrix to apply the operation on
    T: int
        `1 << target_qubit`
    C: int
        `1 << control_qubit`
    i, j: int, int
        the indices of the output density matrix to be calculated
    '''
    new_i = i ^ (int((i & C) != 0) * T)
    new_j = j ^ (int((j & C) != 0) * T)
    out[i, 2*j]   = rho[new_i, 2 * new_j]
    out[i, 2*j+1] = rho[new_i, 2 * new_j + 1]
