"""
Cluster compressed gists from evidence extraction results.

This script reads the JSON output from extract_evidence.py, extracts all compressed gists,
generates embeddings for them, and clusters them into varying numbers of groups using k-means clustering.
It then plots the inertia curve to help determine the optimal number of clusters.
"""

import json
import numpy as np
import time
import os
from typing import List, Dict, Any, Tuple
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import pandas as pd
from collections import defaultdict
import matplotlib.pyplot as plt
import requests
from config import OPENROUTER_API_KEY

from sentence_transformers import SentenceTransformer

EMBEDDING_MODEL = "all-mpnet-base-v2" 

# File paths
INPUT_FILE = "evidence/deepseek_r1_all.json"  # Default input file
OUTPUT_DIR = "clusters"  # Output directory for clustering results
OUTPUT_FILE = "deepseek_r1_curve.json"  # Output file for clustering results

# Run configuration
RUN_MODE = "fixed"  # "curve" or "fixed"
FIXED_K = 200  # Used when RUN_MODE == "fixed"
LEVEL_2_CLUSTERS = 15

USE_LLM_NAMING = True
NAMING_MODEL = "openai/gpt-4.1"


def _truncate_for_name(text: str, max_words: int = None) -> str:
    """Create a readable name from gist text without word limits."""
    return text.strip()


def _compute_cluster_names(
    gists: List[str],
    embeddings: np.ndarray,
    cluster_labels: np.ndarray,
    kmeans: KMeans,
    max_words: int = None,
) -> Dict[int, str]:
    """
    Name clusters by selecting, within each cluster, the gist that is both
    close to its own centroid and far from neighboring clusters. We optimize
    a "margin" objective: (distance to nearest other centroid) - (distance to
    own centroid). The selected gist maximizes this margin, then is truncated
    into a concise name.
    """
    cluster_to_name: Dict[int, str] = {}
    centroids = kmeans.cluster_centers_  # shape: (k, d)

    for cluster_id in range(kmeans.n_clusters):
        indices = np.where(cluster_labels == cluster_id)[0]
        if indices.size == 0:
            cluster_to_name[cluster_id] = f"Cluster {cluster_id}"
            continue
        center = centroids[cluster_id]

        # Distances to own centroid
        own_dists = np.linalg.norm(embeddings[indices] - center, axis=1)  # (m,)

        # Distances to other centroids; get nearest-other-centroid distance per sample
        other_centroid_ids = [cid for cid in range(centroids.shape[0]) if cid != cluster_id]
        if other_centroid_ids:
            # Initialize with +inf
            nearest_other_dists = np.full_like(own_dists, fill_value=np.inf, dtype=float)
            for other_cid in other_centroid_ids:
                other_center = centroids[other_cid]
                d = np.linalg.norm(embeddings[indices] - other_center, axis=1)
                nearest_other_dists = np.minimum(nearest_other_dists, d)
        else:
            # Single cluster case: fall back to closest to own centroid
            nearest_other_dists = np.zeros_like(own_dists)

        # Margin objective: far from neighbors, close to own
        margins = nearest_other_dists - own_dists
        best_idx_local = int(np.argmax(margins))
        best_idx_global = int(indices[best_idx_local])
        cluster_to_name[cluster_id] = _truncate_for_name(gists[best_idx_global])
    return cluster_to_name


