import torch
from numba import cuda
from qtorch.config import QTORCH_CONFIG
from qtorch import RTYPE

import math

from .utils import validateInput

def _ry_on_statevector(psi:torch.Tensor, theta:torch.Tensor, total_qubits:int, 
                       target_qubit:int)->torch.Tensor:
    '''
    Applies the Ry gate to |psi>

    Arguments
    ---------
    psi: torch.Tensor
        The statevector to apply the gate to
    theta: torch.Tensor|float
        The rotation angle for the gate
    total_qubits: int
        The number of qubits in the quantum state
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The statevector after applying the gate to |psi>
    '''
    # Ry = cos(theta/2) I - i sin(theta/2) Y
    #    = cos(theta/2) I - sin(theta/2) ZX 
    phi = torch.empty_like(psi)
    I = torch.arange(2**total_qubits,device=psi.device)
    # X
    phi[I ^ (1 << target_qubit)] = psi
    # Z
    phi[I & (1 << target_qubit) > 0] *= -1

    return torch.cos(theta/2)*psi - torch.sin(theta/2)*phi
    

def _ry_on_densitymatrix(rho:torch.Tensor, theta:torch.Tensor, total_qubits:int, 
                         target_qubit:int)->torch.Tensor:
    '''
    Applies the Ry gate to the density matrix rho

    Arguments
    ---------
    rho: torch.Tensor
        The density matrix to apply the gate to
    theta: torch.Tensor|float
        The rotation angle for the gate
    total_qubits: int
        The number of qubits in the quantum state
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The density matrix after applying the gate to rho
    '''
    # Ry = cos(theta/2) I - sin(theta/2) ZX

    # Ry rho Rx = cos^2(theta/2) rho + sin^2(theta/2) ZX rho XZ
    #           + (i sin(theta) / 2) (rho XZ - ZX rho)
    I = torch.arange(2**total_qubits,device=rho.device)
    swap_indices = I ^ (1 << target_qubit)
    phase_mask = ((I[:,None] ^ I[None,:]) & (1<<target_qubit)) > 0

    # sin^2 term
    # X rho X
    rho1 = rho[swap_indices[:,None], swap_indices[None,:]]
    # ZX rho XZ
    rho1[phase_mask] *= -1

    # rho X
    rho2r = rho[I[:,None], swap_indices[None,:]]
    # rho XZ 
    rho2r[:, I&(1<<target_qubit) > 0] *= -1

    # X rho
    rho2l = rho[swap_indices[:,None], I[None,:]]
    # ZX rho
    rho2l[I&(1<<target_qubit) > 0, :] *= -1
    
    return (
        torch.cos(theta/2)**2 * rho
        + torch.sin(theta/2)**2 * rho1
        - (torch.sin(theta)/2) * (rho2r+rho2l)
    )

def ry(qs:torch.Tensor, theta:float|torch.Tensor, total_qubits:int, 
       target_qubit:int)->torch.Tensor:
    '''
    Applies the Ry gate to the passed quantum state

    Arguments
    ---------
    qs: torch.Tensor
        The quantum state to apply the gate to
    theta: torch.Tensor|float
        The rotation angle for the gate
    total_qubits: int
        The number of qubits in the quantum state
    target_qubit: int
        The index of the target qubit
    
    Returns
    -------
    torch.Tensor:
        The quantum state after applying the gate
    
    Raises
    ------
    NotImplementedError
        - If `qs` is not a 1 or 2-dimensional tensor
    '''
    if not QTORCH_CONFIG['skipValidation']:
        validateInput(qs, total_qubits, target_qubit, num_targets=1)
        if not isinstance(theta, torch.Tensor):
            theta = torch.tensor(theta,dtype=RTYPE,device=qs.device)
        assert theta.dim() == 0, '`theta` must be a scalar (0-dim tensor)'
    
    if qs.dim() == 1:
        return _ry_on_statevector(qs, theta, total_qubits, target_qubit)
    elif qs.dim() == 2:
        return _ry_on_densitymatrix(qs, theta, total_qubits,target_qubit)
    else:
        raise NotImplementedError()

