import argparse
import json
import pathlib
import random
import timeit

import torch

from recognizers.string_sampling.finite_automaton_weight_pushing import (
    push_finite_automaton_weights
)
from recognizers.string_sampling.prepare_sampler import (
    lift_finite_automaton
)
from recognizers.language_sampling.sample_automaton import (
    sample_finite_automaton
)

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=torch.device, required=True)
    parser.add_argument('--num-repetitions', type=int, default=1)
    parser.add_argument('--output', type=pathlib.Path, required=True)
    args = parser.parse_args()

    max_length = 500
    dtype = torch.float32
    device = args.device
    print(f'device: {device}')
    num_repetitions = args.num_repetitions

    durations = []
    is_first = True
    for num_states in range(10, 200+1, 10):
        print(f'{num_states} states')
        generator = random.Random(123)
        M = sample_finite_automaton(num_states, 2, generator)
        M = lift_finite_automaton(M, max_length, dtype, device)

        def func():
            push_finite_automaton_weights(M, dtype, device)

        if is_first:
            print('warming up')
            func()
            is_first = False

        total_duration = timeit.timeit(func, number=num_repetitions)
        mean_duration = total_duration / num_repetitions
        durations.append(dict(num_states=num_states, duration_in_seconds=mean_duration))

    with args.output.open('w') as fout:
        json.dump(dict(
            max_length=max_length,
            num_repetitions=num_repetitions,
            device=device.type,
            durations=durations
        ), fout, indent=2)

if __name__ == '__main__':
    main()
