## Date: 2022/11/08
## Code: python

# import useful libraries
# import paddle tools
import paddle
import paddle_quantum
import paddle.optimizer
from paddle_quantum.state import zero_state
from paddle_quantum.ansatz import Circuit
from paddle_quantum import set_backend
from paddle_quantum.qinfo import partial_trace
from paddle_quantum.state import State

# import AdaptiveLayer tool
from AdaptiveLayer import AdaptiveLayer, trace_loss_func

# basic python tools and math related libraries
import numpy as np
from typing import List, Iterable, Union, Optional
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

# setting the code accuracy
paddle_quantum.set_dtype("complex128")    
set_backend("state_vector")


# function pool
# generate the list of target qubits given the number of qubits in the circuit.
def adaptive_targ_list(num_qubits: int, cutoff: int = None) -> List[int]:
    """at each step of adaptive iteration, we need to specify
    the qubits that the unitary would be acted on. Here is the
    the function that return the lists of such qubits for each 
    iteration for the most general adaptive method

    Args:
        num_qubits (int): the number of qubits for the target
        cutoff (int): the maximum step unitary width, default to the number of qubits
        (if cutoff > n/2, the maximum width will automatically become n/2)

    Returns:
        List: List of lists storing target qubits for each iter
    """
    if cutoff == None:
        cutoff = 2*np.ceil(np.log2(num_qubits))
    else:
        assert cutoff <= 2*np.ceil(np.log2(num_qubits)), "Cutoff exceed the maximum allowed value"

    res = []
    for i in range(num_qubits):
        if 2*(i + 1) <= num_qubits:
            res.append([j for j in range(i, 2*(i + 1)) if j-i+1 <= cutoff])
        else:
            res.append([j for j in range(i, num_qubits) if j-i+1 <= cutoff])
    return res


def linear_targ_list(num_qubits: int) -> List[int]: 
    """at each step of adaptive iteration, we need to specify
    the qubits that the unitary would be acted on. Here is the
    the function that return the linear set-up lists of such 
    qubits for each iteration for linear unitary width = 2

    Args:
        num_qubits (int): the number of qubits for the target

    Returns:
        List: List of lists storing target qubits for each iter
    """
    res = []
    for i in range(num_qubits):
        if i + 1 <= num_qubits - 1:
            res.append([i, i + 1])
        else:
            res.append([i])
    return res


# defining the local loss function
def local_loss_opt(circuit: Circuit, 
                   state_in: Optional[State], 
                   target_state: paddle.Tensor, 
                   adapt_iter: int, 
                   ITR: int, 
                   LR: float, 
                   opt: paddle.optimizer = None) -> Iterable[Union[float, paddle.Tensor, List[float]]]:
    """
    local loss was design for each adaptive iteration's optimization. For each local system
    the optimiser will minimize the local loss so that the subsystem approaches to the 
    corresponding target subsystem.

    Args:
        circuit (Circuit): quantum circuit or layer for computation.
        state_in (State): initial state for the quantum circuit.
        target_state (paddle.Tensor): target state for learning (a density).
        adapt_iter (int): Ladder linear step number.
        ITR (int): the number of local training iterations.
        LR (float): learning rate.
        opt(paddle.optimizer): PaddlePaddle optimizer. Default to None.

    Returns:
        float: minimized loss.
        State: output state.
        List[float]: all loss iteratively.
    """

    if opt == None:
        opt = paddle.optimizer.Adam(learning_rate=LR, parameters=circuit.parameters())
    else:
        opt = opt(learning_rate=LR, parameters=circuit.parameters())
    
    reduce_target = partial_trace(target_state, 2**(adapt_iter + 1), 2**(int(np.log2(target_state.shape[0])) - adapt_iter - 1), 2)
    local_loss_store = []
    for i in range(1, ITR + 1):
        
        # compute the loss and resulting state after each layer iteration
        state_out = circuit(state_in)
        state_out_density = state_out.ket @ state_out.bra

        # measure the partial trace difference
        reduce_out = partial_trace(state_out_density, 2**(adapt_iter + 1), 2**(circuit.num_qubits - adapt_iter - 1), 2)
        loss = trace_loss_func(reduce_out, reduce_target)
        local_loss_store.append(loss.numpy()[0])
        
        # paddle tool, loss backwards for gradient computation
        loss.backward()
        opt.minimize(loss)
        opt.clear_grad()

    # fix the paddle tensor and stop the gradient count for output state
    state_out.data = paddle.to_tensor(state_out.data.numpy())
    return loss.numpy()[0], state_out, local_loss_store


