import json
import os
import sys

def combine_feature_factor():
    feature_file = 'RPM_CODE/GOQA/result/goqa_only_feature.json'
    factor_file = 'RPM_CODE/GOQA/result/goqa_only_factor.json'
    output_file = 'RPM_CODE/GOQA/result/goqa_feature.json'
    
    with open(feature_file, 'r') as f:
        feature_data = json.load(f)
    
    with open(factor_file, 'r') as f:
        factor_data = json.load(f)
    
    result_data = feature_data.copy()
    
    for user_idx, user in enumerate(result_data):
        user_id = user["user_id"]
        
        user_factor = None
        for factor_user in factor_data:
            if str(factor_user["user_id"]) == str(user_id):
                user_factor = factor_user
                break
        
        for profile_item in user["profile"]:
            if "feature" in profile_item:
                for feature in profile_item["feature"]:
                    
                    all_factors = find_all_factors(feature, profile_item["item_id"], user_factor["factorization"]["factors"])
                    
                    if all_factors:
                        feature["factor"] = all_factors
                    else:
                        feature["factor"] = ["Unclassified"]
    
    with open(output_file, 'w') as f:
        json.dump(result_data, f, indent=2)
    
    print(f"Combined data saved to {output_file}")
    
    verify_data_consistency(feature_file, factor_file, output_file)

def find_all_factors(feature, item_id, clusters):
    """Returns a list of all factor names that a feature belongs to."""
    matching_factors = []
    
    for factor_name, features_in_factor in clusters.items():
        if factor_name == "uncovered_features":
            continue
            
        for factor_feature in features_in_factor:
            if (factor_feature["item_id"] == item_id and 
                factor_feature["feature_name"] == feature["feature_name"]):
                matching_factors.append(factor_name)
                
    if "uncovered_features" in clusters:
        for uncovered in clusters["uncovered_features"]:
            if (uncovered["item_id"] == item_id and 
                uncovered["feature_name"] == feature["feature_name"]):
                matching_factors.append("Uncovered")
                break  
    
    return matching_factors

def verify_data_consistency(feature_file, factor_file, feature_factor_file):
    """Load the three files and verify data consistency."""
    print("\nStarting verification: Checking data consistency...")
    
    with open(feature_file, 'r') as f:
        feature_data = json.load(f)
    
    with open(factor_file, 'r') as f:
        factor_data = json.load(f)
    
    with open(feature_factor_file, 'r') as f:
        feature_factor_data = json.load(f)
    
    feature_users = {user["user_id"] for user in feature_data}
    factor_users = {user["user_id"] for user in factor_data}
    feature_factor_users = {user["user_id"] for user in feature_factor_data}
    
    print(f"User count verification: feature={len(feature_users)}, factor={len(factor_users)}, feature_factor={len(feature_factor_users)}")
    
    if feature_users != factor_users or feature_users != feature_factor_users:
        print("Warning: User sets do not match")
    
    missing_factor = 0
    total_features = 0
    list_type_check = 0
    
    for user in feature_factor_data:
        for profile_item in user["profile"]:
            if "feature" in profile_item:
                for feature in profile_item["feature"]:
                    total_features += 1
                    if "factor" not in feature:
                        missing_factor += 1
                    elif not isinstance(feature["factor"], list):
                        list_type_check += 1
    
    print(f"Factor attribute verification: {missing_factor} out of {total_features} features missing factor attribute")
    print(f"Factor type verification: {list_type_check} factors are not list type (all should be list type)")
    
    factor_mismatch = 0
    factor_counts = {} 
    
    for user_idx, user in enumerate(feature_factor_data):
        user_id = user["user_id"]
        
        user_factor = None
        for factor_user in factor_data:
            if factor_user["user_id"] == user_id:
                user_factor = factor_user
                break
        
        factor_map = {}
        for factor_name, features_in_factor in user_factor["factorization"]["factors"].items():
            if factor_name == "uncovered_features":
                continue
            
            for feature in features_in_factor:
                key = (feature["item_id"], feature["feature_name"])
                if key not in factor_map:
                    factor_map[key] = []
                factor_map[key].append(factor_name)
        
        for profile_item in user["profile"]:
            if "feature" in profile_item:
                for feature in profile_item["feature"]:
                    if "factor" in feature:
                        key = (profile_item["item_id"], feature["feature_name"])
                        
                        for f in feature["factor"]:
                            factor_counts[f] = factor_counts.get(f, 0) + 1
                        
                        if key in factor_map and "Unclassified" not in feature["factor"]:
                            expected_factors = sorted(factor_map[key])
                            actual_factors = sorted(feature["factor"])
                            
                            if expected_factors != actual_factors:
                                factor_mismatch += 1
    
    print(f"Factor matching verification: {factor_mismatch} features matched to different factors than expected")
    
    print("\nFactor usage counts (top 10):")
    for factor, count in sorted(factor_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
        print(f"  {factor}: {count} times")
    
    print("\nVerification complete")

if __name__ == "__main__":
    combine_feature_factor()