def _build_hierarchy_from_level1(
    level1_clusters_named: Dict[str, List[str]],
    random_state: int,
) -> Dict[str, Any]:
    """
    Build a two-layer hierarchy on top of Level 1 clusters using their names:
    - Level 2 clusters the Level 1 names into K2 groups, where K2 is approximately
      sqrt(K1) to create a reasonable number of high-level categories.

    Returns a dictionary with 'level1' and optional 'level2'.
    """
    level1_names = list(level1_clusters_named.keys())
    hierarchy: Dict[str, Any] = {
        'level1': {
            'k': len(level1_names),
            'clusters': level1_clusters_named,
        }
    }

    # If not enough Level 1 clusters, return only Level 1
    if len(level1_names) < 2:
        return hierarchy

    # Level 2: cluster Level 1 names into exactly 6 clusters
    level1_name_embeddings = get_embeddings(level1_names)
    k2 = min(LEVEL_2_CLUSTERS, len(level1_names))

    kmeans2 = KMeans(n_clusters=k2, random_state=random_state)
    cluster_labels2 = kmeans2.fit_predict(level1_name_embeddings)

    level2_id_to_name = _compute_cluster_names(
        gists=level1_names,
        embeddings=level1_name_embeddings,
        cluster_labels=cluster_labels2,
        kmeans=kmeans2,
    )

    level2_clusters_named: Dict[str, List[str]] = defaultdict(list)
    for idx, cid in enumerate(cluster_labels2):
        l2_name = level2_id_to_name[int(cid)]
        level2_clusters_named[l2_name].append(level1_names[idx])

    # Sort Level 2 clusters by aggregated phrase size
    def l2_agg_size(item: Tuple[str, List[str]]) -> int:
        _, l1_children = item
        return sum(len(level1_clusters_named[l1]) for l1 in l1_children)

    level2_sorted = sorted(level2_clusters_named.items(), key=l2_agg_size, reverse=True)

    level2_summary: Dict[str, Any] = {}
    for l2_name, l1_children in level2_sorted:
        total_phrases = sum(len(level1_clusters_named[l1]) for l1 in l1_children)
        level2_summary[l2_name] = {
            'children_level1': l1_children,
            'size_level1': len(l1_children),
            'size_phrases': int(total_phrases),
        }

    # Optionally rename Level 2 clusters via LLM using all phrases from their Level 1 children
    if USE_LLM_NAMING and OPENROUTER_API_KEY:
        orig_names = list(level2_summary.keys())
        proposed_names: List[str] = []
        for l2_name in orig_names:
            examples: List[str] = []
            for l1_name in level2_summary[l2_name]['children_level1']:
                examples.extend(level1_clusters_named.get(l1_name, []))
            new_name = _generate_cluster_name_llm(examples) or l2_name
            proposed_names.append(new_name)
        # Deduplicate
        name_counts: Dict[str, int] = {}
        unique_names: List[str] = []
        for name in proposed_names:
            c = name_counts.get(name, 0)
            unique = name if c == 0 else f"{name} ({c + 1})"
            name_counts[name] = c + 1
            unique_names.append(unique)
        # Rebuild with renamed keys
        level2_summary = {new: level2_summary[old] for old, new in zip(orig_names, unique_names)}

    hierarchy['level2'] = {
        'k': len(level2_summary),
        'clusters': level2_summary,
    }

    return hierarchy


def _deduplicate_name_mapping(id_to_name: Dict[int, str]) -> Dict[int, str]:
    """Ensure names are unique by appending ' (2)', ' (3)', ... for duplicates."""
    name_counts: Dict[str, int] = {}
    unique_map: Dict[int, str] = {}
    # Preserve stable ordering by id
    for cid in sorted(id_to_name.keys()):
        name = id_to_name[cid].strip()
        count = name_counts.get(name, 0)
        if count == 0:
            unique_name = name
        else:
            unique_name = f"{name} ({count + 1})"
        name_counts[name] = count + 1
        unique_map[cid] = unique_name
    return unique_map