# defining the global loss function
def global_loss_opt(circuit: Circuit, state_in: Optional[State], target_state: paddle.Tensor, 
                    adapt_iter: int, ITR: int, LR: float, 
                    opt: paddle.optimizer = None) -> Iterable[Union[float, paddle.Tensor, List[float]]]:
    """
    Global loss was design for each adaptive iteration's optimization. Counting the entire
    circuit as a whole system and compute the global trace loss.

    Args:
        circuit (Circuit): quantum circuit or layer for computation.
        state_in (State): initial state for the quantum circuit.
        target_state (paddle.Tensor): target state for learning (a density).
        adapt_iter (int): Ladder linear step number .
        ITR (int): the number of local training iterations.
        LR (float): learning rate.
        opt(paddle.optimizer): PaddlePaddle optimizer. Default to None.


    Returns:
        float: minimized loss.
        State: output state.
        List[float]: all loss iteratively.
    """

    if opt == None:
        opt = paddle.optimizer.Adam(learning_rate=LR, parameters=circuit.parameters())
    else:
        opt = opt(learning_rate=LR, parameters=circuit.parameters())
    opt.clear_grad()
    global_loss_store = []
    for i in range(1, ITR + 1):
        # compute the loss and resulting state after each layer iteration

        state_out = circuit(state_in)
        state_out_density = state_out.ket @ state_out.bra

        loss = trace_loss_func(state_out_density, target_state)
        global_loss_store.append(loss.numpy()[0])
        
        # paddle tool, loss backwards for gradient computation
        loss.backward()
        opt.minimize(loss)
        opt.clear_grad()

    # fix the paddle tensor and stop the gradient count for output state
    state_out.data = paddle.to_tensor(state_out.data.numpy())
    return loss.numpy()[0], state_out, global_loss_store


