import h5py
import numpy as np
from sklearn.cluster import DBSCAN

def load_events(file_path):
    with h5py.File(file_path, 'r') as h5_file:
        events = h5_file['/events'][:]
    return events

def filter_sparse_events(events, eps=5, min_samples=10):
    coords = np.vstack((events[:, 1], events[:, 2])).T
    db = DBSCAN(eps=eps, min_samples=min_samples).fit(coords)
    core_samples_mask = np.zeros_like(db.labels_, dtype=bool)
    core_samples_mask[db.core_sample_indices_] = True
    filtered_events = events[core_samples_mask]
    return filtered_events

def preprocess_events(events, num_steps, min_events_threshold):
    timestamps, x_coords, y_coords, polarities = events[:, 0], events[:, 1], events[:, 2], events[:, 3]
    unique_timestamps = np.unique(timestamps)
    start_time = unique_timestamps.min()
    end_time = unique_timestamps.max()
    total_time = end_time - start_time
    time_window = total_time / num_steps

    processed_timestamps = []
    processed_x_coords = []
    processed_y_coords = []
    processed_polarities = []

    current_window_start = start_time
    current_window_end = current_window_start + time_window
    current_window_events = []

    for timestamp, x, y, p in zip(timestamps, x_coords, y_coords, polarities):
        if timestamp >= current_window_end:
            if len(current_window_events) >= min_events_threshold:
                window_events = np.array(current_window_events)
                filtered_events = filter_sparse_events(window_events)
                processed_timestamps.extend(filtered_events[:, 0])
                processed_x_coords.extend(filtered_events[:, 1])
                processed_y_coords.extend(filtered_events[:, 2])
                processed_polarities.extend(filtered_events[:, 3])
            current_window_start = current_window_end
            current_window_end = current_window_start + time_window
            current_window_events = []
        current_window_events.append((timestamp, x, y, p))

    if len(current_window_events) >= min_events_threshold:
        window_events = np.array(current_window_events)
        filtered_events = filter_sparse_events(window_events)
        processed_timestamps.extend(filtered_events[:, 0])
        processed_x_coords.extend(filtered_events[:, 1])
        processed_y_coords.extend(filtered_events[:, 2])
        processed_polarities.extend(filtered_events[:, 3])

    return np.array(processed_timestamps), np.array(processed_x_coords), np.array(processed_y_coords), np.array(processed_polarities)

# Usage example
file_path = '/mnt/DataDrive151/ZNC/EVENT/S005C003P021R002A060.h5'
events = load_events(file_path)

# print count before processing
print(f"events before: {len(events)}")

num_steps = 64
min_events_threshold = 10

processed_timestamps, processed_x_coords, processed_y_coords, processed_polarities = preprocess_events(events, num_steps, min_events_threshold)

# print count after processing
print(f"events after: {len(processed_timestamps)}")
