import torch
from torch.autograd import Function
import numba.cuda as cuda
import math

from qtorch.unitaries import rx,ry,rz,cnot
from qtorch.unitaries.kernels import (
    rx_statevector_kernel,
    ry_statevector_kernel,
    rz_statevector_kernel,
    cnot_statevector_kernel
)

class MicroGateEnsemblePsi(Function):
    @staticmethod
    def forward(psi: torch.Tensor, theta:torch.Tensor, total_qubits:int, 
                qubit_support: torch.Tensor, target_support_id:int, 
                thread_block_size:int)->torch.Tensor:
        if not torch.cuda.is_available():
            out = MicroGateEnsemblePsi.forward_torch(psi, theta, total_qubits,
                                                     qubit_support, 
                                                     target_support_id)
        else:
            if psi.dtype == torch.complex64:
                FTYPE = torch.float32
            elif psi.dtype == torch.complex128:
                FTYPE = torch.float64
            else:
                raise TypeError('Expected `psi` to be of type complex64 or '
                'complex128')
            num_support = qubit_support.shape[0]
            num_gates = 3 + num_support
            out = torch.empty([num_gates,2**total_qubits],dtype=psi.dtype, 
                              device=psi.device)
            threads_per_block = (thread_block_size,)
            blocks_per_grid = (math.ceil(psi.shape[0]/thread_block_size),
                               num_gates)
            MicroGateEnsemblePsi.forward_cuda[
                blocks_per_grid,
                threads_per_block
            ](out.detach().view(FTYPE), psi.detach().view(FTYPE), theta.item(),
              total_qubits, qubit_support, target_support_id)
        return out
    
    @staticmethod
    def setup_context(ctx, inputs, output):
        (psi, theta,total_qubits, qubit_support, target_support_id, 
         thread_block_size) = inputs
        ctx.total_qubits = total_qubits
        ctx.target_support_id = target_support_id
        ctx.num_support = qubit_support.shape[0]
        ctx.thread_block_size = thread_block_size
        ctx.save_for_backward(psi, theta, qubit_support)
    
    @staticmethod
    def backward(ctx, grad_output):
        psi, theta, qubit_support = ctx.saved_tensors

        if not torch.cuda.is_available():
            psi_grad = MicroGateEnsemblePsi.psi_grad_torch(
                grad_output, theta, ctx.total_qubits, qubit_support, 
                ctx.target_support_id 
            )
            theta_grad = MicroGateEnsemblePsi.theta_grad_torch(
                grad_output, psi, theta, ctx.total_qubits, qubit_support,
                ctx.target_support_id
            )
        else:
            psi_grad = torch.empty([ctx.num_support+3,2**ctx.total_qubits],
                                   dtype=psi.dtype,device=psi.device)
            theta_grad = torch.empty([3,psi.shape[0]], dtype=psi.dtype,
                                     device=theta.device)
            threads_per_block = (ctx.thread_block_size,)
            blocks_per_grid = (math.ceil(psi.shape[0]/ctx.thread_block_size),
                               ctx.num_support+3)
            
            if psi.dtype == torch.complex64:
                FTYPE = torch.float32
            elif psi.dtype == torch.complex128:
                FTYPE = torch.float64
            else:
                raise TypeError('Expected `psi` to be of type complex64 or '
                'complex128')
            MicroGateEnsemblePsi.psi_grad_cuda[
                blocks_per_grid,
                threads_per_block
                ](
                    psi_grad.detach().view(FTYPE), 
                    grad_output.detach().view(FTYPE),
                    theta.item(),
                    ctx.total_qubits, qubit_support, ctx.target_support_id
                )
            MicroGateEnsemblePsi.theta_grad_cuda[
                (blocks_per_grid[0],3), 
                threads_per_block
            ](
                theta_grad.detach().view(FTYPE),
                psi.detach().view(FTYPE),
                theta.item(),
                ctx.total_qubits, qubit_support, ctx.target_support_id
            )
            theta_grad = (grad_output[1:4,:].conj()*theta_grad).real/2
        
        return psi_grad, theta_grad.sum(), None, None, None, None
    
    @staticmethod
    def forward_torch(psi:torch.Tensor, theta:torch.Tensor, total_qubits:int,
                      qubit_support:torch.Tensor, target_support_id:int
                      )->torch.Tensor:
        num_support = qubit_support.shape[0]
        target_qubit = qubit_support[target_support_id].item()
        out = torch.empty([num_support+3, 2**total_qubits],dtype=psi.dtype,
                          device=psi.device)
        out[0] = psi
        out[1] = rx(psi, theta, total_qubits, target_qubit)
        out[2] = ry(psi, theta, total_qubits, target_qubit)
        out[3] = rz(psi, theta, total_qubits, target_qubit)
        for i in range(4,num_support+3):
            control_support_id = i-4 if i-4 < target_support_id else i-3
            control_qubit = qubit_support[control_support_id].item()
            out[i] = cnot(psi, total_qubits, target_qubit, control_qubit)
        return out
    
    @staticmethod
    @cuda.jit
    def forward_cuda(out:torch.Tensor, psi:torch.Tensor, theta:float, 
                     total_qubits:int, qubit_support:torch.Tensor, 
                     target_support_id:int)->None:
        tx = cuda.threadIdx.x
        bx = cuda.blockIdx.x
        G = cuda.blockIdx.y
        bsize = cuda.blockDim.x

        i = bx*bsize + tx

        N = 1 << total_qubits
        if i >= N:
            return
        
        target_qubit = qubit_support[target_support_id]
        T = 1 << target_qubit

        if G == 0:
            out[G,2*i] = psi[2*i]
            out[G, 2*i+1] = psi[2*i+1]
        elif G == 1:
            rx_statevector_kernel(out[G], psi, T, theta, i)
        elif G == 2:
            ry_statevector_kernel(out[G], psi, T, theta, i)
        elif G == 3:
            rz_statevector_kernel(out[G], psi, T, theta, i)
        else:
            if G-4 < target_support_id:
                control_support_id = G-4
            else:
                control_support_id = G-3
            control_qubit = qubit_support[control_support_id]
            C = 1 << control_qubit
            cnot_statevector_kernel(out[G], psi, T, C, i)

    @staticmethod
    def psi_grad_torch(grad_output:torch.Tensor, theta:torch.Tensor, 
                       total_qubits:int, qubit_support:torch.Tensor, 
                       target_support_id:int)->torch.Tensor:
        num_support = qubit_support.shape[0]
        target_qubit = qubit_support[target_support_id].item()
        out = torch.empty([num_support+3, 2**total_qubits],
                          dtype=grad_output.dtype,
                          device=grad_output.device)
        out[0] = grad_output[0]
        out[1] = rx(grad_output[1], -theta, total_qubits, target_qubit)
        out[2] = ry(grad_output[2], -theta, total_qubits, target_qubit)
        out[3] = rz(grad_output[3], -theta, total_qubits, target_qubit)
        for i in range(4,num_support+3):
            control_support_id = i-4 if i-4 < target_support_id else i-3
            control_qubit = qubit_support[control_support_id].item()
            out[i] = cnot(grad_output[i], total_qubits, target_qubit, 
                          control_qubit)
        return out
    
    @staticmethod
    @cuda.jit
    def psi_grad_cuda(out:torch.Tensor, grad_output:torch.Tensor, theta:float,
                      total_qubits:int, qubit_support:torch.Tensor, 
                      target_support_id:int)->None:
        tx = cuda.threadIdx.x
        bx = cuda.blockIdx.x
        G = cuda.blockIdx.y
        bsize = cuda.blockDim.x

        i = bx*bsize + tx

        N = 1 << total_qubits
        if i >= N:
            return
        
        target_qubit = qubit_support[target_support_id]
        T = 1 << target_qubit
        
        if G == 0:
            out[G, 2*i] = grad_output[G, 2*i]
            out[G, 2*i+1] = grad_output[G, 2*i+1]
        elif G == 1:
            rx_statevector_kernel(out[G], grad_output[G], T, -theta, i)
        elif G == 2:
            ry_statevector_kernel(out[G], grad_output[G], T, -theta, i)
        elif G == 3:
            rz_statevector_kernel(out[G], grad_output[G], T, -theta, i)
        else:
            if G-4 < target_support_id:
                control_support_id = G-4
            else:
                control_support_id = G-3
            control_qubit = qubit_support[control_support_id]
            C = 1 << control_qubit
            cnot_statevector_kernel(out[G], grad_output[G], T, C, i)
    
    @staticmethod
    def theta_grad_torch(grad_output:torch.Tensor, psi:torch.Tensor, 
                         theta:torch.Tensor, total_qubits:int, 
                         qubit_support:torch.Tensor, target_support_id:int
                         )->torch.Tensor:
        out = torch.empty([3,2**total_qubits], dtype=psi.dtype,
                          device=psi.device)
        target_qubit = qubit_support[target_support_id].item()
        out[0] = rx(psi, theta+torch.pi, total_qubits, target_qubit)
        out[1] = ry(psi, theta+torch.pi, total_qubits, target_qubit)
        out[2] = rz(psi, theta+torch.pi, total_qubits, target_qubit)

        return (grad_output[1:4,:].conj()*out/2).real
    
    @staticmethod
    @cuda.jit
    def theta_grad_cuda(out:torch.Tensor, psi:torch.Tensor, theta:float,
                        total_qubits:int, qubit_support:torch.Tensor,
                        target_support_id:int)->None:
        tx = cuda.threadIdx.x
        bx = cuda.blockIdx.x
        G = cuda.blockIdx.y
        bsize = cuda.blockDim.x

        i = bx*bsize + tx

        N = 1 << total_qubits
        if i >= N or G >= 3:
            return
        target_qubit = qubit_support[target_support_id]
        T = 1 << target_qubit

        if G == 0:
            rx_statevector_kernel(out[G], psi, T, theta+torch.pi, i)
        elif G == 1:
            ry_statevector_kernel(out[G], psi, T, theta+torch.pi, i)
        else: # G == 2
            rz_statevector_kernel(out[G], psi, T, theta+torch.pi, i)
