import json
import sys
import pickle as pkl
import os

from shutil import copy as cp

import pennylane as qml
from pennylane import numpy as np


from sklearn.metrics import log_loss, accuracy_score


from data_utils import return_scaled_data

config = {"dataset" : "bars", "classes":[0,1], "num_qubits" : 8, "measured_qubit" : 0, "num_layers" : 8 , "encoding_scaler" : "angle", "encoding_order":"C", "activation":"tanh"}

NUM_QUBITS = config["num_qubits"]
NUM_LAYERS = config["num_layers"]


scaled_train_data, scaled_test_data, train_labels, test_labels = return_scaled_data(config)

print(config)

dev = qml.device("default.qubit", wires=NUM_QUBITS)

@qml.qnode(dev)
def classifierCircuit_angles(inputs, params, numQubits, numParamLayers):
    numInputLayers = int(len(inputs) / (numQubits * 2))
    numTotalLayers = numInputLayers + numParamLayers

    inputLayerCounter = 0
    paramLayerCounter = 0

    for j in range(numTotalLayers):
        if j % 2 == 0 and inputLayerCounter < numInputLayers:
            for i in range(numQubits):
                qml.RX(inputs[inputLayerCounter * numQubits + i], wires=i)
                qml.RZ(inputs[numInputLayers * numQubits + inputLayerCounter * numQubits + i], wires=i)
            inputLayerCounter += 1
        elif j % 2 == 1 or inputLayerCounter >= numInputLayers:
            for i in range(numQubits):
                qml.RX(params[paramLayerCounter * numQubits + i], wires=i)
                qml.RZ(params[numParamLayers * numQubits + paramLayerCounter * numQubits + i], wires=i)
            paramLayerCounter += 1

        for i in range(numQubits - 1):
            qml.CNOT(wires=[i, i + 1])

    return qml.expval(qml.PauliZ(0))  

def classifier_forward(inputs, params, numQubits, numParamLayers):
    expVal = (1 - classifierCircuit_angles(inputs, params, numQubits, numParamLayers)) / 2.0

    if config["activation"] == "tanh":
        prediction = 0.5 * np.tanh(10 * (expVal - 0.5)) + 0.5
    elif config["activation"] == "identity":
        prediction = expVal

    return prediction

def square_loss(params, input_batch, numQubits, numLayers, true_labels):
    
    predictions = [classifier_forward(inputs, params, numQubits, numLayers) for inputs in input_batch]

    return np.mean((true_labels - qml.math.stack(predictions)) ** 2)

def binary_cross_entropy_loss(params, input_batch, numQubits, numLayers, true_labels):
    predictions = np.array([classifier_forward(inputs, params, numQubits, numLayers) for inputs in input_batch])

    # Clip predictions to avoid log(0)
    epsilon = 1e-15
    predictions = np.clip(predictions, epsilon, 1 - epsilon)

    loss = -np.mean(true_labels * np.log(predictions) + (1 - true_labels) * np.log(1 - predictions))
    return loss


def classification_loss(params, input_batch, numQubits, numLayers, true_labels):

    predictions = [classifier_forward(inputs, params, numQubits, numLayers) for inputs in input_batch]

    return log_loss(true_labels, predictions, labels=[0,1])

def accuracy(params, input_batch, numQubits, numLayers, true_labels)-> float:

    predictions = np.zeros(len(input_batch))

    for i, inputs in enumerate(input_batch):
        predictions[i] = classifier_forward(inputs, params, numQubits, numLayers)

    return accuracy_score(true_labels, [np.round(e) for e in predictions])


def callback_fun(intermediate_result):
    # print(len(intermediate_result))
    print("Func val: ", loss(intermediate_result))
    return

print("Starting training")

opt = qml.GradientDescentOptimizer(stepsize=0.2)
params = np.array([1.0 for i in range(2*NUM_LAYERS*NUM_QUBITS)], requires_grad=True)


loss = lambda x : binary_cross_entropy_loss(x, scaled_train_data, NUM_QUBITS, NUM_LAYERS, train_labels)


numIterations = 100
for i in range(numIterations):
    params = opt.step(loss, params)
    if i % 10 == 0:
        print(f"Step {i}: cost = {loss(params):.6f} acc = {accuracy(params, scaled_train_data, NUM_QUBITS, NUM_LAYERS, train_labels)}")

print(f"\nOptimized parameters: {params}")
print(f"Final cost value: {loss(params):.6f}")
print(f"Training accuracy: {accuracy(params, scaled_train_data, NUM_QUBITS, NUM_LAYERS, train_labels)}")
print(f"Testing accuracy: {accuracy(params, scaled_test_data, NUM_QUBITS, NUM_LAYERS, test_labels)}")


print(config)

if not os.path.isdir(f'batched{config["dataset"]}_{config["encoding_scaler"]}{config["activation"]}_qml_class{''.join([str(c) for c in config["classes"]])}_results'):
    os.mkdir(f'batched{config["dataset"]}_{config["encoding_scaler"]}{config["activation"]}_qml_class{''.join([str(c) for c in config["classes"]])}_results')

fileObj = open(f'batched{config["dataset"]}_{config["encoding_scaler"]}{config["activation"]}_qml_class{''.join([str(c) for c in config["classes"]])}_results/NISTresult_{NUM_QUBITS}_layer_{NUM_LAYERS}_trial5.pkl', 'wb')
pkl.dump(params, fileObj)
fileObj.close()

print("result saved")