from BA3_loc import *
from tqdm import tqdm
import os.path as osp
import warnings
warnings.filterwarnings("ignore")
import random
import math
import torch
import copy

from scipy.stats import gamma
from scipy.stats import gompertz
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import weibull_min
from scipy.special import gamma, gammaincinv
from scipy.spatial.distance import cdist
from sklearn.metrics.pairwise import euclidean_distances
from collections import deque
data_dir = f'../data/CRCG-NODE/raw/'
os.makedirs(data_dir, exist_ok=True)



def generate_gamma(mu, sigma, size):

    var = np.power(sigma, 2)
    theta = np.divide(var, mu)
    k = np.divide(mu, theta)

    return gamma.rvs(a=k, scale=theta, size=size)
def generate_nodes_normal_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
          nodes[:, i] = np.random.normal(loc=mean[i], scale=std[i], size=num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_uniform_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
          nodes[:, i] = np.random.uniform(low=mean[i]- std[i], high=mean[i] + std[i], size=num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id   
def generate_nodes_exponential_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
          nodes[:, i] = np.random.exponential(scale=std[i]/mean[i], size=num_nodes) * mean[i]
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_lognormal_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
            mu = np.log(mean[i]**2/np.sqrt(std[i]**2+mean[i]**2))
            sigma = np.sqrt(np.log(std[i]**2/mean[i]**2 + 1))
            nodes[:, i] = np.random.lognormal(mu, sigma, num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_gamma_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
        nodes[:,i] = generate_gamma(mean[i], std[i], num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id  
def generate_nodes_beta_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
        a = ((1 - np.array(mean)) / np.array(std) ** 2 - 1 / np.array(mean)) * np.array(mean) ** 2
        b = a * (1 / np.array(mean) - 1)
        nodes[:,i] = np.random.beta(a=a[i], b=b[i], size=num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_weibull_distributed(num_nodes,mean, std):
      role_id = []

      k = np.random.uniform(low=0.5, high=2.0, size=5)
      lam = np.random.uniform(low=0.5, high=2.0, size=5)
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
        nodes[:, i] = weibull_min.rvs(k[i], scale=lam[i], size=num_nodes)
        nodes[:, i] = (nodes[:, i] - nodes[:, i].mean()) / nodes[:, i].std()
        nodes[:, i] = nodes[:, i] * std[i] + mean[i]
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_laplace_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
        mu = mean[i]
        b = std[i] / np.sqrt(2)
        nodes[:, i] = np.random.laplace(mu, b, size=num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_logistic_distributed(num_nodes,mean, std):    
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
        mu = mean[i]
        s = std[i] / np.sqrt(3)
        nodes[:, i] = np.random.logistic(mu, s, size=num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_rayleigh_distributed(num_nodes,mean, std):
      dim = len(mean)    
      role_id = []
      nodes = np.zeros((num_nodes, dim))
      for i in range(dim):
        sigma = std[i] / np.sqrt(2*np.pi)
        nodes[:, i] = np.random.rayleigh(sigma, size=num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_pareto_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
        alpha = mean[i] / std[i]
        nodes[:, i] = np.random.pareto(alpha, size=num_nodes)
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_cauchy_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
        x0 = mean[i]
        gamma = std[i] * np.sqrt(np.pi / 2)
        nodes[:, i] = np.random.standard_cauchy(size=num_nodes) * gamma + x0
      role_id = list(range(num_nodes))
      return nodes,role_id
def generate_nodes_neg_binom_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
        p = mean[i] / (std[i]**2 + mean[i])
        r = mean[i]**2 / (std[i]**2 - mean[i])
        nodes[:, i]=np.random.negative_binomial(r, 1-p, size=num_nodes)
      role_id = list(range(num_nodes)) 
      return nodes,role_id 
def generate_nodes_gumbel_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
         loc = mean[i] - 0.57721 * std[i]
         scale = np.sqrt(6) * std[i] / np.pi
         nodes[:, i]=np.random.gumbel(loc=loc, scale=scale, size=num_nodes)
      role_id = list(range(num_nodes)) 
      return nodes,role_id   
def generate_nodes_gompertz_distributed(num_nodes,mean, std):
      role_id = []
      nodes = np.zeros((num_nodes, len(mean)))
      for i in range(len(mean)):
         loc = mean[i] - np.log(np.log(2)/2)*std[i]
         scale = np.exp(std[i]/np.log(2))
         nodes[:, i]=gompertz.rvs(c=loc, scale=scale, size=num_nodes)
      role_id = list(range(num_nodes))               
      return nodes,role_id
def rectangle_sequence(n):
    seq = [0]
    for i in range(1, n+1):
        seq.append(seq[-1] + i*2)
    return seq[1:]
def binomial_coefficients(n, dim):
    
    seq = []
    for i in range(1, dim+1):
        seq.append(math.comb(n, i))
    return seq
def generate_nodes_arithmetic(num_nodes,dims,step):
      role_id = []
      nodes = []
      for i in range(num_nodes):
          start = random.uniform(0,10)
          node = [start + j * step for j in range(dims)]
          nodes.append(node)

      role_id = [random.randint(0, 2) for i in range(num_nodes)]
      return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_geometric(num_nodes,dims,step):
      role_id = []
      nodes = []
      for i in range(num_nodes):
          start = random.uniform(0,10)
          node = [start*step**j for j in range(dims)]
          nodes.append(node)

      role_id = [random.randint(0, 2) for i in range(num_nodes)]
      return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_fibonacci(num_nodes,dims,step):
      role_id = []
      nodes = []
      for i in range(num_nodes):

          start1 = round(random.uniform(0, 10), 1)
          start2 = round(random.uniform(0, 10), 1)
          fib_nums = [start1, start2]
          for j in range(dims - 2):
              fib_nums.append(fib_nums[-1] + fib_nums[-2])


          nodes.append(fib_nums)
      role_id = [random.randint(0, 2) for i in range(num_nodes)]
      return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_square(num_nodes,dims,step):
        role_id = []
        nodes = []
        for i in range(num_nodes):
          initial_val = random.uniform(0, 10)
          node = [initial_val]
          for j in range(dims-1):
              node.append(initial_val ** 2)
              initial_val=initial_val ** 2
          nodes.append(node)
        role_id = [random.randint(0, 2) for i in range(num_nodes)]
        return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_cube(num_nodes,dims,step):
        role_id = []
        nodes = []
        for i in range(num_nodes):
          initial_val = random.uniform(0, 10)
          node = [initial_val]
          for j in range(dims-1):
              node.append(initial_val ** 3)
              initial_val=initial_val ** 3
          nodes.append(node)
        role_id = [random.randint(0, 2) for i in range(num_nodes)]
        return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_prime(num_nodes,dims,step):
       nodes = []
       for i in range(num_nodes):

          start = random.uniform(0, 10)

          primes = []
          n = 2
          while len(primes) < dims:
              if all(n % p != 0 for p in primes):
                  primes.append(n)
              n += 1
          node = np.array(primes) * start
          nodes.append(node)
       role_id = [random.randint(0, 2) for i in range(num_nodes)]
       return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_triangular(num_nodes,dims,step):

      nodes = []
      for i in range(num_nodes):
          initial_value = random.uniform(0, 10)
          features = []
          for j in range(dims):
              feature = 0.5 * initial_value * (initial_value + 1)
              features.append(feature)
              initial_value += 1
          nodes.append(features)
      role_id = [random.randint(0, 2) for i in range(num_nodes)]
      return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_rectangular(num_nodes,dims,step):    

      nodes = []
      for i in range(num_nodes):
          initial_value = random.uniform(0, 10)
          seq = rectangle_sequence(dims)
          feature = [initial_value + item for item in seq]
          nodes.append(feature)
      role_id = [random.randint(0, 2) for i in range(num_nodes)]
      return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_binomial(num_nodes,dims,step):

      
      nodes = []
      for i in range(num_nodes):
          rand_n = random.randint(1, 10)
          seq = binomial_coefficients(rand_n, dims)
          nodes.append(seq)
      role_id = [random.randint(0, 2) for i in range(num_nodes)]
      return np.array(nodes).reshape((num_nodes, dims)),role_id
def generate_nodes_hamilton(num_nodes,dims,step):

      nodes = []
      for i in range(num_nodes):
          initial_value = random.randint(1, 10)
          hamilton_seq = [2**n - 1 for n in range(initial_value, initial_value+dims)]
          nodes.append(hamilton_seq)
      print(nodes)
      role_id = [random.randint(0,2) for i in range(num_nodes)]
      return np.array(nodes).reshape((num_nodes, dims)),role_id
def merge_nodes(node_set1, node_set2):

    return np.vstack((node_set1, node_set2))
def build_sim_edges(nodes,sim_threshold):
    num_nodes = nodes.shape[0]
    edges = set()
    for i in range(num_nodes):
        for j in range(i+1, num_nodes):
            similarity = cosine_similarity(nodes[i:i+1, :], nodes[j:j+1, :])[0, 0]
            if similarity > sim_threshold:
                edges.add((i, j))
    return edges
def create_partial_sim_edges(nodes, partial_sim_threshold,dims):
    num_nodes = len(nodes)
    edges = []
    for i in range(num_nodes):
        for j in range(i+1, num_nodes):
            sim = cosine_similarity(nodes[i][dims].reshape(1, -1), nodes[j][dims].reshape(1, -1))
            sim=sim[0][0]
            if sim > partial_sim_threshold:
                edges.append((i, j))
    return edges

node_generators = {
    1: (generate_nodes_normal_distributed, "Normal distribution generation based on mean and standard deviation"),
    2: (generate_nodes_uniform_distributed, "Uniform distribution generation based on mean and standard deviation"),
    3: (generate_nodes_exponential_distributed, "Exponential distribution generation based on mean and standard deviation"),
    4: (generate_nodes_lognormal_distributed, "Log-normal distribution generation based on mean and standard deviation"),
    5: (generate_nodes_weibull_distributed, "Weibull distribution generation based on mean and standard deviation"),
    6: (generate_nodes_laplace_distributed, "Laplace distribution generation based on mean and standard deviation"),
    7: (generate_nodes_logistic_distributed, "Logistic distribution generation based on mean and standard deviation"),
    8: (generate_nodes_rayleigh_distributed, "Rayleigh distribution generation based on mean and standard deviation"),
    9: (generate_nodes_pareto_distributed, "Pareto distribution generation based on mean and standard deviation"),
    10: (generate_nodes_cauchy_distributed, "Cauchy distribution generation based on mean and standard deviation"),
    11: (generate_nodes_neg_binom_distributed, "Negative binomial distribution generation based on mean and standard deviation"),
    12: (generate_nodes_gumbel_distributed, "Gumbel distribution generation based on mean and standard deviation"),
    13: (generate_nodes_gompertz_distributed, "Gompertz distribution generation based on mean and standard deviation"),
    14: (generate_nodes_normal_distributed, "Gamma distribution generation based on mean and standard deviation"),
    15: (generate_nodes_normal_distributed, "Beta distribution generation based on mean and standard deviation"),
    16: (generate_nodes_arithmetic, "Arithmetic sequence generation"),
    17: (generate_nodes_geometric, "Geometric sequence generation"),
    18: (generate_nodes_fibonacci, "Fibonacci sequence generation"),
    19: (generate_nodes_square, "Square number sequence generation"),
    20: (generate_nodes_cube, "Cube number sequence generation"),
    21: (generate_nodes_prime, "Prime number sequence generation"),
    22: (generate_nodes_triangular, "Triangular number sequence generation"),
    23: (generate_nodes_rectangular, "Rectangular number sequence generation"),
    24: (generate_nodes_binomial, "Binomial coefficient sequence generation"),
    25: (generate_nodes_hamilton, "Hamiltonian sequence generation")
}

def generate_graph(type, num_nodes, mean, std, sim_threshold):
    nodes,role_id = node_generators[type][0](num_nodes,mean,std)

    edges = []
    for i in range(num_nodes):
        for j in range(i+1, num_nodes):

            if sim_threshold is not None:

               similarity = cosine_similarity(nodes[i:i+1, :], nodes[j:j+1, :])[0, 0]
               if similarity > sim_threshold:
                  edges.append((i, j))   
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edges)
    role_id = [random.randint(0, 2) for i in range(num_nodes)]

    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()
    if edge_index.size(1) == 0:

       default_value = random.randint(0, num_nodes - 1)
       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)

    node_idx = torch.unique(edge_index.flatten())
    num_nodes = node_idx.size(0)
    max_node_idx = torch.max(node_idx).item()
    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:
       num_nodes = max_node_idx + 1
       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)

    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)
    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]
    edge_index = node_idx_map[edge_index]
    assert node_idx.max() == node_idx.size(0) - 1
    return G,role_id

def create_paper_citation_graph(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std,size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        remaining_papers = [p_id for p_id in range(num_papers) if not G.has_edge(p_id, paper_id)]
        num_remaining_papers = len(remaining_papers)
        adjusted_avg_citations_per_paper = min(avg_citations_per_paper, num_remaining_papers)
        num_citations = np.random.poisson(adjusted_avg_citations_per_paper)

        if num_remaining_papers == 0:
            continue
        cited_papers = np.random.choice(remaining_papers, size=min(num_citations, num_remaining_papers), replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)

    return G, role_id

def create_paper_citation_graph2(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    citation_counts = {paper_id: np.random.poisson(avg_citations_per_paper) for paper_id in range(num_papers)}


    for paper_id in range(num_papers):

        sorted_papers = sorted(range(num_papers), key=lambda x: citation_counts[x], reverse=True)

        cited_papers = [p_id for p_id in sorted_papers if p_id != paper_id]

        num_citations = min(avg_citations_per_paper, len(cited_papers))
        cited_papers = cited_papers[:num_citations]
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)

    return G, role_id

def create_paper_citation_graph3(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    author_paper_map = {}
    for paper_id in range(num_papers):
        num_authors = np.random.randint(1, 5)
        authors = [f"Author_{i}" for i in range(num_authors)]
        author_paper_map[paper_id] = authors


    for paper_id in range(num_papers):
        authors_of_current_paper = author_paper_map[paper_id]

        other_papers = [p_id for p_id, authors in author_paper_map.items() if set(authors_of_current_paper) & set(authors)]

        other_papers = [p_id for p_id in other_papers if p_id != paper_id]

        num_citations = min(avg_citations_per_paper, len(other_papers))
        cited_papers = np.random.choice(other_papers, size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)

    return G, role_id

def create_paper_citation_graph4(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = np.random.poisson(avg_citations_per_paper)
        available_papers = [p_id for p_id in range(num_papers) if p_id != paper_id]
        if len(available_papers) < num_citations:
            num_citations = len(available_papers)
        cited_papers = np.random.choice(available_papers, size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    for paper_id in range(num_papers):
        visited = {paper_id}
        propagation_queue = deque([(paper_id, 0)])
        while propagation_queue:
            current_paper_id, depth = propagation_queue.popleft()
            if depth >= avg_citations_per_paper:
                break
            for neighbor_id in G.neighbors(current_paper_id):
                if neighbor_id not in visited:
                    visited.add(neighbor_id)
                    G.add_edge(paper_id, neighbor_id)
                    propagation_queue.append((neighbor_id, depth + 1))

    return G, role_id

def create_paper_citation_graph5(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)

        candidate_papers = [p_id for p_id in range(num_papers) if p_id != paper_id]
        cited_papers = np.random.choice(candidate_papers, size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    topic_similarities = cosine_similarity(np.random.rand(num_papers, 10))


    for paper_id in range(num_papers):
        similar_papers = [(similarity, idx) for idx, similarity in enumerate(topic_similarities[paper_id]) if idx != paper_id]
        similar_papers.sort(reverse=True)
        num_citations = min(avg_citations_per_paper, len(similar_papers))
        for _, cited_paper_id in similar_papers[:num_citations]:
            if not G.has_edge(paper_id, cited_paper_id):
                G.add_edge(paper_id, cited_paper_id)

    return G, role_id


def create_paper_citation_graph6(num_papers, avg_citations_per_paper, num_classes, mean, std,publication_years=None):
    G = nx.DiGraph()
    role_id = []


    if publication_years is None:
        publication_years = np.random.randint(2000, 2022, size=num_papers)


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)

        candidate_papers = [p_id for p_id in range(num_papers) if p_id != paper_id]
        cited_papers = np.random.choice(candidate_papers, size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    for paper_id in range(num_papers):
        publication_year = publication_years[paper_id]

        for cited_paper_id in range(num_papers):
            if publication_years[cited_paper_id] < publication_year and not G.has_edge(paper_id, cited_paper_id):
                G.add_edge(paper_id, cited_paper_id)
                break

    return G, role_id


def create_author_influence_citation_graph(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    author_influence_data = {}
    for paper_id in range(num_papers):
        num_authors = np.random.randint(1, 5)
        authors = np.random.randint(0, num_papers, size=num_authors)
        influence_scores = np.random.uniform(0, 1, size=num_authors)
        author_influence_data[paper_id] = dict(zip(authors, influence_scores))


    for paper_id in range(num_papers):
        num_citations = min(avg_citations_per_paper, num_papers - 1)
        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    for paper_id in range(num_papers):
        authors = list(author_influence_data[paper_id].keys())
        for cited_paper_id in range(num_papers):
            if cited_paper_id == paper_id:
                continue
            cited_authors = list(author_influence_data[cited_paper_id].keys())
            if any(author in cited_authors for author in authors) and not G.has_edge(paper_id, cited_paper_id):
                G.add_edge(paper_id, cited_paper_id)
                break

    return G, role_id

def create_common_citation_count_citation_graph(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)
        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    common_citation_count_data = {}
    for paper_id in range(num_papers):
        common_citation_counts = np.random.randint(0, 10, size=num_papers)
        common_citation_count_data[paper_id] = common_citation_counts


    for paper_id in range(num_papers):
        common_citation_counts = common_citation_count_data[paper_id]
        for cited_paper_id in range(num_papers):
            if cited_paper_id == paper_id:
                continue
            cited_common_citation_counts = common_citation_count_data[cited_paper_id]
            if any(count > 0 for count in common_citation_counts if count in cited_common_citation_counts) and not G.has_edge(paper_id, cited_paper_id):
                G.add_edge(paper_id, cited_paper_id)
                break

    return G, role_id

def create_citation_graph_based_on_citation_density(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(0, 1, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)

        candidate_papers = [p_id for p_id in range(num_papers) if p_id != paper_id]
        cited_papers = np.random.choice(candidate_papers, size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    citation_density = {paper_id: len(list(G.predecessors(paper_id))) / num_papers for paper_id in range(num_papers)}


    for paper_id in range(num_papers):

        sorted_papers = sorted(citation_density.keys(), key=lambda x: citation_density[x], reverse=True)

        for cited_paper_id in sorted_papers[:avg_citations_per_paper]:
            if cited_paper_id != paper_id and not G.has_edge(paper_id, cited_paper_id):
                G.add_edge(paper_id, cited_paper_id)
                break

    return G, role_id

def create_citation_graph_based_on_network_structure(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(mean,std,size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)

        candidate_papers = [p_id for p_id in range(num_papers) if p_id != paper_id]
        cited_papers = np.random.choice(candidate_papers, size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    for paper_id in range(num_papers):
        neighbors = list(G.neighbors(paper_id))
        if neighbors:
            cited_paper_id = np.random.choice(neighbors)
            if not G.has_edge(paper_id, cited_paper_id):
                G.add_edge(paper_id, cited_paper_id)

    return G, role_id

def generate_author_list(num_authors=3, num_fields=5):

    authors = []
    for _ in range(num_authors):
        author_id = np.random.randint(1, 1000)
        author_fields = np.random.choice(range(num_fields), size=np.random.randint(1, 4), replace=False)
        authors.append((author_id, author_fields))
    return authors
def get_paper_field(authors):

    paper_field = set()
    for author in authors:
        paper_field.update(author[1])
    return paper_field
def create_citation_graph_based_on_author_field(num_papers, avg_citations_per_paper, num_classes, mean, std):
    G = nx.DiGraph()
    role_id = []

    for paper_id in range(num_papers):
        features = np.random.normal(mean,std,size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = np.random.poisson(avg_citations_per_paper)
        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            if cited_paper_id != paper_id:
                G.add_edge(paper_id, cited_paper_id)



    for paper_id in range(num_papers):
        authors = generate_author_list()
        paper_field = get_paper_field(authors)
        for cited_paper_id in range(num_papers):
            if cited_paper_id != paper_id:
                cited_paper_authors = generate_author_list()
                cited_paper_field = get_paper_field(cited_paper_authors)
                if paper_field == cited_paper_field:
                    G.add_edge(paper_id, cited_paper_id)
                    break

    return G, role_id


def create_citation_graph_based_on_centrality(num_papers, avg_citations_per_paper, num_classes, centrality_measure='degree'):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(0, 1, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)
        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            if cited_paper_id != paper_id:
                G.add_edge(paper_id, cited_paper_id)


    if centrality_measure == 'degree':
        centrality = nx.degree_centrality(G)
    elif centrality_measure == 'betweenness':
        if nx.is_connected(G):
            centrality = nx.betweenness_centrality(G)
        else:
            raise ValueError("Graph must be connected for betweenness centrality!")
    elif centrality_measure == 'closeness':
        centrality = nx.closeness_centrality(G)
    else:
        raise ValueError("Unsupported centrality measure!")


    for paper_id in range(num_papers):
        candidates = list(set(range(num_papers)) - {paper_id})
        candidate_centralities = {candidate: centrality.get(candidate, 0) for candidate in candidates}
        if not candidate_centralities:
            G.add_edge(paper_id, np.random.randint(num_papers))
        else:
            max_centrality_paper = max(candidate_centralities, key=candidate_centralities.get)
            G.add_edge(paper_id, max_centrality_paper)

    return G, role_id


def create_citation_graph_based_on_geographical_location(num_papers, avg_citations_per_paper, num_classes):
    G = nx.DiGraph()
    role_id = []


    locations = ["USA", "UK", "Canada", "Germany", "France", "China", "Japan", "Australia", "India", "Brazil"]
    author_locations = [random.choice(locations) for _ in range(num_papers)]


    for paper_id in range(num_papers):
        features = np.random.normal(0, 1, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = np.random.poisson(avg_citations_per_paper)
        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            if cited_paper_id != paper_id:
                G.add_edge(paper_id, cited_paper_id)


    for paper_id in range(num_papers):
        author_location = author_locations[paper_id]
        for other_paper_id, location in enumerate(author_locations):
            if other_paper_id != paper_id and location == author_location:
                G.add_edge(paper_id, other_paper_id)
                break

    return G, role_id

def create_team_size_citation_graph(num_papers, avg_citations_per_paper, num_classes):
    G = nx.DiGraph()
    role_id = []


    team_sizes = np.random.randint(50, 101, size=num_papers)


    for paper_id in range(num_papers):
        features = np.random.normal(0, 1, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = np.random.poisson(avg_citations_per_paper)
        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    for paper_id in range(num_papers):
        paper_team_size = team_sizes[paper_id]
        for cited_paper_id in range(num_papers):
            if paper_id == cited_paper_id:
                continue
            cited_paper_team_size = team_sizes[cited_paper_id]
            if abs(paper_team_size - cited_paper_team_size) <= 5:
                G.add_edge(paper_id, cited_paper_id)
    return G, role_id

def create_citation_graph_based_on_credibility(num_papers, avg_citations_per_paper, num_classes):
    G = nx.DiGraph()
    role_id = []


    for paper_id in range(num_papers):
        features = np.random.normal(0, 1, size=(5,))
        label = np.random.randint(0, num_classes)
        role_id.append(label)
        G.add_node(paper_id, features=features, label=label)


    for paper_id in range(num_papers):
        num_citations = np.random.poisson(avg_citations_per_paper)
        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)
        for cited_paper_id in cited_papers:
            G.add_edge(paper_id, cited_paper_id)


    credibility_scores = np.random.rand(num_papers)


    for paper_id in range(num_papers):
        credibility_threshold = np.percentile(credibility_scores, 75)
        candidates = [paper_id for paper_id, score in enumerate(credibility_scores) if score >= credibility_threshold and paper_id != paper_id]
        if candidates:
            cited_paper_id = np.random.choice(candidates)
            G.add_edge(paper_id, cited_paper_id)

    return G, role_id



def generate_citation_graph(num_papers, num_authors_per_paper, num_authors):

    G = nx.DiGraph()
    author_list = []
    paper_author_relations = {}


    for paper_id in range(num_papers):
        authors = np.random.choice(num_authors, size=num_authors_per_paper, replace=False)
        paper_author_relations[paper_id] = authors
        for author in authors:
            author_list.append(author)


    for paper_id, authors in paper_author_relations.items():
        for author in authors:
            G.add_node(author)
            for cited_paper_id, cited_authors in paper_author_relations.items():
                if cited_paper_id != paper_id:
                    for cited_author in cited_authors:
                        if author != cited_author and not G.has_edge(author, cited_author):
                            G.add_edge(author, cited_author)
    

    citations = {}
    for paper_id, authors in paper_author_relations.items():
        for author in authors:
            for successor_author in G.successors(author):
                if successor_author in paper_author_relations and paper_id not in citations:
                    citations[paper_id] = successor_author


    role_id = np.random.randint(0, 2, size=num_papers)

    return G, role_id, citations

def convert_to_networkx_graph(citation_network):
    G = nx.DiGraph()
    num_papers = citation_network.shape[0]
    for i in range(num_papers):
        for j in range(num_papers):
            if citation_network[i][j] == 1:
                G.add_edge(i, j)
    return G
def generate_triangle_citation_network(num_papers, avg_citations_per_paper, num_classes):
    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)


    num_citations = [random.randint(avg_citations_per_paper - 2, avg_citations_per_paper + 2) for _ in range(num_papers)]

    for i in range(num_papers):

        role_id[i] = random.randint(0, num_classes - 1)
        for _ in range(num_citations[i]):

            citation_paper = i
            while citation_paper == i:
                citation_paper = random.randint(0, num_papers - 1)
            citation_network[i][citation_paper] = 1
    

    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id



def generate_citation_network_with_distance(num_papers, max_distance, self_citation_prob):
    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)
    

    citation_distances = np.zeros((num_papers, num_papers), dtype=int)


    for i in range(num_papers):
        for j in range(num_papers):
            if i != j:

                distance = random.randint(1, max_distance)
                citation_distances[i][j] = distance
                

    for i in range(num_papers):

        role_id[i] = random.randint(0, 1)
        for j in range(num_papers):
            if i != j:
                distance = citation_distances[i][j]
                if distance <= max_distance:
                    probability = 1 / distance

                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1 if random.random() < probability else 0
                        

    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id



def generate_flow_direction_prob(num_domains):
    flow_direction_prob = np.zeros((num_domains, num_domains))
    for i in range(num_domains):
        for j in range(num_domains):
            if i != j:

                flow_direction_prob[i][j] = np.random.rand()

    flow_direction_prob /= np.sum(flow_direction_prob, axis=1, keepdims=True)
    return flow_direction_prob
def generate_citation_network_with_knowledge_flow(num_papers, num_domains):

    flow_direction_prob = generate_flow_direction_prob(num_domains)
    

    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)
    
    for i in range(num_papers):

        role_id[i] = random.randint(0, 1)
        for j in range(num_papers):
            if i != j:

                probability = flow_direction_prob[i % num_domains][j % num_domains]

                if random.random() < probability:
                    citation_network[i][j] = 1


    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id


def generate_citation_network_with_chain_length(num_papers, max_chain_length, self_citation_prob):
    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)
    

    for i in range(num_papers):

        role_id[i] = random.randint(0, 1)
        for j in range(num_papers):
            if i != j:

                chain_length = random.randint(1, max_chain_length)

                probability = 1 / chain_length

                if random.random() < self_citation_prob:
                    citation_network[i][j] = 1 if random.random() < probability else 0
    

    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id

def generate_citation_network_with_diversity(num_papers, num_domains, diversity_threshold, self_citation_prob):
    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)
    

    references = np.random.randint(0, num_domains, size=(num_papers, num_papers))
    

    diversity_scores = np.zeros(num_papers)
    for i in range(num_papers):
        diversity_scores[i] = len(set(references[i]))
    

    for i in range(num_papers):

        role_id[i] = random.randint(0, 1)
        for j in range(num_papers):
            if i != j:

                if diversity_scores[i] > diversity_threshold:
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1
                else:

                    probability = 1 / (1 + diversity_threshold - diversity_scores[i])
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1 if random.random() < probability else 0
    

    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id


def generate_citation_network_with_reference_count(num_papers, max_reference_count, self_citation_prob):
    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)
    

    reference_counts = np.random.randint(1, max_reference_count + 1, size=num_papers)
    

    for i in range(num_papers):

        role_id[i] = random.randint(0, 1)
        for j in range(num_papers):
            if i != j:

                if reference_counts[i] > reference_counts[j]:
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1
                else:

                    probability = reference_counts[i] / reference_counts[j]
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1 if random.random() < probability else 0
    

    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id


def generate_citation_network_with_research_object(num_papers, num_objects, self_citation_prob):
    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)
    

    object_popularity = np.random.rand(num_objects)
    

    paper_objects = np.random.randint(0, num_objects, size=num_papers)
    

    for i in range(num_papers):

        role_id[i] = random.randint(0, 1)
        for j in range(num_papers):
            if i != j:

                if object_popularity[paper_objects[i]] > object_popularity[paper_objects[j]]:
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1
                else:

                    probability = object_popularity[paper_objects[i]] / object_popularity[paper_objects[j]]
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1 if random.random() < probability else 0
    

    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id


def generate_citation_network_with_journal_reputation(num_papers, num_journals, self_citation_prob):
    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)
    

    journal_reputation = np.random.rand(num_journals)
    

    paper_journals = np.random.randint(0, num_journals, size=num_papers)
    

    for i in range(num_papers):

        role_id[i] = random.randint(0, 1)
        for j in range(num_papers):
            if i != j:

                if journal_reputation[paper_journals[i]] > journal_reputation[paper_journals[j]]:
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1
                else:

                    probability = journal_reputation[paper_journals[i]] / journal_reputation[paper_journals[j]]
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1 if random.random() < probability else 0
    

    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id


def generate_citation_network_with_open_access(num_papers, open_access_prob, self_citation_prob):
    citation_network = np.zeros((num_papers, num_papers), dtype=int)
    role_id = np.zeros(num_papers, dtype=int)
    

    open_access_status = np.random.rand(num_papers) < open_access_prob
    

    for i in range(num_papers):
        role_id[i] = random.randint(0, 1)
        for j in range(num_papers):
            if i != j:

                if open_access_status[i]:
                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1
                else:

                    if random.random() < self_citation_prob:
                        citation_network[i][j] = 1 if random.random() < self_citation_prob else 0
    

    G = convert_to_networkx_graph(citation_network)
    
    return G, role_id



paper_generators = {
    1: (create_paper_citation_graph, "Random citation relationship generation"),
    2: (create_paper_citation_graph2, "Citation based on paper citation count"),
    3: (create_paper_citation_graph3, "Citation based on author co-citation"),
    4: (create_paper_citation_graph4, "Citation based on citation relationship propagation"),
    5: (create_paper_citation_graph5, "Citation based on topic similarity"),
    6: (create_paper_citation_graph6, "Citation based on citation time"),
    7: (create_author_influence_citation_graph, "Citation based on author influence"),
    8: (create_common_citation_count_citation_graph, "Citation based on common citation count"),
    9: (create_citation_graph_based_on_citation_density, "Citation based on citation density"),
    10: (create_citation_graph_based_on_network_structure, "Citation based on network structure"),
    11: (create_citation_graph_based_on_author_field, "Citation based on author field"),
    12: (create_citation_graph_based_on_centrality, "Citation based on citation network centrality"),
    13: (create_citation_graph_based_on_geographical_location, "Citation based on author geographical location"),
    14: (create_team_size_citation_graph, "Citation based on research team size"),
    15: (create_citation_graph_based_on_credibility, "Citation based on citation credibility"),
    16: (generate_citation_graph, "Citation based on academic lineage relationship"),
    17: (generate_triangle_citation_network, "Citation based on citation structure"),
    18: (generate_citation_network_with_distance, "Citation based on citation distance"),
    19: (generate_citation_network_with_knowledge_flow, "Rules based on knowledge flow"),
    20: (generate_citation_network_with_chain_length, "Rules based on citation chain"),
    21: (generate_citation_network_with_diversity, "Citation based on diversity"),
    22: (generate_citation_network_with_reference_count, "Citation based on reference count"),
    23: (generate_citation_network_with_research_object, "Citation based on research object"),
    24: (generate_citation_network_with_journal_reputation, "Citation based on journal/conference reputation"),
    25: (generate_citation_network_with_open_access, "Rules based on open access")
}

def generate_graph1(type1,type2,num_nodes, mean, std,sim_threshold):
    nodes1,role_id1 = node_generators[type1][0](num_nodes,mean,std)
    nodes2,role_id2 = node_generators[type2][0](num_nodes,mean,std)
    nodes = np.concatenate((nodes1, nodes2), axis=0)

    num_nodes = len(nodes)
    edges = []
    for i in range(num_nodes):
        for j in range(i+1, num_nodes):

            if sim_threshold is not None:

               similarity = cosine_similarity(nodes[i:i+1, :], nodes[j:j+1, :])[0, 0]
               if similarity > sim_threshold:
                  edges.append((i, j))   
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edges)
    role_id = [random.randint(0, 2) for i in range(num_nodes)]

    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()
    if edge_index.size(1) == 0:

       default_value = random.randint(0, num_nodes - 1)
       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)

    node_idx = torch.unique(edge_index.flatten())
    num_nodes = node_idx.size(0)
    max_node_idx = torch.max(node_idx).item()
    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:
       num_nodes = max_node_idx + 1
       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)

    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)
    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]
    edge_index = node_idx_map[edge_index]
    assert node_idx.max() == node_idx.size(0) - 1
    return G,role_id

def generate_graph2(type1,type2,num_nodes, mean, std,partial_sim_threshold, dims):
    nodes1,role_id1 = node_generators[type1][0](num_nodes,mean,std)
    nodes2,role_id2 = node_generators[type2][0](num_nodes,mean,std)
    nodes = np.concatenate((nodes1, nodes2), axis=0)

    num_nodes = len(nodes)
    edges = []
    for i in range(num_nodes):
        for j in range(i+1, num_nodes):

            if partial_sim_threshold is not None and dims is not None:

               sim = cosine_similarity(nodes[i][dims].reshape(1, -1), nodes[j][dims].reshape(1, -1))
               sim=sim[0][0]
               if sim > partial_sim_threshold:
                  edges.append((i, j))     
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edges)
    role_id = [random.randint(0, 2) for i in range(num_nodes)]
    return G,role_id

node_connectors = {
    1: (generate_graph1, "Similar edges with autonomous node similarity judgment"),
    2: (generate_graph2, "Partial similar edges with autonomous multi-dimensional node similarity judgment"),
}
num_papers=10
avg_citations_per_paper=3
num_classes=3
mean=[1.5, 2.0, 1.2, 1.3, 1.8]
std=[1.5, 2.0, 1.2, 1.3, 1.8]
def generate_Y0():
      G0,role_id = create_paper_citation_graph2(num_papers, avg_citations_per_paper, num_classes,mean, std)
      label=0
      return G0, role_id, label
def generate_Y1():
      G1,role_id = create_paper_citation_graph3(num_papers, avg_citations_per_paper, num_classes,mean, std)
      label=1
      return G1, role_id, label
def generate_Y2():
      G2,role_id = create_paper_citation_graph4(num_papers, avg_citations_per_paper, num_classes,mean, std)
      label=2
      return G2, role_id, label
def generate_Y3():
      G3,role_id = create_paper_citation_graph5(num_papers, avg_citations_per_paper, num_classes,mean, std)
      label=3
      return G3, role_id, label
def generate_Y4():
      G4,role_id = create_paper_citation_graph6(num_papers, avg_citations_per_paper, num_classes, mean=[1.5, 2.0, 1.2, 1.3, 1.8], std=[1.5, 2.0, 1.2, 1.3, 1.8],publication_years=None)
      label=4
      return G4, role_id, label
def generate_real_dataset():



        y = random.choice([0, 1, 2 ,3 ,4])
        if y == 0:
           G, role_id, label=generate_Y0()
           motif1_present = True
           motif2_present = True
           motif3_present = False
           motif4_present = False
           motif5_present = False
        elif y == 1:
           G, role_id, label=generate_Y1()
           motif1_present = True
           motif2_present = False
           motif3_present = True
           motif4_present = False
           motif5_present = False
        elif y == 2:
           G, role_id, label=generate_Y2()
           motif1_present = False
           motif2_present = True
           motif3_present = False
           motif4_present = False
           motif5_present = True
        elif y == 3:
           G, role_id, label=generate_Y3()
           motif1_present = False
           motif2_present = False
           motif3_present = False
           motif4_present = True
           motif5_present = True
        elif y == 4:
           G, role_id, label=generate_Y4()
           motif1_present = False
           motif2_present = False
           motif3_present = True
           motif4_present = True
           motif5_present = False
        return G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present

def adjacent_connection(G1,G2):
    nodes1 = set(G1.nodes())
    nodes2 = set(G2.nodes())
    if not nodes1 or not nodes2 or not G1.edges() or not G2.edges():
        return nx.Graph(), [], torch.tensor([], dtype=torch.long)
    common_nodes = nodes1.intersection(nodes2)

    edge1 = random.choice(list(G1.edges()))
    edge2 = random.choice(list(G2.edges()))

    new_node1 = max(nodes1.union(nodes2)) + 1
    new_node2 = max(nodes1.union(nodes2)) + 2

    G1.remove_edge(*edge1)
    G1.add_edge(edge1[0], new_node1)
    G1.add_edge(new_node1, edge1[1])
    G2.remove_edge(*edge2)
    G2.add_edge(edge2[0], new_node2)
    G2.add_edge(new_node2, edge2[1])

    G1.add_node(new_node2)
    G2.add_node(new_node1)

    G = nx.compose(G1, G2)

    G.add_edge(new_node1, new_node2)

    for node in common_nodes:
        G.add_node(node, role_id=np.random.randint(low=1, high=len(common_nodes) + 3))

    if not nx.is_weakly_connected(G):
        components = nx.connected_components(G)
        largest_component = max(components, key=len)
        isolated_nodes = [n for n in G.nodes() if n not in largest_component]
        for u in isolated_nodes:
            v = random.choice(list(largest_component))
            G.add_edge(u, v)


    role_id = [0] * G.number_of_nodes()
    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()


    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()
    if edge_index.size(1) == 0:

       default_value = random.randint(0, num_nodes - 1)
       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)

    node_idx = torch.unique(edge_index.flatten())
    num_nodes = node_idx.size(0)
    max_node_idx = torch.max(node_idx).item()
    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:
       num_nodes = max_node_idx + 1
       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)

    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)
    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]
    edge_index = node_idx_map[edge_index]
    assert node_idx.max() == node_idx.size(0) - 1
    return G,role_id

