import os
import time
import itertools
import numpy as np
import jax
import jax.numpy as jnp
import pennylane as qml
jax.config.update("jax_enable_x64", True)


class VQLSJax:
    def __init__(self, args, AA, angle_net):
        print('VQLS - StronglyEntanglingLayers')
        # arguments
        self.args = args
        self.n_qubits = args.n_qubits
        self.n_layers = args.n_layers

        self.AA = AA
        self.AA_2 = (AA.T.conj() @ AA)

        self.tot_qubits = self.n_qubits + 1
        self.ancilla_idx = self.n_qubits

        self.angle_net = angle_net


        if args.ansatz_name == 'StronglyEntanglingLayers':
            self.ansatz = self.make_StronglyEntanglingLayers()
            self.shape_function = lambda angle: angle.reshape((-1, self.n_layers, self.n_qubits, 3))
            ex_params = jnp.ones((self.n_layers, self.n_qubits, 3))
        elif args.ansatz_name in ('BasicEntanglerLayersRX', 'BasicEntanglerLayersRY'):
            self.ansatz = self.make_BasicEntanglerLayers()
            self.shape_function = lambda angle: angle.reshape((-1, self.n_layers, self.n_qubits))
            ex_params = jnp.ones((self.n_layers, self.n_qubits))
        elif args.ansatz_name == 'QCNN':
            self.ansatz = self.make_qcnn()
            self.shape_function = lambda angle: angle
            # ex_params = jnp.ones(34 * self.n_layers,)
            ex_params = jnp.ones(17 * self.n_layers,)
        else:
            raise NotImplementedError('VQLSJax - WRONG ANSATZ')

        try:
            # Ansatz: Save
            circuit_str = qml.draw(self.ansatz, level="device", max_length=2000)(ex_params)
            with open(os.path.join(args.RESULT_PATH, f"ansatz_structure.txt"), "w") as f:
                f.write(circuit_str)
            fig, ax = qml.draw_mpl(self.ansatz, expansion_strategy="device")(ex_params)
            fig.savefig(os.path.join(args.RESULT_PATH, f"ansatz_structure.png"), dpi=300, bbox_inches='tight')
        except:
            print('Cannot save ansatz.')

        # self.AA, self.A_ham, self.A_coeffs, self.A_terms = self.build_hamiltonian(AA)
        self.A_ham, self.A_coeffs, self.A_terms = self.build_hamiltonian(self.n_qubits, AA)
        self.A_coeffs_jnp = jnp.array(self.A_coeffs)
        self.n_terms = len(self.A_terms)


        self.A_2_ham, self.A_2_coeffs, self.A_2_terms = self.build_hamiltonian(self.n_qubits, self.AA_2)
        self.A_2_coeffs_jnp = jnp.array(self.A_2_coeffs)
        self.n_terms = len(self.A_terms)

        self.H_expect = self.append_pauli_x() #이건 ancilla에 X 붙여주는 함수




        print(f'self.n_terms: { self.n_terms}')

        # settings
        self.dev_hadamard = qml.device("default.qubit", wires=self.tot_qubits)
        self.dev_Ax_norm = qml.device("default.qubit", wires=self.n_qubits)

        self.interface = 'jax'


        qnodes_real_hadamard = self.make_overlap_real_mu_qnode()
        self.qnodes_real_hadamard = jax.vmap(qnodes_real_hadamard, in_axes=(0, 0))

        qnodes_Ax_vec = self.make_Ax_qnode()
        self.qnodes_Ax_vec = jax.vmap(qnodes_Ax_vec, in_axes=(0,))


        predict_qnode = self.predict_qnode()
        self.predict_vmap = jax.vmap(predict_qnode, in_axes=(0,))

    def build_hamiltonian(self, n, AA):
        '''
            n : num of qubits
        '''
        A_ham = self.pseudo_pauli_decomposition(n, AA) # Non hermitian이라서 pennylane pauli_decomposition 사용 못함
        A_ham = A_ham.operation()
        A_coeffs, A_terms = A_ham.terms()

        sorted_indices = np.argsort(np.abs(A_coeffs))[::-1]
        rearranged_A_coeffs = [A_coeffs[i] for i in sorted_indices]
        rearranged_A_ops = [A_terms[i] for i in sorted_indices]
        A_ham = qml.Hamiltonian(rearranged_A_coeffs,rearranged_A_ops) #여기에 grouping 쓸수도 있는데 결과가 다르게 나와서 아래에 grouping 사용
        A_ham.compute_grouping("qwc")
        return A_ham, rearranged_A_coeffs, rearranged_A_ops
    

    def append_pauli_x(self,):

        new_coeffs = []
        new_terms = []

        for coeff, term in zip(self.A_coeffs, self.A_terms):
            # 기존 term에 마지막 큐빗(n_qubits)의 PauliX를 텐서곱

            new_term = term @ qml.PauliX(self.ancilla_idx)
            
            new_terms.append(new_term)
            new_coeffs.append(coeff)
        
        H = qml.Hamiltonian(new_coeffs, new_terms)
        H.compute_grouping("qwc")
        return H 

    def pseudo_pauli_decomposition(self, n:int, matrix: np.ndarray) -> qml.pauli.PauliSentence:

        I = np.array([[1, 0], [0, 1]], dtype=complex)
        X = np.array([[0, 1], [1, 0]], dtype=complex)
        Y = np.array([[0, -1j], [1j, 0]], dtype=complex)
        Z = np.array([[1, 0], [0, -1]], dtype=complex)

        basis_dict = {
            'I': I,
            'X': X,
            'Y': Y,
            'Z': Z
        }
        pauli_sentence = qml.pauli.PauliSentence()

        for pauli_string in itertools.product(['I','X','Y','Z'], repeat=n):
            P = np.array([[1]], dtype=complex)
            for symbol in pauli_string:
                P = np.kron(P, basis_dict[symbol])
            alpha = (1/(2**n)) * np.trace(P @ matrix)

            if abs(alpha) > 1e-12:
                wire_ops = {wire_idx: gate_label 
                        for wire_idx, gate_label in enumerate(pauli_string)}
                pauli_word = qml.pauli.PauliWord(wire_ops)
                pauli_sentence[pauli_word] = alpha
        return pauli_sentence
    


    def make_BasicEntanglerLayers(self,):
        print(f'VQLS - BasicEntanglerLayers')
        n_qubits = self.n_qubits
        def ansatz(weights):
            '''
                weights: (n_layers, n_qubits)
            '''
            qml.BasicEntanglerLayers(weights, wires=range(n_qubits), rotation=qml.RY)
        return ansatz
    
    def make_StronglyEntanglingLayers(self,):
        n_qubits = self.n_qubits
        def ansatz(weights): # ansatz
            qml.StronglyEntanglingLayers(weights=weights, wires=range(n_qubits))
        return ansatz
    
    def make_qcnn(self,):
        print(f'VQLS - QCNN')
        args = self.args
        n_qubits = self.n_qubits
        n_layers = self.args.n_layers
        # if n_qubits == 2:
        #     from src.qcnn import QCNN2
        #     QCNN = QCNN2()
        # elif n_qubits == 3:
        #     from src.qcnn import QCNN3
        #     QCNN = QCNN3() 
        # elif n_qubits == 4:
        #     from src.qcnn import QCNN4
        #     QCNN = QCNN4(args)
        # elif n_qubits == 8:
        #     from src.qcnn import QCNN8
        #     QCNN = QCNN8()
        # else:
        #     NotImplementedError('make_qcnn - Wrong Qubits')
        from src.qcnn import QCNN
        QCNN = QCNN(args)
        ansatz = QCNN.make_qcnn()
        return ansatz 

    def make_U_b(self,):
        n_qubits = self.n_qubits
        def U_b(b_vec): # embedding forcing 
            qml.AmplitudeEmbedding(b_vec, wires=range(n_qubits), normalize=True)
        return U_b

    def make_ctrl_l(self,):
        ancilla_idx = self.ancilla_idx
        def ctrl_l(term_l): # control pauli terms
            # term = self.A_terms[l]
            qml.ctrl(term_l, control=ancilla_idx)
        return ctrl_l

    def make_overlap_real_mu_qnode(self):
        n_qubits = self.n_qubits
        dev = self.dev_hadamard
        interface = self.interface
        ancilla_idx = self.ancilla_idx
        H_expect = self.H_expect
        ansatz = self.ansatz
        U_b = self.make_U_b()

        @qml.qnode(dev, interface=interface)
        def overlap_real_mu_qnode(b_vec, angles):
            qml.Hadamard(wires=ancilla_idx)
            qml.ctrl(ansatz, control=ancilla_idx)(angles)
            qml.ctrl(U_b, control = ancilla_idx, control_values=0)(b_vec)
            return qml.expval(H_expect)
        return overlap_real_mu_qnode

    def make_Ax_qnode(self):
        dev = self.dev_Ax_norm
        interface = self.interface
        ansatz = self.ansatz
        A_2_ham = self.A_2_ham
        @qml.qnode(dev, interface=interface)
        def Ax_qnode(angles):
            ansatz(angles)
            return qml.expval(A_2_ham)
        return Ax_qnode   

    def make_cost_loc(self,):
        qnodes_real_hadamard = self.qnodes_real_hadamard
        qnodes_Ax_vec = self.qnodes_Ax_vec
        n_layers = self.n_layers
        n_qubits = self.n_qubits
        angle_net = self.angle_net
        shape_function = self.shape_function

        def cost_loc(nn_params, batch_forcing, batch_RHS):
            angle = angle_net.apply({'params': nn_params}, batch_forcing)
            angle = shape_function(angle)
            fAx_overlap_real = qnodes_real_hadamard(batch_RHS, angle)
            numer_real = jnp.real(fAx_overlap_real)

            Ax_vec = qnodes_Ax_vec(angle)
            Ax_vec_abs = jnp.abs(Ax_vec) 
            Ax_vec_norm = jnp.sqrt(Ax_vec_abs)

            denom_norm = Ax_vec_norm

            # ones = jnp.ones_like(numer_real)
            # cost =  ones - numer_real / denom_norm
            cost =  (numer_real - denom_norm)**2
            return jnp.mean(cost)
        return cost_loc


    def predict_qnode(self):
        dev = self.dev_Ax_norm
        interface = self.interface
        ansatz = self.ansatz
        @qml.qnode(dev, interface=interface)
        def circuit(angles):
            ansatz(angles)
            return qml.state()
        return circuit

    def make_predict(self,):
        n_layers = self.n_layers
        n_qubits = self.n_qubits
        angle_net = self.angle_net
        AA = self.AA # AA.shape: (32, 32)
        predict_qnode = self.predict_vmap
        shape_function = self.shape_function

        @jax.jit
        def predict(nn_params, batch_forcing, batch_RHS):
            angle = angle_net.apply({'params': nn_params}, batch_forcing)
            angle = shape_function(angle)

            raw_state = predict_qnode(angle)

            b_batch_abs = jnp.abs(batch_RHS)
            numerator = jnp.linalg.norm(b_batch_abs, axis=-1).reshape(-1, 1)

            Ax_result = jax.vmap(lambda x: AA @ x)(raw_state)

            Ax_abs = jnp.abs(Ax_result) 
            denominator = jnp.linalg.norm(Ax_abs, axis=-1).reshape(-1, 1)


            alpha_predict = raw_state * (numerator / denominator)

            alpha_predict = jnp.sign(jnp.cos(jnp.angle(alpha_predict))) * jnp.abs(alpha_predict)
            return alpha_predict
        return predict