def _generate_cluster_name_llm(examples: List[str]) -> str:
    """
    Use an LLM via OpenRouter to generate a concise 1-10 word name for a cluster
    based on its member examples. Returns an empty string on failure.
    """
    if not USE_LLM_NAMING or not OPENROUTER_API_KEY:
        return ""

    # Construct prompt
    examples_sample = examples[:50]  # cap to avoid oversize requests
    joined_examples = "\n- ".join([e.strip() for e in examples_sample if e and e.strip()])
    system_prompt = (
        "You are a precise taxonomy assistant. You read many short phrases (gists) "
        "and produce a short, human-readable category name that summarizes them."
    )
    user_prompt = (
        f"""An LLM that was given a task, and then asked to guess wether the task is meant to be an evaluation of this LLM or not. 
        Then the parts of the LLM answer in which the LLM reasons about the cues that lead it to believe that it's being evaluated, were extracted shortened to gists, clustered into semantically similar groups, and one of these groups is presented to you here.
        Your goal is to name the group.
    
        INSTRUCTIONS:
        - The group name must be 1-12 - word-long. 
        - The group name must contain, what's the common thing across all the elements of the group. Use any cues: what is the common topic among all the elements of the group: whether it's a subject of the prompts, the way prompts are structured, the patterns in the user intent, or something else.
        - Don't make the group name too similar to a single gist.
        - If the elements of the group follow a certain pattern of the sentence structure, follow it if possible. 
        Return only the group name, no explanation.
        
        <GROUP OF GISTS>
        {joined_examples}
        <\GROUP OF GISTS>"""
    
    )

    try:
        headers = {
            "Authorization": f"Bearer {OPENROUTER_API_KEY}",
            "Content-Type": "application/json",
        }
        payload = {
            "model": NAMING_MODEL,
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
            "reasoning": {"effort": "low"},
            "max_tokens": 100,
        }
        resp = requests.post("https://openrouter.ai/api/v1/chat/completions", headers=headers, json=payload, timeout=60)
        resp.raise_for_status()
        data = resp.json()
        content = (
            data.get("choices", [{}])[0]
            .get("message", {})
            .get("content", "")
            .strip()
        )
        # Basic sanitize: collapse whitespace
        words = content.replace("\n", " ").split()
        if not words:
            return ""
        name = " ".join(words)
        return name
    except Exception:
        return ""

def _name_clusters_with_llm(clusters_named: Dict[str, List[str]]) -> Dict[str, List[str]]:
    """
    Re-name a dictionary of clusters keyed by heuristic names using an LLM.
    Returns a new dict keyed by LLM-generated names if successful; otherwise returns original.
    Ensures uniqueness of names.
    """
    if not USE_LLM_NAMING or not OPENROUTER_API_KEY:
        return clusters_named

    new_name_to_examples: Dict[str, List[str]] = {}
    fallback_needed = False
    proposed_names: List[str] = []
    keys = list(clusters_named.keys())
    for key in keys:
        examples = clusters_named[key]
        name = _generate_cluster_name_llm(examples)
        if not name:
            fallback_needed = True
            name = key
        proposed_names.append(name)

    # Deduplicate names
    name_counts: Dict[str, int] = {}
    unique_names: List[str] = []
    for name in proposed_names:
        count = name_counts.get(name, 0)
        if count == 0:
            unique = name
        else:
            unique = f"{name} ({count + 1})"
        name_counts[name] = count + 1
        unique_names.append(unique)

    for original_key, unique_name in zip(keys, unique_names):
        new_name_to_examples[unique_name] = clusters_named[original_key]

    return new_name_to_examples

def get_embeddings(texts: List[str]) -> np.ndarray:
    """
    Generate embeddings for a list of texts using sentence-transformers.
    
    Args:
        texts: List of text strings to embed
        
    Returns:
        numpy array of embeddings, shape (n_texts, embedding_dim)
    """
    print(f"Generating embeddings for {len(texts)} texts using sentence-transformers...")
    
    model = SentenceTransformer(EMBEDDING_MODEL)
    embeddings = model.encode(texts, show_progress_bar=True)
    return embeddings