@cuda.jit(device=True)
def ry_statevector_kernel(out:torch.Tensor, psi:torch.Tensor, T:int, 
                          theta:float, i:int)->None:
    '''
    Device function to calculate the i-th term of applying Ry on a statevector

    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    psi: torch.Tensor
        the statevector to apply the operation on
    T: int
        `1 << target_qubit`
    theta: float
        The angle passed to the gate
    i: int
        The index of the amplitude to be calculated
    '''
    cos_term = math.cos(theta/2)
    sin_term = math.sin(theta/2)
    s = 1 - 2 * int( (i & T)  != 0 )
    out[2*i]   = cos_term*psi[2*i]   - sin_term*psi[2*(i^T)] * s
    out[2*i+1] = cos_term*psi[2*i+1] - sin_term*psi[2*(i^T)+1] * s

@cuda.jit(device=True)
def ry_densitymatrix_kernel(out:torch.Tensor, rho:torch.Tensor, T:int, 
                            theta:float, i:int, j:int)->None:
    '''
    Device function to calculate the (i,j)-th term of applying Ry on a density
    matrix.

    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    rho: torch.Tensor
        the density matrix to apply the operation on
    T: int
        `1 << target_qubit`
    theta: float
        The angle passed to the gate
    i, j: int, int
        the indices of the output density matrix to be calculated
    '''
    cos2   = math.cos(theta/2)**2
    sin2   = math.sin(theta/2)**2
    sin_val = math.sin(theta)
    s1 = 1 - 2 * int(((i ^ j) & (T)) != 0)
    s2 = 1 - 2 * int((j & (T)) != 0)
    s3 = 1 - 2 * int((i & (T)) != 0)

    out[i, 2*j] = (
        cos2 * rho[i, 2*j]
        + sin2 * rho[i ^ T, 2 * (j ^ T)] * s1
        - (sin_val / 2) * (rho[i, 2 * (j ^ T)] * s2
                        - rho[i ^ T, 2*j] * -s3)
    )
    out[i, 2*j+1] = (
        cos2 * rho[i, 2*j+1]
        + sin2 * rho[i ^ T, 2 * (j ^ T) + 1] * s1
        - (sin_val / 2) * (rho[i, 2 * (j ^ T) + 1] * s2
                        - rho[i ^ T, 2*j+1] * -s3)
    )

@cuda.jit(device=True)
def ry_density_matrix_theta_grad_kernel(out:torch.Tensor, rho:torch.Tensor, 
                                        grad_output:torch.Tensor, T:int, 
                                        theta:float, i:int, j:int)->None:
    '''
    Device function to calculate the contribution of the (i,j)-th term of the 
    density matrix (rho) towards the theta gradient of Ry
    
    Arguments
    ---------
    out: torch.Tensor
        where the output of the operation is stored
    rho: torch.Tensor
        the input density matrix
    grad_output: torch.Tensor
        the gradient passed back with respect to the Rx operation
    T: int
        `1 << target_qubit`
    theta: float
        The angle passed to the gate
    i, j: int, int
        the index of the output contribution term to be calculated
    '''
    sin_term = math.sin(theta)/2
    cos_term = math.cos(theta)/2
    s1 = 1 - 2 * int(((i ^ j) & (T)) != 0)
    s2 = 1 - 2 * int((j & (T)) != 0)
    s3 = 1 - 2 * int((i & (T)) != 0)

    out[i,j] = (
        grad_output[i,2*j] * (
            - sin_term * rho[i,2*j]
            + sin_term * rho[i^T, 2*(j^T)] * s1
            - cos_term * (
                rho[i,2*(j^T)] * s2
                - rho[i^T,2*j] * -s3
            )
        )
        + grad_output[i,2*j+1] * (
            - sin_term * rho[i,2*j+1]
            + sin_term * rho[i^T, 2*(j^T)+1] * s1
            - cos_term * (
                rho[i,2*(j^T)+1] * s2
                - rho[i^T,2*j+1] * -s3
            )
        )
    )
