#!/usr/bin/env python3


import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
import time
from sklearn.decomposition import PCA
from algorithms.exact_dbscan import ExactDBSCAN
from algorithms.lsh_dbscan import LSHApproximateDBSCAN
import tensorflow as tf
from utils.aloi_loader import load_aloi_images, images_to_vectors 
import gensim.downloader as api

# For computing cluster alignment
import numpy as np
from scipy.optimize import linear_sum_assignment


class DBSCANExperiment:
    
    def __init__(self, config=None):
        # Default configuration
        self.config = {
            'dataset_name': 'MNIST',
            'n': 1000,
            'eps': 1000,
            'min_pts': 100,
            'c_values': [5],
            'center_ratio': (1, 0.4),
            'ratio_offsets_cpi': [0],
            'ratio_offsets_cf': [0],
            'delta': 0.5,
            'pickling': False,
            'pca_output': 'all',
            'verbose': True,
            'aloi_specific_folders': None,  # List of specific ALOI folders to load
            'custom_colors': None  # Dictionary mapping cluster numbers to colors
        }
        
        # Update with provided config
        if config:
            self.config.update(config)
        
        # Generate ratio offsets
        self.config['ratio_offsets'] = [
            (cpi, cf) for cpi in self.config['ratio_offsets_cpi'] 
            for cf in self.config['ratio_offsets_cf']
        ]
        
        self.results = {}
        self.X = None
        self.object_ids = None
    
    def run_experiment(self):
        if self.config['verbose']:
            print("=== DBSCAN Comparison Experiment ===\n")
        
        if self.config['dataset_name'] == 'custom':
            self.X = self.config['data']
            self.object_ids = self.config['labels']
        else:
            # Load data
            self.X, self.object_ids = get_data(
                dataset=self.config['dataset_name'], 
                n=self.config['n'],
                aloi_specific_folders=self.config['aloi_specific_folders']
            )
        N = self.X.shape[0]
        
        # Print configuration
        if self.config['verbose']:
            print(f"Parameters: eps={self.config['eps']}, min_pts={self.config['min_pts']}")
            print(f"Approximation factors: {self.config['c_values']}")
            print(f"Center ratio: {self.config['center_ratio']}, Offsets: {self.config['ratio_offsets']}")
        
        # Check for existing pickle file
        pickle_file = f'results_{N}_{self.config["eps"]}_{self.config["min_pts"]}_{self.config["dataset_name"]}.pkl'
        if os.path.exists(pickle_file) and self.config['pickling']:
            with open(pickle_file, 'rb') as f:
                self.results = pickle.load(f)
            if self.config['verbose']:
                print(f"Results loaded from {pickle_file}")
        else:
            # Run exact DBSCAN
            self.results['exact_eps'] = run_exact_dbscan(
                self.X, self.config['eps'], self.config['min_pts']
            )
            if self.config['verbose']:
                print(f"   Time: {self.results['exact_eps']['time']:.4f}s")
                print(f"   Clusters: {self.results['exact_eps']['clusters']}")
                print(f"   Distance computations: {self.results['exact_eps']['distance_count']}")

        # Save to pickle if enabled
        if self.config['pickling']:
            with open(pickle_file, 'wb') as f:
                pickle.dump(self.results, f)
            if self.config['verbose']:
                print(f"Results saved to {pickle_file}")
        
        # Run LSH-based approximate DBSCAN for each c value
        print(f"\n2. Running LSH Approximate DBSCAN for c values {self.config['c_values']}, offsets {self.config['ratio_offsets']}...")
        
        for c in self.config['c_values']:
            c_results = []
            for ratio_offset in self.config['ratio_offsets']:
                result = run_lsh_dbscan(
                    self.X, self.config['eps'], self.config['min_pts'], c, 
                    self.config['center_ratio'], ratio_offset,
                    reference_labels=self.results['exact_eps']['labels'],
                    reference_time=self.results['exact_eps']['time'],
                    reference_distance_count=self.results['exact_eps']['distance_count'],
                    verbose=self.config['verbose'],
                    delta=self.config['delta']
                )
                self.results[f'lsh_c{c}_offset{ratio_offset}'] = result
                c_results.append(result)
            
            # Print detailed results for each offset
            if self.config['verbose']:
                for res in c_results:
                    print(f"   Offset {res['ratio_offset']}: Time {res['time']:.4f}s, Clusters {res['clusters']}, Misalignment {res['misalignment']:.3f}, Speedup {res['computation_speedup']:.2f}, (Hash,Dist) ({res['hash_count']},{res['distance_count']}), Time Speedup {res['speedup']:.2f}x")

            # Find best offset (highest computation speedup)
            best_result = max(c_results, key=lambda x: x['computation_speedup'])
            if self.config['verbose']:
                print(f"\n   Best result for c={c}:")
                print(f"        Offset: {best_result['ratio_offset']}")
                print(f"        (hash,distances): ({best_result['hash_count']},{best_result['distance_count']})")
                print(f"        Time: {best_result['time']:.4f}s")
                print(f"        Clusters: {best_result['clusters']}")
                print(f"        Misalignment with exact: {best_result['misalignment']:.3f}")
                print(f"        Computation speedup: {best_result['computation_speedup']:.2f}x")
        
        # Print results table
        print_results(self.results, self.config['eps'], self.config['min_pts'], 
                     self.config['c_values'], self.config['ratio_offsets'], self.config['delta'])
        
        # Generate plots
        print("\nGenerating PCA projection plots...")
        if self.config['pca_output'] == 'single':
            plot_exact_dbscan(self.X, self.results, self.config['eps'])
        elif self.config['pca_output'] == 'all':
            plot_results(self.X, self.results, self.config['eps'], self.config['min_pts'], self.config['c_values'], self.config['ratio_offsets'], self.object_ids, self.config['custom_colors'], self.config.get('plot_config'))

        if self.config['verbose']:
            print("\nExperiment completed!")
        
        return self.results


