import json
import pandas as pd
import random
import argparse
from itertools import combinations
from tqdm import tqdm


def generate_test_data(dataset_name, k=10, neg_samples=20):
    output_mapping_file = f"/home/yqiao47/dataset/{dataset_name}/mapped_user_asin_sampled.json"
    user_cluster_file = f"/home/yqiao47/dataset/{dataset_name}/user_clusters_{k}_train.csv"
    test_data_file = f"/home/yqiao47/dataset/{dataset_name}/test_data.json"

    with open(output_mapping_file, 'r', encoding='utf-8') as f:
        user_asin_map = json.load(f)

    user_cluster_df = pd.read_csv(user_cluster_file)
    user_cluster_map = dict(zip(user_cluster_df["user_id"].astype(str), user_cluster_df["cluster_kmeans"]))
    all_asins = set(asin for asins in user_asin_map.values() for asin in asins)

    test_data = []
    total_users = len(user_asin_map)
    for user_id, asins in tqdm(user_asin_map.items(), desc="Processing users", total=total_users):
        if len(asins) < 2:
            continue

        if user_id not in user_cluster_map:
            continue
        cluster = user_cluster_map[user_id]

        for asin1, asin2 in combinations(asins, 2):
            # Select negative samples (items not purchased by the user)
            negative_asins = list(all_asins - set(asins))
            negative_samples = random.sample(negative_asins, min(neg_samples, len(negative_asins)))

            test_data.append({
                "node": asin1,
                "edge_info": cluster,
                "positive_sample": asin2,
                "negative_samples": negative_samples
            })

    with open(test_data_file, 'w', encoding='utf-8') as f:
        for item in test_data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

def main():
    parser = argparse.ArgumentParser(description='Generate test data from user clusters')
    parser.add_argument('--dataset_name', type=str, required=True, 
                        help='Name of the dataset')
    parser.add_argument('--k', type=int, default=10, 
                        help='Number of clusters (default: 10)')
    parser.add_argument('--neg_samples', type=int, default=20,
                        help='Number of negative samples per positive pair (default: 20)')
    
    args = parser.parse_args()
    generate_test_data(args.dataset_name, args.k, args.neg_samples)


if __name__ == "__main__":
    main()