# defining the QSSM class
class QuantumLadder:
    """
    QuantumLadder class was specified to perform the quantum state preparation tasks via
    successive iterations. The tool was based on the classical optimization methods and quantum
    circuit resources.
    """
    
    # dictionary of all adaptive method
    ADAPT_MED = {"adapt": adaptive_targ_list, "linear": linear_targ_list}


    def __init__(self, num_qubits: int, ITR: int = 100, LR: float = 0.1, N_L: int = 10, 
                 adapt_med: str = "adapt", max_width: int = None) -> None:
        """
        Args:
            num_qubits (int): the number of qubits in the problem.
            ITR (int, optional): the number of Layer optimization iterations. Defaults to 100.
            LR (float, optional): learning rate. Defaults to 0.1.
            N_L (int, optional): Layer depth. Defaults to 10.
            adapt_med (str): Defining the adaptive method for the algorithm.
            max_width (int): Define the maximum unitary width in the ladder.
        """
        
        # Hype parameters        
        self.__num_qubits = num_qubits     # Set the number of qubits
        self.__ITR = ITR                   # Set the number of iterations
        self.__LR = LR                     # Set the learning rate
        self.__N_L = N_L                   # Set the number of layers in each iteration
        self.max_width = max_width         # Set the maximum number of step unitary width
  
        # identify whether the method is correct
        assert adapt_med in self.ADAPT_MED.keys(), f"Invalid adaptive method, must be in {list(self.ADAPT_MED.keys())}"
        self.adapt_med = adapt_med

        # The number of adaptive iterations
        self.__adapt_iter_target = self.ADAPT_MED[adapt_med](self.num_qubits, cutoff = self.max_width)


    # These are just properties of the QuantumLadder
    @property
    def num_qubits(self):
        return self.__num_qubits


    @property
    def get_ITR(self):
        return self.__ITR
    
    
    @property
    def get_LR(self):
        return self.__LR
    
    
    @property
    def num_layers(self):
        return self.__N_L


    @property
    def adapt_iter_target(self):
        return self.__adapt_iter_target


    @property
    def adapt_ITR(self):    
        return len(self.adapt_iter_target)


    #-----------------------------------
    # class methods  
    # state learning method
    def state_learning(self, 
                       target_rho: paddle.Tensor,
                       method: str = "local", 
                       draw_layer: bool = False) -> Iterable[Union[State, List[paddle.Tensor], List[List[float]]]]:
        """Ladder learning for state learning tasks

        Args:
            target_rho (paddle.Tensor): target density state
            method (str): using global or local cost. Defaults to local.
            draw_layer (bool, optional): plot the step circuit. Defaults to False.

        Returns:
            State:  final output state
            List:   optimal parameters
            List:   step loss values
        """
        # store global trace distances between optimized state and target state
        opt_param_store = []
        step_loss = []

        # generate target qubit list for adaptive iteration
        target_qubits_list = self.adapt_iter_target

        # initialize zero state
        update_state = zero_state(self.num_qubits)

        for adapt_iter in tqdm(range(self.adapt_ITR)):

            # update adaptive circuit for each iteration
            adapt_cir = AdaptiveLayer(self.num_qubits, self.num_layers, target_qubits_list[adapt_iter])
            adapt_cir.complex_ent_layer()

            # draw adaptive layers
            if draw_layer:
                print(adapt_cir)
            
            # adaptive optimization
            if method == "local":
                loss_final, state_out, loss_store = local_loss_opt(adapt_cir, update_state, target_rho, adapt_iter, self.get_ITR, self.get_LR)
            elif method == "global":
                loss_final, state_out, loss_store = global_loss_opt(adapt_cir, update_state, target_rho, adapt_iter, self.get_ITR, self.get_LR)
            else:
                raise ValueError("No such method!")

            # update state
            update_state = state_out
            opt_param_store.append(adapt_cir.param)
            step_loss.append(loss_store)
        
        return update_state, opt_param_store, step_loss


    # state learning method
    def fast_state_learning(self, target_rho: paddle.Tensor, method:str="local", 
                            draw_layer:bool=False) -> Iterable[Union[State, List[paddle.Tensor], List[List[float]]]]:
        """Ladder learning for state learning tasks (fast model with sub-circuits)

        Args:
            target_rho (paddle.Tensor): target density state.
            method (str): using global or local cost. Defaults to local.
            draw_layer (bool, optional): plot the step circuit. Defaults to False.

        Returns:
            State:  final output state
            List:   optimal parameters
            List:   step loss values
        """
        
        # store global trace distances between optimized state and target state
        opt_param_store = []
        step_loss = []

        # generate target qubit list for adaptive iteration
        target_qubits_list = self.adapt_iter_target

        # initialize update_state
        update_state = None

        for adapt_iter in tqdm(range(self.adapt_ITR)):

            # sub-circuit initialization
            layer_num_qubits = target_qubits_list[adapt_iter][-1] + 1
            
            # update adaptive circuit for each iteration
            adapt_cir = AdaptiveLayer(layer_num_qubits, self.num_layers, target_qubits_list[adapt_iter])
            adapt_cir.complex_ent_layer()

            # draw adaptive layers
            if draw_layer:
                print(adapt_cir)
            
            # adaptive optimization
            if method == "local":
                loss_final, state_out, loss_store = local_loss_opt(adapt_cir, update_state, target_rho, adapt_iter, self.get_ITR, self.get_LR)
            elif method == "global":
                loss_final, state_out, loss_store = global_loss_opt(adapt_cir, update_state, target_rho, adapt_iter, self.get_ITR, self.get_LR)
            else:
                raise ValueError("No such method!")

            # update 
            try:
                next_layer_num_qubits = target_qubits_list[adapt_iter + 1][-1] + 1
                extra_qubits = next_layer_num_qubits - layer_num_qubits
                update_state = State(paddle.kron(state_out.ket, paddle.eye(2**extra_qubits,1)))
            except IndexError:
                update_state = state_out

            opt_param_store.append(adapt_cir.param)
            step_loss.append(loss_store)
        
        return update_state, opt_param_store, step_loss
