import torch
from torch.autograd import Function
import numba.cuda as cuda
import math

from qtorch.unitaries import rx,ry,rz,cnot
from qtorch.unitaries.kernels import (
    rx_statevector_kernel,
    ry_statevector_kernel,
    rz_statevector_kernel,
    cnot_statevector_kernel
)

class GateEnsemblePsi(Function):
    @staticmethod
    def forward(psi:torch.Tensor, theta:torch.Tensor, total_qubits:int, 
                target_qubit:int,thread_block_size:int)->torch.Tensor:
        assert psi.dim() == 2, '`psi` must be a batched 1-d tensor'
        assert psi.shape[1] == 2**total_qubits, (
            f'Expected psi shape to be [B, {2**total_qubits}], recieved '
            f'{psi.shape}')
        assert theta.dim() == 0, 'Expected `theta` to be a 1-dimensional tensor'
        
        if not torch.cuda.is_available():
            out = GateEnsemblePsi.forward_torch(psi, theta, total_qubits, 
                                                target_qubit)
        else:
            batch_size = psi.size(0)
            out = torch.empty([batch_size, total_qubits+3, 2**total_qubits],
                              dtype=psi.dtype, device=psi.device)
            threads_per_block = (thread_block_size,)
            blocks_per_grid = (math.ceil(psi.shape[1]/thread_block_size),
                               total_qubits+3,
                               batch_size)
            if psi.dtype == torch.complex64:
                FTYPE = torch.float32
            elif psi.dtype == torch.complex128:
                FTYPE = torch.float64
            else:
                raise TypeError('Expected `psi` to be of type complex64 or '
                'complex128')
            GateEnsemblePsi.forward_cuda[blocks_per_grid,threads_per_block](
                out.detach().view(FTYPE), psi.detach().view(FTYPE), 
                theta.item(), total_qubits, target_qubit, batch_size)
        return out

    
    @staticmethod
    def setup_context(ctx, inputs, output):
        psi, theta,total_qubits, target_qubit, thread_block_size = inputs
        ctx.total_qubits = total_qubits
        ctx.target_qubit = target_qubit
        ctx.thread_block_size = thread_block_size
        ctx.save_for_backward(psi, theta)
    
    @staticmethod
    def backward(ctx, grad_output):
        psi, theta = ctx.saved_tensors

        if not torch.cuda.is_available():
            psi_grad = GateEnsemblePsi.psi_grad_torch(grad_output, theta, 
                                                      ctx.total_qubits, 
                                                      ctx.target_qubit)
            theta_grad = GateEnsemblePsi.theta_grad_torch(grad_output, psi, 
                                                          theta, 
                                                          ctx.total_qubits,
                                                          ctx.target_qubit
                                                          )
        else:
            batch_size = psi.size(0)
            N = 2**ctx.total_qubits
            num_gates = ctx.total_qubits+3
            psi_grad = torch.empty([batch_size, num_gates, N],
                                   dtype=psi.dtype, device=psi.device)
            theta_grad = torch.empty([batch_size, 3, N], 
                                     dtype=psi.dtype, device=theta.device)
            
            threads_per_block = (ctx.thread_block_size,)
            blocks_per_grid = (math.ceil(N/ctx.thread_block_size),
                               num_gates,
                               batch_size)
            if psi.dtype == torch.complex64:
                FTYPE = torch.float32
            elif psi.dtype == torch.complex128:
                FTYPE = torch.float64
            else:
                raise TypeError('Expected `psi` to be of type complex64 or '
                'complex128')
            GateEnsemblePsi.psi_grad_cuda[blocks_per_grid,threads_per_block](
                psi_grad.detach().view(FTYPE), grad_output.detach().view(FTYPE),
                theta.item(), ctx.total_qubits, ctx.target_qubit, batch_size)
            
            GateEnsemblePsi.theta_grad_cuda[(blocks_per_grid[0],3,batch_size), 
                                            threads_per_block](
                                                theta_grad.detach().view(FTYPE),
                                                psi.detach().view(FTYPE),
                                                theta.item(),
                                                ctx.total_qubits,
                                                ctx.target_qubit,
                                                batch_size
                                            )
            psi_grad = psi_grad.sum(dim=1)
            theta_grad = (((grad_output[:,1:4].conj()*theta_grad).real)/2
                          ).sum(dim=(1,2))
        
        return psi_grad, theta_grad, None, None, None

    @staticmethod
    def forward_torch(psi, theta, total_qubits, target_qubit):
        batch_size = psi.size(0)
        out = torch.empty([batch_size, total_qubits+3,2**total_qubits], 
                          dtype=psi.dtype, device=psi.device)
        for b in range(batch_size):
            out[b, 0] = psi[b]
            out[b, 1] = rx(psi[b], theta, total_qubits, target_qubit)
            out[b, 2] = ry(psi[b], theta, total_qubits, target_qubit)
            out[b, 3] = rz(psi[b], theta, total_qubits, target_qubit)
            for i in range(4,total_qubits+3):
                control_qubit = i-4 if i-4 < target_qubit else i-3
                out[b, i] = cnot(psi[b], total_qubits, target_qubit, 
                                 control_qubit)
        return out
    
    @staticmethod
    @cuda.jit
    def forward_cuda(out:torch.Tensor, psi:torch.Tensor, theta:float, 
                     total_qubits:int, target_qubit:int,batch_size:int)->None:
        tx = cuda.threadIdx.x
        bx = cuda.blockIdx.x
        G = cuda.blockIdx.y
        B = cuda.blockIdx.z
        bsize = cuda.blockDim.x

        i = bx*bsize + tx

        N = 1 << total_qubits
        T = 1 << target_qubit

        if i >= N or G >= total_qubits+3 or B >= batch_size:
            return
        
        if G == 0:
            out[B, G, 2*i] = psi[B, 2*i]
            out[B, G, 2*i+1] = psi[B, 2*i+1]
        elif G == 1:
            rx_statevector_kernel(out[B, G], psi[B], T, theta, i)
        elif G == 2:
            ry_statevector_kernel(out[B, G], psi[B], T, theta, i)
        elif G == 3:
            rz_statevector_kernel(out[B, G], psi[B], T, theta, i)
        else:
            if G-4 < target_qubit:
                control_qubit = G-4
            else:
                control_qubit = G-3
            C = 1 << control_qubit
            cnot_statevector_kernel(out[B, G], psi[B], T, C, i)
    
    @staticmethod
    def psi_grad_torch(grad_output, theta, total_qubits, target_qubit):
        batch_size = grad_output.size(0)
        out = torch.empty([batch_size, total_qubits+3, 2**total_qubits],
                          dtype=grad_output.dtype, device=grad_output.device)
        for b in range(batch_size):
            out[b, 0] = grad_output[b, 0]
            out[b, 1] = rx(grad_output[b, 1], -theta, total_qubits,
                           target_qubit)
            out[b, 2] = ry(grad_output[b, 2], -theta, total_qubits,
                           target_qubit)
            out[b, 3] = rz(grad_output[b, 3], -theta, total_qubits,
                           target_qubit)
            for i in range(4,total_qubits+3):
                control_qubit = i-4 if i-4 < target_qubit else i-3
                out[b, i] = cnot(grad_output[b, i], total_qubits, target_qubit, 
                            control_qubit)
        
        return out.sum(dim=1)
    
    @staticmethod
    @cuda.jit
    def psi_grad_cuda(out:torch.Tensor, grad_output:torch.Tensor, 
                      theta:float, total_qubits:int, target_qubit:int,
                      batch_size:int)->None:
        tx = cuda.threadIdx.x
        bx = cuda.blockIdx.x
        G = cuda.blockIdx.y
        B = cuda.blockIdx.z
        bsize = cuda.blockDim.x

        i = bx*bsize + tx

        N = 1 << total_qubits
        T = 1 << target_qubit

        if i >= N or G >= (total_qubits+3) or B >= batch_size:
            return
        
        if G == 0:
            out[B, G, 2*i] = grad_output[B, G, 2*i]
            out[B, G, 2*i+1] = grad_output[B, G, 2*i+1]
        elif G == 1:
            rx_statevector_kernel(out[B, G], grad_output[B, G], T, -theta, i)
        elif G == 2:
            ry_statevector_kernel(out[B, G], grad_output[B, G], T, -theta, i)
        elif G == 3:
            rz_statevector_kernel(out[B, G], grad_output[B, G], T, -theta, i)
        else:
            if G-4 < target_qubit:
                control_qubit = G-4
            else:
                control_qubit = G-3
            C = 1 << control_qubit
            cnot_statevector_kernel(out[B, G], grad_output[B, G], T, C, i)
    
    @staticmethod
    def theta_grad_torch(grad_output:torch.Tensor, psi:torch.Tensor, 
                         theta:torch.Tensor, total_qubits:int, target_qubit:int
                         )->torch.Tensor:
        batch_size = psi.size(0)
        out = torch.empty([batch_size, 3,2**total_qubits], dtype=psi.dtype,
                          device=psi.device)
        for b in range(batch_size):
            out[b, 0] = rx(psi[b], theta+torch.pi, total_qubits,
                           target_qubit)
            out[b, 1] = ry(psi[b], theta+torch.pi, total_qubits,
                           target_qubit)
            out[b, 2] = rz(psi[b], theta+torch.pi, total_qubits,
                           target_qubit)

        return (grad_output[:,1:4].conj()*out/2).real.sum(dim=(1,2))
    
    @staticmethod
    @cuda.jit
    def theta_grad_cuda(out:torch.Tensor, psi:torch.Tensor, theta:float,
                        total_qubits:int, target_qubit:int, batch_size:int
                        )->None:
        tx = cuda.threadIdx.x
        bx = cuda.blockIdx.x
        G = cuda.blockIdx.y
        B = cuda.blockIdx.z
        bsize = cuda.blockDim.x

        i = bx*bsize + tx

        N = 1 << total_qubits
        T = 1 << target_qubit

        if i >= N or G >= 3 or B >= batch_size:
            return
        
        if G == 0:
            rx_statevector_kernel(out[B, G], psi[B], T, theta+torch.pi, i)
        elif G == 1:
            ry_statevector_kernel(out[B, G], psi[B], T, theta+torch.pi, i)
        else: # G == 2
            rz_statevector_kernel(out[B, G], psi[B], T, theta+torch.pi, i)
