import torch
from numba import cuda
from qtorch.config import QTORCH_CONFIG

from .utils import validateInput

def _z_on_statevector(psi:torch.Tensor, total_qubits:int, target_qubit:int
                      )->torch.Tensor:
    '''
    Applies the Z gate to |psi>

    Arguments
    ---------
    psi: torch.Tensor
        The statevector to apply Z to
    total_qubits: int
        The number of qubits that the statevector describes
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The statevector after applying Z to |psi>
    '''
    phi = psi.clone()
    I = torch.arange(2**total_qubits,device=psi.device)
    phi[I & (1 << target_qubit) > 0] *= -1
    return phi

def _z_on_densitymatrix(rho:torch.Tensor, total_qubits:int, target_qubit:int
                        )->torch.Tensor:
    '''
    Applies the Z gate to the density matrix rho

    Arguments
    ---------
    rho: torch.Tensor
        The density matrix to apply Z to
    total_qubits: int
        The number of qubits that the density matrix describes
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The statevector after applying Z to rho
    '''
    rho1 = rho.clone()
    I = torch.arange(2**total_qubits,device=rho.device)
    rho1[(I[:,None] ^ I[None,:]) & (1<<target_qubit) > 0] *= -1
    return rho1

def pauli_z(qs:torch.Tensor, total_qubits:int, target_qubit:int)->torch.Tensor:
    '''
    Applies the Z gate to the passed quantum state

    Arguments
    ---------
    qs: torch.Tensor
        The quantum state to apply Z to
    total_qubits: int
        The number of qubits that the quantum state describes
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The quantum state after applying the Z gate
    
    Raises
    ------
    NotImplementedError
        - If `qs` is not a 1 or 2-dimensional tensor
    '''
    if not QTORCH_CONFIG['skipValidation']:
        validateInput(qs, total_qubits, target_qubit, num_targets=1)
    
    if qs.dim() == 1:
        return _z_on_statevector(qs, total_qubits, target_qubit)
    elif qs.dim() == 2:
        return _z_on_densitymatrix(qs,total_qubits,target_qubit)
    else:
        raise NotImplementedError()

@cuda.jit(device=True)
def z_statevector_kernel(out:torch.Tensor, psi:torch.Tensor, T:int, 
                         i:int)->None:
    '''
    Device function to calculate the i-th term of applying Z on a statevector

    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    psi: torch.Tensor
        the statevector to apply the operation on
    T: int
        `1 << target_qubit`
    i: int
        The index of the amplitude to be calculated
    '''
    s = 1 - 2 * int( (i&T) != 0)
    out[2*i]   = psi[2*i] * s
    out[2*i+1] = psi[2*i+1] * s

@cuda.jit(device=True)
def z_densitymatrix_kernel(out:torch.Tensor, rho:torch.Tensor, T:int, 
                           i:int, j:int)->None:
    '''
    Device function to calculate the (i,j)-th term of applying Z on a density
    matrix.

    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    rho: torch.Tensor
        the density matrix to apply the operation on
    T: int
        `1 << target_qubit`
    i, j: int, int
        the indices of the output density matrix to be calculated
    '''
    s = 1 - 2*int( ((i^j)&T) != 0 )
    out[i,2*j]   = rho[i, 2*j] * s
    out[i,2*j+1] = rho[i, 2*j+1] * s
