import torch
from numba import cuda
from qtorch.config import QTORCH_CONFIG
from qtorch import RTYPE

import math

from .utils import validateInput

def _rz_on_statevector(psi:torch.Tensor, theta:torch.Tensor, total_qubits:int, 
                       target_qubit:int)->torch.Tensor:
    '''
    Applies the Rz gate to |psi>

    Arguments
    ---------
    psi: torch.Tensor
        The statevector to apply the gate to
    theta: torch.Tensor|float
        The rotation angle for the gate
    total_qubits: int
        The number of qubits in the quantum state
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The statevector after applying the gate to |psi>
    '''
    # Rz = cos(theta/2) I - i sin(theta/2) Z
    phi = psi.clone()
    I = torch.arange(2**total_qubits,device=psi.device)
    phi[I & (1 << target_qubit) > 0] *= -1
    phi = -1j*torch.sin(theta/2)*phi + torch.cos(theta/2) * psi
    return phi
    

def _rz_on_densitymatrix(rho:torch.Tensor, theta:torch.Tensor, total_qubits:int, 
                         target_qubit:int)->torch.Tensor:
    '''
    Applies the Rz gate to the density matrix rho

    Arguments
    ---------
    rho: torch.Tensor
        The density matrix to apply the gate to
    theta: torch.Tensor|float
        The rotation angle for the gate
    total_qubits: int
        The number of qubits in the quantum state
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The density matrix after applying the gate to rho
    '''
    # Rz = cos(theta/2) I - i sin(theta/2) Z

    # Rz rho Rz = cos^2(theta/2) rho + sin^2(theta/2) Z rho Z
    #           + (i sin(theta) / 2) (rho Z - Z rho)
    I = torch.arange(2**total_qubits,device=rho.device)
    phase_mask = ((I[:,None] ^ I[None,:]) & (1<<target_qubit)) > 0
    
    # sin^2 term
    rho1 = rho.clone()
    rho1[phase_mask] *= -1


    # sin term
    rho2 = torch.zeros_like(rho)
    rho2[phase_mask] = 2*rho[phase_mask]
    rho2[:,I&(1<<target_qubit)>0] *= -1

    return (
        (torch.cos(theta/2)**2 * rho) 
        # + rho1
        + (torch.sin(theta/2)**2 * rho1)
        + ((1j*torch.sin(theta)/2) * rho2)
    )

def rz(qs:torch.Tensor, theta:float|torch.Tensor, total_qubits:int, 
       target_qubit:int)->torch.Tensor:
    '''
    Applies the Rz gate to the passed quantum state

    Arguments
    ---------
    qs: torch.Tensor
        The quantum state to apply the gate to
    theta: torch.Tensor|float
        The rotation angle for the gate
    total_qubits: int
        The number of qubits in the quantum state
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The quantum state after applying the gate
    
    Raises
    ------
    NotImplementedError
        - If `qs` is not a 1 or 2-dimensional tensor
    '''
    if not QTORCH_CONFIG['skipValidation']:
        validateInput(qs, total_qubits, target_qubit, num_targets=1)
        if not isinstance(theta, torch.Tensor):
            theta = torch.tensor(theta,dtype=RTYPE,device=qs.device)
        assert theta.dim() == 0, '`theta` must be a scalar (0-dim tensor)'
    
    if qs.dim() == 1:
        return _rz_on_statevector(qs, theta, total_qubits, target_qubit)
    elif qs.dim() == 2:
        return _rz_on_densitymatrix(qs, theta, total_qubits,target_qubit)
    else:
        raise NotImplementedError()

@cuda.jit(device=True)
def rz_statevector_kernel(out:torch.Tensor, psi:torch.Tensor, T:int, 
                          theta:float, i:int)->None:
    '''
    Device function to calculate the i-th term of applying Rz on a statevector

    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    psi: torch.Tensor
        the statevector to apply the operation on
    T: int
        `1 << target_qubit`
    theta: float
        The angle passed to the gate
    i: int
        The index of the amplitude to be calculated
    '''
    cos_term = math.cos(theta/2)
    sin_term = math.sin(theta/2)
    s = 1 - 2 * int( (i & T)  != 0 )
    out[2*i] = cos_term*psi[2*i] + sin_term*psi[2*i+1] * s
    out[2*i+1] = cos_term*psi[2*i+1] - sin_term*psi[2*i] * s

@cuda.jit(device=True)
def rz_densitymatrix_kernel(out:torch.Tensor, rho:torch.Tensor, T:int, 
                            theta:float, i:int, j:int)->None:
    '''
    Device function to calculate the (i,j)-th term of applying Rz on a density
    matrix.

    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    rho: torch.Tensor
        the density matrix to apply the operation on
    T: int
        `1 << target_qubit`
    theta: float
        The angle passed to the gate
    i, j: int, int
        the indices of the output density matrix to be calculated
    '''
    cos2   = math.cos(theta/2)**2
    sin2   = math.sin(theta/2)**2
    sin_val = math.sin(theta)
    s1 = 1 - 2 * int(((i ^ j) & (T)) != 0)
    s2 = 1 - 2 * int((j & (T)) != 0)
    s3 = 1 - 2 * int((i & (T)) != 0)

    out[i, 2*j] = (
        cos2 * rho[i, 2*j]
        + sin2 * rho[i, 2*j] * s1
        - (sin_val / 2) * rho[i, 2*j+1] * (s2 - s3)
    )
    out[i, 2*j+1] = (
        cos2 * rho[i, 2*j+1]
        + sin2 * rho[i, 2*j+1] * s1
        + (sin_val / 2) * rho[i, 2*j] * (s2 - s3)
    )

@cuda.jit(device=True)
def rz_density_matrix_theta_grad_kernel(out:torch.Tensor, rho:torch.Tensor, 
                                        grad_output:torch.Tensor, T:int, 
                                        theta:float, i:int, j:int)->None:
    '''
    Device function to calculate the contribution of the (i,j)-th term of the 
    density matrix (rho) towards the theta gradient of Rz
    
    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    rho: torch.Tensor
        the input density matrix
    grad_output: torch.Tensor
        the gradient passed back with respect to the Rx operation
    T: int
        `1 << target_qubit`
    theta: float
        The angle passed to the gate
    i, j: int, int
        the index of the output contribution term to be calculated
    '''
    sin_term = math.sin(theta)/2
    cos_term = math.cos(theta)/2
    s1 = 1 - 2 * int(((i ^ j) & (T)) != 0)
    s2 = 1 - 2 * int((j & (T)) != 0)
    s3 = 1 - 2 * int((i & (T)) != 0)
    
    out[i,j] = (
        grad_output[i,2*j] * (
            - sin_term * rho[i,2*j]
            + sin_term * rho[i,2*j] * s1
            - cos_term * rho[i,2*j+1] * (s2 - s3)
        )
        + grad_output[i,2*j+1] * (
            - sin_term * rho[i,2*j+1]
            + sin_term * rho[i,2*j+1] * s1
            + cos_term * rho[i,2*j] * (s2 - s3)
        )
    )
