import pickle as pkl
import json

from os import listdir, mkdir
from os.path import isdir

import numpy as np
import matplotlib.pyplot as plt

from qiskit import QuantumCircuit
from qiskit.quantum_info import Statevector, Pauli, Operator

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.metrics import log_loss, accuracy_score


from sklearn.datasets import load_digits

from data_utils import return_scaled_data, return_unscaled_data, generate_bars_and_stripes
from model_utils import *
from integrated_grads_utils import integrated_grad,  native_integrated_grad, hadamard_integrated_grad

dataset = "NIST"
dir_name = f"trained_qml_circuits/{dataset}"
encoding_scaler = "overflow"
activation_func = "tanh"
num_qubits = 6
num_layers = 12

baselineImage = np.zeros(64)
baselineImage[0] = 1.0
baselineImage = baselineImage / np.linalg.norm(baselineImage)

numSamplesPerExp = 8
numIntegrationSteps = 50

experimentFolders = [c for c in listdir(dir_name) if (encoding_scaler in c and activation_func in c)]
print(experimentFolders[1])

for ef in [experimentFolders[1]]:

    for num_shots in [10, 100, 500]:
        files = listdir(f"{dir_name}/{ef}")

        print(ef, "   ", files[0])
        # print(files[0])
    
        index = ef.find("class")
        if index != -1:
            class0, class1 = int(ef[index + len("class")]) , int(ef[index + len("class") + 1])
            # print("Class0 : ", class0)
            # print("Class1 : ", class1)
        else:
            print("Substring not found.") 
        

        
        config = {"dataset" : dataset, "classes":[class0, class1], "num_qubits" : num_qubits, "num_layers" : num_layers , "measured_qubit" : num_qubits-1,"encoding_scaler" : encoding_scaler, "encoding_order":"C", "activation": activation_func}


        with open(f"{dir_name}/{ef}/{files[0]}", "rb") as f:
            trained_result = pkl.load(f)
        
        opt_params = trained_result.x
        print(len(opt_params))

        scaled_train_data, scaled_test_data, train_labels, test_labels = return_scaled_data(config)
        _, unscaled_test_data, _, _ = return_unscaled_data(config)

        opt_loss = classifier_loss(opt_params, scaled_train_data, num_qubits, num_layers, train_labels, config)
        
        print("Stored loss: ", trained_result.fun)
        print("Calc loss: ", opt_loss)

        print("Train accuracy: ", classifier_accuracy(opt_params, scaled_train_data, num_qubits, num_layers, train_labels, config))
        print("Test accuracy: ", classifier_accuracy(opt_params, scaled_test_data, num_qubits, num_layers, test_labels, config))

        circuit_func = lambda x : classifierCircuit_withoutInitwithoutCbits(opt_params, num_qubits, num_layers)

        fig, axes = plt.subplots(2, numSamplesPerExp)

        ig_arrays =  np.zeros((numSamplesPerExp, 8, 8))
        image_arrays = np.zeros((numSamplesPerExp, 8, 8))

        for idx in range(numSamplesPerExp):
            ig = hadamard_integrated_grad(scaled_test_data[idx], baselineImage, circuit_func, numIntegrationSteps, num_qubits, config, num_shots=num_shots) 
            
            ig_arrays[idx, :] = np.reshape(ig, (8,8))
            image_arrays[idx, :] = np.reshape(unscaled_test_data[idx, :], (8,8))
            
            axes[0,idx].imshow(image_arrays[idx,:])
            axes[1,idx].imshow(np.reshape(ig, (8, 8)))
        
        if not isdir(f"paper_data/shot_noise/{dataset}/shots{num_shots}intSteps{numIntegrationSteps}/"):
            mkdir(f"paper_data/shot_noise/{dataset}/shots{num_shots}intSteps{numIntegrationSteps}/")

        if not isdir(f"paper_data/shot_noise/{dataset}/shots{num_shots}intSteps{numIntegrationSteps}/{ef}/"):
            mkdir(f"paper_data/shot_noise/{dataset}/shots{num_shots}intSteps{numIntegrationSteps}/{ef}/")
            
        np.save(f"paper_data/shot_noise/{dataset}/shots{num_shots}intSteps{numIntegrationSteps}/{ef}/{files[0][:-4]}_attributions.npy", ig_arrays)
        np.save(f"paper_data/shot_noise/{dataset}/shots{num_shots}intSteps{numIntegrationSteps}/{ef}/{files[0][:-4]}_raw_images.npy", image_arrays)

        print("")
        plt.show()
