import numpy as np
from Tools import Tools
from GenericNoiseMitigatedUnitaryCircuit import NoiseMitigatedUnitaryCircuit
from QuantumDataGenarator import QuantumDataGenarator

from qiskit import QuantumCircuit
from qiskit_aer import Aer,AerSimulator
from qiskit.quantum_info import SparsePauliOp
from qiskit.compiler import transpile
from qiskit_aer.primitives import EstimatorV2
from CacheData import get_cached_pauli
from sklearn.model_selection import KFold,StratifiedKFold
from itertools import chain
from qiskit_ibm_runtime.fake_provider import FakeVigoV2
from qiskit.quantum_info import Pauli
import pickle
from random import random

class ExecutionMode:
    NOISELESS = "noiseless"
    NOISY = "noisy"
    MITIGATED = "mitigated"
    VANDEN = "vandenberg"

class ExperimentsHEA:
    def __init__(self,num_layers ,number_of_qubits,mode,seed=54):
        self.seed = seed
        np.random.seed(self.seed)
        self.number_of_qubits = number_of_qubits
        self.mode=mode

        '''
        Select appropriate Observable based on problem, inputs and system.
        '''
        
        #self.observable = self.create_observable_matrix()

        # for all Z 
        self.observable = Tools.pstr2mat('Z'*self.number_of_qubits)

        #for Z on first qubit only mnist
        #self.observable = self.pauli_z_first_qubit()
        
        self.num_layers = num_layers
        self.pauli_observable = SparsePauliOp.from_operator(self.observable)
        self.qc = QuantumCircuit(self.number_of_qubits)
        
        self.presetimated_noise = None
        if self.mode == ExecutionMode.VANDEN :
            with open('vandenNoiseModel.pkl', 'rb') as f:
                self.presetimated_noise = pickle.load(f)

        #if using CPU
        self.back_end = AerSimulator(max_parallel_experiments=8) 
        #self.back_end =  Aer.get_backend('qasm_simulator',max_parallel_experiments=0)
        # self.back_end = FakeVigoV2()
        # self.estimator = EstimatorV2()
        
        #Shots needed for each component 
        self.shots = 2048

        # (1/p) value for tradeoff between number of mesaurement and convergence.  
        self.prob=1

        #if using GPU
        #self.back_end = Aer.get_backend('qasm_simulator',device='GPU')
        
        self._init_thetas = Tools.randThetaParameterInitiatization(self.num_layers*self.number_of_qubits*3,self.seed)
        self.HEA()
        

    def HEA(self):
        # Implement hardware efficent ansatz usig nmPQC
        '''
        list error: is the set of possible pauli errors associated with the gate, by default it will be set to all                          available paulis if user has no knowledge about the system. If certain knowledge persists, only a handful pauli 
        error can be set as well.  

        Sparse dominant subset commonly used: {IX,XI,IZ,ZI,XX,YY,ZZ}. This is motivated by the fact that single-qubit 
        dephasing (Z) and correlated two-qubit errors (XX,YY,ZZ) dominate on many NISQ devices. For 1 qubit error its {X,Z}.
        '''
        #For same error for all gates set list errors and lambda's here

        # list_errors_2_qubits= ['IX','XI','IZ','ZI','XX','YY','ZZ']
        # #list_errors=['XX','XY','XZ']
        # lambdas_2_qubits = Tools.randNoiseParameterInitiatization(len(list_errors_2_qubits))

        # list_errors_1_qubit= ['X','Z']
        # lambdas_1_qubit = Tools.randNoiseParameterInitiatization(len(list_errors_1_qubit))
        self.neurons = []
        loc=0

        #for the dynamic case use this variable to add dynamic noise for simulation
        error_induced_cnot_loc =0
        
        for k in range(self.num_layers):
            # single-qubit gates
            
            for j in range(self.number_of_qubits):
                for gate_type in ['rx', 'ry', 'rz']:
                    '''
                    generate 1 qubit gate diff for each
                    '''
                    #list_errors = Tools.generateAllPossiblePauliString(1)
                    list_errors=['X','Z']
                    
                    theta = self._init_thetas[loc]
                    loc=loc+1
                    sigmas= Tools.randinverseNoiseParameterInitiatization(len(list_errors))           
                    lambdas = Tools.randNoiseParameterInitiatization(len(list_errors))
                    #lambdas= None
                    self.neurons.append(NoiseMitigatedUnitaryCircuit(
                        theta,  
                        None,
                        list_errors,
                        lambdas,
                        sigmas,
                        [j],
                        gate_type=gate_type
                    ))
            
            '''
            two-qubit gates diff for each
            '''
            # list_errors = Tools.generateAllPossiblePauliString(2)
            list_errors = ['ZI','XX','YY','IZ']
            
            for j in range(self.number_of_qubits - 1):
                sigmas= Tools.randinverseNoiseParameterInitiatization(len(list_errors))           
                lambdas = Tools.randNoiseParameterInitiatization(len(list_errors))
                
                
                # Needed for dynamic error scenario - uncomment following when applying dynamic scenario
                
                # if error_induced_cnot_loc is in [0,2,7]:
                #     lambdas += 0.05
                #lambdas= None
                self.neurons.append(NoiseMitigatedUnitaryCircuit(
                    None,  
                    None,
                    list_errors,
                    lambdas,
                    sigmas,
                    [j,j + 1],
                    gate_type='cx'
                ))
                error_induced_cnot_loc+=1
    
            sigmas= Tools.randinverseNoiseParameterInitiatization(len(list_errors))           
            lambdas = Tools.randNoiseParameterInitiatization(len(list_errors))
            #lambdas= None
            self.neurons.append(NoiseMitigatedUnitaryCircuit(
                None,  
                None,
                list_errors,
                lambdas,
                sigmas,
                [self.number_of_qubits - 1, 0],
                gate_type='cx'
            ))
            error_induced_cnot_loc+=1
            # self.view_params()

    def view_params(self):
        for i in range(len(self.neurons)):
            print("theta",self.neurons[i].theta, self.neurons[i].sigmas, self.neurons[i].lambdas)

    def reinitialize_parameters():
    # Reinitialize the parameters
        loc=0
        for i in range(len(self.neurons)):
            if self.neurons[i].gate_type in ['rx', 'ry', 'rz']:
                self.neurons[i].reinitialize_parameters(
                    self._init_thetas[loc], 
                    Tools.randinverseNoiseParameterInitiatization(len(self.neurons[i].sigmas)) 
                )
                loc+=1
            else:   
                self.neurons[i].reinitialize_parameters(
                    None, 
                    Tools.randinverseNoiseParameterInitiatization(len(self.neurons[i].sigmas)) 
                )

    def universal_algorithm(self, qdata, current_neuron, shifted_theta=None, sigma_q=None, apply_sigma=False):
        """
        Applies the necessary operations (unitary, noise, inverse noise, and Pauli operator) 
        to the quantum circuit based on the neuron type and current configuration, respecting the order of operations.
        The overhead(gamma) is handled during execution.
        
    
        Args:
            qdata: Quantum data to initialize the circuit.
            current_neuron: The neuron to apply transformations to.
            sigma_q: Optional Pauli operator for the current neuron (used only for `applysigma`).
            shifted_theta: Optional shifted theta for the current neuron (used for `get_exp_circuit_neuron_shifted`).
            apply_sigma: Flag to indicate whether to apply the Pauli operator (used for `applysigma`).
    
        Returns:
            A circuit and sign count for the experiment.
        """
        circuit = self.qc.copy_empty_like()
        circuit.initialize(qdata, range(self.number_of_qubits))
        count = 0

        for neuron in self.neurons:
            # Apply unitary for the neuron
            if shifted_theta is not None and neuron == current_neuron:
                neuron.apply_unitary(shifted_theta, circuit)
            else:
                neuron.apply_unitary(neuron.theta, circuit)

            # Apply noise for each neuron
            if self.mode in [ExecutionMode.NOISY , ExecutionMode.MITIGATED, ExecutionMode.VANDEN]:
                neuron.apply_noise(circuit)

            # If sigma_q is provided, apply the Pauli operator to the specific neuron
            if apply_sigma and sigma_q is not None and neuron == current_neuron:
                operator = get_cached_pauli(sigma_q)
                for q, p in zip(neuron.support_qubits, operator[:]):
                    circuit.append(p, [q])
                    
            # Apply noise inverse for each neuron
            if self.mode == ExecutionMode.MITIGATED:
                neuron.apply_noise_inverse(circuit)
                count += neuron.inverse_noise_model.get_sign_counts()

            #van der berg mitigation
            if self.mode == ExecutionMode.VANDEN and neuron.gate_type in ['cx','cz']:
                operator,sign = self.sample_vandenberg_inverse()
                for q, p in zip(range(self.number_of_qubits), operator[:]):
                    circuit.append(p, [q]) 
                count+= sign 
        
    
        circuit.measure_all()
    
        return circuit, count        

    def apply_parameter_shift(self,qdata, label,exp_value,prob =1):
        """
        Apply the parameter shift rule to update parameters for each neuron.
        """

        grads = []
        grads_sign=[]
        for current_neuron in self.neurons:

            # Shift the parameter
            if current_neuron.theta is None:
                continue

            choice = np.random.choice([0,1], p=[prob, 1-prob])
            if(choice == 0):
                shifted_theta_plus = current_neuron.theta + np.pi / 4
                shifted_theta_minus = current_neuron.theta - np.pi / 4

                # can be used as well
                # expected=[]
                # for i in range(self.shots):
                #     b = np.random.choice([0,1], p=[0.5, 0.5])
                #     if(b == 0):
                #         expected.append(self.get_exp_circuit_neuron_shifted(qdata,current_neuron,shifted_theta_plus,shots=1))
                #     else:
                #         expected.append(self.get_exp_circuit_neuron_shifted(qdata,current_neuron,shifted_theta_minus,shots=1))
                        
                # gradient = 2* np.sqrt(2)*(exp_value - label)* expected
                
                exp_value_plus= self.get_exp_circuit_neuron_shifted(qdata,current_neuron,shifted_theta_plus)             
                exp_value_minus = self.get_exp_circuit_neuron_shifted(qdata,current_neuron,shifted_theta_minus)
    
                
                gradient = 2*(exp_value - label)*(1/np.sqrt(2))* (exp_value_plus - exp_value_minus)
                grads.append(gradient)
                gradient_sign = Tools.get_sign_with_threshold(gradient)
                grads_sign.append(gradient_sign)
                current_neuron.update_theta(gradient_sign)
            else:
                continue
            
        return grads, grads_sign    


    def get_exp_circuit_neuron(self,qdata, shots = None):

        if shots is None:
            shots = self.shots

        instances=[]
        signs =[]
   
        for i in range(shots):
            circuit,count=self.universal_algorithm(
                qdata=qdata,
                current_neuron=None,
                shifted_theta=None,
                sigma_q=None,
                apply_sigma=False
            )
            instances.append(circuit)
            signs.append(count)
            
        #instances = transpile(instances, self.back_end,optimization_level=0)
        results = self.executor(instances)
         
        return self.expectation_value_qasm(results,signs) 

    def executor(self,circuits,shots=1):
        # if(len(circuits)>1):
        #     batch_size = 8  # Set this to the number of parallel experiments you want
        #     results = []
        #     for i in range(0, len(circuits), batch_size):
        #         batch = circuits[i:i + batch_size]
        #         batch_result = self.back_end.run(batch, shots=shots).result().get_counts()
        #         results.append(batch_result)
           
        #     return list(chain.from_iterable(results))

        return self.back_end.run(circuits, shots = shots).result().get_counts()
        
    def expectation_value_qasm(self,results,signs):
        overhead = 1
        for neuron in self.neurons:
            overhead*= neuron.inverse_noise_model.get_overhead_gamma()  
            
        if not isinstance(results, list):
            
            index = int(list(results.keys())[0],2)
            return Tools.multiply_overhead(self.observable[index, index]*(-1)**signs[0],overhead)
            
        exp_value = []
        for i in range(len(results)):
            index = int(list(results[i].keys())[0],2)
            exp_value.append(self.observable[index, index]*(-1)**signs[i])
        return Tools.multiply_overhead(np.mean(exp_value),overhead)


    def get_exp_circuit_neuron_shifted(self,qdata,current_neuron, shifted_theta, shots = None):

        if shots is None:
            shots = self.shots //2 
        
        instances=[]
        signs =[]

        for i in range(shots):
            circuit,count=self.universal_algorithm(
                qdata=qdata,
                current_neuron=current_neuron,
                shifted_theta=shifted_theta,
                sigma_q=None,
                apply_sigma=False
            )
            instances.append(circuit)
            signs.append(count) 
            
        #instances = transpile(instances, self.back_end,optimization_level=0)
        results = self.executor(instances)
     
        return self.expectation_value_qasm(results,signs)


    def initialize_circuit(self, circuit,state_vector):
        """
        Initialize a quantum circuit with a given state vector.

        :param circuit: The Qiskit QuantumCircuit object.
        """
        circuit.initialize(state_vector, range(self.number_of_qubits))


    def get_signleshot_estimate(self,circuit):
        circuit.measure_all()
        #qc_compiled = transpile(circuit, self.back_end)
        result = self.execute(circuit, shots=1)
        
        state_vector = np.zeros(2**self.number_of_qubits)
        bitstring = list(result.get_counts(circuit).keys())[0]
        index = int(bitstring, 2)
        state_vector[index] = 1
        return np.real(np.conj(state_vector) @ self.observable @ state_vector)


    def expectation_value_estimator(self,circuit,shots):    
        job = self.estimator.run([(circuit,[self.pauli_observable])])

        exp_val = job.result()[0].data.evs[0]
        return exp_val


    def sigma_gradient(self,qdata, label, exp_val,prob =1):
      grads=[]
      grads_sign=[]
      for neuron in self.neurons:
        sigma_neuron_gradients = []
        gs=[]
        
        for index,sigma_q in enumerate(neuron.pauli_ops):
            choice = np.random.choice([0,1], p=[prob, 1-prob])
            if(choice == 0):
                updte = self.apply_sigma_gradient(qdata,label,neuron, sigma_q, exp_val)
            else:
                updte= 0
            sigma_neuron_gradients.append(updte)
            gs.append(np.sign(updte))
        grads.append(sigma_neuron_gradients)
        grads_sign.append(gs)
        neuron.update_sigmas(sigma_neuron_gradients)

      return grads,grads_sign

    def applysigma(self,qdata,current_neuron,sigma_q, shots= None):
            
        if shots is None:
            shots = self.shots

        instances=[]
        signs =[]
        
        for i in range(shots):
            circuit,count=self.universal_algorithm(
                qdata=qdata,
                current_neuron=current_neuron,
                shifted_theta=None,
                sigma_q=sigma_q,
                apply_sigma=True
            )
            instances.append(circuit)
            signs.append(count) 
            
        #instances = transpile(instances, self.back_end,optimization_level=0)
        results = self.executor(instances)
     
        return self.expectation_value_qasm(results,signs)


    def apply_sigma_gradient(self,qdata,label,current_neuron,sigma_q, pred,shots = None):


        #using measurement shots
        #pred = self.get_exp_circuit_neuron(qdata)

        if shots is None:
            shots = self.shots
        exp_val =[]

        for i in range(shots):
            choice = np.random.choice([0,1], p=[0.5, 0.5])

            if(choice == 0):
                exp_val.append(self.get_exp_circuit_neuron(qdata,shots = 1))
            else:
                exp_val.append((-1)*self.applysigma(qdata,current_neuron,sigma_q,shots=1))



        return Tools.get_sign_with_threshold((pred-label)*(4* np.mean(exp_val))) 

    def execute(self,circuit, shots):
        return self.back_end.run(circuit, shots=shots).result()


    def expectation_value_density_matrix(self,density_matrix):
      """
      Compute the expectation value as Trace(rho * observable).
      :param density_matrix: The quantum density matrix.
      :return: The expectation value.
      """
      return np.trace(density_matrix @ self.observable)

    def testing_circuit(self,test_data,test_label):
        accuracy=[]
        accuracy_pred=[]
        for index in range(len(test_label)):
            qdata = test_data[index]
            label = test_label[index] 

            prediction = self.get_exp_circuit_neuron(qdata)
            accuracy.append(self.accuracy_calc(prediction, label))
            accuracy_pred.append(self.accuracy_calc(prediction, label,True))
        return np.mean(accuracy),np.mean(accuracy_pred)

    def training_circuit(self,train_data,data_set,epochs=100, batches=50):
        accuracy=[]
        accuracy2=[]
        mse=[]
        best_loss = 0.20
        
        for i in range(epochs):
            loss=[]
            acc=[]
            acc2=[]
            preds=[]
            labels=[] 
            for j in range(batches):
                index = (i*batches + j)% len(train_data)
                #index = i*batches + j
                
                qdata = train_data[index]
                label = data_set[index]
                
                prediction = self.get_exp_circuit_neuron(qdata)
                l= self.mse_loss(prediction, label)
                loss.append(l)
                acc.append(self.accuracy_calc(prediction, label))
                acc2.append(self.accuracy_calc(prediction, label,True))
                preds.append(prediction)
                labels.append(label)
                #if l > best_loss:
                param_grads,param_signs=self.apply_parameter_shift(qdata, label,prediction,self.prob)
                if self.mode == ExecutionMode.MITIGATED:
                    sigma_grads,sigma_signs = self.sigma_gradient(qdata, label,prediction,self.prob)
                    

            mse.append(np.mean(loss))
            accuracy.append(np.mean(acc))
            accuracy2.append(np.mean(acc2))

            # print(f'Epoch {i+1}: Loss = {np.mean(loss):.4f},Theta = {[neuron.theta for neuron in self.neurons]},Sigma = {[neuron.sigmas for neuron in self.neurons]}, Accuracy = {np.mean(acc):.2f}, AccuracyR = {np.mean(acc2):.2f}')

        return  accuracy, mse, accuracy2


    def train(self,features,labels,epochs=10,samples = 500,k_folds =5):
        accuracy_scores=[]
        mse_scores=[]
        best_loss = 0.20
        
        kf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=42)

        # Perform k-fold cross-validation
        for fold, (train_indices, val_indices) in enumerate(kf.split(features,labels)):
            print(f'\nStarting fold {fold + 1}/{k_folds}')

            
            fold_accuracy = []
            fold_mse = []
            epoch_mse=[]
            epoch_accuracy=[]
 
            train_subset = [features[i] for i in train_indices]
            train_labels = [labels[i] for i in train_indices]
            val_subset = [features[i] for i in val_indices]
            val_labels = [labels[i] for i in val_indices]


            self.reinitialize_parameters()
    
            # Define batch size based on the length of the training subset and the number of epochs
            batches = len(train_subset) // epochs

        
            for i in range(epochs):
                loss=[]
                acc=[]
                preds=[]
                label_ep=[] 
                for j in range(batches):
                    index = i*batches + j
                    
                    label = train_labels[index] 
                    qdata = train_subset[index]

                    prediction = self.get_exp_circuit_neuron(qdata)
                    l = self.mse_loss(prediction, label)
                    loss.append(l)
                    acc.append(self.accuracy_calc(prediction, label))
                    preds.append(prediction)
                    label_ep.append(label)
                    #if l > best_loss:
                    param_grads,param_signs=self.apply_parameter_shift(qdata, label,prediction)
                    if self.mode == ExecutionMode.MITIGATED:
                        sigma_grads,sigma_signs = self.sigma_gradient(qdata, label, prediction)
                    
    
                epoch_mse.append(np.mean(loss))
                epoch_accuracy.append(np.mean(acc))
    
                # print(f'Fold {fold + 1},Epoch {i+1}: Loss = {np.mean(loss):.4f},Theta = {[neuron.theta for neuron in self.neurons]},Sigma = {[neuron.sigmas for neuron in self.neurons]}, Accuracy = {np.mean(acc):.2f}')

            fold_accuracy.append(epoch_accuracy)
            fold_mse.append(epoch_mse)

            val_loss = []
            val_acc = []
            val_preds=[]
            val_label=[]
            
            # print(f'ValidationTheta = {[neuron.theta for neuron in self.neurons]},Sigma = {[neuron.sigmas for neuron in self.neurons]}, Accuracy = {np.mean(acc):.2f}')

            for i,qdata in enumerate(val_subset):
                label = val_labels[i]
                prediction = self.get_exp_circuit_neuron(qdata)
                val_loss.append(self.mse_loss(prediction, label))
                val_acc.append(self.accuracy_calc(prediction, label))
                val_preds.append(prediction)
                val_label.append(label)
            
            # Store results for this fold
            accuracy_scores.append(np.mean(val_acc))
            mse_scores.append(np.mean(val_loss))

            # print(f'Fold {fold + 1} Validation: MSE = {np.mean(val_loss):.4f}, Accuracy = {np.mean(val_acc):.2f}')
        
        # Final cross-validation results
        print(f'\nCross-Validation Results over {k_folds} folds:')
        print(f'Accuracy: {np.mean(accuracy_scores)},Average Accuracy: {np.mean(accuracy_scores):.2f} ± {np.std(accuracy_scores):.2f}')
        print(f'MSE :{np.mean(mse_scores):.4f},Average MSE: {np.mean(mse_scores):.4f} ± {np.std(mse_scores):.4f}')
       
        return fold_accuracy, fold_mse, accuracy_scores, mse_scores


    def mse_loss(self,prediction, label):
        """Compute the mean squared error loss."""
        return (prediction - label) ** 2

    def accuracy_calc(self,prediction, label, flag = False):
        """Compute accuracy based on prediction and label."""
        if(flag):
            predicted_label = self.sample_based_on_expectation(prediction)
        else:    
            predicted_label = 1 if prediction > 0 else -1
        
        return 1 if predicted_label == label else 0

    def sample_based_on_expectation(self,exp_value):
        # Ensure the expectation value is in the range [-1, 1]
        if not -1 <= exp_value <= 1:
            raise ValueError("Expectation value must be in the range [-1, 1].")
        
        # Compute probabilities for binary outcomes (+1 and -1)
        p_plus = (1 + exp_value) / 2
        p_minus = 1 - p_plus  # Equivalent to (1 - exp_value) / 2
        
        # Generate a random sample
        return np.random.choice([1, -1], p=[p_plus, p_minus])

    def compute_mse(self,predictions, true_labels):
      """Compute Mean Squared Error."""

      predictions = np.array(predictions)
      true_labels = np.array(true_labels)
      mse = np.mean((predictions - true_labels) ** 2)
      return mse


    def compute_accuracy(self,predictions, true_labels):
      """Compute accuracy using np.mean."""

      accuracy = np.mean(predictions == true_labels)
      return accuracy

    def create_observable_matrix(self):
        matrix = np.eye(2**self.number_of_qubits)
        matrix[0, 0] = -1
        matrix[-1, -1] = -1
        return matrix

    def pauli_z_first_qubit(self):
        """Creates the observable Z1 = Z ⊗ I ⊗ ... ⊗ I for the first qubit in an n-qubit system."""
        Z = np.array([[1, 0], [0, -1]])  # Pauli-Z matrix
        I = np.eye(2)  # Identity matrix
        observable = Z  # Start with Z for the first qubit
    
        # Tensor product to extend to n_qubits
        for _ in range(1, self.number_of_qubits):
            observable = np.kron(observable, I)
        
        return observable
    
    def sample_vandenberg_inverse(self):

        operator = Pauli("I"*self.number_of_qubits)
        sgn_tot = 0

        for pauli_op, prob in self.presetimated_noise:
            random_num = random()
            if random_num < prob: 
                operator*=get_cached_pauli(pauli_op) 
                sgn_tot +=1 

        return operator, sgn_tot