def minimize_sum_of_indices(matrix: np.ndarray) -> tuple[list[int], float]:
    row_indices, col_indices = linear_sum_assignment(matrix)
    min_sum = matrix[row_indices, col_indices].sum()
    return col_indices.tolist(), min_sum

def minimize_misalignment(A: list[int], B: list[int]) -> float:
    k = max(max(A), max(B))+1
    n = len(A)

    cost_matrix = np.zeros((k, k), dtype=int)
    noise_cost = 0

    for i in range(n):
        a = A[i]
        b_actual = B[i]
        if a == -1:
            if b_actual != -1:        
              noise_cost += 1
        else:
          for b_prime in range(0, k):
              if b_actual != b_prime:
                  cost_matrix[a, b_prime] += 1

    optimal_col_indices, min_no_noise_cost = minimize_sum_of_indices(cost_matrix)
    min_total_cost = min_no_noise_cost + noise_cost

    min_misalignment = min_total_cost / n
    return min_misalignment

def get_data(dataset='MNIST', n=100, aloi_ids=None, aloi_specific_folders=None):
    print(f"Loading {dataset} data...")
    rng = np.random.default_rng(42)
    if dataset == 'MNIST':
        # Load MNIST dataset
        (MNIST_train, object_ids), _ = tf.keras.datasets.mnist.load_data()
        X = MNIST_train.reshape(60000, -1).astype(np.float32) 
        indices = rng.choice(len(X), size=n, replace=False)
        X = X[indices]
        object_ids = object_ids[indices]
        return X, object_ids

    elif dataset == 'FashionMNIST':
        # Load FashionMNIST dataset
        (FashionMNIST_train, object_ids), _ = tf.keras.datasets.fashion_mnist.load_data()
        X = FashionMNIST_train.reshape(60000, -1).astype(np.float32) 
        indices = rng.choice(len(X), size=n, replace=False)
        X = X[indices]
        object_ids = object_ids[indices]
        return X, object_ids

    elif dataset == 'ALOI':
        # Load ALOI dataset
        if aloi_specific_folders is not None:
            X, object_ids = load_aloi_images(specific_folders=aloi_specific_folders)
        else:
            X, object_ids = load_aloi_images(num_folders=n)
        X = images_to_vectors(X)
        if aloi_ids is not None:
            mask = np.isin(object_ids, aloi_ids)
            X = X[mask]
            object_ids = object_ids[mask]
        return X, object_ids
    elif dataset == 'GloVe':
        # Load GloVe word embeddings
        print("Downloading GloVe embeddings...")
        glove_model = api.load("glove-wiki-gigaword-100")
        words = list(glove_model.index_to_key)
        vectors = glove_model.vectors
        indices = rng.choice(len(words), size=n, replace=False)
        X = vectors[indices].astype(np.float32)
        object_ids = None
    print(f"Loaded {X.shape[0]} points in {X.shape[1]} dimensions from {dataset}")
    return X, object_ids

