

def generate_gamma(mu, sigma, size):

    var = np.power(sigma, 2)
    theta = np.divide(var, mu)
    k = np.divide(mu, theta)

    return gamma.rvs(a=k, scale=theta, size=size)
def generate_nodes_normal_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
          nodes[:, i] = np.random.normal(loc=mean[i], scale=std[i], size=num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_uniform_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
          nodes[:, i] = np.random.uniform(low=mean[i]- std[i], high=mean[i] + std[i], size=num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id   
def generate_nodes_exponential_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
          nodes[:, i] = np.random.exponential(scale=std[i]/mean[i], size=num_nodes) * mean[i]
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_lognormal_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
            mu = np.log(mean[i]**2/np.sqrt(std[i]**2+mean[i]**2))
            sigma = np.sqrt(np.log(std[i]**2/mean[i]**2 + 1))
            nodes[:, i] = np.random.lognormal(mu, sigma, num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_gamma_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
        nodes[:,i] = generate_gamma(mean[i], std[i], num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id  
def generate_nodes_beta_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
        a = ((1 - np.array(mean)) / np.array(std) ** 2 - 1 / np.array(mean)) * np.array(mean) ** 2
        b = a * (1 / np.array(mean) - 1)
        nodes[:,i] = np.random.beta(a=a[i], b=b[i], size=num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_weibull_distributed(num_nodes,mean, std):
      role_id = []

      k = np.random.uniform(low=0.5, high=2.0, size=5)
      lam = np.random.uniform(low=0.5, high=2.0, size=5)
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
        nodes[:, i] = weibull_min.rvs(k[i], scale=lam[i], size=num_nodes)
        nodes[:, i] = (nodes[:, i] - nodes[:, i].mean()) / nodes[:, i].std()
        nodes[:, i] = nodes[:, i] * std[i] + mean[i]
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_laplace_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
        mu = mean[i]
        b = std[i] / np.sqrt(2)
        nodes[:, i] = np.random.laplace(mu, b, size=num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_logistic_distributed(num_nodes,mean, std):    
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
        mu = mean[i]
        s = std[i] / np.sqrt(3)
        nodes[:, i] = np.random.logistic(mu, s, size=num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_rayleigh_distributed(num_nodes,mean, std):
      dim = len(mean)    
      role_id = []
      nodes = np.zeros((num_nodes, dim))
      for i in range(dim):
        sigma = std[i] / np.sqrt(2*np.pi)
        nodes[:, i] = np.random.rayleigh(sigma, size=num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_pareto_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
        alpha = mean[i] / std[i]
        nodes[:, i] = np.random.pareto(alpha, size=num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_cauchy_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
        x0 = mean[i]
        gamma = std[i] * np.sqrt(np.pi / 2)
        nodes[:, i] = np.random.standard_cauchy(size=num_nodes) * gamma + x0
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_neg_binom_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
        p = mean[i] / (std[i]**2 + mean[i])
        r = mean[i]**2 / (std[i]**2 - mean[i])
        nodes[:, i]=np.random.negative_binomial(r, 1-p, size=num_nodes)
      role_id = list(range(num_nodes)) 
      return nodes,role_id 
def generate_nodes_gumbel_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
         loc = mean[i] - 0.57721 * std[i]
         scale = np.sqrt(6) * std[i] / np.pi
         nodes[:, i]=np.random.gumbel(loc=loc, scale=scale, size=num_nodes)
      role_id = list(range(num_nodes)) 
      return nodes,role_id   
def generate_nodes_gompertz_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
         loc = mean[i] - np.log(np.log(2)/2)*std[i]
         scale = np.exp(std[i]/np.log(2))
         nodes[:, i]=gompertz.rvs(c=loc, scale=scale, size=num_nodes)
      role_id = list(range(num_nodes))               
      return nodes,role_id
def rectangle_sequence(n):
    seq = [0]
    for i in range(1, n+1):
        seq.append(seq[-1] + i*2)
    return seq[1:]
def binomial_coefficients(n, dim):
    
    seq = []
    for i in range(1, dim+1):
        seq.append(math.comb(n, i))
    return seq
def generate_nodes_arithmetic(num_nodes,dims,step):
      role_id = []
      nodes = []
      for i in range(num_nodes):
          start = random.uniform(0,10)
          node = [start + j * step for j in range(dims)]
          nodes.append(node)

      role_id = [random.randint(0, 2) for i in range(num_nodes)]
      return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_geometric(num_nodes,dims,step):
      role_id = []
      nodes = []
      for i in range(num_nodes):
          start = random.uniform(0,10)
          node = [start*step**j for j in range(dims)]
          nodes.append(node)

      role_id = [random.randint(0, 2) for i in range(num_nodes)]
      return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_fibonacci(num_nodes,dims,step):
      role_id = []
      nodes = []
      for i in range(num_nodes):

          start1 = round(random.uniform(0, 10), 1)
          start2 = round(random.uniform(0, 10), 1)
          fib_nums = [start1, start2]
          for j in range(dims - 2):
              fib_nums.append(fib_nums[-1] + fib_nums[-2])


          nodes.append(fib_nums)
      role_id = [random.randint(0, 2) for i in range(num_nodes)]
      return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_square(num_nodes,dims,step):
        role_id = []
        nodes = []
        for i in range(num_nodes):
          initial_val = random.uniform(0, 10)
          node = [initial_val]
          for j in range(dims-1):
              node.append(initial_val ** 2)
              initial_val=initial_val ** 2
          nodes.append(node)
        role_id = [random.randint(0, 2) for i in range(num_nodes)]
        return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_cube(num_nodes,dims,step):
        role_id = []
        nodes = []
        for i in range(num_nodes):
          initial_val = random.uniform(0, 10)
          node = [initial_val]
          for j in range(dims-1):
              node.append(initial_val ** 3)
              initial_val=initial_val ** 3
          nodes.append(node)
        role_id = [random.randint(0, 2) for i in range(num_nodes)]
        return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_prime(num_nodes,dims,step):
       nodes = []
       for i in range(num_nodes):

          start = random.uniform(0, 10)

          primes = []
          n = 2
          while len(primes) < dims:
              if all(n % p != 0 for p in primes):
                  primes.append(n)
              n += 1
          node = np.array(primes) * start
          nodes.append(node)
       role_id = [random.randint(0, 2) for i in range(num_nodes)]
       return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_triangular(num_nodes,dims,step):

      nodes = []
      for i in range(num_nodes):
          initial_value = random.uniform(0, 10)
          features = []
          for j in range(dims):
              feature = 0.5 * initial_value * (initial_value + 1)
              features.append(feature)
              initial_value += 1
          nodes.append(features)
      role_id = [random.randint(0, 2) for i in range(num_nodes)]
      return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_rectangular(num_nodes,dims,step):    

      nodes = []
      for i in range(num_nodes):
          initial_value = random.uniform(0, 10)
          seq = rectangle_sequence(dims)
          feature = [initial_value + item for item in seq]
          nodes.append(feature)
      role_id = [random.randint(0, 2) for i in range(num_nodes)]
      return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_binomial(num_nodes,dims,step):

      
      nodes = []
      for i in range(num_nodes):
          rand_n = random.randint(1, 10)
          seq = binomial_coefficients(rand_n, dims)
          nodes.append(seq)
      role_id = [random.randint(0, 2) for i in range(num_nodes)]
      return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_hamilton(num_nodes,dims,step):

      nodes = []
      for i in range(num_nodes):
          initial_value = random.randint(1, 10)
          hamilton_seq = [2**n - 1 for n in range(initial_value, initial_value+dims)]
          nodes.append(hamilton_seq)
      print(nodes)
      role_id = [random.randint(0,2) for i in range(num_nodes)]
      return np.array(nodes).reshape((num_nodes, dims)),role_id
def merge_nodes(node_set1, node_set2):

    return np.vstack((node_set1, node_set2))
def build_sim_edges(nodes,sim_threshold):
    num_nodes = nodes.shape[0]
    edges = set()
    for i in range(num_nodes):
        for j in range(i+1, num_nodes):
            similarity = cosine_similarity(nodes[i:i+1, :], nodes[j:j+1, :])[0, 0]
            if similarity > sim_threshold:
                edges.add((i, j))
    return edges
def create_partial_sim_edges(nodes, partial_sim_threshold,dims):
    num_nodes = len(nodes)
    edges = []
    for i in range(num_nodes):
        for j in range(i+1, num_nodes):
            sim = cosine_similarity(nodes[i][dims].reshape(1, -1), nodes[j][dims].reshape(1, -1))
            sim=sim[0][0]
            if sim > partial_sim_threshold:
                edges.append((i, j))
    return edges

node_generators = {
    1: (generate_nodes_normal_distributed, "Normal distribution generation based on mean and standard deviation"),
    2: (generate_nodes_uniform_distributed, "Uniform distribution generation based on mean and standard deviation"),
    3: (generate_nodes_exponential_distributed, "Exponential distribution generation based on mean and standard deviation"),
    4: (generate_nodes_lognormal_distributed, "Log-normal distribution generation based on mean and standard deviation"),
    5: (generate_nodes_weibull_distributed, "Weibull distribution generation based on mean and standard deviation"),
    6: (generate_nodes_laplace_distributed, "Laplace distribution generation based on mean and standard deviation"),
    7: (generate_nodes_logistic_distributed, "Logistic distribution generation based on mean and standard deviation"),
    8: (generate_nodes_rayleigh_distributed, "Rayleigh distribution generation based on mean and standard deviation"),
    9: (generate_nodes_pareto_distributed, "Pareto distribution generation based on mean and standard deviation"),
    10: (generate_nodes_cauchy_distributed, "Cauchy distribution generation based on mean and standard deviation"),
    11: (generate_nodes_neg_binom_distributed, "Negative binomial distribution generation based on mean and standard deviation"),
    12: (generate_nodes_gumbel_distributed, "Gumbel distribution generation based on mean and standard deviation"),
    13: (generate_nodes_gompertz_distributed, "Gompertz distribution generation based on mean and standard deviation"),
    14: (generate_nodes_normal_distributed, "Gamma distribution generation based on mean and standard deviation"),
    15: (generate_nodes_normal_distributed, "Beta distribution generation based on mean and standard deviation"),
    16: (generate_nodes_arithmetic, "Arithmetic sequence generation"),
    17: (generate_nodes_geometric, "Geometric sequence generation"),
    18: (generate_nodes_fibonacci, "Fibonacci sequence generation"),
    19: (generate_nodes_square, "Square number sequence generation"),
    20: (generate_nodes_cube, "Cube number sequence generation"),
    21: (generate_nodes_prime, "Prime number sequence generation"),
    22: (generate_nodes_triangular, "Triangular number sequence generation"),
    23: (generate_nodes_rectangular, "Rectangular number sequence generation"),
    24: (generate_nodes_binomial, "Binomial coefficient sequence generation"),
    25: (generate_nodes_hamilton, "Hamiltonian sequence generation")
}

def generate_graph(type, num_nodes, mean, std, sim_threshold):
    nodes,role_id = node_generators[type][0](num_nodes,mean,std)

    edges = []
    for i in range(num_nodes):
        for j in range(i+1, num_nodes):

            if sim_threshold is not None:

               similarity = cosine_similarity(nodes[i:i+1, :], nodes[j:j+1, :])[0, 0]
               if similarity > sim_threshold:
                  edges.append((i, j))   
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edges)
    role_id = [random.randint(0, 2) for i in range(num_nodes)]

    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()
    if edge_index.size(1) == 0:

       default_value = random.randint(0, num_nodes - 1)
       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)

    node_idx = torch.unique(edge_index.flatten())
    num_nodes = node_idx.size(0)
    max_node_idx = torch.max(node_idx).item()
    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:
       num_nodes = max_node_idx + 1
       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)

    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)
    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]
    edge_index = node_idx_map[edge_index]
    assert node_idx.max() == node_idx.size(0) - 1
    return G,role_id

def create_paper_citation_graph(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std,size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        remaining_papers = [p_id for p_id in range(num_papers) if not G.has_edge(p_id, paper_id)]
        num_remaining_papers = len(remaining_papers)
        adjusted_avg_citations_per_paper = min(avg_citations_per_paper, num_remaining_papers)
        num_citations = np.random.poisson(adjusted_avg_citations_per_paper)

        if num_remaining_papers == 0:
            continue
        cited_papers = np.random.choice(remaining_papers, size=min(num_citations, num_remaining_papers), replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)

    return G, role_id

def create_paper_citation_graph2(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    citation_counts = {paper_id: np.random.poisson(avg_citations_per_paper) for paper_id in range(num_papers)}


    for paper_id in range(num_papers):

        sorted_papers = sorted(range(num_papers), key=lambda x: citation_counts[x], reverse=True)

        cited_papers = [p_id for p_id in sorted_papers if p_id != paper_id]

        num_citations = min(avg_citations_per_paper, len(cited_papers))
        cited_papers = cited_papers[:num_citations]
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)

    return G, role_id

def create_paper_citation_graph3(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    author_paper_map = {}
    for paper_id in range(num_papers):
        num_authors = np.random.randint(1, 5)
        authors = [f"Author_{i}" for i in range(num_authors)]
        author_paper_map[paper_id] = authors


    for paper_id in range(num_papers):
        authors_of_current_paper = author_paper_map[paper_id]

        other_papers = [p_id for p_id, authors in author_paper_map.items() if set(authors_of_current_paper) & set(authors)]

        other_papers = [p_id for p_id in other_papers if p_id != paper_id]

        num_citations = min(avg_citations_per_paper, len(other_papers))
        cited_papers = np.random.choice(other_papers, size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)

    return G, role_id

def create_paper_citation_graph4(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = np.random.poisson(avg_citations_per_paper)
        available_papers = [p_id for p_id in range(num_papers) if p_id != paper_id]
        if len(available_papers) < num_citations:
            num_citations = len(available_papers)
        cited_papers = np.random.choice(available_papers, size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    for paper_id in range(num_papers):
        visited = {paper_id}
        propagation_queue = deque([(paper_id, 0)])
        while propagation_queue:
            current_paper_id, depth = propagation_queue.popleft()
            if depth >= avg_citations_per_paper:
                break
            for neighbor_id in G.neighbors(current_paper_id):
                if neighbor_id not in visited:
                    visited.add(neighbor_id)
                    G.add_edge(paper_id, neighbor_id)
                    propagation_queue.append((neighbor_id, depth + 1))

    return G, role_id

def create_paper_citation_graph5(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)

        candidate_papers = [p_id for p_id in range(num_papers) if p_id != paper_id]
        cited_papers = np.random.choice(candidate_papers, size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    topic_similarities = cosine_similarity(np.random.rand(num_papers, 10))


    for paper_id in range(num_papers):
        similar_papers = [(similarity, idx) for idx, similarity in enumerate(topic_similarities[paper_id]) if idx != paper_id]
        similar_papers.sort(reverse=True)
        num_citations = min(avg_citations_per_paper, len(similar_papers))
        for _, cited_paper_id in similar_papers[:num_citations]:
            if not G.has_edge(paper_id, cited_paper_id):
                G.add_edge(paper_id, cited_paper_id)

    return G, role_id


def create_paper_citation_graph6(num_papers, avg_citations_per_paper, num_classes, mean, std,publication_years=None):
    G = nx.DiGraph()
    role_id = []


    if publication_years is None:
        publication_years = np.random.randint(2000, 2022, size=num_papers)


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)

        candidate_papers = [p_id for p_id in range(num_papers) if p_id != paper_id]
        cited_papers = np.random.choice(candidate_papers, size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    for paper_id in range(num_papers):
        publication_year = publication_years[paper_id]

        for cited_paper_id in range(num_papers):
            if publication_years[cited_paper_id] < publication_year and not G.has_edge(paper_id, cited_paper_id):
                G.add_edge(paper_id, cited_paper_id)
                break

    return G, role_id


def create_author_influence_citation_graph(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    author_influence_data = {}
    for paper_id in range(num_papers):
        num_authors = np.random.randint(1, 5)
        authors = np.random.randint(0, num_papers, size=num_authors)
        influence_scores = np.random.uniform(0, 1, size=num_authors)
        author_influence_data[paper_id] = dict(zip(authors, influence_scores))


    for paper_id in range(num_papers):
        num_citations = min(avg_citations_per_paper, num_papers - 1)
        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    for paper_id in range(num_papers):
        authors = list(author_influence_data[paper_id].keys())
        for cited_paper_id in range(num_papers):
            if cited_paper_id == paper_id:
                continue
            cited_authors = list(author_influence_data[cited_paper_id].keys())
            if any(author in cited_authors for author in authors) and not G.has_edge(paper_id, cited_paper_id):
                G.add_edge(paper_id, cited_paper_id)
                break

    return G, role_id

def create_common_citation_count_citation_graph(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)
        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    common_citation_count_data = {}
    for paper_id in range(num_papers):
        common_citation_counts = np.random.randint(0, 10, size=num_papers)
        common_citation_count_data[paper_id] = common_citation_counts


    for paper_id in range(num_papers):
        common_citation_counts = common_citation_count_data[paper_id]
        for cited_paper_id in range(num_papers):
            if cited_paper_id == paper_id:
                continue
            cited_common_citation_counts = common_citation_count_data[cited_paper_id]
            if any(count > 0 for count in common_citation_counts if count in cited_common_citation_counts) and not G.has_edge(paper_id, cited_paper_id):
                G.add_edge(paper_id, cited_paper_id)
                break

    return G, role_id

def create_citation_graph_based_on_citation_density(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(0, 1, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)

        candidate_papers = [p_id for p_id in range(num_papers) if p_id != paper_id]
        cited_papers = np.random.choice(candidate_papers, size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    citation_density = {paper_id: len(list(G.predecessors(paper_id))) / num_papers for paper_id in range(num_papers)}


    for paper_id in range(num_papers):

        sorted_papers = sorted(citation_density.keys(), key=lambda x: citation_density[x], reverse=True)

        for cited_paper_id in sorted_papers[:avg_citations_per_paper]:
            if cited_paper_id != paper_id and not G.has_edge(paper_id, cited_paper_id):
                G.add_edge(paper_id, cited_paper_id)
                break

    return G, role_id

def create_citation_graph_based_on_network_structure(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std,size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)

        candidate_papers = [p_id for p_id in range(num_papers) if p_id != paper_id]
        cited_papers = np.random.choice(candidate_papers, size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    for paper_id in range(num_papers):
        neighbors = list(G.neighbors(paper_id))
        if neighbors:
            cited_paper_id = np.random.choice(neighbors)
            if not G.has_edge(paper_id, cited_paper_id):
                G.add_edge(paper_id, cited_paper_id)

    return G, role_id

def generate_author_list(num_authors=3, num_fields=5):

    authors = []
    for _ in range(num_authors):
        author_id = np.random.randint(1, 1000)
        author_fields = np.random.choice(range(num_fields), size=np.random.randint(1, 4), replace=False)
        authors.append((author_id, author_fields))
    return authors
def get_paper_field(authors):

    paper_field = set()
    for author in authors:
        paper_field.update(author[1])
    return paper_field
def create_citation_graph_based_on_author_field(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []

    for paper_id in range(num_papers):
        features = np.random.normal(mean,std,size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = np.random.poisson(avg_citations_per_paper)
        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            if cited_paper_id != paper_id:
                G.add_edge(paper_id, cited_paper_id)



    for paper_id in range(num_papers):
        authors = generate_author_list()
        paper_field = get_paper_field(authors)
        for cited_paper_id in range(num_papers):
            if cited_paper_id != paper_id:
                cited_paper_authors = generate_author_list()
                cited_paper_field = get_paper_field(cited_paper_authors)
                if paper_field == cited_paper_field:
                    G.add_edge(paper_id, cited_paper_id)
                    break

    return G, role_id


def create_citation_graph_based_on_centrality(num_papers, avg_citations_per_paper, num_classes, centrality_measure='degree'):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(0, 1, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)
        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            if cited_paper_id != paper_id:
                G.add_edge(paper_id, cited_paper_id)


    if centrality_measure == 'degree':
        centrality = nx.degree_centrality(G)
    elif centrality_measure == 'betweenness':
        if nx.is_connected(G):
            centrality = nx.betweenness_centrality(G)
        else:
            raise ValueError("Graph must be connected for betweenness centrality!")
    elif centrality_measure == 'closeness':
        centrality = nx.closeness_centrality(G)
    else:
        raise ValueError("Unsupported centrality measure!")


    for paper_id in range(num_papers):
        candidates = list(set(range(num_papers)) - {paper_id})
        candidate_centralities = {candidate: centrality.get(candidate, 0) for candidate in candidates}
        if not candidate_centralities:
            G.add_edge(paper_id, np.random.randint(num_papers))
        else:
            max_centrality_paper = max(candidate_centralities, key=candidate_centralities.get)
            G.add_edge(paper_id, max_centrality_paper)

    return G, role_id


def create_citation_graph_based_on_geographical_location(num_papers, avg_citations_per_paper, num_classes):
    G = nx.DiGraph()
    role_id = []


    locations = ["USA", "UK", "Canada", "Germany", "France", "China", "Japan", "Australia", "India", "Brazil"]
    author_locations = [random.choice(locations) for _ in range(num_papers)]


    for paper_id in range(num_papers):
        features = np.random.normal(0, 1, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = np.random.poisson(avg_citations_per_paper)
        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            if cited_paper_id != paper_id:
                G.add_edge(paper_id, cited_paper_id)


    for paper_id in range(num_papers):
        author_location = author_locations[paper_id]
        for other_paper_id, location in enumerate(author_locations):
            if other_paper_id != paper_id and location == author_location:
                G.add_edge(paper_id, other_paper_id)
                break

    return G, role_id

def create_team_size_citation_graph(num_papers, avg_citations_per_paper, num_classes):
    G = nx.DiGraph()
    role_id = []


    team_sizes = np.random.randint(50, 101, size=num_papers)


    for paper_id in range(num_papers):
        features = np.random.normal(0, 1, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = np.random.poisson(avg_citations_per_paper)
        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    for paper_id in range(num_papers):
        paper_team_size = team_sizes[paper_id]
        for cited_paper_id in range(num_papers):
            if paper_id == cited_paper_id:
                continue
            cited_paper_team_size = team_sizes[cited_paper_id]
            if abs(paper_team_size - cited_paper_team_size) <= 5:
                G.add_edge(paper_id, cited_paper_id)
    return G, role_id

def create_citation_graph_based_on_credibility(num_papers, avg_citations_per_paper, num_classes):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(0, 1, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = np.random.poisson(avg_citations_per_paper)
        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    credibility_scores = np.random.rand(num_papers)


    for paper_id in range(num_papers):
        credibility_threshold = np.percentile(credibility_scores, 75)
        candidates = [paper_id for paper_id, score in enumerate(credibility_scores) if score >= credibility_threshold and paper_id != paper_id]
        if candidates:
            cited_paper_id = np.random.choice(candidates)
            G.add_edge(paper_id, cited_paper_id)

    return G, role_id



def generate_citation_graph(num_papers, num_authors_per_paper, num_authors):

    G = nx.DiGraph()
    author_list = []
    paper_author_relations = {}


    for paper_id in range(num_papers):
        authors = np.random.choice(num_authors, size=num_authors_per_paper, replace=False)
        paper_author_relations[paper_id] = authors
        for author in authors:
            author_list.append(author)


    for paper_id, authors in paper_author_relations.items():
        for author in authors:
            G.add_node(author)
            for cited_paper_id, cited_authors in paper_author_relations.items():
                if cited_paper_id != paper_id:
                    for cited_author in cited_authors:
                        if author != cited_author and not G.has_edge(author, cited_author):
                            G.add_edge(author, cited_author)
    

    citations = {}
    for paper_id, authors in paper_author_relations.items():
        for author in authors:
            for successor_author in G.successors(author):
                if successor_author in paper_author_relations and paper_id not in citations:
                    citations[paper_id] = successor_author


    role_id = np.random.randint(0, 2, size=num_papers)

    return G, role_id, citations

def convert_to_networkx_graph(citation_network):
    G = nx.DiGraph()
    num_papers = citation_network.shape[0]
    for i in range(num_papers):
        for j in range(num_papers):
            if citation_network[i][j] == 1:
                G.add_edge(i, j)
    return G
def generate_triangle_citation_network(num_papers, avg_citations_per_paper, num_classes):
    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)


    num_citations = [random.randint(avg_citations_per_paper - 2, avg_citations_per_paper + 2) for _ in range(num_papers)]

    for i in range(num_papers):

        role_id[i] = random.randint(0, num_classes - 1)
        for _ in range(num_citations[i]):

            citation_paper = i
            while citation_paper == i:
                citation_paper = random.randint(0, num_papers - 1)
            citation_network[i][citation_paper] = 1
    

    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id



def generate_citation_network_with_distance(num_papers, max_distance, self_citation_prob):
    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)
    

    citation_distances = np.zeros((num_papers, num_papers), dtype=int)


    for i in range(num_papers):
        for j in range(num_papers):
            if i != j:

                distance = random.randint(1, max_distance)
                citation_distances[i][j] = distance
                

    for i in range(num_papers):

        role_id[i] = random.randint(0, 1)
        for j in range(num_papers):
            if i != j:
                distance = citation_distances[i][j]
                if distance <= max_distance:
                    probability = 1 / distance

                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1 if random.random() < probability else 0
                        

    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id



def generate_flow_direction_prob(num_domains):
    flow_direction_prob = np.zeros((num_domains, num_domains))
    for i in range(num_domains):
        for j in range(num_domains):
            if i != j:

                flow_direction_prob[i][j] = np.random.rand()

    flow_direction_prob /= np.sum(flow_direction_prob, axis=1, keepdims=True)
    return flow_direction_prob
def generate_citation_network_with_knowledge_flow(num_papers, num_domains):

    flow_direction_prob = generate_flow_direction_prob(num_domains)
    

    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)
    
    for i in range(num_papers):

        role_id[i] = random.randint(0, 1)
        for j in range(num_papers):
            if i != j:

                probability = flow_direction_prob[i % num_domains][j % num_domains]

                if random.random() < probability:
                    citation_network[i][j] = 1


    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id


def generate_citation_network_with_chain_length(num_papers, max_chain_length, self_citation_prob):
    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)
    

    for i in range(num_papers):

        role_id[i] = random.randint(0, 1)
        for j in range(num_papers):
            if i != j:

                chain_length = random.randint(1, max_chain_length)

                probability = 1 / chain_length

                if random.random() < self_citation_prob:
                    citation_network[i][j] = 1 if random.random() < probability else 0
    

    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id

def generate_citation_network_with_diversity(num_papers, num_domains, diversity_threshold, self_citation_prob):
    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)
    

    references = np.random.randint(0, num_domains, size=(num_papers, num_papers))
    

    diversity_scores = np.zeros(num_papers)
    for i in range(num_papers):
        diversity_scores[i] = len(set(references[i]))
    

    for i in range(num_papers):

        role_id[i] = random.randint(0, 1)
        for j in range(num_papers):
            if i != j:

                if diversity_scores[i] > diversity_threshold:
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1
                else:

                    probability = 1 / (1 + diversity_threshold - diversity_scores[i])
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1 if random.random() < probability else 0
    

    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id


def generate_citation_network_with_reference_count(num_papers, max_reference_count, self_citation_prob):
    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)
    

    reference_counts = np.random.randint(1, max_reference_count + 1, size=num_papers)
    

    for i in range(num_papers):

        role_id[i] = random.randint(0, 1)
        for j in range(num_papers):
            if i != j:

                if reference_counts[i] > reference_counts[j]:
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1
                else:

                    probability = reference_counts[i] / reference_counts[j]
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1 if random.random() < probability else 0
    

    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id


def generate_citation_network_with_research_object(num_papers, num_objects, self_citation_prob):
    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)
    

    object_popularity = np.random.rand(num_objects)
    

    paper_objects = np.random.randint(0, num_objects, size=num_papers)
    

    for i in range(num_papers):

        role_id[i] = random.randint(0, 1)
        for j in range(num_papers):
            if i != j:

                if object_popularity[paper_objects[i]] > object_popularity[paper_objects[j]]:
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1
                else:

                    probability = object_popularity[paper_objects[i]] / object_popularity[paper_objects[j]]
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1 if random.random() < probability else 0
    

    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id


def generate_citation_network_with_journal_reputation(num_papers, num_journals, self_citation_prob):
    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)
    

    journal_reputation = np.random.rand(num_journals)
    

    paper_journals = np.random.randint(0, num_journals, size=num_papers)
    

    for i in range(num_papers):

        role_id[i] = random.randint(0, 1)
        for j in range(num_papers):
            if i != j:

                if journal_reputation[paper_journals[i]] > journal_reputation[paper_journals[j]]:
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1
                else:

                    probability = journal_reputation[paper_journals[i]] / journal_reputation[paper_journals[j]]
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1 if random.random() < probability else 0
    

    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id


def generate_citation_network_with_open_access(num_papers, open_access_prob, self_citation_prob):
    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)
    

    open_access_status = np.random.rand(num_papers) < open_access_prob
    

    for i in range(num_papers):
        role_id[i] = random.randint(0, 1)
        for j in range(num_papers):
            if i != j:

                if open_access_status[i]:
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1
                else:

                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1 if random.random() < self_citation_prob else 0
    

    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id



paper_generators = {
    1: (create_paper_citation_graph, "Random citation relationship generation"),
    2: (create_paper_citation_graph2, "Citation based on paper citation count"),
    3: (create_paper_citation_graph3, "Citation based on author co-citation"),
    4: (create_paper_citation_graph4, "Citation based on citation relationship propagation"),
    5: (create_paper_citation_graph5, "Citation based on topic similarity"),
    6: (create_paper_citation_graph6, "Citation based on citation time"),
    7: (create_author_influence_citation_graph, "Citation based on author influence"),
    8: (create_common_citation_count_citation_graph, "Citation based on common citation count"),
    9: (create_citation_graph_based_on_citation_density, "Citation based on citation density"),
    10: (create_citation_graph_based_on_network_structure, "Citation based on network structure"),
    11: (create_citation_graph_based_on_author_field, "Citation based on author field"),
    12: (create_citation_graph_based_on_centrality, "Citation based on citation network centrality"),
    13: (create_citation_graph_based_on_geographical_location, "Citation based on author geographical location"),
    14: (create_team_size_citation_graph, "Citation based on research team size"),
    15: (create_citation_graph_based_on_credibility, "Citation based on citation credibility"),
    16: (generate_citation_graph, "Citation based on academic lineage relationship"),
    17: (generate_triangle_citation_network, "Citation based on citation structure"),
    18: (generate_citation_network_with_distance, "Citation based on citation distance"),
    19: (generate_citation_network_with_knowledge_flow, "Rules based on knowledge flow"),
    20: (generate_citation_network_with_chain_length, "Rules based on citation chain"),
    21: (generate_citation_network_with_diversity, "Citation based on diversity"),
    22: (generate_citation_network_with_reference_count, "Citation based on reference count"),
    23: (generate_citation_network_with_research_object, "Citation based on research object"),
    24: (generate_citation_network_with_journal_reputation, "Citation based on journal/conference reputation"),
    25: (generate_citation_network_with_open_access, "Rules based on open access")
}

def generate_graph1(type1,type2,num_nodes, mean, std,sim_threshold):
    nodes1,role_id1 = node_generators[type1][0](num_nodes,mean,std)
    nodes2,role_id2 = node_generators[type2][0](num_nodes,mean,std)
    nodes = np.concatenate((nodes1, nodes2), axis=0)

    num_nodes = len(nodes)
    edges = []
    for i in range(num_nodes):
        for j in range(i+1, num_nodes):

            if sim_threshold is not None:

               similarity = cosine_similarity(nodes[i:i+1, :], nodes[j:j+1, :])[0, 0]
               if similarity > sim_threshold:
                  edges.append((i, j))   
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edges)
    role_id = [random.randint(0, 2) for i in range(num_nodes)]

    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()
    if edge_index.size(1) == 0:

       default_value = random.randint(0, num_nodes - 1)
       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)

    node_idx = torch.unique(edge_index.flatten())
    num_nodes = node_idx.size(0)
    max_node_idx = torch.max(node_idx).item()
    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:
       num_nodes = max_node_idx + 1
       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)

    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)
    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]
    edge_index = node_idx_map[edge_index]
    assert node_idx.max() == node_idx.size(0) - 1
    return G,role_id

def generate_graph2(type1,type2,num_nodes, mean, std,partial_sim_threshold, dims):
    nodes1,role_id1 = node_generators[type1][0](num_nodes,mean,std)
    nodes2,role_id2 = node_generators[type2][0](num_nodes,mean,std)
    nodes = np.concatenate((nodes1, nodes2), axis=0)

    num_nodes = len(nodes)
    edges = []
    for i in range(num_nodes):
        for j in range(i+1, num_nodes):

            if partial_sim_threshold is not None and dims is not None:

               sim = cosine_similarity(nodes[i][dims].reshape(1, -1), nodes[j][dims].reshape(1, -1))
               sim=sim[0][0]
               if sim > partial_sim_threshold:
                  edges.append((i, j))     
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edges)
    role_id = [random.randint(0, 2) for i in range(num_nodes)]
    return G,role_id

node_connectors = {
    1: (generate_graph1, "Similar edges with autonomous node similarity judgment"),
    2: (generate_graph2, "Partial similar edges with autonomous multi-dimensional node similarity judgment"),
}
num_papers=10
avg_citations_per_paper=3
num_classes=3
mean=[1.5, 2.0, 1.2, 1.3, 1.8]
std=[1.5, 2.0, 1.2, 1.3, 1.8]
def generate_Y0():
      G0,role_id = create_paper_citation_graph2(num_papers, avg_citations_per_paper, num_classes,mean, std)
      label=0
      return G0, role_id, label
def generate_Y1():
      G1,role_id = create_paper_citation_graph3(num_papers, avg_citations_per_paper, num_classes,mean, std)
      label=1
      return G1, role_id, label
def generate_Y2():
      G2,role_id = create_paper_citation_graph4(num_papers, avg_citations_per_paper, num_classes,mean, std)
      label=2
      return G2, role_id, label
def generate_Y3():
      G3,role_id = create_paper_citation_graph5(num_papers, avg_citations_per_paper, num_classes,mean, std)
      label=3
      return G3, role_id, label
def generate_Y4():
      G4,role_id = create_paper_citation_graph6(num_papers, avg_citations_per_paper, num_classes, mean=[1.5, 2.0, 1.2, 1.3, 1.8], std=[1.5, 2.0, 1.2, 1.3, 1.8],publication_years=None)
      label=4
      return G4, role_id, label
def generate_real_dataset():



        y = random.choice([0, 1, 2 ,3 ,4])
        if y == 0:
           G, role_id, label=generate_Y0()
           motif1_present = True
           motif2_present = True
           motif3_present = False
           motif4_present = False
           motif5_present = False
        elif y == 1:
           G, role_id, label=generate_Y1()
           motif1_present = True
           motif2_present = False
           motif3_present = True
           motif4_present = False
           motif5_present = False
        elif y == 2:
           G, role_id, label=generate_Y2()
           motif1_present = False
           motif2_present = True
           motif3_present = False
           motif4_present = False
           motif5_present = True
        elif y == 3:
           G, role_id, label=generate_Y3()
           motif1_present = False
           motif2_present = False
           motif3_present = False
           motif4_present = True
           motif5_present = True
        elif y == 4:
           G, role_id, label=generate_Y4()
           motif1_present = False
           motif2_present = False
           motif3_present = True
           motif4_present = True
           motif5_present = False
        return G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present

def adjacent_connection(G1,G2):
    nodes1 = set(G1.nodes())
    nodes2 = set(G2.nodes())
    if not nodes1 or not nodes2 or not G1.edges() or not G2.edges():
        return nx.Graph(), [], torch.tensor([], dtype=torch.long)
    common_nodes = nodes1.intersection(nodes2)

    edge1 = random.choice(list(G1.edges()))
    edge2 = random.choice(list(G2.edges()))

    new_node1 = max(nodes1.union(nodes2)) + 1
    new_node2 = max(nodes1.union(nodes2)) + 2

    G1.remove_edge(*edge1)
    G1.add_edge(edge1[0], new_node1)
    G1.add_edge(new_node1, edge1[1])
    G2.remove_edge(*edge2)
    G2.add_edge(edge2[0], new_node2)
    G2.add_edge(new_node2, edge2[1])

    G1.add_node(new_node2)
    G2.add_node(new_node1)

    G = nx.compose(G1, G2)

    G.add_edge(new_node1, new_node2)

    for node in common_nodes:
        G.add_node(node, role_id=np.random.randint(low=1, high=len(common_nodes) + 3))

    if not nx.is_weakly_connected(G):
        components = nx.connected_components(G)
        largest_component = max(components, key=len)
        isolated_nodes = [n for n in G.nodes() if n not in largest_component]
        for u in isolated_nodes:
            v = random.choice(list(largest_component))
            G.add_edge(u, v)


    role_id = [0] * G.number_of_nodes()
    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()


    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()
    if edge_index.size(1) == 0:

       default_value = random.randint(0, num_nodes - 1)
       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)

    node_idx = torch.unique(edge_index.flatten())
    num_nodes = node_idx.size(0)
    max_node_idx = torch.max(node_idx).item()
    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:
       num_nodes = max_node_idx + 1
       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)

    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)
    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]
    edge_index = node_idx_map[edge_index]
    assert node_idx.max() == node_idx.size(0) - 1
    return G,role_id

