import os
import time
import itertools
import numpy as np
import jax
import jax.numpy as jnp
import pennylane as qml
jax.config.update("jax_enable_x64", True)


class VQLSJax:
    def __init__(self, args, AA, angle_net):
        print('VQLS - StronglyEntanglingLayers')
        # arguments
        self.args = args
        self.key = args.key
        self.n_qubits = args.n_qubits
        self.n_layers = args.n_layers

        self.AA = AA
        self.AA_2 = (AA.T.conj() @ AA)

        self.tot_qubits = self.n_qubits + 1
        self.ancilla_idx = self.n_qubits

        self.angle_net = angle_net


        if args.ansatz_name == 'BasicEntanglerLayers' :
            self.ansatz = self.make_BasicEntanglerLayers()
            self.shape_function = lambda angle: angle.reshape((-1, self.n_layers, self.n_qubits))
            ex_params = jnp.ones((self.n_layers, self.n_qubits))
        elif args.ansatz_name == 'StronglyEntanglingLayers' :
            self.ansatz = self.make_StronglyEntanglingLayers()
            self.shape_function = lambda angle: angle.reshape((-1, self.n_layers, self.n_qubits, 3))
            ex_params = jnp.ones((self.n_layers, self.n_qubits, 3))
        elif args.ansatz_name == 'QCNN':
            self.ansatz = self.make_qcnn()
            self.shape_function = lambda angle: angle
            # ex_params = jnp.ones(34,)
            ex_params = jnp.ones(17 * self.n_layers,)

        elif args.ansatz_name == 'QCNNNN':
            self.shape_function = lambda angle: angle
            self.ansatz = self.make_qcnnnn()
            ex_params = jnp.ones((17+3 * self.n_qubits) * self.n_layers,)
        else:
            raise NotImplementedError('VQLSJax - WRONG ANSATZ')




        # # Ansatz: Save
        try:
            circuit_str = qml.draw(self.ansatz, level="device", max_length=2000)(ex_params)
            with open(os.path.join(args.RESULT_PATH, f"ansatz_structure.txt"), "w") as f:
                f.write(circuit_str)
            fig, ax = qml.draw_mpl(self.ansatz, expansion_strategy="device")(ex_params)
            fig.savefig(os.path.join(args.RESULT_PATH, f"ansatz_structure.png"), dpi=300, bbox_inches='tight')
        except:
            pass

        self.A_ham, self.A_coeffs, self.A_terms = self.build_hamiltonian(self.n_qubits, AA)
        self.A_coeffs_jnp = jnp.array(self.A_coeffs)
        self.n_terms = len(self.A_terms)

        self.A_2_ham, self.A_2_coeffs, self.A_2_terms = self.build_hamiltonian(self.n_qubits, self.AA_2)
        self.A_2_coeffs_jnp = jnp.array(self.A_2_coeffs)
        self.n_terms = len(self.A_terms)

        self.H_expect = self.append_pauli_x() #이건 ancilla에 X 붙여주는 함수


        print(f'self.n_terms: { self.n_terms}')

        # settings
        self.dev_hadamard = qml.device("default.qubit", wires=self.tot_qubits)
        self.dev_Ax_norm = qml.device("default.qubit", wires=self.n_qubits)

        self.interface = 'jax'


        qnodes_real_hadamard = self.make_overlap_real_mu_qnode()
        self.qnodes_real_hadamard = jax.vmap(qnodes_real_hadamard, in_axes=(0, 0))

        qnodes_Ax_vec = self.make_Ax_qnode()
        self.qnodes_Ax_vec = jax.vmap(qnodes_Ax_vec, in_axes=(0,))


        predict_qnode = self.predict_qnode()
        self.predict_vmap = jax.vmap(predict_qnode, in_axes=(0,))


    def compute_expressibility(self, ansatz_fn, param_shape, n_qubits, n_samples=800, seed=42):
        import torch
        import numpy as np
        import pennylane as qml
        import matplotlib.pyplot as plt

        torch.manual_seed(seed)
        np.random.seed(seed)

        dev = qml.device("default.qubit", wires=self.n_qubits)
        
        @qml.qnode(dev, interface="torch")
        def circuit(params):
            ansatz_fn(params)
            return qml.density_matrix(wires=range(self.n_qubits))

        # 샘플링된 밀도 행렬들
        rho_list = []
        for i in range(n_samples):
            params = torch.randn(param_shape)
            rho = circuit(params)
            rho_list.append(rho)

        # 평균 밀도 행렬 계산
        rho_avg = torch.mean(torch.stack(rho_list), dim=0).detach().numpy()

        # Haar 평균 상태 계산
        def random_unitary(N):
            Z = np.random.randn(N, N) + 1.0j * np.random.randn(N, N)
            Q, R = np.linalg.qr(Z)
            D = np.diag(np.diagonal(R) / np.abs(np.diagonal(R)))
            return Q @ D

        def haar_integral(num_qubits, samples):
            N = 2**num_qubits
            rho_haar = np.zeros((N, N), dtype=complex)
            zero_state = np.zeros(N, dtype=complex)
            zero_state[0] = 1
            for _ in range(samples):
                psi = random_unitary(N) @ zero_state
                rho = np.outer(psi, psi.conj())
                rho_haar += rho
            return rho_haar / samples

        rho_haar = haar_integral(n_qubits, n_samples)

        # Frobenius 거리 계산
        fro_dist = np.linalg.norm(rho_avg - rho_haar, ord="fro")
        return fro_dist


    def build_hamiltonian(self, n, AA):
        '''
            n : n_qubits
        '''
        A_ham = self.pseudo_pauli_decomposition(n, AA) # Non hermitian이라서 pennylane pauli_decomposition 사용 못함
        A_ham = A_ham.operation()
        A_coeffs, A_terms = A_ham.terms()

        sorted_indices = np.argsort(np.abs(A_coeffs))[::-1]
        rearranged_A_coeffs = [A_coeffs[i] for i in sorted_indices]
        rearranged_A_ops = [A_terms[i] for i in sorted_indices]
        A_ham = qml.Hamiltonian(rearranged_A_coeffs,rearranged_A_ops) #여기에 grouping 쓸수도 있는데 결과가 다르게 나와서 아래에 grouping 사용
        A_ham.compute_grouping("qwc")
        return A_ham, rearranged_A_coeffs, rearranged_A_ops
    

    def append_pauli_x(self,):

        new_coeffs = []
        new_terms = []

        for coeff, term in zip(self.A_coeffs, self.A_terms):
            # 기존 term에 마지막 큐빗(n_qubits)의 PauliX를 텐서곱

            new_term = term @ qml.PauliX(self.ancilla_idx)
            
            new_terms.append(new_term)
            new_coeffs.append(coeff)
        
        H = qml.Hamiltonian(new_coeffs, new_terms)
        H.compute_grouping("qwc")
        return H 

    def pseudo_pauli_decomposition(self, n:int, matrix: np.ndarray) -> qml.pauli.PauliSentence:

        I = np.array([[1, 0], [0, 1]], dtype=complex)
        X = np.array([[0, 1], [1, 0]], dtype=complex)
        Y = np.array([[0, -1j], [1j, 0]], dtype=complex)
        Z = np.array([[1, 0], [0, -1]], dtype=complex)

        basis_dict = {
            'I': I,
            'X': X,
            'Y': Y,
            'Z': Z
        }
        pauli_sentence = qml.pauli.PauliSentence()

        for pauli_string in itertools.product(['I','X','Y','Z'], repeat=n):
            P = np.array([[1]], dtype=complex)
            for symbol in pauli_string:
                P = np.kron(P, basis_dict[symbol])
            alpha = (1/(2**n)) * np.trace(P @ matrix)

            if abs(alpha) > 1e-12:
                wire_ops = {wire_idx: gate_label 
                        for wire_idx, gate_label in enumerate(pauli_string)}
                pauli_word = qml.pauli.PauliWord(wire_ops)
                pauli_sentence[pauli_word] = alpha
        return pauli_sentence
    


    def make_BasicEntanglerLayers(self,):
        print(f'VQLS - BasicEntanglerLayers')
        n_qubits = self.n_qubits
        def ansatz(weights):
            '''
                weights: (n_layers, n_qubits)
            '''
            qml.BasicEntanglerLayers(weights, wires=range(n_qubits))
        return ansatz
    
    def make_StronglyEntanglingLayers(self,):
        n_qubits = self.n_qubits
        def ansatz(weights): # ansatz
            qml.StronglyEntanglingLayers(weights=weights, wires=range(n_qubits))
        return ansatz
    
    def make_qcnn(self,):
        print(f'VQLS - QCNN')
        args = self.args
        from src.qcnn import QCNN
        QCNN = QCNN(args)
        ansatz = QCNN.make_qcnn()
        return ansatz 

    def make_qcnnnn(self,):
        print(f'VQLS - QCNNNN')
        args = self.args
        from src.qcnn import QCNNNN
        QCNNNN = QCNNNN(args)
        ansatz = QCNNNN.make_qcnn()
        return ansatz 
    
    def make_U_b(self,):
        n_qubits = self.n_qubits
        def U_b(b_vec): # embedding forcing 
            qml.AmplitudeEmbedding(b_vec, wires=range(n_qubits), normalize=True)
        return U_b

    def make_ctrl_l(self,):
        ancilla_idx = self.ancilla_idx
        def ctrl_l(term_l): # control pauli terms
            # term = self.A_terms[l]
            qml.ctrl(term_l, control=ancilla_idx)
        return ctrl_l

    def make_overlap_real_mu_qnode(self):
        n_qubits = self.n_qubits
        dev = self.dev_hadamard
        interface = self.interface
        ancilla_idx = self.ancilla_idx
        H_expect = self.H_expect
        ansatz = self.ansatz
        U_b = self.make_U_b()

        @qml.qnode(dev, interface=interface)
        def overlap_real_mu_qnode(b_vec, angles):
            qml.Hadamard(wires=ancilla_idx)
            qml.ctrl(ansatz, control=ancilla_idx)(angles)
            qml.ctrl(U_b, control = ancilla_idx, control_values=0)(b_vec)
            return qml.expval(H_expect)
        return overlap_real_mu_qnode

    def make_Ax_qnode(self):
        dev = self.dev_Ax_norm
        interface = self.interface
        ansatz = self.ansatz
        A_2_ham = self.A_2_ham
        @qml.qnode(dev, interface=interface)
        def Ax_qnode(angles):
            ansatz(angles)
            return qml.expval(A_2_ham)
        return Ax_qnode   

    def make_cost_loc(self,): 
        qnodes_real_hadamard = self.qnodes_real_hadamard
        qnodes_Ax_vec = self.qnodes_Ax_vec

        angle_net = self.angle_net
        shape_function = self.shape_function
        def cost_loc(nn_params, batch_forcing, batch_RHS):
            angle = angle_net.apply({'params': nn_params}, batch_forcing)
            angle = shape_function(angle)
            
            fAx_overlap_real = qnodes_real_hadamard(batch_RHS, angle)
            numer_real = jnp.real(fAx_overlap_real)

            Ax_vec = qnodes_Ax_vec(angle)
            Ax_vec_abs = jnp.abs(Ax_vec) 
            Ax_vec_norm = jnp.sqrt(Ax_vec_abs)

            denom_norm = Ax_vec_norm

            # ones = jnp.ones_like(numer_real)
            # cost =  ones - numer_real / denom_norm
            cost =  (numer_real - denom_norm)**2
            return jnp.mean(cost)
        return cost_loc


    def predict_qnode(self):
        dev = self.dev_Ax_norm
        interface = self.interface
        ansatz = self.ansatz
        @qml.qnode(dev, interface=interface)
        def circuit(angles):
            ansatz(angles)
            return qml.state()
        return circuit

    def make_predict(self,):
        n_layers = self.n_layers
        n_qubits = self.n_qubits
        angle_net = self.angle_net
        AA = self.AA # AA.shape: (32, 32)
        predict_qnode = self.predict_vmap
        shape_function = self.shape_function

        @jax.jit
        def predict(nn_params, batch_forcing ,batch_RHS):
            angle = angle_net.apply({'params': nn_params}, batch_forcing)
            angle = shape_function(angle)

            raw_state = predict_qnode(angle)

            batch_b_abs = jnp.abs(batch_RHS)
            numerator = jnp.linalg.norm(batch_b_abs, axis=-1).reshape(-1, 1)

            Ax_result = jax.vmap(lambda x: AA @ x)(raw_state)

            Ax_abs = jnp.abs(Ax_result) 
            denominator = jnp.linalg.norm(Ax_abs, axis=-1).reshape(-1, 1)

            alpha_predict = raw_state * (numerator / denominator)

            alpha_predict = jnp.sign(jnp.cos(jnp.angle(alpha_predict))) * jnp.abs(alpha_predict)
            return alpha_predict
        return predict