import json

import numpy as np


def main() -> None:
    data = np.load('reddit/reddit.npz')

    times = data['times']
    masks = data['masks']
    marks = data['marks']

    # Check if marks are 0-indexed and all are present
    unique_marks = np.unique(marks)
    if not np.array_equal(unique_marks, np.arange(unique_marks.min(), unique_marks.max() + 1)):
        raise ValueError(f"Marks are not 0-indexed or some event types are missing. Found marks: {unique_marks}")

    # Initialize the list to store the converted sequences
    easytpp_data = []

    # Iterate through each sequence
    for seq_idx in range(times.shape[0]):
        # Extract the sequence data
        time_seq = times[seq_idx]
        mask_seq = masks[seq_idx]
        mark_seq = marks[seq_idx]

        # Check that once a 0 appears in the mask, all subsequent values are 0
        if not np.all(mask_seq == 1) and not np.all(mask_seq[np.argmax(mask_seq == 0):] == 0):
            raise ValueError(f"Mask sequence at index {seq_idx} is invalid: zeros are not continuous.")

        # Filter out unobserved events based on the mask
        observed_times = time_seq[mask_seq == 1]
        observed_marks = mark_seq[mask_seq == 1]

        # Compute time_since_last_event (difference of consecutive times)
        time_since_last_event = np.diff(observed_times, prepend=observed_times[0])

        # Ensure time_since_last_event is always positive
        if np.any(time_since_last_event < 0):
            raise ValueError(f"Non-positive time difference found in sequence {seq_idx}.")

        # Create the dictionary for the current sequence
        seq_dict = {
            "dim_process": int(np.max(marks) + 1),  # Assuming marks are 0-indexed
            "seq_idx": seq_idx,
            "seq_len": len(observed_times),
            "time_since_start": observed_times.tolist(),
            "time_since_last_event": time_since_last_event.tolist(),
            "type_event": observed_marks.tolist(),
        }

        # Append the dictionary to the list
        easytpp_data.append(seq_dict)


        # Shuffle the data
    np.random.shuffle(easytpp_data)

    # Split into train, dev, and test sets
    total_sequences = len(easytpp_data)
    train_size = int(0.6 * total_sequences)
    dev_size = int(0.2 * total_sequences)

    train_data = easytpp_data[:train_size]
    dev_data = easytpp_data[train_size:train_size + dev_size]
    test_data = easytpp_data[train_size + dev_size:]

    # Save the splits as JSON files
    with open('reddit/train.json', 'w') as train_file:
        json.dump(train_data, train_file, indent=0)

    with open('reddit/dev.json', 'w') as dev_file:
        json.dump(dev_data, dev_file, indent=0)

    with open('reddit/test.json', 'w') as test_file:
        json.dump(test_data, test_file, indent=0)


if __name__ == "__main__":
    main()