def run_exact_dbscan(X, eps, min_pts):
    print(f"\n1. Running Exact DBSCAN...")
    start_time = time.time()
    exact_dbscan = ExactDBSCAN(eps=eps, min_pts=min_pts)
    exact_dbscan.fit(X)
    exact_time = time.time() - start_time
    
    clusters = len(set(exact_dbscan.labels_)) - (1 if -1 in exact_dbscan.labels_ else 0)
    
    return {
        'algorithm': 'Exact DBSCAN (eps)',
        'labels': exact_dbscan.labels_,
        'time': exact_time,
        'distance_count': exact_dbscan.distance_count,
        'clusters': clusters,
        'misalignment': 0.000
    }


def run_lsh_dbscan(X, eps, min_pts, c, center_ratio, ratio_offset, reference_labels=None, reference_time=None, reference_distance_count=None, verbose=True, delta=0.5):
    if verbose:
        print(f"\nRunning LSH DBSCAN (c={c}, center_ratio={center_ratio}, offset={ratio_offset})...")
    start_time = time.time()
    lsh_dbscan = LSHApproximateDBSCAN(eps=eps, min_pts=min_pts, c=c, center_ratio=center_ratio, ratio_offset=ratio_offset, verbose=verbose, delta=delta)
    lsh_dbscan.fit(X)
    lsh_time = time.time() - start_time
    
    clusters = len(set(lsh_dbscan.labels_)) - (1 if -1 in lsh_dbscan.labels_ else 0)
    
    result = {
        'algorithm': f'LSH DBSCAN (c={c})',
        # 'algorithm': f'LSH DBSCAN (c={c}, ratio={center_ratio}, offset={ratio_offset})',
        'labels': lsh_dbscan.labels_,
        'time': lsh_time,
        'distance_count': lsh_dbscan.distance_count,
        'hash_count': lsh_dbscan.hash_count,
        'total_computations': lsh_dbscan.distance_count + lsh_dbscan.hash_count,
        'clusters': clusters,
        'c': c,
        'center_ratio': center_ratio,
        'ratio_offset': ratio_offset
    }
    
    # Add comparison metrics if reference is provided
    if reference_labels is not None:
        agreement = minimize_misalignment(reference_labels, lsh_dbscan.labels_)
        result['misalignment'] = agreement
        
    if reference_time is not None:
        result['speedup'] = reference_time / lsh_time
        
    if reference_distance_count is not None:
        result['computation_speedup'] = reference_distance_count / result['total_computations']
    
    return result


