import argparse
import random
import logging
import time
import os, io

date_time_idx = time.strftime("%Y%m%d-%H%M%S")
logging.basicConfig(filename=f'alpha_integrate/synthetic_data/shufflelogs/logs_{date_time_idx}.txt', filemode='w', level=logging.INFO)


def main(args):
    datasets = args.datasets
    base_directory = 'alpha_integrate/synthetic_data/steps_dataset/'
    
    # Create a combined directory name from all datasets
    combined_dataset_name = '+'.join(datasets)
    combined_directory = f'alpha_integrate/synthetic_data/final_steps_dataset/{combined_dataset_name}/'
    os.makedirs(combined_directory, exist_ok=True)

    TRAIN_SPLIT = 0.95
    TEST_VAL_SPLIT = 0.97

    all_steps = []
    num_duplicates = 0
    duplicate_dict = dict()

    for dataset in datasets:
        directory = os.path.join(base_directory, dataset)
        files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
        files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))

        logging.info(f"Reading from directory: {directory}")

        for file in files:
            data_path = os.path.join(directory, file)
            with io.open(data_path, mode='r', encoding='utf-8') as f:
                new_steps = []
                for line in f:
                    if line == '\n':
                        if len(new_steps) > 0:
                            step_expr_str = new_steps[0].split('\t\t')[0]
                            if step_expr_str in duplicate_dict:
                                duplicate_dict[step_expr_str] += 1
                                num_duplicates += 1
                                new_steps = []
                                continue
                            else:
                                all_steps.append(new_steps)
                                duplicate_dict[step_expr_str] = 1
                                new_steps = []
                    else:
                        new_steps.append(line)

    sorted_dict = dict(sorted(duplicate_dict.items(), key=lambda item: item[1], reverse=True))
    logging.info(f"Top 20 repeated elements:")
    for i, (key, value) in enumerate(sorted_dict.items()):
        if i == 20:
            break
        logging.info(f"{key}: {value}")

    N_SHUFFLE = 5
    logging.info(f"Obtained steps for {len(all_steps)} expressions from datasets {datasets}. Found {num_duplicates} duplicates.")
    logging.info(f"Shuffling all steps {N_SHUFFLE} times")

    t0 = time.time()
    for i in range(N_SHUFFLE):
        random.shuffle(all_steps)
    t1 = time.time()

    logging.info(f"Shuffling done, {(t1 - t0)/N_SHUFFLE} seconds per shuffle")

    train_split = int(len(all_steps) * TRAIN_SPLIT)
    test_val_split = int(len(all_steps) * TEST_VAL_SPLIT)

    train_steps = 0
    val_steps = 0 
    test_steps = 0

    train_data = []
    for i in range(train_split):
        train_data += all_steps[i]

    for i in range(N_SHUFFLE):
        random.shuffle(train_data)

    val_data = []
    for i in range(test_val_split, len(all_steps)):
        val_data += all_steps[i]

    for i in range(N_SHUFFLE):
        random.shuffle(val_data)

    with io.open(os.path.join(combined_directory, 'train.txt'), mode='w', encoding='utf-8') as f:
        for data in train_data:
            f.write(data)
            train_steps += 1

    with io.open(os.path.join(combined_directory, 'test.txt'), mode='w', encoding='utf-8') as f:
        for i in range(train_split, test_val_split):
            for step in all_steps[i]:
                f.write(step)
                test_steps += 1
            f.write('\n')

    with io.open(os.path.join(combined_directory, 'val.txt'), mode='w', encoding='utf-8') as f:
        for data in val_data:
            f.write(data)
            val_steps += 1

    logging.info(f"Total number of steps: {sum([len(s) for s in all_steps])}")
    logging.info(f"Train steps: {train_steps}")
    logging.info(f"Validation steps: {val_steps}")
    logging.info(f"Test steps: {test_steps}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Shuffle and dump the data in multiple files into one file.')
    parser.add_argument('--datasets', type=str, nargs='+', required=True, help='List of dataset names to process')
    args = parser.parse_args()

    main(args)
