import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Callable
from warnings import warn

from .gate_ensemble_psi import GateEnsemblePsi

class QDARTS(nn.Module):
    '''
    Module to implement the Quantum Differentiable Architecture Search
    (QuantumDARTS) algorithm (https://proceedings.mlr.press/v202/wu23v.html).

    The module is used to search for a particular quantum circuit architecture
    for a variational quantum algorithm.

    The search space comprises of quantum circuits of n qubits and m layers.
    Each layer consists of n gates, targetting each of the qubits in the
    circuit. The possible gates are:
        - Identity
        - Pauli X Rotation
        - Pauli Y Rotation
        - Pauli Z Rotation
        - CNOT gate which can be controlled by any of the other qubits
    
    The algorithm is end-to-end meaning that the parameters of the rotation
    gates are learned simultaneously with the gate probabilities. Each rotation
    gate acting on a single qubit in a layer shares the same rotation angle.

    Attributes
    ----------
    num_qubits: int
        The total number of qubits
    num_layers: int
        The toal number of layers in the architecture
    tau: float
        The temperature value used for Gumbel-Softmax,
        defualt = 0.5
    gumbel_hard: bool
        Flag on whether to use hard sampling in forward pass of Gumbel-Softmax,
        default = True
    thread_block_size: int
        The size of the thread block used to call the gate ensemble kernel.
        default = 16
    noise_model: Callable[[torch.Tensor], torch.Tensor] | None
        A function to implement some noise model on the quantum simulation. The
        function should take only a state vector as input and return the noisy
        state vector. The noise is applied after every layer.
    
    Buffers
    -------
    psi0: torch.Tensor
        The statevector to initialize the quantum simulation with.
    
    Methods
    -------
    forward(logits, angles, softmax_temperature, skip_validation):
        Generates the probability distribution of gates in the architecture 
        search according to the passed logits, samples a circuit via 
        Gumbel-Softmax and applies it to |psi0>, returning the resulting
        Statevector
    '''
    def __init__(self, num_qubits:int, num_layers:int, gumbel_temp:float=0.5,
                 gumbel_hard:bool=True, psi0:torch.Tensor|None=None, 
                 thread_block_size:int=16,
                 noise_model:Callable[[torch.Tensor],torch.Tensor]|None=None):
        '''
        Arguments
        ---------
        num_qubits: int
            The total number of qubits.
        num_layers: int
            The toal number of layers in the architecture.
        gumbel_temp: float, optional
            The temperature value used for Gumbel-Softmax.
            defualt = 0.5
        gumbel_hard: bool, optional
            Flag on whether to use hard sampling in forward pass of 
            Gumbel-Softmax.
            default = True
        psi0: torch.Tensor, optional
            The statevector to initialize the quantum simulation with. If None,
            psi0 is initalized to the |0...0> state.
            default = None
        thread_block_size: int, optional
            The size of the thread block used to call the gate ensemble kernel.
            default = 16
        noise_model: Callable[[torch.Tensor], torch.Tensor], optional
            A function to implement some noise model on the quantum simulation. 
            The function should take only a state vector as input and return the
            noisy state vector. The noise is applied after every layer.
            default = None
        
        Raises
        ------
        ValueError:
            - psi0 has the wrong shape
        '''
        super().__init__()
        self.num_qubits = num_qubits
        self.num_layers = num_layers
        self.tau = gumbel_temp
        self.gumbel_hard = gumbel_hard

        if psi0 is not None:
            if psi0.shape != (2**self.num_qubits, ):
                raise ValueError('`psi0` must have shape [2^Q] for Q qubits')
            if psi0.dtype not in [torch.complex64, torch.complex128]:
                warn('`psi0` does not have a complex dtype, it will be '
                'converted to a complex type')
                if psi0.dtype == torch.float64:
                    psi0 = psi0.to(torch.complex128)
                else:
                    psi0 = psi0.to(torch.complex64)
        else:
            psi0 = torch.tensor([1.0]+[0.0]*(2**num_qubits - 1),
                                dtype=torch.complex64)
        self.thread_block_size = thread_block_size
        self.register_buffer('psi0', psi0)

        self.noise_model = noise_model
    
    def forward(self, logits:torch.Tensor, angles:torch.Tensor, 
                angles_optimizer:torch.optim.Optimizer, num_iter:int,
                angle_loss_fn:Callable[[torch.Tensor,torch.Tensor],torch.Tensor],
                softmax_temperature:float=1.0, skip_validation:bool=True,
                psi0:torch.Tensor|None=None
                )->torch.Tensor:
        '''
        Generates the probability distribution of gates in the architecture 
        search according to the passed logits, samples a circuit via 
        Gumbel-Softmax and applies it to |psi0>, returning the resulting
        Statevector
        
        Arguments
        ---------
        logits: torch.Tensor
            Shape [L,Q,Q+3] - The unnormalized log probabilities for each of the 
            gates in the quantum architecture.
        angles: torch.Tensor
            Shape [L, Q]
            The rotation angles for each of the gates in the 
            quantum architecture.
        softmax_temperature: float, optional
            The temperature scaling to be used in the softmax function which
            converts the logits to the gate probabilities.
            default = 1.0
        skip_validation: bool, optional
            Flag on whether to skip the data validation logic.
            default = True
        psi0: torch.Tensor, optional
            Batch of initial state vectors to override the initial state buffer
        
        Returns
        -------
        torch.Tensor
            phi - The quantum state obtained by applying the sampled circuit to
            the initial state |psi0>
        
        Raises
        ------
        ValueError
            - logits has the wrong shape
            - angles has the wrong shape
        '''
        if not skip_validation:
            if logits.shape != (self.num_layers, self.num_qubits, 
                                self.num_qubits+3):
                raise ValueError('`logits` must have shape: [L, Q, Q+3] for L '
                                 'layers and Q qubits')
            if angles.shape != (self.num_layers, self.num_qubits):
                raise ValueError('`angles` must have shape [L, Q] for L '
                                    'layers and Q qubits')
            
        if psi0 is not None:
            assert psi0.dim() == 2, '`psi0` must be a batch of state vectors.'
            if psi0.size(1) != 2**self.num_qubits:
                raise ValueError('Statevector size of `psi0` does not match qubit count.')
            batch_size = psi0.size(0)
        else:
            batch_size = 1
            
        h = F.gumbel_softmax(logits/softmax_temperature, self.tau, 
                             self.gumbel_hard, dim=-1)
        
        def qsim(h:torch.Tensor, angles:torch.Tensor)->torch.Tensor:
            if psi0 is None:
                psi = self.psi0.unsqueeze(0).expand(batch_size,-1).contiguous()
            else:
                psi = psi0.clone().contiguous()
            for layer in range(self.num_layers):
                for target_qubit in range(self.num_qubits):
                    psi = torch.sum(h[layer,target_qubit,:,None]
                                    * GateEnsemblePsi.apply(
                                        psi, angles[layer, target_qubit],
                                        self.num_qubits, target_qubit, 
                                        self.thread_block_size)
                                        , dim = 1)
                if self.noise_model is not None:
                    psi = self.noise_model(psi)
            return psi
        
        for iter in range(num_iter):
            angles_optimizer.zero_grad()
            psi = qsim(h.detach(), angles)
            angle_loss = angle_loss_fn(psi, angles)
            angle_loss.backward()
            angles_optimizer.step()
        
        psi = qsim(h, angles.detach())
        return psi, h
