import os
import sys
import pandas as pd
from collections import defaultdict
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
import json

import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')

def run_user_clustering(dataset_name, k=10, split='train'):
    print(f"Running user clustering for dataset: {dataset_name}, split: {split}")

    parent_asin_to_id_file = f"/home/yqiao47/dataset/{dataset_name}/asin2idx.json"
    item_file_path = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}_item_{split}.jsonl"
    review_file_path = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}_review_{split}.jsonl"

    output_dir = f"/home/yqiao47/dataset/{dataset_name}"

    with open(parent_asin_to_id_file, 'r') as f:
        mapping = json.load(f)
    parent_asin_to_id = mapping["asin2idx"]
    
    item_categories = {}
    with open(item_file_path, "r") as f:
        for line in f:
            item = json.loads(line)
            item_id = parent_asin_to_id.get(item['parent_asin'], item.get('parent_asin'))
            if 'label' in item:
                item_categories[item_id] = item['label']

    print("Loading user interactions...")
    user_item_interactions = defaultdict(list)
    with open(review_file_path, "r") as f:
        for line in f:
            interaction = json.loads(line)
            user_id = interaction['user_id']
            item_id = interaction['parent_asin']
            if user_id is not None:
                user_item_interactions[user_id].append(item_id)
    
    # Count user purchases by category
    print("Counting user purchases by category...")
    user_category_counts = defaultdict(lambda: defaultdict(int))
    for user_id, items in user_item_interactions.items():
        for item_id in items:
            category = item_categories.get(item_id)
            if category:
                user_category_counts[user_id][category] += 1
    
    user_category_df = pd.DataFrame(user_category_counts).fillna(0).T
    user_category_df.index.name = "user_id"
    user_category_df.reset_index(inplace=True)
    
    user_category_df["total_purchases"] = user_category_df.drop(columns=["user_id"]).sum(axis=1)
    user_category_df = user_category_df[user_category_df["total_purchases"] > 0]
    
    for category in user_category_df.columns[1:-1]:
        user_category_df[category] = user_category_df[category] / user_category_df["total_purchases"]
    
    # Remove total_purchases column
    user_category_df = user_category_df.drop(columns=["total_purchases"])
    
    # Normalize data
    X = user_category_df.drop(columns=["user_id"])
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    # K-Means clustering
    print(f"Running K-Means clustering with k={k}...")
    kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
    user_category_df["cluster_kmeans"] = kmeans.fit_predict(X_scaled)
    
    # GMM clustering
    print(f"Running GMM clustering with k={k}...")
    gmm = GaussianMixture(n_components=k, random_state=42)
    user_category_df["cluster_gmm"] = gmm.fit_predict(X_scaled)
    
    output_file = os.path.join(output_dir, f"user_clusters_{k}_{split}.csv")
    user_category_df.to_csv(output_file, index=False)
    print(f"Clustering results saved as {output_file}")
    
    cluster_category_distribution = user_category_df.groupby("cluster_kmeans").mean()
    category_columns = [col for col in cluster_category_distribution.columns if col not in ["user_id", "cluster_gmm"]]
    cluster_category_distribution = cluster_category_distribution[category_columns]
    
    # Generate natural language descriptions
    cluster_descriptions = {}
    for cluster_id, row in cluster_category_distribution.iterrows():
        sorted_categories = sorted([(category, prob) for category, prob in row.items()], 
                                key=lambda x: x[1], reverse=True)
        category_probs = [f"{category} ({prob:.2%})" for category, prob in sorted_categories]
        description = f"Users in this cluster purchase " + ", ".join(category_probs) + "."
        cluster_descriptions[f"Cluster {cluster_id}"] = description
        
    cluster_descriptions[f"Cluster {cluster_id+1}"] = "These two items represent the same product. Align their features accordingly."

    print("\n==== Cluster Purchase Probabilities Summary ====\n")
    for cluster_id, description in cluster_descriptions.items():
        print(f"{cluster_id}: {description}")

    with open(os.path.join(output_dir, f"cluster_{k}_description_{split}.json"), "w") as f:
        json.dump(cluster_descriptions, f, indent=2)
    
    # Count users per cluster
    kmeans_cluster_counts = user_category_df["cluster_kmeans"].value_counts().sort_index()
    gmm_cluster_counts = user_category_df["cluster_gmm"].value_counts().sort_index()
    
    cluster_counts_df = pd.DataFrame({
        "K-Means": kmeans_cluster_counts,
        "GMM": gmm_cluster_counts
    }).fillna(0)
    
    counts_file = os.path.join(output_dir, f"cluster_counts_{k}_{split}.csv")
    cluster_counts_df.to_csv(counts_file, index=True)
    
    # K-Means plot
    plt.figure(figsize=(10, 5))
    kmeans_cluster_counts.plot(kind="bar", color="blue", alpha=0.7)
    plt.title(f"User Count per Cluster (K-Means) - {dataset_name} ({split})")
    plt.xlabel("Cluster ID")
    plt.ylabel("User Count")
    plt.xticks(rotation=0)
    plt.grid(axis="y", linestyle="--", alpha=0.6)
    kmeans_plot_file = os.path.join(output_dir, f"kmeans_cluster_distribution_{k}_{split}.png") 
    plt.savefig(kmeans_plot_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    # GMM plot
    plt.figure(figsize=(10, 5))
    gmm_cluster_counts.plot(kind="bar", color="red", alpha=0.7)
    plt.title(f"User Count per Cluster (GMM) - {dataset_name} ({split})")
    plt.xlabel("Cluster ID")
    plt.ylabel("User Count")
    plt.xticks(rotation=0)
    plt.grid(axis="y", linestyle="--", alpha=0.6)
    gmm_plot_file = os.path.join(output_dir, f"gmm_cluster_distribution_{k}_{split}.png")
    plt.savefig(gmm_plot_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Results saved for {split} split.")

def main():
    dataset_name = sys.argv[1]
    k = int(sys.argv[2]) if len(sys.argv) > 2 else 10
    split = sys.argv[3] if len(sys.argv) > 3 else "train"
    
    run_user_clustering(dataset_name, k, split)

if __name__ == "__main__":
    main()