from BA3_loc import *
from tqdm import tqdm
import os.path as osp
import warnings
warnings.filterwarnings("ignore")
import random
import torch
import copy
import itertools
data_dir = f'./data/CRCG-MOTIF/raw/'
os.makedirs(data_dir, exist_ok=True)

def create_motif_star_branch(size, branches, node_feature_mean, std):
       G = nx.Graph() 
       role_id = []
       for i in range(size):
          if i == 0:
            features = np.random.normal(loc=node_feature_mean, scale=std)
            G.add_node(i, features=features, is_center=True)
            G.nodes[i]['feature'] = features  
            role_id.append(0)
          else:
            features = np.random.normal(loc=node_feature_mean, scale=std)
            G.add_node(i, features=features, is_center=False)
            G.nodes[i]['feature'] = features  
            role_id.append(1)
       for i in range(1, size):
              G.add_edge(0, i)
       return G,role_id
def create_motif_path_branch(size, branches, node_feature_mean, std):
        G = nx.Graph()        
        role_id = []
        for i in range(size):
            node_id = i
            node_features = np.random.normal(node_feature_mean, std)
            G.add_node(node_id, features=node_features.tolist())
            G.nodes[node_id]['feature'] = node_features
            role_id.append(1)
            if i > 0:
                G.add_edge(i-1, i)
        return G,role_id