def print_results(results, eps, min_pts, c_values, ratio_offsets, delta):
    print("\n" + "="*120)
    print("SUMMARY TABLE")
    print("="*120)
    print(f"{'Algorithm':<40} {'Time(s)':<10} {'Clusters':<10} {'Misalignment':<12} {'Comp Speedup':>14} {'Time Speedup':>14}")
    print("-"*120)
    
    # Reference exact DBSCAN
    exact_result = results['exact_eps']
    print(f"{exact_result['algorithm']:<40} {exact_result['time']:<10.4f} {exact_result['clusters']:<10} {'1.000':<12} {'1.00x':>14} {'1.00x':>14}")
    
    # LSH results - show only best offset for each c
    for c in c_values:
        # Find best offset for this c value
        best_result = _find_best_lsh_result(results, c, ratio_offsets)
        
        if best_result:
            print(f"{best_result['algorithm']:<40} {best_result['time']:<10.4f} {best_result['clusters']:<10} {best_result['misalignment']:<12.3f} {best_result['computation_speedup']:>13.2f}x {best_result['speedup']:>13.2f}x")
    
    print("="*120)
    
    # Computation savings summary
    print("\nCOMPUTATION SAVINGS:")
    print("-" * 40)
    for c in c_values:
        # Find best offset for this c value
        best_result = None
        best_speedup = 0
        for offset in ratio_offsets:
            result = results[f'lsh_c{c}_offset{offset}']
            if result['computation_speedup'] > best_speedup:
                best_speedup = result['computation_speedup']
                best_result = result
        
        if best_result:
            print(f"LSH c={c} (offset={best_result['ratio_offset']}): {best_result['computation_speedup']:.2f}x faster than exact DBSCAN")


def plot_results(X, results, eps, min_pts, c_values, ratio_offsets, object_ids=None, custom_colors=None, plot_config=None):    
    # Default plot configuration
    default_config = {
        'figsize': (24, 18),
        'layout': (2, 3),
        'show_axes_labels': False,
        'show_grid': True,
        'show_legends': True,
        'legend_position': (1.02, 1),
        'legend_fontsize': 9,
        'max_legend_entries': 10,
        'title_fontsize': 16,
        'show_overall_title': False,
        'show_subplot_titles': False,
        'vertical_spacing': 0.15,
        'top_spacing': 0.05,
        'separate_plots': False,
        'ground_truth_colors': [
            '#ACBBE8', '#EB9DA2', '#E8E6A5', '#BBE8B5', '#F0B884', '#C5ACE8'
        ],
        'hierarchical_colors': [
            '#1955FA', '#FAE502', '#027D35', '#F77B1B', '#E00909', '#911BF7'
        ],
        'alternate_palettes': [
            ['#3AA2F2', '#B8A000', '#015A26', '#A0520E', '#8A0D07', '#5A1BA4'],
            ['#010F87', '#FFD700', '#015A26', '#FF8C00', '#FF6C52', '#8A2BE2'],
            ['#6362A6', '#D4B800', '#015A26', '#C5620E', '#B80707', '#7208C4']
        ]
    }
    
    # Merge with provided config
    config = {**default_config, **(plot_config or {})}
    
    # Project to 2D using PCA
    pca = PCA(n_components=2)
    X_2d = pca.fit_transform(X)
    
    # Generate color mappings
    exact_color_map, hierarchical_color_maps = _generate_color_mappings(
        results, c_values, ratio_offsets, config, custom_colors
    )
    
    if config['separate_plots']:
        # Create separate individual plots
        _create_separate_plots(X_2d, results, eps, min_pts, c_values, ratio_offsets, 
                               object_ids, exact_color_map, hierarchical_color_maps, config)
    else:
        # Create combined subplot layout
        _create_combined_plots(X_2d, results, eps, min_pts, c_values, ratio_offsets, 
                               object_ids, exact_color_map, hierarchical_color_maps, config)
    

