import cirq
import sympy
import tensorflow as tf
import num2words.lang_EN


class CircuitLayerBuilder:
    def __init__(self, pixel_qubits, color_qubit, output_qubit):
        self.pixel_qubits = pixel_qubits
        self.color_qubit = color_qubit
        self.output_qubit = output_qubit

    def add_layer(self, circuit, gate, prefix):
        symbols = sympy.symbols(prefix + '0:{}'.format(len(self.pixel_qubits)))
        for n, qubit in enumerate(self.pixel_qubits):
            circuit.append(gate(qubit, self.output_qubit) ** symbols[n])
            circuit.append(gate(qubit, self.color_qubit) ** symbols[n])

    def add_mixed_layer(self, circuit, gate1, gate2, prefix):
        symbols = sympy.symbols(prefix + '0:{}'.format(len(self.pixel_qubits)))
        for n, qubit in enumerate(self.pixel_qubits):
            circuit.append(gate1(qubit, self.output_qubit) ** symbols[n])
            circuit.append(gate2(qubit, self.color_qubit) ** symbols[n])

    def add_mixed_layer2(self, circuit, gate1, gate2, prefix):
        symbols = sympy.symbols(prefix + '0:{}'.format(2*len(self.pixel_qubits)))
        for n, qubit in enumerate(self.pixel_qubits):
            circuit.append(gate1(qubit, self.output_qubit) ** symbols[2*n])
            circuit.append(gate1(qubit, self.color_qubit) ** symbols[2*n])
            circuit.append(gate2(qubit, self.output_qubit) ** symbols[2*n+1])
            circuit.append(gate2(qubit, self.color_qubit) ** symbols[2*n+1])

def controlled_x(qubits, exponent=1.):
    if len(qubits) == 2:
        yield cirq.CNOT(*qubits) ** exponent
        return #TODO: Do we need this here?

    yield cirq.CNOT(qubits[-2], qubits[-1]) ** (exponent/2.)
    yield from controlled_x(qubits[:-1], 1.)
    yield cirq.CNOT(qubits[-2], qubits[-1]) ** (-exponent/2.)
    yield from controlled_x(qubits[:-1], 1.)
    yield from controlled_x(qubits[:-2] + [qubits[-1]], exponent/2.)


def accuracy(y_true, y_pred):
    '''
    Compute the average accuracy between the predicted labels and true labels

    TODO: Replace with tensorflow method

    :param y_true: True data labels
    :param y_pred: Predicted data labels
    :return: average accuracy
    '''
    y_true = tf.squeeze(y_true) > 0.0
    y_pred = tf.squeeze(y_pred) > 0.0
    result = tf.cast(y_true == y_pred, tf.float32)
    return tf.reduce_mean(result)


def create_model(n_qubits, n_layers, network_type='CRADL'):
    '''
    Creates a quantum circuit

    TODO: Should this be called create_circuit?

    :param n_qubits: Number of qubits that will be provided as input to the circuit
    :param n_layers: Number of layers deep to make the circuit
    :param network_type: Type of network, e.g. CRADL, CRAML
    :return: quantum cirquit, and Z(output_qubit)
    '''
    pixel_qubits = cirq.GridQubit.rect(n_qubits, 1)
    color_qubit = cirq.GridQubit(-1, -1)
    output_qubit = cirq.GridQubit(-2, -2)
    circuit = cirq.Circuit()

    circuit.append(cirq.X(output_qubit))
    circuit.append(cirq.H(output_qubit))

    builder = CircuitLayerBuilder(pixel_qubits, color_qubit, output_qubit)
    if network_type=='CRADL':
        for layer in range(n_layers // 2):
            builder.add_layer(circuit, cirq.XX, 'xx{}'.format(num2words.num2words(layer)))
            builder.add_layer(circuit, cirq.ZZ, 'zz{}'.format(num2words.num2words(layer)))
    elif network_type=='CRADML':
        for layer in range(n_layers // 2):
            builder.add_mixed_layer(circuit, cirq.XX, cirq.ZZ, 'xxzz{}'.format(num2words.num2words(layer))) 
            builder.add_mixed_layer(circuit, cirq.ZZ, cirq.XX, 'zzxx{}'.format(num2words.num2words(layer))) 
    elif network_type=='CRAML':
        for layer in range(n_layers // 2):
            builder.add_mixed_layer2(circuit, cirq.XX, cirq.ZZ, 'xxzz{}'.format(num2words.num2words(layer)))
    else:
        raise Exception('unrecognized network type {}'.format(network_type))

    circuit.append(cirq.H(output_qubit))

    return circuit, cirq.Z(output_qubit)
