## Date: 2022/11/08
## Code: python

# import useful libraries
# import paddle tools
import paddle
import paddle_quantum
from paddle_quantum.ansatz import Circuit
from paddle_quantum import set_backend

# basic python tools and math related libraries
import numpy as np
from typing import List
import warnings
warnings.filterwarnings("ignore")

# setting the code accuracy
paddle_quantum.set_dtype("complex128")    
set_backend("state_vector")


# function pools for development 
# defining distance measures
def trace_loss_func(psi: paddle.Tensor, rho: paddle.Tensor) -> paddle.Tensor:
    """Define the loss function (2-norm)

    Args:
        psi (paddle.Tensor): input density of state 1
        rho (paddle.Tensor): input density of state 2

    Returns:
        paddle.Tensor: 2-norm loss with gradient
    """

    diff = psi - rho
    loss = paddle.trace((diff.conj().t() @ diff)).cast('float64')
    return loss


# tools for constructing circuit layers
def _random_pauli_tag(size: tuple, dim: int = 3):
    return np.random.choice(dim, size=size)  


#-----------------------------
# The Adaptive Layer class
# defining the adaptive layer class based on circuit class
class AdaptiveLayer(Circuit):
    """AdaptiveLayer class, can be used as a circuit object

    Base: Circuit object
    """

    def __init__(self, num_qubits: int, num_layers: int, target_qubits: List[int]) -> None:
        """__init__ funciton

        Args:
            num_qubits (int): the number of qubits
            num_layers (int): layer depth
            target_qubits (list): the target qubits acted with parametrised unitary
        """
        self.__target_qubits = target_qubits
        self.__num_layers = num_layers 
        super().__init__(num_qubits)
    
    # property settings
    @property
    def target_qubits(self):
        return self.__target_qubits


    @property
    def num_layers(self):
        return self.__num_layers


    # class methods
    # update target qubits in the class attribute
    @target_qubits.setter
    def update_target_qubits(self, new_targets):
        self.__target_qubits = new_targets


    # sqrt of H initialization
    def __sqrt_H_init(self):
        """
        ---sqrt(H)----
        ---sqrt(H)----
        ---sqrt(H)----
        """
        # initialize circuit with sqrt(H) layer
        self.ry(param=np.pi/4)


    # random rotation layer by pauli
    def __rand_layer_circuit(self, pauli):

        # pauli is a random array help determine rotation gates
        for qi in range(len(self.target_qubits)):
            if pauli[qi] == 0:
                self.rz(self.target_qubits[qi])
            elif pauli[qi] == 1:
                self.ry(self.target_qubits[qi])
            else:
                self.rx(self.target_qubits[qi])
                
        # Build adjacent CZ gates
        for qi in range(len(self.target_qubits)-1):
            self.cz([self.target_qubits[qi], self.target_qubits[qi+1]])


    def _rand_circuit(self, pauli_targ):
        
        # apply layers of random block
        for j in range(self.num_layers):
            self.__rand_layer_circuit(pauli_targ[:, j])


    # using McClean's parameterized circuit in the BP paper
    def iterating_layer(self, ini_layer=False):

        if ini_layer:
            self.__sqrt_H_init()

        # generate random pauli generators
        pauli_targ = _random_pauli_tag((len(self.target_qubits), self.num_layers))

        # initializing layer circuit
        self._rand_circuit(pauli_targ)


    # PQ complex entangled layer 
    def complex_ent_layer(self):
        if len(self.target_qubits) > 1:
            self.complex_entangled_layer(qubits_idx=self.target_qubits, num_qubits=self.num_qubits, depth=self.num_layers)
        else:
            for _layer in range(self.num_layers):
                self.u3(qubits_idx=self.target_qubits[0])
    