def generate_false_cause_dataset1():



        mean=[1.0, 2.0, 1.0, 1.5, 3.0]
        std=[1.0, 2.0, 1.0, 1.5, 3.0]
        G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()

        if motif1_present == True:
           G6,role_id= create_author_influence_citation_graph(num_papers, avg_citations_per_paper, num_classes,mean, std)
           graph,role_id= adjacent_connection(G, G6)
        elif motif2_present == True:
           G7,role_id= create_common_citation_count_citation_graph(num_papers, avg_citations_per_paper, num_classes,mean, std)
           graph,role_id= adjacent_connection(G, G7)
        elif motif3_present == True:
           G8,role_id= create_citation_graph_based_on_citation_density(num_papers, avg_citations_per_paper, num_classes,mean, std)
           graph,role_id= adjacent_connection(G, G8)
        elif motif4_present == True:
           G9,role_id= create_citation_graph_based_on_network_structure(num_papers, avg_citations_per_paper, num_classes,mean, std)
           graph,role_id= adjacent_connection(G, G9)
        elif motif5_present == True:
           G10,role_id= create_citation_graph_based_on_author_field(num_papers, avg_citations_per_paper, num_classes,mean, std)
           graph,role_id= adjacent_connection(G, G10)
        else:
           graph,role_id=generate_false_dataset()
        return graph, role_id, label

