import argparse
import cProfile
import pathlib
import random
import timeit
import pstats

import torch

from recognizers.dataset_generation.generate_datasets import get_saved_language, generate_example

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--sampler', type=pathlib.Path, required=True,
        help='A .pt file containing an automaton prepared for sampling.')
    parser.add_argument('--dtype', choices=['float16', 'float32'], default='float16')
    parser.add_argument('--output', type=pathlib.Path, default='benchmarking-output.txt')
    parser.add_argument('--device', type=torch.device, required=True)
    args = parser.parse_args()

    dtype = getattr(torch, args.dtype)
    device = args.device

    length_range = (0, 100)
    language = get_saved_language(args.sampler, dtype, device)
    restricted_language = language.with_length_range(length_range)
    alphabet_size = language.alphabet_size()
    excluded_strings = None
    repetitions = 10000
    include_edit_distance = False

    def func():
        generator = random.Random(123)
        generate_example(
            restricted_language,
            length_range,
            alphabet_size,
            only_negative=False,
            perturbation_probability=0.5,
            strict_num_edits_distribution=False,
            include_log_probability=False,
            include_next_symbols=False,
            include_edit_distance=False,
            generator=generator,
            excluded_strings=excluded_strings
        )

    total_duration = None
    def run_timeit():
        nonlocal total_duration
        total_duration = timeit.timeit(func, number=repetitions)

    profiler = cProfile.Profile()
    profiler.runctx('run_timeit()', globals(), locals())
    mean_duration = total_duration / repetitions

    with open(args.output, "w") as f:
        stats = pstats.Stats(profiler, stream=f)
        stats.strip_dirs()
        stats.sort_stats('cumulative')
        stats.print_stats()
        f.write(f'Mean duration: {mean_duration:.5f}')

if __name__ == '__main__':
    main()
