import os
import sys
sys.path.insert(0, os.getcwd())
import json
import tqdm
import torch
import numpy as np
import pennylane as qml
import matplotlib.pyplot as plt
import circuit.var_config as vc

from tqdm import tqdm
from pennylane import CircuitGraph
from pennylane import numpy as pnp
from torch.nn import functional as F


current_path = os.getcwd()

dev = qml.device("default.qubit", wires=vc.num_qubits)

# app=1: Fidelity Task; app=2: MAXCUT; app=3: VQE
@qml.qnode(dev)
def circuit_qnode(circuit_list, app=1, hamiltonian=None, edge=None):
    for params in list(circuit_list):
        if params == 'START':
            continue
        elif params == 'Identity':
            qml.Identity(wires=params[1])
        elif params[0] == 'PauliX':
            qml.PauliX(wires=params[1])
        elif params[0] == 'PauliY':
            qml.PauliY(wires=params[1])
        elif params[0] == 'PauliZ':
            qml.PauliZ(wires=params[1])
        elif params[0] == 'Hadamard':
            qml.Hadamard(wires=params[1])
        elif params[0] == 'RX':
            param = pnp.array(params[2], requires_grad = True)
            qml.RX(param, wires=params[1])
        elif params[0] == 'RY':
            param = pnp.array(params[2], requires_grad = True)
            qml.RY(param, wires=params[1])
        elif params[0] == 'RZ':
            param = pnp.array(params[2], requires_grad = True)
            qml.RZ(param, wires=params[1])
        elif params[0] == 'CNOT':
            qml.CNOT(wires=[params[1], params[2]])
        elif params[0] == 'CZ':
            qml.CZ(wires=[params[1], params[2]])
        elif params[0] == 'CY':
            qml.CY(wires=[params[1], params[2]])
        elif params[0] == 'U3':
            theta = pnp.array(params[2], requires_grad = True)
            phi = pnp.array(params[3], requires_grad = True)
            delta = pnp.array(params[4], requires_grad = True)
            qml.U3(theta, phi, delta, wires=params[1])
        elif params == 'END':
            break
        else:
            print(params)
            raise ValueError("There exists operations not in the allowed operation pool!")

    if app == 1:
        return qml.state()
    elif app == 2:
        if edge is None:
            return qml.sample()
        if hamiltonian != None:
            return qml.expval(hamiltonian)
        else:
            raise ValueError("Please pass a hamiltonian as an observation for QAOA_MAXCUT!")
    elif app == 3:
        if hamiltonian != None:
            return qml.expval(hamiltonian)
        else:
            raise ValueError("Please pass a hamiltonian as an observation for VQE!")
    else:
        print("Note: Currently, there are no correspoding appllications!")


def reverse_index(lst, value, start=None, end=None):
    """
    Search for a specified value in a list from right to left.
    
    Parameters:
    lst (list): The list to search through.
    value (any): The value to find.
    start (int, optional): The starting index for the search. If None, default is 0.
    end (int, optional): The end index for the search. If None, default is the length of the list.
    
    Returns:
    int: The index of the first match found in the original list from right to left.
    Raises ValueError if the value is not found.
    """
    
    if start is None:
        start = 0
    if end is None:
        end = len(lst)
    
    # Handle negative indices
    start = start if start >= 0 else len(lst) + start
    end = end if end >= 0 else len(lst) + end

    # Create a reversed sublist for the specified range
    reversed_sublist = lst[start:end][::-1]
    
    # Compute the index in the reversed sublist
    reversed_index = reversed_sublist.index(value)
    
    # Convert the reversed index back to the original index in the list
    original_index = end - reversed_index - 1
    
    return original_index


