import numpy as np
from functools import reduce
from itertools import product
from scipy.linalg import expm

class Tools:

    PauliMatrices = {'I' : np.eye(2), \
                 'X' : np.array([[0, 1], [1, 0]]), \
                 'Y' : np.array([[0, -1j], [1j, 0]]), \
                 'Z' : np.array([[1, 0], [0, -1]])}

    @staticmethod
    def kron(matrices : list):
        return reduce(np.kron, matrices)

    @staticmethod
    def pstr2mat(pauli_string : str, pauliMatrices = PauliMatrices):
        """
            turn the Pauli string to the corresponding Pauli matrix
        """
        return Tools.kron([pauliMatrices[P] for P in pauli_string])

    @staticmethod
    def generateAllPossiblePauliString(length : int,excludeIdentity =True,pauliMatrices = PauliMatrices):
        """
            Generate all length n Pauli strings
        """
        pauli_combinations = list(product(pauliMatrices.keys(), repeat=length))
        pauli_strings = [''.join(pauli_string) for pauli_string in pauli_combinations]
        if excludeIdentity:
            return pauli_strings[1:]
        return pauli_strings

    @staticmethod
    def randThetaParameterInitiatization(size=1,seed=54):
        np.random.seed(seed)#for quick check
        return np.random.uniform(-np.pi, np.pi,size)

    @staticmethod
    def randParameterInitiatization(nParams,seed=54):
        np.random.seed(seed)#f
        return np.random.uniform(-np.pi, np.pi), \
                np.random.uniform(0, 0, size=(nParams, ))

    @staticmethod
    def randNoiseParameterInitiatization(nParams,seed=54, max_value =.2):
        np.random.seed(seed)
        return np.random.rand(nParams)* max_value

    @staticmethod
    def randinverseNoiseParameterInitiatization(nParams,seed=54):
        #np.random.seed(seed)
        return np.random.uniform(0, 0, size=(nParams, ))

    @staticmethod
    def randInit(nParams,seed=54):
        np.random.seed(seed)
        return np.random.uniform(-np.pi, np.pi), \
                np.random.uniform(0, 1, size=(nParams, ))

    @staticmethod
    def unitary_exponent(theta, pauli_string):
        """
            Create a unitary matrix U = exp(i * theta * P) where P is a tensor product of Pauli matrices.

            Parameters:
            - theta: The angle for the exponential term.
            - pauli_string: A string of Pauli operators, e.g., "ZX" for Z⊗X.

            Returns:
            - unitary_matrix: The resulting unitary matrix.
        """
        # Convert the Pauli string to a tensor product of Pauli matrices
        P = Tools.pstr2mat(pauli_string)
        # Calculate the unitary matrix using the exponential formula
        return np.cos(theta) * np.eye(P.shape[0]) + 1j * np.sin(theta) * P

        #return expm(1j * theta * P)

    @staticmethod
    def apply_gates_on_qubits(circuit, pauli_string,qubits):
        """
          Apply a sequence of Pauli gates to specified qubits in a quantum circuit.

          Args:
              circuit (QuantumCircuit): The quantum circuit to which the gates will be applied.
              pauli_string (str): A string representing the sequence of Pauli gates (e.g., "XI").
              support_qubits (list): A list of qubits to which the gates will be applied.
        """
        assert len(pauli_string) == len(qubits), 'Pauli string must match the number of support qubits.'

        # Dictionary to map Pauli characters to Qiskit gate methods
        pauli_gates = {
            'I': circuit.id,   # Identity gate
            'X': circuit.x,   # Pauli-X gate
            'Y': circuit.y,   # Pauli-Y gate
            'Z': circuit.z    # Pauli-Z gate
        }

        # Apply each gate to the corresponding qubit
        for gate, qubit in zip(pauli_string, qubits):
            if gate not in pauli_gates:
                raise ValueError(f"Unsupported Pauli gate: {gate}")
            pauli_gates[gate](qubit)

    @staticmethod        
    def get_sign_with_threshold(value, threshold=1e-6):
        if np.abs(value) < threshold:
            return 0
        return np.sign(value)
        
    @staticmethod        
    def multiply_overhead(value, overhead,threshold=1e-6):
        if np.abs(value) < threshold:
            return value*overhead
        return value  