import numpy as np
from Tools import Tools
from GenericNoiseMitigatedUnitaryCircuit import NoiseMitigatedUnitaryCircuit
from QuantumDataGenarator import QuantumDataGenarator

from qiskit import QuantumCircuit
from qiskit_aer import Aer,AerSimulator
from qiskit.quantum_info import SparsePauliOp
from qiskit.compiler import transpile
from qiskit_aer.primitives import EstimatorV2
from CacheData import get_cached_pauli
from sklearn.model_selection import KFold,StratifiedKFold
from itertools import chain
from qiskit.quantum_info import Pauli
import pickle
from random import random

class ExecutionMode:
    NOISELESS = "noiseless"
    NOISY = "noisy"
    MITIGATED = "mitigated"
    VanDen = "vandenberg"

class Experiments:
    def __init__(self, number_of_qubits,num_of_neurons,list_of_unitary_strings,support_qubits_list,mode,seed=54):
        self.seed = seed
        np.random.seed(self.seed)
        self.number_of_qubits = number_of_qubits
        self.mode = mode

        '''
        Select appropriate Observable based on problem, inputs and system.
        '''
        self.observable = self.create_observable_matrix()
        
        # for all Z 
        #self.observable = Tools.pstr2mat('Z'*self.number_of_qubits)
        #self.observable = Tools.pstr2mat('IIIZ')

        #for Z on first qubit only mnist
        #self.observable = self.pauli_z_first_qubit()
        
        self.num_of_neurons = num_of_neurons
        self.support_qubits_list = support_qubits_list
        self.list_of_unitary_strings = list_of_unitary_strings
        self.pauli_observable = SparsePauliOp.from_operator(self.observable)
        self.qc = QuantumCircuit(self.number_of_qubits)
        
        self.presetimated_noise = None
        if self.mode == ExecutionMode.VanDen :
            with open('vandenNoiseModel.pkl', 'rb') as f:
                print('here')
                self.presetimated_noise = pickle.load(f)

        #if using CPU
        self.back_end = AerSimulator(max_parallel_experiments=8)
        
        # old system
        #self.back_end =  Aer.get_backend('qasm_simulator',max_parallel_experiments=0)
        
        self.estimator = EstimatorV2()

        #Shots needed for each component 
        self.shots = 1024

        # (1/p) value for tradeoff between number of mesaurement and convergence.  
        self.prob=1
        
        self._init_thetas = Tools.randThetaParameterInitiatization(self.num_of_neurons,self.seed)
        
        assert self.num_of_neurons == len(self.support_qubits_list) , 'Neurons and support qubit list mismatches'

        self.neurons = []

        #if we need same noise profile for each error on same time of gate

        '''
        list error: is the set of possible pauli errors associated with the gate, by default it will be set to all                          available paulis if user has no knowledge about the system. If certain knowledge persists, only a handful pauli 
        error can be set as well.  

        Sparse dominant subset commonly used: {IX,XI,IZ,ZI,XX,YY,ZZ}. This is motivated by the fact that single-qubit 
        dephasing (Z) and correlated two-qubit errors (XX,YY,ZZ) dominate on many NISQ devices. For 1 qubit error its {X,Z}.
        '''
        #For same error for all gates set list errors and lambda's here

        # list_errors= ['IX','XI','IZ','ZI','XX','YY','ZZ']
        # #list_errors=['XX','XY','XZ']
        # lambdas = Tools.randNoiseParameterInitiatization(len(list_errors))

        for i in range(self.num_of_neurons):
            
            #set list error here for varibale error on each gate
            #list_errors = Tools.generateAllPossiblePauliString(len(self.support_qubits_list[i]))
            # list_errors= ['IX','XI','IZ','ZI','XX','YY','ZZ']
            list_errors=['XX','XY','XZ']
            theta = self._init_thetas[i]
            sigmas= Tools.randinverseNoiseParameterInitiatization(len(list_errors))  
            lambdas = Tools.randNoiseParameterInitiatization(len(list_errors))
            
            self.neurons.append(NoiseMitigatedUnitaryCircuit(theta,  self.list_of_unitary_strings[i],list_errors,lambdas,sigmas,self.support_qubits_list[i]))

    def view_params(self):
        '''
        View all the theta (model) and sigma (inverse noise) parameter per gate basis. 
        '''
        for i in range(self.num_of_neurons):
            print("theta",self.neurons[i].theta, self.neurons[i].sigmas, self.neurons[i].lambdas)

    def reinitialize_parameters(self):
        '''
        Reinitialize the parameters, needed specifically for cross validation.
        '''
        for i in range(len(self.neurons)):
            self.neurons[i].reinitialize_parameters(self._init_thetas[i], Tools.randinverseNoiseParameterInitiatization(len(self.neurons[i].sigmas)) )

    def universal_algorithm(self, qdata, current_neuron, shifted_theta=None, sigma_q=None, apply_sigma=False):
        """
        Applies the necessary operations (unitary, noise, inverse noise, and Pauli operator) 
        to the quantum circuit based on the neuron type and current configuration, respecting the order of operations.
        The overhead(gamma) is handled during execution.

        Extra checks are done to make it generic for all the methods - noiseless, noisy, van den etc.
        
    
        Args:
            qdata: Quantum data to initialize the circuit.
            current_neuron: The neuron to apply transformations to.
            sigma_q: Optional Pauli operator for the current neuron (used only for `applysigma`).
            shifted_theta: Optional shifted theta for the current neuron (used for `get_exp_circuit_neuron_shifted`).
            apply_sigma: Flag to indicate whether to apply the Pauli operator (used for `applysigma`).
    
        Returns:
            A circuit and sign count for the experiment.
        """
        circuit = self.qc.copy_empty_like()
        circuit.initialize(qdata, range(self.number_of_qubits))
        count = 0

        for neuron in self.neurons:
            # Apply unitary for the neuron
            if shifted_theta is not None and neuron == current_neuron:
                neuron.apply_unitary(shifted_theta, circuit)
            else:
                neuron.apply_unitary(neuron.theta, circuit)

            # Apply noise for each neuron
            if self.mode in [ExecutionMode.NOISY , ExecutionMode.MITIGATED, ExecutionMode.VanDen]:
                neuron.apply_noise(circuit)

            # If sigma_q is provided, apply the Pauli operator to the specific neuron
            if apply_sigma and sigma_q is not None and neuron == current_neuron:
                operator = get_cached_pauli(sigma_q)
                for q, p in zip(neuron.support_qubits, operator[:]):
                    circuit.append(p, [q])
                    
            # Apply noise inverse for each neuron
            if self.mode == ExecutionMode.MITIGATED:
                neuron.apply_noise_inverse(circuit)
                count += neuron.inverse_noise_model.get_sign_counts()

            if self.mode == ExecutionMode.VanDen and neuron.gate_type in ['cx','cz']:
                operator,sign = self.sample_vandenberg_inverse()
                print(operator,sign)
                for q, p in zip(srange(self.number_of_qubits), operator[:]):
                    circuit.append(p, [q]) 
                count+= sign 
    
        circuit.measure_all()
    
        return circuit, count
 
    def apply_parameter_shift(self,qdata, label,exp_value,prob =1):
        """
        Apply the parameter shift rule to update parameters for each neuron.
        """
        grads = []
        grads_sign=[]
        for current_neuron in self.neurons:

            # Shift the parameter by pi/2 and -pi/2
            choice = np.random.choice([0,1], p=[prob, 1-prob])
            if(choice == 0):
                shifted_theta_plus = current_neuron.theta + np.pi / 4
                shifted_theta_minus = current_neuron.theta - np.pi / 4
                
                # can be used as well instead
                # expected=[]
                # for i in range(self.shots):
                #     b = np.random.choice([0,1], p=[0.5, 0.5])
                #     if(b == 0):
                #         expected.append(self.get_exp_circuit_neuron_shifted(qdata,current_neuron,shifted_theta_plus,shots=1))
                #     else:
                #         expected.append(self.get_exp_circuit_neuron_shifted(qdata,current_neuron,shifted_theta_minus,shots=1))
                        
                # gradient = 2* np.sqrt(2)*(exp_value - label)* expected
                
                exp_value_plus= self.get_exp_circuit_neuron_shifted(qdata,current_neuron,shifted_theta_plus)
                exp_value_minus = self.get_exp_circuit_neuron_shifted(qdata,current_neuron,shifted_theta_minus)
    
                gradient = 2*(exp_value - label)*(1/np.sqrt(2))* (exp_value_plus - exp_value_minus)
                grads.append(gradient)

                gradient_sign = Tools.get_sign_with_threshold(gradient)
    
                grads_sign.append(gradient_sign)
    
                current_neuron.update_theta(gradient_sign)
            else:
                continue

        return grads, grads_sign


    def get_exp_circuit_neuron(self,qdata, proxtheta =False,proxsigmas= False,shots = None):

        if shots is None:
            shots = self.shots


        instances=[]
        signs =[]
   
        for i in range(shots):
            circuit,count=self.universal_algorithm(
                qdata=qdata,
                current_neuron=None,
                shifted_theta=None,
                sigma_q=None,
                apply_sigma=False
            )
            instances.append(circuit)
            signs.append(count)
            
        #instances = transpile(instances, self.back_end,optimization_level=0)
        results = self.executor(instances)
         
        return self.expectation_value_qasm(results,signs) 

    def executor(self,circuits,shots=1):

        #parallelization
        
        # if(len(circuits)>1):
        #     batch_size = 8  # Set this to the number of parallel experiments you want
        #     results = []
        #     for i in range(0, len(circuits), batch_size):
        #         batch = circuits[i:i + batch_size]
        #         batch_result = self.back_end.run(batch, shots=shots).result().get_counts()
        #         results.append(batch_result)
           
        #     return list(chain.from_iterable(results))

        return self.back_end.run(circuits, shots = shots).result().get_counts()
        
    def expectation_value_qasm(self,results,signs):
        
        overhead = 1
        for neuron in self.neurons:
            overhead*= neuron.inverse_noise_model.get_overhead_gamma()  
            
        if not isinstance(results, list):
            
            index = int(list(results.keys())[0],2)
            return Tools.multiply_overhead(self.observable[index, index]*(-1)**signs[0],overhead)
            
        exp_value = []
        for i in range(len(results)):
            index = int(list(results[i].keys())[0],2)
            exp_value.append(self.observable[index, index]*(-1)**signs[i])
        return Tools.multiply_overhead(np.mean(exp_value),overhead)


    def get_exp_circuit_neuron_shifted(self,qdata,current_neuron, shifted_theta, shots = None):

        if shots is None:
            shots = self.shots //2
        
        instances=[]
        signs =[]

        for i in range(shots):
            circuit,count=self.universal_algorithm(
                qdata=qdata,
                current_neuron=current_neuron,
                shifted_theta=shifted_theta,
                sigma_q=None,
                apply_sigma=False
            )
            instances.append(circuit)
            signs.append(count) 
            
        #instances = transpile(instances, self.back_end,optimization_level=0)
        results = self.executor(instances)
     
        return self.expectation_value_qasm(results,signs)


    def initialize_circuit(self, circuit,state_vector):
        """
        Initialize a quantum circuit with a given state vector.

        :param circuit: The Qiskit QuantumCircuit object.
        """
        circuit.initialize(state_vector, range(self.number_of_qubits))


    def get_signleshot_estimate(self,circuit):
        circuit.measure_all()
        #qc_compiled = transpile(circuit, self.back_end)
        result = self.execute(circuit, shots=1)
        
        state_vector = np.zeros(2**self.number_of_qubits)
        bitstring = list(result.get_counts(circuit).keys())[0]
        index = int(bitstring, 2)
        state_vector[index] = 1
        return np.real(np.conj(state_vector) @ self.observable @ state_vector)


    def expectation_value_estimator(self,circuit,shots):    
        job = self.estimator.run([(circuit,[self.pauli_observable])])

        exp_val = job.result()[0].data.evs[0]
        return exp_val


    def sigma_gradient(self,qdata, label, exp_val,prob =1):
      grads=[]
      grads_sign=[]
      for neuron in self.neurons:
        sigma_neuron_gradients = []
        gs=[]
        
        for index,sigma_q in enumerate(neuron.pauli_ops):
            choice = np.random.choice([0,1], p=[prob, 1-prob])
            if(choice == 0):
                updte = self.apply_sigma_gradient(qdata,label,neuron, sigma_q, exp_val)
            else:
                updte= 0
            sigma_neuron_gradients.append(updte)
            gs.append(np.sign(updte))
        grads.append(sigma_neuron_gradients)
        grads_sign.append(gs)
        neuron.update_sigmas(sigma_neuron_gradients)

      return grads,grads_sign

    def applysigma(self,qdata,current_neuron,sigma_q, shots= None):
            
        if shots is None:
            shots = self.shots

        instances=[]
        signs =[]
        
        for i in range(shots):
            circuit,count=self.universal_algorithm(
                qdata=qdata,
                current_neuron=current_neuron,
                shifted_theta=None,
                sigma_q=sigma_q,
                apply_sigma=True
            )
            instances.append(circuit)
            signs.append(count) 
            
        #instances = transpile(instances, self.back_end,optimization_level=0)
        results = self.executor(instances)
     
        return self.expectation_value_qasm(results,signs)


    def apply_sigma_gradient(self,qdata,label,current_neuron,sigma_q, pred,shots = None):


        #using measurement shots
        #pred = self.get_exp_circuit_neuron(qdata)

        if shots is None:
            shots = self.shots
        exp_val =[]

        for i in range(shots):
            choice = np.random.choice([0,1], p=[0.5, 0.5])

            if(choice == 0):
                exp_val.append(self.get_exp_circuit_neuron(qdata,shots = 1))
            else:
                exp_val.append((-1)*self.applysigma(qdata,current_neuron,sigma_q,shots=1))



        return Tools.get_sign_with_threshold((pred-label)*(4* np.mean(exp_val))) 

    def execute(self,circuit, shots):
        return self.back_end.run(circuit, shots=shots).result()


    def expectation_value_density_matrix(self,density_matrix):
      """
      Compute the expectation value as Trace(rho * observable).
      :param density_matrix: The quantum density matrix.
      :return: The expectation value.
      """
      return np.trace(density_matrix @ self.observable)

    def testing_circuit(self,test_data,test_label):
        accuracy=[]
        accuracy_pred=[]
        for index in range(len(test_label)):
            qdata = test_data[index]
            label = test_label[index] 

            prediction = self.get_exp_circuit_neuron(qdata)
            accuracy.append(self.accuracy_calc(prediction, label))
            accuracy_pred.append(self.accuracy_calc(prediction, label,True))
        return np.mean(accuracy),np.mean(accuracy_pred)

    def training_circuit(self,train_data,data_set,epochs=100, batches=50):
        accuracy=[]
        accuracy_pred=[]
        mse=[]
        # if you want to force some loss criteria you can enable, normally its not needed.
        best_loss = 0.20
        
        for i in range(epochs):
            loss=[]
            acc=[]
            acc_p=[]
            preds=[]
            labels=[] 
            for j in range(batches):
                index = (i*batches + j)% len(train_data)
                qdata = train_data[index]
                label = data_set[index]
                
                prediction = self.get_exp_circuit_neuron(qdata)
                l= self.mse_loss(prediction, label)
                loss.append(l)
                acc.append(self.accuracy_calc(prediction, label))
                acc_p.append(self.accuracy_calc(prediction,label,True))
                preds.append(prediction)
                labels.append(label)
                if l > best_loss:
                    param_grads,param_signs=self.apply_parameter_shift(qdata, label,prediction,self.prob)
                    if self.mode == ExecutionMode.MITIGATED:
                        sigma_grads,sigma_signs = self.sigma_gradient(qdata, label,prediction,self.prob)

                    
            mse.append(np.mean(loss))
            accuracy.append(np.mean(acc))
            accuracy_pred.append(np.mean(acc_p))

            print(f'Epoch {i+1}: Loss = {np.mean(loss):.4f},Theta = {[neuron.theta for neuron in self.neurons]},Sigma = {[neuron.sigmas for neuron in self.neurons]}, Accuracy = {np.mean(acc):.2f}, , Accuracy Probabilistic = {np.mean(acc_p):.2f}')

        return  accuracy, mse, accuracy_pred


    def train_CV(self,features,labels,epochs=10,samples = 500,k_folds =5):
        accuracy_scores=[]
        mse_scores=[]
        best_loss = 0.20
        
        #data_set, train_data = QuantumDataGenarator.gen(self.number_of_qubits, samples , seed = 48)
        #features = [train_data[data_set[i]] for i in range(len(data_set))]
        #labels = [data_set[i] for i in range(len(data_set))]

        
        kf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=42)

        # Perform k-fold cross-validation
        for fold, (train_indices, val_indices) in enumerate(kf.split(features,labels)):
            print(f'\nStarting fold {fold + 1}/{k_folds}')

            
            fold_accuracy = []
            fold_mse = []
            epoch_mse=[]
            epoch_accuracy=[]
            # Training and validation sets for this fold
 
            train_subset = [features[i] for i in train_indices]
            train_labels = [labels[i] for i in train_indices]
            val_subset = [features[i] for i in val_indices]
            val_labels = [labels[i] for i in val_indices]


            self.reinitialize_parameters()
    
            # Define batch size based on the length of the training subset and the number of epochs
            batches = len(train_subset) // epochs

        
            for i in range(epochs):
                loss=[]
                acc=[]
                preds=[]
                label_ep=[] 
                for j in range(batches):
                    index = i*batches + j
                    
                    label = train_labels[index] 
                    qdata = train_subset[index]

                    prediction = self.get_exp_circuit_neuron(qdata)
                    l = self.mse_loss(prediction, label)
                    loss.append(l)
                    acc.append(self.accuracy_calc(prediction, label))
                    preds.append(prediction)
                    label_ep.append(label)
                    if l > best_loss:
                        param_grads,param_signs=self.apply_parameter_shift(qdata, label,prediction)
                        if self.mode == ExecutionMode.MITIGATED:
                            sigma_grads,sigma_signs = self.sigma_gradient(qdata, label, prediction)
                    
    
                epoch_mse.append(np.mean(loss))
                epoch_accuracy.append(np.mean(acc))
    
                # print(f'Fold {fold + 1},Epoch {i+1}: Loss = {np.mean(loss):.4f},Theta = {[neuron.theta for neuron in self.neurons]},Sigma = {[neuron.sigmas for neuron in self.neurons]}, Accuracy = {np.mean(acc):.2f}')

            fold_accuracy.append(epoch_accuracy)
            fold_mse.append(epoch_mse)

            val_loss = []
            val_acc = []
            val_preds=[]
            val_label=[]
            # print(f'ValidationTheta = {[neuron.theta for neuron in self.neurons]},Sigma = {[neuron.sigmas for neuron in self.neurons]}, Accuracy = {np.mean(acc):.2f}')

            for i,qdata in enumerate(val_subset):
                label = val_labels[i]
                prediction = self.get_exp_circuit_neuron(qdata)
                val_loss.append(self.mse_loss(prediction, label))
                val_acc.append(self.accuracy_calc(prediction, label))
                val_preds.append(prediction)
                val_label.append(label)
            
            # Store results for this fold
            accuracy_scores.append(np.mean(val_acc))
            mse_scores.append(np.mean(val_loss))

            # print(f'Fold {fold + 1} Validation: MSE = {np.mean(val_loss):.4f}, Accuracy = {np.mean(val_acc):.2f}')
        
        # Final cross-validation results
        print(f'\nCross-Validation Results over {k_folds} folds:')
        print(f'Accuracy: {np.mean(accuracy_scores)},Average Accuracy: {np.mean(accuracy_scores):.2f} ± {np.std(accuracy_scores):.2f}')
        print(f'MSE :{np.mean(mse_scores):.4f},Average MSE: {np.mean(mse_scores):.4f} ± {np.std(mse_scores):.4f}')
       
        return fold_accuracy, fold_mse, accuracy_scores, mse_scores


    def mse_loss(self,prediction, label):
        """Compute the mean squared error loss."""
        return (prediction - label) ** 2

    def accuracy_calc(self,prediction, label, flag = False):
        """Compute accuracy based on prediction and label."""
        if(flag):
            predicted_label = self.sample_based_on_expectation(prediction)
        else:
            predicted_label = 1 if prediction > 0 else -1
        return 1 if predicted_label == label else 0

    def sample_based_on_expectation(self,exp_value):
        # Ensure the expectation value is in the range [-1, 1]
        if not -1 <= exp_value <= 1:
            raise ValueError("Expectation value must be in the range [-1, 1].")
        
        # Compute probabilities for binary outcomes (+1 and -1)
        p_plus = (1 + exp_value) / 2
        p_minus = 1 - p_plus  # Equivalent to (1 - exp_value) / 2
        
        # Generate a random sample
        return np.random.choice([1, -1], p=[p_plus, p_minus])

    def compute_mse(self,predictions, true_labels):
      """Compute Mean Squared Error."""

      predictions = np.array(predictions)
      true_labels = np.array(true_labels)
      mse = np.mean((predictions - true_labels) ** 2)
      return mse


    def compute_accuracy(self,predictions, true_labels):
      """Compute accuracy using np.mean."""

      accuracy = np.mean(predictions == true_labels)
      return accuracy

    def create_observable_matrix(self):
        matrix = np.eye(2**self.number_of_qubits)
        matrix[0, 0] = -1
        matrix[-1, -1] = -1
        return matrix

    def pauli_z_first_qubit(self):
        """Creates the observable Z1 = Z ⊗ I ⊗ ... ⊗ I for the first qubit in an n-qubit system."""
        Z = np.array([[1, 0], [0, -1]])  # Pauli-Z matrix
        I = np.eye(2)  # Identity matrix
        observable = Z  # Start with Z for the first qubit
    
        # Tensor product to extend to n_qubits
        for _ in range(1, self.number_of_qubits):
            observable = np.kron(observable, I)
        
        return observable

    def sample_vandenberg_inverse(self):

        operator = Pauli("I"*self.number_of_qubits)
        sgn_tot = 0

        for pauli_op, prob in self.presetimated_noise:
            random_num = random()
            if random_num < prob: 
                operator*=get_cached_pauli(pauli_op) #compose into product
                sgn_tot +=1 #record sign

        return operator, sgn_tot
    