def generate_false_cause_dataset1():



        mean=[1.0, 2.0, 1.0, 1.5, 3.0]
        std=[1.0, 2.0, 1.0, 1.5, 3.0]
        G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()

        if motif1_present == True:
           G6,role_id= create_author_influence_citation_graph(num_papers, avg_citations_per_paper, num_classes,mean, std)
           graph,role_id= adjacent_connection(G, G6)
        elif motif2_present == True:
           G7,role_id= create_common_citation_count_citation_graph(num_papers, avg_citations_per_paper, num_classes,mean, std)
           graph,role_id= adjacent_connection(G, G7)
        elif motif3_present == True:
           G8,role_id= create_citation_graph_based_on_citation_density(num_papers, avg_citations_per_paper, num_classes,mean, std)
           graph,role_id= adjacent_connection(G, G8)
        elif motif4_present == True:
           G9,role_id= create_citation_graph_based_on_network_structure(num_papers, avg_citations_per_paper, num_classes,mean, std)
           graph,role_id= adjacent_connection(G, G9)
        elif motif5_present == True:
           G10,role_id= create_citation_graph_based_on_author_field(num_papers, avg_citations_per_paper, num_classes,mean, std)
           graph,role_id= adjacent_connection(G, G10)
        else:
           graph,role_id=generate_false_dataset()
        return graph, role_id, label