def generate_false_cause_dataset2():
        G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()

        if motif1_present == True:
           G6,role_id= generate_graph(6,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.3)
           graph,role_id= adjacent_connection(G, G6)
        elif motif2_present == True:
           G7,role_id= generate_graph(7,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.3)
           graph,role_id= adjacent_connection(G, G7)
        elif motif3_present == True:
           G8,role_id= generate_graph(8,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.3)
           graph,role_id= adjacent_connection(G, G8)
        elif motif4_present == True:
           G9,role_id= generate_graph(9,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.5)
           graph,role_id= adjacent_connection(G, G9)
        elif motif5_present == True:
           G10,role_id= generate_graph(10,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.5)
           graph,role_id= adjacent_connection(G, G10)
        return graph, role_id, label

def generate_false_cause_dataset3():
        graph, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()

        if motif1_present == True:
           G6,role_id= generate_graph(6,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.3)
           graph,role_id= adjacent_connection(graph, G6)
        elif motif2_present == True:
           G7,role_id= generate_graph(7,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.3)
           graph,role_id= adjacent_connection(graph, G7)
        elif motif3_present == True:
           G8,role_id= generate_graph(8,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.3)
           graph,role_id= adjacent_connection(graph, G8)
        elif motif4_present == True:
           G9,role_id= generate_graph(9,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.5)
           graph,role_id= adjacent_connection(graph, G9)
        elif motif5_present == True:
           G10,role_id= generate_graph(10,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.5)
           graph,role_id= adjacent_connection(graph, G10)
        return graph, role_id, label

