import json
from pathlib import Path

import numpy as np
import pandas as pd

QUANTIZE_BINS = 10
WINDOW_SIZE = 256


def make_windows(x: np.ndarray, window_size: int) -> list[np.ndarray]:
    return [
        x[i:i + window_size] for i in
        range(0, len(x) // window_size * window_size, window_size)
    ]

def process(name: str) -> None:
    df = pd.read_csv(
        f'lobster/{name}/message.csv',
        names=['Time', 'Type', 'OrderID', 'Size', 'Price', 'Direction'],
    )
    df = df[df['Type'].isin([1, 3])]

    df['Type_Direction'] = df['Direction'].astype(str) + '_' + df['Type'].astype(str)
    df['Type_Direction'] = df['Type_Direction'].astype('category')
    df['Type_Direction_Code'] = df['Type_Direction'].cat.codes

    df['Diff'] = df['Time'].diff().fillna(0)
    df['Diff_Quantized'], bins = pd.qcut(
        df['Diff'], q=QUANTIZE_BINS, duplicates='drop', labels=False, retbins=True,
    )

    times = df['Diff_Quantized'].values.cumsum()
    marks = df['Type_Direction_Code'].values
    dim = df['Type_Direction_Code'].nunique()

    time_windows = make_windows(times, WINDOW_SIZE)
    mark_windows = make_windows(marks, WINDOW_SIZE)

    events = []
    for seq_idx, (t, x) in enumerate(zip(time_windows, mark_windows)):
        time_since_start = t - t[0]
        time_since_last_event = np.concatenate([[0], np.diff(t)])

        # Create the events
        events.append({
            "dim_process": dim,
            "seq_idx": seq_idx,
            "seq_len": len(t),
            "time_since_start": time_since_start.tolist(),
            "time_since_last_event": time_since_last_event.tolist(),
            "type_event": x.tolist(),
        })

    # Calculate the split indices
    total_events = len(events)
    train_end = int(total_events * 0.6)
    val_end = int(total_events * 0.8)

    # Split the events
    train_events = events[:train_end]
    val_events = events[train_end:val_end]
    test_events = events[val_end:]

    save_dir = Path(f'../data/lobster/{name}')
    with open(save_dir / "train.json", "w") as f:
        json.dump(train_events, f)
    with open(save_dir / "dev.json", "w") as f:
        json.dump(val_events, f)
    with open(save_dir / "test.json", "w") as f:
        json.dump(test_events, f)


if __name__ == '__main__':
    # Other possible datasets: 'aapl', 'goog', 'amzn', 'intc'
    for name in ['msft']:
        process(name)
