# hierarchical_clustering.py
import os
import glob
import numpy as np
import pandas as pd
from sklearn.cluster import DBSCAN
from sklearn.neighbors import NearestNeighbors
import json
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns  

def find_eps(X, min_pts):
    """
    Find appropriate eps for DBSCAN using k-distance plot.
    Computes the k-nearest neighbor distances (k = min_pts - 1), sorts them,
    and finds the 'knee' point as the maximum difference (approximating curvature).
    
    Parameters:
    - X: np.array, point set
    - min_pts: int, minimum points for DBSCAN
    
    Returns:
    - float, estimated eps value
    """
    if len(X) < min_pts:
        return 1.0  # Default fallback for small sets
    
    # Compute k-nearest neighbors
    nbrs = NearestNeighbors(n_neighbors=min_pts).fit(X)
    distances, _ = nbrs.kneighbors(X)
    
    # Sort the (min_pts-1)-th distances
    k_distances = np.sort(distances[:, min_pts - 1])
    
    if len(k_distances) < 2:
        return k_distances[0] if len(k_distances) > 0 else 1.0
    
    # Find knee by maximum difference (simple curvature approximation)
    diffs = np.diff(k_distances)
    knee = np.argmax(diffs) + 1
    eps = k_distances[knee]
    
    # Ensure eps is not too small
    return max(eps, 0.01)

def compute_centroids(X, labels):
    """
    Compute centroids for each cluster, excluding noise.
    
    Parameters:
    - X: np.array, point set
    - labels: np.array, DBSCAN labels
    
    Returns:
    - np.array of centroids (shape: num_clusters x dim)
    """
    unique_labels = set(labels) - {-1}  # Exclude noise
    centroids = []
    for k in unique_labels:
        cluster_points = X[labels == k]
        if len(cluster_points) > 0:
            centroid = np.mean(cluster_points, axis=0)
            centroids.append(centroid)
    return np.array(centroids) if centroids else np.array([])

def arrays_equal(a, b, atol=1e-6):
    """
    Check if two arrays of points are equal (same shape, close values, order-insensitive).
    Sorts both arrays before comparison.
    
    Parameters:
    - a, b: np.arrays to compare
    - atol: float, absolute tolerance for np.allclose
    
    Returns:
    - bool, True if arrays are considered equal
    """
    if a.shape != b.shape:
        return False
    a_sorted = np.sort(a, axis=0)
    b_sorted = np.sort(b, axis=0)
    return np.allclose(a_sorted, b_sorted, atol=atol)

def hierarchical_clustering(X, dim, base_filename):
    """
    Perform hierarchical clustering using DBSCAN recursively on centroids.
    At each level:
    - Compute eps via k-distance plot
    - Run DBSCAN with computed eps and min_pts = dim + 1
    - Remove noise
    - Compute and save centroids to CSV
    - Repeat on centroids until stopping condition
    
    Stopping conditions:
    1. Only one centroid remains
    2. Centroids unchanged from previous level (position and count)
    
    Parameters:
    - X: np.array, initial point set
    - dim: int, dimension
    - base_filename: str, base name for output files
    
    Returns:
    - list of dicts for hierarchy (levels with centroids)
    """
    level = 0
    prev_centroids = None
    hierarchy = []  # Track levels with centroids for visualization and JSON
    
    while True:
        min_pts = dim + 1
        eps = find_eps(X, min_pts)
        print(f"Level {level}: Computed eps={eps:.4f}, min_pts={min_pts}")
        
        # Run DBSCAN
        db = DBSCAN(eps=eps, min_samples=min_pts).fit(X)
        labels = db.labels_
        
        # Remove noise
        non_noise_mask = labels != -1
        if not np.any(non_noise_mask):
            print(f"Level {level}: All points are noise. Stopping.")
            break
        X_non_noise = X[non_noise_mask]
        labels_non_noise = labels[non_noise_mask]
        
        # Compute centroids
        centroids = compute_centroids(X_non_noise, labels_non_noise)
        
        if len(centroids) == 0:
            print(f"Level {level}: No centroids computed. Stopping.")
            break
        
        # Save centroids to CSV
        cols = [f"x{i+1}" for i in range(dim)]
        centroids_df = pd.DataFrame(centroids, columns=cols)
        centroid_filename = f"{base_filename}_level_{level}_centroids.csv"
        centroids_df.to_csv(centroid_filename, index=False)
        print(f"Saved {centroid_filename} with {len(centroids)} centroids.")
        
        # Track in hierarchy
        hierarchy.append({
            'level': level,
            'num_centroids': len(centroids),
            'centroids': centroids.tolist()
        })
        
        # Check stopping conditions
        if len(centroids) == 1:
            print("Stopping: Only one centroid remains.")
            break
        if prev_centroids is not None and arrays_equal(centroids, prev_centroids):
            print("Stopping: Centroids unchanged from previous level.")
            break
        
        # Update for next iteration
        prev_centroids = centroids.copy()
        X = centroids
        level += 1
    
    # Save hierarchy summary as JSON
    hierarchy_filename = f"{base_filename}_hierarchy.json"
    with open(hierarchy_filename, 'w') as f:
        json.dump(hierarchy, f)
    print(f"Saved hierarchy summary to {hierarchy_filename}")
    
    return hierarchy

