import json
import csv
import sys
from collections import defaultdict

def load_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def load_jsonl(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data

def load_csv(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        header = next(reader)  
        for row in reader:
            data.append(dict(zip(header, row)))
    return data

def mine_edges(dataset_name, k):

    review_data = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}_review.jsonl"
    cluster_data = f"/home/yqiao47/dataset/{dataset_name}/user_clusters_{k}.csv"

    output_edge_file = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}_edge_{k}.jsonl"

    user_clusters = load_csv(cluster_data)
    reviews = load_jsonl(review_data)

    user_cluster_map = {row["user_id"]: row["cluster_kmeans"] for row in user_clusters}

    user_purchases = defaultdict(set) 
    for review in reviews:
        user_id = str(review["user_id"])
        product_id = str(review["parent_asin"])
        user_purchases[user_id].add(product_id)

    edges = []
    for user_id, products in user_purchases.items():
        product_list = list(products)
        cluster_id = user_cluster_map.get(str(user_id))
        
        for i in range(len(product_list)):
            for j in range(i + 1, len(product_list)):
                edges.append({
                    "product1": product_list[i],
                    "product2": product_list[j],
                    "user_id": user_id,
                    "cluster_id": cluster_id
                })
                
        for product in product_list:
            edges.append({
                "product1": product,
                "product2": product,
                "user_id": "N/A",
                "cluster_id": 10
            })

    with open(output_edge_file, 'w', encoding='utf-8') as f:
        for edge in edges:
            f.write(json.dumps(edge) + '\n')

    print(f"Edge file saved to {output_edge_file}. Total edges: {len(edges)}")

def main():
    if len(sys.argv) < 3:
        sys.exit(1)
    
    dataset_name = sys.argv[1]
    k = sys.argv[2]
    
    mine_edges(dataset_name, k)

if __name__ == "__main__":
    main()