from datasets import load_dataset
import numpy as np
import json

# dataset_name = 'retweet'
dataset_name = 'stackoverflow'
dataset = load_dataset('easytpp/' + dataset_name)
np.random.seed(123)

split_name = {'train': 'train', 'validation': 'dev', 'test': 'test'}
for split in ['train', 'validation', 'test']:
    print(f'Current split: {split}')

    data = dataset[split]
    keys = ['seq_len', 'time_since_start', 'seq_idx', 'time_since_last_event', 'type_event', 'dim_process']
    min_dt = min([min([t for t in seq['time_since_last_event'][1:] if t != 0]) for seq in data])
    print(f'Min non-zero dt: {min_dt}')

    all_sequences = []
    count_dt0 = 0

    for i in range(len(data)):
        user_seq = data[i]

        ts = user_seq['time_since_start']
        dts = user_seq['time_since_last_event']
        assert (len(ts) == len(dts))

        if min(dts[1:]) <= 0:
            count_dt0 += 1
            dts = np.array(dts[1:])
            dts = np.where(dts > 0, dts, np.random.uniform(low=dts, high=dts+min_dt))
            ts = np.cumsum(dts)
            dts = [0] + list([float(t) for t in dts])
            ts = [0] + list([float(t) for t in ts])
            assert (len(ts) == len(dts))
            assert (len(ts) == len(user_seq['type_event']))
            assert (len(ts) == user_seq['seq_len'])

        all_sequences.append(
            {
                'dim_process': user_seq['dim_process'],
                'seq_idx': user_seq['seq_idx'],
                'seq_len': user_seq['seq_len'],
                'time_since_start': ts,
                'time_since_last_event': dts,
                'type_event': user_seq['type_event'],
            }
        )
    print(f'Detect dt=0 for {count_dt0} sequences out of {len(all_sequences)}.')

    with open(f'./data/{dataset_name}_json/{split_name[split]}.json', 'w') as f:
        json.dump(all_sequences, f)
