import json
import sys
import random
from tqdm import tqdm
from collections import defaultdict

def get_file_line_count(file_path):
    with open(file_path, 'r') as f:
        return sum(1 for _ in f)

def prepare_datasets(dataset_name, sampling_ratio=20):
    raw_input_file = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}.jsonl"
    user_id_to_idx_file = f"/home/yqiao47/dataset/{dataset_name}/user2idx.json"
    
    # Training paths
    train_meta_file = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}_item_train.jsonl"
    train_mapping_file = f"/home/yqiao47/dataset/{dataset_name}/mapping_train.json"
    
    train_output_file = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}_review_train.jsonl"

    # Testing paths
    test_meta_file = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}_item_test.jsonl"
    test_mapping_file = f"/home/yqiao47/dataset/{dataset_name}/mapping_test.json"
    
    test_output_file = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}_review_test.jsonl"
    test_output_sampled_file = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}_review_new_sampled.jsonl"
    test_mapping_sampled_file = f"/home/yqiao47/dataset/{dataset_name}/mapped_user_asin_sampled.json"
    
    # ======= Processing Training Data =======
    with open(train_mapping_file, 'r') as f:
        mapping = json.load(f)
    train_parent_asin_to_id = mapping["asin2idx"]

    print("Loading training metadata...")
    train_asin_set = set()
    with open(train_meta_file, 'r') as f:
        for line in tqdm(f, total=get_file_line_count(train_meta_file), desc="Processing train items"):
            item = json.loads(line.strip())
            train_asin_set.add(item['parent_asin'])
    
    user_id_to_idx = {}
    user_idx = 0
    
    # Process training review data
    print("Loading and processing training reviews...")
    train_user_purchase_count = defaultdict(int)
    train_raw_review_data = []
    train_delete_num = 0
    
    with open(raw_input_file, 'r') as f:
        for line in tqdm(f, total=get_file_line_count(raw_input_file), desc="Processing train reviews"):
            item = json.loads(line.strip())
            if item['parent_asin'] in train_asin_set and item['parent_asin'] in train_parent_asin_to_id:
                asin_id = train_parent_asin_to_id[item['parent_asin']]
                user_id = item['user_id']
                train_user_purchase_count[user_id] += 1
                train_raw_review_data.append({"user_id": user_id, "parent_asin": asin_id})
            else:
                train_delete_num += 1
    
    # Filter users, discard those with only one purchase
    print("Filtering training reviews...")
    train_cleaned_review_data = []
    train_filtered_users = set()
    
    for interaction in tqdm(train_raw_review_data, desc="Filtering train reviews"):
        user_id = interaction["user_id"]
        if train_user_purchase_count[user_id] > 1:
            if user_id not in user_id_to_idx:
                user_id_to_idx[user_id] = user_idx
                user_idx += 1
                
            train_cleaned_review_data.append({
                "user_id": user_id_to_idx[user_id],
                "parent_asin": interaction["parent_asin"]
            })
        else:
            train_filtered_users.add(user_id)
    
    train_final_num = len(train_cleaned_review_data)
    train_filtered_user_count = sum(1 for count in train_user_purchase_count.values() if count == 1)
    
    print(f"Training data statistics:")
    print(f"  Final count: {train_final_num}")
    print(f"  Deleted count (invalid products): {train_delete_num}")
    print(f"  Filtered users (only 1 purchase): {train_filtered_user_count}")
    
    # Save the training dataset and the user ID mapping
    print(f"Saving training dataset...")
    with open(train_output_file, 'w') as f:
        for item in train_cleaned_review_data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

    print(f"Saving user ID mapping...")
    with open(user_id_to_idx_file, 'w') as f:
        json.dump(user_id_to_idx, f, ensure_ascii=False, indent=4)
    
    # ======= Processing Testing Data =======
    with open(test_mapping_file, 'r') as f:
        mapping = json.load(f)
    test_parent_asin_to_id = mapping["asin2idx"]
    
    test_asin_set = set()
    with open(test_meta_file, 'r') as f:
        for line in tqdm(f, total=get_file_line_count(test_meta_file), desc="Processing test items"):
            item = json.loads(line.strip())
            test_asin_set.add(item['parent_asin'])
    
    test_user_purchase_count = defaultdict(int)
    test_raw_review_data = []
    test_delete_num = 0
    
    with open(raw_input_file, 'r') as f:
        for line in tqdm(f, total=get_file_line_count(raw_input_file), desc="Processing test reviews"):
            item = json.loads(line.strip())
            if item['parent_asin'] in test_asin_set and item['parent_asin'] in test_parent_asin_to_id:
                asin_id = item['parent_asin']  # Use original ASIN for test set
                user_id = item['user_id']
                test_user_purchase_count[user_id] += 1
                test_raw_review_data.append({"user_id": user_id, "parent_asin": asin_id})
            else:
                test_delete_num += 1
    
    print("Filtering testing reviews...")
    test_filtered_reviews = []
    test_review_data_with_idx = [] 
    test_filtered_users = set()
    
    for interaction in tqdm(test_raw_review_data, desc="Filtering test reviews"):
        user_id = interaction["user_id"]
        if test_user_purchase_count[user_id] > 1 and user_id in user_id_to_idx:
            # For sampled data
            test_filtered_reviews.append({
                "user_id": user_id_to_idx[user_id],
                "asin": interaction["parent_asin"]
            })
            
            test_review_data_with_idx.append({
                "user_id": user_id_to_idx[user_id],
                "parent_asin": test_parent_asin_to_id[interaction["parent_asin"]]
            })
        else:
            test_filtered_users.add(user_id)
    
    print(f"Saving full test dataset...")
    with open(test_output_file, 'w') as f:
        for item in test_review_data_with_idx:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    
    user_asin_map = defaultdict(set)
    for review in test_filtered_reviews:
        user_asin_map[str(review["user_id"])].add(review["asin"])
    
    user_asin_map = {user: list(asins) for user, asins in user_asin_map.items()}
    
    # Sample 1/sampling_ratio of users
    user_list = list(user_asin_map.keys())
    selected_users = set(random.sample(user_list, max(1, len(user_list) // sampling_ratio)))

    sampled_user_asin_map = {user: user_asin_map[user] for user in selected_users}
    sampled_filtered_reviews = [review for review in test_filtered_reviews if str(review["user_id"]) in selected_users]
    
    # Save sampled test data
    print(f"Saving sampled test dataset...")
    with open(test_output_sampled_file, 'w') as f:
        for review in sampled_filtered_reviews:
            f.write(json.dumps(review, ensure_ascii=False) + '\n')
    
    print(f"Saving sampled user-ASIN mapping...")
    with open(test_mapping_sampled_file, 'w') as f:
        json.dump(sampled_user_asin_map, f, ensure_ascii=False, indent=4)
    
    # Print statistics
    print(f"\nTest data statistics:")
    print(f"  Total test reviews: {len(test_review_data_with_idx)}")
    print(f"  Total users before sampling: {len(user_asin_map)}")
    print(f"  Total users after sampling: {len(sampled_user_asin_map)}")
    print(f"  Total reviews after sampling: {len(sampled_filtered_reviews)}")
    
    print("\nAll datasets prepared successfully!")

def main():
    if len(sys.argv) < 2:
        sys.exit(1)
    
    dataset_name = sys.argv[1]
    sampling_ratio = int(sys.argv[2]) if len(sys.argv) > 2 else 20
    
    prepare_datasets(dataset_name, sampling_ratio)

if __name__ == "__main__":
    main()