import tomita_generator as tomita
import argparse
from concurrent.futures import ProcessPoolExecutor
import random
from utils import *
from tqdm import tqdm


##
#python create_negative_samples.py --seed 1234 --num_samples 1000 --grammar 1 --min_len 1 --max_len 50 --fpath ../data/MSI/Tomita-1/train_neg_src.txt
##

def get_params():
    parser = argparse.ArgumentParser(description='tomita data generation')
    parser.add_argument("--seed", type = int, default = 1234)
    parser.add_argument("--grammar", type=int, default=1)
    parser.add_argument("--min_len", type=int, default=3)
    parser.add_argument("--max_len", type=int, default=20)
    parser.add_argument("--num_samples", type=int, default=4)
    parser.add_argument("--fpath", type=str, required=True)
    parser.add_argument("--exclude", nargs='+', default=[])

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = get_params()
    random.seed(args.seed)
    
    klass_name = "Tomita{}Language".format(args.grammar)
    klass = getattr(tomita, klass_name)
    grammar = klass(0.4, 0.4)

    strings = set()

    excluded_strings = load_excluded_data(args.exclude)
    
    pbar = tqdm(total=args.num_samples)
    while len(strings) < args.num_samples:
        ll = random.choice(range(args.min_len, args.max_len+1))
        if args.grammar in [5, 6]:
            string = get_random_string(ll)
        else:
            string = grammar.generate_string(ll, ll)
        
        if grammar.belongs_to_lang(string) and string not in excluded_strings:
            strings.add(string)
            pbar.update(1)

    with open(args.fpath, "w+") as ff:
        for ss in strings:
            ff.write(ss + "\n")