def generate_false_cause_dataset4():
        G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()

        if motif1_present == True:
           G6,role_id= generate_graph(6,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.3)
           graph,role_id= adjacent_connection(G, G6)
        elif motif2_present == True:
           G7,role_id= generate_graph(7,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.3)
           graph,role_id= adjacent_connection(G, G7)
        elif motif3_present == True:
           G8,role_id= generate_graph(8,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.3)
           graph,role_id= adjacent_connection(G, G8)
        elif motif4_present == True:
           G9,role_id= generate_graph(9,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.5)
           graph,role_id= adjacent_connection(G, G9)
        elif motif5_present == True:
           G10,role_id= generate_graph(10,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.5)
           graph,role_id= adjacent_connection(G, G10)
        return graph, role_id, label


def generate_false_cause_dataset5():
        graph, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()

        if motif1_present == True:
           G6,role_id= generate_graph(6,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.3)
           graph,role_id= adjacent_connection(graph, G6)
        elif motif2_present == True:
           G7,role_id= generate_graph(7,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.3)
           graph,role_id= adjacent_connection(graph, G7)
        elif motif3_present == True:
           G8,role_id= generate_graph(8,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.3)
           graph,role_id= adjacent_connection(graph, G8)
        elif motif4_present == True:
           G9,role_id= generate_graph(9,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.5)
           graph,role_id= adjacent_connection(graph, G9)
        elif motif5_present == True:
           G10,role_id= generate_graph(10,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.5)
           graph,role_id= adjacent_connection(graph, G10)
        return graph, role_id, label

