import argparse
import random
import sys
import torch

from recognizers.language_sampling.sample_cnf_cfg import sample_cnf_cfg
from recognizers.language_sampling.sample_dfa import sample_dfa
from recognizers.language_sampling.sample_dot_depth import sample_dot_depth
from recognizers.language_sampling.sample_podfa import sample_podfa
from recognizers.dataset_generation.weighted_language import FiniteLanguageError, PODFALanguageError


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--language-class', choices=[
        'regular',
        'context-free',
        'podfa',
        'star-free'
    ], required=True)
    parser.add_argument('--random-seed', type=int, required=True)
    parser.add_argument('--mean-num-states', type=int)
    parser.add_argument('--mean-alphabet-size', type=int)
    parser.add_argument('--mean-num-variables', type=int)
    parser.add_argument('--mean-num-lexical-rules', type=int)
    parser.add_argument('--mean-num-binary-rules', type=int)
    parser.add_argument('--output', required=True)
    args = parser.parse_args()

    generator = random.Random(args.random_seed)
    match args.language_class:
        case 'regular':
            try:
                mean_num_states = args.mean_num_states if args.mean_num_states else 10
                mean_alphabet_size = args.mean_alphabet_size if args.mean_alphabet_size else 10
                language = sample_dfa(
                    mean_num_states=mean_num_states,
                    mean_alphabet_size=mean_alphabet_size,
                    transition_probability=0.8,
                    accept_probability=0.4,
                    generator=generator
                )
            except PODFALanguageError as e:
                print(f'error: {e}', file=sys.stderr)
                sys.exit(1)
        case 'context-free':
            try:
                mean_num_variables = args.mean_num_variables if args.mean_num_variables else 20
                mean_alphabet_size = args.mean_alphabet_size if args.mean_alphabet_size else 20
                mean_num_lexical_rules = args.mean_num_lexical_rules if args.mean_num_lexical_rules else 35
                mean_num_binary_rules = args.mean_num_binary_rules if args.mean_num_binary_rules else 20
                language = sample_cnf_cfg(
                    mean_num_variables=mean_num_variables,
                    mean_num_terminals=mean_alphabet_size,
                    mean_num_lexical_rules=mean_num_lexical_rules,
                    mean_num_binary_rules=mean_num_binary_rules,
                    mean_num_chains=3,
                    mean_chain_length=5,
                    generator=generator
                )
            except FiniteLanguageError as e:
                print(f'error: {e}', file=sys.stderr)
                sys.exit(1)
        case 'podfa':
            mean_num_states = args.mean_num_states if args.mean_num_states else 10
            mean_alphabet_size = args.mean_alphabet_size if args.mean_alphabet_size else 10
            language = sample_podfa(
                mean_num_states=mean_num_states,
                mean_alphabet_size=mean_alphabet_size,
                transition_probability=0.5,
                accept_probability=0.4,
                generator=generator
            )
        case 'star-free':
            language, depth = sample_dot_depth(
                mean_alphabet_size=10,
                mean_depth=10,
                max_bool_ops=500,
                max_concat_ops=500,
                accept_probability=0.4,
                generator=generator
            )
        case _:
            raise ValueError
    if args.language_class == 'star-free':
        torch.save(dict(
            random_seed=args.random_seed,
            language_class=args.language_class,
            language=language,
            depth=depth
        ), args.output)
    else:
        torch.save(dict(
            random_seed=args.random_seed,
            language_class=args.language_class,
            language=language
        ), args.output)

if __name__ == '__main__':
    main()