## Date: 2022/11/08
## Code: python

import paddle
from paddle_quantum.linalg import haar_unitary, NKron
from paddle_quantum.state.common import ghz_state
from paddle_quantum.state import State
from paddle_quantum.trotter import get_1d_heisenberg_hamiltonian
import numpy as np
import random
from typing import Iterable, List


# Some useful functions
def SchmidtRankState(num_qubits: int, rank: int) -> paddle.Tensor:
    """Generating random pure state with specific Schmidt rank in the computational basis

    Args:
        num_qubits (int): the number of qubits
        rank (int): Schmidt rank
    
    Returns:
        paddle.Tensor: Tensor density operator form of state
    """
    
    assert num_qubits >= 2, "The number of qubits should be at least 2"
    assert rank <= 2**int(num_qubits/2), "The Schmidt rank is too large"

    if rank == 1:
        sub_dim = np.random.randint(low=1, high=int(num_qubits/2) + 1)
    else:
        sub_dim = np.random.randint(low=int(np.ceil(np.log2(rank))), high=int(num_qubits/2) + 1)
    sup_dim = num_qubits - sub_dim

    # randomize basis element
    sub_state_indexA = np.random.choice(range(2**sup_dim), rank, replace=False)
    sub_state_indexB = np.random.choice(range(2**sub_dim), rank, replace=False)
    index_list = []
    for sub_i in range(rank):
        binary_state_index = format(sub_state_indexA[sub_i], "b").zfill(sup_dim) +  format(sub_state_indexB[sub_i], "b").zfill(sub_dim)
        index_list.append(int(binary_state_index, 2))
    
    # generating the random Schmidt coefficients
    schmidt_coefs = np.sqrt(np.random.dirichlet(range(1, rank+1)))
    state_vector = paddle.zeros((2**num_qubits, 1))
    for i in range(rank):
        state_vector[index_list[i]] = schmidt_coefs[i]
    
    # res = state_vector @ state_vector.conj().T

    # return paddle.to_tensor(res, dtype='complex64')
    return state_vector


def purification(rho: paddle.Tensor, nB: int) -> paddle.Tensor:
    """Generate a random purification state of given density and given number of ancillary

    Args:
        rho (paddle.Tensor): input density
        nB (int): number of ancillary qubits

    Returns:
        paddle.Tensor: output pure state
    """
    
    rank_rho = np.linalg.matrix_rank(rho.numpy())
    assert rank_rho <= 2**nB, "Not enough ancillary dimension"
    haar_random_U = haar_unitary(nB)
    e, v = paddle.linalg.eigh(rho)

    purified_state = paddle.zeros([rho.shape[0] * 2**(nB), 1])
    random_bit_string_idx = random.sample(range(2**nB), rank_rho)
    for i in range(rank_rho):
        temp_array = paddle.zeros([2**nB, 1])
        temp_array[random_bit_string_idx[i], 0] = 1
        purified_state = purified_state + ((e[-1+i])**(1/2)) * NKron(v[:, -1+i].reshape([rho.shape[0], 1]), (haar_random_U @ temp_array))
    
    return purified_state.reshape([purified_state.shape[0]])


def fidelity_loss(rho: paddle.Tensor, sigma: paddle.Tensor, tol: float = 1e-14) -> paddle.Tensor:
    """Gradient net fidelity

    Args:
        rho (paddle.Tensor): density rho
        sigma (paddle.Tensor): density sigma
        tol (float, optional): tol for skip sqrt (NaN). Defaults to 1e-14.

    Returns:
        paddle.Tensor: single value of fidelity
    """

    tol_array = tol * paddle.ones([rho.shape[0]])
    eig_value, eig_vector = paddle.linalg.eigh(rho)
    sqrt_rho = eig_vector @ paddle.diag(paddle.sqrt(eig_value + tol_array)) @ eig_vector.conj().T
    
    inner_mat = sqrt_rho @ sigma @ sqrt_rho
    inner_eig_value, inner_eig_vector = paddle.linalg.eigh(inner_mat)
    sqrt_inner_mat = inner_eig_vector @ paddle.diag(paddle.sqrt(inner_eig_value + tol_array)) @ inner_eig_vector.conj().T
    
    fidelity = paddle.trace(sqrt_inner_mat).real()

    return fidelity