def generate_false_cause_dataset6():
        G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()

        if motif1_present == True:
           G6,role_id= generate_graph(6,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.3)
           graph,role_id= adjacent_connection(G, G6)
        elif motif2_present == True:
           G7,role_id= generate_graph(7,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.3)
           graph,role_id= adjacent_connection(G, G7)
        elif motif3_present == True:
           G8,role_id= generate_graph(8,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.3)
           graph,role_id= adjacent_connection(G, G8)
        elif motif4_present == True:
           G9,role_id= generate_graph(9,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.5)
           graph,role_id= adjacent_connection(G, G9)
        elif motif5_present == True:
           G10,role_id= generate_graph(10,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.5)
           graph,role_id= adjacent_connection(G, G10)
        return graph, role_id, label
def generate_false_dataset():
    G,role_id,label=generate_false_cause_dataset1()
    max_distance = 4
    self_citation_prob = 0
    num_domains = 3
    max_chain_length = 5
    diversity_threshold = 2
    max_reference_count = 20
    num_objects = 5
    num_journals = 5
    open_access_prob = 0.5

    mean=[1.0, 2.0, 1.0, 1.5, 3.0]
    std=[1.0, 2.0, 1.0, 1.5, 3.0]







    pgraph1,role_idr= create_paper_citation_graph(num_papers, avg_citations_per_paper, num_classes,mean, std)


    graph,role_id=adjacent_connection(G,pgraph1)
























    return graph,role_id,label