def generate_false_cause_dataset2():
        G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()

        if motif1_present == True:
           G6,role_id= generate_graph(6,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.3)
           graph,role_id= adjacent_connection(G, G6)
        elif motif2_present == True:
           G7,role_id= generate_graph(7,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.3)
           graph,role_id= adjacent_connection(G, G7)
        elif motif3_present == True:
           G8,role_id= generate_graph(8,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.3)
           graph,role_id= adjacent_connection(G, G8)
        elif motif4_present == True:
           G9,role_id= generate_graph(9,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.5)
           graph,role_id= adjacent_connection(G, G9)
        elif motif5_present == True:
           G10,role_id= generate_graph(10,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.5)
           graph,role_id= adjacent_connection(G, G10)
        return graph, role_id, label

def generate_false_cause_dataset3():
        graph, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()

        if motif1_present == True:
           G6,role_id= generate_graph(6,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.3)
           graph,role_id= adjacent_connection(graph, G6)
        elif motif2_present == True:
           G7,role_id= generate_graph(7,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.3)
           graph,role_id= adjacent_connection(graph, G7)
        elif motif3_present == True:
           G8,role_id= generate_graph(8,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.3)
           graph,role_id= adjacent_connection(graph, G8)
        elif motif4_present == True:
           G9,role_id= generate_graph(9,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.5)
           graph,role_id= adjacent_connection(graph, G9)
        elif motif5_present == True:
           G10,role_id= generate_graph(10,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.5)
           graph,role_id= adjacent_connection(graph, G10)
        return graph, role_id, label

