from Experiments import Experiments,ExecutionMode
from QuantumDataGenarator import QuantumDataGenarator
from sklearn.model_selection import train_test_split
import json
import numpy as np
import matplotlib.pyplot as plt

if __name__ == '__main__':
    #fullselection prob ---  adam param is set to 0.01,0.001/0.003,0.003 in `GenericNoiseMitigatedUnitaryCircuit.py`
    # you can change the number of shots for faster execution but it might decrease the performance.
    
    number_of_qubits=4
    num_of_neurons = 6
    
    list_of_unitary = ['XX','YX','YI','YZ','ZX','IX']
    support_qubits = [[0,1],[1,2],[2,3],[0,2],[1,3],[0,3]]
    
    data,label = QuantumDataGenarator.genDataset(number_of_qubits, samples = 5000)
    
    X_train, X_test, y_train, y_test = train_test_split(data, label,
                                                        stratify=label, 
                                                        test_size=0.2)
    
    '''
    Here we are keeping everything the same for each run, but we are not fixing the seed of the Qiskit circuit simulator. 
    Hence, we will get different outputs in each run based on the inherent randomness of the simulator (due to shot noise, 
    sampling fluctuations etc.), which allows us to estimate the empirical mean and standard deviation across runs.”
    '''
    
    filename = 'mitigation_test.json'
    
    num_times_to_run = 1
    
    mses=[]
    test_accuracy=[]
    for i in range(num_times_to_run):
        print('RUN: ', i+1)
        exp = Experiments(number_of_qubits,num_of_neurons,list_of_unitary,support_qubits,ExecutionMode.MITIGATED)
        acc_d,mse,acc_pred=exp.training_circuit(X_train, y_train,epochs=20, batches=50)
        mses.append(mse)
    
        accD, accP= exp.testing_circuit(X_test,y_test)
        test_accuracy.append(accP)
        
    print('MSE mean: ' ,np.mean(mses,axis=0)) 
    print('MSE std: ' ,np.std(mses,axis=0)) 
    print('Test Acc: ' ,np.mean(test_accuracy)) 
    # end = time.time()
    
    epchs = range(1, len(mse)+1)
    plt.plot(epchs,mse,color='green',label = 'mitigation',linestyle='--')
    plt.ylabel('Training loss (MSE)')
    plt.xlabel('Epochs')
    plt.legend(loc='upper right')
    plt.show()
    
    data = {
        'Mse_mean': np.mean(mses,axis=0).tolist(),
        'Mse_std': np.std(mses,axis=0).tolist()
    }
    with open(filename, 'w') as f:
        json.dump(data, f, indent=4)  # indent=4 → makes it pretty
    
    print("Data saved to ", filename)
