#!/usr/bin/env python
# coding: utf-8

# In[3]:


import pandas as pd
import numpy as np
import os, glob, json, yaml, io, math


# In[4]:


root = './data'
datasets_dir = 'TPP'

root = os.path.join(root, datasets_dir)
datasets = os.listdir(root)
datasets = ['citibike', 'covid19', 'earthquakes', 'hawkes_1_continuous_v1', 'hawkes_2_continuous_v1', 'poisson_continuous_v1', 
            'self_correct_continuous_v1', 'stationary_renewal_continuous_v1']


def diff(per_line, prepend = np._NoValue, append = np._NoValue):
    '''
    Avoid potential 0 output.
    '''
    return np.diff(per_line, prepend = prepend, append = append)

# In[5]:


for dataset in datasets:
    print(f'Processing {dataset}...')
    dataset_dir = os.path.join(root, dataset)

    # load the property card
    f_property = open(os.path.join(dataset_dir, 'dataset_card.yml'), 'r')
    dataset_card = yaml.safe_load(f_property)
    f_property.close()
    start_time = dataset_card['t_0']
    end_time = dataset_card['T']

    file_name = 'train.json'
    f_data = open(os.path.join(dataset_dir, file_name), 'r')
    dataset_json = json.load(f_data)
    f_data.close()

    # the mean and standard deviation of time
    df_dataset = pd.DataFrame.from_dict(dataset_json)
    df_dataset.time_seq = df_dataset.time_seq.apply(np.array, dtype = np.float32)
    df_dataset.time_seq = df_dataset.time_seq.apply(diff, prepend = start_time, append = end_time)
    df_dataset.time_seq = df_dataset.time_seq + 1e-30

    list_time_seq = df_dataset.time_seq.tolist()
    time_inteval = np.array([])
    for item in list_time_seq:
        time_inteval = np.concatenate((time_inteval, item[:-1]))
        
    mean_time = time_inteval.mean()
    std_time = time_inteval.std()

    # the mean and standard deviation of coordinates.
    # Extract all coordinates from the data.
    coordinates = []
    for event_seq in df_dataset.event:
        coordinates += event_seq
    
    coordinates = np.array(coordinates)
    mean_event = coordinates.mean(axis = 0).tolist()
    std_event = coordinates.std(axis = 0).tolist()

    del coordinates, time_inteval, df_dataset

    dataset_card['mean_time'] = mean_time.item()
    dataset_card['std_time'] = std_time.item()
    dataset_card['mean_coordinate'] = mean_event
    dataset_card['std_coordinate'] = std_event

    with io.open(os.path.join(dataset_dir, 'dataset_card.yml'), 'w', encoding = 'utf8') as outfile:
        yaml.dump(dataset_card, outfile, default_flow_style=False, allow_unicode=True)