## Date: 2022/11/08
## Code: python

# import useful libraries
# import paddle tools
import paddle_quantum
import paddle
from paddle_quantum import set_backend
from paddle_quantum.state import State
from paddle_quantum.qinfo import state_fidelity
from paddle_quantum.ansatz import Circuit

# import simulation related packages
from QuantumLadder import QuantumLadder
from LadderExtenstion import ModelPool

# basic python tools and math related libraries
import numpy as np
from tqdm import tqdm
import json

# drawing tools
from matplotlib import pyplot as plt 
import warnings
warnings.filterwarnings("ignore")

# setting the code accuracy
paddle_quantum.set_dtype("complex128")    
set_backend("state_vector")


def json_store(file_name: str, result_dict: dict = None):
    with open(file_name, 'a') as res_store:
        json.dump(result_dict, res_store)


############################################################
################# Experimental platform ####################
############################################################
def loss_func(rho, psi):
    """
    Define the loss function (trace distance square)
    """

    diff = psi - rho
    loss = paddle.trace(diff.conj().t() @ diff).cast('float64')
    return loss


def complex_ansatz(N, D):
   cir = Circuit(N)
   cir.complex_entangled_layer(depth = D) 
   return cir


def Origin_train_loss(N, D, psi, ITR, LR=0.1):
    """
    The global QNN training loss.
    """

    cir = complex_ansatz(N, D)
    cir.randomize_param()
    opt = paddle.optimizer.Adam(learning_rate=LR, parameters=cir.parameters())

    # optimization iteration
    for itr in tqdm(range(1, ITR + 1)):

        output_state = cir()
        output_density = output_state.ket @ output_state.bra
        loss = loss_func(output_density, psi)
        loss.backward()
        opt.minimize(loss)
        opt.clear_grad()

    return loss.numpy()[0], cir, output_state


# Examing the adaptive method
if __name__ == "__main__":

    # Hyper parameters
    num_qubits = 10
    learning_rate = 0.1
    learning_ITR = 200
    
    # change to the StateData set for retrieving experimental results
    modelname = "test"
    psi = ModelPool[modelname]
    
    loss, cir, rho_state = Origin_train_loss(num_qubits, 3, psi, ITR=learning_ITR)
    rho = rho_state.ket @ rho_state.bra
    print("fidelity is ", state_fidelity(rho.numpy(), psi.numpy()))

    # Storing simulation data
    np.save(f"GlobalQNNSimulation_NQ_{num_qubits}_model_{modelname}.npy", rho_state.numpy())
