import h5py
import numpy as np

def preprocess_events(timestamps, x_coords, y_coords, polarities, num_steps):
    unique_timestamps = np.unique(timestamps)
    start_time = unique_timestamps.min()
    end_time = unique_timestamps.max()
    total_time = end_time - start_time
    time_window = total_time / num_steps
    
    processed_timestamps = []
    processed_x_coords = []
    processed_y_coords = []
    processed_polarities = []
    
    current_window_start = start_time
    current_window_end = current_window_start + time_window
    current_window_events = []
    
    for timestamp, x, y, p in zip(timestamps, x_coords, y_coords, polarities):
        if timestamp >= current_window_end:
            if current_window_events:
                window_timestamps, window_x_coords, window_y_coords, window_polarities = zip(*current_window_events)
                processed_timestamps.extend(window_timestamps)
                processed_x_coords.extend(window_x_coords)
                processed_y_coords.extend(window_y_coords)
                processed_polarities.extend(window_polarities)
            current_window_start = current_window_end
            current_window_end = current_window_start + time_window
            current_window_events = []
        current_window_events.append((timestamp, x, y, p))
    
    if current_window_events:
        window_timestamps, window_x_coords, window_y_coords, window_polarities = zip(*current_window_events)
        processed_timestamps.extend(window_timestamps)
        processed_x_coords.extend(window_x_coords)
        processed_y_coords.extend(window_y_coords)
        processed_polarities.extend(window_polarities)
    
    return np.array(processed_timestamps), np.array(processed_x_coords), np.array(processed_y_coords), np.array(processed_polarities), time_window

def integrate_events_segment_to_frame(x: np.ndarray, y: np.ndarray, p: np.ndarray, H: int, W: int, j_l: int = 0, j_r: int = -1) -> np.ndarray:
    frame = np.zeros(shape=[2, H * W])
    x = x[j_l: j_r].astype(int)  # avoid overflow
    y = y[j_l: j_r].astype(int)
    p = p[j_l: j_r]
    mask = []
    mask.append(p == 0)
    mask.append(np.logical_not(mask[0]))
    for c in range(2):
        position = y[mask[c]] * W + x[mask[c]]
        events_number_per_pos = np.bincount(position)
        frame[c][np.arange(events_number_per_pos.size)] += events_number_per_pos
    return frame.reshape((2, H, W))

def create_event_sequence(timestamps, x_coords, y_coords, polarities, height, width, num_steps, time_window):
    event_sequence = []
    
    start_time = timestamps.min()
    
    for step in range(num_steps):
        current_window_start = start_time + step * time_window
        current_window_end = current_window_start + time_window
        
        mask = (timestamps >= current_window_start) & (timestamps < current_window_end)
        
        x_segment = x_coords[mask]
        y_segment = y_coords[mask]
        p_segment = polarities[mask]
        
        frame = integrate_events_segment_to_frame(x_segment, y_segment, p_segment, height, width)
        event_sequence.append(frame)
    
    return np.stack(event_sequence)

# open HDF5 file
file_path = '/mnt/DataDrive151/ZNC/EVENT/S005C003P021R002A060.h5'
with h5py.File(file_path, 'r') as h5_file:
    events = h5_file['/events'][:]
    timestamps = events[:, 0]
    x_coords = events[:, 1]
    y_coords = events[:, 2]
    polarities = events[:, 3]
    
    max_x, max_y = np.max(x_coords), np.max(y_coords)
    min_x, min_y = np.min(x_coords), np.min(y_coords)
    print(f"X range: {min_x} to {max_x}")
    print(f"Y range: {min_y} to {max_y}")
    
    height, width = 480, 640
    num_steps = 16
    
    processed_timestamps, processed_x_coords, processed_y_coords, processed_polarities, time_window = preprocess_events(
        timestamps, x_coords, y_coords, polarities, num_steps)
    
    event_sequence = create_event_sequence(processed_timestamps, processed_x_coords, processed_y_coords, processed_polarities, height, width, num_steps, time_window)
    
    print(f"event sequence shape: {event_sequence.shape}")
    print(event_sequence)
    print(f"event sequence max: {event_sequence.max()}")

# event_sequence is a (64, 2, 480, 640) tensor of event counts
