import argparse

import torch

from recognizers.automata.finite_automaton import FiniteAutomaton
from recognizers.grammars.context_free_grammar import ContextFreeGrammar
from recognizers.string_sampling.finite_automaton.weight_lifting import (
    lift_finite_automaton_weights
)
from recognizers.string_sampling.finite_automaton.weight_pushing import (
    push_finite_automaton_weights
)
from recognizers.string_sampling.cnf_cfg.weight_lifting import (
    lift_cnf_cfg_weights
)
from recognizers.string_sampling.cnf_cfg.weight_pushing import (
    push_cnf_cfg_weights
)

def main() -> None:

    parser = argparse.ArgumentParser()
    parser.add_argument('--input', required=True)
    parser.add_argument('--output', required=True)
    parser.add_argument('--max-length', type=int, required=True)
    parser.add_argument('--dtype', choices=['float16', 'float32'], default='float32')
    parser.add_argument('--device', type=torch.device, required=True)
    args = parser.parse_args()

    dtype = getattr(torch, args.dtype)
    device = args.device

    data = torch.load(args.input, weights_only=False)
    language = data.pop('language')
    match language:
        case FiniteAutomaton():
            wfa = lift_finite_automaton_weights(
                language,
                args.max_length,
                dtype,
                device
            )
            sampler = push_finite_automaton_weights(
                wfa,
                dtype,
                device
            )
        case ContextFreeGrammar():
            wcfg = lift_cnf_cfg_weights(
                language,
                args.max_length,
                dtype,
                device
            )
            sampler = push_cnf_cfg_weights(
                wcfg,
                dtype,
                device
            )
        case _:
            raise ValueError
    data['sampler'] = sampler
    torch.save(data, args.output)

if __name__ == '__main__':
    main()
