from typing import List, Tuple
from project_qsl.circuit import QCircuit

__all__ = ["SWAPTESTCircuit"]


class SWAPTESTCircuit(QCircuit):
    def __init__(self, cir1: QCircuit, cir2: QCircuit, swap_qinds: List[Tuple[int]]) -> None:
        num_qubits = cir1.num_qubits + cir2.num_qubits + 1
        cir2.backend(cir1.backendName)
        super().__init__(num_qubits, cir1.backendName)

        self.h(0)
        self.join(cir1, 1)
        self.join(cir2, 1 + cir1.num_qubits)
        for qs in swap_qinds:
            self.cswap(0, [qs[0]+1, qs[1]+1+cir1.num_qubits])
        self.h(0)

    def prob0(self, num_shots: int = 1024, qiskit_backend=None, error_mitigator=None) -> float:
        if qiskit_backend:
            results = self.run_qiskit(qiskit_backend, num_shots, measure_qindices=[0], error_mitigator=error_mitigator)
        else:
            results = self.run(num_shots, measure_qindices=[0])
        return results["0"] / num_shots