# NSP/datagen/generate.py
from __future__ import annotations
import argparse
import pathlib
from typing import Optional
import json
from .languages import LANGUAGES
from .generator import DataGenerator, GenConfig, write_jsonl, write_vocab


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Generate NSP training/test data.")
    p.add_argument("--language", type=str, default='tomita2',
                   choices=sorted(LANGUAGES.keys()),
                   help="Language to sample positive strings from.")
    p.add_argument("--max-context", type=int, default=150,
                   help="Strict upper bound on total tokens in an example.")
    p.add_argument("--min-str-len", type=int, default=0,
                   help="Minimum length of each positive string s_i.")
    p.add_argument("--max-str-len", type=int, default=80,
                   help="Maximum length of each positive string s_i.")
    p.add_argument("--seed", type=int, default=42, help="RNG seed.")
    # Output control
    # p.add_argument("--out-prefix", type=str, default="data/testing/check",
    #                help="Output prefix; will write <prefix>.train.jsonl and/or .test.jsonl")
    p.add_argument("--num-train", type=int, default=0, help="Number of training examples.")
    p.add_argument("--num-test", type=int, default=0, help="Number of test examples.")
    p.add_argument("--data-name", type=str,default='check', help="Subfolder name under data/<language>/<data-name>, e.g. data1. Will write train.jsonl, test.jsonl, vocab.json there.")
    # p.add_argument("--write-vocab", default=True, action="store_true", help="Also write a <prefix>.vocab.json")
    return p.parse_args()


def build_and_generate(args: argparse.Namespace) -> None:
    lang = LANGUAGES[args.language]
    args.min_str_len = lang.min_len
    cfg = GenConfig(
        language=lang,
        max_context_len=args.max_context,
        min_string_len=args.min_str_len,
        max_string_len=args.max_str_len,
        seed=args.seed,
    )
    gen = DataGenerator(cfg)
    meta = {
        "language": lang.name,
        "sigma": lang.sigma,
        "eos": lang.eos,
        "bos": lang.bos,
        "max_context_len": args.max_context,
        "min_string_len": args.min_str_len,
        "max_string_len": args.max_str_len,
        "train_num": args.num_train,
        "test_num": args.num_test,
        "seed": args.seed,
    }

    out_dir = pathlib.Path("data") / args.language / args.data_name
    out_dir.mkdir(parents=True, exist_ok=True)
    if args.num_train > 0:
        train, avg_first_len = gen.generate_many(args.num_train)
        meta["avg_first_len"] = avg_first_len
        write_jsonl(out_dir / "train.jsonl", train, meta=meta)
    if args.num_test > 0:
        test, avg_first_len = gen.generate_many(args.num_test)
        meta["avg_first_len"] = avg_first_len
        write_jsonl(out_dir / "test.jsonl", test, meta=meta)
    # Always write vocab.json
    write_vocab(out_dir / "vocab.json", gen.vocab)
    # meta is a dictionary, save meta as a separate json file
    (out_dir / "meta.json").write_text(json.dumps(meta, ensure_ascii=False, indent=2))
    


def main():
    args = parse_args()
    if args.num_train <= 0 and args.num_test <= 0:
        raise SystemExit("Specify at least one of --num-train or --num-test > 0.")
    build_and_generate(args)


if __name__ == "__main__":
    main()