def generate_false_cause_dataset4():
        G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()

        if motif1_present == True:
           G6,role_id= generate_graph(6,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.3)
           graph,role_id= adjacent_connection(G, G6)
        elif motif2_present == True:
           G7,role_id= generate_graph(7,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.3)
           graph,role_id= adjacent_connection(G, G7)
        elif motif3_present == True:
           G8,role_id= generate_graph(8,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.3)
           graph,role_id= adjacent_connection(G, G8)
        elif motif4_present == True:
           G9,role_id= generate_graph(9,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.5)
           graph,role_id= adjacent_connection(G, G9)
        elif motif5_present == True:
           G10,role_id= generate_graph(10,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.5)
           graph,role_id= adjacent_connection(G, G10)
        return graph, role_id, label


def generate_false_cause_dataset5():
        graph, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()

        if motif1_present == True:
           G6,role_id= generate_graph(6,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.3)
           graph,role_id= adjacent_connection(graph, G6)
        elif motif2_present == True:
           G7,role_id= generate_graph(7,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.3)
           graph,role_id= adjacent_connection(graph, G7)
        elif motif3_present == True:
           G8,role_id= generate_graph(8,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.3)
           graph,role_id= adjacent_connection(graph, G8)
        elif motif4_present == True:
           G9,role_id= generate_graph(9,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.5)
           graph,role_id= adjacent_connection(graph, G9)
        elif motif5_present == True:
           G10,role_id= generate_graph(10,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.5)
           graph,role_id= adjacent_connection(graph, G10)
        return graph, role_id, label

