import torch
from numba import cuda

import math

from qtorch.config import QTORCH_CONFIG
from qtorch import RTYPE

from .utils import validateInput

def _rx_on_statevector(psi:torch.Tensor, theta:torch.Tensor, total_qubits:int, 
                       target_qubit:int)->torch.Tensor:
    '''
    Applies the Rx gate to |psi>

    Arguments
    ---------
    psi: torch.Tensor
        The statevector to apply the gate to
    theta: torch.Tensor|float
        The rotation angle for the gate
    total_qubits: int
        The number of qubits in the quantum state
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The statevector after applying the gate to |psi>
    '''
    # Rx = cos(theta/2) I - i sin(theta/2) X
    # Inverse uses -theta instead of theta (only affects the sin term)
    phi = torch.empty_like(psi)
    I = torch.arange(2**total_qubits, device=psi.device)
    phi[I ^ (1 << target_qubit)] = -1j*torch.sin(theta/2) * psi
    phi = phi + torch.cos(theta/2) * psi
    return phi
    

def _rx_on_densitymatrix(rho:torch.Tensor, theta:torch.Tensor, total_qubits:int,
                         target_qubit:int)->torch.Tensor:
    '''
    Applies the Rx gate to the density matrix rho

    Arguments
    ---------
    rho: torch.Tensor
        The density matrix to apply the gate to
    theta: torch.Tensor|float
        The rotation angle for the gate
    total_qubits: int
        The number of qubits in the quantum state
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The density matrix after applying the gate to rho
    '''
    # Rx = cos(theta/2) I - i sin(theta/2) X

    # Rx rho Rx = cos^2(theta/2) rho + sin^2(theta/2) X rho X
    #           + (i sin(theta) / 2) (rho X - X rho)
    I = torch.arange(2**total_qubits,device=rho.device)
    swap_indices = I ^ (1 << target_qubit)
    
    return(
        torch.cos(theta/2)**2 * rho +
        torch.sin(theta/2)**2 * rho[swap_indices[:,None], swap_indices[None,:]]+
        (1j*torch.sin(theta)/2) * (
            rho[I[:,None], swap_indices[None,:]] - 
            rho[swap_indices[:,None], I[None,:]]
        )
    )

def rx(qs:torch.Tensor, theta:float|torch.Tensor, total_qubits:int, 
       target_qubit:int)->torch.Tensor:
    '''
    Applies the Rx gate to the passed quantum state

    Arguments
    ---------
    qs: torch.Tensor
        The quantum state to apply the gate to
    theta: torch.Tensor|float
        The rotation angle for the gate
    total_qubits: int
        The number of qubits in the quantum state
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The quantum state after applying the gate
    
    Raises
    ------
    NotImplementedError
        - If `qs` is not a 1 or 2-dimensional tensor
    '''
    if not QTORCH_CONFIG['skipValidation']:
        validateInput(qs, total_qubits, target_qubit, num_targets=1)
        if not isinstance(theta, torch.Tensor):
            theta = torch.tensor(theta,dtype=RTYPE,device=qs.device)
        assert theta.dim() == 0, '`theta` must be a scalar (0-dim tensor)'
    
    if qs.dim() == 1:
        return _rx_on_statevector(qs, theta, total_qubits, target_qubit)
    elif qs.dim() == 2:
        return _rx_on_densitymatrix(qs, theta, total_qubits,target_qubit)
    else:
        raise NotImplementedError()

@cuda.jit(device=True)
def rx_statevector_kernel(out:torch.Tensor, psi:torch.Tensor, T:int, 
                          theta:float, i:int)->None:
    '''
    Device function to calculate the i-th term of applying Rx on a statevector

    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    psi: torch.Tensor
        the statevector to apply the operation on
    T: int
        `1 << target_qubit`
    theta: float
        The angle passed to the gate
    i: int
        The index of the amplitude to be calculated
    '''
    cos_term = math.cos(theta/2)
    sin_term = math.sin(theta/2)
    out[2*i]   = cos_term*psi[2*i] + sin_term*psi[2*(i^T)+1]
    out[2*i+1] = cos_term*psi[2*i+1] - sin_term*psi[2*(i^T)]

@cuda.jit(device=True)
def rx_densitymatrix_kernel(out:torch.Tensor, rho:torch.Tensor, T:int, 
                            theta:float, i:int, j:int)->None:
    '''
    Device function to calculate the (i,j)-th term of applying Rx on a density
    matrix.

    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    rho: torch.Tensor
        the density matrix to apply the operation on
    T: int
        `1 << target_qubit`
    theta: float
        The angle passed to the gate
    i, j: int, int
        the indices of the output density matrix to be calculated
    '''
    cos2   = math.cos(theta/2)**2
    sin2   = math.sin(theta/2)**2
    sin_val = math.sin(theta)
    
    out[i, 2*j] = (
        cos2 * rho[i, 2*j]
        + sin2 * rho[i ^ (T), 2 * (j ^ T)]
        - (sin_val / 2) * (rho[i, 2 * (j ^ T) + 1]
                        - rho[i ^ T, 2*j+1])
    )
    out[i, 2*j+1] = (
        cos2 * rho[i, 2*j+1]
        + sin2 * rho[i ^ T, 2 * (j ^ T) + 1]
        + (sin_val / 2) * (rho[i, 2 * (j ^ T)]
                        - rho[i ^ T, 2*j])
    )

@cuda.jit(device=True)
def rx_density_matrix_theta_grad_kernel(out:torch.Tensor, rho:torch.Tensor, 
                                        grad_output:torch.Tensor, T:int, 
                                        theta:float, i:int, j:int)->None:
    '''
    Device function to calculate the contribution of the (i,j)-th term of the 
    density matrix (rho) towards the theta gradient of Rx
    
    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    rho: torch.Tensor
        the input density matrix
    grad_output: torch.Tensor
        the gradient passed back with respect to the Rx operation
    T: int
        `1 << target_qubit`
    theta: float
        The angle passed to the gate
    i, j: int, int
        the index of the output contribution term to be calculated
    '''
    sin_term = math.sin(theta)/2
    cos_term = math.cos(theta)/2
    out[i,j] = (
        grad_output[i,2*j] * (
            - sin_term * rho[i,2*j] 
            + sin_term * rho[i^T,2*(j^T)] 
            - cos_term * (rho[i,2*(j^T)+1] - rho[i^T,2*j+1])
            )
        + grad_output[i,2*j+1] * (
            - sin_term * rho[i,2*j+1] 
            + sin_term * rho[i^T,2*(j^T)+1] 
            + cos_term * (rho[i,2*(j^T)] - rho[i^T,2*j])
            )
    )
