import torch
import numba.cuda as cuda
import math
from torch.autograd import Function
from itertools import product, count

from qtorch.unitaries import rx,ry,rz,cnot
from qtorch.unitaries.kernels import (
    rx_densitymatrix_kernel, rx_density_matrix_theta_grad_kernel,
    ry_densitymatrix_kernel, ry_density_matrix_theta_grad_kernel,
    rz_densitymatrix_kernel, rz_density_matrix_theta_grad_kernel,
    cnot_density_matrix_kernel
)

class GateEnsembleRho(Function):
    @staticmethod
    def forward(rho:torch.Tensor, theta:torch.Tensor, total_qubits:int, 
                target_qubit:int,thread_block_size:int):
        assert rho.dim() == 3, '`rho` must be a batched 2-d tensor'
        assert rho.shape[1:] == (2**total_qubits, 2**total_qubits), (
            f'Expected `rho.shape` to be {[2**total_qubits,2**total_qubits]}, '
            f'recieved {rho.shape}')
        assert theta.dim() == 0,'Exptected `theta` to be a 1-dimensional tensor'
        
        if not torch.cuda.is_available():
            out = GateEnsembleRho.forward_torch(rho,theta,total_qubits,
                                                target_qubit)
        else:
            batch_size = rho.size(0)
            out = torch.empty([batch_size, total_qubits+3,
                               2**total_qubits, 2**total_qubits], 
                              dtype=rho.dtype, device=rho.device)
            threads_per_block = (thread_block_size, thread_block_size)
            blocks_per_grid = (math.ceil(rho.shape[1]/thread_block_size),
                               math.ceil(rho.shape[2]/thread_block_size),
                               (total_qubits+3)*batch_size)
            if rho.dtype == torch.complex64:
                FTYPE = torch.float32
            elif rho.dtype == torch.complex128:
                FTYPE = torch.float64
            else:
                raise TypeError('Expected `rho` to be of type complex64 or '
                'complex128')
            GateEnsembleRho.forward_cuda[blocks_per_grid,threads_per_block](
                out.detach().view(FTYPE), rho.detach().view(FTYPE),
                theta.item(), total_qubits, target_qubit, batch_size)
        
        return out

    @staticmethod
    def setup_context(ctx, inputs, output):
        rho, theta, total_qubits, target_qubit, thread_block_size = inputs
        ctx.total_qubits = total_qubits
        ctx.target_qubit = target_qubit
        ctx.thread_block_size = thread_block_size
        ctx.save_for_backward(rho, theta)
    
    @staticmethod
    def backward(ctx, grad_output):
        rho, theta = ctx.saved_tensors
        
        if not torch.cuda.is_available():
            rho_grad = GateEnsembleRho.rho_grad_torch(
                grad_output, theta, ctx.total_qubits, ctx.target_qubit)
            theta_grad = GateEnsembleRho.theta_grad_torch(
                grad_output, rho, theta, ctx.total_qubits, ctx.target_qubit)
        else:
            batch_size = rho.size(0)
            N = 2**ctx.total_qubits
            num_gates = ctx.total_qubits + 3
            rho_grad = torch.empty(
                [batch_size, num_gates, N, N],
                dtype=rho.dtype, device=rho.device)
            theta_grad = torch.empty([batch_size, 3, N, N],
                                     dtype=theta.dtype,device=theta.device)
            
            threads_per_block = (ctx.thread_block_size, ctx.thread_block_size)
            blocks_per_grid1 = (math.ceil(N/ctx.thread_block_size),
                               math.ceil(N/ctx.thread_block_size),
                               num_gates*batch_size)
            blocks_per_grid2 = (blocks_per_grid1[0], 
                                blocks_per_grid1[1], 
                                3*batch_size)
            
            if rho.dtype == torch.complex64:
                FTYPE = torch.float32
            elif rho.dtype == torch.complex128:
                FTYPE = torch.float64
            else:
                raise TypeError('Expected `rho` to be of type complex64 or '
                'complex128')
            
            GateEnsembleRho.rho_grad_cuda[blocks_per_grid1,threads_per_block](
                rho_grad.detach().view(FTYPE), grad_output.detach().view(FTYPE),
                theta.item(), ctx.total_qubits, ctx.target_qubit, batch_size)
            rho_grad = rho_grad.sum(dim=1)
            
            GateEnsembleRho.theta_grad_cuda[blocks_per_grid2,threads_per_block](
                theta_grad.detach(), grad_output.detach().view(FTYPE), 
                rho.detach().view(FTYPE), theta.item(), ctx.total_qubits,
                ctx.target_qubit, batch_size)
            theta_grad = theta_grad.sum(dim=(1,2,3))

        return rho_grad, theta_grad, None, None, None
    
    @staticmethod
    @cuda.jit
    def forward_cuda(out:torch.Tensor, rho:torch.Tensor, theta:float, 
                     total_qubits:int, target_qubit:int, batch_size:int):
        # Thread and block indices.
        tx = cuda.threadIdx.x
        ty = cuda.threadIdx.y
        bx = cuda.blockIdx.x
        by = cuda.blockIdx.y
        bz = cuda.blockIdx.z
        bsizex = cuda.blockDim.x
        bsizey = cuda.blockDim.y

        # Global indices
        i = bx * bsizex + tx
        j = by * bsizey + ty
        G = bz % (total_qubits+3) # gate index
        B = bz // (total_qubits+3) # batch index

        # Total number of rows (and columns for rho) is 2^total_qubits.
        N = 1 << total_qubits
        T = 1 << target_qubit

        # Check bounds.
        if i >= N or j >= N or B >= batch_size:
            return

        if G == 0:
            out[B, G, i, 2*j]   = rho[B, i, 2*j]
            out[B, G, i, 2*j+1] = rho[B, i, 2*j+1]
        elif G == 1:
            rx_densitymatrix_kernel(out[B, G], rho[B], T, theta, i, j)
        elif G == 2:
            ry_densitymatrix_kernel(out[B, G], rho[B], T, theta, i, j)
        elif G == 3:
            rz_densitymatrix_kernel(out[B, G], rho[B], T, theta, i, j)
        else:
            if G - 4 < target_qubit:
                control_qubit = G - 4
            else:
                control_qubit = G - 3
            C = 1 << control_qubit
            cnot_density_matrix_kernel(out[B, G], rho[B], T, C, i, j)
    
    @staticmethod
    def forward_torch(rho,theta,total_qubits,target_qubit):
        batch_size = rho.size(0)
        out = torch.empty([batch_size,total_qubits+3,2**total_qubits,2**total_qubits],
                          dtype=rho.dtype,device=rho.device)

        for b in range(batch_size):
            out[b,0,:,:] = rho[b]
            out[b,1,:,:] = rx(rho[b],theta,total_qubits,target_qubit)
            out[b,2,:,:] = ry(rho[b],theta,total_qubits,target_qubit)
            out[b,3,:,:] = rz(rho[b],theta,total_qubits,target_qubit)
            for i in range(4,total_qubits+3):
                control_qubit = i-4 if i-4 < target_qubit else i-3
                out[b,i,:,:] = cnot(rho[b], total_qubits, target_qubit, control_qubit)
        return out
    
    @staticmethod
    def rho_grad_torch(grad_output, theta, total_qubits, target_qubit):
        batch_size = grad_output.size(0)
        out = torch.empty([batch_size, total_qubits+3, 
                           2**total_qubits, 2**total_qubits],
                           dtype=grad_output.dtype, device=grad_output.device)
        for b in range(batch_size):
            out[b,0] = grad_output[b,0]
            out[b,1] = rx(grad_output[b,1], -theta, total_qubits, target_qubit)
            out[b,2] = ry(grad_output[b,2], -theta, total_qubits, target_qubit)
            out[b,3] = rz(grad_output[b,3], -theta, total_qubits, target_qubit)
            for i in range(4, total_qubits+3):
                control_qubit = i-4 if i-4 < target_qubit else i-3
                out[b,i] = cnot(grad_output[b,i], total_qubits, target_qubit, 
                            control_qubit)

        return out.sum(dim=1)
    
    @staticmethod
    @cuda.jit
    def rho_grad_cuda(out:torch.Tensor, grad_output:torch.Tensor,
                      theta:float, total_qubits:int, target_qubit:int,
                      batch_size:int)->None:
        # Thread and block indices.
        tx = cuda.threadIdx.x
        ty = cuda.threadIdx.y
        bx = cuda.blockIdx.x
        by = cuda.blockIdx.y
        bz  = cuda.blockIdx.z
        bsizex = cuda.blockDim.x
        bsizey = cuda.blockDim.y

        # Global indices
        i = bx * bsizex + tx
        j = by * bsizey + ty

        # Total number of rows (and columns for rho) is 2^total_qubits.
        N = 1 << total_qubits
        T = 1 << target_qubit
        G = bz % (total_qubits+3)
        B = bz // (total_qubits+3)

        # Check bounds.
        if i >= N or j >= N or B >= batch_size:
            return

        if G == 0:
            out[B, G, i, 2*j]   = grad_output[B, G, i, 2*j]
            out[B, G, i, 2*j+1] = grad_output[B, G, i, 2*j+1]
        elif G == 1:
            rx_densitymatrix_kernel(out[B, G], grad_output[B, G], T, -theta,
                                    i, j)
        elif G == 2:
            ry_densitymatrix_kernel(out[B, G], grad_output[B, G], T, -theta,
                                    i, j)
        elif G == 3:
            rz_densitymatrix_kernel(out[B, G], grad_output[B, G], T, -theta,
                                    i, j)
        else:
            if G - 4 < target_qubit:
                control_qubit = G - 4
            else:
                control_qubit = G - 3
            C = 1 << control_qubit
            cnot_density_matrix_kernel(out[B, G], grad_output[B, G], T, C, i, j)
    
    @staticmethod
    @cuda.jit
    def theta_grad_cuda(theta_grad:torch.Tensor, grad_output:torch.Tensor, 
                        rho:torch.Tensor, theta:float, total_qubits:int,
                        target_qubit:int, batch_size:int):
        # Thread indices
        tx = cuda.threadIdx.x
        ty = cuda.threadIdx.y
        bx = cuda.blockIdx.x
        by = cuda.blockIdx.y
        bz = cuda.blockIdx.z
        bsizex = cuda.blockDim.x
        bsizey = cuda.blockDim.y

        # Global indices
        i = bx * bsizex + tx
        j = by * bsizey + ty

        # Compute matrix dimension
        N = 1 << total_qubits  # 2^total_qubits
        T = 1 << target_qubit
        G = bz % 3
        B = bz // 3

        if i >= N or j >= N or G >=3 or B >= batch_size:
            return  # Out of bounds

        if G == 0: 
            rx_density_matrix_theta_grad_kernel(theta_grad[B, G], rho[B],
                                                grad_output[B, 1], T, theta,
                                                i, j)
            
        elif G == 1:
            ry_density_matrix_theta_grad_kernel(theta_grad[B, G], rho[B],
                                                grad_output[B, 2], T, theta,
                                                i, j)
        else: # G == 2
            rz_density_matrix_theta_grad_kernel(theta_grad[B, G], rho[B],
                                                grad_output[B, 3], T, theta,
                                                i, j)

    @staticmethod
    def theta_grad_torch(grad_output, rho, theta, total_qubits, target_qubit):
        """
        Batched gradient of theta for Rx, Ry, Rz gate ensembles.

        Parameters:
        -----------
        grad_output: [B, Q+3, N, N] — Gradient from upstream
        rho:         [B, N, N]      — Density matrix input
        theta:       []             — Rotation angle
        total_qubits: int
        target_qubit: int

        Returns:
        --------
        theta_grad: [B] — Gradient of scalar loss w.r.t each theta[b]
        """
        device = rho.device
        B, G, N, _ = grad_output.shape
        I = torch.arange(N, device=device)

        T = 1 << target_qubit

        # Binary masks for flipping and sign adjustment
        Ix = I ^ T
        s1 = torch.where((I[:, None] ^ I[None, :]) & T > 0, -1.0, 1.0).to(rho.dtype)
        s2 = torch.where(I[None, :] & T > 0, -1.0, 1.0).to(rho.dtype)

        # Expand for batch processing
        x_rho_x = rho[:, Ix[:, None], Ix[None, :]]                 # [B, N, N]
        i_rho_x = rho[:, I[:, None], Ix[None, :]]                # [B, N, N]
        x_rho_i = rho[:, Ix[:, None], I[None, :]]                # [B, N, N]

        # Prepare the three gate derivatives
        rx_term = (
            -torch.sin(theta) / 2 * rho
            + torch.sin(theta) / 2 * x_rho_x
            + 1j * torch.cos(theta) / 2 * (i_rho_x - x_rho_i)
        )

        ry_term = (
            -torch.sin(theta) / 2 * rho
            + torch.sin(theta) / 2 * x_rho_x * s1
            + 1j * torch.cos(theta) / 2 * (
                i_rho_x * s2 * 1j - x_rho_i * s2.T * (-1j)
            )
        )

        rz_term = (
            -torch.sin(theta) / 2 * rho
            + torch.sin(theta) / 2 * rho * s1
            + 1j * torch.cos(theta) / 2 * (
                rho * s2 - rho * s2.T
            )
        )

        # Stack gradient terms [B, 3, N, N]
        d_theta_rho = torch.stack([rx_term, ry_term, rz_term], dim=1)

        # grad_output[:, 1:4, :, :] * d_theta_rho —> [B, 3, N, N]
        grad_product = grad_output[:, 1:4].conj() * d_theta_rho

        # Reduce across the matrix dimensions and real part
        return grad_product.real.sum(dim=(1,2,3))  # shape [B]
