import torch
import numba.cuda as cuda
import math
from torch.autograd import Function
from itertools import product, count

from qtorch.unitaries import rx,ry,rz,cnot
from qtorch.unitaries.kernels import (
    rx_densitymatrix_kernel, rx_density_matrix_theta_grad_kernel,
    ry_densitymatrix_kernel, ry_density_matrix_theta_grad_kernel,
    rz_densitymatrix_kernel, rz_density_matrix_theta_grad_kernel,
    cnot_density_matrix_kernel
)

class MicroGateEnsembleRho(Function):
    @staticmethod
    def forward(rho:torch.Tensor, theta:torch.Tensor, total_qubits:int,
                qubit_support:torch.Tensor, target_support_id:int, 
                thread_block_size:int)->torch.Tensor:
        if not torch.cuda.is_available():
            out = MicroGateEnsembleRho.forward_torch(rho,theta,total_qubits,
                                                     qubit_support,
                                                     target_support_id)
        else:
            if rho.dtype == torch.complex64:
                FTYPE = torch.float32
            elif rho.dtype == torch.complex128:
                FTYPE = torch.float64
            else:
                raise TypeError('Expected `rho` to be of type complex64 or '
                'complex128')
            num_support = qubit_support.shape[0]
            num_gates = 3 + num_support
            out = torch.empty([num_gates,2**total_qubits,2**total_qubits], 
                              dtype=rho.dtype, device=rho.device)
            threads_per_block = (thread_block_size,thread_block_size)
            blocks_per_grid = (math.ceil(rho.shape[0]/thread_block_size),
                               math.ceil(rho.shape[1]/thread_block_size),
                               num_gates)
            MicroGateEnsembleRho.forward_cuda[
                blocks_per_grid,
                threads_per_block
            ](out.detach().view(FTYPE),
              rho.detach().view(FTYPE),
              theta.item(),
              total_qubits,
              qubit_support,
              target_support_id
            )
        
        return out
    
    @staticmethod
    def setup_context(ctx, inputs, output):
        (rho, theta, total_qubits,qubit_support, target_support_id, 
         thread_block_size) = inputs
        ctx.total_qubits = total_qubits
        ctx.num_support = qubit_support.shape[0]
        ctx.target_support_id = target_support_id
        ctx.thread_block_size = thread_block_size
        ctx.save_for_backward(rho,theta,qubit_support)
    
    @staticmethod
    def backward(ctx, grad_output):
        rho, theta, qubit_support = ctx.saved_tensors

        if not torch.cuda.is_available():
            rho_grad = MicroGateEnsembleRho.rho_grad_torch(
                grad_output, theta, ctx.total_qubits, qubit_support,
                ctx.target_support_id)
            theta_grad = MicroGateEnsembleRho.theta_grad_torch(
                grad_output, rho, theta, ctx.total_qubits, qubit_support,
                ctx.target_support_id)
        else:
            rho_grad = torch.empty(
                [ctx.num_support+3, 2**ctx.total_qubits, 2**ctx.total_qubits],
                dtype=rho.dtype, device=rho.device)
            theta_grad = torch.zeros([3, rho.shape[0], rho.shape[1]],
                                     dtype=theta.dtype,device=theta.device)
            
            threads_per_block = (ctx.thread_block_size, ctx.thread_block_size)
            blocks_per_grid1 = (math.ceil(rho.shape[0]/ctx.thread_block_size),
                               math.ceil(rho.shape[1]/ctx.thread_block_size),
                               ctx.num_support+3)
            blocks_per_grid2 = (blocks_per_grid1[0], blocks_per_grid1[1], 3)
            
            if rho.dtype == torch.complex64:
                FTYPE = torch.float32
            elif rho.dtype == torch.complex128:
                FTYPE = torch.float64
            else:
                raise TypeError('Expected `rho` to be of type complex64 or '
                'complex128')
            
            MicroGateEnsembleRho.rho_grad_cuda[
                blocks_per_grid1,
                threads_per_block
            ](
                rho_grad.detach().view(FTYPE),
                grad_output.detach().view(FTYPE),
                theta.item(), 
                ctx.total_qubits, qubit_support, ctx.target_support_id
            )
            MicroGateEnsembleRho.theta_grad_cuda[
                blocks_per_grid2,
                threads_per_block
            ](
                theta_grad.detach(), 
                grad_output.detach().view(FTYPE), 
                rho.detach().view(FTYPE), 
                theta.item(), 
                ctx.total_qubits,qubit_support, ctx.target_support_id
            )

            theta_grad = theta_grad.sum()

        return rho_grad, theta_grad, None, None, None, None
    
    @staticmethod
    def forward_torch(rho:torch.Tensor, theta:torch.Tensor, total_qubits:int,
                qubit_support:torch.Tensor, target_support_id:int
                )->torch.Tensor:
        num_support = qubit_support.shape[0]
        target_qubit = qubit_support[target_support_id].item()
        out = torch.empty([num_support+3,2**total_qubits,2**total_qubits],
                          dtype=rho.dtype,device=rho.device)

        out[0] = rho
        out[1] = rx(rho,theta,total_qubits,target_qubit)
        out[2] = ry(rho,theta,total_qubits,target_qubit)
        out[3] = rz(rho,theta,total_qubits,target_qubit)
        for i in range(4,num_support+3):
            control_support_id = i-4 if i-4 < target_support_id else i-3
            control_qubit = qubit_support[control_support_id].item()
            out[i] = cnot(rho, total_qubits, target_qubit, control_qubit)
        return out
    
    @staticmethod
    @cuda.jit
    def forward_cuda(out:torch.Tensor, rho:torch.Tensor, theta:float,
                     total_qubits:int, qubit_support:torch.Tensor,
                     target_support_id:int)->None:
        # Thread and block indices.
        tx = cuda.threadIdx.x
        ty = cuda.threadIdx.y
        bx = cuda.blockIdx.x
        by = cuda.blockIdx.y
        G  = cuda.blockIdx.z
        bsizex = cuda.blockDim.x
        bsizey = cuda.blockDim.y

        # Global indices
        i = bx * bsizex + tx
        j = by * bsizey + ty

        # Total number of rows (and columns for rho) is 2^total_qubits.
        N = 1 << total_qubits
        # Check bounds.
        if i >= N or j >= N:
            return
        target_qubit = qubit_support[target_support_id]
        T = 1 << target_qubit

        if G == 0:
            out[G, i, 2*j]   = rho[i, 2*j]
            out[G, i, 2*j+1] = rho[i, 2*j+1]
        elif G == 1:
            rx_densitymatrix_kernel(out[G], rho, T, theta, i, j)
        elif G == 2:
            ry_densitymatrix_kernel(out[G], rho, T, theta, i, j)
        elif G == 3:
            rz_densitymatrix_kernel(out[G], rho, T, theta, i, j)
        else:
            if G - 4 < target_support_id:
                control_support_id = G - 4
            else:
                control_support_id = G - 3
            control_qubit = qubit_support[control_support_id]
            C = 1 << control_qubit
            cnot_density_matrix_kernel(out[G], rho, T, C, i, j)
    
    @staticmethod
    def rho_grad_torch(grad_output:torch.Tensor, theta:torch.Tensor, 
                       total_qubits:int, qubit_support:torch.Tensor,
                       target_support_id:int)->torch.Tensor:
        num_support = qubit_support.shape[0]
        target_qubit = qubit_support[target_support_id].item()
        out = torch.empty([num_support+3, 2**total_qubits, 2**total_qubits],
                          dtype=grad_output.dtype, device=grad_output.device)
        out[0] = grad_output[0]
        out[1] = rx(grad_output[1], -theta, total_qubits, target_qubit)
        out[2] = ry(grad_output[2], -theta, total_qubits, target_qubit)
        out[3] = rz(grad_output[3], -theta, total_qubits, target_qubit)
        for i in range(4, num_support+3):
            control_support_id = i-4 if i-4 < target_support_id else i-3
            control_qubit = qubit_support[control_support_id].item()
            out[i] = cnot(grad_output[i], total_qubits, target_qubit, 
                          control_qubit)

        return out
    
    @staticmethod
    @cuda.jit
    def rho_grad_cuda(out:torch.Tensor, grad_output:torch.Tensor, theta:float, 
                      total_qubits:int, qubit_support:torch.Tensor, 
                      target_support_id:int)->None:
        # Thread and block indices.
        tx = cuda.threadIdx.x
        ty = cuda.threadIdx.y
        bx = cuda.blockIdx.x
        by = cuda.blockIdx.y
        G  = cuda.blockIdx.z
        bsizex = cuda.blockDim.x
        bsizey = cuda.blockDim.y

        # Global indices
        i = bx * bsizex + tx
        j = by * bsizey + ty

        # Total number of rows (and columns for rho) is 2^total_qubits.
        N = 1 << total_qubits
        # Check bounds.
        if i >= N or j >= N:
            return
        
        target_qubit = qubit_support[target_support_id]
        T = 1 << target_qubit

        if G == 0:
            out[G, i, 2*j]   = grad_output[G, i, 2*j]
            out[G, i, 2*j+1] = grad_output[G, i, 2*j+1]
        elif G == 1:
            rx_densitymatrix_kernel(out[G], grad_output[G], T, -theta, i, j)
        elif G == 2:
            ry_densitymatrix_kernel(out[G], grad_output[G], T, -theta, i, j)
        elif G == 3:
            rz_densitymatrix_kernel(out[G], grad_output[G], T, -theta, i, j)
        else:
            if G - 4 < target_support_id:
                control_support_id = G - 4
            else:
                control_support_id = G - 3
            control_qubit = qubit_support[control_support_id]
            C = 1 << control_qubit
            cnot_density_matrix_kernel(out[G], grad_output[G], T, C, i, j)
    
    @staticmethod
    def theta_grad_torch(grad_output:torch.Tensor, rho:torch.Tensor, 
                         theta:torch.Tensor, total_qubits:int, 
                         qubit_support:torch.Tensor, target_support_id:int
                         )->torch.Tensor:
        target_qubit = qubit_support[target_support_id].item()
        T = 1 << target_qubit
        I = torch.arange(2**total_qubits,device=rho.device)
        
        s1 = torch.where((I[:,None] ^ I[None,:]) & T > 0, -torch.ones_like(rho),
                         torch.ones_like(rho))
        s2 = torch.where(I[None,:]&T > 0, -torch.ones_like(I[None,:]), 
                         torch.ones_like(I[None,:]))

        return ((grad_output[1:4,:,:].conj() * torch.stack([
            # Rx Term
            (-torch.sin(theta)/2 * rho
            +torch.sin(theta)/2 * rho[(I^T)[:,None], (I^T)[None,:]]
            + 1j*torch.cos(theta)/2 * (
                rho[I[:,None],(I^T)[None,:]] - rho[(I^T)[:,None],I[None,:]]
            )),

            # Ry Term:
            (- torch.sin(theta)/2 * rho  
            + torch.sin(theta)/2 * rho[(I^T)[:,None], (I^T)[None,:]] * s1
            + 1j*torch.cos(theta)/2 * (
                rho[I[:,None], (I^T)[None,:]]  * s2*1j 
                - rho[(I^T)[:,None], I[None,:]]* s2.T*(-1j)
            )),

            # Rz Term:
            (-torch.sin(theta)/2 * rho 
            + torch.sin(theta)/2 * rho * s1
            + 1j*torch.cos(theta)/2 * (
                rho*s2 - rho*s2.T
            ))

        ], dim=0)).real).sum()
    
    @staticmethod
    @cuda.jit
    def theta_grad_cuda(out:torch.Tensor, grad_output:torch.Tensor, 
                        rho:torch.Tensor, theta:float, total_qubits:int,
                        qubit_support:torch.Tensor, target_support_id:int
                        )->None:
        # Thread indices
        tx = cuda.threadIdx.x
        ty = cuda.threadIdx.y
        bx = cuda.blockIdx.x
        by = cuda.blockIdx.y
        G = cuda.blockIdx.z
        bsizex = cuda.blockDim.x
        bsizey = cuda.blockDim.y

        # Global indices
        i = bx * bsizex + tx
        j = by * bsizey + ty

        # Compute matrix dimension
        N = 1 << total_qubits  # 2^total_qubits
        if i >= N or j >= N or G >=3:
            return  # Out of bounds
        
        target_qubit = qubit_support[target_support_id]
        T = 1 << target_qubit

        if G == 0: 
            rx_density_matrix_theta_grad_kernel(out[G], rho, grad_output[1], T, theta, i, j)
            
        elif G == 1:
            ry_density_matrix_theta_grad_kernel(out[G], rho, grad_output[2], T, theta, i, j)
        else: # G == 2
            rz_density_matrix_theta_grad_kernel(out[G], rho, grad_output[3], T, theta, i, j)
