import numpy as np

from qiskit import QuantumCircuit
from qiskit.quantum_info import Statevector, DensityMatrix, state_fidelity, partial_trace, Pauli

from sklearn.metrics import log_loss, accuracy_score


def classifierCircuit(inputs, params, numQubits, numLayers):

    qc = QuantumCircuit(numQubits, numQubits)

    qc.initialize(inputs, normalize=True)

    for j in range(numLayers):
    
        for i in range(numQubits):
            qc.rx(params[j*numQubits+i], i)
            qc.rz(params[numQubits*numLayers + j*numQubits+i], i)

        for i in range(numQubits-1):
            qc.cx(i, i+1)
    return qc

def classifierCircuit_withoutInit(params, numQubits, numLayers):
    # print(len(params))
    qc = QuantumCircuit(numQubits, numQubits)

    # qc.initialize(inputs)

    for j in range(numLayers):
    
        for i in range(numQubits):
            qc.rx(params[j*numQubits+i], i)
            qc.rz(params[numQubits*numLayers + j*numQubits+i], i)

        for i in range(numQubits-1):
            qc.cx(i, i+1)
    return qc

def classifierCircuit_withoutInitwithoutCbits(params, numQubits, numLayers):
    # print(len(params))
    qc = QuantumCircuit(numQubits)

    # qc.initialize(inputs)

    for j in range(numLayers):
    
        for i in range(numQubits):
            qc.rx(params[j*numQubits+i], i)
            qc.rz(params[numQubits*numLayers + j*numQubits+i], i)

        for i in range(numQubits-1):
            qc.cx(i, i+1)
    return qc

def classifierCircuit_angles(inputs, params, numQubits, numParamLayers):

    numInputLayers = int(len(inputs) / (numQubits*2)) # 2 rotation gates per qubit in each input layer
    numTotalLayers = numInputLayers + numParamLayers

    qc = QuantumCircuit(numQubits, numQubits)

    inputLayerCounter = 0
    paramLayerCounter = 0
    
    for j in range(numTotalLayers):
        
        if j % 2 == 0 and inputLayerCounter < numInputLayers:
            for i in range(numQubits):
                qc.rx(inputs[inputLayerCounter*numQubits+i], i)
                qc.rz(inputs[numQubits*numInputLayers + inputLayerCounter*numQubits+i], i)
            inputLayerCounter += 1
        elif j % 2 == 1 or inputLayerCounter >= numInputLayers:
            for i in range(numQubits):
                qc.rx(params[paramLayerCounter*numQubits+i], i)
                qc.rz(params[numQubits*numParamLayers + paramLayerCounter*numQubits+i], i)
            paramLayerCounter += 1

        for i in range(numQubits-1):
            qc.cx(i, i+1)
    return qc

def classifier_forward(inputs, params, numQubits, numLayers, config):

    s = Statevector.from_instruction(classifierCircuit(inputs, params, numQubits, numLayers))

    pauliString = ["I"] * config["num_qubits"]
    pauliString[config["measured_qubit"]] = "Z"

    # compute expVal for quantum model function
    expVal = (1 - s.expectation_value(Pauli(''.join(pauliString)[::-1]))) / 2.0
    
    if config["activation"] == "tanh":
        prediction = 1/2*np.tanh( 10 * (expVal - 1/2) ) + 1/2
    elif config["activation"] == "identity":
        prediction = expVal

    return prediction

def classifier_forward_angles(inputs, params, numQubits, numLayers, config):

    s = Statevector.from_instruction(classifierCircuit_angles(inputs, params, numQubits, numLayers))

    pauliString = ["I"] * config["num_qubits"]
    pauliString[config["measured_qubit"]] = "Z"

    # compute expVal for quantum model function
    expVal = (1 - s.expectation_value(Pauli(''.join(pauliString)[::-1]))) / 2.0

    if config["activation"] == "tanh":
        prediction = 1/2*np.tanh( 10 * (expVal - 1/2) ) + 1/2
    elif config["activation"] == "identity":
        prediction = expVal

    return prediction

def classifier_loss(params, input_batch, numQubits, numLayers, true_labels, config):
    # print(params[0:5])
    predictions = np.zeros(len(input_batch))

    # print(input_batch[0])
    for i, inputs in enumerate(input_batch):
        predictions[i] = classifier_forward(inputs, params, numQubits, numLayers, config)

    # print("loss called, acc = ", accuracy_score(train_labels, [np.round(e) for e in predictions]))

    # print(np.round(predictions, 1))

    return log_loss(true_labels, predictions, labels=[0,1])

def classifier_accuracy(params, input_batch, numQubits, numLayers, true_labels, config)-> float:

    predictions = np.zeros(len(input_batch))

    for i, inputs in enumerate(input_batch):
        predictions[i] = classifier_forward(inputs, params, numQubits, numLayers, config)

    return accuracy_score(true_labels, [np.round(e) for e in predictions])

def classifier_accuracy_angles(params, input_batch, numQubits, numLayers, true_labels, config)-> float:

    predictions = np.zeros(len(input_batch))

    for i, inputs in enumerate(input_batch):
        predictions[i] = classifier_forward_angles(inputs, params, numQubits, numLayers, config)

    return accuracy_score(true_labels, [np.round(e) for e in predictions])