def plot_exact_dbscan(X, results, eps):
    from sklearn.decomposition import PCA
    import matplotlib.pyplot as plt
    import numpy as np

    # Project to 2D using PCA
    pca = PCA(n_components=2)
    X_2d = pca.fit_transform(X)

    labels = results['exact_eps']['labels']
    n_clusters = results['exact_eps']['clusters']

    # Colors for clusters - same scientifically distinct palette
    colors = [
        '#E31A1C',  # Red
        '#1F78B4',  # Blue
        '#33A02C',  # Green
        '#FF7F00',  # Orange
        '#6A3D9A',  # Purple
        '#B15928',  # Brown
        '#FB9A99',  # Pink
        '#A6CEE3',  # Light Blue
        '#B2DF8A',  # Light Green
        '#FDBF6F',  # Light Orange
        '#CAB2D6',  # Light Purple
        '#FFFF99',  # Yellow
        '#B2DF8A',  # Light Green 2
        '#33A02C',  # Dark Green
        '#1F78B4',  # Dark Blue
        '#E31A1C',  # Dark Red
        '#FF7F00',  # Dark Orange
        '#6A3D9A',  # Dark Purple
        '#B15928',  # Dark Brown
        '#FB9A99'   # Light Pink
    ]

    # Make plot
    plt.figure(figsize=(8, 6))

    unique_labels = set(labels)
    noise_label = -1
    cluster_labels = [l for l in unique_labels if l != noise_label]

    # Plot clusters
    for i, label in enumerate(cluster_labels):
        mask = labels == label
        cluster_points = X_2d[mask]
        if len(cluster_points) > 0:
            plt.scatter(cluster_points[:, 0], cluster_points[:, 1],
                        c=colors[i % len(colors)], alpha=0.7, s=30,
                        label=f'Cluster {label}')

    # Plot noise points
    noise_mask = labels == noise_label
    if np.any(noise_mask):
        noise_points = X_2d[noise_mask]
        plt.scatter(noise_points[:, 0], noise_points[:, 1],
                    c='black', alpha=0.5, s=20, label='Noise')

    plt.title(f"Exact DBSCAN (eps={eps})\n{n_clusters} clusters")
    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()  

def _generate_color_mappings(results, c_values, ratio_offsets, config, custom_colors):
    if custom_colors is not None:
        return None, None
    
    # Build LSH labels dict
    lsh_labels_dict = {}
    for c in c_values:
        best_result = _find_best_lsh_result(results, c, ratio_offsets)
        if best_result:
            lsh_labels_dict[c] = best_result['labels']
    
    # Generate exact color map
    exact_clusters = sorted([l for l in set(results['exact_eps']['labels']) if l != -1])
    exact_color_map = {}
    for i, cluster in enumerate(exact_clusters):
        exact_color_map[cluster] = config['hierarchical_colors'][i % len(config['hierarchical_colors'])]
    
    # Generate hierarchical color maps for LSH
    hierarchical_color_maps = {}
    all_palettes = [config['hierarchical_colors']] + config['alternate_palettes']
    
    for c_val, labels in lsh_labels_dict.items():
        lsh_clusters = sorted([l for l in set(labels) if l != -1])
        color_map = {}
        exact_cluster_usage = {}
        
        for lsh_cluster in lsh_clusters:
            best_exact_cluster = _find_best_matching_exact_cluster(
                lsh_cluster, labels, exact_clusters, results['exact_eps']['labels']
            )
            
            if best_exact_cluster is not None:
                base_color = exact_color_map[best_exact_cluster]
                
                if best_exact_cluster not in exact_cluster_usage:
                    exact_cluster_usage[best_exact_cluster] = 0
                
                if exact_cluster_usage[best_exact_cluster] == 0:
                    color_map[lsh_cluster] = base_color
                else:
                    base_color_index = exact_clusters.index(best_exact_cluster)
                    palette_index = exact_cluster_usage[best_exact_cluster]
                    color_map[lsh_cluster] = all_palettes[palette_index % len(all_palettes)][base_color_index]
                
                exact_cluster_usage[best_exact_cluster] += 1
            else:
                color_map[lsh_cluster] = config['hierarchical_colors'][len(color_map) % len(config['hierarchical_colors'])]
        
        hierarchical_color_maps[c_val] = color_map
    
    return exact_color_map, hierarchical_color_maps