def random_pauli_word(sys_qubit_range: Iterable, k: int, set_of_pauli: list):
    """Generating random pauli word

    Returns:
        str: Pauli word in string
    """
    
    # random sample K indices from all number of qubits
    q_index_list = random.sample(sys_qubit_range, k=k)
    q_index_list.sort()
    q_index_list = list(map(str, q_index_list))

    # random choose k single qubit pauli
    pauli_string_list = random.choices(set_of_pauli, k=k)
    pauli_word_temp = [''.join(list(q_word_pair)) for q_word_pair in zip(pauli_string_list, q_index_list)]

    # remove identities
    pauli_word = ",".join([word for word in pauli_word_temp if word[0] != "I"])

    return pauli_word


def random_pauli_string(num_qubits: int, k: int, num_items: int, 
                        coef_range: tuple = (-2,2), set_of_pauli: list = ["I","X","Y","Z"]) -> list:
    """Function generates random k-local pauli string

    Args:
        num_qubits (int): The total number of qubits
        k (int): Index for k-local
        num_items (int): The number of items in the hamiltonian
        coef_range: (tuple): range for generating random coefficients, default to (-2,2)

    Returns:
        list: generated Pauli string
    """

    res_pauli_string = []
    pauli_words_pool = ['']
    system_index_range = range(num_qubits)

    # repeat num_items times for k-local hamiltonian
    for i in range(num_items):
        
        runtime_index = 0
        for runtime_index in range(10000):
            pauli_word = random_pauli_word(system_index_range, k, set_of_pauli)
            if pauli_word not in pauli_words_pool:
                pauli_words_pool.append(pauli_word)
                break
            else:
                pauli_word = None
                continue
        if pauli_word == None:
            raise RuntimeError("Too many repeated terms!")

        res_pauli_string.append([random.uniform(coef_range[0], coef_range[1]), pauli_word])
        # res_pauli_string.append([-1.0, pauli_word])

    return res_pauli_string


def heisenberg_ground_state(num_qubits: int, model: List[float] = [1,1,1]):
    """Compute ground state of Heisenberg Chain

    Args:
        num_qubits (int): number of qubits
        model (List[float]): Jx, Jy, Jz

    Returns:
        ground energy: float
        State: Ground state
    """
    
    heis_Ham = get_1d_heisenberg_hamiltonian(num_qubits, j_x=model[0], j_y=model[1], j_z=model[2])
    heis_mat = heis_Ham.construct_h_matrix()
    e, v = np.linalg.eigh(heis_mat)
    hesi_GS = State(paddle.to_tensor(np.array(v[:, 0])))
    return e[0], hesi_GS


'''
Loading the numerical experimental test quantum states.
'''
heis_GS_state = State(np.load("./StateData/Heisenberg_GS_array_(1110)_(pbc).npy"))
heis_XXZ_GS_state = State(np.load("./StateData/Heisenberg_GS_array_(1120)_(pbc).npy"))
lih_GS_state = State(np.load("./StateData/lih_GS_array_12q.npy"))

Gaussian_dis_state = State(np.load("./StateData/Gaussian_dis_state_12q.npy"))
MNIST_state = State(np.load("./StateData/MNIST_array_12q.npy"))
RandomRank12_state = State(np.load("./StateData/randomrank12_array_12q.npy"))


# import ModelPool as a source package for state learning experiments
ModelPool = {"Heisenberg": heis_GS_state.ket @ heis_GS_state.bra, 
            "XXZ": heis_XXZ_GS_state.ket @ heis_XXZ_GS_state.bra,
            "LiH": lih_GS_state.ket @ lih_GS_state.bra,
            "RandomRank12": RandomRank12_state.ket @ RandomRank12_state.bra,
            "GHZ-like": ghz_state(12).ket @ ghz_state(12).bra,
            "MNIST_state": MNIST_state.ket @ MNIST_state.bra,
            "Gaussian_distribution": Gaussian_dis_state.ket @ Gaussian_dis_state.bra,
            "test": ghz_state(4).ket @ ghz_state(4).bra}