import unittest
import data_processing
import cirq
import os
import numpy as np

from constants import N_QUBITS, REFERENCE
from gates import controlled_x
from cirq import Simulator
from cirq.contrib.svg import SVGCircuit


class ComponentsTest(unittest.TestCase):
    def print_info(self, out_state_rep, num):
        print(out_state_rep)
        for i in range(0, len(out_state_rep), 2):
            print('{} x ({},{})'.format(format(i // 2, '0{}b'.format(num)), int(out_state_rep[i]), int(out_state_rep[i + 1])))

    def testing_image_to_circuit(self):
        circuit, qubits = data_processing.circuit_from_image(np.zeros(2**N_QUBITS) + 1, N_QUBITS, return_qubits=True)
        print(circuit)
        simulator = Simulator()
        result = simulator.simulate(circuit, qubit_order=qubits, initial_state=0)
        out_state = result.final_state_vector
        out_state_rep = np.abs(np.around(out_state.astype(np.float) * np.power(2, N_QUBITS/2), 3))
        self.print_info(out_state_rep, N_QUBITS)

        #SVGCircuit(circuit)
        ref_path = os.path.join(os.getcwd(), 'tests/testing_image_to_circuit.txt') 
        reference = np.loadtxt(ref_path).view(complex)
        np.testing.assert_allclose(reference, circuit.unitary())


    def testing_controlledX(self):
        n_q = 4
        qubits = cirq.GridQubit.rect(n_q-11, 1)
        qubits.append(cirq.GridQubit(n_q, 1))

        # TODO: Describe what this is testing
        circuit = cirq.Circuit()
        circuit.append([cirq.X(q) for q in qubits[:-1]])
        circuit.append(cirq.X.controlled(n_q-1)(*qubits))
        print(circuit)
        simulator = Simulator()
        result=simulator.simulate(circuit, qubit_order=qubits, initial_state=0)
        out_state = result.final_state_vector
        out_state_rep = np.abs(np.around(out_state.astype(np.float) * np.power(2, (n_q)/2), 3))
        self.print_info(out_state_rep, n_q-1)
        np.testing.assert_allclose(REFERENCE, out_state)

        # TODO: Describe what this is testing
        circuit2 = cirq.Circuit()
        circuit2.append([cirq.X(q) for q in qubits[:-1]])
        circuit2.append(controlled_x(qubits))
        print(circuit2)
        simulator = Simulator()
        result = simulator.simulate(circuit2, qubit_order=qubits, initial_state=0)
        out_state = result.final_state_vector
        out_state_rep = np.abs(np.around(out_state.astype(np.float) * np.power(2, (n_q)/2),3))
        self.print_info(out_state_rep, n_q - 1)
        np.testing.assert_allclose(REFERENCE, out_state, atol=1e-6)


if __name__ == '__main__':
    unittest.main()