def _find_best_lsh_result(results, c, ratio_offsets):
    best_result = None
    best_speedup = 0
    for offset in ratio_offsets:
        result_key = f'lsh_c{c}_offset{offset}'
        if result_key in results:
            result = results[result_key]
            if result['computation_speedup'] > best_speedup:
                best_speedup = result['computation_speedup']
                best_result = result
    return best_result


def _find_best_matching_exact_cluster(lsh_cluster, lsh_labels, exact_clusters, exact_labels):
    lsh_mask = lsh_labels == lsh_cluster
    lsh_points = np.where(lsh_mask)[0]
    
    best_exact_cluster = None
    max_overlap = 0
    
    for exact_cluster in exact_clusters:
        exact_mask = exact_labels == exact_cluster
        exact_points = np.where(exact_mask)[0]
        overlap = len(set(lsh_points) & set(exact_points))
        if overlap > max_overlap:
            max_overlap = overlap
            best_exact_cluster = exact_cluster
    
    return best_exact_cluster


def _plot_single_clustering(ax, X_2d, labels, title, color_map, config):
    unique_labels = set(labels)
    noise_label = -1
    cluster_labels = [l for l in unique_labels if l != noise_label]
    
    # Limit legend entries to avoid overcrowding
    max_legend_entries = config.get('max_legend_entries', 10)
    cluster_labels_for_legend = cluster_labels[:max_legend_entries]
    
    # Plot clusters
    for i, label in enumerate(cluster_labels):
        mask = labels == label
        cluster_points = X_2d[mask]
        if len(cluster_points) > 0:
            color = color_map.get(label) if color_map else config['hierarchical_colors'][i % len(config['hierarchical_colors'])]
            # Only add to legend if it's in the truncated list
            legend_label = f'Cluster {label}' if label in cluster_labels_for_legend else None
            ax.scatter(cluster_points[:, 0], cluster_points[:, 1], 
                      c=color, alpha=0.7, s=30, label=legend_label)
    
    # Plot noise
    noise_mask = labels == noise_label
    if np.any(noise_mask):
        noise_points = X_2d[noise_mask]
        ax.scatter(noise_points[:, 0], noise_points[:, 1], 
                  c='black', alpha=0.5, s=20, label='Noise')
    
    # Add note about truncated legend if there are more clusters than shown
    if len(cluster_labels) > max_legend_entries:
        ax.text(0.02, 0.02, f'... and {len(cluster_labels) - max_legend_entries} more clusters', 
                transform=ax.transAxes, fontsize=8, alpha=0.7)
    
    # Style the plot
    if config['show_subplot_titles'] and title:  # Only set title if titles are enabled and title is not empty
        ax.set_title(title, fontsize=config['title_fontsize'])
    if config['show_grid']:
        ax.grid(True, alpha=0.3)
    if config['show_legends']:
        ax.legend(bbox_to_anchor=config['legend_position'], loc='upper left', fontsize=config['legend_fontsize'])
    if not config['show_axes_labels']:
        ax.set_xticklabels([])
        ax.set_yticklabels([])


def _plot_ground_truth(ax, X_2d, object_ids, config):
    if object_ids is None:
        ax.set_visible(False)
        return
    
    unique_objects = sorted(list(set(object_ids)))
    colors_obj = [config['ground_truth_colors'][i % len(config['ground_truth_colors'])] 
                  for i in range(len(unique_objects))]
    color_map_obj = dict(zip(unique_objects, colors_obj))
    
    # Limit legend entries to avoid overcrowding
    max_legend_entries = config.get('max_legend_entries', 10)
    objects_for_legend = unique_objects[:max_legend_entries]
    
    for obj_id in unique_objects:
        mask = np.array(object_ids) == obj_id
        points = X_2d[mask]
        # Only add to legend if it's in the truncated list
        legend_label = f'Object {obj_id}' if obj_id in objects_for_legend else None
        ax.scatter(points[:, 0], points[:, 1], 
                  c=[color_map_obj[obj_id]], 
                  label=legend_label, 
                  alpha=0.9, s=30)
    
    # Add note about truncated legend if there are more objects than shown
    if len(unique_objects) > max_legend_entries:
        ax.text(0.02, 0.02, f'... and {len(unique_objects) - max_legend_entries} more objects', 
                transform=ax.transAxes, fontsize=8, alpha=0.7)
    
    if config['show_subplot_titles']:
        ax.set_title('Ground Truth (by Object ID)', fontsize=config['title_fontsize'])
    if config['show_grid']:
        ax.grid(True, alpha=0.3)
    if config['show_legends']:
        ax.legend(bbox_to_anchor=config['legend_position'], loc='upper left', fontsize=config['legend_fontsize'])
    if not config['show_axes_labels']:
        ax.set_xticklabels([])
        ax.set_yticklabels([])


