import argparse
import cProfile
import pathlib
import random
import timeit
import pstats
from xml.etree.ElementInclude import include

import torch

from recognizers.dataset_generation.generate_datasets import (
    generate_random_string,
    get_saved_language
)

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--sampler', type=pathlib.Path, required=True,
        help='A .pt file containing a CFG prepared for sampling.')
    parser.add_argument('--dtype', choices=['float16', 'float32'], default='float16')
    parser.add_argument('--device', type=torch.device, required=True)
    parser.add_argument('--random-seed', type=int, required=True,
        help='Random seed used for random sampling.')
    args = parser.parse_args()

    dtype = getattr(torch, args.dtype)
    device = args.device

    generator = random.Random(args.random_seed)

    length_range = (0, 20)
    language = get_saved_language(args.sampler, dtype, device)
    alphabet_size = language.alphabet_size()
    repetitions = 10000
    n_accepted = 0
    accepted_strings = set()
    non_accepted_strings = set()

    for _ in range(repetitions):
        s = generate_random_string(length_range, alphabet_size, generator)
        if language.uncached_label(s):
            n_accepted += 1
            accepted_strings.add(s)
        else:
            non_accepted_strings.add(s)

    mean_accepted = n_accepted / repetitions
    print(f'Percentage of random strings in the language: {mean_accepted*100:.2f}%')

if __name__ == '__main__':
    main()
