"""
Notes on the data formats and script to map our dataset into the EasyTPP pkl format.

# pkl is a dictionary with keys: dict_keys(['dim_process', 'test'])
# data['dim_process'] = 17
# len(data['test']) -> 500  # 500 test sequences, list of list of dictionary
# len(data['test'][0]) -> 62  # first seq length: 62
# data['test'][0] -> list of dictionary
# data['test'][0][0] -> {'time_since_last_event': 0, 'time_since_start': 151205.8087, 'type_event': 16}
# data['test'][0][1] -> {'time_since_last_event': 0.0011999999987892, 'time_since_start': 151205.8099, 'type_event': 0}

# {'idx_event': 1,  # some don't have it
#  'type_event': 8,
#  'time_since_start': 0.0,
#  'time_since_last_event': 0.0}



# In our format, each user has a dictionary with: dict_keys(['user', 'T', 'times', 'marks'])
# data[0]['T'] -> 192.0
# 'times' don't have to start from 0 or end at 'T'
# 'marks'
"""

import pickle
import json

# with open('train.pkl', 'rb') as f:  # inspect the data format supported by EasyTPP
#     data = pickle.load(f)
# print(data)


# mapping our data
dataset_name = 'taobao'
name_mapping = {'train': 'train', 'valid': 'dev', 'test': 'test'}  # the EasyTPP repo takes 'dev' as dict key
for split in name_mapping.keys():  # for all train/validation/test splits
    print('Current split: {}'.format(split))

    data = {'dim_process': 1000}  # new dictionary for each split
    data[name_mapping[split]] = []
    with open(split + f'_{dataset_name}.jsonl', 'rb') as f:
        for line in f:
            user_dict = json.loads(line)
            times, marks, T = user_dict['times'], user_dict['marks'], user_dict['T']
            if times[0] > 0:
                times = [0] + times
                marks = [-1] + marks
            elif times[0] < 0:
                raise ValueError

            if times[-1] < T:
                times.append(T)
                marks.append(-1)
            elif times[-1] > T:
                raise ValueError

            user_seq = [{'time_since_last_event': 0., 'time_since_start': times[0], 'type_event': marks[0]}]  # first event
            for i in range(1, len(times)):
                user_seq.append({'time_since_last_event': times[i] - times[i - 1],
                                 'time_since_start': times[i], 'type_event': marks[i]})
            data[name_mapping[split]].append(user_seq)

    with open(split + f'_{dataset_name}.pkl', 'wb') as f:
        pickle.dump(data, f)