def _create_combined_plots(X_2d, results, eps, min_pts, c_values, ratio_offsets, 
                          object_ids, exact_color_map, hierarchical_color_maps, config):
    # Create subplots with explicit height constraints
    fig, axes = plt.subplots(*config['layout'], figsize=config['figsize'])
    axes = axes.flatten()
    
    # Plot exact DBSCAN
    title = f"Exact DBSCAN (eps={eps})\n{results['exact_eps']['clusters']} clusters" if config['show_subplot_titles'] else ""
    _plot_single_clustering(
        axes[0], X_2d, results['exact_eps']['labels'],
        title, exact_color_map, config
    )
    
    # Plot LSH results
    for i, c in enumerate(c_values):
        best_result = _find_best_lsh_result(results, c, ratio_offsets)
        if best_result:
            lsh_color_map = hierarchical_color_maps.get(c) if hierarchical_color_maps else None
            title = f"LSH DBSCAN (c={c})\n{best_result['clusters']} clusters\nMisalignment: {best_result['misalignment']:.3f}" if config['show_subplot_titles'] else ""
            _plot_single_clustering(
                axes[i+1], X_2d, best_result['labels'],
                title, lsh_color_map, config
            )
    
    # Plot ground truth
    _plot_ground_truth(axes[len(c_values)+1], X_2d, object_ids, config)
    
    fig.subplots_adjust(
        left=0.08, right=0.85,  # Leave space for legends
        bottom=0.1, top=1-config['top_spacing'],
        hspace=config['vertical_spacing'],
        wspace=0.3  # Horizontal spacing between columns
    )
    
    # Finalize plot
    if config['show_overall_title']:
        fig.suptitle(f'DBSCAN Comparison: eps={eps}, min_pts={min_pts}', fontsize=14)
    
    plt.show()


def _create_separate_plots(X_2d, results, eps, min_pts, c_values, ratio_offsets, 
                          object_ids, exact_color_map, hierarchical_color_maps, config):
    # Plot exact DBSCAN
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    _plot_single_clustering(
        ax, X_2d, results['exact_eps']['labels'],
        f"Exact DBSCAN (eps={eps})\n{results['exact_eps']['clusters']} clusters",
        exact_color_map, config
    )
    plt.tight_layout()
    print("Exact DBSCAN plot created")
    plt.show()
    
    # Plot LSH results
    for c in c_values:
        best_result = _find_best_lsh_result(results, c, ratio_offsets)
        if best_result:
            fig, ax = plt.subplots(1, 1, figsize=(8, 6))
            lsh_color_map = hierarchical_color_maps.get(c) if hierarchical_color_maps else None
            _plot_single_clustering(
                ax, X_2d, best_result['labels'],
                f"LSH DBSCAN (c={c})\n{best_result['clusters']} clusters\n"
                f"Misalignment: {best_result['misalignment']:.3f}",
                lsh_color_map, config
            )
            plt.tight_layout()
            print(f"LSH DBSCAN (c={c}) plot created")
            plt.show()
    
    # Plot ground truth
    if object_ids is not None:
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))
        _plot_ground_truth(ax, X_2d, object_ids, config)
        plt.tight_layout()
        print("Ground truth plot created")
        plt.show()
