from typing import List, Dict, Optional
import numpy as np
from copy import deepcopy
import QCompute as qc
from QCompute.QPlatform.Error import RuntimeError
from project_qsl.hamiltonian import QHamiltonian
from project_qsl.utils import set_cirline_param_, get_cirline_param


class QCircuit(qc.QEnv):
    def __init__(
        self,
        num_qubits: int,
        backend: qc.BackendName = qc.BackendName.LocalBaiduSim2
    ) -> None:
        super().__init__()
        self._num_qubits = num_qubits
        self.qlist = self.Q.createList(num_qubits)
        self.backend(backend)
        self._num_params = 0

    @property
    def num_qubits(self) -> int:
        return self._num_qubits

    @property
    def num_params(self) -> int:
        return self._num_params

    def get_params(self) -> "np.ndarray[float]":
        return get_cirline_param(self.circuit)

    def set_params(self, new_params: "np.ndarray[float]") -> None:
        try:
            assert len(new_params) == self._num_params
            set_cirline_param_(self.circuit, new_params)
        except AssertionError as e:
            raise RuntimeError(
                f"The number of parameters inside the circuit is {self._num_params:d}, while the number of newly passed parameters is {len(new_params):d}"
            ) from e

    def copy(self) -> "QCircuit":
        return deepcopy(self)

    def rx(self, qindex: int, theta: float) -> None:
        self._num_params += 1
        qc.RX(theta)(self.qlist[qindex])

    def ry(self, qindex: int, theta: float) -> None:
        self._num_params += 1
        qc.RY(theta)(self.qlist[qindex])

    def rz(self, qindex: int, theta: float) -> None:
        self._num_params += 1
        qc.RZ(theta)(self.qlist[qindex])

    def h(self, qindex: int) -> None:
        qc.H(self.qlist[qindex])

    def x(self, qindex: int) -> None:
        qc.X(self.qlist[qindex])

    def y(self, qindex: int) -> None:
        qc.Y(self.qlist[qindex])

    def z(self, qindex: int) -> None:
        qc.Z(self.qlist[qindex])

    def s(self, qindex: int) -> None:
        qc.S(self.qlist[qindex])
    
    def sdg(self, qindex: int) -> None:
        qc.SDG(self.qlist[qindex])

    def cx(self, ctrl_qindex: int, target_qindex: int) -> None:
        qc.CX(self.qlist[ctrl_qindex], self.qlist[target_qindex])

    def cz(self, ctrl_qindex: int, target_qindex: int) -> None:
        qc.CZ(self.qlist[ctrl_qindex], self.qlist[target_qindex])

    def cswap(self, ctrl_qindex: int, swap_qinds: List[int]) -> None:
        qc.CSWAP(self.qlist[ctrl_qindex], self.qlist[swap_qinds[0]], self.qlist[swap_qinds[1]])

    def run(self, nshots: int = 1024, measure_qindices: Optional[List[int]] = None) -> Dict:
        if isinstance(measure_qindices, list):
            qc.MeasureZ(
                [self.qlist[i] for i in measure_qindices],
                measure_qindices
            )
        else:
            # measure all
            qc.MeasureZ(*self.Q.toListPair())
        results = self.commit(nshots)
        return results["counts"]
    
    def run_qiskit(
        self,
        qiskit_backend,
        nshots: int = 1024,
        measure_qindices: Optional[List[int]] = None,
        error_mitigator=None
    ) -> Dict:
        from QCompute import CircuitToQasm
        from qiskit import QuantumCircuit, transpile

        if isinstance(measure_qindices, list):
            qc.MeasureZ(
                [self.qlist[i] for i in measure_qindices],
                measure_qindices
            )
        else:
            # measure all
            qc.MeasureZ(*self.Q.toListPair())
        self.publish(applyModule=False)
        qasm = CircuitToQasm().convert(self.program)

        results = qiskit_backend.run(
            transpile(
                QuantumCircuit.from_qasm_str(qasm), qiskit_backend
            ),
            shots=nshots
        ).result()

        if error_mitigator:
            results = error_mitigator.apply(results)
        return results.get_counts()

    def expectval(self, h: QHamiltonian, nshots: int = 1024, qiskit_backend = None) -> float:
        val = 0.0
        for coef, qpauli in h.h_terms():
            cir = self.copy()
            cir.join(qpauli)
            if qiskit_backend:
                result = cir.run_qiskit(qiskit_backend, nshots, qpauli.measuredQs)
            else:
                result = cir.run(nshots, qpauli.measuredQs)
            for rstr, count in result.items():
                val += coef*(count/nshots)*((-1)**(rstr.count("1")))
        return val