def generate_false_cause_dataset6():
        G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()

        if motif1_present == True:
           G6,role_id= generate_graph(6,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.3)
           graph,role_id= adjacent_connection(G, G6)
        elif motif2_present == True:
           G7,role_id= generate_graph(7,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.3)
           graph,role_id= adjacent_connection(G, G7)
        elif motif3_present == True:
           G8,role_id= generate_graph(8,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.3)
           graph,role_id= adjacent_connection(G, G8)
        elif motif4_present == True:
           G9,role_id= generate_graph(9,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.5)
           graph,role_id= adjacent_connection(G, G9)
        elif motif5_present == True:
           G10,role_id= generate_graph(10,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.5)
           graph,role_id= adjacent_connection(G, G10)
        return graph, role_id, label
def generate_false_dataset():
    G,role_id,label=generate_false_cause_dataset1()
    max_distance = 4
    self_citation_prob = 0
    num_domains = 3
    max_chain_length = 5
    diversity_threshold = 2
    max_reference_count = 20
    num_objects = 5
    num_journals = 5
    open_access_prob = 0.5

    mean=[1.0, 2.0, 1.0, 1.5, 3.0]
    std=[1.0, 2.0, 1.0, 1.5, 3.0]







    pgraph1,role_idr= create_paper_citation_graph(num_papers, avg_citations_per_paper, num_classes,mean, std)


    graph,role_id=adjacent_connection(G,pgraph1)
























    return graph,role_id,label

