from datasets import load_dataset
import numpy as np
import json

datasets = ['amazon', 'retweet', 'taxi', 'taobao', 'stackoverflow', 'lastfm', 'mimic', 'ehr',]
# datasets = ['ehr', 'nlb', 'lastfm']
datasets_path = {
    'retweet': '../data/retweet_json/',
    'stackoverflow': '../data/stackoverflow_json/',
    'ehr': '../data/ehrshot_all/ehrshot_cpt4/',
    'nlb': '../data/nlb_1rep/',
    'lastfm': '../data/lastfm_json/',
    'mimic': '../data/mimicii/',
}

for dataset in datasets:
    print('\n' + dataset)
    if dataset in ['taxi', 'taobao', 'amazon']:
        seq_len = []
        for split in ['train', 'validation', 'test']:
            print(f'split: {split}')
            num_events = 0
            data = load_dataset('easytpp/' + dataset)[split]
            for seq in data:
                num_events += len(seq['type_event'])
                seq_len.append(len(seq['type_event']))
            print(f'num_events: {num_events}')
            print(f'num_marks: {seq['dim_process']}')
            print(f'num_seqs: {len(data)}')
        print(f'min seq_len: {min(seq_len)}')
        print(f'max seq_len: {max(seq_len)}')
        print(f'mean seq_len: {np.mean(np.array(seq_len))}')
    else:
        seq_len = []
        for split in ['train', 'dev', 'test']:
            print(f'split: {split}')
            num_events = 0
            with open(datasets_path[dataset] + f'{split}.json', 'r') as f:
                data = json.load(f)
            for seq in data:
                num_events += len(seq['type_event'])
                seq_len.append(len(seq['type_event']))
            print(f'num_events: {num_events}')
            print(seq['dim_process'])
            print(f'num_seqs: {len(data)}')
        print(f'min seq_len: {min(seq_len)}')
        print(f'max seq_len: {max(seq_len)}')
        print(f'mean seq_len: {np.mean(np.array(seq_len))}')

# TODO: if we want to report any stats for dts