# code to shuffle and dump the data in multiple files into one file
import argparse
import random
import logging
import time
import os, io

date_time_idx = time.strftime("%Y%m%d-%H%M%S")
logging.basicConfig(filename = f'alpha_integrate/synthetic_data/shufflelogs/logs_{date_time_idx}.txt', filemode='w', level=logging.INFO)


def main(args):
    dataset = args.dataset
    directory = f'alpha_integrate/synthetic_data/steps_dataset/{dataset}/'
    os.makedirs(directory, exist_ok=True)
    files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
    files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))

    logging.info(f"Reading from directory: {directory}")

    TRAIN_SPLIT = 0.95
    TEST_VAL_SPLIT = 0.97

    all_steps = []
    num_duplicates = 0
    # dictionary to avoid duplicate entries
    duplicate_dict = dict()

    for file in files:
        #logging.info(f'Reading from: {file}')
        data_path = directory + file
        with io.open(data_path, mode='r', encoding='utf-8') as f:
            new_steps = []
            for line in f:

                if line == '\n':
                    if len(new_steps) > 0:
                        step_expr_str = new_steps[0].split('\t\t')[0]
                        #print(step_expr_str)
                        if step_expr_str in duplicate_dict:
                            duplicate_dict[step_expr_str] += 1   
                            num_duplicates += 1
                            new_steps = []
                            continue
                        else:
                            all_steps.append(new_steps)
                            duplicate_dict[step_expr_str] = 1
                            new_steps = []
                else:
                    new_steps.append(line)

    # sorth duplicate dict by value and print (key, value) pairs in descending order
    sorted_dict = dict(sorted(duplicate_dict.items(), key=lambda item: item[1], reverse=True))
    # print top 20 repeated elements
    logging.info(f"Top 20 repeated elements:")
    for i, (key, value) in enumerate(sorted_dict.items()):
        if i == 20:
            break
        logging.info(f"{key}: {value}")

    # shuffle all steps 5 times
    N_SHUFFLE = 5
    logging.info(f"Obtained steps for {len(all_steps)} expressions. Found {num_duplicates} duplicates.")
    logging.info(f"Shuffling all steps {N_SHUFFLE} times")

    t0 = time.time()
    for i in range(N_SHUFFLE):
        random.shuffle(all_steps)
    t1 = time.time()

    logging.info(f"Shuffling done, {(t1 - t0)/N_SHUFFLE} seconds per shuffle")

    # split the data into train, test and validation
    # and save into 'final_dataset/{DATASET}/' directory
    train_split = int(len(all_steps) * TRAIN_SPLIT)
    test_val_split = int(len(all_steps) * TEST_VAL_SPLIT)
    train_len = train_split
    test_len = test_val_split - train_split
    val_len = len(all_steps) - test_val_split

    train_steps = 0
    val_steps = 0 
    test_steps = 0

    # save the data
    path = f'alpha_integrate/synthetic_data/final_steps_dataset/{dataset}/'
    os.makedirs(path, exist_ok=True)

    train_data = []
    for i in range(train_split):
        train_data += all_steps[i]

    for i in range(N_SHUFFLE):
        random.shuffle(train_data)

    val_data = []
    for i in range(test_val_split, len(all_steps)):
        val_data += all_steps[i]

    for i in range(N_SHUFFLE):
        random.shuffle(val_data)

    with io.open(path + 'train.txt', mode='w', encoding='utf-8') as f:
        for data in train_data:
            f.write(data)
            train_steps += 1

    with io.open(path + 'test.txt', mode='w', encoding='utf-8') as f:
        for i in range(train_split, test_val_split):
            for step in all_steps[i]:
                f.write(step)
                test_steps += 1
            f.write('\n')  

    with io.open(path + 'val.txt', mode='w', encoding='utf-8') as f:
        for data in val_data:
            f.write(data)
            val_steps += 1

    logging.info(f"Total number of steps: {sum([len(s) for s in all_steps])}")
    logging.info(f"Train steps: {train_steps}")
    logging.info(f"Validation steps: {val_steps}")
    logging.info(f"Test expressions: {test_val_split - train_split}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Shuffle and dump the data in multiple files into one file.')
    parser.add_argument('--dataset', type=str, required=True, help='Name of the dataset to process')
    args = parser.parse_args()

    main(args)