def add_noise(G, delete_edge_prob, add_edge_prob, delete_node_prob, add_node_prob,label=None):

    G_noisy = copy.deepcopy(G)



    num_edges_to_delete = int(delete_edge_prob * G_noisy.number_of_edges())
    edges_to_delete = random.sample(G_noisy.edges(), num_edges_to_delete)
    G_noisy.remove_edges_from(edges_to_delete)
    role_id_noisy = [random.randint(0, 2) for i in range(G_noisy.number_of_nodes())]
    

    num_edges_to_add = int(add_edge_prob * G_noisy.number_of_nodes() * (G_noisy.number_of_nodes()-1)/2)
    for i in range(num_edges_to_add):
        node1, node2 = random.sample(G_noisy.nodes(), 2)
        if not G_noisy.has_edge(node1, node2):
            G_noisy.add_edge(node1, node2)

    num_nodes_to_delete = int(delete_node_prob * G_noisy.number_of_nodes())
    nodes_to_delete = random.sample(G_noisy.nodes(), num_nodes_to_delete)
    for node in nodes_to_delete:
        G_noisy.remove_node(node)

    num_nodes_to_add = int(add_node_prob * G_noisy.number_of_nodes())
    for i in range(num_nodes_to_add):
        node_id = G_noisy.number_of_nodes() + 1
        G_noisy.add_node(node_id)

        connected = False
        while not connected:
            nodes_to_connect = random.sample(G_noisy.nodes(), random.randint(1, G_noisy.number_of_nodes()-1))
            for n in nodes_to_connect:
                if not G_noisy.has_edge(node_id, n):
                    G_noisy.add_edge(node_id, n)
            connected = nx.is_weakly_connected(G_noisy)
            if not connected:
                for n in nodes_to_connect:
                    G_noisy.remove_edge(node_id, n)

    role_id_noisy = [random.randint(0, 2) for i in range(G_noisy.number_of_edges())]
    label_noisy = label

    return G_noisy, role_id_noisy, label_noisy
