import torch
from numba import cuda
from qtorch.config import QTORCH_CONFIG

from .utils import validateInput

def _y_on_statevector(psi:torch.Tensor, total_qubits:int, target_qubit:int
                      )->torch.Tensor:
    '''
    Applies the Y gate to |psi>

    Arguments
    ---------
    psi: torch.Tensor
        The statevector to apply Y to
    total_qubits: int
        The number of qubits that the statevector describes
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The statevector after applying Y to |psi>
    '''
    phi = torch.empty_like(psi)
    I = torch.arange(2**total_qubits,device=psi.device)
    # X
    phi[I ^ (1 << target_qubit)] = psi
    # Z
    phi[I & (1 << target_qubit) > 0] *= -1
    # Y = -iZX
    return -1j*phi

def _y_on_densitymatrix(rho:torch.Tensor, total_qubits:int, target_qubit:int
                        )->torch.Tensor:
    '''
    Applies the X gate to the density matrix rho

    Arguments
    ---------
    rho: torch.Tensor
        The density matrix to apply Y to
    total_qubits: int
        The number of qubits that the density matrix describes
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The statevector after applying Y to rho
    '''
    rho1 = torch.empty_like(rho)
    I = torch.arange(2**total_qubits,device=rho.device)
    swap_indices = I ^ (1 << target_qubit)
    # X
    rho1 = rho[swap_indices[:,None], swap_indices[None,:]]
    # Z
    rho1[(I[:,None] ^ I[None,:]) & (1<<target_qubit) > 0] *= -1
    # Y rho Y = Z X rho X Z 
    return rho1

def pauli_y(qs:torch.Tensor, total_qubits:int, target_qubit:int)->torch.Tensor:
    '''
    Applies the Y gate to the passed quantum state

    Arguments
    ---------
    qs: torch.Tensor
        The quantum state to apply Y to
    total_qubits: int
        The number of qubits that the quantum state describes
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The quantum state after applying the Y gate
    
    Raises
    ------
    NotImplementedError
        - If `qs` is not a 1 or 2-dimensional tensor
    '''
    if not QTORCH_CONFIG['skipValidation']:
        validateInput(qs, total_qubits, target_qubit, num_targets=1)
    
    if qs.dim() == 1:
        return _y_on_statevector(qs, total_qubits, target_qubit)
    elif qs.dim() == 2:
        return _y_on_densitymatrix(qs,total_qubits,target_qubit)
    else:
        raise NotImplementedError()

@cuda.jit(device=True)
def y_statevector_kernel(out:torch.Tensor, psi:torch.Tensor, T:int, 
                         i:int)->None:
    '''
    Device function to calculate the i-th term of applying Y on a statevector

    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    psi: torch.Tensor
        the statevector to apply the operation on
    T: int
        `1 << target_qubit`
    i: int
        The index of the amplitude to be calculated
    '''
    s = 1 - 2 * int( (i&T) != 0)
    out[2*i]   = psi[2*(i^T)] * s
    out[2*i+1] = psi[2*(i^T)+1] * s

@cuda.jit(device=True)
def y_densitymatrix_kernel(out:torch.Tensor, rho:torch.Tensor, T:int, 
                           i:int, j:int)->None:
    '''
    Device function to calculate the (i,j)-th term of applying Y on a density
    matrix.

    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    rho: torch.Tensor
        the density matrix to apply the operation on
    T: int
        `1 << target_qubit`
    i, j: int, int
        the indices of the output density matrix to be calculated
    '''
    s = 1 - 2*int( ((i^j)&T) != 0 )
    out[i,2*j]   = rho[i^T, 2*(j^T)] * s
    out[i,2*j+1] = rho[i^T, 2*(j^T)+1] * s