def visualize_hierarchy(X, dim, hierarchy, base_filename):
    """
    Visualize the original points and centroids at each level for 1D, 2D, 3D data.
    Uses scientific plotting styles with seaborn for color palettes.
    
    Parameters:
    - X: np.array, original point set
    - dim: int, dimension (1,2,3)
    - hierarchy: list of dicts with 'level' and 'centroids'
    - base_filename: str, base name for saving visualization
    """
    if dim > 3:
        return  # No visualization for 4D+
    
    # Set scientific plotting style
    sns.set(style="whitegrid", palette="muted")
    
    fig = plt.figure(figsize=(8, 6))
    
    if dim == 1:
        # 1D: Plot points and centroids on x-axis only, like a number line
        # Adjust figure size for 1D to be more horizontal and compact
        fig.set_size_inches(10, 4)
        
        # Plot original points on the x-axis
        plt.scatter(X[:, 0], np.zeros_like(X[:, 0]), label='Points', alpha=0.6, s=20, c='blue')
        
        # Plot centroids at different levels with different colors/markers
        colors = ['red', 'green', 'orange', 'purple', 'brown', 'pink']
        markers = ['x', '^', 's', 'D', 'v', 'o']
        
        for entry in hierarchy:
            level = entry['level']
            cents = np.array(entry['centroids'])
            if len(cents) > 0:
                color = colors[level % len(colors)]
                marker = markers[level % len(markers)]
                plt.scatter(cents[:, 0], np.zeros_like(cents[:, 0]), 
                           label=f'Level {level} Centroids', 
                           marker=marker, s=80, c=color)  # Removed edgecolors to avoid warning
        # Remove y-axis ticks and labels
        plt.yticks([])
        plt.gca().spines['left'].set_visible(False)
        plt.gca().spines['right'].set_visible(False)
        plt.gca().spines['top'].set_visible(False)
        
        # Only show x-axis
        plt.xlabel('X')
        plt.grid(True, axis='x', alpha=0.3)
    
    elif dim == 2:
        # 2D: Scatter plot with different markers/colors for levels
        plt.scatter(X[:, 0], X[:, 1], label='Points', alpha=0.5, s=10)
        for entry in hierarchy:
            level = entry['level']
            cents = np.array(entry['centroids'])
            if len(cents) > 0:
                plt.scatter(cents[:, 0], cents[:, 1], label=f'Level {level} Centroids', marker='x', s=50)
        plt.xlabel('X1')
        plt.ylabel('X2')
    
    elif dim == 3:
        # 3D: 3D scatter plot
        ax = fig.add_subplot(111, projection='3d')
        ax.scatter(X[:, 0], X[:, 1], X[:, 2], label='Points', alpha=0.5, s=10)
        for entry in hierarchy:
            level = entry['level']
            cents = np.array(entry['centroids'])
            if len(cents) > 0:
                ax.scatter(cents[:, 0], cents[:, 1], cents[:, 2], label=f'Level {level} Centroids', marker='x', s=50)
        ax.set_xlabel('X1')
        ax.set_ylabel('X2')
        ax.set_zlabel('X3')
    
    plt.legend()
    plt.title(f'Hierarchical Clustering Visualization ({dim}D)')
    vis_filename = f"{base_filename}_visualization.png"
    plt.savefig(vis_filename)
    plt.close()
    print(f"Saved visualization to {vis_filename}")

if __name__ == "__main__":
    # Automatically find and process all points_*d.csv files in current directory
    
    csv_files = glob.glob("points_*d.csv")
    for filename in csv_files:
        # Extract dim from filename, e.g., points_2d.csv -> 2
        try:
            dim_str = filename.split("_")[1].split("d")[0]
            dim = int(dim_str)
        except (IndexError, ValueError):
            print(f"Skipping invalid file: {filename}")
            continue
        
        # Load data
        df = pd.read_csv(filename)
        X = df.values
        if X.shape[1] != dim:
            print(f"Dimension mismatch in {filename}. Skipping.")
            continue
        
        base_filename = os.path.splitext(filename)[0]  # e.g., points_2d
        
        print(f"Processing {filename} as {dim}-dimensional data with {len(X)} points.")
        
        # Run hierarchical clustering
        hierarchy = hierarchical_clustering(X, dim, base_filename)
        
        # Visualize if dim <= 3
        if dim <= 3:
            visualize_hierarchy(X, dim, hierarchy, base_filename)