class CircuitManager:

    # class constructor
    def __init__(self, num_qubits, num_circuits, num_gates, max_depth, allowed_gates):
        self.num_qubits = num_qubits
        self.num_circuits = num_circuits
        self.num_gates = num_gates
        self.max_depth = max_depth
        self.gate_dict = {k:v for (k, v) in zip(allowed_gates, list(range(1, len(allowed_gates)+1)))}
        self.pbar = tqdm(range(self.num_circuits), desc ="generated_num_circuits")

    # encode allowed gates in one-hot encoding
    def encode_gate_type(self):
        gate_dict = {}
        ops = list(self.gate_dict.keys()).copy()
        ops.insert(0, 'START')
        ops.append('END')
        ops_len = len(ops)
        ops_index = torch.tensor(range(ops_len))
        type_onehot = F.one_hot(ops_index, num_classes = ops_len)
        for i in range(ops_len):
            gate_dict[ops[i]] = type_onehot[i]
        return gate_dict

    # Circuit generator function
    def generate_circuits(self):
        circuit_list = []
        unique_circuits_checker = []

        def random_circuit_generator():
            circuit_ops = []
            for i in range(self.num_gates):
                gate = np.random.choice(list(self.gate_dict.keys()), p=[0.03, 0.03, 0.03, 0.06, 0.15, 0.15, 0.15, 0.1, 0.1, 0.1, 0.1]).tolist()
                qubit = np.random.choice(self.num_qubits)
                if gate in ['CNOT', 'CY', 'CZ']:
                    all_choice = np.array(range(self.num_qubits))
                    possible_choice = np.delete(all_choice, np.where(all_choice == qubit))
                    qubit2 = np.random.choice(possible_choice).item()
                    circuit_ops.append((gate, qubit, qubit2))
                elif gate in ['RX', 'RY', 'RZ']:
                    angle = np.random.uniform(0, 2 * np.pi)
                    circuit_ops.append((gate, qubit, angle))
                elif gate in ['U3']:
                    theta = np.random.uniform(0, 2 * np.pi)
                    phi = np.random.uniform(0, 2 * np.pi)
                    delta = np.random.uniform(0, 2 * np.pi)
                    circuit_ops.append((gate, qubit, theta, phi, delta))
                else:
                    circuit_ops.append((gate, qubit))
                if qml.specs(circuit_qnode)(circuit_ops)['resources'].depth > self.max_depth:
                    return
            return circuit_ops
                
        while len(circuit_list) < self.num_circuits:
            circuit_ops = random_circuit_generator()
            if circuit_ops == None:
                continue
            circuit_hash = self.get_circuit_graph(circuit_ops).hash
            if circuit_hash not in set(unique_circuits_checker):
                unique_circuits_checker.append(circuit_hash)
                circuit_list.append(tuple(circuit_ops))
                self.pbar.update(1)
        self.pbar.close()
        return circuit_list
    
    # transform a circuit into a circuit graph
    def get_circuit_graph(self, circuit_list):
        circuit_qnode(circuit_list)
        tape = circuit_qnode.qtape
        ops = tape.operations
        obs = tape.observables
        return CircuitGraph(ops, obs, tape.wires)
    
    # get the gate matrix of a circuit (For the original one, see “GSQAS: Graph Self-supervised Quantum Architecture Search” -> https://arxiv.org/pdf/2303.12381)）
    def get_gate_matrix(self, circuit_list, type='original'):
        gate_matrix = []
        op_list = []
        cl = list(circuit_list).copy()
        if cl[0] != 'START':
            cl.insert(0, 'START')
        if cl[-1] != 'END':
            cl.append('END')
        cg = self.get_circuit_graph(circuit_list)
        gate_dict = self.encode_gate_type()
        # For START node
        gate_matrix.append(gate_dict['START'].tolist() + [1]*self.num_qubits)
        op_list.append('START')
        #For gate nodes
        for op in cg.operations_in_order:
            op_list.append(op)
            op_qubits = [0] * self.num_qubits
            if op.name in ['CNOT', 'CY', 'CZ']:
                op_qubits[op.control_wires[0]] = -1 if type == "improved" else 1
                op_qubits[op.target_wires[0]] = 1
            else:
                for i in op.wires:
                    op_qubits[i] = 1
            op_vector = gate_dict[op.name].tolist() + op_qubits
            gate_matrix.append(op_vector)
        # For END node
        gate_matrix.append(gate_dict['END'].tolist() + [1]*self.num_qubits)
        op_list.append('END')
        return cl, gate_matrix


    # get the adjacent matrix of a circuit (See “GSQAS: Graph Self-supervised Quantum Architecture Search” -> https://arxiv.org/pdf/2303.12381)
    def get_adj_matrix(self, circuit_list):
        op_list=['START']
        cg = self.get_circuit_graph(circuit_list)
        for op in cg.operations_in_order:
            op_list.append(op)
        op_list.append('END')
        op_len = len(op_list)
        adj_matrix = np.zeros((op_len, op_len), dtype = int)
        if len(cg.wires) < vc.num_qubits:
            adj_matrix[0][-1] = 1
        for ind, op in enumerate(cg.operations_in_order):
            ancestors = cg.ancestors_in_order([op])
            descendants = cg.descendants_in_order([op])
            if len(ancestors) == 0:
                adj_matrix[0][ind+1] = 1
            else:
                if op.name in ['CNOT', 'CZ', 'CY']:
                    count = 0
                    control_flag = False
                    target_flag = False
                    for ancestor in ancestors[::-1]:
                        same_wires = set(ancestor.wires) & set(op.wires)
                        if op.control_wires[0] in same_wires and control_flag == True:
                            continue
                        if op.target_wires[0] in same_wires and target_flag == True:
                            continue
                        wires = set()
                        wires.update(same_wires)
                        if count < 2:
                            adj_matrix[reverse_index(op_list, ancestor, 0, ind+1)][ind+1] = 1
                            count += 1
                        if op.control_wires[0] in wires:
                            control_flag = True
                        if op.target_wires[0] in wires:
                            target_flag = True
                    if count == 1:
                        adj_matrix[0][ind+1] = 1
                else:
                    direct_ancestor = ancestors[-1]
                    adj_matrix[reverse_index(op_list, direct_ancestor, 0, ind+1)][ind+1] = 1
            if op.name in ['CNOT', 'CZ', 'CY']:
                wires = set()
                for descendant in descendants:
                    wires.update(set(descendant.wires) & set(op.wires))
                    if isinstance(descendant, qml.measurements.StateMP):
                        adj_matrix[op_list.index(op, ind+1)][-1] = 1
                    if len(wires) == 2:
                        break
            else:
                if isinstance(descendants[0], qml.measurements.StateMP):
                    adj_matrix[op_list.index(op, ind+1)][-1] = 1
        return adj_matrix
    
    # get a group of adjacency matrices according to differnt edge types
    def get_adj_matrix_group(self, circuit_list):
        '''
        type 1: single-qubit gate <-> single-qubit gate (START and END should belong to this type as well)
        type 2: single-qubit gate <-> control wire of a two-qubit gate
        type 3: single-qubit gate <-> target wire of a two-qubit gate
        type 4: control wire of a two-qubit gate -> target wire of a two-qubit gate
        type 5: target wire of a two-qubit gate -> control wire of a two-qubit gate
        type 6: control wire of a two-qubit gate <-> control wire of a two-qubit gate
        type 7: target wire of a two-qubit gate <-> target wire of a two-qubit gate
        '''
        op_list=['START']
        cg = self.get_circuit_graph(circuit_list)
        for op in cg.operations_in_order:
            op_list.append(op)
        op_list.append('END')
        op_len = len(op_list)
        adj_matrix = np.zeros((7, op_len, op_len), dtype = int)
        if len(cg.wires) < vc.num_qubits:
            adj_matrix[0][0][-1] = 1   
        for ind, op in enumerate(cg.operations_in_order):
            ancestors = cg.ancestors_in_order([op])
            descendants = cg.descendants_in_order([op])
            if len(ancestors) == 0:
                if op.name in ['CNOT', 'CZ', 'CY']:
                    adj_matrix[1][0][ind+1] = 1
                    adj_matrix[2][0][ind+1] = 1
                else:
                    adj_matrix[0][0][ind+1] = 1
            else:
                if op.name in ['CNOT', 'CZ', 'CY']:
                    count = 0
                    control_flag = False
                    target_flag = False
                    for ancestor in ancestors[::-1]:
                        same_wires = set(ancestor.wires) & set(op.wires)
                        if op.control_wires[0] in same_wires and control_flag == True:
                            continue
                        if op.target_wires[0] in same_wires and target_flag == True:
                            continue
                        wires = set()
                        wires.update(same_wires)
                        if count < 2:
                            if ancestor.name in ['CNOT', 'CZ', 'CY']:
                                path_nodes = cg.nodes_between(ancestor, op)
                                if op in path_nodes:
                                    path_nodes.remove(op)
                                if ancestor in path_nodes:
                                    path_nodes.remove(ancestor)
                                if ancestor.control_wires[0] == op.target_wires[0]:
                                    if len(path_nodes) == 0 or op.target_wires[0] not in path_nodes[0].wires:
                                        adj_matrix[3][reverse_index(op_list, ancestor, 0, ind+1)][ind+1] = 1
                                        count += 1
                                if ancestor.target_wires[0] == op.control_wires[0]:
                                    if len(path_nodes) == 0 or op.control_wires[0] not in path_nodes[0].wires:
                                        adj_matrix[4][reverse_index(op_list, ancestor, 0, ind+1)][ind+1] = 1
                                        count += 1
                                if ancestor.control_wires[0] == op.control_wires[0]:
                                    if len(path_nodes) == 0 or op.control_wires[0] not in path_nodes[0].wires:
                                        adj_matrix[5][reverse_index(op_list, ancestor, 0, ind+1)][ind+1] = 1
                                        count += 1
                                if ancestor.target_wires[0] == op.target_wires[0]:
                                    if len(path_nodes) == 0 or op.target_wires[0] not in path_nodes[0].wires:
                                        adj_matrix[6][reverse_index(op_list, ancestor, 0, ind+1)][ind+1] = 1
                                        count += 1
                            else:
                                if op.control_wires[0] in wires:
                                    adj_matrix[1][reverse_index(op_list, ancestor, 0, ind+1)][ind+1] = 1
                                else:
                                    adj_matrix[2][reverse_index(op_list, ancestor, 0, ind+1)][ind+1] = 1
                                count += 1
                            if op.control_wires[0] in wires:
                                control_flag = True
                            if op.target_wires[0] in wires:
                                target_flag = True
                    if count == 1:
                        if op.control_wires[0] in wires:
                            adj_matrix[2][0][ind+1] = 1
                        else:
                            adj_matrix[1][0][ind+1] = 1
                else:
                    direct_ancestor = ancestors[-1]
                    if direct_ancestor.name in ['CNOT', 'CZ', 'CY']:
                        if direct_ancestor.control_wires[0] == op.wires[0]:
                            adj_matrix[1][reverse_index(op_list, direct_ancestor, 0, ind+1)][ind+1] = 1
                        else:
                            adj_matrix[2][reverse_index(op_list, direct_ancestor, 0, ind+1)][ind+1] = 1
                    else:
                        adj_matrix[0][reverse_index(op_list, direct_ancestor, 0, ind+1)][ind+1] = 1
                    
            if op.name in ['CNOT', 'CZ', 'CY']:
                wires = set()
                for descendant in descendants:
                    wires.update(set(descendant.wires) & set(op.wires))
                    if isinstance(descendant, qml.measurements.StateMP):
                        if op.control_wires[0] in wires:
                            adj_matrix[1][op_list.index(op, ind+1)][-1] = 1
                        if op.target_wires[0] in wires:
                            adj_matrix[2][op_list.index(op, ind+1)][-1] = 1
                    if len(wires) == 2:
                        break
            else:
                if isinstance(descendants[0], qml.measurements.StateMP):
                    adj_matrix[0][op_list.index(op, ind+1)][-1] = 1 
        return adj_matrix
    
    # get the adjacency matrix with the degree of connections
    def get_adj_matrix_with_degree(self, circuit_list):
        op_list=['START']
        cg = self.get_circuit_graph(circuit_list)
        for op in cg.operations_in_order:
            op_list.append(op)
        op_list.append('END')
        op_len = len(op_list)
        adj_matrix = np.zeros((op_len, op_len), dtype = int)
        if len(cg.wires) < vc.num_qubits:
            adj_matrix[0][-1] = vc.num_qubits - len(cg.wires)
        for ind, op in enumerate(cg.operations_in_order):
            ancestors = cg.ancestors_in_order([op])
            descendants = cg.descendants_in_order([op])
            if len(ancestors) == 0:
                if op.name in ['CNOT', 'CZ', 'CY']:
                    adj_matrix[0][ind+1] = 2
                else:
                    adj_matrix[0][ind+1] = 1
            else:
                if op.name in ['CNOT', 'CZ', 'CY']:
                    count = 0
                    control_flag = False
                    target_flag = False
                    for ancestor in ancestors[::-1]:
                        same_wires = set(ancestor.wires) & set(op.wires)
                        if op.control_wires[0] in same_wires and control_flag == True:
                            continue
                        if op.target_wires[0] in same_wires and target_flag == True:
                            continue
                        wires = set()
                        wires.update(same_wires)
                        if count < 2:
                            if ancestor.name in ['CNOT', 'CZ', 'CY']:
                                path_nodes = cg.nodes_between(ancestor, op)
                                if op in path_nodes:
                                    path_nodes.remove(op)
                                if ancestor in path_nodes:
                                    path_nodes.remove(ancestor)
                                if len(path_nodes) == 0 and len(wires) == 2:
                                    adj_matrix[reverse_index(op_list, ancestor, 0, ind+1)][ind+1] = 2
                                    count += 2
                                else:
                                    adj_matrix[reverse_index(op_list, ancestor, 0, ind+1)][ind+1] = 1
                                    count += 1
                            else:
                                adj_matrix[reverse_index(op_list, ancestor, 0, ind+1)][ind+1] = 1
                                count += 1
                            if op.control_wires[0] in wires:
                                control_flag = True
                            if op.target_wires[0] in wires:
                                target_flag = True
                    if count == 1:
                        adj_matrix[0][ind+1] = 1
                else:
                    direct_ancestor = ancestors[-1]
                    adj_matrix[reverse_index(op_list, direct_ancestor, 0, ind+1)][ind+1] = 1
                    
            if op.name in ['CNOT', 'CZ', 'CY']:
                wires = set()
                for descendant in descendants:
                    wires.update(set(descendant.wires) & set(op.wires))
                    if isinstance(descendant, qml.measurements.StateMP):
                        if len(descendants) == 1:
                            adj_matrix[op_list.index(op, ind+1)][-1] = 2
                        else:
                            adj_matrix[op_list.index(op, ind+1)][-1] = 1
                    if len(wires) == 2:
                        break
            else:
                if isinstance(descendants[0], qml.measurements.StateMP):
                    adj_matrix[op_list.index(op, ind+1)][-1] = 1 
        return adj_matrix
    
    # get image matrices of circuits (See "Neural Predictor based Quantum Architecture Search" -> https://arxiv.org/pdf/2103.06524)
    def get_image_imatrix(self, circuit_list):
        pos_flag = np.zeros(self.num_qubits, dtype=int)
        image_matrix = np.zeros((self.max_depth, self.num_qubits), dtype=int)
        for param in list(circuit_list):
            if param[0] == 'START' or param == 'END':
                continue
            elif param[0] in ['PauliX', 'PauliY', 'PauliZ', 'Hadamard', 'RX', 'RY', 'RZ', 'U3']:
                image_matrix[pos_flag[param[1]], param[1]] = self.gate_dict[param[0]]
                pos_flag[param[1]] += 1
            else:
                if pos_flag[param[1]] >= pos_flag[param[2]]:
                    image_matrix[pos_flag[param[1]], param[1]] = self.gate_dict[param[0]]
                    image_matrix[pos_flag[param[1]], param[2]] = self.gate_dict[param[0]]
                    pos_flag[param[1]] += 1
                    pos_flag[param[2]] = pos_flag[param[1]]
                else:
                    image_matrix[pos_flag[param[2]], param[2]] = self.gate_dict[param[0]]
                    image_matrix[pos_flag[param[2]], param[1]] = self.gate_dict[param[0]]
                    pos_flag[param[2]] += 1
                    pos_flag[param[1]] = pos_flag[param[2]]
                    
        return image_matrix
    
    def get_improved_image_representation(self, circuit_list):
        pos_flag = np.zeros(self.num_qubits, dtype=int)
        improved_image_representation = np.zeros((self.max_depth, self.num_qubits, 3), dtype=int)
        for param in list(circuit_list):
            if param[0] == 'START' or param == 'END':
                continue
            elif param[0] in ['PauliX', 'PauliY', 'PauliZ', 'Hadamard', 'RX', 'RY', 'RZ', 'U3']:
                improved_image_representation[pos_flag[param[1]], param[1], 0] = self.gate_dict[param[0]]
                improved_image_representation[pos_flag[param[1]], param[1], 1] = 1
                improved_image_representation[pos_flag[param[1]], param[1], 2] = 0
                pos_flag[param[1]] += 1
            else:
                if pos_flag[param[1]] >= pos_flag[param[2]]:
                    improved_image_representation[pos_flag[param[1]], param[1], 0] = self.gate_dict[param[0]]
                    improved_image_representation[pos_flag[param[1]], param[2], 0] = self.gate_dict[param[0]]
                    improved_image_representation[pos_flag[param[1]], param[1], 1] = 2
                    improved_image_representation[pos_flag[param[1]], param[2], 1] = 2
                    improved_image_representation[pos_flag[param[1]], param[1], 2] = param[2] + 1
                    improved_image_representation[pos_flag[param[1]], param[2], 2] = 0
                    pos_flag[param[1]] += 1
                    pos_flag[param[2]] = pos_flag[param[1]]
                else:
                    improved_image_representation[pos_flag[param[2]], param[2], 0] = self.gate_dict[param[0]]
                    improved_image_representation[pos_flag[param[2]], param[1], 0] = self.gate_dict[param[0]]
                    improved_image_representation[pos_flag[param[2]], param[1], 1] = 2
                    improved_image_representation[pos_flag[param[2]], param[2], 1] = 2
                    improved_image_representation[pos_flag[param[2]], param[1], 2] = param[2] + 1
                    improved_image_representation[pos_flag[param[2]], param[2], 2] = 0
                    pos_flag[param[2]] += 1
                    pos_flag[param[1]] = pos_flag[param[2]]
                    
        return improved_image_representation

    # For quantum graph encoding, gates should have fixed indegree and outdegree.
    def get_degree_condition(self, circuit_list):
        # START
        indegree = [0]
        outdegree = [vc.num_qubits]
        # circuit gates
        for param in list(circuit_list):
            if param[0] == 'START' or param == 'END':
                continue
            elif param[0] in ['CNOT', 'CY', 'CZ']:
                indegree.append(2)
                outdegree.append(2)
            else:
                indegree.append(1)
                outdegree.append(1)
        # END
        indegree.append(vc.num_qubits)
        outdegree.append(0)
        return indegree, outdegree
    
    #TODO
    def reconstruct_circuit_from_graph(self, gate_matrix, adj_matrix_with_degree):
        pass
    
    #TODO
    def reconstruct_circuit_from_image(self, improved_image_representation):
        pass

    @property
    def get_num_qubits(self):
        return self.num_qubits
    
    @property
    def get_num_circuits(self):
        return self.num_circuits
    
    @property
    def get_num_gates(self):
        return self.num_gates
    
    @property
    def get_max_depth(self):
        return self.max_depth
    
    @property
    def get_gate_dict(self):
        return self.gate_dict

