import json
import os
import re
import sys
from tqdm import tqdm
from collections import defaultdict, Counter

def get_file_line_count(file_path):
    with open(file_path, 'r') as f:
        return sum(1 for _ in f)

def preprocess_dataset(dataset_name, top_categories=10):
    print(f"Preprocessing for dataset: {dataset_name}")
    
    base_dir = f"/home/yqiao47/dataset/{dataset_name}"
    meta_data = os.path.join(base_dir, f"meta_{dataset_name}.jsonl")
    review_data = os.path.join(base_dir, f"{dataset_name}.jsonl")
    image_folder = os.path.join(base_dir, "images")

    item_output_file = os.path.join(base_dir, f"{dataset_name}_item_new.jsonl")
    review_output_file = os.path.join(base_dir, f"{dataset_name}_review.jsonl")
    
    mapping_file = os.path.join(base_dir, "asin2idx.json")
    user_mapping_file = os.path.join(base_dir, "user2idx.json")
    
    # Check which output files already exist
    need_items = not os.path.exists(item_output_file)
    need_reviews = not os.path.exists(review_output_file)
    need_asin_mapping = not os.path.exists(mapping_file)
    need_user_mapping = not os.path.exists(user_mapping_file)
    
    # If all files exist, we can exit early
    if not need_items and not need_reviews and not need_asin_mapping and not need_user_mapping:
        print(f"All output files already exist for {dataset_name}. Skipping processing.")
        return None
    
    # Report what we need to process
    print("Files that need to be generated:")
    if need_items:
        print(f"  - {dataset_name}_item_new.jsonl")
    if need_reviews:
        print(f"  - {dataset_name}_review.jsonl")
    if need_asin_mapping:
        print(f"  - asin2idx.json")
    if need_user_mapping:
        print(f"  - user2idx.json")
    
    # Check input files
    if not os.path.exists(meta_data):
        print(f"Error: Meta data file not found: {meta_data}")
        return None
    if not os.path.exists(review_data):
        print(f"Error: Review data file not found: {review_data}")
        return None
    if not os.path.exists(image_folder):
        print(f"Warning: Image folder not found: {image_folder}")
        print("Creating image folder...")
        os.makedirs(image_folder, exist_ok=True)
    
    stats = {
        "total_items": 0,
        "no_image_items": 0,
        "no_title_items": 0,
        "no_categories_items": 0,
        "isolated_items": 0,
        "final_items": 0,
        "total_reviews": 0,
        "filtered_reviews": 0,
        "single_purchase_users": 0
    }
    
    # PHASE 1: Process items and create ASIN mapping
    processed_items_wo_isolation = []
    asin2idx = {}
    idx2asin = {}
    
    # Load existing ASIN mapping if we don't need to create it
    if not need_asin_mapping and not need_items:
        print("Loading existing ASIN mapping...")
        with open(mapping_file, 'r') as f:
            mapping_data = json.load(f)
            asin2idx = mapping_data.get("asin2idx", {})
            idx2asin = mapping_data.get("idx2asin", {})
    elif not need_items and need_asin_mapping:
        # If items exist but mapping doesn't, load items to create mapping
        print("Loading existing item data to create ASIN mapping...")
        with open(item_output_file, 'r') as f:
            for line in tqdm(f, desc="Loading items", total=get_file_line_count(item_output_file)):
                item = json.loads(line.strip())
                processed_items_wo_isolation.append(item)
        
        # Create ASIN to index mapping from loaded items
        print(f"Creating ASIN to index mapping from existing items...")
        for idx, item in enumerate(processed_items_wo_isolation):
            asin = item["parent_asin"]
            asin2idx[asin] = idx
            idx2asin[str(idx)] = asin
        
        # Save the mapping
        print(f"Saving ASIN mapping to: {mapping_file}")
        with open(mapping_file, 'w', encoding='utf-8') as f:
            json.dump({"asin2idx": asin2idx, "idx2asin": idx2asin}, f, indent=4)
    else:
        # We need to process items from scratch
        processed_items = []
        processed_items_asin = set()
        level_1_counts = Counter()

        # Process metadata
        with open(meta_data, 'r') as f:
            for line in tqdm(f, total=get_file_line_count(meta_data), desc="Processing items", unit="item"):
                stats["total_items"] += 1
                item = json.loads(line.strip())
                parent_asin = item.get('parent_asin')
                
                image_path = os.path.join(image_folder, f"{parent_asin}_MAIN.jpg")
                if not os.path.exists(image_path):
                    stats["no_image_items"] += 1
                    continue
                
                title = item.get('title')
                if not title:
                    stats["no_title_items"] += 1
                    continue
                
                if not re.match(r'^[\x00-\x7F]+$', title):
                    stats["no_title_items"] += 1
                    continue
                
                category = item.get('categories')
                if not category:
                    stats["no_categories_items"] += 1
                    continue
                    
                if isinstance(category, list) and len(category) > 0:
                    formatted_dataset_name = dataset_name.replace('_', ' & ').replace('and', '&')
                    if category[0] == formatted_dataset_name or category[0] == dataset_name.replace('_', ' '):
                        category = category[1:]
                    
                    # Count level 1 categories for later filtering
                    if len(category) > 0:
                        level_1_counts[category[0]] += 1
                
                # Get rating
                rate = item.get('rating_number', 0)
                
                # Create processed item
                processed_data = {
                    'parent_asin': parent_asin,
                    'title': title,
                    'category': category,
                    'rate': rate
                }
                processed_items.append(processed_data)
        
        print(f"\nMetadata statistics:")
        print(f"  Total items: {stats['total_items']}")
        print(f"  No image items: {stats['no_image_items']}")
        print(f"  No title items: {stats['no_title_items']}")
        print(f"  No categories items: {stats['no_categories_items']}")
        print(f"  Items after initial filtering: {len(processed_items)}")
        
        # Get top categories
        top_n_labels = [label for label, _ in level_1_counts.most_common(top_categories)]
        print(f"\nTop {top_categories} Level 1 tags:")
        for label in top_n_labels:
            count = level_1_counts[label]
            print(f"  {label}: {count} items")
        
        # Filter by top categories
        print("\nFiltering dataset to keep only top categories...")
        processed_items_top = []
        for item in tqdm(processed_items, desc="Filtering by category"):
            categories = item.get('category', [])
            if isinstance(categories, list) and len(categories) > 0 and categories[0] in top_n_labels:
                item['label'] = categories[0]  # Add the label for clarity
                processed_items_top.append(item)
                processed_items_asin.add(item.get('parent_asin'))
        
        print(f"Items after category filtering: {len(processed_items_top)}")
        
        # Build co-purchase relationships if needed for item filtering
        if need_items:
            item_users = defaultdict(list)
            item_connections = defaultdict(set)
            
            # Load user-item interactions
            print('Loading user-item interactions...')
            with open(review_data, 'r') as f:
                for line in tqdm(f, total=get_file_line_count(review_data), desc="Loading reviews"):
                    stats["total_reviews"] += 1
                    review = json.loads(line.strip())
                    item_id = review.get('parent_asin')
                    if item_id not in processed_items_asin:
                        continue
                    user_id = review.get('user_id')
                    item_users[user_id].append(item_id)
            
            # Build item connections
            print('Building item connections...')
            for user, items in tqdm(item_users.items(), total=len(item_users), desc="Building connections"):
                if len(items) < 2:
                    continue
                for i, item1 in enumerate(items):
                    for item2 in items[i + 1:]:
                        # Add bidirectional connections
                        item_connections[item1].add(item2)
                        item_connections[item2].add(item1)
            
            all_nodes = set()
            for item, connections in item_connections.items():
                all_nodes.add(item)
                all_nodes.update(connections)
            
            isolated_items = processed_items_asin - all_nodes
            stats["isolated_items"] = len(isolated_items)
            
            print(f"Total unique connected nodes: {len(all_nodes)}")
            print(f"Isolated items (to be removed): {stats['isolated_items']}")
            
            # Filter out isolated items
            for item in tqdm(processed_items_top, desc="Filtering isolated items"):
                parent_asin = item.get('parent_asin')
                if parent_asin in isolated_items:
                    continue
                else:
                    processed_items_wo_isolation.append(item)
                    stats["final_items"] += 1
            
            # Sort items by rating
            processed_items_wo_isolation = sorted(
                processed_items_wo_isolation,
                key=lambda x: x.get('rate', 0),
                reverse=True
            )
            
            print(f"Final item count after all filtering: {stats['final_items']}")
        
            # Save the processed items
            print(f"\nSaving processed items to: {item_output_file}")
            with open(item_output_file, 'w', encoding='utf-8') as f:
                for item in processed_items_wo_isolation:
                    json_line = json.dumps(item, ensure_ascii=False)
                    f.write(json_line + '\n')
        
        # Only create mapping if items were processed or if mapping is needed
        if processed_items_wo_isolation and need_asin_mapping:
            # Create ASIN to index mapping
            print(f"Creating ASIN to index mapping...")
            
            for idx, item in enumerate(processed_items_wo_isolation):
                asin = item["parent_asin"]
                asin2idx[asin] = idx
                idx2asin[str(idx)] = asin 
            
            # Save the mapping if needed
            print(f"Saving ASIN mapping to: {mapping_file}")
            with open(mapping_file, 'w', encoding='utf-8') as f:
                json.dump({"asin2idx": asin2idx, "idx2asin": idx2asin}, f, indent=4)
    
    # PHASE 2: Process reviews and user mapping
    if need_reviews or need_user_mapping:
        # If we need to create reviews or user mapping but don't have ASIN mapping, we need to load it
        if not asin2idx and os.path.exists(mapping_file):
            print("Loading ASIN mapping...")
            with open(mapping_file, 'r') as f:
                mapping_data = json.load(f)
                asin2idx = mapping_data.get("asin2idx", {})
        
        if not asin2idx:
            print("Cannot process reviews: ASIN mapping not available.")
            return stats
        
        # Create a mapping for user IDs
        user_purchase_count = defaultdict(int)
        user_id_to_idx = {}
        
        # Load existing user mapping if we don't need to create it
        if not need_user_mapping and os.path.exists(user_mapping_file):
            print("Loading existing user mapping...")
            with open(user_mapping_file, 'r') as f:
                user_id_to_idx = json.load(f)
        else:
            # Process review data
            raw_review_data = []
            with open(review_data, 'r') as f:
                for line in tqdm(f, total=get_file_line_count(review_data), desc="Counting user purchases"):
                    item = json.loads(line.strip())
                    parent_asin = item.get('parent_asin')
                    if parent_asin in asin2idx:
                        asin_id = asin2idx[parent_asin]
                        user_id = item.get('user_id')
                        user_purchase_count[user_id] += 1
                        raw_review_data.append({"user_id": user_id, "parent_asin": asin_id})
                    else:
                        stats["filtered_reviews"] += 1
            
            # Create user ID to index mapping
            user_index = 0
            for user_id, purchase_count in user_purchase_count.items():
                if purchase_count > 1:  # Only include users with multiple purchases
                    user_id_to_idx[user_id] = user_index
                    user_index += 1
            
            # Save the user ID mapping if needed
            if need_user_mapping:
                print(f"Saving user ID mapping to: {user_mapping_file}")
                with open(user_mapping_file, 'w') as f:
                    json.dump(user_id_to_idx, f, ensure_ascii=False, indent=4)
        
        # Only process and save reviews if needed
        if need_reviews:
            # If we already processed the reviews in the previous step
            if 'raw_review_data' in locals():
                cleaned_review_data = []
                for interaction in tqdm(raw_review_data, desc="Filtering single-purchase users"):
                    user_id = interaction["user_id"]
                    if user_purchase_count[user_id] > 1:  # Keep only users with multiple purchases
                        user_idx = user_id_to_idx.get(user_id)
                        if user_idx is not None:
                            cleaned_review_data.append({
                                "user_id": user_idx, 
                                "parent_asin": interaction["parent_asin"]
                            })
                        else:
                            stats["single_purchase_users"] += 1
                
                # Save the cleaned review dataset
                print(f"Saving processed reviews to: {review_output_file}")
                with open(review_output_file, 'w') as f:
                    for item in cleaned_review_data:
                        f.write(json.dumps(item, ensure_ascii=False) + '\n')
            else:
                # We need to process the reviews from scratch
                print("Processing reviews from scratch...")
                cleaned_review_data = []
                
                with open(review_data, 'r') as f:
                    for line in tqdm(f, total=get_file_line_count(review_data), desc="Processing reviews"):
                        item = json.loads(line.strip())
                        parent_asin = item.get('parent_asin')
                        if parent_asin in asin2idx:
                            asin_id = asin2idx[parent_asin]
                            user_id = item.get('user_id')
                            
                            # Only include users with known mapping (multiple purchases)
                            user_idx = user_id_to_idx.get(user_id)
                            if user_idx is not None:
                                cleaned_review_data.append({
                                    "user_id": user_idx, 
                                    "parent_asin": asin_id
                                })
                        else:
                            stats["filtered_reviews"] += 1
                
                # Save the cleaned review dataset
                print(f"Saving processed reviews to: {review_output_file}")
                with open(review_output_file, 'w') as f:
                    for item in cleaned_review_data:
                        f.write(json.dumps(item, ensure_ascii=False) + '\n')
                
                stats["single_purchase_users"] = len(user_purchase_count) - len(user_id_to_idx)
    
    # Calculate stats for final report
    print("\n--- PREPROCESSING COMPLETE ---")
    print(f"Final statistics for {dataset_name}:")
    
    # Get stats for existing files if we didn't process them
    if not need_items and need_asin_mapping:
        # Count lines in item file
        with open(item_output_file, 'r') as f:
            stats["final_items"] = sum(1 for _ in f)
    
    if not need_reviews:
        # Count lines in review file
        with open(review_output_file, 'r') as f:
            review_count = sum(1 for _ in f)
    else:
        review_count = len(cleaned_review_data) if 'cleaned_review_data' in locals() else 0
    
    print(f"  Items processed: {stats['total_items'] if 'total_items' in stats and stats['total_items'] > 0 else 'N/A'}")
    print(f"  Final items kept: {stats['final_items']}")
    print(f"  Total reviews: {stats['total_reviews'] if 'total_reviews' in stats and stats['total_reviews'] > 0 else 'N/A'}")
    print(f"  Final reviews kept: {review_count}")
    print(f"  Users with multiple purchases: {len(user_id_to_idx) if user_id_to_idx else 'N/A'}")
    print(f"  Single purchase users (filtered): {stats['single_purchase_users'] if 'single_purchase_users' in stats else 'N/A'}")
    
    return stats

def main():
    if len(sys.argv) < 2:
        print("Usage: python preprocess_dataset.py dataset_name [top_categories]")
        sys.exit(1)
    
    dataset_name = sys.argv[1]
    top_categories = int(sys.argv[2]) if len(sys.argv) > 2 else 10
    
    try:
        preprocess_dataset(dataset_name, top_categories)
    except Exception as e:
        print(f"Error preprocessing dataset {dataset_name}: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)

if __name__ == "__main__":
    main()