import os
import time
import itertools
import numpy as np
import jax
import jax.numpy as jnp
import pennylane as qml
jax.config.update("jax_enable_x64", True)


class VQLSJax:
    def __init__(self, args, angle_net, B_matrix, C_matrix):
        print('VQLS - StronglyEntanglingLayers')
        
        # arguments
        self.args = args
        self.key = args.key
        self.n_qubits = args.n_qubits
        self.tot_qubits = self.n_qubits + 1
        self.ancilla_idx = self.n_qubits

        self.angle_net = angle_net
        self.n_layers = args.n_layers

        self.B_matrix = B_matrix
        self.C_matrix = C_matrix

        self.B_2_matrix = (B_matrix.T.conj() @ B_matrix)
        self.BC_matrix = (B_matrix.T.conj() @ C_matrix)
        self.CB_matrix = (C_matrix.T.conj() @ B_matrix)
        self.C_2_matrix = (C_matrix.T.conj() @ C_matrix)

        # settings
        self.dev_hadamard = qml.device("default.qubit", wires=self.tot_qubits)
        self.dev_Ax_norm = qml.device("default.qubit", wires=self.n_qubits)
        self.interface = 'jax'

        if args.ansatz_name == 'BasicEntanglerLayers' :
            self.ansatz = self.make_BasicEntanglerLayers()
            self.shape_function = lambda angle: angle.reshape((-1, self.n_layers, self.n_qubits))
            ex_params = jnp.ones((self.n_layers, self.n_qubits))
        elif args.ansatz_name == 'StronglyEntanglingLayers' :
            self.ansatz = self.make_StronglyEntanglingLayers()
            self.shape_function = lambda angle: angle.reshape((-1, self.n_layers, self.n_qubits, 3))
            ex_params = jnp.ones((self.n_layers, self.n_qubits, 3))
        elif args.ansatz_name == 'QCNN':
            self.ansatz = self.make_qcnn()
            self.shape_function = lambda angle: angle
            # ex_params = jnp.ones(34,)
            ex_params = jnp.ones(17 * self.n_layers,)

        elif args.ansatz_name == 'QCNNNN':
            self.shape_function = lambda angle: angle
            self.ansatz = self.make_qcnnnn()
            ex_params = jnp.ones((17+3 * self.n_qubits) * self.n_layers,)
        else:
            raise NotImplementedError('VQLSJax - WRONG ANSATZ')
        try:
            circuit_str = qml.draw(self.ansatz, level="device", max_length=2000)(ex_params)
            with open(os.path.join(args.RESULT_PATH, f"ansatz_structure.txt"), "w") as f:
                f.write(circuit_str)
            fig, ax = qml.draw_mpl(self.ansatz, expansion_strategy="device")(ex_params)
            fig.savefig(os.path.join(args.RESULT_PATH, f"ansatz_structure.png"), dpi=300, bbox_inches='tight')
        except:
            pass

        self.B_ham, self.B_coeffs, self.B_terms, self.B_terms_Pauli_X = self.build_hamiltonian(self.n_qubits, B_matrix)
        self.C_ham, self.C_coeffs, self.C_terms, self.C_terms_Pauli_X = self.build_hamiltonian(self.n_qubits, C_matrix)

        self.B_2_ham, self.B_2_coeffs, self.B_2_terms, self.B_2_terms_Pauli_X = self.build_hamiltonian(self.n_qubits, self.B_2_matrix)
        self.BC_ham, self.BC_coeffs, self.BC_terms, self.BC_terms_Pauli_X = self.build_hamiltonian(self.n_qubits, self.BC_matrix)
        self.CB_ham, self.CB_coeffs, self.CB_terms, self.CB_terms_Pauli_X = self.build_hamiltonian(self.n_qubits, self.CB_matrix)
        self.C_2_ham, self.C_2_coeffs, self.C_2_terms, self.C_2_terms_Pauli_X = self.build_hamiltonian(self.n_qubits, self.C_2_matrix)


        self.qnodes_real_hadamard = self.make_overlap_real_mu_qnode(self.ansatz
                                                                    , self.B_coeffs, self.B_terms_Pauli_X
                                                                    , self.C_coeffs, self.C_terms_Pauli_X
                                                                    )
        self.qnodes_Ax_vec =  self.make_Ax_vec(self.ansatz
                                               , self.B_2_coeffs, self.B_2_terms
                                               , self.BC_coeffs, self.BC_terms
                                               , self.CB_coeffs, self.CB_terms
                                               , self.C_2_coeffs, self.C_2_terms
                                               )
        self.predict_vmap = self.make_predict_qnode(self.ansatz)


    def compute_expressibility(self, ansatz_fn, param_shape, n_qubits, n_samples=800, seed=42):
        import torch
        import numpy as np
        import pennylane as qml
        import matplotlib.pyplot as plt

        torch.manual_seed(seed)
        np.random.seed(seed)

        dev = qml.device("default.qubit", wires=self.n_qubits)
        
        @qml.qnode(dev, interface="torch")
        def circuit(params):
            ansatz_fn(params)
            return qml.density_matrix(wires=range(self.n_qubits))

        # 샘플링된 밀도 행렬들
        rho_list = []
        for i in range(n_samples):
            params = torch.randn(param_shape)
            rho = circuit(params)
            rho_list.append(rho)

        # 평균 밀도 행렬 계산
        rho_avg = torch.mean(torch.stack(rho_list), dim=0).detach().numpy()

        # Haar 평균 상태 계산
        def random_unitary(N):
            Z = np.random.randn(N, N) + 1.0j * np.random.randn(N, N)
            Q, R = np.linalg.qr(Z)
            D = np.diag(np.diagonal(R) / np.abs(np.diagonal(R)))
            return Q @ D

        def haar_integral(num_qubits, samples):
            N = 2**num_qubits
            rho_haar = np.zeros((N, N), dtype=complex)
            zero_state = np.zeros(N, dtype=complex)
            zero_state[0] = 1
            for _ in range(samples):
                psi = random_unitary(N) @ zero_state
                rho = np.outer(psi, psi.conj())
                rho_haar += rho
            return rho_haar / samples

        rho_haar = haar_integral(n_qubits, n_samples)

        # Frobenius 거리 계산
        fro_dist = np.linalg.norm(rho_avg - rho_haar, ord="fro")
        return fro_dist


    def build_hamiltonian(self, n, AA):
        '''
            n : n_qubits
        '''
        A_ham = self.pseudo_pauli_decomposition(n, AA) # Non hermitian이라서 pennylane pauli_decomposition 사용 못함
        A_ham = A_ham.operation()
        A_coeffs, A_terms = A_ham.terms()

        sorted_indices = np.argsort(np.abs(A_coeffs))[::-1]

        rearranged_A_coeffs = [A_coeffs[i] for i in sorted_indices]

        rearranged_A_ops = [A_terms[i] for i in sorted_indices]
        rearranged_A_ops_Pauli_X = [A_terms[i]  @ qml.PauliX(self.ancilla_idx) for i in sorted_indices]
        
        A_ham = qml.Hamiltonian(rearranged_A_coeffs,rearranged_A_ops) #여기에 grouping 쓸수도 있는데 결과가 다르게 나와서 아래에 grouping 사용
        A_ham.compute_grouping("qwc")
        return A_ham, jnp.array(rearranged_A_coeffs), rearranged_A_ops, rearranged_A_ops_Pauli_X
    
    def pseudo_pauli_decomposition(self, n:int, matrix: np.ndarray) -> qml.pauli.PauliSentence:

        I = np.array([[1, 0], [0, 1]], dtype=complex)
        X = np.array([[0, 1], [1, 0]], dtype=complex)
        Y = np.array([[0, -1j], [1j, 0]], dtype=complex)
        Z = np.array([[1, 0], [0, -1]], dtype=complex)

        basis_dict = {
            'I': I,
            'X': X,
            'Y': Y,
            'Z': Z
        }
        pauli_sentence = qml.pauli.PauliSentence()

        for pauli_string in itertools.product(['I','X','Y','Z'], repeat=n):
            P = np.array([[1]], dtype=complex)
            for symbol in pauli_string:
                P = np.kron(P, basis_dict[symbol])
            alpha = (1/(2**n)) * np.trace(P @ matrix)

            if abs(alpha) > 1e-12:
                wire_ops = {wire_idx: gate_label 
                        for wire_idx, gate_label in enumerate(pauli_string)}
                pauli_word = qml.pauli.PauliWord(wire_ops)
                pauli_sentence[pauli_word] = alpha
        return pauli_sentence
    


    def make_BasicEntanglerLayers(self,):
        print(f'VQLS - BasicEntanglerLayers')
        n_qubits = self.n_qubits
        def ansatz(weights):
            '''
                weights: (n_layers, n_qubits)
            '''
            qml.BasicEntanglerLayers(weights, wires=range(n_qubits))
        return ansatz
    
    def make_StronglyEntanglingLayers(self,):
        n_qubits = self.n_qubits
        def ansatz(weights): # ansatz
            qml.StronglyEntanglingLayers(weights=weights, wires=range(n_qubits))
        return ansatz
    
    def make_qcnn(self,):
        print(f'VQLS - QCNN')
        args = self.args
        from src.qcnn import QCNN
        QCNN = QCNN(args)
        ansatz = QCNN.make_qcnn()
        return ansatz 

    def make_qcnnnn(self,):
        print(f'VQLS - QCNNNN')
        args = self.args
        from src.qcnn import QCNNNN
        QCNNNN = QCNNNN(args)
        ansatz = QCNNNN.make_qcnn()
        return ansatz 
    
    def make_U_b(self,):
        n_qubits = self.n_qubits
        def U_b(b_vec): # embedding forcing 
            qml.AmplitudeEmbedding(b_vec, wires=range(n_qubits), normalize=True)
        return U_b

    def make_ctrl_l(self,):
        ancilla_idx = self.ancilla_idx
        def ctrl_l(term_l): # control pauli terms
            # term = self.A_terms[l]
            qml.ctrl(term_l, control=ancilla_idx)
        return ctrl_l

    def make_overlap_real_mu_qnode(self, ansatz, B_coeffs, B_terms, C_coeffs, C_terms):
        dev = self.dev_hadamard
        interface = self.interface
        ancilla_idx = self.ancilla_idx

        n_qubits = self.n_qubits
        B_matrix = self.B_matrix
        C_matrix = self.C_matrix
        build_hamiltonian = self.build_hamiltonian

        H_ops = B_terms + C_terms
        U_b = self.make_U_b()
        @qml.qnode(dev, interface=interface)
        def overlap_real_mu_qnode(pde_param, batch_RHS, batch_angles):


            # A_ham, A_coeffs, A_terms, A_terms_Pauli_X = build_hamiltonian(n_qubits, B_matrix + pde_param * C_matrix)
            # H_simplified = qml.Hamiltonian(A_coeffs, A_terms_Pauli_X)
            # H_simplified.compute_grouping("qwc")

            H_coeffs = jnp.concatenate([jnp.array(B_coeffs), pde_param * C_coeffs])
            H_simplified = qml.Hamiltonian(H_coeffs, H_ops).simplify()

            qml.Hadamard(wires=ancilla_idx)
            qml.ctrl(ansatz, control=ancilla_idx)(batch_angles)
            qml.ctrl(U_b, control = ancilla_idx, control_values=0)(batch_RHS)

            return qml.expval(H_simplified)
    
        return jax.vmap(overlap_real_mu_qnode, in_axes=(0, 0, 0))

    def make_Ax_vec(self, ansatz
                    , B_2_coeffs, B_2_terms
                    , BC_coeffs, BC_terms
                    , CB_coeffs, CB_terms
                    , C_2_coeffs, C_2_terms
                    ):
        dev = self.dev_Ax_norm
        interface = self.interface



        n_qubits = self.n_qubits
        B_matrix = self.B_matrix
        C_matrix = self.C_matrix
        build_hamiltonian = self.build_hamiltonian


        H_ops = B_2_terms + BC_terms + CB_terms + C_2_terms
        @qml.qnode(dev, interface=interface)
        def qnodes_Ax_vec(pde_param, angles):

            # A_matrix  = B_matrix + pde_param * C_matrix
            # A_2_ham, A_2_coeffs, A_2_terms, _ = build_hamiltonian(n_qubits, (A_matrix.T.conj() @ A_matrix) )
            # H_simplified = A_2_ham

            H_coeffs = jnp.concatenate([B_2_coeffs
                                        , pde_param * BC_coeffs
                                        , pde_param.conj() * CB_coeffs
                                        , (pde_param ** 2) * C_2_coeffs
                                        ])
            H_simplified = qml.Hamiltonian(H_coeffs, H_ops).simplify()
            ansatz(angles)
            return qml.expval(H_simplified)
        return jax.vmap(qnodes_Ax_vec, in_axes=(0,0,))   

    def make_cost_loc(self,): 
        qnodes_real_hadamard = self.qnodes_real_hadamard
        qnodes_Ax_vec = self.qnodes_Ax_vec
        n_layers = self.n_layers
        n_qubits = self.n_qubits
        angle_net = self.angle_net
        shape_function = self.shape_function

        def cost_loc(nn_params, batch_pde_param, batch_forcing, batch_RHS):

            angle = angle_net.apply({'params': nn_params}, batch_pde_param, batch_forcing)
            angle = shape_function(angle)
            
            fAx_overlap_real = qnodes_real_hadamard(batch_pde_param, batch_RHS, angle)
            numer_real = jnp.real(fAx_overlap_real)

            Ax_vec = qnodes_Ax_vec(batch_pde_param, angle)
            Ax_vec_abs = jnp.abs(Ax_vec) 
            Ax_vec_norm = jnp.sqrt(Ax_vec_abs)

            denom_norm = Ax_vec_norm

            # ones = jnp.ones_like(numer_real)
            # cost =  ones - numer_real / denom_norm
            cost =  (numer_real - denom_norm)**2
            return jnp.mean(cost)
        return cost_loc


    def make_predict_qnode(self, ansatz):
        dev = self.dev_Ax_norm
        interface = self.interface
        @qml.qnode(dev, interface=interface)
        def circuit(angles):
            ansatz(angles)
            return qml.state()
        return jax.vmap(circuit, in_axes=(0,))

    def make_predict(self,):
        n_layers = self.n_layers
        n_qubits = self.n_qubits
        angle_net = self.angle_net

        predict_qnode = self.predict_vmap
        shape_function = self.shape_function

        @jax.jit
        def predict(nn_params, batch_pde_param, batch_A_matrix, batch_forcing ,batch_RHS):
            angle = angle_net.apply({'params': nn_params}, batch_pde_param, batch_forcing)
            angle = shape_function(angle)

            raw_state = predict_qnode(angle)

            batch_b_abs = jnp.abs(batch_RHS)
            numerator = jnp.linalg.norm(batch_b_abs, axis=-1).reshape(-1, 1)

            Ax_result = jax.vmap(lambda A, x: A @ x)(batch_A_matrix, raw_state)

            Ax_abs = jnp.abs(Ax_result) 
            denominator = jnp.linalg.norm(Ax_abs, axis=-1).reshape(-1, 1)

            alpha_predict = raw_state * (numerator / denominator)

            alpha_predict = jnp.sign(jnp.cos(jnp.angle(alpha_predict))) * jnp.abs(alpha_predict)
            return alpha_predict
        return predict