import torch
import torch.nn as nn

from typing import Callable
from warnings import warn

from .gate_ensemble_rho import GateEnsembleRho

class RhoDARTS(nn.Module):
    '''
    Module to implement our modification of QuantumDARTS -- Rho-DARTS.

    The module is used to search for a particular quantum circuit architecture
    for a variational quantum algorithm.

    The search space comprises of quantum circuits of n qubits and m layers.
    Each layer consists of n gates, targetting each of the qubits in the
    circuit. The possible gates are:
        - Identity
        - Pauli X Rotation
        - Pauli Y Rotation
        - Pauli Z Rotation
        - CNOT gate which can be controlled by any of the other qubits
    
    The algorithm is end-to-end meaning that the parameters of the rotation
    gates are learned simultaneously with the gate probabilities. Each rotation
    gate acting on a single qubit in a layer shares the same rotation angle.

    In our implementation, we represent quantum states using density matrices
    updated by probabilistic ensembles of unitary gates to ensure physical 
    interpretation of the quantum states during back-propagation.

    Attributes
    ----------
    num_qubits: int
        The total number of qubits
    num_layers: int
        The toal number of layers in the architecture
    thread_block_size: int
        The size of the thread block used to call the gate ensemble kernel.
        default = 16
    noise_model: Callable[[torch.Tensor], torch.Tensor]
        A function to implement some noise model on the quantum simulation. The
        function should take only a density matrix as input and return the noisy
        density matrix. The noise is applied after every layer.
    
    Buffers
    -------
    rho: torch.Tensor
        The density matrix to initialize the quantum simulation with.

    Methods
    -------
    forward(logits, angles, softmax_temperature, skip_validation):
        Generates the ensemble circuit according to the passed logits and angles
        and applies it to the initial state, returning the resulting density 
        matrix.
    '''
    def __init__(self, num_qubits:int, num_layers:int, psi0:torch.Tensor|None=None,
                 thread_block_size:int=16,
                 noise_model:Callable[[torch.Tensor],torch.Tensor]|None=None):
        '''
        Arguments
        ---------
        num_qubits: int
            The total number of qubits
        num_layers: int
            The toal number of layers in the architecture
        psi0: torch.Tensor, optional
            The statevector of the pure state used to initialize the quantum 
            simulation. If None, psi0 is assumed to be the |0...0> state.
            default = None
        thread_block_size: int, optional
            The size of the thread block used to call the gate ensemble kernel.
            default = 16
        noise_model: Callable[[torch.Tensor], torch.Tensor], optional
            A function to implement some noise model on the quantum simulation. 
            The function should take only a density matrix as input and return 
            the noisy density matrix. The noise is applied after every layer.
            default = None
        '''
        super().__init__()
        self.num_qubits = num_qubits
        self.num_layers = num_layers

        if psi0 is not None:
            if psi0.shape != (2**self.num_qubits, ):
                raise ValueError('`psi0` must have shape [2^Q] for Q qubits')
            if psi0.dtype not in [torch.complex64, torch.complex128]:
                warn('`psi0` does not have a complex dtype, it will be '
                'converted to a complex type')
                if psi0.dtype == torch.float64:
                    psi0 = psi0.to(torch.complex128)
                else:
                    psi0 = psi0.to(torch.complex64)
            rho = torch.outer(psi0, psi0.conj())
        else:
            rho = torch.zeros([2**self.num_qubits,2**self.num_qubits], 
                              dtype=torch.complex64)
            rho[0,0] = 1.0
        
        self.thread_block_size = thread_block_size
        self.register_buffer('rho0', rho)

        self.noise_model = noise_model
    
    def forward(self, logits:torch.Tensor, angles:torch.Tensor, 
                softmax_temperature:float=1.0, skip_validation:bool=True,
                psi0:torch.Tensor|None=None):
        '''
        Generates the ensemble circuit according to the passed logits and angles
        and applies it to the initial state, returning the resulting density 
        matrix.

        Arguments
        ---------
        logits: torch.Tensor
            Shape [L, Q, Q+3] - The unnormalized log probabilities for each of 
            the gates in the quantum architecture.
        angles: torch.Tensor
            Shape [L, Q]
            The rotation angles for each of the gates in the quantum 
            architecture.
        softmax_temperature: float, optional
            The temperature scaling to be used in the softmax function which 
            converts the logits to the gate probabilities.
            Default 1.0
        skip_validation: bool, optional
            Flag on whether to skip the data validation logic.
        psi0: torch.Tensor, optional
            Batch of initial states to override the initial state buffer
        
        Returns
        -------
        torch.Tensor
            rho - The density matrix obtained by applying the ensemble of 
            circuits to the initial state |psi0>
        
        Raises
        ------
        ValueError
            - logits has the wrong shape
            - angles has the wrong shape
        '''
            
        if not skip_validation:
            if logits.shape != (self.num_layers, self.num_qubits, self.num_qubits+3):
                raise ValueError('`logits` must have shape: [L, Q, Q+3] for L layers and Q qubits')
            if angles.shape != (self.num_layers, self.num_qubits):
                raise ValueError('`angles` must have shape [L, Q] for L layers and Q qubits')
        
        if psi0 is not None:
            assert psi0.dim() == 2, '`psi0` must be a batch of state vectors.'
            if psi0.size(1) != 2**self.num_qubits:
                raise ValueError('Statevector size of `psi0` does not match qubit count.')
            batch_size = psi0.size(0)
            rho = torch.vmap(torch.outer)(psi0, psi0.conj()).contiguous()
        else:
            batch_size = 1
            rho = self.rho0.expand(batch_size,-1,-1).contiguous()
        
        h = torch.softmax(logits/softmax_temperature, dim=-1)

        for layer in range(self.num_layers):
            for target_qubit in range(self.num_qubits):
                rho = torch.sum(h[layer,target_qubit,:,None,None] 
                                * GateEnsembleRho.apply(
                                    rho, 
                                    angles[layer,target_qubit], 
                                    self.num_qubits, 
                                    target_qubit, 
                                    self.thread_block_size)
                                ,dim = 1)
            if self.noise_model is not None:
                rho = torch.vmap(self.noise_model)(rho)
        return rho
