import json
import random
import sys

def split_dataset(dataset_name):
    item_file = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}_item_new.jsonl"
    train_file = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}_item_train.jsonl"
    test_file = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}_item_test.jsonl"

    # Split ratio
    train_ratio = 0.90

    print(f"Loading data from {item_file}...")
    products = []
    with open(item_file, "r", encoding="utf-8") as f:
        for line in f:
            products.append(json.loads(line.strip()))

    random.shuffle(products)
    split_idx = int(len(products) * train_ratio)

    train_products = products[:split_idx]
    test_products = products[split_idx:]

    print(f"Saving training set to {train_file}...")
    with open(train_file, "w", encoding="utf-8") as f:
        for product in train_products:
            f.write(json.dumps(product) + "\n")

    print(f"Saving test set to {test_file}...")
    with open(test_file, "w", encoding="utf-8") as f:
        for product in test_products:
            f.write(json.dumps(product) + "\n")

    print(f"Dataset split complete: {len(train_products)} training samples, {len(test_products)} test samples.")
    
    # Generate mappings
    train_mapping_file = f"/home/yqiao47/dataset/{dataset_name}/mapping_train.json"
    test_mapping_file = f"/home/yqiao47/dataset/{dataset_name}/mapping_test.json"
    
    train_asins = [item["parent_asin"] for item in train_products if "parent_asin" in item]
    test_asins = [item["parent_asin"] for item in test_products if "parent_asin" in item]
    
    asin2idx_train = {asin: idx for idx, asin in enumerate(train_asins)}
    asin2idx_test = {asin: idx for idx, asin in enumerate(test_asins)}
    
    idx2asin_train = {str(idx): asin for asin, idx in asin2idx_train.items()}
    idx2asin_test = {str(idx): asin for asin, idx in asin2idx_test.items()}
    
    train_mapping = {
        "asin2idx": asin2idx_train,
        "idx2asin": idx2asin_train
    }
    print(f"Saving training mapping to {train_mapping_file}...")
    with open(train_mapping_file, "w", encoding="utf-8") as f:
        json.dump(train_mapping, f, indent=4)

    test_mapping = {
        "asin2idx": asin2idx_test,
        "idx2asin": idx2asin_test
    }
    print(f"Saving test mapping to {test_mapping_file}...")
    with open(test_mapping_file, "w", encoding="utf-8") as f:
        json.dump(test_mapping, f, indent=4)
    
    print("Mapping files saved successfully!")

def main():
    if len(sys.argv) < 2:
        sys.exit(1)
    
    dataset_name = sys.argv[1]
    split_dataset(dataset_name)

if __name__ == "__main__":
    main()