import argparse
import numpy as np

from sequence_utils import build_char_vocab, encode_text, run_lm_task, try_load_ptb_text


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="ptb_char")
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--scan-epochs", type=int, default=3)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--hidden", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--ptb-path", type=str, default=None)
    parser.add_argument("--ptb-block-size", type=int, default=80)
    parser.add_argument("--ptb-steps-per-epoch", type=int, default=300)
    parser.add_argument("--ptb-val-steps", type=int, default=60)
    parser.add_argument("--ptb-max-chars", type=int, default=None)
    parser.add_argument("--eval-every", type=int, default=5)
    parser.add_argument("--gains", type=str, default=None)
    parser.add_argument("--step-labels", type=str, choices=["final", "fptt"], default="final")
    parser.add_argument("--no-plot", action="store_true")
    parser.add_argument("--no-eprop", action="store_true")
    parser.add_argument("--plot-path", type=str, default=None)
    args = parser.parse_args()
    args.plot = not args.no_plot

    gains_default = np.linspace(0.5, 1.6, 8, endpoint=False)
    if args.gains:
        gains = np.array([float(x) for x in args.gains.split(",") if x.strip()], dtype=np.float32)
        if gains.size == 0:
            gains = gains_default
    else:
        gains = gains_default

    train_text, valid_text, test_text = try_load_ptb_text(args.ptb_path)
    stoi, _ = build_char_vocab(train_text)
    train_ids = encode_text(train_text, stoi)
    valid_ids = encode_text(valid_text, stoi)
    test_ids = encode_text(test_text, stoi)

    if args.ptb_max_chars is not None:
        train_ids = train_ids[: args.ptb_max_chars]
        valid_ids = valid_ids[: args.ptb_max_chars]
        test_ids = test_ids[: args.ptb_max_chars]

    vocab_size = len(stoi)
    task_data = {
        "task_type": "lm",
        "task_name": "PTB Char LM",
        "train_data": train_ids,
        "valid_data": valid_ids,
        "test_data": test_ids,
        "input_size": vocab_size,
        "output_size": vocab_size,
    }

    print("STAGE 1: Scanning gains for PTB char LM...")
    print("STAGE 2: Local Rule vs BPTT/E-Prop/FPTT on PTB char LM...")
    run_lm_task(task_data, args, gains)


if __name__ == "__main__":
    main()
