import os
import sys
import numpy as np
import pickle

# Add the parent directory to the experiments directory to Python path
experiments_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, experiments_dir)

from dbscan_experiment import DBSCANExperiment, minimize_misalignment
from sklearn.decomposition import PCA
from PIL import Image

def synthetic_squares():
    import numpy as np
    rng = np.random.default_rng(42)

    # Define image dimensions
    image_height = 100
    image_width = 100

    # Define white square properties
    square_size1 = 50
    square_size2 = 60
    square_value1 = 100
    square_value2 = 90
    background_value = 0

    # Define number of images for each type
    n = 500  # Number of images with white square
    # Create list to store images with white squares
    square_images1 = []
    square_images2 = []

    # Generate images with white squares
    for _ in range(n):
        image = np.full((image_height, image_width), background_value, dtype=np.uint8)

        # Randomly choose top-left corner for the square
        max_x = image_width - square_size1
        max_y = image_height - square_size1
        x = rng.integers(0, max_x + 1)
        y = rng.integers(0, max_y + 1)

        # Set pixels within the chosen region to square_value
        image[y:y + square_size1, x:x + square_size1] = square_value1
        square_images1.append(image)

    for _ in range(n):
        image = np.full((image_height, image_width), background_value, dtype=np.uint8)

        # Randomly choose top-left corner for the square
        max_x = image_width - square_size2
        max_y = image_height - square_size2
        x = rng.integers(0, max_x + 1)
        y = rng.integers(0, max_y + 1)

        # Set pixels within the chosen region to square_value
        image[y:y + square_size2, x:x + square_size2] = square_value2
        square_images2.append(image)

    # Concatenate the two lists of images
    all_images = square_images1 + square_images2

    # Create corresponding labels
    labels = ['square'] * n + ['dim square']*n

    # Convert the lists to NumPy arrays
    all_images = np.array(all_images).reshape(-1, image_height * image_width)
    labels = np.array(labels)
    return all_images, labels

def main():
    X, labels = synthetic_squares()
    unique_labels = np.unique(labels)
    label_to_int = {label: i for i, label in enumerate(unique_labels)}
    int_labels = np.array([label_to_int[label] for label in labels])
    # print(f"X shape: {X.shape}")
    # print(labels.shape)


    config = {
        'dataset_name': 'custom',
        'data': X,
        'labels': labels,
        'eps': 7500,
        'min_pts': 3,
        'c_values': [],
        'center_ratio': (1, 1),
        'delta': 0.001,
        'pickling': False,
        'pca_output': 'all',
        'verbose': False,
        'plot_config': {
            'separate_plots': True,
            'show_subplot_titles': False,
        }
    }

    for eps in [3300, 5500]:
        config['eps'] = eps
        experiment = DBSCANExperiment(config)
        results = experiment.run_experiment()
        print(f"Misalignment: {minimize_misalignment(results['exact_eps']['labels'], int_labels)}")
    

    X_12D = PCA(n_components=12).fit_transform(X)
    print(f"X_12D shape: {X_12D.shape}")
    set_of_dists = set([np.linalg.norm(X_12D[i] - X_12D[j]) for i in range(X_12D.shape[0]) for j in range(i+1, X_12D.shape[0])])
    set_of_dists_in_range = set([dist for dist in set_of_dists if dist <= 607 and dist >= 600])
    set_of_dists_in_range = sorted(set_of_dists_in_range)
    print(f"Set of distances in range: {set_of_dists_in_range}")
    min_misalignment = 1.0
    config['data'] = X_12D
    config['pca_output'] = 'none'
    for eps in set_of_dists_in_range:
        config['eps'] = eps
        experiment = DBSCANExperiment(config)
        results = experiment.run_experiment()
        # Handle negative labels (-1 for noise points) in bincount
        labels = results['exact_eps']['labels']
        min_label = min(labels)
        if min_label < 0:
            shifted_labels = np.array(labels) - min_label
            counts = np.bincount(shifted_labels)
            print(f"Max cluster size: {max(counts[1:]) if len(counts) > 1 else 0}")
        else:
            counts = np.bincount(labels)
            print(f"Max cluster size: {max(counts)}")
            
        print(f"Misalignment: {minimize_misalignment(results['exact_eps']['labels'], int_labels)}")
        if minimize_misalignment(results['exact_eps']['labels'], int_labels) < min_misalignment:
            min_misalignment = minimize_misalignment(results['exact_eps']['labels'], int_labels)
            min_eps = eps
    print(f"Min misalignment: {min_misalignment}, Min eps: {min_eps}")

if __name__ == "__main__":
    main()