def add_noise(G, delete_edge_prob, add_edge_prob, delete_node_prob, add_node_prob,label=None):

    G_noisy = copy.deepcopy(G)



    num_edges_to_delete = int(delete_edge_prob * G_noisy.number_of_edges())
    edges_to_delete = random.sample(G_noisy.edges(), num_edges_to_delete)
    G_noisy.remove_edges_from(edges_to_delete)
    role_id_noisy = [random.randint(0, 2) for i in range(G_noisy.number_of_nodes())]
    

    num_edges_to_add = int(add_edge_prob * G_noisy.number_of_nodes() * (G_noisy.number_of_nodes()-1)/2)
    for i in range(num_edges_to_add):
        node1, node2 = random.sample(G_noisy.nodes(), 2)
        if not G_noisy.has_edge(node1, node2):
            G_noisy.add_edge(node1, node2)

    num_nodes_to_delete = int(delete_node_prob * G_noisy.number_of_nodes())
    nodes_to_delete = random.sample(G_noisy.nodes(), num_nodes_to_delete)
    for node in nodes_to_delete:
        G_noisy.remove_node(node)

    num_nodes_to_add = int(add_node_prob * G_noisy.number_of_nodes())
    for i in range(num_nodes_to_add):
        node_id = G_noisy.number_of_nodes() + 1
        G_noisy.add_node(node_id)

        connected = False
        while not connected:
            nodes_to_connect = random.sample(G_noisy.nodes(), random.randint(1, G_noisy.number_of_nodes()-1))
            for n in nodes_to_connect:
                if not G_noisy.has_edge(node_id, n):
                    G_noisy.add_edge(node_id, n)
            connected = nx.is_weakly_connected(G_noisy)
            if not connected:
                for n in nodes_to_connect:
                    G_noisy.remove_edge(node_id, n)

    role_id_noisy = [random.randint(0, 2) for i in range(G_noisy.number_of_edges())]
    label_noisy = label

    return G_noisy, role_id_noisy, label_noisy