# transform adjancency matrices to the edge adjacnecy lists
def transform_adj_to_edge_list(adj_matrix):
    adj_matrix = torch.tensor(adj_matrix, dtype=torch.int)
    def adj_to_edge_index(adj):
        src, dst = adj.nonzero(as_tuple=True)
        return torch.stack([src, dst], dim=0)
    edge_list = [adj_to_edge_index(adj) for adj in adj_matrix]
    return edge_list

# dump circuit features in json file
def data_dumper(circuit_manager: CircuitManager, f_name: str ='data.json'):
    """dump circuit DAG features."""
    circuit_features = []
    file_path = os.path.join(current_path, f'circuit\\data\\{f_name}')
    for i in range(circuit_manager.get_num_circuits):
        op_list, gate_matrix = circuit_manager.get_gate_matrix(circuits[i], type='original')
        _, gate_matrix_improved = circuit_manager.get_gate_matrix(circuits[i], type='improved')
        indegree, outdegree = circuit_manager.get_degree_condition(circuits[i])
        adj_matrix, adj_matrix_with_degree = circuit_manager.get_adj_matrix(circuits[i]), circuit_manager.get_adj_matrix_with_degree(circuits[i])
        image_matrix, improved_image_representation = circuit_manager.get_image_imatrix(circuits[i]), circuit_manager.get_improved_image_representation(circuits[i])
        depth = qml.specs(circuit_qnode)(circuits[i])['resources'].depth
        circuit_features.append({'op_list': op_list, 'depth': depth, 'gate_matrix': gate_matrix, 'improved_gate_matrix': gate_matrix_improved,
                                 'indegree': indegree, 'outdegree': outdegree,
                                'adj_matrix': adj_matrix.tolist(), 'adj_matrix_with_degree': adj_matrix_with_degree.tolist(), 
                                'image_matrix':image_matrix.tolist(), 'improved_image_representation': improved_image_representation.tolist()})
    with open(file_path, 'w', encoding='utf-8') as file:  
        json.dump(circuit_features, file)

