import torch
from numba import cuda

from qtorch.config import QTORCH_CONFIG

from .utils import validateInput


def _x_on_statevector(psi:torch.Tensor, total_qubits:int, target_qubit:int
                      )->torch.Tensor:
    '''
    Applies the X gate to |psi>

    Arguments
    ---------
    psi: torch.Tensor
        The statevector to apply X to
    total_qubits: int
        The number of qubits that the statevector describes
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The statevector after applying X to |psi>
    '''
    phi = torch.empty_like(psi)
    I = torch.arange(2**total_qubits,device=psi.device)
    phi[I ^ (1 << target_qubit)] = psi
    return phi

def _x_on_densitymatrix(rho:torch.Tensor, total_qubits:int, target_qubit:int
                        )->torch.Tensor:
    '''
    Applies the X gate to the density matrix rho

    Arguments
    ---------
    rho: torch.Tensor
        The density matrix to apply X to
    total_qubits: int
        The number of qubits that the density matrix describes
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The statevector after applying X to rho
    '''
    rho1 = torch.empty_like(rho)
    I = torch.arange(2**total_qubits,device=rho.device)
    swap_indices = I ^ (1 << target_qubit)
    rho1 = rho[swap_indices[:,None], swap_indices[None,:]]
    return rho1

def pauli_x(qs:torch.Tensor, total_qubits:int, target_qubit:int)->torch.Tensor:
    '''
    Applies the X gate to the passed quantum state

    Arguments
    ---------
    qs: torch.Tensor
        The quantum state to apply X to
    total_qubits: int
        The number of qubits that the quantum state describes
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The quantum state after applying the X gate
    
    Raises
    ------
    NotImplementedError
        - If `qs` is not a 1 or 2-dimensional tensor
    '''
    if not QTORCH_CONFIG['skipValidation']:
        validateInput(qs, total_qubits, target_qubit, num_targets=1)
    
    if qs.dim() == 1:
        return _x_on_statevector(qs, total_qubits, target_qubit)
    elif qs.dim() == 2:
        return _x_on_densitymatrix(qs,total_qubits,target_qubit)
    else:
        raise NotImplementedError()

@cuda.jit(device=True)
def x_statevector_kernel(out:torch.Tensor, psi:torch.Tensor, T:int, 
                         i:int)->None:
    '''
    Device function to calculate the i-th term of applying X on a statevector

    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    psi: torch.Tensor
        the statevector to apply the operation on
    T: int
        `1 << target_qubit`
    i: int
        The index of the amplitude to be calculated
    '''
    out[2*i]   = psi[2*(i^T)]
    out[2*i+1] = psi[2*(i^T)+1]

@cuda.jit(device=True)
def x_densitymatrix_kernel(out:torch.Tensor, rho:torch.Tensor, T:int, 
                           i:int, j:int)->None:
    '''
    Device function to calculate the (i,j)-th term of applying X on a density
    matrix.

    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    rho: torch.Tensor
        the density matrix to apply the operation on
    T: int
        `1 << target_qubit`
    i, j: int, int
        the indices of the output density matrix to be calculated
    '''
    out[i,2*j]   = rho[i^T, 2*(j^T)]
    out[i,2*j+1] = rho[i^T, 2*(j^T)+1]
