from AdamOptimizer import AdamOptimizer
from qiskit.circuit.library import UnitaryGate, RXGate, RYGate, RZGate, CXGate, CZGate
from Tools import Tools
from LinbladNoiseModel import InverseLindbladError,LindbladError
import numpy as np
from CacheData import get_cached_unitary


class NoiseMitigatedUnitaryCircuit:
    def __init__(self,theta,unitary_pauli_string, pauli_ops, lambdas, sigmas,support_qubits,gate_type='exp'):

        """
        Initialize a noise-mitigated circuit capable of applying single- and two-qubit gates.

        :param theta: Parameter for rotation
        :param unitary_pauli_string: The Pauli string for the exponentunitary
        :param pauli_ops: List of Pauli errors for the noise
        :param lambdas: Noise parameters for Lindblad noise
        :param sigmas: Noise parameters for inverse Lindblad noise
        :param support_qubits: The list of qubits that the unitary acts on
        :param gate_type: Specifies the type of gate ('exp', 'rx', 'ry', 'rz', 'cx', 'cz', 'swap'), default: 'exp'
        """
        
        self.unitary_pauli_string = unitary_pauli_string
        self.theta = theta
        self.pauli_ops = pauli_ops
        self.lambdas = lambdas
        self.sigmas = sigmas
        self.support_qubits = support_qubits
        self.pauli_ops = pauli_ops
        self.noise_model = LindbladError(self.pauli_ops, lambdas, self.support_qubits)
        self.inverse_noise_model = InverseLindbladError(self.pauli_ops, sigmas, self.support_qubits)
        self.gate_type=gate_type  # 'exp', 'rx', 'ry', 'rz', 'cx', 'cz'

        self.nonprox_theta = None
        self.nonprox_sigmas= None

        self.adam_theta = AdamOptimizer(learning_rate=0.003)
        self.adam_sigma = AdamOptimizer(learning_rate=0.003)


    def reinitialize_parameters(self, theta, sigmas):
        self.theta = theta
        self.sigmas = sigmas
        self.inverse_noise_model = InverseLindbladError(self.pauli_ops, sigmas, self.support_qubits)
        self.adam_theta.reinitialize()
        self.adam_sigma.reinitialize()
        print('theta ', self.theta , 'sigmas ', self.sigmas)
        
        
        
    def create_unitary_gate(self,theta=None):
        """ 
        Create a unitary gate based on the gate type. 
        Supports exp(iθP), RX, RY, RZ, CNOT (CX), CZ gates.
        """
        if self.gate_type == 'exp':
            unitary_gate = get_cached_unitary(theta, self.unitary_pauli_string)
        elif self.gate_type == 'rx':
            unitary_gate = RXGate(theta)
        elif self.gate_type == 'ry':
            unitary_gate = RYGate(theta)
        elif self.gate_type == 'rz':
            unitary_gate = RZGate(theta)
        elif self.gate_type == 'cx':
            unitary_gate = CXGate()
        elif self.gate_type == 'cz':
            unitary_gate = CZGate()
        else:
            raise ValueError(f"Unsupported gate type: {self.gate_type}")
        
        return unitary_gate
        
    def apply_unitary(self,theta ,circuit):
        """
          Apply the unitary gate to the circuit.

          :return: None
        """
        unitary_gate = self.create_unitary_gate(theta)
        #get_cached_unitary(theta, self.unitary_pauli_string)
        circuit.append(unitary_gate, self.support_qubits)

    def update_theta(self, grad_update):
        """
        update theta param
        """
        grad_dict = {'theta': grad_update}
        param_dict = {'theta': self.theta}
        updated_params = self.adam_theta.update(param_dict, grad_dict)
        self.nonprox_theta = updated_params['theta']
        self.theta = self.project_to_pi_interval(updated_params['theta'])

    def apply_noise_mitigation(self,circuit, proxsigmas = False):
        """
          Apply noise mitigation to the circuit.

          :return: None
        """
        if(proxsigmas):
            self.inverse_noise_model.update_probabilities(self.nonprox_sigmas)
        
        op = self.noise_model.apply(circuit).copy()
        op2 = self.inverse_noise_model.apply(circuit)
        op *=op2
        for q,p in zip(self.support_qubits, op[:]):
            circuit.append(p,[q])

    def apply_noise(self,circuit):
        op = self.noise_model.apply(circuit)
        for q,p in zip(self.support_qubits, op[:]):
            circuit.append(p,[q])

    def apply_noise_inverse(self,circuit):
        op = self.inverse_noise_model.apply(circuit)
        for q,p in zip(self.support_qubits, op[:]):
            circuit.append(p,[q])
        
    def build_mitigated_model(self,circuit, proxtheta= False, proxsigmas= False):
        """
          Build a noise model based on the Pauli operators and noise parameters.

          :return: A NoiseModel object.
        """
        #comment out when using all possible error term
        #assert len(self.lambdas) == 4**len(self.support_qubits)-1 , 'DOES NOT HAVE ENOUGH PARAMS FOR NOISE'
        # assert len(self.sigmas) == 4**len(self.support_qubits)-1 , 'DOES NOT HAVE ENOUGH PARAMS FOR INVERSE NOISE'
        if proxtheta:
            self.apply_unitary(self.nonprox_theta,circuit)
        else:
            self.apply_unitary(self.theta,circuit)
        self.apply_noise_mitigation(circuit, proxsigmas)

    def project_to_interval(self,x, lower=0, upper=1):
        """Projection onto an interval [lower, upper]"""
        return np.clip(x, lower, upper)

    def project_to_pi_interval(self,x,lower= -np.pi, upper=np.pi):
        """Projection onto [-pi, pi] """
        return np.clip(x, lower, upper)
        

    def update_sigmas(self,grad_updates):
        """
        update sigmas param
        """
        grad_dict = {'sigmas': np.array(grad_updates)}
        param_dict = {'sigmas': self.sigmas}
        updated_params = self.adam_sigma.update(param_dict, grad_dict)

        for i, updated_sigma in enumerate(updated_params['sigmas']):
            self.sigmas[i] = self.project_to_interval(updated_sigma)
            self.nonprox_sigmas = updated_params['sigmas']
        self.inverse_noise_model.update_probabilities(self.sigmas)


    def update_sigmas_idx(self,idx,grad_update):
        """
        update sigmas param
        """
        grad_updates = np.zeros_like(self.sigmas)
        grad_updates[idx] = grad_update
        grad_dict = {'sigmas': grad_updates}
        param_dict = {'sigmas': self.sigmas}
        updated_params = self.adam_sigma.update(param_dict, grad_dict)
        updated_term = updated_params['sigmas'][idx]
 
        self.sigmas[idx] = project_to_interval(updated_term)
        self.inverse_noise_model.update_probabilities(self.sigmas)
