import argparse
import json
import pathlib
import random
import timeit

import torch

from recognizers.language_sampling.sample_automaton import (
    sample_nf_td_pda
)


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=torch.device, required=True)
    parser.add_argument('--num-repetitions', type=int, default=1)
    parser.add_argument('--output', type=pathlib.Path, required=True)
    args = parser.parse_args()

    dtype = torch.float32
    device = args.device
    print(f'device: {device}')
    num_repetitions = args.num_repetitions

    durations = []
    is_first = True
    for non_terminals_number in range(2, 4+1, 2):
        print(f"Number of non-terminals in the underlying CNF CFG: {non_terminals_number}")
        generator = random.Random(123)

        def func():
            M = sample_nf_td_pda(3, non_terminals_number, generator)
        
        if is_first:
            print('warming up')
            func()
            is_first = False
    
        total_duration = timeit.timeit(func, number=num_repetitions)
        mean_duration = total_duration / num_repetitions
        durations.append(dict(non_terminals_number=non_terminals_number, duration_in_seconds=mean_duration))

    with args.output.open('w') as fout:
        json.dump(dict(
            #max_length=max_length,
            #non_terminals_number=non_terminals_number,
            device=device.type,
            durations=durations
        ), fout, indent=2)

if __name__ == '__main__':
    main()
