import json
import csv
import sys
from collections import defaultdict
import os
from tqdm import tqdm

def load_csv(file_path):
    user_cluster_map = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        header = next(reader)
        user_idx = header.index("user_id")
        cluster_idx = header.index("cluster_kmeans")
        
        for row in reader:
            user_cluster_map[row[user_idx]] = int(row[cluster_idx])
    return user_cluster_map

def stream_jsonl(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            yield json.loads(line.strip())

def mine_edges(dataset_name, k, split):
    review_data = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}_review_{split}.jsonl"
    cluster_data = f"/home/yqiao47/dataset/{dataset_name}/user_clusters_{k}_train.csv"
    
    output_edge_file = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}_edge_{split}.jsonl"

    user_cluster_map = load_csv(cluster_data)
    user_purchases = defaultdict(set)
    
    # Process reviews
    for review in tqdm(stream_jsonl(review_data), desc="Loading reviews"):
        user_id = str(review["user_id"])
        product_id = str(review["parent_asin"])
        user_purchases[user_id].add(product_id)
    
    edge_count = 0
    self_connection_count = 0
    
    with open(output_edge_file, 'w', encoding='utf-8') as f:
        user_ids = list(user_purchases.keys())
        batch_size = 1000
        
        for i in range(0, len(user_ids), batch_size):
            batch_user_ids = user_ids[i:i+batch_size]
            
            for user_id in tqdm(batch_user_ids, desc=f"Processing users batch {i//batch_size + 1}/{(len(user_ids)+batch_size-1)//batch_size}", leave=False):
                cluster_id = user_cluster_map.get(user_id)
                if cluster_id is None:
                    continue
                
                products = list(user_purchases[user_id])
                for i in range(len(products)):
                    edge = {
                        "product1": products[i],
                        "product2": products[i],
                        "user_id": "self_generated",
                        "cluster_id": 10
                    }
                    f.write(json.dumps(edge) + '\n')
                    self_connection_count += 1

                    for j in range(i + 1, len(products)):
                        edge = {
                            "product1": products[i],
                            "product2": products[j],
                            "user_id": user_id,
                            "cluster_id": cluster_id
                        }
                        f.write(json.dumps(edge) + '\n')
                        edge_count += 1
    
    total_edges = edge_count + self_connection_count
    print(f"Edge file saved: {edge_count} co-purchase edges, {self_connection_count} self-connections")
    return total_edges

def main():
    if len(sys.argv) < 3:
        print("Usage: python mine_edges.py dataset_name k")
        sys.exit(1)
    
    dataset_name = sys.argv[1]
    k = sys.argv[2]
    
    # === Processing Training Split ===
    train_edges = mine_edges(dataset_name, k, "train")
    # === Processing Testing Split ===
    test_edges = mine_edges(dataset_name, k, "test")
    
    print(f"\nCompleted edge mining for {dataset_name}:")
    print(f"  Training edges: {train_edges}")
    print(f"  Testing edges: {test_edges}")

if __name__ == "__main__":
    main()