# File: sample_community_relations.py
# Description: This script samples a subgraph from the original DBLP dataset.


import random

def load_communities(file_path):
    communities = []
    with open(file_path, "r") as file:
        for line in file:
            community = list(map(int, line.strip().split()))
            communities.append(community)
    return communities

def get_all_vertices(communities):
    vertices = set()
    for community in communities:
        vertices.update(community)
    return list(vertices)

def sample_vertices(vertices, sample_size, seed=42):
    random.seed(seed)  
    sampled = random.sample(vertices, sample_size)
    return sorted(sampled) 

def build_relation_pairs(sampled_vertices, communities):
    vertex_set = set(sampled_vertices)
    community_map = {}

    for community in communities:
        community_vertices = set(community)
        for vertex in community_vertices & vertex_set:
            community_map[vertex] = community_map.get(vertex, set()) | community_vertices

    pairs = []
    n = len(sampled_vertices)
    for i in range(n):
        for j in range(i + 1, n):
            v1, v2 = sampled_vertices[i], sampled_vertices[j]
            
            if v2 in community_map.get(v1, set()):
                pairs.append((v1, v2, 0))  
            else:
                pairs.append((v1, v2, 1))  
    return pairs

def save_pairs_to_file(pairs, output_path):
    with open(output_path, "w") as file:
        for v1, v2, relation in pairs:
            file.write(f"{v1} {v2} {relation}\n")

def main(input_file, output_file, sample_size=10000):
    communities = load_communities(input_file)
    vertices = get_all_vertices(communities)
    sampled_vertices = sample_vertices(vertices, sample_size)
    relation_pairs = build_relation_pairs(sampled_vertices, communities)
    save_pairs_to_file(relation_pairs, output_file)
    print(f"Sampled relation pairs saved to {output_file}.")

if __name__ == "__main__":
    input_file = "../data/dblp/com-dblp.top5000.cmty.txt"  # replace with your file path
    output_file = "../data/dblp/dblp10000.txt"
    main(input_file, output_file)