# this script runs the training loop for basic QML models using amplitude encoding. 
# You can change datasets, classes, and model hyperparameters by changing the dictionary config

import json
import sys
import pickle as pkl
import numpy as np
import matplotlib.pyplot as plt
import os

from shutil import copy as cp
from scipy.optimize import minimize, dual_annealing

from qiskit import QuantumCircuit
from qiskit.quantum_info import Statevector, DensityMatrix, state_fidelity, partial_trace, Pauli


from sklearn.metrics import log_loss, accuracy_score


from data_utils import return_scaled_data

config = {"dataset" : "MNIST", "classes":[0, 3], "num_qubits" : 10, "num_layers" : 10 , "encoding_scaler" : "overflow", "encoding_order":"C", "activation":"identity"}

# default is to measure the last qubit, but this can be changed
config["measured_qubit"] = config["num_qubits"]-1

NUM_QUBITS = config["num_qubits"]
NUM_LAYERS = config["num_layers"]

print(config)

scaled_train_data, scaled_test_data, train_labels, test_labels = return_scaled_data(config)

def classifierCircuit(inputs, params, numQubits, numLayers):

    qc = QuantumCircuit(numQubits, numQubits)

    qc.initialize(inputs)

    for j in range(numLayers):
    
        for i in range(numQubits):
            qc.rx(params[j*numQubits+i], i)
            qc.rz(params[numQubits*numLayers + j*numQubits+i], i)

        for i in range(numQubits-1):
            qc.cx(i, i+1)
    return qc

def classifier_forward(inputs, params, numQubits, numLayers):

    s = Statevector.from_instruction(classifierCircuit(inputs, params, numQubits, numLayers))

    

    expVal = (1 - s.expectation_value(Pauli("Z" + ''.join(["I" for i in range(NUM_QUBITS-1)])))) / 2.0

    if config["activation"] == "tanh":
        prediction = 1/2*np.tanh( 10 * (expVal - 1/2) ) + 1/2
    elif config["activation"] == "identity":
        prediction = expVal

    return prediction

def loss(params, input_batch, numQubits, numLayers, true_labels):
    # print(params[0:5])
    predictions = np.zeros(len(input_batch))

    # print(input_batch[0])
    for i, inputs in enumerate(input_batch):
        predictions[i] = classifier_forward(inputs, params, numQubits, numLayers)

    print("loss called, acc = ", accuracy_score(train_labels, [np.round(e) for e in predictions]))

    # print(np.round(predictions, 1))

    return log_loss(true_labels, predictions, labels=[0,1])

def accuracy(params, input_batch, numQubits, numLayers, true_labels)-> float:

    predictions = np.zeros(len(input_batch))

    for i, inputs in enumerate(input_batch):
        predictions[i] = classifier_forward(inputs, params, numQubits, numLayers)

    return accuracy_score(true_labels, [np.round(e) for e in predictions])



print("Starting training")
initial_params = np.random.rand(2*NUM_LAYERS*NUM_QUBITS) #NUM_LAYERS*qubit_count *[0.0]
print(accuracy(initial_params,scaled_train_data, NUM_QUBITS, NUM_LAYERS, train_labels ))


result = minimize(loss, initial_params, args=(scaled_train_data, NUM_QUBITS, NUM_LAYERS, train_labels), method='COBYLA', options={'maxiter': 100})
# param_bounds = [(0, 2*np.pi) for i in range(len(initial_params))]
# result = dual_annealing(lambda x: loss(x, scaled_train_data, NUM_QUBITS, NUM_LAYERS, train_labels), bounds=param_bounds, maxiter=1000, maxfun=1000)

print(result)

print(result.x)

print("training_acc: ", accuracy(result.x, scaled_train_data, NUM_QUBITS, NUM_LAYERS, train_labels))
print("test_acc: ", accuracy(result.x, scaled_test_data, NUM_QUBITS, NUM_LAYERS, test_labels))


if not os.path.isdir(f'{config["dataset"]}_{config["encoding_scaler"]}{config["activation"]}_qml_class{''.join([str(c) for c in config["classes"]])}_results'):
    os.mkdir(f'{config["dataset"]}_{config["encoding_scaler"]}{config["activation"]}_qml_class{''.join([str(c) for c in config["classes"]])}_results')

fileObj = open(f'{config["dataset"]}_{config["encoding_scaler"]}{config["activation"]}_qml_class{''.join([str(c) for c in config["classes"]])}_results/NISTresult_{NUM_QUBITS}_layer_{NUM_LAYERS}_trial3.pkl', 'wb')
pkl.dump(result, fileObj)
fileObj.close()

print("result saved")