num_classes         = 5
feature_dim         = 5
num_graphs          = 2000
num_nodes_per_graph = 12
p_edge              = 0.1
noise_std           = 0.05


def generate_citation_network(num_nodes):

    label = np.random.randint(0, num_classes)
    center = np.full((feature_dim,), (label + 1) / num_classes, dtype=float)
    G = nx.DiGraph()
    role_id = np.zeros((num_nodes,), dtype=int)

    for i in range(num_nodes):
        feats = center + np.random.normal(0, noise_std, size=(feature_dim,))
        G.add_node(i, features=feats, label=label)
        role_id[i] = label


    for i in range(num_nodes - 1):
        G.add_edge(i, i + 1)

        for j in range(i - 1):
            if np.random.rand() < p_edge:
                G.add_edge(j, i)

    return G, role_id, label

def generate_dataset():
    edge_index_list   = []
    features_list     = []
    label_list        = []
    role_id_list      = []
    pos_list          = []
    ground_truth_list = []

    for _ in tqdm(range(num_graphs), desc="Generating graphs"):

        n_nodes = num_nodes_per_graph
        G, role_id, label = generate_citation_network(n_nodes)


        edge_index = np.array(list(G.edges), dtype=int).T
        edge_index_list.append(edge_index)


        feats_mat = np.stack(
            [G.nodes[i]['features'] for i in range(n_nodes)],
            axis=0
        )
        features_list.append(feats_mat)


        pos = np.array(list(nx.spring_layout(G).values()))
        pos_list.append(pos)


        gt = find_gd(edge_index, role_id)
        ground_truth_list.append(gt)


        role_id_list.append(role_id)
        label_list.append(label)

    return {
        'edge_index':   edge_index_list,
        'features':     features_list,
        'label':        label_list,
        'role_id':      role_id_list,
        'pos':          pos_list,
        'ground_truth': ground_truth_list,
    }

if __name__ == "__main__":
    os.makedirs(data_dir, exist_ok=True)
    ds = generate_dataset()
    np.save(osp.join(data_dir, 'train.npy'), ds)
    print(f"Saved {num_graphs} graphs (each with {num_nodes_per_graph} nodes) to {osp.join(data_dir, 'train.npy')}")
