from typing import List
import numpy as np
from QCompute.QPlatform.QOperation import CircuitLine
from QCompute.QPlatform.QOperation.RotationGate import RotationGateOP

__all__ = ["set_cirline_param_", "get_cirline_param", "partial_trace"]


def set_cirline_param_(clines: "List[CircuitLine]", params: "np.ndarray[float]") -> "List[CircuitLine]":
    pointer = 0
    for cline in clines:
        if isinstance(cline.data, RotationGateOP):
            num_params = len(cline.data.argumentList)
            cline.data.argumentList = params[pointer:pointer+num_params].tolist()
            pointer += num_params
    return clines


def get_cirline_param(clines: "List[CircuitLine]") -> "np.ndarray[float]":
    parameter_list = []
    for cline in clines:
        if isinstance(cline.data, RotationGateOP):
            parameter_list.extend(cline.data.argumentList)
    return np.asarray(parameter_list)


# NOTE: QCompute and QEP use little endian, in which 01 -> |10\rangle
def partial_trace(rho: np.ndarray, trace_out_qinds: List[int]) -> np.ndarray:
    if not trace_out_qinds:
        return rho

    d = rho.shape[0]
    num_qubits = int(np.log2(d))
    num_rest_qubits = num_qubits - len(trace_out_qinds)

    rrho = np.zeros([2**num_rest_qubits, 2**num_rest_qubits], dtype=rho.dtype)
    for i in range(2**num_rest_qubits):
        rrho[i, i] = np.trace(rho[i:2**num_qubits:2**num_rest_qubits, i:2**num_qubits:2**num_rest_qubits])
        for j in range(i+1, 2**num_rest_qubits):
            rrho[i, j] = np.trace(rho[i:2**num_qubits:2**num_rest_qubits, j:2**num_qubits:2**num_rest_qubits])
            rrho[j, i] = rrho[i, j].conj()

    return rrho
