import QCompute as qc
from project_qsl import QCircuit, GHZSimCircuit, SWAPTESTCircuit
import numpy

__all__ = ["distance2norm", "fidelity", "distance2norm_scipyopt"]


class GHZCircuit(QCircuit):
    def __init__(self, num_qubits: int, backend: qc.BackendName = qc.BackendName.LocalBaiduSim2) -> None:
        super().__init__(num_qubits, backend)

        self.h(0)
        for q in range(1, num_qubits):
            self.cx(0, q)


def fidelity(cir: GHZSimCircuit, num_shots: int = 1024, qiskit_backend=None, error_mitigator=None) -> float:
    target_state_cir = GHZCircuit(cir.num_qubits)
    swap_qinds = list(zip(range(cir.num_qubits), range(cir.num_qubits)))
    swaptest_cir = SWAPTESTCircuit(cir, target_state_cir, swap_qinds)
    p0 = swaptest_cir.prob0(num_shots, qiskit_backend, error_mitigator)
    return 2*p0 - 1


def distance2norm(cir: GHZSimCircuit, num_shots: int = 1024, qiskit_backend=None, error_mitigator=None) -> float:
    pointer_index = cir._qubit_pointer
    swap_qinds = list(zip(range(pointer_index), range(pointer_index)))

    # calculate Tr(\rho^2)
    swaptest_cir = SWAPTESTCircuit(cir.copy(), cir.copy(), swap_qinds)
    prob0_rho = swaptest_cir.prob0(num_shots, qiskit_backend, error_mitigator)

    # calculate Tr(\rho\sigma)
    ghzcir = GHZCircuit(cir.num_qubits, cir.backendName)
    swaptest_cir = SWAPTESTCircuit(cir.copy(), ghzcir, swap_qinds)
    prob0_rs = swaptest_cir.prob0(num_shots, qiskit_backend, error_mitigator)

    if pointer_index < 4:
        return abs((2*prob0_rho - 1) + 0.5 - 2*(2*prob0_rs - 1))
    else:
        return abs(2 - 2*(2*prob0_rs - 1))


def distance2norm_scipyopt(x: numpy.ndarray, cir: GHZSimCircuit, num_shots: int = 1024, qiskit_backend=None, error_mitigator=None) -> float:
    cir.set_params(x)
    pointer_index = cir._qubit_pointer
    swap_qinds = list(zip(range(pointer_index), range(pointer_index)))

    # calculate Tr(\rho^2)
    swaptest_cir = SWAPTESTCircuit(cir.copy(), cir.copy(), swap_qinds)
    prob0_rho = swaptest_cir.prob0(num_shots, qiskit_backend, error_mitigator)

    # calculate Tr(\rho\sigma)
    ghzcir = GHZCircuit(cir.num_qubits, cir.backendName)
    swaptest_cir = SWAPTESTCircuit(cir.copy(), ghzcir, swap_qinds)
    prob0_rs = swaptest_cir.prob0(num_shots, qiskit_backend, error_mitigator)

    if pointer_index < 4:
        return abs((2*prob0_rho - 1) + 0.5 - 2*(2*prob0_rs - 1))
    else:
        return abs(2 - 2*(2*prob0_rs - 1))