if __name__ == '__main__':
    seed = 42
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    circuit_manager = CircuitManager(vc.num_qubits, vc.num_circuits, vc.num_gates, vc.max_depth, vc.allowed_gates)
    circuits = circuit_manager.generate_circuits()
    print("Number of unique circuits generated:", len(circuits))
    print("The first curcuit list: ", circuits[0])
    op_list, gate_matrix = circuit_manager.get_gate_matrix(circuits[0], type='original')
    _, gate_matrix_improved = circuit_manager.get_gate_matrix(circuits[0], type='improved')
    indegree, outdegree = circuit_manager.get_degree_condition(circuits[0])
    adj_matrix, adj_matrix_with_degree = circuit_manager.get_adj_matrix(circuits[0]), circuit_manager.get_adj_matrix_with_degree(circuits[0])
    image_matrix, improved_image_representation = circuit_manager.get_image_imatrix(circuits[0]), circuit_manager.get_improved_image_representation(circuits[0])
    print("The first curcuit info: ")
    print("op_list: ", op_list)
    print("gate_matrix", gate_matrix)
    print("improved_gate_matrix", gate_matrix_improved)
    print("indegree", indegree)
    print('outdegree', outdegree)
    print("adj_matrix: ", adj_matrix)
    print("adj_matrix_with_degree:", adj_matrix_with_degree)
    print("image_matrix: ", image_matrix)
    print("improved_image_representation: ", improved_image_representation)
    print("depth: ", qml.specs(circuit_qnode)(circuits[0])['resources'].depth)
    print("gate_dict: ", circuit_manager.gate_dict)
    fig, ax = qml.draw_mpl(circuit_qnode)(circuits[0])
    plt.show()
    data_dumper(circuit_manager, f_name=f'data_{vc.num_qubits}_qubits.json')