def create_motif_fan_branch(size, branches, node_feature_mean, std):

        G = nx.Graph()
        role_id = []

        for i in range(size):
          if i == 0:
           features = np.random.normal(loc=node_feature_mean, scale=std)
           G.add_node(i, features=features, is_center=True)
           G.nodes[i]['feature'] = features
           role_id.append(0)
          else:
           features = np.random.normal(loc=node_feature_mean, scale=std)
           G.add_node(i, features=features, is_center=False)
           G.nodes[i]['feature'] = features
           role_id.append(1)

        for i in range(1, size):
          G.add_edge(0, i)
        for i in range(1, branches):
          for j in range(1, size // branches):
           G.add_edge(i, j * branches + i)
        return G,role_id
def create_motif_cuspedPolygon_branch(size, branches, node_feature_mean, std):



        G = nx.Graph()
        role_id = []

        nodes = range(size)
        features = np.random.normal(loc=node_feature_mean, scale=std)
        role_id.append(1)
        G.add_nodes_from(nodes)

        for i in range(size // 2):
           node_a = i
           node_b = (i + size // 2) % size
        for b in range(branches):
            next_node = (i * branches + b + 1) % (size // 2) + size // 2
            G.add_edge(node_a, next_node)
            G.add_edge(node_b, next_node)
            G.add_edge(size // 2 - 1, size // 2)

        if len(G.edges()) == 0:
           node_a = random.choice(list(G.nodes()))
           node_b = random.choice(list(G.nodes()))
           G.add_edge(node_a, node_b)
        return G,role_id
def create_motif_random_bipartite_branch(size, branches, node_feature_mean, std):    

        G = nx.Graph()
        role_id=[]

        nodes = []
        for i in range(size):
           nodes.append((i, {'feat': np.random.normal(node_feature_mean, std)}))
           role_id.append(i)
        G.add_nodes_from(nodes)

        num_nodes_per_part = size // 2
        part1 = list(range(num_nodes_per_part))
        part2 = list(range(num_nodes_per_part, size))
        for i in part1:
           for j in part2:
              if np.random.rand() < 0.5:
                  G.add_edge(i, j)

        if len(G.edges()) == 0:
           node_a = random.choice(list(G.nodes()))
           node_b = random.choice(list(G.nodes()))
           G.add_edge(node_a, node_b)
        return G,role_id
def create_motif_tree_branch(size, branches, node_feature_mean, std): 
        assert branches > 0, "Number of branches must be greater than 0 for tree motif."
        G = nx.Graph()
        node_id = 0
        role_id=[]
        for i in range(size):
            G.add_node(node_id, feature=np.random.normal(node_feature_mean, std))
            role_id.append(i)
            if i > 0:
               parent_id = random.randint(max(0, node_id - branches), node_id - 1)
               G.add_edge(parent_id, node_id)
            node_id += 1 
        return G,role_id  
def create_motif_trident_branch(size, branches, node_feature_mean, std):      


      G = nx.Graph()
      role_id=[]
      for i in range(size):
         feature = np.random.normal(node_feature_mean, std)
         G.add_node(i, feature=feature)
         role_id.append(i)
      for i in range(branches):
          start_node = i * 3
          end_node_1 = start_node + 1
          end_node_2 = start_node + 2
          G.add_edge(start_node, end_node_1)
          G.add_edge(start_node, end_node_2)
          if i > 0:
             last_branch_end_1 = (i - 1) * 3 + 1
             last_branch_end_2 = (i - 1) * 3 + 2
             G.add_edge(last_branch_end_1, end_node_1)
             G.add_edge(last_branch_end_2, end_node_2)
          return G,role_id
def create_motif_conicalConnection_branch(size, branches, node_feature_mean, std):  



      G = nx.Graph()
      role_id = []
      mid = size // 2
      for i in range(size):
          if i == mid:
              G.add_node(i, feature=np.random.normal(node_feature_mean, std))
              role_id.append(1)
          else:
              G.add_node(i, feature=np.random.normal(node_feature_mean, std))
              if i < mid:
                  role_id.append(2)
              else:
                  role_id.append(3)
      for i in range(size):
          if i == mid:
              continue
          if i < mid:
              G.add_edge(mid, i)
              G.add_edge(i, i + 1)
          else:
              G.add_edge(mid, i)
              G.add_edge(i, i - 1)
      for b in range(branches):
          mid_b = mid + b + 1
          for i in range(size):
              if i == mid_b:
                  G.add_node(size + b, feature=np.random.normal(node_feature_mean, std))
                  role_id.append(4)
              else:
                  G.add_node(size + b, feature=np.random.normal(node_feature_mean, std))
                  role_id.append(5)
          for i in range(size):
              if i == mid_b:
                  continue
              if i < mid_b:
                  G.add_edge(mid_b, i + size)
                  G.add_edge(i + size, i + 1 + size)
              else:
                  G.add_edge(mid_b, i + size)
                  G.add_edge(i + size, i - 1 + size)
          G.add_edge(mid, mid_b)
      return G,role_id
def create_motif_chainBypass_branch(size, branches, node_feature_mean, std): 





      role_id=[]
      G = nx.Graph()

      for i in range(size):
          features = np.random.normal(node_feature_mean, std)
          G.add_node(i, features=features)
          role_id.append(i)

      for i in range(size-1):
          G.add_edge(i, i+1)
      branch_size = (size - 2) // branches
      start_node = 1
      for i in range(branches):
          for j in range(branch_size):

              branch_node = start_node + j
              bypass_node = size - (2*branches) + i
              G.add_edge(branch_node, bypass_node)
          start_node += branch_size


      return G,role_id
def create_motif_partPolygon_branch(size, branches, node_feature_mean, std): 


    n_nodes = size

    edges = [(i, i + 1) for i in range(n_nodes - 1)] + [(n_nodes - 1, i) for i in range(1, n_nodes - 1)]

    G = nx.Graph()
    G.add_edges_from(edges) 

    node_features = np.random.normal(node_feature_mean, std)  

    role_id = np.zeros((n_nodes, branches))
    role_id[0] = np.array([1] * branches)
    role_id[n_nodes - 1] = np.array([2] * branches)
    role_id[1: n_nodes - 1] = np.array([3] * branches)
    return G,role_id
def create_motif_completeGraph(size,node_feature_mean,std):
    G = nx.complete_graph(size)
    role_id = []
    for i in range(size):
        G.nodes[i]["feature"] = np.random.normal(node_feature_mean[i % len(node_feature_mean)], std[i % len(std)])
        role_id.append(i)
    return G,role_id
def create_motif_netShape(size,node_feature_mean,std):
    G = nx.Graph()
    role_id = []
    n = int(np.sqrt(size))
    for i in range(n):
        for j in range(n):
            node_idx = i * n + j
            node_feature = np.random.normal(node_feature_mean, std)
            G.add_node(node_idx, feature=node_feature)
            role_id.append(i)
            if i > 0:
                upper_node_idx = (i - 1) * n + j
                G.add_edge(node_idx, upper_node_idx)
            if j > 0:
                left_node_idx = i * n + (j - 1)
                G.add_edge(node_idx, left_node_idx)
    return G,role_id
def create_motif_dircycle(size,node_feature_mean,std):   
    G = nx.Graph()
    role_id = []
    for i in range(size):
        features = np.random.normal(loc=node_feature_mean, scale=std)
        G.add_node(i, features=features)
        role_id.append(i)
    for i in range(size):
        G.add_edge(i, (i + 1) % size)
    return G,role_id
def create_motif_dualRing(size,node_feature_mean,std):
    G = nx.Graph()
    role_id = []
    n1 = size//2
    n2 = size-n1
    G = nx.empty_graph(n=n1+n2)
    for i in range(size):
        features = np.random.normal(loc=node_feature_mean, scale=std)
        G.add_node(i, features=features)
        role_id.append(i)
    G.add_edges_from([(i,(i+1)%n1) for i in range(n1)] + [(i+n1,(i+1)%n2+n1) for i in range(n2)] + [(i,i+n1) for i in range(n1)])
    return G,role_id
def create_motif_triangle(size,node_feature_mean,std):

    G = nx.Graph()
    role_id = []
    n = int(np.sqrt(2 * size))
    for i in range(n):
        for j in range(i + 1):
            node_idx = (i * (i + 1) // 2) + j
            node_feature = np.random.normal(node_feature_mean, std)
            G.add_node(node_idx, feature=node_feature)
            role_id.append(i)
            if i > 0:
                upper_node_idx = (i * (i - 1) // 2) + j
                G.add_edge(node_idx, upper_node_idx)
                if j > 0:
                    left_upper_node_idx = upper_node_idx - 1
                    G.add_edge(node_idx, left_upper_node_idx)
                if j < i:
                    right_upper_node_idx = upper_node_idx
                    G.add_edge(node_idx, right_upper_node_idx)
            if j > 0:
                left_node_idx = (i * (i + 1) // 2) + (j - 1)
                G.add_edge(node_idx, left_node_idx)
    return G,role_id

def create_motif_ringShape(size,node_feature_mean,std):   
    ring_nodes = []
    for i in range(size):
        theta = i / size * 2 * np.pi
        r = 1.5 + np.random.normal(node_feature_mean, std)
        x = r * np.cos(theta)
        y = r * np.sin(theta)
        ring_nodes.append((x, y))

    ring_edges = []
    for i in range(size):
        ring_edges.append((i, (i+1)%size))

    extra_edges = set()
    while len(extra_edges) < size:
        u = np.random.randint(0, size)
        v = np.random.randint(0, size)
        if u == v:
            continue
        if (u, v) in extra_edges or (v, u) in extra_edges:
            continue
        extra_edges.add((u, v))

    edges = list(ring_edges) + list(extra_edges)

    G = nx.Graph()
    G.add_nodes_from(range(size))
    G.add_edges_from(edges)


    node_features = []
    for i in range(size):
        node_features.append(np.random.normal(node_feature_mean, std))

    role_id = [0] * size
    role_id[0] = 1
    role_id[-1] = 2
    return G,role_id
def create_motif_diamond(size,node_feature_mean,std):
    num_nodes = 2*size + 1
    node_features = []
    for i in range(num_nodes):
        node_features.append(np.random.normal(node_feature_mean, std))
    edges = []
    for i in range(size):
        edges.append((i, i+1))
        edges.append((i, size+i+1))
        edges.append((size+i+1, i+1))
    edges.append((size, size*2))
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edges)

    role_id = [0] * num_nodes
    role_id[0] = 1
    role_id[-1] = 2
    return G,role_id
def create_motif_HShape(size,node_feature_mean,std):
     num_nodes = 2*size + 1
     node_features = []
     for i in range(num_nodes):
        node_features.append(np.random.normal(node_feature_mean, std))
     edges = []
     for i in range(size):
        edges.append((i, i+1))
        edges.append((i, size+i+1))
        edges.append((size+i, size+i+1))
     G = nx.Graph()
     G.add_nodes_from(range(num_nodes))
     G.add_edges_from(edges)

     role_id = [0] * num_nodes
     role_id[0] = 1
     role_id[-1] = 2
     return G,role_id
def create_motif_wheel(size,node_feature_mean,std):
    num_nodes = size + 1
    role_id = []
    node_features = []
    for i in range(num_nodes):
        node_features.append(np.random.normal(node_feature_mean, std))
        role_id.append(i)
    edges = []
    for i in range(1, size+1):
        edges.append((0, i))
        edges.append((i, (i % size) + 1))
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edges)
    return G,role_id
def create_motif_hourglass(size,node_feature_mean,std):
    num_nodes = 2*size + 1
    role_id = []
    node_features = []
    for i in range(num_nodes):
        node_features.append(np.random.normal(node_feature_mean, std))
        role_id.append(i)
    edges = []
    for i in range(size):
        edges.append((0, i+1))
        edges.append((i+1, i+size+1))
        edges.append((i+size+1, num_nodes-1))
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edges)
    return G,role_id
def create_motif_DCD(size,node_feature_mean,std):

    num_nodes = 3*size + 1
    role_id = []
    node_features = []
    for i in range(num_nodes):
        node_features.append(np.random.normal(node_feature_mean, std))
        role_id.append(i)
    edges = []
    for i in range(size):
        edges.append((i, i+1))
        edges.append((i, size+i+1))
        edges.append((size+i, size+i+1))
        edges.append((size+i, 2*size+i+1))
        edges.append((2*size+i, 2*size+i+1))
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edges)
    return G,role_id
def create_motif_Cyclocross(size,node_feature_mean,std):
    num_nodes = 2*size + 1
    role_id = []
    node_features = []
    for i in range(num_nodes):
        node_features.append(np.random.normal(node_feature_mean, std))
        role_id.append(i)
    edges = []
    for i in range(size):
        edges.append((i, i+1))
        edges.append((size+i, size+i+1))
    edges.append((size, 0))
    edges.append((size, 2*size))
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edges)
    return G,role_id
def create_motif_ladder(size,node_feature_mean,std):
    num_nodes = 2 * size
    role_id = []
    node_features = []
    for i in range(num_nodes):
        node_features.append(np.random.normal(node_feature_mean, std))
        role_id.append(i)
    edges = [(i, i+1) for i in range(size-1)] + [(i+size, i+size+1) for i in range(size-1)]
    edges += [(i, i+size) for i in range(size)]
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edges)
    return G,role_id
def create_motif_bowtie(size,node_feature_mean,std):   

    G = nx.Graph()
    for i in range(size):
        G.add_node(i, feature=np.random.normal(node_feature_mean, std))
    G.add_edge(0, 1)
    for i in range(2, size):
        if i % 2 == 0:
            G.add_edge(0, i)
        else:
            G.add_edge(1, i)
    role_id = np.zeros((size,))
    role_id[0] = 0
    role_id[1] = 1
    for i in range(2, size):
        if i % 2 == 0:
            role_id[i] = 1
        else:
            role_id[i] = 2
    return G,role_id
def create_motif_cross(size,node_feature_mean,std):   

    G = nx.Graph()
    role_id = np.zeros(size, dtype=np.int32)
    for i in range(size):
        G.add_node(i, feature=np.random.normal(node_feature_mean, std, [5]))
        if i == size // 2:
            G.add_edge(i, 0)
            role_id[i] = 1
        else:
            G.add_edge(i, (i+1) % (size//2))
            G.add_edge(i, (i+size//2) % size)
            if i < size // 2:
                role_id[i] = 2
            else:
                role_id[i] = 3
    return G,role_id



def create_benzene_ring(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []

    for i in range(size):

        node_feature = np.array([0, 0, 0, 0, 0])
        G.add_node(i, feature=node_feature)
        G.nodes[i]['feature'] = node_feature
        role_id.append(0)

        if i > 0:
            G.add_edge(i, i - 1)
        if i == size - 1:
            G.add_edge(i, 0)

    return G, role_id
def create_methane(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = [0, 1, 2, 3]

    for i in range(4):

        node_feature = np.array([1, 1, 1, 1, 1])
        G.add_node(i, feature=node_feature)
        G.nodes[i]['feature'] = node_feature

    G.add_edge(0, 1)
    G.add_edge(0, 2)
    G.add_edge(0, 3)
    G.add_edge(0, 4)

    return G, role_id
def create_ethane(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = [0, 1, 2, 3, 4, 5]

    for i in range(6):

        node_feature = np.array([2, 2, 2, 2, 2])
        G.add_node(i, feature=node_feature)
        G.nodes[i]['feature'] = node_feature

    G.add_edge(0, 1)
    G.add_edge(0, 2)
    G.add_edge(0, 3)
    G.add_edge(0, 4)
    G.add_edge(4, 5)
    G.add_edge(4, 6)
    G.add_edge(4, 7)

    return G, role_id
def create_benzoic_acid(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []
    n = int(np.sqrt(size))


    benzene_ring = []
    for i in range(6):
        node_idx = i

        node_feature = np.array([3, 3, 3, 3, 3])
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        benzene_ring.append(node_idx)

    for i in range(len(benzene_ring)):
        G.add_edge(benzene_ring[i], benzene_ring[(i + 1) % len(benzene_ring)])


    carboxyl_group = []
    for i in range(6, 11):
        node_idx = i

        node_feature = np.array([3, 3, 3, 3, 3])
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(1)
        carboxyl_group.append(node_idx)

    G.add_edge(benzene_ring[0], carboxyl_group[0])
    G.add_edge(carboxyl_group[0], carboxyl_group[1])
    G.add_edge(carboxyl_group[1], carboxyl_group[2])
    G.add_edge(carboxyl_group[2], carboxyl_group[3])
    G.add_edge(carboxyl_group[3], carboxyl_group[4])

    return G, role_id
def create_nitrobenzene(size, node_feature_mean, std):

    G = nx.Graph()
    role_id = []
    n = int(np.sqrt(size))


    benzene_ring = []
    for i in range(6):
        node_idx = i

        node_feature = np.array([4, 4, 4, 4, 4])
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        benzene_ring.append(node_idx)

    for i in range(len(benzene_ring)):
        G.add_edge(benzene_ring[i], benzene_ring[(i + 1) % len(benzene_ring)])


    nitro_group = []
    for i in range(6, 9):
        node_idx = i

        node_feature = np.array([4, 4, 4, 4, 4])
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(1)
        nitro_group.append(node_idx)

    G.add_edge(benzene_ring[0], nitro_group[0])
    G.add_edge(nitro_group[0], nitro_group[1])
    G.add_edge(nitro_group[0], nitro_group[2])

    return G, role_id
def create_ethanol(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []
    n = int(np.sqrt(size))


    ethanol_structure = []
    for i in range(3):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        ethanol_structure.append(node_idx)

    for i in range(len(ethanol_structure)):
        G.add_edge(ethanol_structure[i], ethanol_structure[(i + 1) % len(ethanol_structure)])

    hydroxyl_group_node = 3
    G.add_node(hydroxyl_group_node, feature=np.random.normal(node_feature_mean, std))
    G.nodes[hydroxyl_group_node]['feature'] = node_feature
    role_id.append(1)

    G.add_edge(ethanol_structure[-1], hydroxyl_group_node)

    return G, role_id
def create_thioether(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    thioether_structure = []
    for i in range(size):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)

        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        thioether_structure.append(node_idx)


    for i in range(len(thioether_structure)-1):
        G.add_edge(thioether_structure[i], thioether_structure[i+1])

    return G, role_id
def create_simplified_dopamine(size, node_feature_mean, std):
    G = nx.Graph()


    atoms = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]


    bonds = [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6),
             (6, 1), (1, 7), (2, 8), (4, 9), (6, 10),
             (10, 11)]

    for atom in atoms:
        node_feature = np.random.normal(node_feature_mean, std)
        role_id = 0
        G.add_node(atom, feature=node_feature, role_id=role_id)
        G.nodes[atom]['feature'] = node_feature

    G.add_edges_from(bonds)

    roles = [G.nodes[node]['role_id'] for node in G.nodes]
    
    return G, roles
def create_hexamethylbenzene(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []
    n = int(np.sqrt(size))


    hexamethylbenzene_structure = []
    for i in range(6):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        hexamethylbenzene_structure.append(node_idx)

    for i in range(len(hexamethylbenzene_structure) - 1):
        G.add_edge(hexamethylbenzene_structure[i], hexamethylbenzene_structure[i+1])

    G.add_edge(hexamethylbenzene_structure[-1], hexamethylbenzene_structure[0])


    for i in range(6, size):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(1)
        G.add_edge(node_idx, hexamethylbenzene_structure[i % 6])
    return G, role_id
def create_acetic_acid(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []
    

    carbon_atoms = []
    for i in range(2):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        carbon_atoms.append(node_idx)
    

    hydrogen_atoms = []
    for i in range(2, 4):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(1)
        hydrogen_atoms.append(node_idx)
    

    oxygen_atoms = []
    for i in range(4, 6):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(2)
        oxygen_atoms.append(node_idx)
    

    G.add_edge(carbon_atoms[0], hydrogen_atoms[0])
    G.add_edge(carbon_atoms[0], carbon_atoms[1])
    G.add_edge(carbon_atoms[1], oxygen_atoms[0])
    G.add_edge(carbon_atoms[1], oxygen_atoms[1])
    G.add_edge(oxygen_atoms[0], hydrogen_atoms[1])

    return G, role_id   
def create_ammonia(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    nitrogen_atom = 0
    node_feature = np.random.normal(node_feature_mean, std)
    G.add_node(nitrogen_atom, feature=node_feature)
    G.nodes[nitrogen_atom]['feature'] = node_feature
    role_id.append(0)
    

    hydrogen_atoms = []
    for i in range(1, 4):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(1)
        hydrogen_atoms.append(node_idx)
    

    for atom in hydrogen_atoms:
        G.add_edge(nitrogen_atom, atom)

    return G, role_id
def create_vitamin_c(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    carbon_atoms = []
    for i in range(6):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        carbon_atoms.append(node_idx)
    

    hydrogen_atoms = []
    for i in range(6, 12):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(1)
        hydrogen_atoms.append(node_idx)
    

    oxygen_atoms = []
    for i in range(12, 18):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(2)
        oxygen_atoms.append(node_idx)
    

    G.add_edge(carbon_atoms[0], carbon_atoms[1])
    G.add_edge(carbon_atoms[1], carbon_atoms[2])
    G.add_edge(carbon_atoms[2], carbon_atoms[3])
    G.add_edge(carbon_atoms[2], carbon_atoms[4])
    G.add_edge(carbon_atoms[3], carbon_atoms[5])
    G.add_edge(carbon_atoms[0], oxygen_atoms[0])
    G.add_edge(carbon_atoms[1], hydrogen_atoms[0])
    G.add_edge(carbon_atoms[1], hydrogen_atoms[1])
    G.add_edge(carbon_atoms[3], hydrogen_atoms[2])
    G.add_edge(carbon_atoms[3], hydrogen_atoms[3])
    G.add_edge(carbon_atoms[4], hydrogen_atoms[4])
    G.add_edge(carbon_atoms[5], hydrogen_atoms[5])
    G.add_edge(carbon_atoms[5], oxygen_atoms[1])
    G.add_edge(oxygen_atoms[1], oxygen_atoms[2])
    G.add_edge(oxygen_atoms[1], oxygen_atoms[3])
    G.add_edge(oxygen_atoms[2], oxygen_atoms[4])
    G.add_edge(oxygen_atoms[2], oxygen_atoms[5])

    return G, role_id
def create_adrenaline(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    carbon_atoms = []
    for i in range(6):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        carbon_atoms.append(node_idx)
    

    hydrogen_atoms = []
    for i in range(6, 13):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(1)
        hydrogen_atoms.append(node_idx)
    

    nitrogen_atoms = []
    for i in range(13, 15):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(2)
        nitrogen_atoms.append(node_idx)
    

    oxygen_atoms = []
    for i in range(15, 19):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(3)
        oxygen_atoms.append(node_idx)
    

    G.add_edge(carbon_atoms[0], carbon_atoms[1])
    G.add_edge(carbon_atoms[1], carbon_atoms[2])
    G.add_edge(carbon_atoms[1], carbon_atoms[3])
    G.add_edge(carbon_atoms[2], carbon_atoms[4])
    G.add_edge(carbon_atoms[2], carbon_atoms[5])
    G.add_edge(carbon_atoms[5], carbon_atoms[3])
    G.add_edge(carbon_atoms[5], oxygen_atoms[0])
    G.add_edge(carbon_atoms[5], oxygen_atoms[1])
    G.add_edge(carbon_atoms[0], hydrogen_atoms[0])
    G.add_edge(carbon_atoms[1], hydrogen_atoms[1])
    G.add_edge(carbon_atoms[3], hydrogen_atoms[2])
    G.add_edge(carbon_atoms[4], hydrogen_atoms[3])
    G.add_edge(carbon_atoms[4], oxygen_atoms[2])
    G.add_edge(carbon_atoms[4], oxygen_atoms[3])
    G.add_edge(carbon_atoms[5], hydrogen_atoms[4])
    G.add_edge(nitrogen_atoms[0], carbon_atoms[0])
    G.add_edge(nitrogen_atoms[0], hydrogen_atoms[5])
    G.add_edge(nitrogen_atoms[1], carbon_atoms[0])
    G.add_edge(nitrogen_atoms[1], carbon_atoms[1])
    G.add_edge(nitrogen_atoms[1], carbon_atoms[2])
    G.add_edge(nitrogen_atoms[1], hydrogen_atoms[6])

    return G, role_id
def create_glucose(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    carbon_atoms = []
    for i in range(6):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        carbon_atoms.append(node_idx)


    oxygen_atoms = []
    for i in range(6, 12):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(1)
        oxygen_atoms.append(node_idx)


    hydrogen_atoms = []
    for i in range(12, 24):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(2)
        hydrogen_atoms.append(node_idx)


    for i in range(5):
        G.add_edge(carbon_atoms[i], carbon_atoms[i + 1])
        G.add_edge(carbon_atoms[i], oxygen_atoms[i])
        G.add_edge(carbon_atoms[i], hydrogen_atoms[2 * i])
        G.add_edge(carbon_atoms[i], hydrogen_atoms[2 * i + 1])
    G.add_edge(carbon_atoms[5], oxygen_atoms[5])
    G.add_edge(carbon_atoms[5], hydrogen_atoms[10])
    G.add_edge(carbon_atoms[5], hydrogen_atoms[11])

    return G, role_id
def create_fullerenes(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    carbon_atoms = []
    for i in range(60):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        carbon_atoms.append(node_idx)


    for i in range(59):
        G.add_edge(carbon_atoms[i], carbon_atoms[i + 1])
    G.add_edge(carbon_atoms[59], carbon_atoms[0])

    return G, role_id
def create_pyridine(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    carbon_atoms = []
    for i in range(6):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        carbon_atoms.append(node_idx)


    nitrogen_atom = 6
    node_feature = np.random.normal(node_feature_mean, std)
    G.add_node(nitrogen_atom, feature=node_feature)
    G.nodes[nitrogen_atom]['feature'] = node_feature
    role_id.append(1)


    G.add_edge(carbon_atoms[0], nitrogen_atom)
    for i in range(5):
        G.add_edge(carbon_atoms[i], carbon_atoms[i + 1])
    G.add_edge(carbon_atoms[5], carbon_atoms[0])

    return G, role_id
def create_pyrrole(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    carbon_atoms = []
    for i in range(5):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        carbon_atoms.append(node_idx)


    nitrogen_atom = 5
    node_feature = np.random.normal(node_feature_mean, std)
    G.add_node(nitrogen_atom, feature=node_feature)
    G.nodes[nitrogen_atom]['feature'] = node_feature
    role_id.append(1)


    G.add_edge(carbon_atoms[0], nitrogen_atom)
    for i in range(4):
        G.add_edge(carbon_atoms[i], carbon_atoms[i + 1])
    G.add_edge(carbon_atoms[4], carbon_atoms[0])

    return G, role_id
def create_indole(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    carbon_atoms = []
    for i in range(9):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        carbon_atoms.append(node_idx)


    nitrogen_atom = 9
    node_feature = np.random.normal(node_feature_mean, std)
    G.add_node(nitrogen_atom, feature=node_feature)
    G.nodes[nitrogen_atom]['feature'] = node_feature
    role_id.append(1)


    G.add_edge(carbon_atoms[0], nitrogen_atom)
    G.add_edge(carbon_atoms[0], carbon_atoms[8])
    for i in range(8):
        G.add_edge(carbon_atoms[i], carbon_atoms[i + 1])

    return G, role_id
def create_thiazole(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    carbon_atoms = []
    for i in range(4):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        carbon_atoms.append(node_idx)


    nitrogen_atom = 4
    node_feature = np.random.normal(node_feature_mean, std)
    G.add_node(nitrogen_atom, feature=node_feature)
    G.nodes[nitrogen_atom]['feature'] = node_feature
    role_id.append(1)


    sulfur_atom = 5
    node_feature = np.random.normal(node_feature_mean, std)
    G.add_node(sulfur_atom, feature=node_feature)
    G.nodes[sulfur_atom]['feature'] = node_feature
    role_id.append(2)


    G.add_edge(carbon_atoms[0], nitrogen_atom)
    G.add_edge(carbon_atoms[0], carbon_atoms[3])
    G.add_edge(carbon_atoms[3], sulfur_atom)
    for i in range(3):
        G.add_edge(carbon_atoms[i], carbon_atoms[i + 1])

    return G, role_id
def create_imidazole(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    carbon_atoms = []
    for i in range(5):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        carbon_atoms.append(node_idx)


    nitrogen_atoms = []
    for i in range(5, 7):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(1)
        nitrogen_atoms.append(node_idx)


    G.add_edge(carbon_atoms[0], carbon_atoms[1])
    G.add_edge(carbon_atoms[1], carbon_atoms[2])
    G.add_edge(carbon_atoms[2], carbon_atoms[3])
    G.add_edge(carbon_atoms[3], carbon_atoms[4])
    G.add_edge(carbon_atoms[4], carbon_atoms[0])
    G.add_edge(carbon_atoms[1], nitrogen_atoms[0])
    G.add_edge(carbon_atoms[3], nitrogen_atoms[1])

    return G, role_id
def create_pyrimidine(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    carbon_atoms = []
    for i in range(6):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        carbon_atoms.append(node_idx)


    nitrogen_atoms = []
    for i in range(6, 8):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(1)
        nitrogen_atoms.append(node_idx)


    G.add_edge(carbon_atoms[0], carbon_atoms[1])
    G.add_edge(carbon_atoms[1], carbon_atoms[2])
    G.add_edge(carbon_atoms[2], carbon_atoms[3])
    G.add_edge(carbon_atoms[3], carbon_atoms[4])
    G.add_edge(carbon_atoms[4], carbon_atoms[5])
    G.add_edge(carbon_atoms[0], nitrogen_atoms[0])
    G.add_edge(carbon_atoms[2], nitrogen_atoms[1])

    return G, role_id
def create_porphyrin(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    carbon_atoms = []
    for i in range(24):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        carbon_atoms.append(node_idx)


    G.add_edge(carbon_atoms[0], carbon_atoms[1])
    G.add_edge(carbon_atoms[1], carbon_atoms[2])
    G.add_edge(carbon_atoms[2], carbon_atoms[3])
    G.add_edge(carbon_atoms[3], carbon_atoms[4])
    G.add_edge(carbon_atoms[4], carbon_atoms[5])
    G.add_edge(carbon_atoms[5], carbon_atoms[6])
    G.add_edge(carbon_atoms[6], carbon_atoms[7])
    G.add_edge(carbon_atoms[7], carbon_atoms[8])
    G.add_edge(carbon_atoms[8], carbon_atoms[9])
    G.add_edge(carbon_atoms[9], carbon_atoms[10])
    G.add_edge(carbon_atoms[10], carbon_atoms[11])
    G.add_edge(carbon_atoms[11], carbon_atoms[12])
    G.add_edge(carbon_atoms[12], carbon_atoms[13])
    G.add_edge(carbon_atoms[13], carbon_atoms[14])
    G.add_edge(carbon_atoms[14], carbon_atoms[15])
    G.add_edge(carbon_atoms[15], carbon_atoms[16])
    G.add_edge(carbon_atoms[16], carbon_atoms[17])
    G.add_edge(carbon_atoms[17], carbon_atoms[18])
    G.add_edge(carbon_atoms[18], carbon_atoms[19])
    G.add_edge(carbon_atoms[19], carbon_atoms[20])
    G.add_edge(carbon_atoms[20], carbon_atoms[21])
    G.add_edge(carbon_atoms[21], carbon_atoms[22])
    G.add_edge(carbon_atoms[22], carbon_atoms[23])
    G.add_edge(carbon_atoms[23], carbon_atoms[0])

    return G, role_id
def create_nitrophenol(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    nitrophenol_structure = []
    for i in range(7):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        
        if i < 6:
            role_id.append(0)
        else:
            role_id.append(1)
            
        nitrophenol_structure.append(node_idx)


    for i in range(len(nitrophenol_structure)-1):
        G.add_edge(nitrophenol_structure[i], nitrophenol_structure[i+1])

    G.add_edge(nitrophenol_structure[-1], nitrophenol_structure[0])


    G.add_edge(nitrophenol_structure[6], nitrophenol_structure[1])

    return G, role_id
def create_hydrated_sulfuric_acid(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    sulfuric_acid_structure = []
    for i in range(8):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        
        if i < 4:
            role_id.append(0)
        else:
            role_id.append(1)
            
        sulfuric_acid_structure.append(node_idx)


    G.add_edge(sulfuric_acid_structure[0], sulfuric_acid_structure[1])
    G.add_edge(sulfuric_acid_structure[1], sulfuric_acid_structure[2])
    G.add_edge(sulfuric_acid_structure[2], sulfuric_acid_structure[3])


    G.add_edge(sulfuric_acid_structure[4], sulfuric_acid_structure[0])
    G.add_edge(sulfuric_acid_structure[5], sulfuric_acid_structure[1])
    G.add_edge(sulfuric_acid_structure[6], sulfuric_acid_structure[2])
    G.add_edge(sulfuric_acid_structure[7], sulfuric_acid_structure[3])

    return G, role_id
def create_methyl_anthranilate(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    methyl_anthranilate_structure = []
    for i in range(11):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        
        if i < 7:
            role_id.append(0)
        else:
            role_id.append(1)
            
        methyl_anthranilate_structure.append(node_idx)


    for i in range(len(methyl_anthranilate_structure)-2):
        G.add_edge(methyl_anthranilate_structure[i], methyl_anthranilate_structure[i+1])

    G.add_edge(methyl_anthranilate_structure[-3], methyl_anthranilate_structure[0])


    G.add_edge(methyl_anthranilate_structure[-2], methyl_anthranilate_structure[0])
    G.add_edge(methyl_anthranilate_structure[-1], methyl_anthranilate_structure[0])

    return G, role_id
def create_anthracene(size, node_feature_mean, std):
    G = nx.Graph()
    role_id = []


    anthracene_structure = []
    for i in range(9):
        node_idx = i
        node_feature = np.random.normal(node_feature_mean, std)
        G.add_node(node_idx, feature=node_feature)
        G.nodes[node_idx]['feature'] = node_feature
        role_id.append(0)
        anthracene_structure.append(node_idx)


    G.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 0)])

    return G, role_id


molecular_generators = {
    1: (create_benzene_ring, "Benzene ring, adjustable node features"),
    2: (create_methane, "Methane, adjustable node features"),
    3: (create_ethane, "Ethane, adjustable node features"),
    4: (create_benzoic_acid, "Benzoic acid, adjustable node features"),
    5: (create_nitrobenzene, "Nitrobenzene, adjustable node features"),
    6: (create_ethanol, "Ethanol, adjustable node features"),
    7: (create_thioether, "Thioether, adjustable node features"),
    8: (create_simplified_dopamine, "Simplified dopamine, adjustable node features"),
    9: (create_hexamethylbenzene, "Hexamethylbenzene, adjustable node features"),
    10: (create_acetic_acid, "Acetic acid, adjustable node features"),
    11: (create_ammonia, "Ammonia, adjustable features"),
    12: (create_vitamin_c, "Vitamin C, adjustable features"),
    13: (create_adrenaline, "Adrenaline, adjustable features"),
    14: (create_glucose, "Glucose, adjustable features"),
    15: (create_fullerenes, "Fullerenes, adjustable features"),
    16: (create_pyridine, "Pyridine, adjustable features"),
    17: (create_pyrrole, "Pyrrole, adjustable features"),
    18: (create_indole, "Indole, adjustable features"),
    19: (create_thiazole, "Thiazole, adjustable features"),
    20: (create_imidazole, "Imidazole, adjustable features"),
    21: (create_pyrimidine, "Pyrimidine, adjustable features"),
    22: (create_porphyrin, "Porphyrin, adjustable features"),
    23: (create_nitrophenol, "Nitrophenol, adjustable features"),
    24: (create_hydrated_sulfuric_acid, "Hydrated sulfuric acid, adjustable features"),
    25: (create_methyl_anthranilate, "Methyl anthranilate, adjustable features"),
    26: (create_anthracene, "Anthracene, adjustable features")
}


motif_generators = {
    1: (create_motif_star_branch, "Star shape, star node count, branch, adjustable node features"),
    2: (create_motif_path_branch, "Path shape, path node count, branch, adjustable node features"),
    3: (create_motif_fan_branch, "Fan shape, fan node count, branch, adjustable node features"),
    4: (create_motif_cuspedPolygon_branch, "Cusped polygon, cusped polygon node count, branch, adjustable node features"),
    5: (create_motif_random_bipartite_branch, "Random bipartite, random bipartite node count, branch, adjustable node features"),
    6: (create_motif_tree_branch, "Tree shape, tree node count, branch, adjustable node features"),
    7: (create_motif_trident_branch, "Trident shape, trident node count, branch, adjustable node features"),
    8: (create_motif_conicalConnection_branch, "Conical connection graph, conical connection node count, branch, adjustable node features"),
    9: (create_motif_chainBypass_branch, "Chain bypass shape, chain bypass node count, branch, adjustable node features"),
    10: (create_motif_trident_branch, "Partial polygon, partial polygon node count, branch, adjustable node features"),
    11: (create_motif_completeGraph, "Complete graph, complete graph node count, adjustable features"),
    12: (create_motif_dircycle, "Cycle graph, cycle graph node count, adjustable features"),
    13: (create_motif_dualRing, "Dual ring graph, dual ring node count, adjustable features"),
    14: (create_motif_triangle, "Triangle graph, triangle node count, adjustable features"),
    15: (create_motif_ringShape, "Ring shape graph, ring shape node count, adjustable features"),
    16: (create_motif_diamond, "Diamond graph, diamond node count, adjustable features"),
    17: (create_motif_HShape, "H-shape graph, H-shape node count, adjustable features"),
    18: (create_motif_wheel, "Wheel graph, wheel node count, adjustable features"),
    19: (create_motif_hourglass, "Hourglass graph, hourglass node count, adjustable features"),
    20: (create_motif_DCD, "DCD triple diamond graph, DCD triple diamond node count, adjustable features"),
    21: (create_motif_Cyclocross, "Cycle cross graph, cycle cross node count, adjustable features"),
    22: (create_motif_netShape, "Net shape graph, net shape node count, adjustable features"),
    23: (create_motif_ladder, "Ladder graph, ladder node count, adjustable features"),
    24: (create_motif_bowtie, "Bowtie graph, bowtie node count, adjustable features"),
    25: (create_motif_cross, "Cross arm graph, cross arm node count, adjustable features")
}


def include_smaller_graph(G1, G2):

    if len(G1.nodes()) > len(G2.nodes()):
        G1, G2 = G2, G1

    G = nx.Graph()
    G.add_nodes_from(G2.nodes())
    G.add_nodes_from(G1.nodes())

    for u, v in G1.edges():
        if u in G2.nodes() and v in G2.nodes():
            G.add_edge(u, v)

    for u, v in G2.edges():
        if u in G.nodes() and v in G.nodes():
            G.add_edge(u, v)

    if len(list(nx.connected_components(G))) > 1:
        components = list(nx.connected_components(G))
        for i in range(len(components)-1):
            u = random.sample(components[i], 1)[0]
            v = random.sample(components[i+1], 1)[0]
            nx.add_path(G, [u, v])


    role_id = [0] * G.number_of_nodes()
    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()

    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()
    if edge_index.size(1) == 0:

       default_value = random.randint(0, num_nodes - 1)
       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)

    node_idx = torch.unique(edge_index.flatten())
    num_nodes = node_idx.size(0)
    max_node_idx = torch.max(node_idx).item()
    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:
       num_nodes = max_node_idx + 1
       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)
    node_idx_map = torch.zeros_like(node_idx)
    node_idx_map[node_idx] = torch.arange(num_nodes)
    edge_index = node_idx_map[edge_index]
    assert node_idx.max() == node_idx.size(0) - 1
    return G,role_id,edge_index

def multiple_edges_connection(G1, G2, common_edges):

    if common_edges > min(G1.number_of_edges(), G2.number_of_edges()):
        common_edges = min(G1.number_of_edges(), G2.number_of_edges())

    common_edges1 = random.sample(G1.edges(), common_edges)
    common_edges2 = random.sample(G2.edges(), common_edges)

    G2.remove_edges_from(common_edges2)

    G = nx.disjoint_union(G1, G2)
    for u, v in common_edges1:
        G.add_edge(u, v + G1.number_of_nodes())

    isolated_nodes = [n for n in G.nodes() if G.degree(n) == 0]
    for u in isolated_nodes:
        v = random.choice(list(G.nodes()))
        while G.has_edge(u, v) or u == v:
            v = random.choice(list(G.nodes()))
        G.add_edge(u, v)

    if not nx.is_connected(G):
        components = nx.connected_components(G)
        largest_component = max(components, key=len)
        isolated_nodes = [n for n in G.nodes() if n not in largest_component]
        for u in isolated_nodes:
            v = random.choice(list(largest_component))
            G.add_edge(u, v)


    role_id = [0] * G.number_of_nodes()
    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()

    max_node_id = G.number_of_nodes() - 1
    G.remove_nodes_from([n for n in G.nodes() if n > max_node_id])

    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()
    if edge_index.size(1) == 0:

       default_value = random.randint(0, num_nodes - 1)
       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)

    node_idx = torch.unique(edge_index.flatten())
    num_nodes = node_idx.size(0)
    max_node_idx = torch.max(node_idx).item()
    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:
       num_nodes = max_node_idx + 1
       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)

    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)
    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]
    edge_index = node_idx_map[edge_index]
    assert node_idx.max() == node_idx.size(0) - 1
    return G,role_id,edge_index

def multiple_nodes_connection(G1, G2,common_nodes):

    common_nodes1 = set(random.sample(G1.nodes(), common_nodes))
    common_nodes2 = set(random.sample(G2.nodes(), common_nodes))

    G2.remove_nodes_from(common_nodes2)

    G = nx.disjoint_union(G1, G2)
    while G.number_of_nodes() > G1.number_of_nodes() + G2.number_of_nodes():
        G.remove_node(random.choice(list(G.nodes())))
    for node in common_nodes1:
        if node + G1.number_of_nodes() < G.number_of_nodes():
            G.add_edge(node, node + G1.number_of_nodes())

    if not nx.is_connected(G):
        components = list(nx.connected_components(G))
        for i in range(len(components) - 1):
            u = random.choice(list(components[i]))
            v = random.choice(list(components[i+1]))
            G.add_edge(u, v)


    role_id = [0] * G.number_of_nodes()
    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()

    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()
    if edge_index.size(1) == 0:

       default_value = random.randint(0, num_nodes - 1)
       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)

    node_idx = torch.unique(edge_index)
    num_nodes = node_idx.size(0)
    max_node_idx = torch.max(node_idx).item()
    if max_node_idx >= num_nodes:
       num_nodes = max_node_idx + 1

    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)
    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)
    edge_index = node_idx_map[edge_index]
    assert node_idx.max() == node_idx.size(0) - 1
    return G, role_id, edge_index


def adjacent_connection(G1,G2):
    nodes1 = set(G1.nodes())
    nodes2 = set(G2.nodes())
    if not nodes1 or not nodes2 or not G1.edges() or not G2.edges():
        return nx.Graph(), [], torch.tensor([], dtype=torch.long)
    
    common_nodes = nodes1.intersection(nodes2)

    edge1 = random.choice(list(G1.edges()))
    edge2 = random.choice(list(G2.edges()))

    new_node1 = max(nodes1.union(nodes2)) + 1
    new_node2 = max(nodes1.union(nodes2)) + 2

    G1.remove_edge(*edge1)
    G1.add_edge(edge1[0], new_node1)
    G1.add_edge(new_node1, edge1[1])
    G2.remove_edge(*edge2)
    G2.add_edge(edge2[0], new_node2)
    G2.add_edge(new_node2, edge2[1])

    G1.add_node(new_node2)
    G2.add_node(new_node1)
    

    G = nx.compose(G1, G2)

    G.add_edge(new_node1, new_node2)

    for node in common_nodes:
        G.add_node(node, role_id=np.random.randint(low=1, high=len(common_nodes) + 3))

    if not nx.is_connected(G):
        components = nx.connected_components(G)
        largest_component = max(components, key=len)
        isolated_nodes = [n for n in G.nodes() if n not in largest_component]
        for u in isolated_nodes:
            v = random.choice(list(largest_component))
            G.add_edge(u, v)


    node_features = [np.asarray(G.nodes[node]['feature'], dtype=float).flatten() for node in G.nodes if 'feature' in G.nodes[node]]
    node_features = np.array(node_features)
    role_id = [0] * G.number_of_nodes()
    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()

    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()
    if edge_index.size(1) == 0:

       default_value = random.randint(0, num_nodes - 1)
       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)

    node_idx = torch.unique(edge_index.flatten())
    num_nodes = node_idx.size(0)
    max_node_idx = torch.max(node_idx).item()
    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:
       num_nodes = max_node_idx + 1
       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)

    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)
    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]
    edge_index = node_idx_map[edge_index]

    new_node_features = np.zeros((num_nodes, node_features.shape[1]))
    for i in range(num_nodes):
        if i < len(node_features):
            new_node_features[i] = node_features[i]
        else:
            new_node_features[i] = np.zeros(node_features.shape[1])

    node_features = new_node_features

    assert node_idx.max() == node_idx.size(0) - 1
    return G,role_id,edge_index, node_features

def reindex_graph(graph, start_index=0):
    
    new_graph = nx.Graph()
    node_mapping = {}
    
    for i, node in enumerate(graph.nodes(), start=start_index):
        new_node = i
        node_mapping[node] = new_node
        new_graph.add_node(new_node, **graph.nodes[node])
    
    for u, v in graph.edges():
        new_graph.add_edge(node_mapping[u], node_mapping[v], **graph.edges[u, v])
    
    return new_graph, node_mapping

def feature_connection(graph1, graph2):
    if len(graph1.nodes) == 0 or len(graph2.nodes) == 0:
        raise ValueError("One or both graphs are empty")
    

    G1_reindexed, mapping1 = reindex_graph(graph1)
    G2_reindexed, mapping2 = reindex_graph(graph2, start_index=len(G1_reindexed.nodes))
    

    combined_graph = nx.Graph()
    

    combined_graph.add_nodes_from(G1_reindexed.nodes(data=True))
    combined_graph.add_edges_from(G1_reindexed.edges(data=True))
    
    combined_graph.add_nodes_from(G2_reindexed.nodes(data=True))
    combined_graph.add_edges_from(G2_reindexed.edges(data=True))
    

    node1 = random.choice(list(G1_reindexed.nodes))
    node2 = random.choice(list(G2_reindexed.nodes))
    
    combined_graph.add_edge(node1, node2)


    node_features = [combined_graph.nodes[node]['feature'].tolist() for node in combined_graph.nodes if 'feature' in combined_graph.nodes[node]]
    role_id = [0] * combined_graph.number_of_nodes()
    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()

    edge_index = torch.tensor(list(combined_graph.edges()), dtype=torch.long).t().contiguous()
    return combined_graph, role_id,edge_index,node_features


motif_connectors = {
    1: (adjacent_connection, "相邻,即两个motif通过一条边连接在一起"),
    2: (multiple_nodes_connection, "交叉,两个motif间公用一些顶点,公用顶点数可配置"),
    3: (multiple_edges_connection, "纠缠,即两个motif通过多条边连接在一起,边数可配置"),
    4: (include_smaller_graph, "包含,两个motif中顶点较少的完全被另一个所包含"),
}


def generate_graph_dataset(molecular_generators, motif_connectors):

    m = random.randint(1, 10)

    a = random.randint(2, 4)
    f = lambda x: 2*x
    g = lambda x: int(0.5 * x)
    motif_m=nx.Graph()
    motif_k=nx.Graph()
    motif_n=nx.Graph()
    while not motif_m:

      motif_m,role_id1= molecular_generators[m][0](random.randint(5,20),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])

    if random.random() < 0.8:
            k = m + 10
    else:   
            k = random.randint(11, 25)
    while not motif_k:
      motif_k,role_id2= molecular_generators[k][0](f(a),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])

    G,role_id3,edge_index3, node_features = motif_connectors[1][0](motif_m, motif_k)

    n = random.randint(1, 10)
    while not motif_n:
      motif_n,role_id4 = molecular_generators[n][0](random.randint(5,20),g(a),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])

    r2=random.randint(2, 4)
    if r2 in [2, 3]:
      G,role_id,edge_index = motif_connectors[r2][0](G, motif_n,3)
    else:
      G,role_id,edge_index = motif_connectors[r2][0](G, motif_n)


    h = lambda x, y: (x + y) % 3
    label = h(m, a)
    return G, role_id, label,edge_index
def generate_Y0():
      a = random.randint(2, 4)
      motif1,role_id1= molecular_generators[1][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
      motif2,role_id2= molecular_generators[2][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
      motifY0,role_id,edge_index= motif_connectors[2][0](motif1, motif2,2)
      label=0
      return motifY0, role_id, label,edge_index

def generate_Y1():
      a = random.randint(2, 4)
      motif1,role_id1= molecular_generators[1][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
      motif3,role_id3= molecular_generators[3][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
      G,role_id,edge_index, node_features = motif_connectors[1][0](motif1, motif3)
      motif5,role_id5= molecular_generators[3][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
      motifY1,role_id,edge_index= motif_connectors[2][0](G, motif5,2)
      label=1
      return motifY1, role_id, label,edge_index
def generate_Y2():
      a = random.randint(2, 4)
      motif1,role_id1= molecular_generators[1][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
      motif2,role_id2= molecular_generators[2][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
      G,role_id,edge_index= motif_connectors[3][0](motif1, motif2,2)
      motif5,role_id5= molecular_generators[3][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
      motifY2,role_id,edge_index= motif_connectors[2][0](G, motif5,2)
      label=2
      return motifY2, role_id, label,edge_index
def generate_Y3():
      a = random.randint(2, 4)
      motif4,role_id4= molecular_generators[4][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
      motif5,role_id5= molecular_generators[5][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
      motifY3,role_id,edge_index, node_features = motif_connectors[1][0](motif4, motif5)
      label=3  
      return motifY3, role_id, label,edge_index
def generate_Y4():
      a = random.randint(2, 4)
      motif3,role_id3= molecular_generators[3][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
      motif4,role_id4= molecular_generators[4][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
      motifY4,role_id,edge_index= motif_connectors[2][0](motif3, motif4,2)
      label=4
      return motifY4, role_id, label,edge_index
def generate_real_dataset():



        y = random.choice([0, 1, 2 ,3 ,4])
        if y == 0:
           G, role_id, label,edge_index=generate_Y0()
           motif1_present = True
           motif2_present = True
           motif3_present = False
           motif4_present = False
           motif5_present = False
        elif y == 1:
           G, role_id, label,edge_index=generate_Y1()
           motif1_present = True
           motif2_present = False
           motif3_present = True
           motif4_present = False
           motif5_present = True
        elif y == 2:
           G, role_id, label,edge_index=generate_Y2()
           motif1_present = True
           motif2_present = True
           motif3_present = False
           motif4_present = False
           motif5_present = True
        elif y == 3:
           G, role_id, label,edge_index=generate_Y3()
           motif1_present = False
           motif2_present = False
           motif3_present = False
           motif4_present = True
           motif5_present = True
        elif y == 4:
           G, role_id, label,edge_index=generate_Y4()
           motif1_present = False
           motif2_present = False
           motif3_present = True
           motif4_present = True
           motif5_present = False
        return G, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present

def generate_false_cause_dataset(motif_generators, motif_connectors):
        graph, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()
        a = random.randint(2, 4)    
        motif6,role_id6= motif_generators[7][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
        graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif6)
        return graph,role_id,label,edge_index

def generate_false_cause_dataset0(motif_generators, motif_connectors):
        graph, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()
        a = random.randint(2, 4)
        if random.random() < 0.9:
           numbers = random.sample(range(6,11),1)
           motifr1,role_idr2= motif_generators[numbers[0]][0](10,a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motifr1)
        return graph,role_id,label,edge_index

def generate_false_cause_dataset1(molecular_generators, motif_connectors):
        graph, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()
        a = random.randint(2, 4)

        if motif1_present == True:
           motif6,role_id6= molecular_generators[6][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif6)
        elif motif2_present == True:
           motif7,role_id7= molecular_generators[7][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif7) 
        elif motif3_present == True:
           motif8,role_id8= molecular_generators[8][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif8) 
        elif motif4_present == True:
           motif9,role_id9= molecular_generators[9][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif9) 
        elif motif5_present == True:
           motif10,role_id10= molecular_generators[10][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif10)
        else:
           graph, role_id, label,edge_index=generate_false_dataset(molecular_generators, motif_connectors)
        return graph,role_id,label,edge_index

def generate_false_cause_dataset2(molecular_generators, motif_connectors):
        graph, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()
        a = random.randint(2, 4)

        if motif1_present == True and random.random() < 0.2:
           motif6,role_id6= molecular_generators[6][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif6)
        elif motif2_present == True and random.random() < 0.2:
           motif7,role_id7= molecular_generators[7][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif7) 
        elif motif3_present == True and random.random() < 0.2:
           motif8,role_id8= molecular_generators[8][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif8) 
        elif motif4_present == True and random.random() < 0.2:
           motif9,role_id9= molecular_generators[9][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif9) 
        elif motif5_present == True and random.random() < 0.2:
           motif10,role_id10= molecular_generators[10][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif10)
        else:
            graph, role_id, label,edge_index=generate_false_dataset(molecular_generators, motif_connectors)
        return graph, role_id, label,edge_index

def generate_false_cause_dataset(molecular_generators, motif_connectors):
        graph, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()
        a = random.randint(2, 4)

        if motif1_present == True and random.random() < 0.2:
           motif6,role_id6= molecular_generators[6][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif6)
        elif motif2_present == True and random.random() < 0.2:
           motif7,role_id7= molecular_generators[7][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif7) 
        elif motif3_present == True and random.random() < 0.2:
           motif8,role_id8= molecular_generators[8][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif8) 
        elif motif4_present == True and random.random() < 0.2:
           motif9,role_id9= molecular_generators[9][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif9) 
        elif motif5_present == True and random.random() < 0.2:
           motif10,role_id10= molecular_generators[10][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif10)
        return graph, role_id, label,edge_index
def generate_false_cause_dataset3(molecular_generators, motif_connectors):
        G, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()
        a = random.randint(2, 4)

        if motif1_present == True and random.random() < 0.05:
           motif6,role_id6= molecular_generators[6][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](G, motif6)
        elif motif2_present == True and random.random() < 0.05:
           motif7,role_id7= molecular_generators[7][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](G, motif7) 
        elif motif3_present == True and random.random() < 0.05:
           motif8,role_id8= molecular_generators[8][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](G, motif8) 
        elif motif4_present == True and random.random() < 0.05:
           motif9,role_id9= molecular_generators[9][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](G, motif9) 
        elif motif5_present == True and random.random() < 0.05:
           motif10,role_id10= molecular_generators[10][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
           graph,role_id,edge_index, node_features = motif_connectors[1][0](G, motif10)


        return graph, role_id, label,edge_index

def generate_false_dataset(molecular_generators, motif_connectors):
    G, role_id, label,edge_index=generate_false_cause_dataset2(molecular_generators, motif_connectors)
    a = random.randint(2, 4)
    





    





    numbers = random.sample(range(6,27),5)
    motifr1,role_idr1= molecular_generators[numbers[0]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
    motifr2,role_idr2= molecular_generators[numbers[1]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])     
    motifr3,role_idr3= molecular_generators[numbers[2]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
    motifr4,role_idr4= molecular_generators[numbers[3]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
    motifr5,role_idr5= molecular_generators[numbers[4]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
    graph1,role_id1,edge_index, node_features = motif_connectors[1][0](G, motifr1)
    graph2,role_id2,edge_index, node_features = motif_connectors[1][0](graph1,motifr2)
    graph3,role_id3,edge_index, node_features = motif_connectors[1][0](graph2,motifr3)
    graph4,role_id4,edge_index, node_features = motif_connectors[1][0](graph3,motifr4)
    graph,role_id5,edge_index, node_features = motif_connectors[1][0](graph4,motifr5)
    G_noisy, role_id_noisy, label_noisy=add_noise(graph,0, 0.1, 0, 0.1,label)
    return G_noisy, role_id_noisy, label_noisy,edge_index

def generate_false_dataset2(molecular_generators, motif_connectors):
    G, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()
    G, role_id, label,edge_index=generate_false_cause_dataset2(molecular_generators, motif_connectors)

    a = random.randint(2, 4)
    numbers = random.sample(range(6,11),1)
    if(6 <=numbers[0]<= 10):
        motifr1,role_idr1= motif_generators[numbers[0]][0](10,a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
    else:
        motifr1,role_idr2= motif_generators[numbers[0]][0](10,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
    graph,role_id1,edge_index, node_features = motif_connectors[1][0](G, motifr1)
    G_noisy, role_id_noisy, label_noisy=add_noise(graph,0, 0.1, 0, 0.1,label)
    return  G_noisy, role_id_noisy, label_noisy,edge_index

def generate_false_dataset3(molecular_generators, motif_connectors):
    G, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()


    a = random.randint(2, 4)
    numbers = random.sample(range(6,11),1)
    if(6 <=numbers[0]<= 10):
        motifr1,role_idr1= molecular_generators[numbers[0]][0](10,a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
    else:
        motifr1,role_idr2= molecular_generators[numbers[0]][0](10,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])
    graph,role_id1,edge_index, node_features = motif_connectors[1][0](G, motifr1)

    return  graph,role_id1, label,edge_index

def add_noise(G, delete_edge_prob, add_edge_prob, delete_node_prob, add_node_prob,label=None):

    G_noisy = copy.deepcopy(G)



    num_edges_to_delete = int(delete_edge_prob * G_noisy.number_of_edges())
    edges_to_delete = random.sample(G_noisy.edges(), num_edges_to_delete)
    G_noisy.remove_edges_from(edges_to_delete)
    role_id_noisy = [random.randint(0, 2) for i in range(G_noisy.number_of_nodes())]

    num_edges_to_add = int(add_edge_prob * G_noisy.number_of_nodes() * (G_noisy.number_of_nodes()-1)/2)
    for i in range(num_edges_to_add):
        node1, node2 = random.sample(G_noisy.nodes(), 2)
        if not G_noisy.has_edge(node1, node2):
            G_noisy.add_edge(node1, node2)

    num_nodes_to_delete = int(delete_node_prob * G_noisy.number_of_nodes())
    nodes_to_delete = random.sample(G_noisy.nodes(), num_nodes_to_delete)
    for node in nodes_to_delete:
        G_noisy.remove_node(node)

    num_nodes_to_add = int(add_node_prob * G_noisy.number_of_nodes())
    for i in range(num_nodes_to_add):
        node_id = G_noisy.number_of_nodes() + 1
        G_noisy.add_node(node_id)

        connected = False
        while not connected:
            nodes_to_connect = random.sample(G_noisy.nodes(), random.randint(1, G_noisy.number_of_nodes()-1))
            for n in nodes_to_connect:
                if not G_noisy.has_edge(node_id, n):
                    G_noisy.add_edge(node_id, n)
            connected = nx.is_connected(G_noisy)
            if not connected:
                for n in nodes_to_connect:
                    G_noisy.remove_edge(node_id, n)

    role_id_noisy = [random.randint(0, 2) for i in range(G_noisy.number_of_edges())]
    label_noisy = label
    return G_noisy, role_id_noisy, label_noisy
    
    


    
def generate_graph_dataset_with_noise(molecular_generators, motif_connectors, num_samples, delete_edge_prob, add_edge_prob, delete_node_prob, add_node_prob):
    dataset = []
    for i in range(num_samples):
        G, role_id, label = generate_graph_dataset(molecular_generators, motif_connectors)
        G_noisy, role_id_noisy, label_noisy = add_noise(G, role_id, label, delete_edge_prob, add_edge_prob, delete_node_prob, add_node_prob)
        dataset.append((G_noisy, role_id_noisy, label_noisy))
    return G_noisy, role_id_noisy, label_noisy
