import numpy as np
from Tools import Tools
import random
from qiskit.quantum_info import Pauli
from CacheData import get_cached_pauli
        
class BaseLindbladError:
    def __init__(self, pauli_ops, params, support_qubits):
        self.pauli_ops = pauli_ops
        self.params = params
        self.support_qubits = support_qubits
        #self._applied_errors = []  # Track the applied errors


        self._initialize_probabilities()

    def _initialize_probabilities(self):
        self._identity_probs = [(1 + np.exp(-2 * param)) / 2 for param in self.params]
        self._probabilities = [1 - prob for prob in self._identity_probs]
        #print("Model's Pauli probabilities:", self._probabilities)
        #print("Model's Identity probabilities:", self._identity_probs)

    def update_probabilities(self, new_params):
        # Update sigmas and probabilities
        self.params = new_params
        self._initialize_probabilities()

    def apply(self, circuit):
        """ Apply the noise model to the circuit. """

        raise NotImplementedError("This method should be implemented by subclasses.")

    def get_applied_errors(self):
        return self._applied_errors


    def _apply_pauli_string(self, circuit, pauli_string):
        Tools.apply_gates_on_qubits(circuit, pauli_string, self.support_qubits)


class LindbladError(BaseLindbladError):
    def __init__(self, pauli_ops, params,support_qubits):
        super().__init__(pauli_ops, params,support_qubits)

    def apply(self, circuit):
        """ Apply the noise model to the circuit. """
        #self._applied_errors.clear()  # Reset applied errors tracking
        random.seed(54)
        operator = Pauli("I"*len(self.support_qubits))
        
        for (pauli_op, prob) in zip(self.pauli_ops, self._probabilities):
            if random.random() < prob:
                operator*=get_cached_pauli(pauli_op) 
                #self._applied_errors.append(pauli_op)  # Track applied Identity operation
        
        return operator

class InverseLindbladError(BaseLindbladError):
    def __init__(self, pauli_ops, params,support_qubits):
        super().__init__(pauli_ops, params,support_qubits)
        self._sign = 0

    def apply(self, circuit):
        """ Apply the noise model to the circuit. """
        random.seed(54)
        self._sign = 0
        #self._applied_errors.clear()  # Reset applied errors tracking
        operator = Pauli("I"*len(self.support_qubits))
        
        for (pauli_op, prob) in zip(self.pauli_ops, self._probabilities):
            if random.random() < prob:
                operator*=get_cached_pauli(pauli_op) 
                self._sign += 1 # Increment Sign count

        return operator

    def get_sign_counts(self):
        return self._sign

    def get_overhead_gamma(self):
      return np.exp(2*np.sum(self.params))