from typing import List, Optional
import numpy as np
import QCompute as qc
from QCompute.QPlatform.Error import RuntimeError
from project_qsl.circuit import QCircuit
from project_qsl.utils import set_cirline_param_, get_cirline_param


class Univ2QCir(QCircuit):
    def __init__(
        self,
        params: Optional["np.ndarray[float]"] = None,
        backend: qc.BackendName = qc.BackendName.LocalBaiduSim2
    ) -> None:
        super().__init__(2, backend)

        if params is None:
            params = 2*np.pi*np.random.rand(4)
        elif isinstance(params, np.ndarray):
            if len(params) != 4:
                raise ValueError(f"`Univ2Qcir` has exactly 4 parameters, but {len(params):d} parameters are passed.")
        else:
            raise ValueError("Parameters must be put into an np array.")

        self.ry(0, params[0])
        self.ry(1, params[1])
        self.cx(0, 1)
        self.ry(0, params[2])
        self.ry(1, params[3])


class GHZSimCircuit(QCircuit):
    def __init__(self, num_qubits: int, backend: qc.BackendName = qc.BackendName.LocalBaiduSim2) -> None:
        super().__init__(num_qubits, backend)
        self._qubit_pointer = 0
        self._cur_block_cline_start = 0
        self._block_cline_width = 0
        self._depth = 0

    @property
    def num_all_params(self) -> int:
        return self._num_params

    @property
    def num_params(self) -> int:
        if self._block_cline_width == 0:
            return 0
        elif (self._block_cline_width != 0) and (self._qubit_pointer < self.num_qubits - 1):
            return 4 * self._depth
        elif self._qubit_pointer == self.num_qubits - 1:
            return 1 * self._depth

    def get_all_params(self) -> np.ndarray:
        return get_cirline_param(self.circuit)

    def set_all_params(self, new_params: np.ndarray) -> None:
        try:
            assert len(new_params) == self.num_all_params
            set_cirline_param_(self.circuit, new_params)
        except AssertionError as e:
            raise RuntimeError(
                f"The number of parameters inside the circuit is {self._num_params:d}, while the number of newly passed parameters is {len(new_params):d}"
            ) from e

    def get_params(self) -> np.ndarray:
        return get_cirline_param(self.circuit[self._cur_block_cline_start:self._cur_block_cline_start+self._block_cline_width])

    def set_params(self, new_params: np.ndarray) -> None:
        set_cirline_param_(self.circuit[self._cur_block_cline_start:self._cur_block_cline_start+self._block_cline_width], new_params)

    def add_block(self, depth: int = 1, params: Optional["np.ndarray[float]"] = None) -> None:
        self._cur_block_cline_start += self._block_cline_width
        self._depth = depth

        if self._qubit_pointer < self.num_qubits - 1:
            new_block = Univ2QCir(params, self.backend)
            self.join(new_block, startQRegIndex=self._qubit_pointer)
            self._num_params += new_block.num_params
            self._block_cline_width = 5 * depth
        elif self._qubit_pointer == self.num_qubits - 1:
            if params is None:
                params = np.random.rand(1)*2*np.pi
            self.ry(self._qubit_pointer, params[0])
            self._block_cline_width = 1 * depth
        else:
            raise RuntimeError(f"Can't add more circuit block, the number of qubits required exceeds the number of qubits {self.num_qubits:d} in the circuit.")

        self._qubit_pointer += 1
