import numpy as np
import matplotlib.pyplot as plt

from qiskit.quantum_info import Statevector, Pauli
from qiskit import QuantumCircuit, transpile
from qiskit_aer import StatevectorSimulator, AerSimulator
from qiskit.circuit.library import StatePreparation

def integrated_grad(inputs, baseline, model, numSteps):

    attributions = np.zeros_like(inputs, dtype=np.float64)

    for k in range(numSteps):

        currentPoint = baseline + k/numSteps * (inputs-baseline)
        grads = num_grad_calc(currentPoint, model)
        attributions += ( grads * 1/numSteps )


    attributions = (inputs - baseline) * attributions

    return attributions

def native_integrated_grad(inputs, baseline, circuit_func, numSteps, numQubits, config):

    attributions = np.zeros_like(inputs, dtype=np.float64)

    for k in range(numSteps):

        currentPoint = baseline + k/numSteps * (inputs-baseline)
        grads = quantum_native_grad_calc(currentPoint, circuit_func, numQubits, config)
        attributions += ( grads * 1/numSteps )


    attributions = (inputs - baseline) * attributions

    return attributions

def hadamard_integrated_grad(inputs, baseline, circuit_func, numSteps, numQubits, config, num_shots=1000):

    attributions = np.zeros_like(inputs, dtype=np.float64)

    for k in range(numSteps):

        currentPoint = baseline + k/numSteps * (inputs-baseline)
        grads = quantum_hadamard_grad_calc(currentPoint, circuit_func, numQubits, config)
        attributions += ( grads * 1/numSteps )


    attributions = (inputs - baseline) * attributions

    return attributions


def num_grad_calc(x, model, stepSize=0.0001):

    currentFunctionVal = model(x)
    grads = np.empty_like(x, dtype=np.float64)

    for i, b in enumerate(np.eye(len(x))):

        newFunctionVal = model(x + stepSize*b)

        grads[i] = (newFunctionVal - currentFunctionVal) / stepSize
    
    return grads

# computes the gradient of quantum model, with respect to input features x
def quantum_native_grad_calc(x, circuit:QuantumCircuit, numQubits, config):

    qc_operations = circuit(x)
    qc_operations_inverse = qc_operations.inverse()

    qc = QuantumCircuit(numQubits)
    qc.initialize(x, normalize=True)
    qc.compose(qc_operations, inplace=True)
    
    s = Statevector.from_instruction(qc)
  
    pauliString = ["I"] * config["num_qubits"]
    pauliString[config["measured_qubit"]] = "Z"

    # compute expVal for quantum model function
    expVal = (1 - s.expectation_value(Pauli(''.join(pauliString)[::-1]))) / 2.0
    
    # additional operations for grad calculation
    qc.z(config['measured_qubit'])
    qc.compose(qc_operations_inverse, inplace=True)


    circuit_grads = 2*np.real(Statevector.from_instruction(qc).data) * (-1/2) # -1/2 is from the expVal expression above

    #adjust for overflow state
    if config["dataset"] == "NIST" or config["dataset"]=="bars":
        circuit_grads[1:] = circuit_grads[1:] + circuit_grads[0] * 1/2*(x[0])**(-1/2)*(-2)*x[1:]
    elif config["dataset"] == "MNIST" or config["dataset"]=="Fashion":
        circuit_grads[0:784] = circuit_grads[0:784] + circuit_grads[-1] * 1/2*(x[-1])**(-1/2)*(-2)*x[0:784]

    # print(circuit_grads)
    # print("expval", expVal)

    # chain rule for activation functions
    if config["activation"] == "tanh":
        grads = 1/2 * (1 / np.cosh( 10 * (expVal - 1/2)))**2 * 10 * circuit_grads 

    elif config["activation"] == "identity":
        grads = circuit_grads

    return grads


# computes the gradient of quantum model, with respect to input features x
def quantum_hadamard_grad_calc(x, circuit:QuantumCircuit, numQubits, config, num_shots=1000):

    backend = AerSimulator()
    # backend = StatevectorSimulator()

    qc_operations = circuit(x)
    qc_operations_inverse = qc_operations.inverse()

    # circuit for computing the expectation values so that we can feed them to chain rule
    expval_circuit = QuantumCircuit(numQubits, 1)
    expval_circuit.initialize(x, normalize=True)
    expval_circuit.compose(qc_operations, inplace=True)

    s = Statevector.from_instruction(expval_circuit)
    pauliString = ["I"] * config["num_qubits"]
    pauliString[config["measured_qubit"]] = "Z"

    # compute expVal for quantum model function
    expVal = (1 - s.expectation_value(Pauli(''.join(pauliString)[::-1]))) / 2.0

    # create subcircuit to be used in hadamard test
    subcircuit = QuantumCircuit(numQubits)
    subcircuit.compose(qc_operations, inplace=True)
    subcircuit.z(config['measured_qubit'])
    subcircuit.compose(qc_operations_inverse, inplace=True)
    
    grad_gate = subcircuit.to_gate()
    controlled_grad_gate = grad_gate.control(1)

    controlled_state_prep = StatePreparation(x, normalize=True).control()

    circuit_grads = np.zeros(2**numQubits)

    for i in range(2**numQubits):
        controlled_basis_prep = StatePreparation([1 if j == i else 0 for j in range(2**numQubits)]).control(ctrl_state=0)

        full_circuit = QuantumCircuit(numQubits+1, 1)
        full_circuit.h(0)
        full_circuit.append(controlled_state_prep, range(numQubits+1))
        full_circuit.append(controlled_basis_prep, range(numQubits+1))
        full_circuit.append(controlled_grad_gate, range(numQubits+1))
        full_circuit.h(0)

        full_circuit.measure(0, 0)

        transpiled_qc = transpile(full_circuit, backend)

        result = backend.run(transpiled_qc, shots=num_shots).result()
        counts = result.get_counts()
        hadamard_result = (counts["0"] / num_shots) * 4 - 2
        circuit_grads[i] = hadamard_result * -1/2

    # print(circuit_grads)
    # print("expval", expVal)
    # circuit_grads = 2*np.real(Statevector.from_instruction(qc).data) * (-1/2) # -1/2 is from the expVal expression above

    # chain rule for activation functions
    if config["activation"] == "tanh":
        grads = 1/2 * (1 / np.cosh( 10 * (expVal - 1/2)))**2 * 10 * circuit_grads 

    elif config["activation"] == "identity":
        grads = circuit_grads

    return grads