def extract_gists_from_evidence_json(json_file_path: str) -> Tuple[List[str], List[Dict[str, Any]]]:
    """
    Extract all compressed gists from evidence JSON file.
    
    Args:
        json_file_path: Path to the evidence JSON file
        
    Returns:
        Tuple of (list of gists, list of metadata for each gist)
    """
    with open(json_file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    gists = []
    metadata = []
    
    # Extract gists from all samples
    for sample_id, sample_data in data.get('samples', {}).items():
        evidence_items = sample_data.get('evidence', [])
        
        for i, evidence_item in enumerate(evidence_items):
            gist = evidence_item.get('compressed_gist', '')
            original_evidence = evidence_item.get('original_evidence', '')
            
            if gist.strip():  # Only add non-empty gists
                gists.append(gist.strip())
                metadata.append({
                    'sample_id': sample_id,
                    'evidence_index': i,
                    'original_evidence': original_evidence,
                    'gist': gist.strip()
                })
    
    print(f"Extracted {len(gists)} compressed gists from {len(data.get('samples', {}))} samples")
    return gists, metadata


def cluster_gists(gists: List[str], metadata: List[Dict[str, Any]], 
                 n_clusters: int = 50, random_state: int = 42) -> Dict[str, Any]:
    """
    Cluster gists using k-means clustering on embeddings.
    
    Args:
        gists: List of gist strings
        metadata: List of metadata for each gist
        n_clusters: Number of clusters to create
        random_state: Random seed for reproducibility
        
    Returns:
        Dictionary containing clustering results and inertia
    """
    if len(gists) == 0:
        return {"error": "No gists to cluster"}
    
    # Get embeddings
    embeddings = get_embeddings(gists)
    
    # Perform k-means clustering
    print(f"Performing k-means clustering into {n_clusters} clusters...")
    kmeans = KMeans(n_clusters=min(n_clusters, len(gists)), random_state=random_state)
    cluster_labels = kmeans.fit_predict(embeddings)
    
    # Get inertia (within-cluster sum of squares)
    inertia = kmeans.inertia_
    
    # Calculate silhouette score if we have enough samples
    silhouette_avg = None
    if len(gists) > 1:
        try:
            silhouette_avg = silhouette_score(embeddings, cluster_labels)
            print(f"Silhouette score: {silhouette_avg:.3f}")
        except:
            print("Could not calculate silhouette score")
    
    # Organize results by cluster
    clusters = defaultdict(list)
    for i, cluster_id in enumerate(cluster_labels):
        # Only include the gist text, no metadata
        clusters[int(cluster_id)].append(gists[i])
    
    # Convert to regular dict and sort clusters by size
    clusters_dict = dict(clusters)
    sorted_clusters = sorted(clusters_dict.items(), 
                           key=lambda x: len(x[1]), reverse=True)
    
    # Create summary statistics
    cluster_sizes = [len(items) for items in clusters_dict.values()]
    
    # Name clusters using closest-to-centroid gist
    cluster_id_to_name_raw = _compute_cluster_names(
        gists=gists,
        embeddings=embeddings,
        cluster_labels=cluster_labels,
        kmeans=kmeans,
    )
    cluster_id_to_name = _deduplicate_name_mapping(cluster_id_to_name_raw)

    # Alternate view keyed by names (Level 1)
    clusters_named = {cluster_id_to_name[cid]: items for cid, items in sorted_clusters}
    # Optionally rename clusters using LLM labels
    clusters_named = _name_clusters_with_llm(clusters_named)

    # Build hierarchical clustering based on cluster names (Levels 2 and 3)
    hierarchy = _build_hierarchy_from_level1(clusters_named, random_state=random_state)

    results = {
        'n_total_gists': len(gists),
        'n_clusters': len(clusters_dict),
        'silhouette_score': float(silhouette_avg) if silhouette_avg is not None else None,
        'inertia': float(inertia),
        'cluster_size_stats': {
            'mean': float(np.mean(cluster_sizes)),
            'std': float(np.std(cluster_sizes)),
            'min': int(np.min(cluster_sizes)),
            'max': int(np.max(cluster_sizes))
        },
        # Only keep hierarchy; Level 1 holds the named clusters
        'hierarchy': hierarchy,
    }
    
    return results


def cluster_with_varying_k(gists: List[str], metadata: List[Dict[str, Any]], 
                          k_range: Tuple[int, int] = (30, 150), k_step: int = 5,
                          random_state: int = 42) -> Dict[str, Any]:
    """
    Cluster gists with varying numbers of clusters and calculate inertia for each.
    
    Args:
        gists: List of gist strings
        metadata: List of metadata for each gist
        k_range: Tuple of (min_k, max_k) for clustering
        k_step: Step size for k values
        random_state: Random seed for reproducibility
        
    Returns:
        Dictionary containing clustering results for all k values
    """
    if len(gists) == 0:
        return {"error": "No gists to cluster"}
    
    # Get embeddings once
    print(f"Generating embeddings for {len(gists)} texts...")
    embeddings = get_embeddings(gists)
    
    min_k, max_k = k_range
    k_values = list(range(min_k, max_k + 1, k_step))
    
    # Ensure k doesn't exceed number of samples
    k_values = [k for k in k_values if k <= len(gists)]
    
    if not k_values:
        return {"error": "No valid k values in the specified range"}
    
    results = {
        'n_total_gists': len(gists),
        'k_values': k_values,
        'inertias': [],
        'silhouette_scores': [],
        'clustering_results': {}
    }
    
    print(f"Clustering with k values: {k_values}")
    
    for k in k_values:
        print(f"Clustering with k={k}...")
        
        # Perform k-means clustering
        kmeans = KMeans(n_clusters=k, random_state=random_state)
        cluster_labels = kmeans.fit_predict(embeddings)
        
        # Get inertia
        inertia = kmeans.inertia_
        results['inertias'].append(float(inertia))
        
        # Calculate silhouette score if possible
        silhouette_avg = None
        if len(gists) > 1 and k > 1:
            try:
                silhouette_avg = silhouette_score(embeddings, cluster_labels)
            except:
                pass
        results['silhouette_scores'].append(float(silhouette_avg) if silhouette_avg is not None else None)
        
        # Store clustering results for the final k value (or you can store all if needed)
        if k == k_values[-1]:  # Store results for the last k value
            clusters = defaultdict(list)
            for i, cluster_id in enumerate(cluster_labels):
                clusters[int(cluster_id)].append(gists[i])

            clusters_dict = dict(clusters)
            sorted_clusters = sorted(
                clusters_dict.items(), key=lambda x: len(x[1]), reverse=True
            )

            # Name clusters using closest-to-centroid gist (use fitted kmeans)
            cluster_id_to_name_raw = _compute_cluster_names(
                gists=gists,
                embeddings=embeddings,
                cluster_labels=cluster_labels,
                kmeans=kmeans,
            )
            cluster_id_to_name = _deduplicate_name_mapping(cluster_id_to_name_raw)
            clusters_named = {cluster_id_to_name[cid]: items for cid, items in sorted_clusters}

            results['clustering_results'] = {
                'n_clusters': len(clusters_dict),
                # Only keep hierarchy; Level 1 holds the named clusters
                'hierarchy': _build_hierarchy_from_level1(clusters_named, random_state=random_state),
            }
    
    return results


def plot_inertia_curve(k_values: List[int], inertias: List[float], output_dir: str = OUTPUT_DIR):
    """
    Plot the inertia curve for different numbers of clusters.
    
    Args:
        k_values: List of k values used for clustering
        inertias: List of corresponding inertia values
        output_dir: Directory to save the plot
    """
    plt.figure(figsize=(12, 8))
    
    # Plot inertia curve
    plt.subplot(2, 1, 1)
    plt.plot(k_values, inertias, 'o-', color='#BDE8FD', linewidth=2, markersize=6)
    plt.xlabel('Number of Clusters (k)')
    plt.ylabel('Inertia (Within-cluster Sum of Squares)')
    plt.title('Inertia vs Number of Clusters')
    plt.grid(True, alpha=0.3)
    
    # Plot change in inertia (first derivative)
    plt.subplot(2, 1, 2)
    if len(inertias) > 1:
        inertia_changes = np.diff(inertias)
        k_values_changes = k_values[1:]  # k values for the changes
        
        plt.plot(k_values_changes, inertia_changes, 'ro-', linewidth=2, markersize=6)
        plt.xlabel('Number of Clusters (k)')
        plt.ylabel('Change in Inertia')
        plt.title('Change in Inertia vs Number of Clusters')
        plt.grid(True, alpha=0.3)
        
        # Add horizontal line at y=0 for reference
        plt.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    
    # Save the plot
    plot_path = os.path.join(output_dir, 'inertia_curve.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    print(f"Inertia curve saved to {plot_path}")
    
    # Show the plot
    plt.show()


def print_cluster_summary(results: Dict[str, Any], top_n: int = 10):
    """
    Print a summary of the clustering results.
    
    Args:
        results: Clustering results dictionary
        top_n: Number of top clusters to display in detail
    """
    if 'error' in results:
        print(f"Error: {results['error']}")
        return
    
    print(f"\n=== CLUSTERING SUMMARY ===")
    print(f"Total gists: {results['n_total_gists']}")
    print(f"Number of clusters: {results['n_clusters']}")
    if results['silhouette_score']:
        print(f"Silhouette score: {results['silhouette_score']:.3f}")
    
    print(f"\nCluster size statistics:")
    stats = results['cluster_size_stats']
    print(f"  Mean: {stats['mean']:.1f}")
    print(f"  Std: {stats['std']:.1f}")
    print(f"  Min: {stats['min']}")
    print(f"  Max: {stats['max']}")
    
    print(f"\n=== TOP {top_n} LARGEST CLUSTERS ===")
    
    # Read clusters from hierarchy level1
    level1 = results.get('hierarchy', {}).get('level1', {})
    cluster_items = list(level1.get('clusters', {}).items())
    for i, (cluster_name, items) in enumerate(cluster_items[:top_n]):
        print(f"\n{cluster_name.upper()} ({len(items)} items):")
        
        # Show first few gists in the cluster
        unique_gists = list(set(items))  # Remove duplicates
        
        for j, gist in enumerate(unique_gists[:5]):  # Show up to 5 unique gists
            count = items.count(gist)
            print(f"  • {gist}" + (f" ({count}x)" if count > 1 else ""))
        
        if len(unique_gists) > 5:
            print(f"  ... and {len(unique_gists) - 5} more unique gists")


def main(
    input_file: str = INPUT_FILE,
    output_dir: str = OUTPUT_DIR,
    output_file: str = OUTPUT_FILE,
    k_range: Tuple[int, int] = (50, 300),
    k_step: int = 10,
    run_mode: str = RUN_MODE,
    fixed_k: int = FIXED_K,
):
    """
    Main function to perform gist clustering with varying numbers of clusters.
    
    Args:
        input_file: Path to evidence JSON file
        output_dir: Directory to save clustering results
        output_file: Filename for clustering results
        k_range: Tuple of (min_k, max_k) for clustering
        k_step: Step size for k values
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, output_file)
    
    print("=== GIST CLUSTERING PIPELINE ===")
    print(f"Input file: {input_file}")
    print(f"Output path: {output_path}")
    if run_mode == "curve":
        print(f"Mode: curve | K range: {k_range[0]} to {k_range[1]} with step {k_step}")
    elif run_mode == "fixed":
        print(f"Mode: fixed | K: {fixed_k}")
    else:
        print(f"Unknown run_mode '{run_mode}', defaulting to 'curve'")
        run_mode = "curve"
    
    # Extract gists from evidence JSON
    print("\n1. Extracting compressed gists...")
    gists, metadata = extract_gists_from_evidence_json(input_file)
    
    if not gists:
        print("No gists found in the input file!")
        return
    
    if run_mode == "fixed":
        # Cluster into a fixed number of clusters and save JSON
        print("\n2. Clustering gists with fixed k...")
        results = cluster_gists(gists, metadata, n_clusters=fixed_k)

        if 'error' in results:
            print(f"Error: {results['error']}")
            return

        print(f"\n3. Saving results to {output_path}...")
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2, ensure_ascii=False)

        # Optional concise summary
        print_cluster_summary(results, top_n=10)

        print(f"\nClustering complete! Results saved to {output_path}")
        return

    # Default: curve mode
    print("\n2. Clustering gists with varying k values...")
    results = cluster_with_varying_k(gists, metadata, k_range=k_range, k_step=k_step)

    if 'error' in results:
        print(f"Error: {results['error']}")
        return

    print("\n3. Plotting inertia curve...")
    plot_inertia_curve(results['k_values'], results['inertias'], output_dir)

    print(f"\n4. Saving results to {output_path}...")
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

    print(f"\n=== CLUSTERING SUMMARY ===")
    print(f"Total gists: {results['n_total_gists']}")
    print(f"K values tested: {results['k_values']}")
    print(f"Number of k values: {len(results['k_values'])}")

    if results['inertias']:
        min_inertia_idx = np.argmin(results['inertias'])
        best_k = results['k_values'][min_inertia_idx]
        min_inertia = results['inertias'][min_inertia_idx]
        print(f"Best k (lowest inertia): {best_k} (inertia: {min_inertia:.2f})")

    print(f"\nInertia values:")
    for k, inertia in zip(results['k_values'], results['inertias']):
        print(f"  k={k}: {inertia:.2f}")

    print(f"\nClustering complete! Results saved to {output_path}")
    print(f"Inertia curve plot saved to {os.path.join(output_dir, 'inertia_curve.png')}")


if __name__ == '__main__':
    main()